Comments (7)
Hi, thanks for your interest.
Yeah, it would be cool to apply PEZ to SDXL. I think the straightforward way is to optimize separate prompts for each text encoder and feed the prompt to the corresponding text encoder. This might require some small modifications to the diffusers pipeline here: https://github.com/huggingface/diffusers/blob/b9feed87958c27074b0618cc543696c05f58e2c9/src/diffusers/pipelines/stable_diffusion_xl/pipeline_stable_diffusion_xl.py#L316. Instead of having one universal prompt, we can pass a list of prompts. I don't have SDXL model weights for now, but I think it will be public this month, so as I have the model weights, I will play around with it.
For now, maybe the simplest way is just to use the prompt optimized only with ViT-bigG, cause ViT-bigG is the main text-encoder for SDXL (I believe so), so it should kind work. To do so, you can just simply change args.clip_model = "ViT-bigG-14"
and args.clip_pretrain = "laion2b_s39b_b160k"
.
from hard-prompts-made-easy.
Hi manzonif, thank you for sharing the details. I have been busy with a conference deadline recently, but I will try my best to test it either this month or the next. I appreciate your understanding and patience.
To delve a bit deeper into the conceptual framework I had in mind earlier, there are two ways I am considering:
- Optimizing two independent prompts for the two text encoders.
- Optimizing a universal prompt using an ensemble of two text encoders.
from hard-prompts-made-easy.
Certainly! I wish you a good conference.
from hard-prompts-made-easy.
Gotcha. Will try it out 😎
from hard-prompts-made-easy.
Any updates?
from hard-prompts-made-easy.
Hi, I've given it a try, but it seems that it's not working as expected. It's not learning.
I'm doing my best, but I'm new to Python and Torch, so there might be something I'm overlooking in my code, (even some unforgivable mistakes :-) ).
I tried to use "ViT-bigG-14" clip model but it's too big for my 4090 24GB
I used last diffusers==0.20.0
Here my prompt inversion:
import open_clip
import torch
from torchvision import transforms
import argparse
import datetime
import os
import copy
from transformers.optimization import Adafactor, AdafactorSchedule
from optim_utils import *
from diffusers import DDPMScheduler, DPMSolverMultistepScheduler
from modified_stable_diffusion_xl_pipeline import ModifiedStableDiffusionPipelineXL
args = argparse.Namespace()
args.iter = 1000
args.prompt_len = 8
args.lr = 0.1
args.weight_decay = 0.1
args.opt_iters = 3000
args.eval_step = 50
args.prompt_bs = 1
args.loss_weight = 1.0
args.print_step = 100
args.batch_size = 1
# args.clip_model = "ViT-bigG-14"
# args.clip_pretrain = "laion2b_s39b_b160k"
args.clip_model = "ViT-H-14"
args.clip_pretrain = "laion2b_s32b_b79k"
best_loss = -999
eval_loss = -99999
best_text = ""
weight_dtype = torch.bfloat16
device = "cuda" if torch.cuda.is_available() else "cpu"
def initialize_prompt(tokenizers_list, token_embeddings_list, args, device):
prompt_len = args.prompt_len
# randomly optimize prompt embeddings
prompt_embeds_list = []
dummy_embeds_list = []
dummy_ids_list = []
prompt_ids = torch.randint(len(tokenizers_list[0].encoder), (args.prompt_bs, prompt_len)).to(device)
for tokenizer, token_embeddings in zip(tokenizers_list, token_embeddings_list):
prompt_embeds = token_embeddings(prompt_ids).detach()
prompt_embeds.requires_grad = True
# initialize the template
# -1 for optimized tokens
dummy_ids = [tokenizer.bos_token_id] + [-1] * prompt_len + [tokenizer.eos_token_id] + [0] * (75 - prompt_len)
dummy_ids = torch.tensor([dummy_ids] * args.prompt_bs).to(device)
# for getting dummy embeds; -1 won't work for token_embedding
tmp_dummy_ids = [tokenizer.bos_token_id] + [0] * prompt_len + [tokenizer.eos_token_id] + [0] * (75 - prompt_len)
tmp_dummy_ids = torch.tensor([tmp_dummy_ids] * args.prompt_bs).to(device)
dummy_embeds = token_embeddings(tmp_dummy_ids).detach()
dummy_embeds.requires_grad = False
prompt_embeds_list.append(prompt_embeds)
dummy_embeds_list.append(dummy_embeds)
dummy_ids_list.append(dummy_ids)
return prompt_embeds_list, dummy_embeds_list, dummy_ids_list
model_id = "stabilityai/stable-diffusion-xl-base-1.0"
scheduler = DPMSolverMultistepScheduler.from_pretrained(model_id, subfolder="scheduler")
# scheduler = DDPMScheduler(
# beta_start=0.00085, beta_end=0.012, beta_schedule="scaled_linear", num_train_timesteps=1000, clip_sample=False
# )
pipe = ModifiedStableDiffusionPipelineXL.from_pretrained(
model_id,
scheduler=scheduler,
torch_dtype=weight_dtype,
variant="fp16",
use_safetensors=True
)
pipe = pipe.to(device)
pipe.vae.requires_grad_(False)
pipe.vae.eval()
pipe.unet.requires_grad_(True)
pipe.unet.train()
clip_model, _, clip_preprocess = open_clip.create_model_and_transforms(args.clip_model, pretrained=args.clip_pretrain, device=device)
image_length = 1024
tokenizers_list = [pipe.tokenizer, pipe.tokenizer_2] if pipe.tokenizer is not None else [pipe.tokenizer_2]
token_embeddings_list =[pipe.text_encoder.text_model.embeddings.token_embedding, pipe.text_encoder_2.text_model.embeddings.token_embedding]
preprocess = transforms.Compose(
[
transforms.Resize(1024, interpolation=transforms.InterpolationMode.BILINEAR),
transforms.CenterCrop(1024),
transforms.ToTensor(),
]
)
urls = [
"https://www.parkwestgallery.com/wp-content/uploads/2017/10/im811243-e1507918728745.jpg",
]
orig_images = list(filter(None,[download_image(url) for url in urls]))
SDXL_VAE_SCALE_FACTOR = 0.13025
with torch.no_grad():
curr_images = [preprocess(i).unsqueeze(0) for i in orig_images]
curr_images = torch.concatenate(curr_images).to(device)
all_latents = pipe.vae.encode(curr_images.to(weight_dtype)).latent_dist.sample()
all_latents = all_latents * SDXL_VAE_SCALE_FACTOR
#initialize random prompt
prompt_embeds_list, dummy_embeds_list, dummy_ids_list = initialize_prompt(tokenizers_list, token_embeddings_list, args, device)
# input_optimizer = Adafactor(prompt_embeds_list, scale_parameter=False, relative_step=False, warmup_init=False, lr=0.2)
input_optimizer = torch.optim.AdamW(prompt_embeds_list, lr=args.lr, weight_decay=args.weight_decay)
input_optim_scheduler = None
for step in range(args.opt_iters):
padded_embeds_list = []
padded_dummy_ids_list = []
tmp_embeds_list = []
nn_indices_list = []
# forward projection (top1 semantic_search(prompt_embeds, token_embedding))
for prompt_embeds, dummy_embeds, dummy_ids, tokenizer, token_embeddings in zip(prompt_embeds_list, dummy_embeds_list, dummy_ids_list, tokenizers_list, token_embeddings_list):
projected_embeds, nn_indices = nn_project(prompt_embeds, token_embeddings)
tmp_embeds = copy.deepcopy(prompt_embeds)
tmp_embeds.data = projected_embeds.data
tmp_embeds.requires_grad = True
# padding and repeat
padded_embeds = copy.deepcopy(dummy_embeds)
padded_embeds[:, 1:args.prompt_len+1] = tmp_embeds
padded_embeds = padded_embeds.repeat(args.batch_size, 1, 1)
padded_dummy_ids = dummy_ids.repeat(args.batch_size, 1)
nn_indices_list.append(nn_indices)
padded_embeds_list.append(padded_embeds)
padded_dummy_ids_list.append(padded_dummy_ids)
tmp_embeds_list.append(tmp_embeds)
# randomly sample sample images and get features
if args.batch_size is None:
latents = all_latents
else:
perm = torch.randperm(len(all_latents))
idx = perm[:args.batch_size]
latents = all_latents[idx]
# Sample noise that we'll add to the latents
noise = torch.randn_like(latents)
bsz = latents.shape[0]
# Sample a random timestep for each image
timesteps = torch.randint(0, 1000, (bsz,), device=latents.device)
timesteps = timesteps.long()
# Add noise to the latents according to the noise magnitude at each timestep
noisy_latents = pipe.scheduler.add_noise(latents, noise, timesteps)
# Get the target for loss depending on the prediction type
if pipe.scheduler.config.prediction_type == "epsilon":
target = noise
elif pipe.scheduler.config.prediction_type == "v_prediction":
target = pipe.scheduler.get_velocity(latents, noise, timesteps)
else:
raise ValueError(f"Unknown prediction type {pipe.scheduler.config.prediction_type}")
# get text embeddings
text_embeddings, pooled_prompt_embeds = pipe._get_text_embedding_with_embeddings(padded_dummy_ids_list, padded_embeds_list)
add_time_ids = pipe._get_add_time_ids(
(image_length, image_length), (0,0), (image_length, image_length), dtype=prompt_embeds.dtype
).to(device)
add_text_embeds = pooled_prompt_embeds
# Predict the noise residual and compute loss
model_pred = pipe.unet(noisy_latents, timesteps, encoder_hidden_states=text_embeddings, added_cond_kwargs = {"text_embeds": add_text_embeds, "time_ids": add_time_ids}).sample
loss = torch.nn.functional.mse_loss(model_pred.float(), target.float(), reduction="mean")
prompt_embeds_list = torch.autograd.grad(loss, tmp_embeds_list)
input_optimizer.step()
input_optimizer.zero_grad()
curr_lr = input_optimizer.param_groups[0]["lr"]
### eval
if step % args.eval_step == 0:
prompt_1 = decode_ids(nn_indices_list[0], tokenizers_list[0])[0]
prompt_2 = decode_ids(nn_indices_list[1], tokenizers_list[1])[0]
print(f"step: {step}, lr: {curr_lr}, cosim: {eval_loss:.3f}, best_cosim: {best_loss:.3f}, best prompt: {best_text}")
with torch.no_grad():
pred_imgs = pipe(
prompt_1,
prompt_2,
num_images_per_prompt=4,
guidance_scale=9,
num_inference_steps=50,
height=image_length,
width=image_length,
output_type='pil'
).images
eval_loss = measure_similarity(orig_images, pred_imgs, clip_model, clip_preprocess, device)
if best_loss < eval_loss:
best_loss = eval_loss
best_text = f'{prompt_1} {prompt_2}'
print()
print(f"Best shot: consine similarity: {best_loss:.3f}")
print(f"text: {best_text}")
# you can customize the learned prompt here
prompt = best_text
num_images = 4
guidance_scale = 9
num_inference_steps = 25
images = pipe(
prompt,
num_images_per_prompt=num_images,
guidance_scale=guidance_scale,
num_inference_steps=num_inference_steps,
height=image_length,
width=image_length,
).images
timestamp = datetime.datetime.now().strftime("%Y%m%d-%H%M%S")
for i, img in enumerate(images):
img.save(os.path.join('output/', f"sd2_result_{timestamp}_{i:03d}.png"))
print("Save images.")
Here the modified Pipeline:
from typing import Callable, List, Optional, Union
import torch
from diffusers import StableDiffusionXLPipeline
from transformers import CLIPTextModel, CLIPTextModelWithProjection, CLIPTokenizer
from transformers.models.clip.modeling_clip import CLIPTextModelOutput
from diffusers.pipelines.stable_diffusion_xl import StableDiffusionXLPipelineOutput
from diffusers.utils import logging
from transformers.modeling_outputs import BaseModelOutputWithPooling
from typing import Any, Callable, Dict, List, Optional, Tuple, Union
logger = logging.get_logger(__name__) # pylint: disable=invalid-name
class ModifiedStableDiffusionPipelineXL(StableDiffusionXLPipeline):
def __init__(self,
vae,
text_encoder: CLIPTextModel,
text_encoder_2: CLIPTextModelWithProjection,
tokenizer: CLIPTokenizer,
tokenizer_2: CLIPTokenizer,
unet,
scheduler,
force_zeros_for_empty_prompt: bool = True,
add_watermarker: Optional[bool] = None
):
super(ModifiedStableDiffusionPipelineXL, self).__init__(vae,
text_encoder,
text_encoder_2,
tokenizer,
tokenizer_2,
unet,
scheduler,
force_zeros_for_empty_prompt,
add_watermarker)
def _build_causal_attention_mask(self,bsz, seq_len, dtype):
# lazily create causal attention mask, with full attention between the vision tokens
# pytorch uses additive attention mask; fill with -inf
mask = torch.empty(bsz, seq_len, seq_len, dtype=dtype)
mask.fill_(torch.tensor(torch.finfo(dtype).min))
mask.triu_(1) # zero out the lower diagonal
mask = mask.unsqueeze(1) # expand mask
return mask
def _encode_embeddings(self, text_encoder, input_ids, prompt_embeddings, attention_mask=None):
output_attentions = text_encoder.text_model.config.output_attentions
output_hidden_states = True
return_dict = text_encoder.text_model.config.use_return_dict
hidden_states = text_encoder.text_model.embeddings(inputs_embeds=prompt_embeddings)
bsz, seq_len = input_ids.shape[0], input_ids.shape[1]
# CLIP's text model uses causal mask, prepare it here.
# https://github.com/openai/CLIP/blob/cfcffb90e69f37bf2ff1e988237a0fbe41f33c04/clip/model.py#L324
causal_attention_mask = self._build_causal_attention_mask(bsz, seq_len, hidden_states.dtype).to(
hidden_states.device
)
# expand attention_mask
if attention_mask is not None:
# [bsz, seq_len] -> [bsz, 1, tgt_seq_len, src_seq_len]
attention_mask = text_encoder.text_model._expand_mask(attention_mask, hidden_states.dtype)
encoder_outputs = text_encoder.text_model.encoder(
inputs_embeds=hidden_states,
attention_mask=attention_mask,
causal_attention_mask=causal_attention_mask,
output_attentions=output_attentions,
output_hidden_states=output_hidden_states,
return_dict=return_dict,
)
last_hidden_state = encoder_outputs[0]
last_hidden_state = text_encoder.text_model.final_layer_norm(last_hidden_state)
# text_embeds.shape = [batch_size, sequence_length, transformer.width]
# take features from the eot embedding (eot_token is the highest number in each sequence)
# casting to torch.int for onnx compatibility: argmax doesn't support int64 inputs with opset 14
pooled_output = last_hidden_state[
torch.arange(last_hidden_state.shape[0], device=last_hidden_state.device),
input_ids.to(dtype=torch.int, device=last_hidden_state.device).argmax(dim=-1),
]
text_outputs = BaseModelOutputWithPooling(
last_hidden_state=last_hidden_state,
pooler_output=pooled_output,
hidden_states=encoder_outputs.hidden_states,
attentions=encoder_outputs.attentions,
)
if isinstance(text_encoder, CLIPTextModelWithProjection):
pooled_output = text_outputs[1]
text_embeds = text_encoder.text_projection(pooled_output)
if not return_dict:
outputs = (text_embeds, text_outputs[0]) + text_outputs[2:]
return tuple(output for output in outputs if output is not None)
return CLIPTextModelOutput(
text_embeds=text_embeds,
last_hidden_state=text_outputs.last_hidden_state,
hidden_states=text_outputs.hidden_states,
attentions=text_outputs.attentions,
)
if not return_dict:
return (last_hidden_state, pooled_output) + encoder_outputs[1:]
return text_outputs
def _get_text_embedding_with_embeddings(self, text_input_ids_list, prompt_embeddings_list):
text_encoders_list = (
[self.text_encoder, self.text_encoder_2] if self.text_encoder is not None else [self.text_encoder_2]
)
prompt_embeds_list = []
for text_input_ids, prompt_embeddings, text_encoder in zip(text_input_ids_list, prompt_embeddings_list, text_encoders_list):
text_embeddings = self._encode_embeddings(
text_encoder,
text_input_ids,
prompt_embeddings
)
# We are only ALWAYS interested in the pooled output of the final text encoder
pooled_prompt_embeds = text_embeddings[0]
text_embeddings = text_embeddings.hidden_states[-2]
prompt_embeds_list.append(text_embeddings)
prompt_embeds = torch.cat(prompt_embeds_list, dim=-1)
return prompt_embeds, pooled_prompt_embeds
from hard-prompts-made-easy.
Hi @manzonif , sorry about the late response. Not sure if you have any progress on this, but I recently tried to optimize two independent prompts for the two text encoders. However, it doesn't work very well. I am going to double-check the code and also see if optimizing a universal prompt with an ensemble of two text encoders works.
Thanks for your patience!
from hard-prompts-made-easy.
So, did anyone have any luck making this work for XL? 🤔
from hard-prompts-made-easy.
Related Issues (15)
- Only for 2.1? HOT 2
- Auto1111 web ui extension HOT 5
- do you plan to create this for sd 1.5 too? HOT 4
- Fluency loss HOT 1
- Demo does not work on huggingface HOT 1
- algorithm 1, and the necessity of image encoder HOT 2
- Prompt Optimization without CLIP Loss HOT 7
- [Feature] Can you make a simple UI for this? HOT 1
- [Discussion] Difference between different models? And other values? HOT 1
- question about '<start_of_text>' HOT 3
- Any extension for ComfyUI yet? HOT 1
- Questions around running this to get more usable prompts HOT 4
- Negative and weighted prompts HOT 1
- gpu detection fails HOT 4
Recommend Projects
-
React
A declarative, efficient, and flexible JavaScript library for building user interfaces.
-
Vue.js
🖖 Vue.js is a progressive, incrementally-adoptable JavaScript framework for building UI on the web.
-
Typescript
TypeScript is a superset of JavaScript that compiles to clean JavaScript output.
-
TensorFlow
An Open Source Machine Learning Framework for Everyone
-
Django
The Web framework for perfectionists with deadlines.
-
Laravel
A PHP framework for web artisans
-
D3
Bring data to life with SVG, Canvas and HTML. 📊📈🎉
-
Recommend Topics
-
javascript
JavaScript (JS) is a lightweight interpreted programming language with first-class functions.
-
web
Some thing interesting about web. New door for the world.
-
server
A server is a program made to process requests and deliver data to clients.
-
Machine learning
Machine learning is a way of modeling and interpreting data that allows a piece of software to respond intelligently.
-
Visualization
Some thing interesting about visualization, use data art
-
Game
Some thing interesting about game, make everyone happy.
Recommend Org
-
Facebook
We are working to build community through open source technology. NB: members must have two-factor auth.
-
Microsoft
Open source projects and samples from Microsoft.
-
Google
Google ❤️ Open Source for everyone.
-
Alibaba
Alibaba Open Source for everyone
-
D3
Data-Driven Documents codes.
-
Tencent
China tencent open source team.
from hard-prompts-made-easy.