def sds_loss(self, pred_depth, rgb_in, bs, view_num, guidance_scale=100, as_latent=False, grad_scale=1,
save_guidance_path=None):
if self.alphas is None:
self.alphas = self.scheduler.alphas_cumprod.to(self.device)
"""
pred_depth: the predicted depth, normalized to [-1,1], and nearer is smaller, size of (bs,1,h,w)
rgb_in: the conditioned image, normalized to [-1,1], size of (bs,3,h,w)
bs: batch size
view_num: used for reshaping
"""
device = self.device
# Encode image
pred_depth = pred_depth.repeat(1, 3, 1, 1)
rgb_latent = self.encode_rgb(rgb_in)
depth_latent = self.encode_rgb(pred_depth)
# Set timesteps
t = torch.randint(self.min_step, self.max_step + 1, (bs,), dtype=torch.long,
device=self.device)
t = t.unsqueeze(-1).repeat(1, view_num).view(-1)
with torch.no_grad():
# Initial depth map (noise)
latent_noise = torch.randn(
rgb_latent.shape,
device=device,
dtype=self.dtype,
generator=None,
) # [B, 4, h, w]
latents_noisy = self.scheduler.add_noise(depth_latent, latent_noise, t)
# pred noise
uncon_latent=torch.cat([torch.zeros_like(rgb_latent).to(rgb_latent), latents_noisy], dim=1)
con_latent = torch.cat([rgb_latent, latents_noisy], dim=1)
latent_model_input=torch.cat([uncon_latent,con_latent],dim=0)
tt = torch.cat([t] * 2)
# Batched empty text embedding
if self.empty_text_embed is None:
self.encode_empty_text()
batch_empty_text_embed = self.empty_text_embed.repeat(
(latent_model_input.shape[0], 1, 1)
).to(device) # [B, 2, 1024]
noise_pred = self.unet(
latent_model_input, tt, encoder_hidden_states=batch_empty_text_embed
).sample # [B, 4, h, w]
# perform guidance (high scale from paper!)
noise_pred_uncond, noise_pred_pos = noise_pred.chunk(2)
noise_pred = noise_pred_uncond + guidance_scale * (noise_pred_pos - noise_pred_uncond)
# w(t), sigma_t^2
w = (1 - self.alphas[t])
grad = grad_scale * w[:, None, None, None] * (noise_pred - latent_noise)
grad = torch.nan_to_num(grad)
targets = (depth_latent - grad).detach()
loss = 0.5 * F.mse_loss(depth_latent.float(), targets, reduction='sum') / depth_latent.shape[0]
return loss