CLS Token in Vision Transformers

Learn how the CLS token acts as a global information aggregator in Vision Transformers, enabling whole-image classification through attention mechanisms.

Best viewed on desktop for optimal interactive experience

Understanding the CLS Token in Vision Transformers

The CLS (Classification) token is a special learnable token that acts as a global information aggregator in Vision Transformers. It's the key to transforming patch-level visual features into a single representation for whole-image tasks.

Interactive CLS Token Visualization

Explore how the CLS token gathers information from all image patches through attention:

ViT Explained: The Role of the CLS Token

1
2
3
4
Vision Transformer (ViT) with CLS TokenExplanation of how a Vision Transformer uses a CLS (Classification) token to classify images. The diagram shows the CLS token being added to the image patches, gathering information through attention, and finally predicting the class of the image.
Step 1: Adding the CLS Token
CLS
(Special Learnable Token)
+
P1
P2
P3
P4
P5
P6
P7
P8
P9
(Sequence of Image Patches)
Inspired by BERT, a special [CLS] token is added to the start of the image patch sequence. Its goal is to aggregate information from all patches and represent the entire image for classification.

What is the CLS Token?

The CLS token is:

  • A learnable embedding added to the beginning of the patch sequence
  • A global aggregator that attends to all image patches
  • The classification head whose final representation is used for predictions
  • Position-aware through positional embeddings

How CLS Token Works

1. Initialization

The CLS token starts as a randomly initialized learnable parameter, separate from image content:

cls_token = nn.Parameter(torch.randn(1, 1, embed_dim))

2. Position Encoding

Like patch embeddings, the CLS token receives its own positional encoding:

pos_embed = nn.Parameter(torch.randn(1, num_patches + 1, embed_dim)) x = torch.cat([cls_token, patch_embeddings], dim=1) x = x + pos_embed # Add positional information

3. Attention Mechanism

Through self-attention layers, the CLS token:

  • Attends to all patches to gather global information
  • Is attended by patches allowing bidirectional information flow
  • Progressively refines its representation through multiple layers

4. Classification

The final CLS token representation is used for classification:

cls_output = x[:, 0] # Extract CLS token (first position) logits = classification_head(cls_output)

Why Use a CLS Token?

Advantages

  • Flexibility: Works for images of any size
  • Efficiency: Single token for classification vs. pooling all patches
  • Interpretability: Attention weights show which patches influence decisions
  • Consistency: Aligns with NLP transformer architectures

Alternative Approaches

  • Global Average Pooling: Average all patch representations
  • Direct Patch Classification: Use all patches for classification
  • Learnable Pooling: Learn weighted combination of patches

CLS Token Attention Patterns

The CLS token learns different attention patterns across layers:

Early Layers:

  • Broad, uniform attention across patches
  • Gathering basic visual features

Middle Layers:

  • Focused attention on semantic regions
  • Building object representations

Final Layers:

  • Task-specific attention patterns
  • Emphasizing discriminative features

Implementation Details

PyTorch Implementation

class VisionTransformer(nn.Module): def __init__(self, img_size=224, patch_size=16, num_classes=1000): super().__init__() num_patches = (img_size // patch_size) ** 2 # CLS token and positional embeddings self.cls_token = nn.Parameter(torch.zeros(1, 1, embed_dim)) self.pos_embed = nn.Parameter(torch.zeros(1, num_patches + 1, embed_dim)) def forward(self, x): # Add CLS token to patch embeddings cls_tokens = self.cls_token.expand(B, -1, -1) x = torch.cat((cls_tokens, x), dim=1) # Add positional embeddings x = x + self.pos_embed # Apply transformer blocks x = self.transformer(x) # Extract CLS token for classification cls_output = x[:, 0] return self.head(cls_output)

CLS Token in Different Architectures

Vision Transformer (ViT)

  • Standard CLS token approach
  • Used for image classification

DINO (Self-Supervised ViT)

  • CLS token for self-supervised learning
  • Learns without labels through self-distillation

CLIP (Vision-Language)

  • CLS token represents entire image
  • Aligned with text representations

DeiT (Data-Efficient ViT)

  • Uses both CLS and distillation tokens
  • Improved training efficiency

Best Practices

  1. Initialization: Use proper initialization (truncated normal)
  2. Learning Rate: Often benefits from different LR than patches
  3. Regularization: Apply dropout to CLS token output
  4. Fine-tuning: CLS token adapts quickly to new tasks

Common Misconceptions

"CLS token is just for classification" ✅ Can be used for any global task (detection, segmentation with modifications)

"CLS token sees the original image" ✅ Only sees patch embeddings, not raw pixels

"CLS token is necessary for ViT" ✅ Alternative pooling strategies exist and can work well

If you found this explanation helpful, consider sharing it with others.

Mastodon