从 Non-local Means 看 PyTorch 显存优化的奇技淫巧

背景

最近的一个项目中需要用 PyTorch 实现一套可微分的 Non-local Means (NLM) 降噪算法,并通过样本学习建立一个图像内容与降噪强度参数 $h$ 的映射模型。

在 PyTorch 框架下,实现传统图像处理算法往往有着不止一种的写法,例如,为了实现 $3\times3$ 的均值滤波,我们既可以使用一个归一化的 box filter 对图像进行卷积,也可以对图像进行9次位移(shift/roll)并逐像素求和,最后再将每个像素值除以9。在模型进行正向推断(forward)时,不同的写法在显存开销上或许不存在太大的区别,但是当我们使用 Autograd 进行逐层的梯度反向传播(backward)时,不同的写法往往对应了截然不同的显存占用。以 NLM 为例,在我的实验中,对于同一张 $512\times{}512$ 的输入图像,不同的 tensor 操作写法可以带来180%以上的训练显存差异。当然,在代码设计没有明显缺陷的情况下,节省显存必然意味着运算效率的下降,不过对于我的这个项目来说,训练阶段的耗时并不是瓶颈所在,batch size 才是影响模型性能的主要因素,因此利用时间换取空间仍然是一个非常划算的选择。

先简单地回顾一下 Non-local Means 算法:给定一幅噪声图像,对于图像中的每个像素 $x$,以其为中心,在一个 $w\times{}w$ 的搜索窗口(下文代码中以 search_window 命名)内,遍历所有的可能的图像块(一般是一个 $k\times{}k$ 的方形区域,且有 $k<w$,下文代码中以 patch 命名),并计算该图像块与以 $x$ 为中心的那个图像块之间的相似度。

在 $w\times{}w$ 的搜索窗口内,一共存在 $w^2$ 个可重叠的图像块,对于第 $i$ 个图像块,假设其与 $x$ 所在图像块之间的相似度为 $s_i$(注意:是图像块之间的相似度,而不是像素与像素之间的相似度!),且该图像块中心像素的像素值为 $y_i$,那么对于 $x$ 像素,其经过 NLM 降噪后的像素值可以表示为:

\begin{equation*} \hat{x}_\mathrm{\scriptstyle NLM}=\sum_{i=1}^{w^2} f(s_i)y_i \end{equation*}

其中 $f(\cdot)$ 为一单调递增函数,即,若第 $i$ 个图像块与 $x$ 所在图像块的相似度越高,其在加权平均时所占的比重就越大。实际计算中一般会采用两个图像块之间的欧式距离 $d$ 作为相似度度量,即 $f(\cdot)=\exp\left(-\frac{d}{h}\right)$,其中 $h$ 是一个用于控制降噪强度的超参数

一点准备工作

为简化起见,假设现在给定了一张噪声图像和一张与之对应的无噪声参考图像(ground-truth),我们希望对该噪声图像求得一个最佳的 $h$ 参数,使之经过 NLM 降噪后与参考图像之间的 PSNR 最大化。

在这个全监督任务中,我们可以直接使用 MSE loss 作为目标函数进行迭代优化:

# train.py

import torch
import torch.nn as nn

from .nlm import NonLocalMeans
from .utils import gen_image_pair


dev = torch.device('cuda')
clean_rgb, noisy_rgb = gen_image_pair('Lenna.png', device=dev, sigma=0.05)

denoiser = NonLocalMeans(h0=0.05).to(dev)
denoiser.train()
optimizer = torch.optim.Adam(params=denoiser.parameters(), lr=0.0001, weight_decay=0.0001)
loss = nn.MSELoss()

for iter_num in range(100):
    denoised_rgb = denoiser(noisy_rgb)
    mse = loss(clean_rgb, denoised_rgb)

    mse.backward()
    optimizer.step()

    psnr = 10 * torch.log10(1.0 / mse)
    print('step{}: PSNR={:.3f} (h={:.5f})'.format(iter_num, psnr.item(), float(denoiser.h.item())))

其中 gen_image_pair 用于生成一对干净图像和噪声图像,为了不影响阅读的连续性,我把它的具体实现放在了文末的附录中;NonLocalMeans 模块接收一个 $N\times{}3\times{}H\times{}W$ 的张量作为输入,返回一个同尺寸的降噪后的张量。NonLocalMeans 的具体实现将放在下文中展开,这里我们只需确保该模块中包含一个可训练的 self.h 参数:

