Face De-Blurring
U2Net
Face deblurring using Nested U-Structure for Salient Object Detection; top1 - 0.798 SSIM score
In this task we use the U2-Net model architecture. It is used for tasks such as salient object detection, background removal or art generation. The network should be able to capture the face characteristics.
In [ ]:
!pip install aicrowd-cli
%load_ext aicrowd.magic
Login to AIcrowd¶
In [ ]:
%aicrowd login
Download Dataset¶
We will create a folder name data and download the files there.
In [ ]:
!rm -rf data
!mkdir data
%aicrowd ds dl -c face-de-blurring -o data
In [ ]:
!unzip data/train.zip -d data/train > /dev/null
!unzip data/val.zip -d data/val > /dev/null
!unzip data/test.zip -d data/test > /dev/null
Importing Libraries:¶
In [ ]:
import pandas as pd
import numpy as np
import os
import torch
import torch.optim as optim
import torch.nn as nn
import math
import torchvision
import torchvision.transforms as T
import torchvision.transforms.functional as F
from PIL import Image
from tqdm import tqdm
from kornia.losses import ssim_loss
import typing
import random
import cv2 as cv
import matplotlib.pyplot as plt
Creating helper classes¶
We treat the blurred images as the input image and the original ones as the target.
In [ ]:
class BlurDataset(torch.utils.data.Dataset):
def __init__(self, root, transforms):
self.root = root
self.transforms = transforms
# load all image files, sorting them to
# ensure that they are aligned
self.imgs = list(sorted(os.listdir(os.path.join(root, "blur"))))
self.targets = list(sorted(os.listdir(os.path.join(root, "original"))))
def __getitem__(self, idx):
# load images
img_path = os.path.join(self.root, "blur", self.imgs[idx])
target_path = os.path.join(self.root, "original", self.targets[idx])
img = Image.open(img_path).convert("RGB")
width, height = img.size
if width != 512 or height != 512:
img = img.resize((512, 512))
target = Image.open(target_path).convert("RGB")
width, height = target.size
if width != 512 or height != 512:
target = target.resize((512, 512))
if self.transforms is not None:
img = np.array(img)
target = np.array(target)
transformed = self.transforms(image=img, mask=target)
img, target = transformed['image'], transformed['mask']
img = Image.fromarray(img)
target = Image.fromarray(target)
img = T.ToTensor()(img)
target = T.ToTensor()(target)
return img, target
def __len__(self):
return len(self.imgs)
In [ ]:
class TestBlurDataset(torch.utils.data.Dataset):
def __init__(self, root, transforms):
self.root = root
self.transforms = transforms
# load all image files
self.imgs = list(sorted(os.listdir(os.path.join(root, "blur"))))
def __getitem__(self, idx):
# load images
img_path = os.path.join(self.root, "blur", self.imgs[idx])
img = Image.open(img_path).convert("RGB")
width, height = img.size
if width != 512 or height != 512:
img = img.resize((512, 512))
if self.transforms is not None:
img = self.transforms(img)
return img, self.imgs[idx] # for saving purposes
def __len__(self):
return len(self.imgs)
Diving in the dataset¶
In [ ]:
train_images = 'data/train'
val_images = 'data/val'
test_images = "data/test"
Preprocessing¶
In [ ]:
train_transform = A.Compose([
A.augmentations.crops.transforms.RandomResizedCrop(512, 512, p=0.5),
A.augmentations.geometric.transforms.Perspective(),
A.augmentations.transforms.HorizontalFlip(),
A.augmentations.geometric.transforms.ShiftScaleRotate(border_mode=0)])
In [ ]:
train_ds = BlurDataset(train_images, train_transform)
val_ds = BlurDataset(val_images, None)
Architecture¶
In [ ]:
def _upsample_like(x, size):
return nn.Upsample(size=size, mode='bilinear', align_corners=False)(x)
def _size_map(x, height):
# {height: size} for Upsample
size = list(x.shape[-2:])
sizes = {}
for h in range(1, height):
sizes[h] = size
size = [math.ceil(w / 2) for w in size]
return sizes
class REBNCONV(nn.Module):
def __init__(self, in_ch=3, out_ch=3, dilate=1):
super(REBNCONV, self).__init__()
self.conv_s1 = nn.Conv2d(in_ch, out_ch, 3, padding=1 * dilate, dilation=1 * dilate)
self.bn_s1 = nn.BatchNorm2d(out_ch)
self.relu_s1 = nn.ReLU(inplace=True)
def forward(self, x):
return self.relu_s1(self.bn_s1(self.conv_s1(x)))
class RSU(nn.Module):
def __init__(self, name, height, in_ch, mid_ch, out_ch, dilated=False):
super(RSU, self).__init__()
self.name = name
self.height = height
self.dilated = dilated
self._make_layers(height, in_ch, mid_ch, out_ch, dilated)
def forward(self, x):
sizes = _size_map(x, self.height)
x = self.rebnconvin(x)
# U-Net like symmetric encoder-decoder structure
def unet(x, height=1):
if height < self.height:
x1 = getattr(self, f'rebnconv{height}')(x)
if not self.dilated and height < self.height - 1:
x2 = unet(getattr(self, 'downsample')(x1), height + 1)
else:
x2 = unet(x1, height + 1)
x = getattr(self, f'rebnconv{height}d')(torch.cat((x2, x1), 1))
return _upsample_like(x, sizes[height - 1]) if not self.dilated and height > 1 else x
else:
return getattr(self, f'rebnconv{height}')(x)
return x + unet(x)
def _make_layers(self, height, in_ch, mid_ch, out_ch, dilated=False):
self.add_module('rebnconvin', REBNCONV(in_ch, out_ch))
self.add_module('downsample', nn.MaxPool2d(2, stride=2, ceil_mode=True))
self.add_module(f'rebnconv1', REBNCONV(out_ch, mid_ch))
self.add_module(f'rebnconv1d', REBNCONV(mid_ch * 2, out_ch))
for i in range(2, height):
dilate = 1 if not dilated else 2 ** (i - 1)
self.add_module(f'rebnconv{i}', REBNCONV(mid_ch, mid_ch, dilate=dilate))
self.add_module(f'rebnconv{i}d', REBNCONV(mid_ch * 2, mid_ch, dilate=dilate))
dilate = 2 if not dilated else 2 ** (height - 1)
self.add_module(f'rebnconv{height}', REBNCONV(mid_ch, mid_ch, dilate=dilate))
class U2NET(nn.Module):
def __init__(self, cfgs, out_ch):
super(U2NET, self).__init__()
self.out_ch = out_ch
self._make_layers(cfgs)
def forward(self, x):
sizes = _size_map(x, self.height)
maps = [] # storage for maps
# side saliency map
def unet(x, height=1):
if height < 6:
x1 = getattr(self, f'stage{height}')(x)
x2 = unet(getattr(self, 'downsample')(x1), height + 1)
x = getattr(self, f'stage{height}d')(torch.cat((x2, x1), 1))
side(x, height)
return _upsample_like(x, sizes[height - 1]) if height > 1 else x
else:
x = getattr(self, f'stage{height}')(x)
side(x, height)
return _upsample_like(x, sizes[height - 1])
def side(x, h):
# side output saliency map (before sigmoid)
x = getattr(self, f'side{h}')(x)
x = _upsample_like(x, sizes[1])
maps.append(x)
def fuse():
# fuse saliency probability maps
maps.reverse()
x = torch.cat(maps, 1)
x = getattr(self, 'outconv')(x)
maps.insert(0, x)
return [torch.sigmoid(x) for x in maps]
unet(x)
maps = fuse()
return maps
def _make_layers(self, cfgs):
self.height = int((len(cfgs) + 1) / 2)
self.add_module('downsample', nn.MaxPool2d(2, stride=2, ceil_mode=True))
for k, v in cfgs.items():
# build rsu block
self.add_module(k, RSU(v[0], *v[1]))
if v[2] > 0:
# build side layer
self.add_module(f'side{v[0][-1]}', nn.Conv2d(v[2], self.out_ch, 3, padding=1))
# build fuse layer
self.add_module('outconv', nn.Conv2d(int(self.height * self.out_ch), self.out_ch, 1))
def U2NET_full():
full = {
# cfgs for building RSUs and sides
# {stage : [name, (height(L), in_ch, mid_ch, out_ch, dilated), side]}
'stage1': ['En_1', (7, 3, 32, 64), -1],
'stage2': ['En_2', (6, 64, 32, 128), -1],
'stage3': ['En_3', (5, 128, 64, 256), -1],
'stage4': ['En_4', (4, 256, 128, 512), -1],
'stage5': ['En_5', (4, 512, 256, 512, True), -1],
'stage6': ['En_6', (4, 512, 256, 512, True), 512],
'stage5d': ['De_5', (4, 1024, 256, 512, True), 512],
'stage4d': ['De_4', (4, 1024, 128, 256), 256],
'stage3d': ['De_3', (5, 512, 64, 128), 128],
'stage2d': ['De_2', (6, 256, 32, 64), 64],
'stage1d': ['De_1', (7, 128, 16, 64), 64],
}
return U2NET(cfgs=full, out_ch=3)
def U2NET_lite():
lite = {
# cfgs for building RSUs and sides
# {stage : [name, (height(L), in_ch, mid_ch, out_ch, dilated), side]}
'stage1': ['En_1', (7, 3, 16, 64), -1],
'stage2': ['En_2', (6, 64, 16, 64), -1],
'stage3': ['En_3', (5, 64, 16, 64), -1],
'stage4': ['En_4', (4, 64, 16, 64), -1],
'stage5': ['En_5', (4, 64, 16, 64, True), -1],
'stage6': ['En_6', (4, 64, 16, 64, True), 64],
'stage5d': ['De_5', (4, 128, 16, 64, True), 64],
'stage4d': ['De_4', (4, 128, 16, 64), 64],
'stage3d': ['De_3', (5, 128, 16, 64), 64],
'stage2d': ['De_2', (6, 128, 16, 64), 64],
'stage1d': ['De_1', (7, 128, 16, 64), 64],
}
return U2NET(cfgs=lite, out_ch=3)
Training model¶
In [ ]:
model = U2NET_lite()
In [ ]:
optimizer = optim.Adam(model.parameters(), lr=0.005)
In [ ]:
def train(model, train_ds, val_ds, optimizer, epochs_no=50, patience=5, batch_size=8):
history = {"train_loss":[], "val_loss":[]}
cooldown = 0
steps_train = int(len(train_ds)/batch_size)
steps_val = int(len(val_ds)/batch_size)
device = torch.device('cuda') if torch.cuda.is_available() else torch.device('cpu')
model.to(device)
train_loader = torch.utils.data.DataLoader(
train_ds, batch_size=batch_size, shuffle=True, num_workers=8)
val_loader = torch.utils.data.DataLoader(
val_ds, batch_size=batch_size, shuffle=True, num_workers=8)
for epoch in tqdm(range(epochs_no)):
model.train()
epoch_train_loss = 0
epoch_val_loss = 0
for img, target in train_loader:
img, target = img.to(device), target.to(device)
pred = model(img)[0]
loss = ssim_loss(target, pred, 5)
optimizer.zero_grad()
loss.backward()
optimizer.step()
epoch_train_loss += loss
with torch.no_grad():
model.eval()
for img, target in val_loader:
img, target = img.to(device), target.to(device)
pred = model(img)[0]
epoch_val_loss += ssim_loss(target, pred, 5)
epoch_train_loss /= steps_train
epoch_val_loss /= steps_val
history["train_loss"].append(epoch_train_loss.cpu().detach().numpy())
history["val_loss"].append(epoch_val_loss.cpu().detach().numpy())
print("EPOCH: {}/{}".format(epoch + 1, epochs_no))
print("Train loss: {:.6f}, Validation loss: {:.4f}".format(epoch_train_loss, epoch_val_loss))
if epoch != 0 and history["val_loss"][epoch] > history["val_loss"][epoch-1]:
cooldown += 1
if cooldown == patience:
break
else:
cooldown = 0
if epoch == 0:
min_val_loss = history["val_loss"][epoch]
elif min_val_loss >= history["val_loss"][epoch]:
min_val_loss = history["val_loss"][epoch]
torch.save(model.state_dict(), 'model_weights.pth')
model.load_state_dict(torch.load('model_weights.pth'))
return model
In [ ]:
model = train(model, train_ds, val_ds, optimizer)
Generating Prediction Files¶
In [ ]:
!mkdir original
In [ ]:
def get_test_transform():
transforms = []
transforms.append(T.ToTensor())
return T.Compose(transforms)
In [ ]:
test_ds = TestBlurDataset(test_images, get_test_transform())
In [ ]:
test_loader = torch.utils.data.DataLoader(
test_ds, batch_size=1, shuffle=False, num_workers=8)
In [ ]:
device = torch.device('cuda') if torch.cuda.is_available() else torch.device('cpu')
model.to(device)
model.eval()
In [ ]:
solution_directory = "original"
to_pil = T.ToPILImage()
def save_result(img, name):
img = torch.squeeze(img)
img = to_pil(img)
img_path = os.path.join(solution_directory, name)
img.save(img_path)
In [ ]:
for img, (name,) in tqdm(test_loader):
img = img.to(device)
res = model(img)[0]
save_result(res, name)
In [ ]:
def show_result(img_name):
blur = cv.imread("data/test/blur/"+img_name)
blur = cv.cvtColor(blur, cv.COLOR_RGB2BGR)
original = cv.imread("original/"+img_name)
original = cv.cvtColor(original, cv.COLOR_RGB2BGR)
fig = plt.figure(figsize=(10, 20))
ax1 = fig.add_subplot(121)
ax2 = fig.add_subplot(122)
ax1.axis('off')
ax2.axis('off')
ax1.imshow(blur)
ax2.imshow(original)
ax1.title.set_text('Blurred image')
ax2.title.set_text('Reconstructed')
plt.show()
In [24]:
show_result("01d44.jpg")
Submitting our Predictions¶
Note : Please save the notebook before submitting it (Ctrl + S)
In [ ]:
%aicrowd notebook submit -c face-de-blurring -a original --no-verify
In [ ]:
Content
Comments
You must login before you can post a comment.
great solution, but perhaps could you kindly explain the architecture a little more? thanks!