Giter VIP home page Giter VIP logo

Comments (7)

YuxinWenRick avatar YuxinWenRick commented on July 23, 2024 2

Hi, thanks for your interest.

Yeah, it would be cool to apply PEZ to SDXL. I think the straightforward way is to optimize separate prompts for each text encoder and feed the prompt to the corresponding text encoder. This might require some small modifications to the diffusers pipeline here: https://github.com/huggingface/diffusers/blob/b9feed87958c27074b0618cc543696c05f58e2c9/src/diffusers/pipelines/stable_diffusion_xl/pipeline_stable_diffusion_xl.py#L316. Instead of having one universal prompt, we can pass a list of prompts. I don't have SDXL model weights for now, but I think it will be public this month, so as I have the model weights, I will play around with it.

For now, maybe the simplest way is just to use the prompt optimized only with ViT-bigG, cause ViT-bigG is the main text-encoder for SDXL (I believe so), so it should kind work. To do so, you can just simply change args.clip_model = "ViT-bigG-14" and args.clip_pretrain = "laion2b_s39b_b160k".

from hard-prompts-made-easy.

YuxinWenRick avatar YuxinWenRick commented on July 23, 2024 2

Hi manzonif, thank you for sharing the details. I have been busy with a conference deadline recently, but I will try my best to test it either this month or the next. I appreciate your understanding and patience.

To delve a bit deeper into the conceptual framework I had in mind earlier, there are two ways I am considering:

  1. Optimizing two independent prompts for the two text encoders.
  2. Optimizing a universal prompt using an ensemble of two text encoders.

from hard-prompts-made-easy.

manzonif avatar manzonif commented on July 23, 2024 1

Certainly! I wish you a good conference.

from hard-prompts-made-easy.

0x1355 avatar 0x1355 commented on July 23, 2024

Gotcha. Will try it out 😎

from hard-prompts-made-easy.

StableInfo avatar StableInfo commented on July 23, 2024

Any updates?

from hard-prompts-made-easy.

manzonif avatar manzonif commented on July 23, 2024

Hi, I've given it a try, but it seems that it's not working as expected. It's not learning.
I'm doing my best, but I'm new to Python and Torch, so there might be something I'm overlooking in my code, (even some unforgivable mistakes :-) ).
I tried to use "ViT-bigG-14" clip model but it's too big for my 4090 24GB
I used last diffusers==0.20.0
Here my prompt inversion:

import open_clip
import torch
from torchvision import transforms
import argparse
import datetime
import os
import copy
from transformers.optimization import Adafactor, AdafactorSchedule
from optim_utils import * 
from diffusers import DDPMScheduler, DPMSolverMultistepScheduler
from modified_stable_diffusion_xl_pipeline import ModifiedStableDiffusionPipelineXL

    
args = argparse.Namespace()
args.iter = 1000
args.prompt_len = 8
args.lr = 0.1
args.weight_decay = 0.1
args.opt_iters = 3000
args.eval_step = 50
args.prompt_bs = 1
args.loss_weight = 1.0
args.print_step = 100
args.batch_size = 1
# args.clip_model = "ViT-bigG-14"
# args.clip_pretrain =  "laion2b_s39b_b160k"
args.clip_model = "ViT-H-14"
args.clip_pretrain =  "laion2b_s32b_b79k"
best_loss = -999
eval_loss = -99999
best_text = ""

weight_dtype = torch.bfloat16

device = "cuda" if torch.cuda.is_available() else "cpu"


