Giter VIP home page Giter VIP logo

tokenflow's Introduction

TokenFlow: Consistent Diffusion Features for Consistent Video Editing (ICLR 2024)

arXiv Hugging Face Spaces Pytorch

teaser.mp4

TokenFlow is a framework that enables consistent video editing, using a pre-trained text-to-image diffusion model, without any further training or finetuning.

The generative AI revolution has been recently expanded to videos. Nevertheless, current state-of-the-art video mod- els are still lagging behind image models in terms of visual quality and user control over the generated content. In this work, we present a framework that harnesses the power of a text-to-image diffusion model for the task of text-driven video editing. Specifically, given a source video and a target text-prompt, our method generates a high-quality video that adheres to the target text, while preserving the spatial lay- out and dynamics of the input video. Our method is based on our key observation that consistency in the edited video can be obtained by enforcing consistency in the diffusion feature space. We achieve this by explicitly propagating diffusion features based on inter-frame correspondences, readily available in the model. Thus, our framework does not require any training or fine-tuning, and can work in con- junction with any off-the-shelf text-to-image editing method. We demonstrate state-of-the-art editing results on a variety of real-world videos.

For more see the project webpage.

Sample results

Environment

conda create -n tokenflow python=3.9
conda activate tokenflow
pip install -r requirements.txt

Preprocess

Preprocess you video by running using the following command:

python preprocess.py --data_path <data/myvideo.mp4> \
                     --inversion_prompt <'' or a string describing the video content>

Additional arguments:

                     --save_dir <latents>
                     --H <video height>
                     --W <video width>
                     --sd_version <Stable-Diffusion version>
                     --steps <number of inversion steps>
                     --save_steps <number of sampling steps that will be used later for editing>
                     --n_frames <number of frames>
                     

more information on the arguments can be found here.

Note:

The video reconstruction will be saved as inverted.mp4. A good reconstruction is required for successfull editing with our method.

Editing

  • TokenFlow is designed for video for structure-preserving edits.
  • Our method is built on top of an image editing technique (e.g., Plug-and-Play, ControlNet, etc.) - therefore, it is important to ensure that the edit works with the chosen base technique.
  • The LDM decoder may introduce some jitterness, depending on the original video.

To edit your video, first create a yaml config as in configs/config_pnp.yaml. Then run

python run_tokenflow_pnp.py

Similarly, if you want to use ControlNet or SDEedit, create a yaml config as in config/config_controlnet.yaml or configs/config_SDEdit.yaml and run python run_tokenflow_controlnet.py or python run_tokenflow_SDEdit.py respectivly.

Citation

@article{tokenflow2023,
        title = {TokenFlow: Consistent Diffusion Features for Consistent Video Editing},
        author = {Geyer, Michal and Bar-Tal, Omer and Bagon, Shai and Dekel, Tali},
        journal={arXiv preprint arxiv:2307.10373},
        year={2023}
        }

tokenflow's People

Contributors

duongna21 avatar michalgeyer avatar omerbt avatar

Stargazers

 avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar

Watchers

 avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar

tokenflow's Issues

ValueError: attempt to get argmax of an empty sequence

Any suggestions how to fix this issue
I am using the default provided with the repository, only updated the following in \TokenFlow\configs\config_pnp.yaml

batch_size: 8
to
batch_size: 1

data_path: 'data/woman-running'
to
data_path: 'data/woman-running.mp4'

(tokenflow) C:\tut\TokenFlow>python run_tokenflow_pnp.py --config_path "configs/config_pnp.yaml"
A matching Triton is not available, some optimizations will not be enabled
Traceback (most recent call last):
  File "C:\Users\nitin\miniconda3\envs\tokenflow\lib\site-packages\xformers\__init__.py", line 55, in _is_triton_available
    from xformers.triton.softmax import softmax as triton_softmax  # noqa
  File "C:\Users\nitin\miniconda3\envs\tokenflow\lib\site-packages\xformers\triton\softmax.py", line 11, in <module>
    import triton