# nlm.py

import torch
import torch.nn as nn


class NonLocalMeans(nn.Module):
    def __init__(self, h0, search_window_size, patch_size):
        super().__init__()
        self.h = nn.Parameter(torch.tensor([float(h0)]), requires_grad=True)
        pass

    def forward(self, rgb):
        pass

NonLocalMeans v1

在第一版的 PyTorch NLM 中,我很自然地想到了空间换时间的方法,即,对于每个像素,将其搜索窗口中的 $w\times{}w$ 个像素拍扁并放置到一个新的维度中。以单通道图像为例,假设输入图像尺寸为 $1\times{}H\times{}W$,搜索窗口宽度 $w=11$,那么可以创建一个尺寸为 $1\times{}H\times{}W\times{}121$ 的张量,其中第 $(1,i,j,k)$ 个元素表示的是在以原图 $(i,j)$ 像素为中心的搜索窗口中第 $k$ 个像素的像素值。

有了这个张量之后,我们就可以将其与原图相减得到一个差值张量(相减之前需要将原图 unsqueeze 出一个新的维度用于 broadcast),并在这个差值张量上利用局部求和+开方的方法得到图像块之间的欧氏距离(出于偷懒的原因,图像块之间相似度的计算仅在亮度通道进行,下同。rgb_to_luminance 的实现见附录):

# nlm_v1.py

from .utils import rgb_to_luminance, ShiftStack, BoxFilter

EPSILON = 1E-12


class NonLocalMeans(nn.Module):
    def __init__(self, h0, search_window_size=11, patch_size=5):
        super().__init__()
        self.h = nn.Parameter(torch.tensor([float(h0)]), requires_grad=True)

        self.gen_window_stack = ShiftStack(window_size=search_window_size)
        self.box_sum = BoxFilter(window_size=patch_size, reduction='sum')

    def forward(self, rgb):
        y = rgb_to_luminance(rgb)  # (N, 1, H, W)

        rgb_window_stack = self.gen_window_stack(rgb)  # (N, 3, H, W, w*y)
        y_window_stack = self.gen_window_stack(y)  # (N, 1, H, W, w*y)

        distances = torch.sqrt(self.box_sum((y.unsqueeze(-1) - y_window_stack) ** 2))  # (N, 1, H, W, w*y)
        weights = torch.exp(-distances / (torch.relu(self.h) + EPSILON))  # (N, 1, H, W, w*y)

        denoised_rgb = (weights * rgb_window_stack).sum(dim=-1) / weights.sum(dim=-1)  # (N, 3, H, W)

        return torch.clamp(denoised_rgb, 0, 1)  # (N, 3, H, W)

其中的 self.gen_window_stack 方法用来将搜索窗口内的像素拍扁并放到一个新的维度中(ShiftStacknn.Unfold 类似,但是其实现比 Unfold 更省显存);self.box_sum 方法用来计算 $k\times{}k$ 窗口内的各像素值之和。它们的具体实现同样放在了文末附录中。

让我们先看下迭代过程中的 PSNR 和 $h$ 参数的变化,确保训练过程可以正常收敛:

step0: PSNR=28.600 (h=0.05050)
step1: PSNR=28.675 (h=0.05098)
step2: PSNR=28.747 (h=0.05146)
step3: PSNR=28.818 (h=0.05194)
...
step97: PSNR=32.882 (h=0.10972)
step98: PSNR=32.888 (h=0.11034)
step99: PSNR=32.894 (h=0.11096)

再看下降噪效果,确保这段代码能够正常实现 NLM 算法:

Ground-truth

+AWGN ($\sigma=0.05$)

NLM with $h=0.11096$ (PSNR=32.9)

在这个版本的 NLM 中,对于单张 $512\times{}512$ 的 RGB 图像,一次 forward 大概需要占用 2.3G 显存(已使用 with torch.no_grad() 关闭梯度计算),且训练阶段(forward+backward,见上文第一段代码块)的显存占用与 forward 基本持平:

