Hierarchical Attention in Vision Transformers
Explore how hierarchical attention enables Vision Transformers (ViT) to process sequential data by encoding relative positions.
Best viewed on desktop for optimal interactive experience
Hierarchical Attention: Efficient Multi-Scale Processing
Hierarchical attention mechanisms enable transformers to efficiently process data at multiple scales, crucial for vision tasks where both local details and global context matter. This approach, pioneered by models like Swin Transformer, revolutionizes how transformers handle high-resolution images.
Interactive Hierarchical Attention Visualization
Explore how attention operates at different scales and merges information hierarchically:
Hierarchical Attention Visualization
Illustrating multi-stage processing with local attention and feature merging.
Processing Steps
Input Image
The process starts with the input image. It's conceptually divided into patches (like ViT), but attention will operate on windows of these patches.
Technical Concepts
Local Window Attention
Instead of computing attention across all patches (like standard ViT), attention is restricted to non-overlapping local windows (e.g., 7x7 patches). This significantly reduces computational complexity from quadratic to linear with respect to the number of patches. It captures local interactions effectively. Models like Swin Transformer also use shifted windows in alternating layers to allow cross-window connections (Shifted windows are conceptually important but not explicitly drawn in this simplified visualization).
Hierarchical Structure & Merging
As the network goes deeper (stages), patch merging layers reduce the number of tokens (spatial resolution) while increasing the feature dimension. For example, features from a 2x2 group of neighboring patches/tokens can be concatenated and then linearly projected to a smaller dimension. This creates a hierarchical representation, similar to CNNs, allowing the model to learn features at different scales. The receptive field of attention windows effectively increases at deeper stages.
Benefits
- Linear computational complexity w.r.t. image size (vs. quadratic for ViT).
- Suitable for high-resolution images and dense prediction tasks (segmentation, detection).
- Captures multi-scale features naturally through hierarchy.
Why Hierarchical Attention?
The Challenge with Standard Attention
- Quadratic complexity: O(N²) for N tokens
- Memory explosion: Unfeasible for high-resolution images
- Single scale: Misses multi-scale nature of visual data
The Hierarchical Solution
- Local windows: Compute attention within small regions
- Progressive merging: Combine windows at higher levels
- Multi-scale features: Capture both fine details and global context
- Linear complexity: O(N) with respect to image size
How Hierarchical Attention Works
1. Window Partitioning
Divide the input into non-overlapping windows:
def window_partition(x, window_size): """ Args: x: (B, H, W, C) window_size: int Returns: windows: (num_windows*B, window_size, window_size, C) """ B, H, W, C = x.shape x = x.view(B, H // window_size, window_size, W // window_size, window_size, C) windows = x.permute(0, 1, 3, 2, 4, 5).contiguous() windows = windows.view(-1, window_size, window_size, C) return windows
2. Local Window Attention
Apply self-attention within each window independently:
def window_attention(windows, window_size): """ Apply self-attention within each window Complexity: O(W² × N) where W is window size """ B_W, W_h, W_w, C = windows.shape windows = windows.view(B_W, W_h * W_w, C) # Standard self-attention within window attn_output = self_attention(windows) return attn_output.view(B_W, W_h, W_w, C)
3. Shifted Windows (Swin Transformer)
Create connections between windows through shifting:
def shifted_window_attention(x, window_size, shift_size): """ Shift windows to create cross-window connections """ if shift_size > 0: # Shift the windows shifted_x = torch.roll(x, shifts=(-shift_size, -shift_size), dims=(1, 2)) else: shifted_x = x # Apply window attention to shifted configuration return window_attention(shifted_x, window_size)
4. Hierarchical Merging
Progressively merge patches to create hierarchy:
def patch_merging(x, merge_size=2): """ Merge patches to reduce spatial resolution """ B, H, W, C = x.shape x = x.view(B, H // merge_size, merge_size, W // merge_size, merge_size, C) x = x.permute(0, 1, 3, 2, 4, 5).contiguous() x = x.view(B, H // merge_size, W // merge_size, merge_size**2 * C) # Linear projection to reduce channels x = self.reduction(x) # Reduces 4C → 2C typically return x
Hierarchical Attention Architectures
Swin Transformer
- Window-based attention: 7×7 or 14×14 windows
- Shifted windows: Alternate between regular and shifted
- 4 stages: Progressively downsample like CNNs
- Patch merging: 2×2 patches → 1 patch between stages
Pyramid Vision Transformer (PVT)
- Progressive shrinking: Reduce spatial resolution gradually
- Spatial reduction attention: Downsample K, V for efficiency
- Multi-scale features: Different resolutions at each stage
Focal Transformer
- Focal attention: Both fine-grained and coarse-grained
- Multi-level aggregation: Combine multiple window sizes
- Adaptive granularity: Adjust based on content
Mathematical Formulation
Complexity Analysis
Standard Attention:
- Complexity: O(N² × d) where N = H × W
- Memory: O(N²)
Hierarchical Attention (with windows of size M):
- Complexity: O(N × M² × d)
- Memory: O(N × M²)
- Reduction factor: N/M² (typically 49× for M=7)
Multi-Scale Feature Maps
At stage s with downsampling factor 2^s:
- Resolution: H/2^s × W/2^s
- Channels: C × 2^s
- Window size: M (constant)
- Number of windows: (H × W) / (M² × 4^s)
Implementation Tips
1. Efficient Window Operations
class WindowAttention(nn.Module): def __init__(self, dim, window_size, num_heads): super().__init__() self.window_size = window_size self.num_heads = num_heads # Relative position bias self.relative_position_bias_table = nn.Parameter( torch.zeros((2 * window_size - 1) ** 2, num_heads) ) def forward(self, x, mask=None): # Window-based multi-head attention B_W, N, C = x.shape qkv = self.qkv(x).reshape(B_W, N, 3, self.num_heads, C // self.num_heads) q, k, v = qkv.unbind(2) attn = (q @ k.transpose(-2, -1)) / math.sqrt(C // self.num_heads) attn = attn + self.get_relative_position_bias() if mask is not None: attn = attn.masked_fill(mask == 0, float('-inf')) attn = F.softmax(attn, dim=-1) out = (attn @ v).transpose(1, 2).reshape(B_W, N, C) return out
2. Handling Window Boundaries
def create_mask_for_shifted_windows(H, W, window_size, shift_size): """ Create attention mask for shifted window attention """ mask = torch.zeros((1, H, W, 1)) h_slices = (slice(0, -window_size), slice(-window_size, -shift_size), slice(-shift_size, None)) w_slices = (slice(0, -window_size), slice(-window_size, -shift_size), slice(-shift_size, None)) cnt = 0 for h in h_slices: for w in w_slices: mask[:, h, w, :] = cnt cnt += 1 return mask
Advantages and Trade-offs
Advantages
✅ Linear complexity with image size ✅ Multi-scale features naturally ✅ CNN-like inductive bias through locality ✅ Memory efficient for high-resolution ✅ Strong performance on vision tasks
Trade-offs
⚠️ Limited global context in early layers ⚠️ Implementation complexity with shifting/masking ⚠️ Fixed window sizes may not be optimal for all content ⚠️ Boundary artifacts from window partitioning
Best Practices
-
Window Size Selection:
- Typically 7×7 or 14×14 for images
- Larger windows for higher resolution
- Consider content characteristics
-
Shift Strategy:
- Shift by window_size // 2 for maximum overlap
- Alternate shifted and non-shifted blocks
-
Position Encodings:
- Use relative position bias within windows
- Absolute position encodings between stages
-
Training Tips:
- Warm-up learning rate important
- Gradient clipping helpful for stability
- Data augmentation crucial for small datasets
Comparison with Other Approaches
Approach | Complexity | Global Context | Multi-Scale | Memory |
---|---|---|---|---|
Standard Attention | O(N²) | ✅ From start | ❌ Single scale | High |
Hierarchical | O(N) | ✅ At higher levels | ✅ Built-in | Low |
Sparse Attention | O(N√N) | ⚠️ Limited | ❌ Single scale | Medium |
Axial Attention | O(N^1.5) | ⚠️ Along axes | ❌ Single scale | Medium |
Applications
Computer Vision
- Image Classification: State-of-the-art on ImageNet
- Object Detection: Backbone for detectors
- Semantic Segmentation: Multi-scale features crucial
- Video Understanding: Temporal hierarchies
Other Domains
- 3D Vision: Hierarchical voxel processing
- Point Clouds: Multi-resolution point attention
- Medical Imaging: Multi-scale tissue analysis
Related Concepts
Related Concepts
Deepen your understanding with these interconnected concepts