Giter VIP home page Giter VIP logo

neuralprocesses's Introduction

Neural Processes

This is an implementation of Neural Processes for 1D-regression, accompanying my blog post.

Structure of the repo

The implementation uses TensorFlow in R:

  • The file NP_core.R contains functions to define the loss function and carry out posterior prediction.
  • The files NP_architecture*.R specify the NN architectures for the encoder h and decoder g. (Note: when changing network architecture, e.g. when fitting a new model, you need to run tf$reset_default_graph() or restart your R session.)

All experiments can be found in the "experiments" folder (where they appear in the same order as in the blog post):

  • The first experiment involves training an NP on a single small data set.
  • The second experiment involves training an NP on a small class of functions of the form a * sin(x).
  • The third experiment involves training an NP on repeated draws from the GP.

Example code

Loading all the libraries and helper functions

library(tidyverse)
library(tensorflow)
library(patchwork)

source("NP_core.R")
source("GP_helpers.R")
source("helpers_for_plotting.R")
source("NP_architecture1.R")

Setting up the NP model:

sess <- tf$Session()

# specify (global variables) for dimensionality of r, z, and hidden layers of g and h
dim_r <- 2L
dim_z <- 2L
dim_h_hidden <- 32L
dim_g_hidden <- 32L

# placeholders for training inputs
x_context <- tf$placeholder(tf$float32, shape(NULL, 1))
y_context <- tf$placeholder(tf$float32, shape(NULL, 1))
x_target <- tf$placeholder(tf$float32, shape(NULL, 1))
y_target <- tf$placeholder(tf$float32, shape(NULL, 1))

# set up NN
train_op_and_loss <- init_NP(x_context, y_context, x_target, y_target, learning_rate = 0.001)

# initialise
init <- tf$global_variables_initializer()
sess$run(init)

n_iter <- 50000

Now, sampling data according to the function y = a*sin(x),we can fit the model as follows:

n_iter <- 10000

for(iter in 1:n_iter){
  # sample data (x_obs, y_obs)
  N <- 20
  x_obs <- runif(N, -3, 3)
  a <- runif(1, -2, 2)
  y_obs <- a * sin(x_obs)
  
  # sample N_context for training
  N_context <- sample(1:10, 1)
  
  # use helper function to pick a random context set
  feed_dict <- helper_context_and_target(x_obs, y_obs, N_context, x_context, y_context, x_target, y_target)
  
  # optimisation step
  a <- sess$run(train_op_and_loss, feed_dict = feed_dict)
  
  if(iter %% 1e3 == 0){
    cat(sprintf("loss = %1.3f\n", a[[2]]))
  }
}

Prediction using the trained model:

# context set at prediction-time
x0 <- c(0, 1)
y0 <- 1*sin(x0)

# prediction grid
x_star <- seq(-4, 4, length=100)

# plot posterior draws
plot_posterior_draws(x0, y0, x_star, n_draws = 50)

neuralprocesses's People

Contributors

kasparmartens avatar

Watchers

 avatar  avatar

Recommend Projects

  • React photo React

    A declarative, efficient, and flexible JavaScript library for building user interfaces.

  • Vue.js photo Vue.js

    ๐Ÿ–– Vue.js is a progressive, incrementally-adoptable JavaScript framework for building UI on the web.

  • Typescript photo Typescript

    TypeScript is a superset of JavaScript that compiles to clean JavaScript output.

  • TensorFlow photo TensorFlow

    An Open Source Machine Learning Framework for Everyone

  • Django photo Django

    The Web framework for perfectionists with deadlines.

  • D3 photo D3

    Bring data to life with SVG, Canvas and HTML. ๐Ÿ“Š๐Ÿ“ˆ๐ŸŽ‰

Recommend Topics

  • javascript

    JavaScript (JS) is a lightweight interpreted programming language with first-class functions.

  • web

    Some thing interesting about web. New door for the world.

  • server

    A server is a program made to process requests and deliver data to clients.

  • Machine learning

    Machine learning is a way of modeling and interpreting data that allows a piece of software to respond intelligently.

  • Game

    Some thing interesting about game, make everyone happy.

Recommend Org

  • Facebook photo Facebook

    We are working to build community through open source technology. NB: members must have two-factor auth.

  • Microsoft photo Microsoft

    Open source projects and samples from Microsoft.

  • Google photo Google

    Google โค๏ธ Open Source for everyone.

  • D3 photo D3

    Data-Driven Documents codes.