Loading

Face De-Blurring

U2Net

Face deblurring using Nested U-Structure for Salient Object Detection; top1 - 0.798 SSIM score

anna_mrukwa

In this task we use the U2-Net model architecture. It is used for tasks such as salient object detection, background removal or art generation. The network should be able to capture the face characteristics.

Download the files

Download AIcrowd CLI

In [ ]:
!pip install aicrowd-cli
%load_ext aicrowd.magic

Login to AIcrowd

In [ ]:
%aicrowd login

Download Dataset

We will create a folder name data and download the files there.

In [ ]:
!rm -rf data
!mkdir data
%aicrowd ds dl -c face-de-blurring -o data
In [ ]:
!unzip data/train.zip -d data/train > /dev/null
!unzip data/val.zip -d data/val > /dev/null
!unzip data/test.zip -d data/test > /dev/null

Importing Libraries:

In [ ]:
import pandas as pd
import numpy as np
import os
import torch
import torch.optim as optim
import torch.nn as nn
import math
import torchvision
import torchvision.transforms as T
import torchvision.transforms.functional as F
from PIL import Image
from tqdm import tqdm
from kornia.losses import ssim_loss
import typing
import random
import cv2 as cv
import matplotlib.pyplot as plt

Creating helper classes

We treat the blurred images as the input image and the original ones as the target.

In [ ]:
class BlurDataset(torch.utils.data.Dataset):
    def __init__(self, root, transforms):
        self.root = root
        self.transforms = transforms
        # load all image files, sorting them to
        # ensure that they are aligned
        self.imgs = list(sorted(os.listdir(os.path.join(root, "blur"))))
        self.targets = list(sorted(os.listdir(os.path.join(root, "original"))))

    def __getitem__(self, idx):
        # load images
        img_path = os.path.join(self.root, "blur", self.imgs[idx])
        target_path = os.path.join(self.root, "original", self.targets[idx])
        img = Image.open(img_path).convert("RGB")
        width, height = img.size
        if width != 512 or height != 512:
            img = img.resize((512, 512))
        target = Image.open(target_path).convert("RGB")
        width, height = target.size
        if width != 512 or height != 512:
            target = target.resize((512, 512))
            
        
        if self.transforms is not None:
            img = np.array(img)
            target = np.array(target)
            transformed = self.transforms(image=img, mask=target)
            img, target = transformed['image'], transformed['mask']
            img = Image.fromarray(img)
            target = Image.fromarray(target)
        
        img = T.ToTensor()(img)
        target = T.ToTensor()(target)
        return img, target

    def __len__(self):
        return len(self.imgs)
In [ ]:
class TestBlurDataset(torch.utils.data.Dataset):
    def __init__(self, root, transforms):
        self.root = root
        self.transforms = transforms
        # load all image files
        self.imgs = list(sorted(os.listdir(os.path.join(root, "blur"))))

    def __getitem__(self, idx):
        # load images
        img_path = os.path.join(self.root, "blur", self.imgs[idx])
        img = Image.open(img_path).convert("RGB")
        width, height = img.size
        if width != 512 or height != 512:
            img = img.resize((512, 512))
        if self.transforms is not None:
            img = self.transforms(img)

        return img, self.imgs[idx] # for saving purposes

    def __len__(self):
        return len(self.imgs)

Diving in the dataset

In [ ]:
train_images = 'data/train'
val_images = 'data/val'
test_images = "data/test"

Preprocessing

In [ ]:
train_transform = A.Compose([
    A.augmentations.crops.transforms.RandomResizedCrop(512, 512, p=0.5),
    A.augmentations.geometric.transforms.Perspective(),
    A.augmentations.transforms.HorizontalFlip(),
    A.augmentations.geometric.transforms.ShiftScaleRotate(border_mode=0)])
In [ ]:
train_ds = BlurDataset(train_images, train_transform)
val_ds = BlurDataset(val_images, None)

Architecture

In [ ]:
def _upsample_like(x, size):
    return nn.Upsample(size=size, mode='bilinear', align_corners=False)(x)


def _size_map(x, height):
    # {height: size} for Upsample
    size = list(x.shape[-2:])
    sizes = {}
    for h in range(1, height):
        sizes[h] = size
        size = [math.ceil(w / 2) for w in size]
    return sizes


class REBNCONV(nn.Module):
    def __init__(self, in_ch=3, out_ch=3, dilate=1):
        super(REBNCONV, self).__init__()

        self.conv_s1 = nn.Conv2d(in_ch, out_ch, 3, padding=1 * dilate, dilation=1 * dilate)
        self.bn_s1 = nn.BatchNorm2d(out_ch)
        self.relu_s1 = nn.ReLU(inplace=True)

    def forward(self, x):
        return self.relu_s1(self.bn_s1(self.conv_s1(x)))


