Skip to content

Commit

Permalink
clean code
Browse files Browse the repository at this point in the history
  • Loading branch information
arthur-qiu committed Dec 23, 2024
1 parent f4373dd commit 4cd8bdc
Show file tree
Hide file tree
Showing 3 changed files with 1 addition and 364 deletions.
121 changes: 0 additions & 121 deletions pipeline_freescale.py
Original file line number Diff line number Diff line change
Expand Up @@ -55,127 +55,6 @@
```
"""

def default(val, d):
if exists(val):
return val
return d() if isfunction(d) else d

def exists(val):
return val is not None

def extract_into_tensor(a, t, x_shape):
b, *_ = t.shape
out = a.gather(-1, t)
return out.reshape(b, *((1,) * (len(x_shape) - 1)))

def make_beta_schedule(schedule, n_timestep, linear_start=1e-4, linear_end=2e-2, cosine_s=8e-3):
if schedule == "linear":
betas = (
torch.linspace(linear_start ** 0.5, linear_end ** 0.5, n_timestep, dtype=torch.float64) ** 2
)
elif schedule == "cosine":
timesteps = (
torch.arange(n_timestep + 1, dtype=torch.float64) / n_timestep + cosine_s
)
alphas = timesteps / (1 + cosine_s) * np.pi / 2
alphas = torch.cos(alphas).pow(2)
alphas = alphas / alphas[0]
betas = 1 - alphas[1:] / alphas[:-1]
betas = np.clip(betas, a_min=0, a_max=0.999)
elif schedule == "sqrt_linear":
betas = torch.linspace(linear_start, linear_end, n_timestep, dtype=torch.float64)
elif schedule == "sqrt":
betas = torch.linspace(linear_start, linear_end, n_timestep, dtype=torch.float64) ** 0.5
else:
raise ValueError(f"schedule '{schedule}' unknown.")
return betas.numpy()

to_torch = partial(torch.tensor, dtype=torch.float16)
betas = make_beta_schedule("linear", 1000, linear_start=0.00085, linear_end=0.012)
alphas = 1. - betas
alphas_cumprod = np.cumprod(alphas, axis=0)
sqrt_alphas_cumprod = to_torch(np.sqrt(alphas_cumprod))
sqrt_one_minus_alphas_cumprod = to_torch(np.sqrt(1. - alphas_cumprod))

def q_sample(x_start, t, init_noise_sigma = 1.0, noise=None, device=None):
noise = default(noise, lambda: torch.randn_like(x_start)).to(device) * init_noise_sigma
return (extract_into_tensor(sqrt_alphas_cumprod.to(device), t, x_start.shape) * x_start +
extract_into_tensor(sqrt_one_minus_alphas_cumprod.to(device), t, x_start.shape) * noise)

def get_views(height, width, h_window_size=128, w_window_size=128, h_window_stride=64, w_window_stride=64, vae_scale_factor=8):
height //= vae_scale_factor
width //= vae_scale_factor
num_blocks_height = int((height - h_window_size) / h_window_stride - 1e-6) + 2 if height > h_window_size else 1
num_blocks_width = int((width - w_window_size) / w_window_stride - 1e-6) + 2 if width > w_window_size else 1
total_num_blocks = int(num_blocks_height * num_blocks_width)
views = []
for i in range(total_num_blocks):
h_start = int((i // num_blocks_width) * h_window_stride)
h_end = h_start + h_window_size
w_start = int((i % num_blocks_width) * w_window_stride)
w_end = w_start + w_window_size

if h_end > height:
h_start = int(h_start + height - h_end)
h_end = int(height)
if w_end > width:
w_start = int(w_start + width - w_end)
w_end = int(width)
if h_start < 0:
h_end = int(h_end - h_start)
h_start = 0
if w_start < 0:
w_end = int(w_end - w_start)
w_start = 0

random_jitter = True
if random_jitter:
h_jitter_range = (h_window_size - h_window_stride) // 4
w_jitter_range = (w_window_size - w_window_stride) // 4
h_jitter = 0
w_jitter = 0

if (w_start != 0) and (w_end != width):
w_jitter = random.randint(-w_jitter_range, w_jitter_range)
elif (w_start == 0) and (w_end != width):
w_jitter = random.randint(-w_jitter_range, 0)
elif (w_start != 0) and (w_end == width):
w_jitter = random.randint(0, w_jitter_range)
if (h_start != 0) and (h_end != height):
h_jitter = random.randint(-h_jitter_range, h_jitter_range)
elif (h_start == 0) and (h_end != height):
h_jitter = random.randint(-h_jitter_range, 0)
elif (h_start != 0) and (h_end == height):
h_jitter = random.randint(0, h_jitter_range)
h_start += (h_jitter + h_jitter_range)
h_end += (h_jitter + h_jitter_range)
w_start += (w_jitter + w_jitter_range)
w_end += (w_jitter + w_jitter_range)

views.append((h_start, h_end, w_start, w_end))
return views

def gaussian_kernel(kernel_size=3, sigma=1.0, channels=3):
x_coord = torch.arange(kernel_size)
gaussian_1d = torch.exp(-(x_coord - (kernel_size - 1) / 2) ** 2 / (2 * sigma ** 2))
gaussian_1d = gaussian_1d / gaussian_1d.sum()
gaussian_2d = gaussian_1d[:, None] * gaussian_1d[None, :]
kernel = gaussian_2d[None, None, :, :].repeat(channels, 1, 1, 1)

return kernel

def gaussian_filter(latents, kernel_size=3, sigma=1.0):
channels = latents.shape[1]
kernel = gaussian_kernel(kernel_size, sigma, channels).to(latents.device, latents.dtype)
if len(latents.shape) == 5:
b = latents.shape[0]
latents = rearrange(latents, 'b c t i j -> (b t) c i j')
blurred_latents = F.conv2d(latents, kernel, padding=kernel_size//2, groups=channels)
blurred_latents = rearrange(blurred_latents, '(b t) c i j -> b c t i j', b=b)
else:
blurred_latents = F.conv2d(latents, kernel, padding=kernel_size//2, groups=channels)

return blurred_latents

# Copied from diffusers.pipelines.stable_diffusion.pipeline_stable_diffusion.rescale_noise_cfg
def rescale_noise_cfg(noise_cfg, noise_pred_text, guidance_rescale=0.0):
Expand Down
123 changes: 1 addition & 122 deletions pipeline_freescale_imgen.py
Original file line number Diff line number Diff line change
Expand Up @@ -57,6 +57,7 @@
```
"""


def process_image_to_tensor(image_path):
image = Image.open(image_path).convert("RGB")
transform = transforms.Compose(
Expand All @@ -76,128 +77,6 @@ def process_image_to_bitensor(image_path):
binary_tensor = torch.where(image_tensor != 0, torch.tensor(1.0), torch.tensor(0.0))
return binary_tensor

def default(val, d):
if exists(val):
return val
return d() if isfunction(d) else d

def exists(val):
return val is not None

def extract_into_tensor(a, t, x_shape):
b, *_ = t.shape
out = a.gather(-1, t)
return out.reshape(b, *((1,) * (len(x_shape) - 1)))

def make_beta_schedule(schedule, n_timestep, linear_start=1e-4, linear_end=2e-2, cosine_s=8e-3):
if schedule == "linear":
betas = (
torch.linspace(linear_start ** 0.5, linear_end ** 0.5, n_timestep, dtype=torch.float64) ** 2
)
elif schedule == "cosine":
timesteps = (
torch.arange(n_timestep + 1, dtype=torch.float64) / n_timestep + cosine_s
)
alphas = timesteps / (1 + cosine_s) * np.pi / 2
alphas = torch.cos(alphas).pow(2)
alphas = alphas / alphas[0]
betas = 1 - alphas[1:] / alphas[:-1]
betas = np.clip(betas, a_min=0, a_max=0.999)
elif schedule == "sqrt_linear":
betas = torch.linspace(linear_start, linear_end, n_timestep, dtype=torch.float64)
elif schedule == "sqrt":
betas = torch.linspace(linear_start, linear_end, n_timestep, dtype=torch.float64) ** 0.5
else:
raise ValueError(f"schedule '{schedule}' unknown.")
return betas.numpy()

to_torch = partial(torch.tensor, dtype=torch.float16)
betas = make_beta_schedule("linear", 1000, linear_start=0.00085, linear_end=0.012)
alphas = 1. - betas
alphas_cumprod = np.cumprod(alphas, axis=0)
sqrt_alphas_cumprod = to_torch(np.sqrt(alphas_cumprod))
sqrt_one_minus_alphas_cumprod = to_torch(np.sqrt(1. - alphas_cumprod))

def q_sample(x_start, t, init_noise_sigma = 1.0, noise=None, device=None):
noise = default(noise, lambda: torch.randn_like(x_start)).to(device) * init_noise_sigma
return (extract_into_tensor(sqrt_alphas_cumprod.to(device), t, x_start.shape) * x_start +
extract_into_tensor(sqrt_one_minus_alphas_cumprod.to(device), t, x_start.shape) * noise)

def get_views(height, width, h_window_size=128, w_window_size=128, h_window_stride=64, w_window_stride=64, vae_scale_factor=8):
height //= vae_scale_factor
width //= vae_scale_factor
num_blocks_height = int((height - h_window_size) / h_window_stride - 1e-6) + 2 if height > h_window_size else 1
num_blocks_width = int((width - w_window_size) / w_window_stride - 1e-6) + 2 if width > w_window_size else 1
total_num_blocks = int(num_blocks_height * num_blocks_width)
views = []
for i in range(total_num_blocks):
h_start = int((i // num_blocks_width) * h_window_stride)
h_end = h_start + h_window_size
w_start = int((i % num_blocks_width) * w_window_stride)
w_end = w_start + w_window_size

if h_end > height:
h_start = int(h_start + height - h_end)
h_end = int(height)
if w_end > width:
w_start = int(w_start + width - w_end)
w_end = int(width)
if h_start < 0:
h_end = int(h_end - h_start)
h_start = 0
if w_start < 0:
w_end = int(w_end - w_start)
w_start = 0

random_jitter = True
if random_jitter:
h_jitter_range = (h_window_size - h_window_stride) // 4
w_jitter_range = (w_window_size - w_window_stride) // 4
h_jitter = 0
w_jitter = 0

if (w_start != 0) and (w_end != width):
w_jitter = random.randint(-w_jitter_range, w_jitter_range)
elif (w_start == 0) and (w_end != width):
w_jitter = random.randint(-w_jitter_range, 0)
elif (w_start != 0) and (w_end == width):
w_jitter = random.randint(0, w_jitter_range)
if (h_start != 0) and (h_end != height):
h_jitter = random.randint(-h_jitter_range, h_jitter_range)
elif (h_start == 0) and (h_end != height):
h_jitter = random.randint(-h_jitter_range, 0)
elif (h_start != 0) and (h_end == height):
h_jitter = random.randint(0, h_jitter_range)
h_start += (h_jitter + h_jitter_range)
h_end += (h_jitter + h_jitter_range)
w_start += (w_jitter + w_jitter_range)
w_end += (w_jitter + w_jitter_range)

views.append((h_start, h_end, w_start, w_end))
return views

def gaussian_kernel(kernel_size=3, sigma=1.0, channels=3):
x_coord = torch.arange(kernel_size)
gaussian_1d = torch.exp(-(x_coord - (kernel_size - 1) / 2) ** 2 / (2 * sigma ** 2))
gaussian_1d = gaussian_1d / gaussian_1d.sum()
gaussian_2d = gaussian_1d[:, None] * gaussian_1d[None, :]
kernel = gaussian_2d[None, None, :, :].repeat(channels, 1, 1, 1)

return kernel

def gaussian_filter(latents, kernel_size=3, sigma=1.0):
channels = latents.shape[1]
kernel = gaussian_kernel(kernel_size, sigma, channels).to(latents.device, latents.dtype)
if len(latents.shape) == 5:
b = latents.shape[0]
latents = rearrange(latents, 'b c t i j -> (b t) c i j')
blurred_latents = F.conv2d(latents, kernel, padding=kernel_size//2, groups=channels)
blurred_latents = rearrange(blurred_latents, '(b t) c i j -> b c t i j', b=b)
else:
blurred_latents = F.conv2d(latents, kernel, padding=kernel_size//2, groups=channels)

return blurred_latents

# Copied from diffusers.pipelines.stable_diffusion.pipeline_stable_diffusion.rescale_noise_cfg
def rescale_noise_cfg(noise_cfg, noise_pred_text, guidance_rescale=0.0):
"""
Expand Down
121 changes: 0 additions & 121 deletions pipeline_freescale_turbo.py
Original file line number Diff line number Diff line change
Expand Up @@ -55,127 +55,6 @@
```
"""

def default(val, d):
if exists(val):
return val
return d() if isfunction(d) else d

def exists(val):
return val is not None

def extract_into_tensor(a, t, x_shape):
b, *_ = t.shape
out = a.gather(-1, t)
return out.reshape(b, *((1,) * (len(x_shape) - 1)))

def make_beta_schedule(schedule, n_timestep, linear_start=1e-4, linear_end=2e-2, cosine_s=8e-3):
if schedule == "linear":
betas = (
torch.linspace(linear_start ** 0.5, linear_end ** 0.5, n_timestep, dtype=torch.float64) ** 2
)
elif schedule == "cosine":
timesteps = (
torch.arange(n_timestep + 1, dtype=torch.float64) / n_timestep + cosine_s
)
alphas = timesteps / (1 + cosine_s) * np.pi / 2
alphas = torch.cos(alphas).pow(2)
alphas = alphas / alphas[0]
betas = 1 - alphas[1:] / alphas[:-1]
betas = np.clip(betas, a_min=0, a_max=0.999)
elif schedule == "sqrt_linear":
betas = torch.linspace(linear_start, linear_end, n_timestep, dtype=torch.float64)
elif schedule == "sqrt":
betas = torch.linspace(linear_start, linear_end, n_timestep, dtype=torch.float64) ** 0.5
else:
raise ValueError(f"schedule '{schedule}' unknown.")
return betas.numpy()

to_torch = partial(torch.tensor, dtype=torch.float16)
betas = make_beta_schedule("linear", 1000, linear_start=0.00085, linear_end=0.012)
alphas = 1. - betas
alphas_cumprod = np.cumprod(alphas, axis=0)
sqrt_alphas_cumprod = to_torch(np.sqrt(alphas_cumprod))
sqrt_one_minus_alphas_cumprod = to_torch(np.sqrt(1. - alphas_cumprod))

def q_sample(x_start, t, init_noise_sigma = 1.0, noise=None, device=None):
noise = default(noise, lambda: torch.randn_like(x_start)).to(device) * init_noise_sigma
return (extract_into_tensor(sqrt_alphas_cumprod.to(device), t, x_start.shape) * x_start +
extract_into_tensor(sqrt_one_minus_alphas_cumprod.to(device), t, x_start.shape) * noise)

def get_views(height, width, h_window_size=128, w_window_size=128, h_window_stride=64, w_window_stride=64, vae_scale_factor=8):
height //= vae_scale_factor
width //= vae_scale_factor
num_blocks_height = int((height - h_window_size) / h_window_stride - 1e-6) + 2 if height > h_window_size else 1
num_blocks_width = int((width - w_window_size) / w_window_stride - 1e-6) + 2 if width > w_window_size else 1
total_num_blocks = int(num_blocks_height * num_blocks_width)
views = []
for i in range(total_num_blocks):
h_start = int((i // num_blocks_width) * h_window_stride)
h_end = h_start + h_window_size
w_start = int((i % num_blocks_width) * w_window_stride)
w_end = w_start + w_window_size

if h_end > height:
h_start = int(h_start + height - h_end)
h_end = int(height)
if w_end > width:
w_start = int(w_start + width - w_end)
w_end = int(width)
if h_start < 0:
h_end = int(h_end - h_start)
h_start = 0
if w_start < 0:
w_end = int(w_end - w_start)
w_start = 0

random_jitter = True
if random_jitter:
h_jitter_range = (h_window_size - h_window_stride) // 4
w_jitter_range = (w_window_size - w_window_stride) // 4
h_jitter = 0
w_jitter = 0

if (w_start != 0) and (w_end != width):
w_jitter = random.randint(-w_jitter_range, w_jitter_range)
elif (w_start == 0) and (w_end != width):
w_jitter = random.randint(-w_jitter_range, 0)
elif (w_start != 0) and (w_end == width):
w_jitter = random.randint(0, w_jitter_range)
if (h_start != 0) and (h_end != height):
h_jitter = random.randint(-h_jitter_range, h_jitter_range)
elif (h_start == 0) and (h_end != height):
h_jitter = random.randint(-h_jitter_range, 0)
elif (h_start != 0) and (h_end == height):
h_jitter = random.randint(0, h_jitter_range)
h_start += (h_jitter + h_jitter_range)
h_end += (h_jitter + h_jitter_range)
w_start += (w_jitter + w_jitter_range)
w_end += (w_jitter + w_jitter_range)

views.append((h_start, h_end, w_start, w_end))
return views

def gaussian_kernel(kernel_size=3, sigma=1.0, channels=3):
x_coord = torch.arange(kernel_size)
gaussian_1d = torch.exp(-(x_coord - (kernel_size - 1) / 2) ** 2 / (2 * sigma ** 2))
gaussian_1d = gaussian_1d / gaussian_1d.sum()
gaussian_2d = gaussian_1d[:, None] * gaussian_1d[None, :]
kernel = gaussian_2d[None, None, :, :].repeat(channels, 1, 1, 1)

return kernel

def gaussian_filter(latents, kernel_size=3, sigma=1.0):
channels = latents.shape[1]
kernel = gaussian_kernel(kernel_size, sigma, channels).to(latents.device, latents.dtype)
if len(latents.shape) == 5:
b = latents.shape[0]
latents = rearrange(latents, 'b c t i j -> (b t) c i j')
blurred_latents = F.conv2d(latents, kernel, padding=kernel_size//2, groups=channels)
blurred_latents = rearrange(blurred_latents, '(b t) c i j -> b c t i j', b=b)
else:
blurred_latents = F.conv2d(latents, kernel, padding=kernel_size//2, groups=channels)

return blurred_latents

# Copied from diffusers.pipelines.stable_diffusion.pipeline_stable_diffusion.rescale_noise_cfg
def rescale_noise_cfg(noise_cfg, noise_pred_text, guidance_rescale=0.0):
Expand Down

0 comments on commit 4cd8bdc

Please sign in to comment.