ModuleNotFoundError: No module named 'triton'
{'seed': 1, 'device': 'cuda', 'output_path': 'tokenflow-results_pnp_SD_2.1\\woman-running\\a marble sculpture of a woman running, Venus de Milo\\attn_0.5_f_0.8\\batch_size_1\\50', 'data_path': 'data/woman-running.mp4', 'latents_path': 'latents', 'n_inversion_steps': 500, 'n_frames': 40, 'sd_version': '2.1', 'guidance_scale': 7.5, 'n_timesteps': 50, 'prompt': 'a marble sculpture of a woman running, Venus de Milo', 'negative_prompt': 'ugly, blurry, low res, unrealistic, unaesthetic', 'batch_size': 1, 'pnp_attn_t': 0.5, 'pnp_f_t': 0.8}
{'seed': 1, 'device': 'cuda', 'output_path': 'tokenflow-results_pnp_SD_2.1\\woman-running\\a marble sculpture of a woman running, Venus de Milo\\attn_0.5_f_0.8\\batch_size_1\\50', 'data_path': 'data/woman-running.mp4', 'latents_path': 'latents', 'n_inversion_steps': 500, 'n_frames': 40, 'sd_version': '2.1', 'guidance_scale': 7.5, 'n_timesteps': 50, 'prompt': 'a marble sculpture of a woman running, Venus de Milo', 'negative_prompt': 'ugly, blurry, low res, unrealistic, unaesthetic', 'batch_size': 1, 'pnp_attn_t': 0.5, 'pnp_f_t': 0.8}
Loading SD model
model_index.json: 100%|████████████████████████████████████████████████████████████████| 543/543 [00:00<00:00, 673kB/s]
tokenizer/special_tokens_map.json: 100%|██████████████████████████████████████████████| 460/460 [00:00<00:00, 10.6kB/s]
tokenizer/tokenizer_config.json: 100%|████████████████████████████████████████████████| 807/807 [00:00<00:00, 50.9kB/s]
(…)ature_extractor/preprocessor_config.json: 100%|████████████████████████████████████| 342/342 [00:00<00:00, 57.9kB/s]
unet/config.json: 100%|███████████████████████████████████████████████████████████████| 911/911 [00:00<00:00, 83.5kB/s]
tokenizer/merges.txt: 100%|██████████████████████████████████████████████████████████| 525k/525k [00:00<00:00, 822kB/s]
scheduler/scheduler_config.json: 100%|████████████████████████████████████████████████| 346/346 [00:00<00:00, 33.6kB/s]
text_encoder/config.json: 100%|███████████████████████████████████████████████████████| 613/613 [00:00<00:00, 40.8kB/s]
vae/config.json: 100%|█████████████████████████████████████████████████████████████████| 553/553 [00:00<00:00, 217kB/s]
tokenizer/vocab.json: 100%|████████████████████████████████████████████████████████| 1.06M/1.06M [00:01<00:00, 544kB/s]
vae/diffusion_pytorch_model.safetensors: 100%|██████████████████████████████████████| 335M/335M [01:04<00:00, 5.20MB/s]
text_encoder/model.safetensors: 100%|█████████████████████████████████████████████| 1.36G/1.36G [02:50<00:00, 7.98MB/s]
unet/diffusion_pytorch_model.safetensors: 100%|███████████████████████████████████| 3.46G/3.46G [06:49<00:00, 8.45MB/s]
Fetching 13 files: 100%|███████████████████████████████████████████████████████████████| 13/13 [06:51<00:00, 31.69s/it]
Loading pipeline components...: 100%|████████████████████████████████████████████████████| 6/6 [00:06<00:00,  1.02s/it]
SD model loadedpytorch_model.safetensors: 100%|███████████████████████████████████| 3.46G/3.46G [06:49<00:00, 12.7MB/s]
Traceback (most recent call last):
  File "C:\tut\TokenFlow\run_tokenflow_pnp.py", line 301, in <module>
    run(config)
  File "C:\tut\TokenFlow\run_tokenflow_pnp.py", line 279, in run
    editor = TokenFlow(config)
  File "C:\tut\TokenFlow\run_tokenflow_pnp.py", line 60, in __init__
    self.latents_path = self.get_latents_path()
  File "C:\tut\TokenFlow\run_tokenflow_pnp.py", line 119, in get_latents_path
    latents_path = latents_path[np.argmax(n_frames)]
  File "<__array_function__ internals>", line 200, in argmax
  File "C:\Users\nitin\miniconda3\envs\tokenflow\lib\site-packages\numpy\core\fromnumeric.py", line 1242, in argmax
    return _wrapfunc(a, 'argmax', axis=axis, out=out, **kwds)
  File "C:\Users\nitin\miniconda3\envs\tokenflow\lib\site-packages\numpy\core\fromnumeric.py", line 54, in _wrapfunc
    return _wrapit(obj, method, *args, **kwds)
  File "C:\Users\nitin\miniconda3\envs\tokenflow\lib\site-packages\numpy\core\fromnumeric.py", line 43, in _wrapit
    result = getattr(asarray(obj), method)(*args, **kwds)
