MABe 2022: Mouse-Triplets - Video Data
Unsupervised model - SimCLR - Mouse Video Data
Unsupervised model training using contrastive learning with modified SimCLR
Unsupervised model training using contrastive learning with modified SimCLR
How to use this notebook 📝¶
- Copy the notebook. This is a shared template and any edits you make here will not be saved. You should copy it into your own drive folder. For this, click the "File" menu (top-left), then "Save a Copy in Drive". You can edit your copy however you like.
- Link it to your AIcrowd account. In order to submit your predictions to AIcrowd, you need to provide your account's API key.
Setup AIcrowd Utilities 🛠¶
!pip install -U aicrowd-cli
%load_ext aicrowd.magic
Login to AIcrowd¶
%aicrowd login
Install packages 🗃¶
Please add all pacakages installations in this section
!pip install torch==1.10.2 torchvision==0.11.3 simclr
Import necessary modules and packages 📚¶
import os
import cv2
import numpy as np
from tqdm.auto import tqdm
import torch
import torchvision
import torchvision.transforms as T
from simclr import SimCLR
from simclr.modules import NT_Xent
from simclr.modules import LARS
Download and prepare the dataset 🔍¶
aicrowd_challenge_name = "mabe-2022-mouse-triplets-video-data"
if not os.path.exists('data'):
os.mkdir('data')
datafolder = 'data/'
## If data is already downloaded and stored on google drive, skip the download and point to the prepared directory
# datafolder = '/content/drive/MyDrive/mabe-2022-mouse-video/data/'
video_folder = f'{datafolder}video_clips/'
## The download might take a while, recommend to move to Google Drive if you want to run multiple times.
%aicrowd ds dl -c {aicrowd_challenge_name} -o data *.npy* # Download all files
# We'll download the 224x224 videos since they're fast on the dataloader, but you can use the full sized videos if you want
%aicrowd ds dl -c {aicrowd_challenge_name} -o data *resized_224* # Download all file
# %aicrowd ds dl -c {aicrowd_challenge_name} -o data *videos.zip* # Download the 512x512 videos
!unzip -q data/submission_videos_resized_224.zip -d {video_folder}
!unzip -q data/userTrain_videos_resized_224.zip -d {video_folder}
## Careful when running the below commands - For copying to Google Drive
# !rm data/submission_videos.zip data/userTrain_videos.zip
# !cp -r data/ '/content/drive/MyDrive/mabe-2022-mouse-video/data/'
Train Unsupervised Baseline - SIMCLR 🏋️¶
We use Contrastive learning for the baseline for the MABe video datasets. The code uses SIMCLR (A Simple Framework for Contrastive Learning of Visual Representations) - https://arxiv.org/abs/2002.05709 - A popular and "simple" contrastive learning algorithm.
Some changes are made to SIMCLR use some specific ideas about the dataset. Namely frame stacking and cropping around the animal subjects.
We use a ResNet50 model with the pytorch simclr package (unofficial) https://github.com/sthalles/SimCLR to do unsupervised learning on the video data.
We also stack past and future frames with frame skipping to incorporate temporal information into each of the embeddings.
Dataloader 📚¶
The dataloader reads video files, seeks the required past and future frames with frame skipping.
Additionaly, we also crop the images to keep the animals in focus.
class MouseVideoDataset(torch.utils.data.Dataset):
"""
Reads frames from video files
"""
def __init__(self,
datafolder,
frame_number_map,
keypoints,
frame_skip,
num_prev_frames,
num_next_frames,
frame_size=(224, 224),
transform=None):
"""
Initializing the dataset with images and labels
"""
self.datafolder = datafolder
self.transform = transform
self.frame_number_map = frame_number_map
self.num_prev_frames = num_prev_frames
self.num_next_frames = num_next_frames
self.frame_skip = frame_skip
self.frame_size = frame_size
self.keypoints = keypoints
self._setup_frame_map()
def set_transform(self, transform):
self.transform = transform
def _setup_frame_map(self):
self._video_names = np.array(list(self.frame_number_map.keys()))
# IMPORTANT: the frame number map should be sorted for self.get_video_name to work
frame_nums = np.array([self.frame_number_map[k] for k in self._video_names])
self._frame_numbers = frame_nums[:, 0] - 1 # start values
assert np.all(np.diff(self._frame_numbers) > 0), "Frame number map is not sorted"
self.length = frame_nums[-1, 1] # last value is the total number of frames
def get_frame_info(self, global_index):
""" Returns corresponding video name and frame number"""
video_idx = np.searchsorted(self._frame_numbers, global_index) - 1
frame_index = global_index - (self._frame_numbers[video_idx] + 1)
return self._video_names[video_idx], frame_index
def __len__(self):
return self.length
def __getitem__(self, idx):
video_name, frame_index = self.get_frame_info(idx)
video_path = os.path.join(self.datafolder, video_name + '.avi')
nf = self.num_next_frames + self.num_prev_frames + 1
frames_array = np.zeros((*self.frame_size, nf), dtype=np.float32)
if not os.path.exists(video_path):
# raise FileNotFoundError(video_path)
if self.transform is not None:
frames_array = self.transform(frames_array)
return { "idx": idx,
"image": frames_array,
}
cap = cv2.VideoCapture(video_path)
num_video_frames = int(cap.get(cv2.CAP_PROP_FRAME_COUNT))
for arridx, fnum in enumerate(range(frame_index - self.num_prev_frames * self.frame_skip,
frame_index + self.num_next_frames * self.frame_skip + 1,
self.frame_skip + 1)):
if fnum < 0 or fnum >= num_video_frames:
continue
cap.set(cv2.CAP_PROP_POS_FRAMES, fnum)
success, frame = cap.read()
# print(fnum, frame_index, success)
if success:
frame = cv2.cvtColor(frame, cv2.COLOR_BGR2GRAY)
frames_array[:, :, arridx] = frame
if video_name in self.keypoints['sequences']:
bbox = self.keypoints['sequences'][video_name]['bbox']
if bbox.shape[0] > frame_index:
bbox = bbox[frame_index]
frames_array = frames_array[bbox[0]:bbox[2], bbox[1]:bbox[3]] # Crop the image so random crop is more useful
if self.transform is not None:
frames_array = self.transform(frames_array)
return { "idx": idx,
"image": frames_array,
}
Utilites - Optimizer, Transforms and Augmentations 🔧¶
def load_optimizer(optimizer, epochs, weight_decay, batch_size, model):
scheduler = None
if optimizer == "Adam":
optimizer = torch.optim.Adam(model.parameters(), lr=3e-4) # TODO: LARS
elif optimizer == "LARS":
# optimized using LARS with linear learning rate scaling
# (i.e. LearningRate = 0.3 × BatchSize/256) and weight decay of 10−6.
learning_rate = 0.3 * batch_size / 256
optimizer = LARS(
model.parameters(),
lr=learning_rate,
weight_decay=weight_decay,
exclude_from_weight_decay=["batch_normalization", "bias"],
)
# "decay the learning rate with the cosine decay schedule without restarts"
scheduler = torch.optim.lr_scheduler.CosineAnnealingLR(
optimizer, epochs, eta_min=0, last_epoch=-1
)
else:
raise NotImplementedError
return optimizer, scheduler
def save_model(epoch, model_path, model, optimizer):
out = os.path.join(model_path, "checkpoint_{}.tar".format(epoch))
# To save a DataParallel model generically, save the model.module.state_dict().
# This way, you have the flexibility to load the model any way you want to any device you want.
if isinstance(model, torch.nn.DataParallel):
torch.save(model.module.state_dict(), out)
else:
torch.save(model.state_dict(), out)
class TransformsSimCLR:
def __init__(self, size, pretrained=True, n_channel=3, validation=False) -> None:
self.train_transforms = T.Compose([
T.ToTensor(),
T.RandomResizedCrop(size=size, scale=(0.25, 1.0)),
T.RandomHorizontalFlip(),
T.RandomVerticalFlip(),
# Taking the means of the normal distributions of the 3 channels
# since we are moving to grayscale
T.Normalize(mean=np.mean([0.485, 0.456, 0.406]).repeat(n_channel),
std=np.sqrt(
(np.array([0.229, 0.224, 0.225])**2).sum()/9).repeat(n_channel)
) if pretrained is True else T.Lambda(lambda x: x)
])
self.validation_transforms = T.Compose([
T.ToTensor(),
T.Resize(size=size),
# Taking the means of the normal distributions of the 3 channels
# since we are moving to grayscale
T.Normalize(mean=np.mean([0.485, 0.456, 0.406]).repeat(n_channel),
std=np.sqrt(
(np.array([0.229, 0.224, 0.225])**2).sum()/9).repeat(n_channel)
) if pretrained is True else T.Lambda(lambda x: x)
])
self.validation = validation
def __call__(self, x):
if not self.validation:
return self.train_transforms(x), self.train_transforms(x)
else:
return self.validation_transforms(x)
Bounding box creation 📦¶
Since most of the frame is empty, it is important that the mouse triplets are cropped correctly when doing SimCLR augments. We use the keypoints to create rough bounding box coordinates around them.
Note that these bounding boxes are made in a simple fixed pixel size cropping mechanism, feel free to change the bounding box generation system.
######## Prepare bounding boxes from keypoints ##########
# Preparing some bounding box information to be used for cropping frames during training
keypoints = np.load(os.path.join(datafolder, 'submission_keypoints.npy'), allow_pickle=True).item()
padbbox = 50
crop_size = 512
for sk in tqdm(keypoints['sequences'].keys()):
kp = keypoints['sequences'][sk]['keypoints']
bboxes = []
for frame_idx in range(len(kp)):
allcoords = np.int32(kp[frame_idx].reshape(-1, 2))
minvals = max(np.min(allcoords[:, 0]) - padbbox, 0), max(np.min(allcoords[:, 1]) - padbbox, 0)
maxvals = min(np.max(allcoords[:, 0]) + padbbox, crop_size), min(np.max(allcoords[:, 1]) + padbbox, crop_size)
bbox = (*minvals, *maxvals)
bbox = np.array(bbox)
bbox = np.int32(bbox * 224 / 512)
bboxes.append(bbox)
keypoints['sequences'][sk]['bbox'] = np.array(bboxes)
# Can save it you want and load later
# keypoints = np.save(os.path.join(datafolder, 'submission_keypoints_bbox.npy'), keypoints)
# keypoints = np.load(os.path.join(datafolder, 'submission_keypoints_bbox.npy'), allow_pickle=True).item()
Training ☑️¶
Below are hyperparamers you can play around with. The runs are pretty slow, so you can reduce the epochs and steps per epochs to find the parameters you want to use.
Note that we do not go over the entire dataset for each "epoch", because the whole dataset is huge.
This code will only use the submission videos for unsupervised training, but you can change it to use all the videos.
################### CONFIG #########################
IS_PRETRAINED = True
batch_size = 32
epochs = 10
# Stack frames with frame skip from the video sequences
LEFT_WINDOW = 3
RIGHT_WINDOW = 3
IMG_SIZE = 224
FRAME_SKIP = 5
n_channel = LEFT_WINDOW + RIGHT_WINDOW + 1
# Check batch size that fits when changing this
embedding_size = 128
# Full Dataset is huge, he're we limiting to steps per epoch
steps_per_epoch = 1000
videos_folder = os.path.join(datafolder, 'video_clips') # TODO: Change this to combined folder
frame_number_map = np.load(os.path.join(datafolder, 'frame_number_map.npy'), allow_pickle=True).item()
checkpoint_folder = "mouse_video_checkpoints/" # Can change this to a Google drive folder
if not os.path.exists(checkpoint_folder):
os.mkdir(checkpoint_folder)
train_dataset = MouseVideoDataset(datafolder=videos_folder,
frame_number_map=frame_number_map,
keypoints=keypoints,
frame_skip=FRAME_SKIP,
num_prev_frames=LEFT_WINDOW,
num_next_frames=RIGHT_WINDOW,
frame_size=(224, 224),
transform=TransformsSimCLR(size=(IMG_SIZE, IMG_SIZE),
pretrained=IS_PRETRAINED,
n_channel=n_channel))
train_loader = torch.utils.data.DataLoader(
train_dataset,
batch_size=batch_size,
shuffle=True,
drop_last=True,
pin_memory=True,
num_workers=2,
)
device = 'cuda' if torch.cuda.is_available() else 'cpu'
##################### MODEL ############################
def get_simclr_model():
resnet_encoder = torchvision.models.resnet50(pretrained=IS_PRETRAINED)
## Experimental setup for multiplying the grayscale channel
## https://stackoverflow.com/a/54777347
weight = resnet_encoder.conv1.weight.clone()
resnet_encoder.conv1 = torch.nn.Conv2d(n_channel, 64, kernel_size=7, stride=2, padding=3, bias=False)
# normalize back by n_channels after tiling
resnet_encoder.conv1.weight.data = weight.sum(dim=1, keepdim=True).tile(1, n_channel, 1, 1)/n_channel
n_features = resnet_encoder.fc.in_features
model = SimCLR(resnet_encoder, embedding_size, n_features)
model = model.to(device)
return model
model = get_simclr_model()
##################### UTILS ############################
optimizer_type = 'Adam'
weight_decay = 1e-6
optimizer, scheduler = load_optimizer(optimizer_type, epochs, weight_decay, batch_size, model)
world_size = 1
temperature = 0.5
criterion = NT_Xent(batch_size, temperature, world_size)
# Basic Training loop
def train(epoch, train_loader, model, criterion, optimizer):
loss_epoch = 0
# tqdm_iter = tqdm(train_loader, total=len(train_loader)) # Total train loader is huge, he're we limiting to steps per epoch
tqdm_iter = tqdm(train_loader, total=steps_per_epoch)
tqdm_iter.set_description(f"Epoch {epoch}")
for step, batch in enumerate(tqdm_iter):
optimizer.zero_grad()
x_i = batch['image'][0].cuda(non_blocking=True)
x_j = batch['image'][1].cuda(non_blocking=True)
# positive pair, with encoding
h_i, h_j, z_i, z_j = model(x_i, x_j)
loss = criterion(z_i, z_j)
loss.backward()
optimizer.step()
tqdm_iter.set_postfix(iter_loss=loss.item())
loss_epoch += loss.item()
if step >= steps_per_epoch:
break
return loss_epoch
# Baseline submission on the leaderboard is trained with 100 epochs, you can train according to your needs
for epoch in range(epochs):
lr = optimizer.param_groups[0]['lr']
loss_epoch = train(epoch, train_loader, model, criterion, optimizer)
if scheduler:
scheduler.step()
if (epoch % 3) == 0:
save_model(epoch, checkpoint_folder, model, optimizer)
save_model(epochs, checkpoint_folder, model, optimizer)
# Cleanup RAM
del model, optimizer
del train_loader, train_dataset
Predict Embeddings 🔮¶
Here we'll predict the outputs from the frames, this may take a long time.
# Load latest model
model = get_simclr_model()
checkpoint_path = os.path.join(checkpoint_folder, 'checkpoint_100.tar')
model.load_state_dict(torch.load(checkpoint_path))
model = model.to(device)
model.eval();
prediction_dataset = MouseVideoDataset(datafolder=videos_folder,
frame_number_map=frame_number_map,
keypoints=keypoints,
frame_skip=FRAME_SKIP,
num_prev_frames=LEFT_WINDOW,
num_next_frames=RIGHT_WINDOW,
frame_size=(224, 224),
transform=TransformsSimCLR(size=(IMG_SIZE, IMG_SIZE),
pretrained=IS_PRETRAINED,
n_channel=n_channel,
validation=True))
prediction_batch_size=128
prediction_loader = torch.utils.data.DataLoader(
prediction_dataset,
batch_size=prediction_batch_size,
shuffle=False,
drop_last=False,
pin_memory=True,
num_workers=2,
)
sample_submission = np.load(datafolder + 'sample_submission.npy')
submission = np.empty((sample_submission.shape[0], embedding_size), dtype=np.float32)
idx = 0
# This may take quite long, since predicting on all frames
for data in tqdm(prediction_loader, total=len(prediction_loader)):
with torch.no_grad():
images = data['image'].to(device)
output = model.projector(model.encoder(images))
output = output.cpu().numpy()
submission[idx:idx+len(output)] = output
idx += len(output)
Submission 🚀¶
print("Embedding shape:", submission.shape)
Validate the submission ✅¶
The submssion should follow these constraints:
- It should be a numpy array
- Embeddings is an 2D numpy array of dtype float32
- The embedding size should't exceed 128
- The frame number map matches the clip lengths
- You can use the helper function below to check these
def validate_submission(submission, frame_number_map):
if not isinstance(submission, np.ndarray):
print("Embeddings should be a numpy array")
return False
elif not len(submission.shape) == 2:
print("Embeddings should be 2D array")
return False
elif not submission.shape[1] <= 128:
print("Embeddings too large, max allowed is 128")
return False
elif not isinstance(submission[0, 0], np.float32):
print(f"Embeddings are not float32")
return False
total_clip_length = frame_number_map[list(frame_number_map.keys())[-1]][1]
if not len(submission) == total_clip_length:
print(f"Emebddings length doesn't match submission clips total length")
return False
if not np.isfinite(submission).all():
print(f"Emebddings contains NaN or infinity")
return False
print("All checks passed")
return True
validate_submission(submission, frame_number_map)
np.save('submission_mouse_simclr.npy', submission)
## Uploads may take time, you can also rund aicrowd-cli on your local machines with he prepared submission file
%aicrowd submission create --description "Mouse SimCLR Baseline" -c {aicrowd_challenge_name} -f submission_mouse_simclr.npy
Content
Comments
You must login before you can post a comment.
Comment deleted by nilesh_arnaiya.