Giter VIP home page Giter VIP logo

Comments (6)

karalets avatar karalets commented on August 20, 2024 1

I like this and agree we need a compact way to write that up, currently it is much too contrived.

Will spend some time on it to play with and see if I have any useful comments.

from pyro.

eb8680 avatar eb8680 commented on August 20, 2024

I like the idea of pyro.random_module (but hopefully with a shorter, punchier name - pyro.nn.lift?) for lifting a nn.Module to a stochastic function that returns new nn.Modules with parameters sampled from a prior. Here's a slightly different way to generate a guide automatically that seems more Pyronic (?):

def make_guide(fn, sites=None):
    def guide(*args, **kwargs):
        model_trace = poutine.block(poutine.trace(fn))(*args, **kwargs)
        if sites is None:
            sites = {name: name for name in model_trace.keys()}
        for name in model_trace.keys():
            if model_trace[name]["type"] == "sample" and name in sites:
                pyro.sample(sites[name], make_site_posterior(model_trace[name], *args, **kwargs))

    return guide

guide_dist = make_guide(pyro.random_module(mod, prior))

As written this generates mean-field guides, but you can write more sophisticated guides in a similar style.

from pyro.

eb8680 avatar eb8680 commented on August 20, 2024

Riffing on this some more because I quite like it: there's no reason the parameter-lifting operation has to be nn-specific. Imagine a poutine operation poutine.lift(fn, prior) that overrides each pyro.param call in fn with a pyro.sample call internally using the provided prior. Then for nn.Modules we can just write

pyro.random_module = lambda name, mod, prior: poutine.lift(pyro.module, prior)(name, mod)

but now the same principles, as well as guide generators like the one above, can be applied to any stochastic function that has pyro.param calls.

I'm not completely happy with this structure, though, because conceptually it would be nicer if, like the proposed pyro.random_module, lift(fn, prior) returned a distribution over fn-like callables that could be called to sample a single fn with new values but no more randomness at the original pyro.param sites. I'm not sure how to do this, since different execution traces may contain entirely different sets of pyro.param sites that can only be determined at runtime.

Edit: Ok, I thought about the last problem some more. I don't think there's a way to do that in general, and it's not even a probabilistically coherent request because the joint distribution over fns and traces doesn't factor that way when the appearance of a pyro.param in a trace is determined by pyro.sample or pyro.observe statements in the trace (or else, if it does, it's no longer guaranteed to have a density).

However, suppose we happen to know that all execution traces of a stochastic function fn will contain the same pyro.param sites. In that case the distribution does factor that way and we should in principle be able to create a nn.Module from fn and lift it with pyro.random_module:

class LiftableFunction(nn.Module):
    def __init__(self, fn, *args, **kwargs):
        self.fn = fn
        initial_trace = poutine.block(poutine.trace(fn))(*args, **kwargs)
        for name in initial_trace.keys():
            if initial_trace[name]["type"] == "param":
                # XXX something like this? not exactly correct
                setattr(self, "_weight_" + name, nn.Parameter(pyro.param(name, ...)))

    def forward(self, *args, **kwargs):
        return self.fn(*args, **kwargs)

from pyro.

ngoodman avatar ngoodman commented on August 20, 2024

i like the idea of a more general "lift" function that promotes params to samples! you're right that it couldn't cleanly separate the new randomness from the original randomness in the fn. it's not totally clear to me if this is an important separation. without that separation the bayesian nn example would look something like:

stoch_classifier = pyro.random_module("classifier", classify, prior) 
class_weights = stoch_classifier.forward(data) #use the net (as ordinary *stochastic* fn)

which is actually simpler! we've basically just upgraded the deterministic (but parameterized) function defined by the module to a stochastic function of the same signature.

btw. the make_guide function defined in the above comment makes sense only if the samples don't affect control flow. (which is true in the module case, but maybe not generally?)

from pyro.

ngoodman avatar ngoodman commented on August 20, 2024

for future reference, a nice but pretty straightforward use of bayesian rnns: https://arxiv.org/pdf/1704.02798.pdf

from pyro.

eb8680 avatar eb8680 commented on August 20, 2024

Addressed by #121

from pyro.

Related Issues (20)

Recommend Projects

  • React photo React

    A declarative, efficient, and flexible JavaScript library for building user interfaces.

  • Vue.js photo Vue.js

    🖖 Vue.js is a progressive, incrementally-adoptable JavaScript framework for building UI on the web.

  • Typescript photo Typescript

    TypeScript is a superset of JavaScript that compiles to clean JavaScript output.

  • TensorFlow photo TensorFlow

    An Open Source Machine Learning Framework for Everyone

  • Django photo Django

    The Web framework for perfectionists with deadlines.

  • D3 photo D3

    Bring data to life with SVG, Canvas and HTML. 📊📈🎉

Recommend Topics

  • javascript

    JavaScript (JS) is a lightweight interpreted programming language with first-class functions.

  • web

    Some thing interesting about web. New door for the world.

  • server

    A server is a program made to process requests and deliver data to clients.

  • Machine learning

    Machine learning is a way of modeling and interpreting data that allows a piece of software to respond intelligently.

  • Game

    Some thing interesting about game, make everyone happy.

Recommend Org

  • Facebook photo Facebook

    We are working to build community through open source technology. NB: members must have two-factor auth.

  • Microsoft photo Microsoft

    Open source projects and samples from Microsoft.

  • Google photo Google

    Google ❤️ Open Source for everyone.

  • D3 photo D3

    Data-Driven Documents codes.