ValueError: attempt to get argmax of an empty sequence

VRAM Usage relative to n_frames

Hi,

I am relatively new to the AI space so apologies if I am missing key information.

I've noticed that I am able to fully process videos that are under a certain number of n_frames.

For example I can successfully process 30 n_frames of a video however the more I increase the n_frames the more VRAM is required.

Does this process mean that all frames need to be loaded into VRAM, is it true the more frames the more VRAM you require or are there some optimizations needed for longer videos?

What is the code of 'NN field compute & warp' ?

Hi, thank you for your nice work. Is the part of th code 'NN field compute & warp' is this code?

def prepare_depth_maps(self, model_type='DPT_Large', device='cuda'):
    depth_maps = []
    midas = torch.hub.load("intel-isl/MiDaS", model_type)
    midas.to(device)
    midas.eval()

    midas_transforms = torch.hub.load("intel-isl/MiDaS", "transforms")

    if model_type == "DPT_Large" or model_type == "DPT_Hybrid":
        transform = midas_transforms.dpt_transform
    else:
        transform = midas_transforms.small_transform

    for i in range(len(self.paths)):
        img = cv2.imread(self.paths[i])
        img = cv2.cvtColor(img, cv2.COLOR_BGR2RGB)

        latent_h = img.shape[0] // 8
        latent_w = img.shape[1] // 8
        
        input_batch = transform(img).to(device)
        prediction = midas(input_batch)

        depth_map = torch.nn.functional.interpolate(
            prediction.unsqueeze(1),
            size=(latent_h, latent_w),
            mode="bicubic",
            align_corners=False,
        )
        depth_min = torch.amin(depth_map, dim=[1, 2, 3], keepdim=True)
        depth_max = torch.amax(depth_map, dim=[1, 2, 3], keepdim=True)
        depth_map = 2.0 * (depth_map - depth_min) / (depth_max - depth_min) - 1.0
        depth_maps.append(depth_map)

    return torch.cat(depth_maps).to(self.device).to(torch.float16)

"ValueError: max() arg is an empty sequence" when trying to run via jupyterlab env

I'm trying to tokenflow via jupyterlab env.

On the last step, when running run_tokenflow_pnp.py, I get this error:

Traceback (most recent call last):
  File "/home/jovyan/token-flow/run_tokenflow_pnp.py", line 301, in <module>
    run(config)
  File "/home/jovyan/token-flow/run_tokenflow_pnp.py", line 280, in run
    editor = TokenFlow(config)
  File "/home/jovyan/token-flow/run_tokenflow_pnp.py", line 62, in __init__
    self.paths, self.frames, self.latents, self.eps = self.get_data()
  File "/home/jovyan/token-flow/run_tokenflow_pnp.py", line 183, in get_data
    eps = self.get_ddim_eps(latents, range(self.config["n_frames"])).to(torch.float16).to(self.device)
  File "/home/jovyan/token-flow/run_tokenflow_pnp.py", line 187, in get_ddim_eps
    noisest = max([int(x.split('_')[-1].split('.')[0]) for x in glob.glob(os.path.join(self.latents_path, f'noisy_latents_*.pt'))])
ValueError: max() arg is an empty sequence

Requirements?

This looks so promising, I just don't wanna promise myself something I won't be able to afford :-}

What is noisy video Jt

image Hi, cool work! In this picture, how can we get Jt? I guess it is the result of per frame editing by image to image editing with prompt "colorful painting"? am I right?

Random images if we use different SD version

@MichalGeyer @duongna21 @omerbt If I use any other SD versions other than the 1.5, 2.0 -base, 2.1-base I'm getting random images after preprocess. I have used 2.1-unclip with image conditioning but getting these results after preprocess. I also tried with sd 2.1 ( not base) still got the similar output. May I know what is causing this kind of output?

Code Release

