Giter VIP home page Giter VIP logo

Comments (8)

skywalker00001 avatar skywalker00001 commented on August 26, 2024

If you just discard some frames, and remain the other frames in the training, it's like MAE, let model learn how to generate the next frame based on the corrupted previous frames (let's say if we are generating the frame 3, and your free idx goes to the frame 1, that means the stable diffusion is trying to learn how to diffuse from (ClIP[3]+BLIP[0, 2]) context. But this method doesn't align how you do sampling.

from arldm.

xichenpan avatar xichenpan commented on August 26, 2024

Hi @skywalker00001, thank you for your comment. Actually, you can treat our batch size as B * V, which means we generate every single frame according to the previous frames in a story. We do drop all clip and blip conditions at the classifier_free_id frames. And for other frames, the conditons are not corrupted, as you can see we save a copy in

ARLDM/main.py

Line 200 in b8c1db4

source_embeddings = source_embeddings.repeat_interleave(V, dim=0)

You can debug our code to better understand the shape, it is quite confusing.

And the code

ARLDM/main.py

Line 355 in b8c1db4

noise_pred = noise_pred_uncond + guidance_scale * (noise_pred_text - noise_pred_uncond)

you mentioned is actually following the implementation of Diffusers

https://github.com/huggingface/diffusers/blob/b9b891621e8ed5729761cc6a31b23072315d2df0/src/diffusers/pipelines/stable_diffusion/pipeline_stable_diffusion.py#L686

actually the $scale=1+w$, the two formula is the same

from arldm.

skywalker00001 avatar skywalker00001 commented on August 26, 2024

