Giter VIP home page Giter VIP logo

sussy's Introduction

susie

Code for the paper Zero-Shot Robotic Manipulation With Pretrained Image-Editing Diffusion Models.

This repository contains the code for training the high-level image-editing diffusion model on video data. For training the low-level policy, head over to the BridgeData V2 repository --- we use the gc_ddpm_bc agent, unmodified, with an action prediction horizon of 4 and the delta_goals relabeling strategy.

For integration with the CALVIN simulator and reproducing our simulated results, see our fork of the calvin-sim repo and the corresponding documentation in the BridgeData V2 repository.

  • Creating datasets: this repo uses dlimp for dataloading. Check out the scripts/ directory inside dlimp for creating TFRecords in a compatible format.
  • Installation: pip install -r requirements.txt to install the versions of required packages confirmed to be working with this codebase. Then, pip install -e .. Only tested with Python 3.10. You'll also have to manually install Jax for your platform (see the Jax installation instructions). Make sure you have the Jax version specified in requirements.txt (rather than using --upgrade as suggested in the Jax docs).
  • Training: once the missing dataset paths have been filled in inside base.py, you can start training by running python scripts/train.py --config configs/base.py:base.
  • Evaluation: robot evaluation scripts are provided in the scripts/robot directory. You probably won't be able to run them, since you don't have our robot setup, but they are there for reference. See create_sample_fn in susie/model.py for canonical sampling code.

Model Weights

The UNet weights for our best-performing model, trained on BridgeData and Something-Something for 40k steps, are hosted on HuggingFace. They can be loaded using FlaxUNet2DConditionModel.from_pretrained("kvablack/susie", subfolder="unet"). Use with the standard Stable Diffusion v1-5 VAE and text encoder.

Here's a quickstart for getting out-of-the-box subgoals using this repo:

from susie.model import create_sample_fn
from susie.jax_utils import initialize_compilation_cache
import requests
import numpy as np
from PIL import Image

initialize_compilation_cache()

IMAGE_URL = "https://rail.eecs.berkeley.edu/datasets/bridge_release/raw/bridge_data_v2/datacol2_toykitchen7/drawer_pnp/01/2023-04-19_09-18-15/raw/traj_group0/traj0/images0/im_12.jpg"

sample_fn = create_sample_fn("kvablack/susie")
image = np.array(Image.open(requests.get(IMAGE_URL, stream=True).raw).resize((256, 256)))
image_out = sample_fn(image, "open the drawer")

# to display the images if you're in a Jupyter notebook
display(Image.fromarray(image))
display(Image.fromarray(image_out))

sussy's People

Contributors

kvablack avatar johnwick123f avatar

Recommend Projects

  • React photo React

    A declarative, efficient, and flexible JavaScript library for building user interfaces.

  • Vue.js photo Vue.js

    ๐Ÿ–– Vue.js is a progressive, incrementally-adoptable JavaScript framework for building UI on the web.

  • Typescript photo Typescript

    TypeScript is a superset of JavaScript that compiles to clean JavaScript output.

  • TensorFlow photo TensorFlow

    An Open Source Machine Learning Framework for Everyone

  • Django photo Django

    The Web framework for perfectionists with deadlines.

  • D3 photo D3

    Bring data to life with SVG, Canvas and HTML. ๐Ÿ“Š๐Ÿ“ˆ๐ŸŽ‰

Recommend Topics

  • javascript

    JavaScript (JS) is a lightweight interpreted programming language with first-class functions.

  • web

    Some thing interesting about web. New door for the world.

  • server

    A server is a program made to process requests and deliver data to clients.

  • Machine learning

    Machine learning is a way of modeling and interpreting data that allows a piece of software to respond intelligently.

  • Game

    Some thing interesting about game, make everyone happy.

Recommend Org

  • Facebook photo Facebook

    We are working to build community through open source technology. NB: members must have two-factor auth.

  • Microsoft photo Microsoft

    Open source projects and samples from Microsoft.

  • Google photo Google

    Google โค๏ธ Open Source for everyone.

  • D3 photo D3

    Data-Driven Documents codes.