def initialize_prompt(tokenizers_list, token_embeddings_list, args, device):
    prompt_len = args.prompt_len
    # randomly optimize prompt embeddings    
    prompt_embeds_list = []
    dummy_embeds_list = []
    dummy_ids_list = []
    prompt_ids = torch.randint(len(tokenizers_list[0].encoder), (args.prompt_bs, prompt_len)).to(device)
    for tokenizer, token_embeddings in zip(tokenizers_list, token_embeddings_list):

        prompt_embeds = token_embeddings(prompt_ids).detach()
        prompt_embeds.requires_grad = True
        # initialize the template
        # -1 for optimized tokens
        dummy_ids = [tokenizer.bos_token_id] + [-1] * prompt_len + [tokenizer.eos_token_id] + [0] * (75 - prompt_len)
        dummy_ids = torch.tensor([dummy_ids] * args.prompt_bs).to(device)
        # for getting dummy embeds; -1 won't work for token_embedding
        tmp_dummy_ids = [tokenizer.bos_token_id] + [0] * prompt_len + [tokenizer.eos_token_id] + [0] * (75 - prompt_len)
        tmp_dummy_ids = torch.tensor([tmp_dummy_ids] * args.prompt_bs).to(device)
    
        dummy_embeds = token_embeddings(tmp_dummy_ids).detach()
        dummy_embeds.requires_grad = False
        prompt_embeds_list.append(prompt_embeds)
        dummy_embeds_list.append(dummy_embeds)
        dummy_ids_list.append(dummy_ids)
    
    return prompt_embeds_list, dummy_embeds_list, dummy_ids_list

model_id = "stabilityai/stable-diffusion-xl-base-1.0"
scheduler = DPMSolverMultistepScheduler.from_pretrained(model_id, subfolder="scheduler")
# scheduler = DDPMScheduler(
#     beta_start=0.00085, beta_end=0.012, beta_schedule="scaled_linear", num_train_timesteps=1000, clip_sample=False
# )


pipe = ModifiedStableDiffusionPipelineXL.from_pretrained(
    model_id,
    scheduler=scheduler,
    torch_dtype=weight_dtype,
    variant="fp16", 
    use_safetensors=True
)
pipe = pipe.to(device)

pipe.vae.requires_grad_(False)
pipe.vae.eval()

pipe.unet.requires_grad_(True)
pipe.unet.train()

clip_model, _, clip_preprocess = open_clip.create_model_and_transforms(args.clip_model, pretrained=args.clip_pretrain, device=device) 

image_length = 1024
tokenizers_list = [pipe.tokenizer, pipe.tokenizer_2] if pipe.tokenizer is not None else [pipe.tokenizer_2]
token_embeddings_list =[pipe.text_encoder.text_model.embeddings.token_embedding, pipe.text_encoder_2.text_model.embeddings.token_embedding]


preprocess = transforms.Compose(
    [
        transforms.Resize(1024, interpolation=transforms.InterpolationMode.BILINEAR),
        transforms.CenterCrop(1024),
        transforms.ToTensor(),
    ]
)

urls = [
        "https://www.parkwestgallery.com/wp-content/uploads/2017/10/im811243-e1507918728745.jpg",
       ]

orig_images = list(filter(None,[download_image(url) for url in urls]))

SDXL_VAE_SCALE_FACTOR = 0.13025

with torch.no_grad():
    curr_images = [preprocess(i).unsqueeze(0) for i in orig_images]
    curr_images = torch.concatenate(curr_images).to(device)
    all_latents = pipe.vae.encode(curr_images.to(weight_dtype)).latent_dist.sample()
    all_latents = all_latents * SDXL_VAE_SCALE_FACTOR


#initialize random prompt 
prompt_embeds_list, dummy_embeds_list, dummy_ids_list = initialize_prompt(tokenizers_list, token_embeddings_list, args, device)
# input_optimizer = Adafactor(prompt_embeds_list, scale_parameter=False, relative_step=False, warmup_init=False, lr=0.2)
input_optimizer = torch.optim.AdamW(prompt_embeds_list, lr=args.lr, weight_decay=args.weight_decay)
input_optim_scheduler = None