forward:
+-----------------------------------------------------------------------------+
| NVIDIA-SMI 384.81                 Driver Version: 384.81                    |
|-------------------------------+----------------------+----------------------+
| GPU  Name        Persistence-M| Bus-Id        Disp.A | Volatile Uncorr. ECC |
| Fan  Temp  Perf  Pwr:Usage/Cap|         Memory-Usage | GPU-Util  Compute M. |
|===============================+======================+======================|
|   0  GeForce GTX 106...  Off  | 00000000:0D:00.0  On |                  N/A |
| 39%   57C    P2    90W / 120W |   2305MiB /  6069MiB |    100%      Default |
+-------------------------------+----------------------+----------------------+

forward + backward:
+-----------------------------------------------------------------------------+
| NVIDIA-SMI 384.81                 Driver Version: 384.81                    |
|-------------------------------+----------------------+----------------------+
| GPU  Name        Persistence-M| Bus-Id        Disp.A | Volatile Uncorr. ECC |
| Fan  Temp  Perf  Pwr:Usage/Cap|         Memory-Usage | GPU-Util  Compute M. |
|===============================+======================+======================|
|   0  GeForce GTX 106...  Off  | 00000000:0D:00.0  On |                  N/A |
| 43%   59C    P2    96W / 120W |   2307MiB /  6069MiB |    100%      Default |
+-------------------------------+----------------------+----------------------+

显然,这个版本中,中间变量 rgb_window_stacky_window_stackdistancesweights 均是5维张量,当使用默认的搜索窗口宽度 $w=11$ 时,为了存下它们,PyTorch 至少需要额外申请121倍图像大小的显存。

NonLocalMeans v2

在 NLM 算法中,用来计算相似度的图像块的尺寸 $k$ 通常要小于搜索窗口的尺寸 $w$,因此一个很容易想到的显存优化策略就是,将额外多出来的那个维度用来存储图像块内的各个像素值,而非搜索窗口内的各个像素值:

# nlm_v2.py

class NonLocalMeans(nn.Module):
    def __init__(self, h0, search_window_size=11, patch_size=5):
        super().__init__()
        self.h = nn.Parameter(torch.tensor([float(h0)]), requires_grad=True)

        self.gen_patch_stack = ShiftStack(window_size=patch_size)
        self.r = search_window_size // 2

    def forward(self, rgb):
        batch_size, _, height, width = rgb.shape
        weights = torch.zeros((batch_size, 1, height, width)).float().to(rgb.device)  # (N, 1, H, W)
        denoised_rgb = torch.zeros_like(rgb)  # (N, 3, H, W)

        y = rgb_to_luminance(rgb)  # (N, 1, H, W)

        y_patch_stack = self.gen_patch_stack(y)  # (N, 1, H, W, k*k)

        for x_shift in range(-self.r, self.r + 1):
            for y_shift in range(-self.r, self.r + 1):
                shifted_rgb = torch.roll(rgb, shifts=(y_shift, x_shift), dims=(2, 3))  # (N, 3, H, W)

                shifted_y_patch_stack = \
                    torch.roll(y_patch_stack, (y_shift, x_shift), dims=(2, 3))  # (N, 1, H, W, k*k)

                distance = torch.sqrt(((y_patch_stack - shifted_y_patch_stack) ** 2).sum(dim=-1))  # (N, 1, H, W)
                weight = torch.exp(-distance / (torch.relu(self.h) + EPSILON))  # (N, 1, H, W)

                denoised_rgb += shifted_rgb * weight  # (N, 3, H, W)
                weights += weight  # (N, 1, H, W)

        return torch.clamp(denoised_rgb / weights, 0, 1)  # (N, 3, H, W)

这个版本中,我使用两层循环完成对搜索窗口中每个像素的遍历,因此无论搜索窗口设定得多大,其影响的只是运行速度,而非显存。当图像块宽度 $k=5$,搜索窗口宽度 $w=11$ 时,forward 阶段显存开销理论上可以降低至 v1 版本的 $\frac{25}{121}\approx{}21\%$。

实测发现,对于同样尺寸的输入图像,这个版本的 NLM 一次 forward 只需要占用大约 600M 显存,而训练阶段则是 1.2G:

