Giter VIP home page Giter VIP logo

dreambooth-docker's Introduction

DreamBooth local docker file for windows/linux

DreamBooth is a method to personalize text2image models like stable diffusion given just a few(3~5) images of a subject.

The training script in this repo is adapted from ShivamShrirao's diffuser repo. See here for detailed training command.

Docker file copy the ShivamShrirao's train_dreambooth.py to root directory. Replace any train_dreambooth.py in original doc with /train_dreambooth.py. Eg, if you want to run accelerate launch train_dreambooth.py you need to run following

sudo docker run -it --gpus=all --ipc=host -v $(pwd):/train -e HUGGING_FACE_HUB_TOKEN=$(cat ~/.huggingface/token)  smy20011/dreambooth:latest \
  accelerate launch /train_dreambooth.py (your arguments here)

Running locally

You need to install WSL with docker support. For linux users, make sure you have NVIDIA driver installed and have latest docker installed. If you still cannot access GPU, make sure you have nvidia container toolkit installed.

For Windows users, follow the guide here and here. You don't need to do the "CUDA Support for WSL2 section". You can also follow this Youtube video for reference as well.

Open the WSL terminal in Windows or open the terminal in Linux. Makesure you have huggingface-cli installed.

Dog toy example

You need to accept the model license before downloading or using the weights. In this example we'll use model version v1-4, so you'll need to visit its card, read the license and tick the checkbox if you agree.

You have to be a registered user in ๐Ÿค— Hugging Face Hub, and you'll also need to use an access token for the code to work. For more information on access tokens, please refer to this section of the documentation.

Run the following command to authenticate your token

huggingface-cli login

Now let's get our dataset. Download images from here and save them in a directory. This will be our training data.

cd to the directory in your terminal. Run following command

export MODEL_NAME="CompVis/stable-diffusion-v1-4"
export INSTANCE_DIR="path-to-instance-images"
export OUTPUT_DIR="path-to-save-model"

sudo docker run -it --gpus=all --ipc=host -v $(pwd):/train -e HUGGING_FACE_HUB_TOKEN=$(cat ~/.huggingface/token)  smy20011/dreambooth:latest \
  accelerate launch /train_dreambooth.py \
  --pretrained_model_name_or_path=$MODEL_NAME \
  --instance_data_dir=$INSTANCE_DIR \
  --output_dir=$OUTPUT_DIR \
  --instance_prompt="a photo of sks dog" \
  --resolution=512 \
  --train_batch_size=1 \
  --gradient_accumulation_steps=1 \
  --learning_rate=5e-6 \
  --lr_scheduler="constant" \
  --lr_warmup_steps=0 \
  --max_train_steps=400

Training with prior-preservation loss

Prior-preservation is used to avoid overfitting and language-drift. Refer to the paper to learn more about it. For prior-preservation we first generate images using the model with a class prompt and then use those during training along with our data. According to the paper, it's recommended to generate num_epochs * num_samples images for prior-preservation. 200-300 works well for most cases.

export MODEL_NAME="CompVis/stable-diffusion-v1-4"
export INSTANCE_DIR="path-to-instance-images"
export CLASS_DIR="path-to-class-images"
export OUTPUT_DIR="path-to-save-model"

sudo docker run -it --gpus=all --ipc=host -v $(pwd):/train -e HUGGING_FACE_HUB_TOKEN=$(cat ~/.huggingface/token)  smy20011/dreambooth:latest \
  accelerate launch /train_dreambooth.py \
  --pretrained_model_name_or_path=$MODEL_NAME \
  --instance_data_dir=$INSTANCE_DIR \
  --class_data_dir=$CLASS_DIR \
  --output_dir=$OUTPUT_DIR \
  --with_prior_preservation --prior_loss_weight=1.0 \
  --instance_prompt="a photo of sks dog" \
  --class_prompt="a photo of dog" \
  --resolution=512 \
  --train_batch_size=1 \
  --gradient_accumulation_steps=1 \
  --learning_rate=5e-6 \
  --lr_scheduler="constant" \
  --lr_warmup_steps=0 \
  --num_class_images=200 \
  --max_train_steps=800

Training on a 16GB GPU:

With the help of gradient checkpointing and the 8-bit optimizer from bitsandbytes it's possible to run train dreambooth on a 16GB GPU.

export MODEL_NAME="CompVis/stable-diffusion-v1-4"
export INSTANCE_DIR="path-to-instance-images"
export CLASS_DIR="path-to-class-images"
export OUTPUT_DIR="path-to-save-model"

sudo docker run -it --gpus=all --ipc=host -v $(pwd):/train -e HUGGING_FACE_HUB_TOKEN=$(cat ~/.huggingface/token)  smy20011/dreambooth:latest \
  accelerate launch /train_dreambooth.py \
  --pretrained_model_name_or_path=$MODEL_NAME \
  --instance_data_dir=$INSTANCE_DIR \
  --class_data_dir=$CLASS_DIR \
  --output_dir=$OUTPUT_DIR \
  --with_prior_preservation --prior_loss_weight=1.0 \
  --instance_prompt="a photo of sks dog" \
  --class_prompt="a photo of dog" \
  --resolution=512 \
  --train_batch_size=1 \
  --gradient_accumulation_steps=2 --gradient_checkpointing \
  --use_8bit_adam \
  --learning_rate=5e-6 \
  --lr_scheduler="constant" \
  --lr_warmup_steps=0 \
  --num_class_images=200 \
  --max_train_steps=800

dreambooth-docker's People

Contributors

smy20011 avatar erjanmx avatar aswilam avatar

Recommend Projects

  • React photo React

    A declarative, efficient, and flexible JavaScript library for building user interfaces.

  • Vue.js photo Vue.js

    ๐Ÿ–– Vue.js is a progressive, incrementally-adoptable JavaScript framework for building UI on the web.

  • Typescript photo Typescript

    TypeScript is a superset of JavaScript that compiles to clean JavaScript output.

  • TensorFlow photo TensorFlow

    An Open Source Machine Learning Framework for Everyone

  • Django photo Django

    The Web framework for perfectionists with deadlines.

  • D3 photo D3

    Bring data to life with SVG, Canvas and HTML. ๐Ÿ“Š๐Ÿ“ˆ๐ŸŽ‰

Recommend Topics

  • javascript

    JavaScript (JS) is a lightweight interpreted programming language with first-class functions.

  • web

    Some thing interesting about web. New door for the world.

  • server

    A server is a program made to process requests and deliver data to clients.

  • Machine learning

    Machine learning is a way of modeling and interpreting data that allows a piece of software to respond intelligently.

  • Game

    Some thing interesting about game, make everyone happy.

Recommend Org

  • Facebook photo Facebook

    We are working to build community through open source technology. NB: members must have two-factor auth.

  • Microsoft photo Microsoft

    Open source projects and samples from Microsoft.

  • Google photo Google

    Google โค๏ธ Open Source for everyone.

  • D3 photo D3

    Data-Driven Documents codes.