Hi, xichenpan, thank you for your answer!
I see. Does that mean you actually trained 4 extra unconditional models (let's say our task is continuation ) for each frames?
unconditional model 1: [null clip+ null blip frame 0]
unconditional model 2: [null clip+ null blip frame0 + null blip frame1]
unconditional model 3: [null clip+ null blip frame0 + null blip frame1 + null blip frame2]
unconditional model 4: [null clip+ null blip frame0 + null blip frame1 + null blip frame2 + null blip frame 3]
And use attention mask = 1 to mask the future frames in each model.
Is my interpretation right?

And for the second question, yeah, you are right, the two formulas are the same.
Thanks again for your patience!

from arldm.

skywalker00001 avatar skywalker00001 commented on August 26, 2024

Besides, I encountered another problem.
How do you calculate the FID score?
I debugged and output the shape of "original_images" and "images"

ARLDM/main.py

Lines 313 to 323 in b8c1db4

def predict_step(self, batch, batch_idx, dataloader_idx=0):
original_images, images = self.sample(batch)
if self.args.calculate_fid:
original_images = original_images.cpu().numpy().astype('uint8')
original_images = [Image.fromarray(im, 'RGB') for im in original_images]
ori = self.inception_feature(original_images).cpu().numpy()
gen = self.inception_feature(images).cpu().numpy()
else:
ori = None
gen = None
return images, ori, gen

original_images has the shape: # (4, 3, 128, 128), but don't we need to permute it before turn it to PIL image?
Because the "original_images" now has the PIL image shape 128 * 3 but "images" has the PIL image shape of 512 * 512 (the stable diffusion output size).

I know in the code

ARLDM/main.py

Lines 381 to 390 in b8c1db4

def inception_feature(self, images):
images = torch.stack([self.fid_augment(image) for image in images])
images = images.type(torch.FloatTensor).to(self.device)
images = (images + 1) / 2
images = F.interpolate(images, size=(299, 299), mode='bilinear', align_corners=False)
pred = self.inception(images)[0]
if pred.shape[2] != 1 or pred.shape[3] != 1:
pred = F.adaptive_avg_pool2d(pred, output_size=(1, 1))
return pred.reshape(-1, 2048)

It will turn the images to (3, 64, 64 ) in the line of 382, but I suspect that the values will change.
For example, I use "transforms.ToTensor()(images[0])" for "original_images", and the shape is (3, 3, 128).

And I use [transforms.ToPILImage()(self.fid_augment(im)).save("faked_pics/ori_{:02}.png".format(idx)) for idx, im in enumerate(images)], the generated images are all corrupted for "original_images". But for the generated "images", there will not be error. the "transforms.ToTensor()(images[0])" has the shape (3, 512, 512) and PIL save the correct image.

image

Therefore, I suppose should it be

original_images = original_images.permute(0, 2, 3, 1).cpu().numpy().astype('uint8')

in the line 316?

ARLDM/main.py

Line 316 in b8c1db4

original_images = original_images.cpu().numpy().astype('uint8')

If that is the case, maybe the FID scores will be changed too...

from arldm.

skywalker00001 avatar skywalker00001 commented on August 26, 2024

And I think the multi-GPU inference is supported by Pytorch Lightning.

I only add "strategy="ddp"," in the Trainer, and set "args.gpu_ids" to [0, 1, 2, 3] and comment your code in line

ARLDM/main.py

Line 425 in b8c1db4

assert args.gpu_ids == 1 or len(args.gpu_ids) == 1, "Only one GPU is supported in test mode"

It succeed. So I think it may help others.
image
for

ARLDM/main.py

Lines 430 to 435 in b8c1db4

predictor = pl.Trainer(
accelerator='gpu',
devices=args.gpu_ids,
max_epochs=-1,
benchmark=True
)

from arldm.

xichenpan avatar xichenpan commented on August 26, 2024

@skywalker00001 Hi, for the frist issue, it is one single uncond model (with varied length), because all params are shared.
For the second issue, thanks for pointing that out, I found in our original implementation, we do not permute the original_image, and it has a shape of BHWC

ARLDM/ARLDM.py

Line 146 in a24e2e9

features['img'] = torch.from_numpy(images)

While for our current implementation we do, so it has a shpae of BCHW
images = torch.stack([self.augment(im) for im in images]) \
if self.subset in ['train', 'val'] else torch.from_numpy(np.array(images)).permute(0, 3, 1, 2)

So this cause an inconsistency cause we still copy the code from old implementation:

ARLDM/ARLDM.py

Lines 353 to 357 in a24e2e9

original_images = input_data.img[:, 1:] if self.task == 'continuation' else input_data.img
original_images = torch.flatten(original_images, start_dim=0, end_dim=1)
original_images = original_images.cpu().numpy().astype('uint8')
original_images = [Image.fromarray(im, 'RGB') for im in original_images]
kit_output.add_original_feature_output(self.inception_feature(original_images, device))

We will remove the permute in dataset code, so that the shape is correct. So I believe the FID score we reported in our original paper is correct, while current repo do not correctly immigrant the original implementation. Another user has reported this issue, while I am sorry that I forget to correct it, #10 (comment).
For the final issue, you can do so, while running on multiple GPU may droplast or assign a same sample multiple times on different GPU, so we only use single GPU in inference, if you want to get the result quickly, it is acceptable.

from arldm.

skywalker00001 avatar skywalker00001 commented on August 26, 2024

Got it! Thanks again.

from arldm.

skywalker00001 avatar skywalker00001 commented on August 26, 2024

from arldm.

Related Issues (20)

Recommend Projects

  • React photo React

    A declarative, efficient, and flexible JavaScript library for building user interfaces.

  • Vue.js photo Vue.js

    🖖 Vue.js is a progressive, incrementally-adoptable JavaScript framework for building UI on the web.

  • Typescript photo Typescript

    TypeScript is a superset of JavaScript that compiles to clean JavaScript output.

  • TensorFlow photo TensorFlow

    An Open Source Machine Learning Framework for Everyone

  • Django photo Django

    The Web framework for perfectionists with deadlines.

  • D3 photo D3

    Bring data to life with SVG, Canvas and HTML. 📊📈🎉

Recommend Topics

  • javascript

    JavaScript (JS) is a lightweight interpreted programming language with first-class functions.

  • web

    Some thing interesting about web. New door for the world.

  • server

    A server is a program made to process requests and deliver data to clients.

  • Machine learning

    Machine learning is a way of modeling and interpreting data that allows a piece of software to respond intelligently.

  • Game

    Some thing interesting about game, make everyone happy.

Recommend Org

  • Facebook photo Facebook

    We are working to build community through open source technology. NB: members must have two-factor auth.

  • Microsoft photo Microsoft

    Open source projects and samples from Microsoft.

  • Google photo Google

    Google ❤️ Open Source for everyone.

  • D3 photo D3

    Data-Driven Documents codes.