class RSU(nn.Module):
    def __init__(self, name, height, in_ch, mid_ch, out_ch, dilated=False):
        super(RSU, self).__init__()
        self.name = name
        self.height = height
        self.dilated = dilated
        self._make_layers(height, in_ch, mid_ch, out_ch, dilated)

    def forward(self, x):
        sizes = _size_map(x, self.height)
        x = self.rebnconvin(x)

        # U-Net like symmetric encoder-decoder structure
        def unet(x, height=1):
            if height < self.height:
                x1 = getattr(self, f'rebnconv{height}')(x)
                if not self.dilated and height < self.height - 1:
                    x2 = unet(getattr(self, 'downsample')(x1), height + 1)
                else:
                    x2 = unet(x1, height + 1)

                x = getattr(self, f'rebnconv{height}d')(torch.cat((x2, x1), 1))
                return _upsample_like(x, sizes[height - 1]) if not self.dilated and height > 1 else x
            else:
                return getattr(self, f'rebnconv{height}')(x)

        return x + unet(x)

    def _make_layers(self, height, in_ch, mid_ch, out_ch, dilated=False):
        self.add_module('rebnconvin', REBNCONV(in_ch, out_ch))
        self.add_module('downsample', nn.MaxPool2d(2, stride=2, ceil_mode=True))

        self.add_module(f'rebnconv1', REBNCONV(out_ch, mid_ch))
        self.add_module(f'rebnconv1d', REBNCONV(mid_ch * 2, out_ch))

        for i in range(2, height):
            dilate = 1 if not dilated else 2 ** (i - 1)
            self.add_module(f'rebnconv{i}', REBNCONV(mid_ch, mid_ch, dilate=dilate))
            self.add_module(f'rebnconv{i}d', REBNCONV(mid_ch * 2, mid_ch, dilate=dilate))

        dilate = 2 if not dilated else 2 ** (height - 1)
        self.add_module(f'rebnconv{height}', REBNCONV(mid_ch, mid_ch, dilate=dilate))


class U2NET(nn.Module):
    def __init__(self, cfgs, out_ch):
        super(U2NET, self).__init__()
        self.out_ch = out_ch
        self._make_layers(cfgs)

    def forward(self, x):
        sizes = _size_map(x, self.height)
        maps = []  # storage for maps

        # side saliency map
        def unet(x, height=1):
            if height < 6:
                x1 = getattr(self, f'stage{height}')(x)
                x2 = unet(getattr(self, 'downsample')(x1), height + 1)
                x = getattr(self, f'stage{height}d')(torch.cat((x2, x1), 1))
                side(x, height)
                return _upsample_like(x, sizes[height - 1]) if height > 1 else x
            else:
                x = getattr(self, f'stage{height}')(x)
                side(x, height)
                return _upsample_like(x, sizes[height - 1])

        def side(x, h):
            # side output saliency map (before sigmoid)
            x = getattr(self, f'side{h}')(x)
            x = _upsample_like(x, sizes[1])
            maps.append(x)

        def fuse():
            # fuse saliency probability maps
            maps.reverse()
            x = torch.cat(maps, 1)
            x = getattr(self, 'outconv')(x)
            maps.insert(0, x)
            return [torch.sigmoid(x) for x in maps]

        unet(x)
        maps = fuse()
        return maps

    def _make_layers(self, cfgs):
        self.height = int((len(cfgs) + 1) / 2)
        self.add_module('downsample', nn.MaxPool2d(2, stride=2, ceil_mode=True))
        for k, v in cfgs.items():
            # build rsu block
            self.add_module(k, RSU(v[0], *v[1]))
            if v[2] > 0:
                # build side layer
                self.add_module(f'side{v[0][-1]}', nn.Conv2d(v[2], self.out_ch, 3, padding=1))
        # build fuse layer
        self.add_module('outconv', nn.Conv2d(int(self.height * self.out_ch), self.out_ch, 1))


def U2NET_full():
    full = {
        # cfgs for building RSUs and sides
        # {stage : [name, (height(L), in_ch, mid_ch, out_ch, dilated), side]}
        'stage1': ['En_1', (7, 3, 32, 64), -1],
        'stage2': ['En_2', (6, 64, 32, 128), -1],
        'stage3': ['En_3', (5, 128, 64, 256), -1],
        'stage4': ['En_4', (4, 256, 128, 512), -1],
        'stage5': ['En_5', (4, 512, 256, 512, True), -1],
        'stage6': ['En_6', (4, 512, 256, 512, True), 512],
        'stage5d': ['De_5', (4, 1024, 256, 512, True), 512],
        'stage4d': ['De_4', (4, 1024, 128, 256), 256],
        'stage3d': ['De_3', (5, 512, 64, 128), 128],
        'stage2d': ['De_2', (6, 256, 32, 64), 64],
        'stage1d': ['De_1', (7, 128, 16, 64), 64],
    }
    return U2NET(cfgs=full, out_ch=3)


