Hierarchical Attention in Vision Transformers

Explore how hierarchical attention enables Vision Transformers (ViT) to process sequential data by encoding relative positions.

Best viewed on desktop for optimal interactive experience

Hierarchical Attention: Efficient Multi-Scale Processing

Hierarchical attention mechanisms enable transformers to efficiently process data at multiple scales, crucial for vision tasks where both local details and global context matter. This approach, pioneered by models like Swin Transformer, revolutionizes how transformers handle high-resolution images.

Interactive Hierarchical Attention Visualization

Explore how attention operates at different scales and merges information hierarchically:

Hierarchical Attention Visualization

Illustrating multi-stage processing with local attention and feature merging.

Step 0: Input Image
Selected: Stage 1, Window [0, 0]

Processing Steps

Image
S1 Attn
Merge 1
S2 Attn
Merge 2
S3 Attn

Input Image

The process starts with the input image. It's conceptually divided into patches (like ViT), but attention will operate on windows of these patches.

Image

Technical Concepts

Local Window Attention

Instead of computing attention across all patches (like standard ViT), attention is restricted to non-overlapping local windows (e.g., 7x7 patches). This significantly reduces computational complexity from quadratic to linear with respect to the number of patches. It captures local interactions effectively. Models like Swin Transformer also use shifted windows in alternating layers to allow cross-window connections (Shifted windows are conceptually important but not explicitly drawn in this simplified visualization).

Hierarchical Structure & Merging

As the network goes deeper (stages), patch merging layers reduce the number of tokens (spatial resolution) while increasing the feature dimension. For example, features from a 2x2 group of neighboring patches/tokens can be concatenated and then linearly projected to a smaller dimension. This creates a hierarchical representation, similar to CNNs, allowing the model to learn features at different scales. The receptive field of attention windows effectively increases at deeper stages.

Benefits

  • Linear computational complexity w.r.t. image size (vs. quadratic for ViT).
  • Suitable for high-resolution images and dense prediction tasks (segmentation, detection).
  • Captures multi-scale features naturally through hierarchy.
Interactive Hierarchical Attention Visualization. Simplified representation. Click on canvas windows to select.

Why Hierarchical Attention?

The Challenge with Standard Attention

  • Quadratic complexity: O(N²) for N tokens
  • Memory explosion: Unfeasible for high-resolution images
  • Single scale: Misses multi-scale nature of visual data

The Hierarchical Solution

  • Local windows: Compute attention within small regions
  • Progressive merging: Combine windows at higher levels
  • Multi-scale features: Capture both fine details and global context
  • Linear complexity: O(N) with respect to image size

How Hierarchical Attention Works

1. Window Partitioning

Divide the input into non-overlapping windows:

