import numpy as np import torch #---------------------------------------------------------------------------- # Preconditioning corresponding to the variance exploding (VE) formulation # from the paper "Score-Based Generative Modeling through Stochastic # Differential Equations". class VEPrecond(torch.nn.Module): def __init__(self, model, label_dim = 0, # Number of class labels, 0 = unconditional. use_fp16 = False, # Execute the underlying model at FP16 precision? sigma_min = 0.02, # Minimum supported noise level. sigma_max = 100, # Maximum supported noise level. ): super().__init__() self.label_dim = label_dim self.use_fp16 = use_fp16 self.sigma_min = sigma_min self.sigma_max = sigma_max self.model = model def forward(self, x, sigma, class_labels=None, force_fp32=False, **model_kwargs): sigma = sigma.to(torch.float32).reshape(-1, 1, 1, 1) x = x.to(torch.float32) class_labels = None if self.label_dim == 0 else torch.zeros([1, self.label_dim], device=x.device) if class_labels is None else class_labels.to(torch.float32).reshape(-1, self.label_dim) dtype = torch.float16 if (self.use_fp16 and not force_fp32 and x.device.type == 'cuda') else torch.float32 c_skip = 1 c_out = sigma c_in = 1 c_noise = (0.5 * sigma).log() if class_labels is not None: F_x = self.model((c_in * x).to(dtype), c_noise.flatten(), class_labels=class_labels, **model_kwargs) else: F_x = self.model((c_in * x).to(dtype), c_noise.flatten(), **model_kwargs) assert F_x.dtype == dtype D_x = c_skip * x + c_out * F_x.to(torch.float32) return D_x def round_sigma(self, sigma): return torch.as_tensor(sigma) #---------------------------------------------------------------------------- # Preconditioning corresponding to improved DDPM (iDDPM) formulation from # the paper "Improved Denoising Diffusion Probabilistic Models". class iDDPMPrecond(torch.nn.Module): def __init__(self, model, label_dim = 0, # Number of class labels, 0 = unconditional. use_fp16 = False, # Execute the underlying model at FP16 precision? C_1 = 0.001, # Timestep adjustment at low noise levels. C_2 = 0.008, # Timestep adjustment at high noise levels. M = 1000, # Original number of timesteps in the DDPM formulation. ): super().__init__() self.label_dim = label_dim self.use_fp16 = use_fp16 self.C_1 = C_1 self.C_2 = C_2 self.M = M self.model = model u = torch.zeros(M + 1) for j in range(M, 0, -1): # M, ..., 1 u[j - 1] = ((u[j] ** 2 + 1) / (self.alpha_bar(j - 1) / self.alpha_bar(j)).clip(min=C_1) - 1).sqrt() self.register_buffer('u', u) self.sigma_min = float(u[M - 1]) self.sigma_max = float(u[0]) def forward(self, x, sigma, class_labels=None, lamb=None, force_fp32=False, **model_kwargs): sigma = sigma.to(torch.float32).reshape(-1, 1, 1, 1) x = x.to(torch.float32) class_labels = None if self.label_dim == 0 else torch.zeros([1, self.label_dim], device=x.device) if class_labels is None else class_labels.to(torch.float32).reshape(-1, self.label_dim) dtype = torch.float16 if (self.use_fp16 and not force_fp32 and x.device.type == 'cuda') else torch.float32 c_skip = 1 c_out = -sigma c_in = 1 / (sigma ** 2 + 1).sqrt() c_noise = self.M - 1 - self.round_sigma(sigma, return_index=True).to(torch.float32) # if class_labels is not None: # F_x = self.model((c_in * x).to(dtype), c_noise.flatten(), class_labels=class_labels, **model_kwargs) # else: if lamb is not None: F_x = self.model((c_in * x).to(dtype), lamb, c_noise.flatten(), **model_kwargs) else: F_x = self.model((c_in * x).to(dtype), c_noise.flatten(), **model_kwargs) assert F_x.dtype == dtype D_x = c_skip * x + c_out * F_x.to(torch.float32) return D_x def alpha_bar(self, j): j = torch.as_tensor(j) return (0.5 * np.pi * j / self.M / (self.C_2 + 1)).sin() ** 2 def round_sigma(self, sigma, return_index=False): sigma = torch.as_tensor(sigma) index = torch.cdist(sigma.to(self.u.device).to(torch.float32).reshape(1, -1, 1), self.u.reshape(1, -1, 1)).argmin(2) result = index if return_index else self.u[index.flatten()].to(sigma.dtype) return result.reshape(sigma.shape).to(sigma.device) #---------------------------------------------------------------------------- # Improved preconditioning proposed in the paper "Elucidating the Design # Space of Diffusion-Based Generative Models" (EDM). class EDMPrecond(torch.nn.Module): def __init__(self, model, label_dim = 0, # Number of class labels, 0 = unconditional. use_fp16 = False, # Execute the underlying model at FP16 precision? sigma_min = 0, # Minimum supported noise level. sigma_max = float('inf'), # Maximum supported noise level. sigma_data = 0.5, # Expected standard deviation of the training data. ): super().__init__() self.label_dim = label_dim self.use_fp16 = use_fp16 self.sigma_min = sigma_min self.sigma_max = sigma_max self.sigma_data = sigma_data self.model = model def forward(self, x, sigma, class_labels=None, force_fp32=False, **model_kwargs): x = x.to(torch.float32) sigma = sigma.to(torch.float32).reshape(-1, 1, 1, 1) if class_labels is not None: if self.label_dim == 0: class_labels = None else: class_labels = class_labels.to(torch.float32).reshape(-1, self.label_dim) dtype = torch.float16 if (self.use_fp16 and not force_fp32 and x.device.type == 'cuda') else torch.float32 c_skip = self.sigma_data ** 2 / (sigma ** 2 + self.sigma_data ** 2) c_out = sigma * self.sigma_data / (sigma ** 2 + self.sigma_data ** 2).sqrt() c_in = 1 / (self.sigma_data ** 2 + sigma ** 2).sqrt() c_in = c_in.to(x.device) c_noise = sigma.log() / 4 if class_labels is not None: F_x = self.model((c_in * x).to(dtype), c_noise.flatten(), c_latent=class_labels, **model_kwargs) else: F_x = self.model((c_in * x).to(dtype), c_noise.flatten(), **model_kwargs) assert F_x.dtype == dtype D_x = c_skip * x + c_out * F_x.to(torch.float32) return D_x def round_sigma(self, sigma): return torch.as_tensor(sigma) #----------------------------------------------------------------------------