MABe 2022: Fruit Fly Groups
Getting Started - MABe Challenge 2022: Fruit Flies v 0.2.2kb
Explore the fly tracking dataset and make your first submission with a simple PCA embedding.
Explore the fly tracking dataset and make your first submission with a simple PCA embedding.
20220220: Fixed mistake in description of dimensionality limit, it is 256, not 100.
How to use this notebook 📝¶
- Copy the notebook. This is a shared template and any edits you make here will not be saved. You should copy it into your own drive folder. For this, click the "File" menu (top-left), then "Save a Copy in Drive". You can edit your copy however you like.
- Link it to your AIcrowd account. In order to submit your predictions to AIcrowd, you need to provide your account's API key.
Setup AIcrowd Utilities 🛠¶
!pip install -U aicrowd-cli
%load_ext aicrowd.magic
Login to AIcrowd ㊗¶¶
%aicrowd login
Install packages 🗃¶
Please add all pacakages installations in this section
!pip install scikit-learn
Import necessary modules and packages 📚¶
import os
import numpy as np
import pandas as pd
from sklearn.decomposition import PCA
from sklearn.preprocessing import StandardScaler
Download the dataset 📲¶
aicrowd_challenge_name = "mabe-2022-fruit-fly-groups"
if not os.path.exists('data'):
# %aicrowd ds dl -c {aicrowd_challenge_name} -o data # Download all files
# %aicrowd ds dl -c {aicrowd_challenge_name} -o data *submission_data* # download only the submission keypoint data
%aicrowd ds dl -c {aicrowd_challenge_name} -o data *user_train* # download only the submission keypoint data
Load Data¶
user_train = np.load('data/user_train.npy',allow_pickle=True).item()
# sample_submission = np.load('data/sample_submission.npy',allow_pickle=True).item()
Dataset Specifications 💾¶
We provide frame-by-frame animal pose estimates extracted from top-view videos of 9-11 interacting flies filmed at 150Hz; raw videos will not be provided. Flies come from a variety of genetically engineered lines, and exhibit both naturally occurring and optogenetically or thermogenetically evoked behaviors.
The following files are available in the resources
section. A "sequence" is a continuous recording of social interactions between animals: sequences are 30 seconds long (4500 frames at 150Hz) in the fly dataset. The sequence_id
is a random hash to anonymize experiment details. nans indicate missing data. These occur because videos have between 9 and 11 flies. Data are padded with nans to be all the same size.
- Training set for the task, which follows the following schema :
"sequences" : {
"<sequence_id> : {
"keypoints" : a ndarray of shape (4500, 11, 24, 2)
"vocabulary" : a list of strings identifying sample classification tasks
"keypoint_vocabulary" : names of the 24 x 2 keypoints, list of pairs of strings
- Test set for the task, which follows the following schema:
"<sequence_id> : {
"keypoints" : a ndarray of shape (4500, 11, 24, 2)
- sample_submission.npy - Template for a sample submission for this task, follows the following schema :
{"<sequence_id-1>": (start_frame_index, end_frame_index),
"<sequence_id-1>": (start_frame_index, end_frame_index),
"<sequence_id-n>": (start_frame_index, end_frame_index),
"<sequence_id-1>" : [
[0.321, 0.234, 0.186, 0.857, 0.482, 0.185], .....]
[0.184, 0.583, 0.475], 0.485, 0.275, 0.958], .....]
In sample_submission
, each key in the frame_number_map
dictionary refers to the unique sequence id of a video in the test set. The item for each key is expected to be an the start and end index for slicing the embeddings
numpy array to get the corresponding embeddings. The embeddings
array is a 2D ndarray
of floats of size total_frames
by X
, where X
is the dimension of your learned embedding (6 in the above example; maximum permitted embedding dimension is 256), representing the embedded value of each frame in the sequence. total_frames
is the sum of all the frames of the sequences, the array should be concatenation of all the embeddings of all the clips.
print("Dataset keys - ", user_train.keys())
print("Number of train data sequences - ", len(user_train['sequences']))
Sample overview¶
sequence_names = list(user_train["sequences"].keys())
sequence_key = sequence_names[0]
single_sequence = user_train["sequences"][sequence_key]
print("Sequence name - ", sequence_key)
print("Single Sequence shape ", single_sequence['keypoints'].shape)
print(f"Number of elements in {sequence_key} - ", len(single_sequence))
Data representation¶
Animal poses are characterized by the tracked locations of body parts on each animal, termed "keypoints." Keypoints are stored in an ndarray with the following properties:
- Dimensions: (
# frames
) × (animal ID
) × (body part
) × (x, y coordinate
). - Units: millimeters; coordinates are relative to the center of the circular arena (radius 26.689 mm) the flies are contained within. Original image dimensions are 1024 × 1024 pixels (18.8825 pixels / mm) for the fly dataset.
Body parts are ordered: 1) left wing tip, 2) right wing tip, 3) antennae midpoint, 4) right eye, 5) left eye, 6) left front of thorax, 7) right front of thorax, 8) base of thorax, 9) tip of abdomen, 10) right middle femur base, 11) right middle femur-tibia join, 12) left middle femur base, 13) left middle femur-tibia joint, 14) right front leg tip, 15) right middle leg tip, 16) right rear leg tip, 17) left front leg tip, 18) left middle leg tip, 19) left rear leg tip.
The placement of these keypoints is illustrated below:
In addition, 10 other features are included in the keypoints array, increasing its dimension from 19 × 2 to 24 × 2. These features are: 20) Ellipse-fit center (x and y coordinates), 21) Ellipse-fit orientation (cosine and sine), 22) Ellipse fit axis lengths (major and minor), 23) Area (body, foreground), and 24) Appearance (foreground/background contrast, minimum neighbor distance).
Helper function for visualization 💁¶
Useful functions for interacting with the fly tracking sequences
Don't forget to run the cell 😉
import numpy as np
import os
from import Dataset, DataLoader
import matplotlib
import matplotlib.pyplot as plt
from matplotlib import cm,animation,colors,rc
rc('animation', html='jshtml')
from tqdm import tqdm
# data frame rate
FPS = 150.
# size of the arena the flies are enclosed in
# hard-code indices of keypoints and skeleton edges
keypointidx = np.arange(18,dtype=int)
skeleton_edges = np.array([
[ 7, 8],
[10, 14],
[11, 12],
[12, 17],
[ 7, 11],
[ 9, 10],
[ 7, 9],
[ 5, 7],
[ 2, 3],
[ 2, 7],
[ 5, 18],
[ 6, 13],
[ 7, 16],
[ 7, 15],
[ 2, 4],
[ 6, 7],
[ 7, 0],
[ 7, 1]
# keypoints for computing distances between pairs of flies
fidxdist = np.array([2,7,8])
d = get_fly_dists(x,tgt=0)
Compute the distance between fly tgt and all other flies. This is defined as the
minimum distance between any pair of the following keypoints:
'antennae','end_notum_x','end_abdomen', hard-coded with fidxdist
at middle frame data.ctrf
x: ndarray of size maxnflies x nkeypoints x 2 data sample, sequence of data for all flies
tgt: (optional) which fly to compute distances to. Default is 0.
d: Array of length nflies with the squared distance to the selected target.
def get_fly_dists(x, tgt=0):
nkpts = len(fidxdist)
ntgts = x.shape[0]
ndim = x.shape[2]
d = np.min(np.sum((x[:,fidxdist,:].reshape(ntgts,1,nkpts,ndim)-
return d
dark3cm = get_Dark3_cmap()
Returns a new matplotlib colormap based on the Dark2 colormap.
I didn't have quite enough unique colors in Dark2, so I made Dark3 which
is Dark2 followed by all Dark2 colors with the same hue and saturation but
half the brightness.
def get_Dark3_cmap():
dark2 = list(cm.get_cmap('Dark2').colors)
dark3 = dark2.copy()
for c in dark2:
chsv = colors.rgb_to_hsv(c)
chsv[2] = chsv[2]/2.
crgb = colors.hsv_to_rgb(chsv)
dark3cm = colors.ListedColormap(tuple(dark3))
return dark3cm
isreal = get_real_flies(x)
Returns which flies in the input ndarray x correspond to real data (are not nan).
x: ndarray of arbitrary dimensions, as long as the last two dimensions are nfeatures x 2,
and correspond to the keypoints and x,y coordinates.
def get_real_flies(x):
# x is ntgts x nfeatures x 2
isreal = np.all(np.isnan(x),axis=(-1,-2))==False
return isreal
fig,ax,isnewaxis = set_fig_ax(fig=None,ax=None)
Create new figure and/or axes if those are not input.
Returns the handles to those figures and axes.
isnewaxis is whether a new set of axes was created.
def set_fig_ax(fig=None,ax=None):
isnewaxis = True
if ax is None:
if fig is None:
fig = plt.figure(figsize=(8, 8))
ax = fig.add_subplot(111)
isnewaxis = False
return fig, ax, isnewaxis
hkpt,hedge,fig,ax = plot_fly(pose=None,
fig=None, ax=None, kptcolors=None, color=None, name=None,
plotskel=True, plotkpts=True, hedge=None, hkpt=None)
Visualize the single fly position specified by pose
pose: Required. nfeatures x 2 ndarray.
kptidx: Optional. 1-dimensional array specifying which keypoints to plot. If None,
uses keypointidx. Default: None.
skelidx: Optional. nedges x 2 ndarray specifying which keypoints to connect with edges.
If None, uses skeleton_edges. Default: None.
fig: Optional. Handle to figure to plot in. Only used if ax is not specified. Default = None.
If None, a new figure is created.
ax: Optional. Handle to axes to plot in. Default = None. If None, new axes are created.
kptcolors: Optional. Color scheme for each keypoint. Can be a string defining a matplotlib
colormap (e.g. 'hsv'), a matplotlib colormap, or a single color. If None, it is set to 'hsv'.
Default: None
color: Optional. Color for edges plotted. If None, it is set to [.6,.6,.6]. efault: None.
name: Optional. String defining an identifying label for these plots. Default None.
plotskel: Optional. Whether to plot skeleton edges. Default: True.
plotkpts: Optional. Whether to plot key points. Default: True.
hedge: Optional. Handle of edges to update instead of plot new edges. Default: None.
hkpt: Optional. Handle of keypoints to update instead of plot new key points. Default: None.
def plot_fly(pose=None, kptidx=None, skelidx=None, fig=None, ax=None, kptcolors=None, color=None, name=None,
plotskel=True, plotkpts=True, hedge=None, hkpt=None):
# plot_fly(x,fig=None,ax=None,kptcolors=None):
# x is nfeatures x 2
assert(pose is not None)
if kptidx is None:
kptidx = keypointidx
if skelidx is None:
skelidx = skeleton_edges
isnewaxis = False
if ((hedge is None) and plotskel) or ((hkpt is None) and plotkpts):
fig,ax,isnewaxis = set_fig_ax(fig=fig,ax=ax)
isreal = get_real_flies(pose)
if plotkpts:
if isreal:
xc = pose[kptidx,0]
yc = pose[kptidx,1]
xc = []
yc = []
if hkpt is None:
if kptcolors is None:
kptcolors = 'hsv'
if (type(kptcolors) == list or type(kptcolors) == np.ndarray) and len(kptcolors) == 3:
kptname = 'keypoints'
if name is not None:
kptname = name + ' ' + kptname
hkpt = ax.plot(xc,yc,'.',color=kptcolors,label=kptname,zorder=10)[0]
if type(kptcolors) == str:
kptcolors = plt.get_cmap(kptcolors)
hkpt = ax.scatter(xc,yc,c=np.arange(len(kptidx)),marker='.',cmap=kptcolors,zorder=10)
if type(hkpt) == matplotlib.lines.Line2D:
if plotskel:
nedges = skelidx.shape[0]
if isreal:
xc = np.concatenate((pose[skelidx,0],np.zeros((nedges,1))+np.nan),axis=1)
yc = np.concatenate((pose[skelidx,1],np.zeros((nedges,1))+np.nan),axis=1)
xc = np.array([])
yc = np.array([])
if hedge is None:
edgename = 'skeleton'
if name is not None:
edgename = name + ' ' + edgename
if color is None:
color = [.6,.6,.6]
hedge = ax.plot(xc.flatten(),yc.flatten(),'-',color=color,label=edgename,zorder=0)[0]
if isnewaxis:
return hkpt,hedge,fig,ax
hkpt,hedge,fig,ax = plot_flies(poses=None, kptidx=None, skelidx=None,
Visualize all flies for a single frame specified by poses.
poses: Required. nflies x nfeatures x 2 ndarray.
colors: Optional. Color scheme for edges plotted for each fly. Can be a string defining a matplotlib
colormap (e.g. 'hsv'), a matplotlib colormap, or a single color. If None, it is set to the Dark3
colormap I defined in get_Dark3_cmap(). Default: None.
kptcolors: Optional. Color scheme for each keypoint. Can be a string defining a matplotlib
colormap (e.g. 'hsv'), a matplotlib colormap, or a single color. If None, it is set to [0,0,0].
Default: None
hedges: Optional. List of handles of edges, one per fly, to update instead of plot new edges. Default: None.
hkpts: Optional. List of handles of keypoints, one per fly, to update instead of plot new key points.
Default: None.
Extra arguments: All other arguments will be passed directly to plot_fly.
def plot_flies(poses=None,fig=None,ax=None,colors=None,kptcolors=None,hedges=None,hkpts=None,**kwargs):
if hedges is None or hkpts is None:
fig,ax,isnewaxis = set_fig_ax(fig=fig,ax=ax)
isnewaxis = False
if colors is None:
colors = get_Dark3_cmap()
if kptcolors is None:
kptcolors = [0,0,0]
nflies = poses.shape[0]
if not (type(colors) == list or type(kptcolors) == np.ndarray):
if type(colors) == str:
cmap = cm.get_cmap(colors)
cmap = colors
colors = cmap(np.linspace(0.,1.,nflies))
if hedges is None:
hedges = [None,]*nflies
if hkpts is None:
hkpts = [None,]*nflies
for fly in range(nflies):
hkpts[fly],hedges[fly],fig,ax = plot_fly(poses[fly,...],fig=fig,ax=ax,color=colors[fly,...],
if isnewaxis:
return hkpts,hedges,fig,ax
animate_pose_sequence(seq=None, kptidx=None, skelidx=None,
Visualize all flies for the input sequence of frames seq.
seq: Required. seql x nflies x nfeatures x 2 ndarray.
start_frame: Which frame of the sequence to start plotting at. Default: 0.
stop_frame: Which frame of the sequence to end plotting on. Default: None. If None, the
sequence length (seq.shape[0]) is used.
skip: How many frames to skip between plotting. Default: 1.
fig: Optional. Handle to figure to plot in. Only used if ax is not specified. Default = None.
If None, a new figure is created.
ax: Optional. Handle to axes to plot in. Default = None. If None, new axes are created.
savefile: Optional. Name of video file to save animation to. If None, animation is displayed
instead of saved.
Extra arguments: All other arguments will be passed directly to plot_flies.
def animate_pose_sequence(seq=None,start_frame=0,stop_frame=None,skip=1,
if stop_frame is None:
stop_frame = seq.shape[0]
fig,ax,isnewaxis = set_fig_ax(fig=fig,ax=ax)
isreal = get_real_flies(seq)
idxreal = np.where(np.any(isreal,axis=0))[0]
seq = seq[:,idxreal,...]
# plot the arena wall
theta = np.linspace(0,2*np.pi,360)
minv = -ARENA_RADIUS_MM*1.01
maxv = ARENA_RADIUS_MM*1.01
# first frame
f = start_frame
h = {}
h['kpts'],h['edges'],fig,ax = plot_flies(poses=seq[f,...],fig=fig,ax=ax,**kwargs)
h['frame'] = plt.text(-ARENA_RADIUS_MM*.99,ARENA_RADIUS_MM*.99,'Frame %d (%.2f s)'%(f,float(f)/FPS),
def update(f):
h['kpts'],h['edges'],fig,ax = plot_flies(poses=seq[f,...],
h['frame'].set_text('Frame %d (%.2f s)'%(f,float(f)/FPS))
return h['edges']+h['kpts']
ani = animation.FuncAnimation(fig, update, frames=np.arange(start_frame,stop_frame,skip,dtype=int))
if savefile is not None:
print('Saving animation to file %s...'%savefile)
writer = animation.PillowWriter(fps=30),writer=writer)
print('Finished writing.')
return ani
Visualize the fly movements🎥¶
Sample visualization for plotting pose gifs.
seqid = next(iter(user_train['sequences']))
seq = user_train['sequences'][seqid]['keypoints']
# animate frames from the sequence
ani = animate_pose_sequence(seq=seq,
# prepare and show the animation - this could take a few seconds