forward:
+-----------------------------------------------------------------------------+
| NVIDIA-SMI 384.81                 Driver Version: 384.81                    |
|-------------------------------+----------------------+----------------------+
| GPU  Name        Persistence-M| Bus-Id        Disp.A | Volatile Uncorr. ECC |
| Fan  Temp  Perf  Pwr:Usage/Cap|         Memory-Usage | GPU-Util  Compute M. |
|===============================+======================+======================|
|   0  GeForce GTX 106...  Off  | 00000000:0D:00.0  On |                  N/A |
| 45%   61C    P2    92W / 120W |    603MiB /  6069MiB |    100%      Default |
+-------------------------------+----------------------+----------------------+

forward + backward:
+-----------------------------------------------------------------------------+
| NVIDIA-SMI 384.81                 Driver Version: 384.81                    |
|-------------------------------+----------------------+----------------------+
| GPU  Name        Persistence-M| Bus-Id        Disp.A | Volatile Uncorr. ECC |
| Fan  Temp  Perf  Pwr:Usage/Cap|         Memory-Usage | GPU-Util  Compute M. |
|===============================+======================+======================|
|   0  GeForce GTX 106...  Off  | 00000000:0D:00.0  On |                  N/A |
| 40%   57C    P2    96W / 120W |   1185MiB /  6069MiB |     99%      Default |
+-------------------------------+----------------------+----------------------+

NonLocalMeans v3

再极端一点,如果我们连用来存储图像块内的各个像素值的中间变量都不想创建呢?显然,可以在 v2 版本的基础上,再加两层循环完成图像块相似度计算(BoxFilter 实际上也是利用两层循环来实现卷积的效果),这个时候,没有任何新的维度需要添加到中间过程的张量中,所有操作都可以原址完成:

# nlm_v3.py

class NonLocalMeans(nn.Module):
    def __init__(self, h0, search_window_size=11, patch_size=5):
        super().__init__()
        self.h = nn.Parameter(torch.tensor([float(h0)]), requires_grad=True)

        self.box_sum = BoxFilter(window_size=patch_size, reduction='sum')
        self.r = search_window_size // 2

    def forward(self, rgb):
        batch_size, _, height, width = rgb.shape
        weights = torch.zeros((batch_size, 1, height, width)).float().to(rgb.device)  # (N, 1, H, W)
        denoised_rgb = torch.zeros_like(rgb)  # (N, 3, H, W)

        y = rgb_to_luminance(rgb)  # (N, 1, H, W)

        for x_shift in range(-self.r, self.r + 1):
            for y_shift in range(-self.r, self.r + 1):
                shifted_rgb = torch.roll(rgb, shifts=(y_shift, x_shift), dims=(2, 3))  # (N, 3, H, W)
                shifted_y = torch.roll(y, shifts=(y_shift, x_shift), dims=(2, 3))  # (N, 1, H, W)

                distance = torch.sqrt(self.box_sum((y - shifted_y) ** 2))  # (N, 1, H, W)
                weight = torch.exp(-distance / (torch.relu(self.h) + EPSILON))  # (N, 1, H, W)

                denoised_rgb += shifted_rgb * weight  # (N, 3, H, W)
                weights += weight  # (N, 1, H, W)

        return torch.clamp(denoised_rgb / weights, 0, 1)  # (N, 3, H, W)

在这个版本里,一次 forward 和一次 backward 分别只占用了大约 470M 和 1.1G 显存:

forward:
+-----------------------------------------------------------------------------+
| NVIDIA-SMI 384.81                 Driver Version: 384.81                    |
|-------------------------------+----------------------+----------------------+
| GPU  Name        Persistence-M| Bus-Id        Disp.A | Volatile Uncorr. ECC |
| Fan  Temp  Perf  Pwr:Usage/Cap|         Memory-Usage | GPU-Util  Compute M. |
|===============================+======================+======================|
|   0  GeForce GTX 106...  Off  | 00000000:0D:00.0  On |                  N/A |
| 43%   60C    P2    92W / 120W |    473MiB /  6069MiB |    100%      Default |
+-------------------------------+----------------------+----------------------+

