Stichable Neural Networks

TLDR; the Stichable Neural Networks paper includes some interesting concepts. It allows the creation of multiple neural networks with varying complexity and performance trade-offs from a family of pretrained models.

Key Principles

A key question about combining sub-networks from different pretrained models is how to maintain accuracy. The paper concludes that the final performance of these combinations is nearly predictable due to an interpolation-like performance curve between anchors. This predictability allows for selective pre-training of stitches based on various deployment scenarios.

The Choice of Anchors

Anchors that are pretrained on different tasks can learn very different representations due to the large distribution gap of different domains. Therefore, the selected anchors should be consistent in terms of the pretrained domain.

The Stitching Layer and its Initialization

SN-Net is built upon pretrained models. Therefore, the anchors have already learned good representations, which allows to directly obtain an accurate transformation matrix by solving the least squares problem:

$$||AM_o - B|| = min||AM - b||_F$$

where $A \in R^{N \times D_1}$ and \(B \in R^{N \times D_2}\) are two feature maps of the same spatial size but with different number of hidden dimensions.

This function indicates a closed form expression based on singular value decomposition, in which case the optimal solution can be achieved through an orthogonal projection in the space of matrices:

$$M_o = A^\dagger B$$

where $A^\dagger$ denotes the Moore-Penrose pseudoinverse of $A$.

Where to Stitch

SN-Net takes Fast-to-Slow as the default stitching direction, meaning it will stitch bigger and slower network after smaller and faster networks to achieve better model performance. Besides, it also proposes a nearest stitching strategy by limiting the stitching between two anchors of the nearest model complexity/performance.

Way to Stitch

Prior works shows neighboring layers dealing with the same scale feature maps share similar representations. Therefore, SN-Net uses slideing window: where the same window shares a common stitching layer.

sliding-window

Stitching Space

The stitching space is controlled by the configuring the sliding window kernel size $k$ and step size $s$.

Training Strategy

The training algorithm of SN-Net can be described as:

sn-net-training

The training algorithm can be summarized as:

  1. Firstly define a configuration set that contains all possible stitches
  2. Initialize all stitching layers with least-squares matching
  3. At each training iteration, we randomly sample a stitch and follow the standard training process as in common practices