Hi,
This work is very interesting!
Are you planning to release the code? If yes, what is the approximate timeline for that?
Additionally, what is the maximum allowed video length for this method?
Thanks a lot!

confusions between reshape_heads_to_batch_dim and heads_to_batch_dim

There are some confusing parts between the usage of the two pairs in tokenflow_utils.py:
1 reshape_heads_to_batch_dim and head_to_batch_dim
2 reshape_batch_dim_to_heads and batch_dim_to_head

For example, in tokenflow_utils.py, head_to_batch_dim appears in two blocks (in line 140, and 241 respectively)

to run the pnp example successfully, I added a line before the block at line 241 like
self.head_to_batch_dim = self.reshape_heads_to_batch_dim
and it works. But to run the sdedit example successful, I need to add this line in a different place (in the line 140 block).

Same for the pair of {reshape_batch_dim_to_heads and batch_dim_to_head}. Is there a principled method to tackle this?

In-painting

Have you tested this for the in-painting task like StableDiffusion?

Is it all code released?

Hi, thanks for great work!

But, I can't find 'compute nn fields' and 'tokenflow propagation' it just looks using PnP instead.
Is this repo is not implemented with original tokenflow method fully?

batching pivots allows processing bigger/longer sequences

for those who (like me) wanted to apply this exciting technique for longer videos:
i've integrated this method into my SD repo https://github.com/eps696/SDfu and added there batches for pivots with offloading them onto CPU. this allowed to process e.g. 300 frames in 960x540 res on 3090 (24gb).
as i renamed some variables to my convenience, my code is not directly copypastable into this repo, yet i hope it's readable enough to apply here. the solution is also pretty clumsy, as i had very little idea about that attention stuff and just tried to debug OOMs..

Output is same as input file, why?

(tokenflow) C:\tut\TokenFlow>python preprocess.py --data_path data/woman-running.mp4 --inversion_prompt "a silver sculpt
ure of a woman running"
A matching Triton is not available, some optimizations will not be enabled
Traceback (most recent call last):
  File "C:\Users\nitin\miniconda3\envs\tokenflow\lib\site-packages\xformers\__init__.py", line 55, in _is_triton_available
    from xformers.triton.softmax import softmax as triton_softmax  # noqa
  File "C:\Users\nitin\miniconda3\envs\tokenflow\lib\site-packages\xformers\triton\softmax.py", line 11, in <module>
    import triton
ModuleNotFoundError: No module named 'triton'
C:\Users\nitin\miniconda3\envs\tokenflow\lib\site-packages\torchvision\io\video.py:161: UserWarning: The pts_unit 'pts' gives wrong results. Please use pts_unit 'sec'.
  warnings.warn("The pts_unit 'pts' gives wrong results. Please use pts_unit 'sec'.")