for step in range(args.opt_iters):
    padded_embeds_list = []
    padded_dummy_ids_list = []
    tmp_embeds_list = []
    nn_indices_list = []

    # forward projection (top1 semantic_search(prompt_embeds, token_embedding))
    for prompt_embeds, dummy_embeds, dummy_ids, tokenizer, token_embeddings in zip(prompt_embeds_list, dummy_embeds_list, dummy_ids_list, tokenizers_list, token_embeddings_list):    
        projected_embeds, nn_indices = nn_project(prompt_embeds, token_embeddings)

        tmp_embeds = copy.deepcopy(prompt_embeds)
        tmp_embeds.data = projected_embeds.data
        tmp_embeds.requires_grad = True
            
        # padding and repeat
        padded_embeds = copy.deepcopy(dummy_embeds)
        padded_embeds[:, 1:args.prompt_len+1] = tmp_embeds
        padded_embeds = padded_embeds.repeat(args.batch_size, 1, 1)
        padded_dummy_ids = dummy_ids.repeat(args.batch_size, 1)
        nn_indices_list.append(nn_indices)
        padded_embeds_list.append(padded_embeds)
        padded_dummy_ids_list.append(padded_dummy_ids)
        tmp_embeds_list.append(tmp_embeds)
    
    # randomly sample sample images and get features
    if args.batch_size is None:
        latents = all_latents
    else:
        perm = torch.randperm(len(all_latents))
        idx = perm[:args.batch_size]
        latents = all_latents[idx]
        
    # Sample noise that we'll add to the latents
    noise = torch.randn_like(latents)
    bsz = latents.shape[0]
    # Sample a random timestep for each image
    timesteps = torch.randint(0, 1000, (bsz,), device=latents.device)
    timesteps = timesteps.long()
        
    # Add noise to the latents according to the noise magnitude at each timestep
    noisy_latents = pipe.scheduler.add_noise(latents, noise, timesteps)

    # Get the target for loss depending on the prediction type
    if pipe.scheduler.config.prediction_type == "epsilon":
        target = noise
    elif pipe.scheduler.config.prediction_type == "v_prediction":
        target = pipe.scheduler.get_velocity(latents, noise, timesteps)
    else:
        raise ValueError(f"Unknown prediction type {pipe.scheduler.config.prediction_type}")
        
    # get text embeddings
    text_embeddings, pooled_prompt_embeds = pipe._get_text_embedding_with_embeddings(padded_dummy_ids_list, padded_embeds_list)
    
    add_time_ids = pipe._get_add_time_ids(
        (image_length, image_length), (0,0), (image_length, image_length), dtype=prompt_embeds.dtype
    ).to(device)
    add_text_embeds = pooled_prompt_embeds
    # Predict the noise residual and compute loss
    model_pred = pipe.unet(noisy_latents, timesteps, encoder_hidden_states=text_embeddings, added_cond_kwargs = {"text_embeds": add_text_embeds, "time_ids": add_time_ids}).sample
    loss = torch.nn.functional.mse_loss(model_pred.float(), target.float(), reduction="mean")
  
    prompt_embeds_list = torch.autograd.grad(loss, tmp_embeds_list)
    input_optimizer.step()
    input_optimizer.zero_grad()    
    
    curr_lr = input_optimizer.param_groups[0]["lr"]
    
    ### eval
    if step % args.eval_step == 0:
        prompt_1 = decode_ids(nn_indices_list[0], tokenizers_list[0])[0]
        prompt_2 = decode_ids(nn_indices_list[1], tokenizers_list[1])[0]
        print(f"step: {step}, lr: {curr_lr}, cosim: {eval_loss:.3f}, best_cosim: {best_loss:.3f}, best prompt: {best_text}")

        with torch.no_grad():
            pred_imgs = pipe(
                prompt_1,
                prompt_2,
                num_images_per_prompt=4,
                guidance_scale=9,
                num_inference_steps=50,
                height=image_length,
                width=image_length,
                output_type='pil'
                ).images
            eval_loss = measure_similarity(orig_images, pred_imgs, clip_model, clip_preprocess, device)

        if best_loss < eval_loss:
            best_loss = eval_loss
            best_text = f'{prompt_1} {prompt_2}'   
            
print()
print(f"Best shot: consine similarity: {best_loss:.3f}")
print(f"text: {best_text}")
# you can customize the learned prompt here
prompt = best_text

num_images = 4
guidance_scale = 9
num_inference_steps = 25