def U2NET_lite():
    lite = {
        # cfgs for building RSUs and sides
        # {stage : [name, (height(L), in_ch, mid_ch, out_ch, dilated), side]}
        'stage1': ['En_1', (7, 3, 16, 64), -1],
        'stage2': ['En_2', (6, 64, 16, 64), -1],
        'stage3': ['En_3', (5, 64, 16, 64), -1],
        'stage4': ['En_4', (4, 64, 16, 64), -1],
        'stage5': ['En_5', (4, 64, 16, 64, True), -1],
        'stage6': ['En_6', (4, 64, 16, 64, True), 64],
        'stage5d': ['De_5', (4, 128, 16, 64, True), 64],
        'stage4d': ['De_4', (4, 128, 16, 64), 64],
        'stage3d': ['De_3', (5, 128, 16, 64), 64],
        'stage2d': ['De_2', (6, 128, 16, 64), 64],
        'stage1d': ['De_1', (7, 128, 16, 64), 64],
    }
    return U2NET(cfgs=lite, out_ch=3)

Training model

In [ ]:
model = U2NET_lite()
In [ ]:
optimizer = optim.Adam(model.parameters(), lr=0.005)
In [ ]:
def train(model, train_ds, val_ds, optimizer, epochs_no=50, patience=5, batch_size=8):
    history = {"train_loss":[], "val_loss":[]}
    cooldown = 0
    steps_train = int(len(train_ds)/batch_size)
    steps_val = int(len(val_ds)/batch_size)
    device = torch.device('cuda') if torch.cuda.is_available() else torch.device('cpu')
    model.to(device)
    train_loader = torch.utils.data.DataLoader(
        train_ds, batch_size=batch_size, shuffle=True, num_workers=8)
    val_loader = torch.utils.data.DataLoader(
        val_ds, batch_size=batch_size, shuffle=True, num_workers=8)
    
    for epoch in tqdm(range(epochs_no)):
        model.train()
        epoch_train_loss = 0
        epoch_val_loss = 0
        for img, target in train_loader:
            img, target = img.to(device), target.to(device)
            pred = model(img)[0]
            loss = ssim_loss(target, pred, 5)
            optimizer.zero_grad()
            loss.backward()
            optimizer.step()
            epoch_train_loss += loss
            
        with torch.no_grad():
            model.eval()
            for img, target in val_loader:
                img, target = img.to(device), target.to(device)
                pred = model(img)[0]
                epoch_val_loss += ssim_loss(target, pred, 5)
                
        epoch_train_loss /= steps_train
        epoch_val_loss /= steps_val
        history["train_loss"].append(epoch_train_loss.cpu().detach().numpy())
        history["val_loss"].append(epoch_val_loss.cpu().detach().numpy())
        
        print("EPOCH: {}/{}".format(epoch + 1, epochs_no))
        print("Train loss: {:.6f}, Validation loss: {:.4f}".format(epoch_train_loss, epoch_val_loss))
        
        if epoch != 0 and history["val_loss"][epoch] > history["val_loss"][epoch-1]:
            cooldown += 1
            if cooldown == patience:
                break
        else:
            cooldown = 0
            if epoch == 0:
                min_val_loss = history["val_loss"][epoch]
            elif min_val_loss >= history["val_loss"][epoch]:
                min_val_loss = history["val_loss"][epoch]
                torch.save(model.state_dict(), 'model_weights.pth')
    model.load_state_dict(torch.load('model_weights.pth'))
    return model
In [ ]:
model = train(model, train_ds, val_ds, optimizer)

Generating Prediction Files

In [ ]:
!mkdir original
In [ ]:
def get_test_transform():
    transforms = []
    transforms.append(T.ToTensor())
    return T.Compose(transforms)
In [ ]:
test_ds = TestBlurDataset(test_images, get_test_transform())
In [ ]:
test_loader = torch.utils.data.DataLoader(
        test_ds, batch_size=1, shuffle=False, num_workers=8)
In [ ]:
device = torch.device('cuda') if torch.cuda.is_available() else torch.device('cpu')
model.to(device)
model.eval()
In [ ]:
solution_directory = "original"
to_pil = T.ToPILImage()

def save_result(img, name):
    img = torch.squeeze(img)
    img = to_pil(img)
    img_path = os.path.join(solution_directory, name)
    img.save(img_path)
In [ ]:
for img, (name,) in tqdm(test_loader):
    img = img.to(device)
    res = model(img)[0]
    save_result(res, name)
In [ ]:
def show_result(img_name):
    blur = cv.imread("data/test/blur/"+img_name)
    blur = cv.cvtColor(blur, cv.COLOR_RGB2BGR)
    original = cv.imread("original/"+img_name)
    original = cv.cvtColor(original, cv.COLOR_RGB2BGR)
    
    fig = plt.figure(figsize=(10, 20))
    ax1 = fig.add_subplot(121)
    ax2 = fig.add_subplot(122)
    ax1.axis('off')
    ax2.axis('off')
    ax1.imshow(blur)
    ax2.imshow(original)
    ax1.title.set_text('Blurred image')
    ax2.title.set_text('Reconstructed')
    plt.show()
In [24]:
show_result("01d44.jpg")

Submitting our Predictions

Note : Please save the notebook before submitting it (Ctrl + S)

In [ ]:
%aicrowd notebook submit -c face-de-blurring -a original   --no-verify
In [ ]:


Comments

jinoooooooooo
Over 2 years ago

great solution, but perhaps could you kindly explain the architecture a little more? thanks!

You must login before you can post a comment.

Execute