[INFO] loading stable diffusion...
C:\Users\nitin\miniconda3\envs\tokenflow\lib\site-packages\diffusers\models\attention_processor.py:1117: UserWarning: 1Torch was not compiled with flash attention. (Triggered internally at C:\actions-runner\_work\pytorch\pytorch\builder\windows\pytorch\aten\src\ATen\native\transformers\cuda\sdp_utils.cpp:455.)
  hidden_states = F.scaled_dot_product_attention(
[INFO] loaded stable diffusion!
100%|████████████████████████████████████████████████████████████████████████████████| 500/500 [18:02<00:00,  2.16s/it]
100%|████████████████████████████████████████████████████████████████████████████████| 500/500 [18:03<00:00,  2.17s/it]

pip list


absl-py                 2.1.0
accelerate              0.29.3
av                      12.0.0
Brotli                  1.0.9
certifi                 2024.2.2
chardet                 4.0.0
charset-normalizer      2.0.4
colorama                0.4.6
diffusers               0.20.0
filelock                3.13.1
fsspec                  2024.3.1
ftfy                    6.2.0
gmpy2                   2.1.2
grpcio                  1.62.2
huggingface-hub         0.22.2
idna                    3.7
importlib_metadata      7.1.0
intel-openmp            2021.4.0
Jinja2                  3.1.3
kornia                  0.7.2
kornia_rs               0.1.3
Markdown                3.6
MarkupSafe              2.1.3
mkl                     2021.4.0
mkl-fft                 1.3.1
mkl-random              1.2.2
mkl-service             2.4.0
mpmath                  1.3.0
networkx                3.1
numpy                   1.24.3
opencv-python           4.9.0.80
packaging               24.0
pillow                  10.3.0
pip                     23.3.1
protobuf                5.26.1
psutil                  5.9.8
PySocks                 1.7.1
PyYAML                  6.0.1
regex                   2024.4.28
requests                2.31.0
safetensors             0.4.3
setuptools              68.2.2
six                     1.16.0
sympy                   1.12
tbb                     2021.12.0
tensorboard             2.16.2
tensorboard-data-server 0.7.2
tokenizers              0.19.1
torch                   2.3.0+cu121
torchvision             0.18.0
tqdm                    4.66.2
transformers            4.40.1
typing_extensions       4.11.0
urllib3                 2.1.0
wcwidth                 0.2.13
Werkzeug                3.0.2
wheel                   0.41.2
win-inet-pton           1.1.0
xformers                0.0.26.post1
zipp                    3.18.1

Output
https://github.com/omerbt/TokenFlow/assets/2102186/92f3cb7d-67f8-48a2-bccb-abe062384af8

not compatible with diffusers 0.21+ [with workaround]

everything runs ok on diffusers version 0.20 or below, while getting this error on diffusers 0.21:

File "F:\_neuro\SDfu\lib\tokenflow.py", line 185, in denoise_step
  noise_pred = self.unet(lat_in, t, conds).sample
File "C:\Users\eps\AppData\Local\Programs\Python\Python39\lib\site-packages\torch\nn\modules\module.py", line 1501, in _call_impl
  return forward_call(*args, **kwargs)
File "C:\Users\eps\AppData\Local\Programs\Python\Python39\lib\site-packages\diffusers\models\unet_2d_condition.py", line 1018, in forward
  sample = upsample_block(
File "C:\Users\eps\AppData\Local\Programs\Python\Python39\lib\site-packages\torch\nn\modules\module.py", line 1501, in _call_impl
  return forward_call(*args, **kwargs)
File "C:\Users\eps\AppData\Local\Programs\Python\Python39\lib\site-packages\diffusers\models\unet_2d_blocks.py", line 2227, in forward
  hidden_states = resnet(hidden_states, temb, scale=lora_scale)
File "C:\Users\eps\AppData\Local\Programs\Python\Python39\lib\site-packages\torch\nn\modules\module.py", line 1501, in _call_impl
  return forward_call(*args, **kwargs)
TypeError: forward() got an unexpected keyword argument 'scale'

i've got python 2.0.1, xformers 0.0.21 (but again, it's only diffusers version that brings this error or not).
UPD: the error is only for pnp method, sdedit works ok.

there were some similar issues on their github, maybe it helps:
huggingface/diffusers#3348
huggingface/diffusers#5028

a problem about the code,thanks

it seems that you change all the basictransformerblock in both down_blocks, mid_blocks and up_blocks. why still change the up_blocks in the unet again?

def register_extended_attention(model):
    for _, module in model.unet.named_modules():
        if isinstance_str(module, "BasicTransformerBlock"):
            module.attn1.forward = sa_forward(module.attn1)

    res_dict = {1: [1, 2], 2: [0, 1, 2], 3: [0, 1, 2]}
    # we are injecting attention in blocks 4 - 11 of the decoder, so not in the first block of the lowest resolution
    for res in res_dict:
        for block in res_dict[res]:
            module = model.unet.up_blocks[res].attentions[block].transformer_blocks[0].attn1
            module.forward = sa_forward(module)

what is the correct way to run demo?

python preprocess.py --data_path data/woman-running.mp4 --inversion_prompt "a marble sculpture of a woman running, Venus de Milossets"

the "latents" folder is created , but the video file "inverted.mp4" is not changed .

then I run "python run_tokenflow_pnp.py"

Traceback (most recent call last):
File "/home/xxx/TokenFlow/run_tokenflow_pnp.py", line 300, in
run(config)
File "/home/xxx/TokenFlow/run_tokenflow_pnp.py", line 280, in run
editor.edit_video()
File "/home/xxx/TokenFlow/run_tokenflow_pnp.py", line 258, in edit_video
edited_frames = self.sample_loop(noisy_latents, torch.arange(self.config["n_frames"]))
File "/home/xxx/TokenFlow/run_tokenflow_pnp.py", line 267, in sample_loop
x = self.batched_denoise_step(x, t, indices)
File "/home/xxx/anaconda3/envs/tokenflow/lib/python3.9/site-packages/torch/amp/autocast_mode.py", line 14, in decorate_autocast
return func(*args, **kwargs)
File "/home/xxx/TokenFlow/run_tokenflow_pnp.py", line 227, in batched_denoise_step
self.denoise_step(x[pivotal_idx], t, indices[pivotal_idx])
File "/home/xxx/anaconda3/envs/tokenflow/lib/python3.9/site-packages/torch/utils/_contextlib.py", line 115, in decorate_context
return func(*args, **kwargs)
File "/home/xxx/TokenFlow/run_tokenflow_pnp.py", line 210, in denoise_step
noise_pred = self.unet(latent_model_input, t, encoder_hidden_states=text_embed_input)['sample']
File "/home/xxx/anaconda3/envs/tokenflow/lib/python3.9/site-packages/torch/nn/modules/module.py", line 1501, in _call_impl
return forward_call(*args, **kwargs)
File "/home/xxx/anaconda3/envs/tokenflow/lib/python3.9/site-packages/diffusers/models/unet_2d_condition.py", line 1018, in forward
sample = upsample_block(
File "/home/xxx/anaconda3/envs/tokenflow/lib/python3.9/site-packages/torch/nn/modules/module.py", line 1501, in _call_impl
return forward_call(*args, **kwargs)
File "/home/xxx/anaconda3/envs/tokenflow/lib/python3.9/site-packages/diffusers/models/unet_2d_blocks.py", line 2227, in forward
hidden_states = resnet(hidden_states, temb, scale=lora_scale)
File "/home/xxx/anaconda3/envs/tokenflow/lib/python3.9/site-packages/torch/nn/modules/module.py", line 1501, in _call_impl
return forward_call(*args, **kwargs)
TypeError: forward() got an unexpected keyword argument 'scale'

could you help?

Tried to build a colab notebook, but got "ValueError: attempt to get argmax of an empty sequence" when running "!python /content/TokenFlow/run_tokenflow_pnp.py"

Tried to build a colab notebook, but got "ValueError: attempt to get argmax of an empty sequence" when running "!python /content/TokenFlow/run_tokenflow_pnp.py"

i tried to build the sequent notebook:
Notebook_colab_.txt

but even if i have my preprocessed data corretly in "/content/TokenFlow/data/test5" the frames, and in "/content/TokenFlow/latents/sd_ControlNet/test5" frames, latent etc, i still get an error trying with both paths, what is wrong?

RuntimeError: CUDA error: out of memory, but this maybe no the memory problem,

when run python run_tokenflow_pnp.py
Traceback (most recent call last):
File "/mnt/disk_1/kaizhou/TokenFlow/run_tokenflow_pnp.py", line 300, in
run(config)
File "/mnt/disk_1/kaizhou/TokenFlow/run_tokenflow_pnp.py", line 279, in run
editor = TokenFlow(config)
File "/mnt/disk_1/kaizhou/TokenFlow/run_tokenflow_pnp.py", line 62, in init
self.paths, self.frames, self.latents, self.eps = self.get_data()
File "/mnt/disk_1/kaizhou/TokenFlow/run_tokenflow_pnp.py", line 183, in get_data
eps = self.get_ddim_eps(latents, range(self.config["n_frames"])).to(torch.float16).to(self.device)
File "/mnt/disk_1/kaizhou/TokenFlow/run_tokenflow_pnp.py", line 189, in get_ddim_eps
noisy_latent = torch.load(latents_path)[indices].to(self.device)

the config_pnp.yml's conten is follow(and the cuda:5 is unoccupied:

seed: 1
device: 'cuda:5'
output_path: 'tokenflow-results'

data_path: 'data/wolf'
latents_path: 'latents' # should be the same as 'save_dir' arg used in preprocess
n_inversion_steps: 500 # for retrieving the latents of the inversion
n_frames: 40

sd_version: '1.5'
guidance_scale: 7.5
n_timesteps: 50
prompt: "A robotic wolf"
negative_prompt: "ugly, blurry, low res, unrealistic, unaesthetic"
batch_size: 1

pnp_attn_t: 0.5
pnp_f_t: 0.8

SD XL Integration

Is it possible to integrate the latest SD XL to the stable diffuion option?

ControlNet Usage

The documentation says

Similarly, if you want to use ControlNet or SDEedit, create a yaml config as in ``config/config_controlnet.yaml

Want should go in the config/config_controlnet.yaml. I tried to figure it out from looking at the code, but only saw the controlnet being used during preprocessing.

Required GPU memory depends on the video length.

I've managed to run run_tokenflow_pnp.py for a small excerpt of my video (5s) - and it looks really cool - but when I run it on the full one (5min) it crashes with CUDA OOM error even when I drop the batch size down to 1.

This scaling dependence on the video length probably caused by the extended attention seems like a major limitation of the method and is not highlighted neither in the discussion section nor somewhere else in the paper (as far as I can tell).

Is it possible to offload part of the attention computation to the CPU so that the number of frames is not a bottleneck?

Script for Warp-error metric.

Hi,
May I ask for the script for getting the Warp-error metric? Or is there any code base that I can refer to?
Thanks a lot.
Best,

Code snippet to reduce VRAM usage when too many frames to process.

Base on #20, I've modified the code to reduce vram usage when processing.

Usage:

Replace the register_extended_attention_pnp() function in tokenflow_utils.py with the code snippet below.


def register_extended_attention_pnp(model, injection_schedule):
    def sa_forward(self):
        to_out = self.to_out
        if type(to_out) is torch.nn.modules.container.ModuleList:
            to_out = self.to_out[0]
        else:
            to_out = self.to_out
            
        def forward_original(q, k, v):
            n_frames, seq_len, dim = q.shape
            h = self.heads
            head_dim = dim // h
            
            q = self.head_to_batch_dim(q).reshape(n_frames, h, seq_len, head_dim)
            k = self.head_to_batch_dim(k).reshape(n_frames, h, seq_len, head_dim)
            v = self.head_to_batch_dim(v).reshape(n_frames, h, seq_len, head_dim)

            out_all = []
            
            for frame in range(n_frames):
                out = []
                for j in range(h):
                    sim = torch.matmul(q[frame, j], k[frame, j].transpose(-1, -2)) * self.scale # (seq_len, seq_len)                                            
                    out.append(torch.matmul(sim.softmax(dim=-1), v[frame, j])) # h * (seq_len, head_dim)

                out = torch.cat(out, dim=0).reshape(-1, seq_len, head_dim) # (h, seq_len, head_dim)
                out_all.append(out) # n_frames * (h, seq_len, head_dim)
            
            out = torch.cat(out_all, dim=0) # (n_frames * h, seq_len, head_dim)
            out = self.batch_to_head_dim(out) # (n_frames, seq_len, h * head_dim)
            return out
            
        def forward_extended(q, k, v):
            n_frames, seq_len, dim = q.shape
            h = self.heads
            head_dim = dim // h
            
            q = self.head_to_batch_dim(q).reshape(n_frames, h, seq_len, head_dim)
            k = self.head_to_batch_dim(k).reshape(n_frames, h, seq_len, head_dim)
            v = self.head_to_batch_dim(v).reshape(n_frames, h, seq_len, head_dim)

            out_all = []
            window_size = 3
            
            for frame in range(n_frames):
                out = []
                # sliding window to improve speed.
                window = range(max(0, frame-window_size // 2), min(n_frames, frame+window_size//2+1))
                
                for j in range(h):
                    sim_all = []
                    
                    for kframe in window:
                        sim_all.append(torch.matmul(q[frame, j], k[kframe, j].transpose(-1, -2)) * self.scale) # window * (seq_len, seq_len)
                        
                    sim_all = torch.cat(sim_all).reshape(len(window), seq_len, seq_len).transpose(0, 1) # (seq_len, window, seq_len)
                    sim_all = sim_all.reshape(seq_len, len(window) * seq_len) # (seq_len, window * seq_len)
                    out.append(torch.matmul(sim_all.softmax(dim=-1), v[window, j].reshape(len(window) * seq_len, head_dim))) # h * (seq_len, head_dim)

                out = torch.cat(out, dim=0).reshape(-1, seq_len, head_dim) # (h, seq_len, head_dim)
                out_all.append(out) # n_frames * (h, seq_len, head_dim)
            
            out = torch.cat(out_all, dim=0) # (n_frames * h, seq_len, head_dim)
            out = self.batch_to_head_dim(out) # (n_frames, seq_len, h * head_dim)

            return out
            

        def forward(x, encoder_hidden_states=None, attention_mask=None):
            batch_size, sequence_length, dim = x.shape
            h = self.heads
            n_frames = batch_size // 3
            
            is_cross = encoder_hidden_states is not None
            encoder_hidden_states = encoder_hidden_states if is_cross else x
            q = self.to_q(x)
            k = self.to_k(encoder_hidden_states)
            v = self.to_v(encoder_hidden_states)

            if self.injection_schedule is not None and (self.t in self.injection_schedule or self.t == 1000):
                # inject unconditional
                q[n_frames:2 * n_frames] = q[:n_frames]
                k[n_frames:2 * n_frames] = k[:n_frames]
                # inject conditional
                q[2 * n_frames:] = q[:n_frames]
                k[2 * n_frames:] = k[:n_frames]

            out_source = forward_original(q[:n_frames], k[:n_frames], v[:n_frames])
            out_uncond = forward_extended(q[n_frames:2 * n_frames], k[n_frames:2 * n_frames], v[n_frames:2 * n_frames])
            out_cond = forward_extended(q[2 * n_frames:], k[2 * n_frames:], v[2 * n_frames:])
                            
            out = torch.cat([out_source, out_uncond, out_cond], dim=0) # (3 * n_frames, seq_len, dim)

            return to_out(out)

        return forward

    for _, module in model.unet.named_modules():
        if isinstance_str(module, "BasicTransformerBlock"):
            module.attn1.forward = sa_forward(module.attn1)
            setattr(module.attn1, 'injection_schedule', [])

    res_dict = {1: [1, 2], 2: [0, 1, 2], 3: [0, 1, 2]}
    # we are injecting attention in blocks 4 - 11 of the decoder, so not in the first block of the lowest resolution
    for res in res_dict:
        for block in res_dict[res]:
            module = model.unet.up_blocks[res].attentions[block].transformer_blocks[0].attn1
            module.forward = sa_forward(module)
            setattr(module, 'injection_schedule', injection_schedule)

Note

The code slightly modified the extended attention method in the paper, where the self attentions are just extended across consecutive 3 key frames instead of all the key frames.

Missing License

Hi and thank you for sharing your code.

Could you add a license file to the repo?
Thanks,
Best,
D

Add missing requirements

I needed to pip install these libraries:
torchvision
av
kornia

These should be added to requirements.txt

Is only the last layer of the edited frame processed?

Thanks for your nice work! I have two questions.
The first question, the paper mentioned that each layer of the key frames has been processed. So, when editing the original video frame, is every layer also processed, or is only the last layer processed. Second question, I understand that the processing of video frames should be carried out step by step, and the result of the processing of the previous step will be output as the next step. So according to the understanding of the paper, all frames should be processed in each step, is it right?

I look forward to your reply. Thank you again.

Recommend Projects

  • React photo React

    A declarative, efficient, and flexible JavaScript library for building user interfaces.

  • Vue.js photo Vue.js

    🖖 Vue.js is a progressive, incrementally-adoptable JavaScript framework for building UI on the web.

  • Typescript photo Typescript

    TypeScript is a superset of JavaScript that compiles to clean JavaScript output.

  • TensorFlow photo TensorFlow

    An Open Source Machine Learning Framework for Everyone

  • Django photo Django

    The Web framework for perfectionists with deadlines.

  • D3 photo D3

    Bring data to life with SVG, Canvas and HTML. 📊📈🎉

Recommend Topics

  • javascript

    JavaScript (JS) is a lightweight interpreted programming language with first-class functions.

  • web

    Some thing interesting about web. New door for the world.

  • server

    A server is a program made to process requests and deliver data to clients.

  • Machine learning

    Machine learning is a way of modeling and interpreting data that allows a piece of software to respond intelligently.

  • Game

    Some thing interesting about game, make everyone happy.

Recommend Org

  • Facebook photo Facebook

    We are working to build community through open source technology. NB: members must have two-factor auth.

  • Microsoft photo Microsoft

    Open source projects and samples from Microsoft.

  • Google photo Google

    Google ❤️ Open Source for everyone.

  • D3 photo D3

    Data-Driven Documents codes.