def window_partition(x, window_size): """ Args: x: (B, H, W, C) window_size: int Returns: windows: (num_windows*B, window_size, window_size, C) """ B, H, W, C = x.shape x = x.view(B, H // window_size, window_size, W // window_size, window_size, C) windows = x.permute(0, 1, 3, 2, 4, 5).contiguous() windows = windows.view(-1, window_size, window_size, C) return windows

2. Local Window Attention

Apply self-attention within each window independently:

def window_attention(windows, window_size): """ Apply self-attention within each window Complexity: O(W² × N) where W is window size """ B_W, W_h, W_w, C = windows.shape windows = windows.view(B_W, W_h * W_w, C) # Standard self-attention within window attn_output = self_attention(windows) return attn_output.view(B_W, W_h, W_w, C)

3. Shifted Windows (Swin Transformer)

Create connections between windows through shifting:

def shifted_window_attention(x, window_size, shift_size): """ Shift windows to create cross-window connections """ if shift_size > 0: # Shift the windows shifted_x = torch.roll(x, shifts=(-shift_size, -shift_size), dims=(1, 2)) else: shifted_x = x # Apply window attention to shifted configuration return window_attention(shifted_x, window_size)

4. Hierarchical Merging

Progressively merge patches to create hierarchy:

def patch_merging(x, merge_size=2): """ Merge patches to reduce spatial resolution """ B, H, W, C = x.shape x = x.view(B, H // merge_size, merge_size, W // merge_size, merge_size, C) x = x.permute(0, 1, 3, 2, 4, 5).contiguous() x = x.view(B, H // merge_size, W // merge_size, merge_size**2 * C) # Linear projection to reduce channels x = self.reduction(x) # Reduces 4C → 2C typically return x

Hierarchical Attention Architectures

Swin Transformer

  • Window-based attention: 7×7 or 14×14 windows
  • Shifted windows: Alternate between regular and shifted
  • 4 stages: Progressively downsample like CNNs
  • Patch merging: 2×2 patches → 1 patch between stages

Pyramid Vision Transformer (PVT)

  • Progressive shrinking: Reduce spatial resolution gradually
  • Spatial reduction attention: Downsample K, V for efficiency
  • Multi-scale features: Different resolutions at each stage

Focal Transformer

  • Focal attention: Both fine-grained and coarse-grained
  • Multi-level aggregation: Combine multiple window sizes
  • Adaptive granularity: Adjust based on content

Mathematical Formulation

Complexity Analysis

Standard Attention:

  • Complexity: O(N² × d) where N = H × W
  • Memory: O(N²)

Hierarchical Attention (with windows of size M):

  • Complexity: O(N × M² × d)
  • Memory: O(N × M²)
  • Reduction factor: N/M² (typically 49× for M=7)

Multi-Scale Feature Maps

At stage s with downsampling factor 2^s:

  • Resolution: H/2^s × W/2^s
  • Channels: C × 2^s
  • Window size: M (constant)
  • Number of windows: (H × W) / (M² × 4^s)

Implementation Tips

1. Efficient Window Operations

class WindowAttention(nn.Module): def __init__(self, dim, window_size, num_heads): super().__init__() self.window_size = window_size self.num_heads = num_heads # Relative position bias self.relative_position_bias_table = nn.Parameter( torch.zeros((2 * window_size - 1) ** 2, num_heads) ) def forward(self, x, mask=None): # Window-based multi-head attention B_W, N, C = x.shape qkv = self.qkv(x).reshape(B_W, N, 3, self.num_heads, C // self.num_heads) q, k, v = qkv.unbind(2) attn = (q @ k.transpose(-2, -1)) / math.sqrt(C // self.num_heads) attn = attn + self.get_relative_position_bias() if mask is not None: attn = attn.masked_fill(mask == 0, float('-inf')) attn = F.softmax(attn, dim=-1) out = (attn @ v).transpose(1, 2).reshape(B_W, N, C) return out

2. Handling Window Boundaries

def create_mask_for_shifted_windows(H, W, window_size, shift_size): """ Create attention mask for shifted window attention """ mask = torch.zeros((1, H, W, 1)) h_slices = (slice(0, -window_size), slice(-window_size, -shift_size), slice(-shift_size, None)) w_slices = (slice(0, -window_size), slice(-window_size, -shift_size), slice(-shift_size, None)) cnt = 0 for h in h_slices: for w in w_slices: mask[:, h, w, :] = cnt cnt += 1 return mask

Advantages and Trade-offs

Advantages

Linear complexity with image size ✅ Multi-scale features naturally ✅ CNN-like inductive bias through locality ✅ Memory efficient for high-resolution ✅ Strong performance on vision tasks

Trade-offs

⚠️ Limited global context in early layers ⚠️ Implementation complexity with shifting/masking ⚠️ Fixed window sizes may not be optimal for all content ⚠️ Boundary artifacts from window partitioning

Best Practices

  1. Window Size Selection:

    • Typically 7×7 or 14×14 for images
    • Larger windows for higher resolution
    • Consider content characteristics
  2. Shift Strategy:

    • Shift by window_size // 2 for maximum overlap
    • Alternate shifted and non-shifted blocks
  3. Position Encodings:

    • Use relative position bias within windows
    • Absolute position encodings between stages
  4. Training Tips:

    • Warm-up learning rate important
    • Gradient clipping helpful for stability
    • Data augmentation crucial for small datasets

Comparison with Other Approaches

ApproachComplexityGlobal ContextMulti-ScaleMemory
Standard AttentionO(N²)✅ From start❌ Single scaleHigh
HierarchicalO(N)✅ At higher levels✅ Built-inLow
Sparse AttentionO(N√N)⚠️ Limited❌ Single scaleMedium
Axial AttentionO(N^1.5)⚠️ Along axes❌ Single scaleMedium

Applications

Computer Vision

  • Image Classification: State-of-the-art on ImageNet
  • Object Detection: Backbone for detectors
  • Semantic Segmentation: Multi-scale features crucial
  • Video Understanding: Temporal hierarchies

Other Domains

  • 3D Vision: Hierarchical voxel processing
  • Point Clouds: Multi-resolution point attention
  • Medical Imaging: Multi-scale tissue analysis

If you found this explanation helpful, consider sharing it with others.

Mastodon