forward + backward:
+-----------------------------------------------------------------------------+
| NVIDIA-SMI 384.81                 Driver Version: 384.81                    |
|-------------------------------+----------------------+----------------------+
| GPU  Name        Persistence-M| Bus-Id        Disp.A | Volatile Uncorr. ECC |
| Fan  Temp  Perf  Pwr:Usage/Cap|         Memory-Usage | GPU-Util  Compute M. |
|===============================+======================+======================|
|   0  GeForce GTX 106...  Off  | 00000000:0D:00.0  On |                  N/A |
| 43%   59C    P2    72W / 120W |   1113MiB /  6069MiB |     99%      Default |
+-------------------------------+----------------------+----------------------+

这也是我能想到的针对 Non-local Means 算法最显存友好的一种实现了。

Reference

附录

gen_image_pair

# utils.py

import numpy as np
import skimage.io
import torch
import torch.nn as nn


def gen_image_pair(image_path, device, sigma):
    # :return: two tensors with shape (1, 3, H, W) in [0, 1] range
    clean_rgb = skimage.io.imread(image_path).astype(np.float32) / 255.0
    # additive white gaussian noise
    awgn = np.random.normal(0, scale=sigma, size=rgb.shape).astype(np.float32)
    noisy_rgb = np.clip(clean_rgb + awgn, 0, 1)

    clean_rgb = torch.from_numpy(rgb.transpose(2, 0, 1)).unsqueeze(0).to(device)
    noisy_rgb = torch.from_numpy(noisy_rgb.transpose(2, 0, 1)).unsqueeze(0).to(device)
    return clean_rgb, noisy_rgb

rgb_to_luminance

# utils.py
    
def rgb_to_luminance(rgb_tensor):
    # :param rgb_tensor: torch.Tensor(N, 3, H, W, ...) in [0, 1] range
    # :return: torch.Tensor(N, 1, H, W, ...) in [0, 1] range
    assert rgb_tensor.min() >= 0.0 and rgb_tensor.max() <= 1.0
    return 0.299 * rgb_tensor[:, :1, ...] + 0.587 * rgb_tensor[:, 1:2, ...] + 0.114 * rgb_tensor[:, 2:, ...]

ShiftStack

# utils.py
    
class ShiftStack(nn.Module):
    """
    Shift n-dim tensor in a local window and generate a stacked
    (n+1)-dim tensor with shape (*orig_shapes, w*y), where wx
    and wy are width and height of the window
    """
    def __init__(self, window_size):
        # :param window_size: Int or Tuple(Int, Int) in (win_width, win_height) order
        super().__init__()
        wx, wy = window_size if isinstance(window_size, (list, tuple)) else (window_size, window_size)
        assert wx % 2 == 1 and wy % 2 == 1, 'window size must be odd'
        self.rx, self.ry = wx // 2, wy // 2

    def forward(self, tensor):
        # :param tensor: torch.Tensor(N, C, H, W, ...)
        # :return: torch.Tensor(N, C, H, W, ..., w*y)
        shifted_tensors = []
        for x_shift in range(-self.rx, self.rx + 1):
            for y_shift in range(-self.ry, self.ry + 1):
                shifted_tensors.append(
                    torch.roll(tensor, shifts=(y_shift, x_shift), dims=(2, 3))
                )

        return torch.stack(shifted_tensors, dim=-1)

BoxFilter

# utils.py
    
class BoxFilter(nn.Module):
    def __init__(self, window_size, reduction='mean'):
        # :param window_size: Int or Tuple(Int, Int) in (win_width, win_height) order
        # :param reduction: 'mean' | 'sum'
        super().__init__()
        wx, wy = window_size if isinstance(window_size, (list, tuple)) else (window_size, window_size)
        assert wx % 2 == 1 and wy % 2 == 1, 'window size must be odd'
        self.rx, self.ry = wx // 2, wy // 2
        self.area = wx * wy
        self.reduction = reduction

    def forward(self, tensor):
        # :param tensor: torch.Tensor(N, C, H, W, ...)
        # :return: torch.Tensor(N, C, H, W, ...)
        local_sum = torch.zeros_like(tensor)
        for x_shift in range(-self.rx, self.rx + 1):
            for y_shift in range(-self.ry, self.ry + 1):
                local_sum += torch.roll(tensor, shifts=(y_shift, x_shift), dims=(2, 3))

        return local_sum if self.reduction == 'sum' else local_sum / self.area

About the author

Jueqin

本作品以 CC BY-NC-ND 许可协议进行发布。

如果您认为文章对您有用的话,不妨请我喝一杯咖啡?

Add comment