images = pipe(
    prompt,
    num_images_per_prompt=num_images,
    guidance_scale=guidance_scale,
    num_inference_steps=num_inference_steps,
    height=image_length,
    width=image_length,
    ).images

timestamp = datetime.datetime.now().strftime("%Y%m%d-%H%M%S")
for i, img in enumerate(images):
    img.save(os.path.join('output/', f"sd2_result_{timestamp}_{i:03d}.png"))

print("Save images.")

Here the modified Pipeline:

from typing import Callable, List, Optional, Union

import torch
from diffusers import StableDiffusionXLPipeline
from transformers import CLIPTextModel, CLIPTextModelWithProjection, CLIPTokenizer
from transformers.models.clip.modeling_clip import CLIPTextModelOutput
from diffusers.pipelines.stable_diffusion_xl import StableDiffusionXLPipelineOutput
from diffusers.utils import logging
from transformers.modeling_outputs import BaseModelOutputWithPooling
from typing import Any, Callable, Dict, List, Optional, Tuple, Union

logger = logging.get_logger(__name__)  # pylint: disable=invalid-name

class ModifiedStableDiffusionPipelineXL(StableDiffusionXLPipeline):
    def __init__(self,
        vae,
        text_encoder: CLIPTextModel,
        text_encoder_2: CLIPTextModelWithProjection,
        tokenizer: CLIPTokenizer,
        tokenizer_2: CLIPTokenizer,
        unet,
        scheduler,
        force_zeros_for_empty_prompt: bool = True,
        add_watermarker: Optional[bool] = None
    ):
        super(ModifiedStableDiffusionPipelineXL, self).__init__(vae,
                text_encoder,
                text_encoder_2,
                tokenizer,
                tokenizer_2,
                unet,
                scheduler,
                force_zeros_for_empty_prompt,
                add_watermarker)
        
    def _build_causal_attention_mask(self,bsz, seq_len, dtype):
        # lazily create causal attention mask, with full attention between the vision tokens
        # pytorch uses additive attention mask; fill with -inf
        mask = torch.empty(bsz, seq_len, seq_len, dtype=dtype)
        mask.fill_(torch.tensor(torch.finfo(dtype).min))
        mask.triu_(1)  # zero out the lower diagonal
        mask = mask.unsqueeze(1)  # expand mask
        return mask
    

    def _encode_embeddings(self, text_encoder, input_ids, prompt_embeddings, attention_mask=None):
        output_attentions = text_encoder.text_model.config.output_attentions
        output_hidden_states = True
        return_dict = text_encoder.text_model.config.use_return_dict

        hidden_states = text_encoder.text_model.embeddings(inputs_embeds=prompt_embeddings)

        bsz, seq_len = input_ids.shape[0], input_ids.shape[1]
        # CLIP's text model uses causal mask, prepare it here.
        # https://github.com/openai/CLIP/blob/cfcffb90e69f37bf2ff1e988237a0fbe41f33c04/clip/model.py#L324
        causal_attention_mask = self._build_causal_attention_mask(bsz, seq_len, hidden_states.dtype).to(
            hidden_states.device
        )

        # expand attention_mask
        if attention_mask is not None:
            # [bsz, seq_len] -> [bsz, 1, tgt_seq_len, src_seq_len]
            attention_mask = text_encoder.text_model._expand_mask(attention_mask, hidden_states.dtype)

        encoder_outputs = text_encoder.text_model.encoder(
            inputs_embeds=hidden_states,
            attention_mask=attention_mask,
            causal_attention_mask=causal_attention_mask,
            output_attentions=output_attentions,
            output_hidden_states=output_hidden_states,
            return_dict=return_dict,
        )

        last_hidden_state = encoder_outputs[0]
        last_hidden_state = text_encoder.text_model.final_layer_norm(last_hidden_state)

        # text_embeds.shape = [batch_size, sequence_length, transformer.width]
        # take features from the eot embedding (eot_token is the highest number in each sequence)
        # casting to torch.int for onnx compatibility: argmax doesn't support int64 inputs with opset 14
        pooled_output = last_hidden_state[
            torch.arange(last_hidden_state.shape[0], device=last_hidden_state.device),
            input_ids.to(dtype=torch.int, device=last_hidden_state.device).argmax(dim=-1),
        ]

        text_outputs = BaseModelOutputWithPooling(
            last_hidden_state=last_hidden_state,
            pooler_output=pooled_output,
            hidden_states=encoder_outputs.hidden_states,
            attentions=encoder_outputs.attentions,
        )
        if isinstance(text_encoder, CLIPTextModelWithProjection):
            pooled_output = text_outputs[1]

            text_embeds = text_encoder.text_projection(pooled_output)

            if not return_dict:
                outputs = (text_embeds, text_outputs[0]) + text_outputs[2:]
                return tuple(output for output in outputs if output is not None)

            return CLIPTextModelOutput(
                text_embeds=text_embeds,
                last_hidden_state=text_outputs.last_hidden_state,
                hidden_states=text_outputs.hidden_states,
                attentions=text_outputs.attentions,
            )

        if not return_dict:
            return (last_hidden_state, pooled_output) + encoder_outputs[1:]

        return text_outputs

    def _get_text_embedding_with_embeddings(self, text_input_ids_list, prompt_embeddings_list):
        text_encoders_list = (
            [self.text_encoder, self.text_encoder_2] if self.text_encoder is not None else [self.text_encoder_2]
        )
        prompt_embeds_list = []
        for text_input_ids, prompt_embeddings, text_encoder in zip(text_input_ids_list, prompt_embeddings_list, text_encoders_list):
            text_embeddings = self._encode_embeddings(
                text_encoder,
                text_input_ids,
                prompt_embeddings
            )
             # We are only ALWAYS interested in the pooled output of the final text encoder
            pooled_prompt_embeds = text_embeddings[0]
            text_embeddings = text_embeddings.hidden_states[-2]
            prompt_embeds_list.append(text_embeddings)

        prompt_embeds = torch.cat(prompt_embeds_list, dim=-1)
        return prompt_embeds, pooled_prompt_embeds

from hard-prompts-made-easy.

YuxinWenRick avatar YuxinWenRick commented on July 23, 2024

Hi @manzonif , sorry about the late response. Not sure if you have any progress on this, but I recently tried to optimize two independent prompts for the two text encoders. However, it doesn't work very well. I am going to double-check the code and also see if optimizing a universal prompt with an ensemble of two text encoders works.

Thanks for your patience!

from hard-prompts-made-easy.

slashedstar avatar slashedstar commented on July 23, 2024

So, did anyone have any luck making this work for XL? 🤔

from hard-prompts-made-easy.

Related Issues (15)

Recommend Projects

  • React photo React

    A declarative, efficient, and flexible JavaScript library for building user interfaces.

  • Vue.js photo Vue.js

    🖖 Vue.js is a progressive, incrementally-adoptable JavaScript framework for building UI on the web.

  • Typescript photo Typescript

    TypeScript is a superset of JavaScript that compiles to clean JavaScript output.

  • TensorFlow photo TensorFlow

    An Open Source Machine Learning Framework for Everyone

  • Django photo Django

    The Web framework for perfectionists with deadlines.

  • D3 photo D3

    Bring data to life with SVG, Canvas and HTML. 📊📈🎉

Recommend Topics

  • javascript

    JavaScript (JS) is a lightweight interpreted programming language with first-class functions.

  • web

    Some thing interesting about web. New door for the world.

  • server

    A server is a program made to process requests and deliver data to clients.

  • Machine learning

    Machine learning is a way of modeling and interpreting data that allows a piece of software to respond intelligently.

  • Game

    Some thing interesting about game, make everyone happy.

Recommend Org

  • Facebook photo Facebook

    We are working to build community through open source technology. NB: members must have two-factor auth.

  • Microsoft photo Microsoft

    Open source projects and samples from Microsoft.

  • Google photo Google

    Google ❤️ Open Source for everyone.

  • D3 photo D3

    Data-Driven Documents codes.