Quantization Effects Simulator
Explore memory-accuracy trade-offs in embedding quantization from float32 to binary representations.
Best viewed on desktop for optimal interactive experience
Quantization Effects Simulator
Quantization reduces the precision of embedding values to save memory and accelerate computation, with controllable trade-offs in accuracy.
Interactive Quantization Simulator
Quantization Effects Simulator
Explore memory-accuracy trade-offs in embedding quantization
Quantization Settings
Int8 Quantization
8-bit integers, good balance
Value-Level Comparison
Memory Savings Analysis
Performance Impact
Bit-Level Representation
Bit Representation of 0.7265
Quantization Methods Comparison
Method | Bits | Memory | Accuracy | Speed | Use Case |
---|---|---|---|---|---|
Float32 (Original) | 32 | 100% | 100% | 1x | Research, training |
Float16 (Half Precision) | 16 | 50% | 99.5% | 1.8x | GPU inference |
Int8 Quantization | 8 | 25% | 98.2% | 3.2x | Production, edge |
Int4 Quantization | 4 | 13% | 95.5% | 5.5x | Mobile, IoT |
Binary Quantization | 1 | 3% | 85% | 20x | Extreme edge |
Quantization Best Practices
When to Use
- • Edge deployment with memory constraints
- • Large-scale serving to reduce costs
- • Mobile and embedded applications
- • Batch inference on CPU
Considerations
- • Test accuracy on your specific task
- • Consider quantization-aware training
- • Profile actual speedup on target hardware
- • May need calibration dataset
Understanding Quantization
Quantization maps continuous values to discrete levels:
Where:
- \text{scale} = \text{max} - \text{min}2\text{bits} - 1
- Lower bits = fewer discrete levels
- Higher compression = more information loss
Quantization Methods
1. Float16 (Half Precision)
16 bits: 1 sign + 5 exponent + 10 mantissa
Original: 0.123456789 (float32) Quantized: 0.1235 (float16) Memory: 50% reduction Accuracy: ~99.5% preserved
2. Int8 Quantization
8 bits: Maps to [-128, 127]
def quantize_int8(x, scale, zero_point): # Affine quantization q = np.round(x / scale + zero_point) q = np.clip(q, -128, 127).astype(np.int8) return q def dequantize_int8(q, scale, zero_point): return scale * (q - zero_point)
3. Int4 Quantization
4 bits: Maps to [-8, 7]
- 93.75% memory reduction
- Good for inference on edge devices
- Requires careful calibration
4. Binary Quantization
1 bit: Only sign matters
def binary_quantize(x): return np.sign(x) # Returns -1 or 1 # Similarity in binary space def binary_similarity(b1, b2): # Hamming distance return np.sum(b1 == b2) / len(b1)
Quantization Schemes
Symmetric vs Asymmetric
Symmetric Quantization:
# Zero point at origin scale = max(abs(x_min), abs(x_max)) / (2^(bits-1) - 1) q = round(x / scale)
Asymmetric Quantization:
# Arbitrary zero point scale = (x_max - x_min) / (2^bits - 1) zero_point = round(-x_min / scale) q = round(x / scale) + zero_point
Per-Tensor vs Per-Channel
# Per-tensor: Single scale for entire tensor scale = compute_scale(tensor) quantized = quantize(tensor, scale) # Per-channel: Different scale per dimension scales = [compute_scale(tensor[i]) for i in range(channels)] quantized = [quantize(tensor[i], scales[i]) for i in range(channels)]
Implementation Examples
Post-Training Quantization
import torch import torch.nn as nn def quantize_model_weights(model, bits=8): """Quantize model after training""" for name, param in model.named_parameters(): if 'weight' in name: # Calculate quantization parameters min_val = param.min() max_val = param.max() scale = (max_val - min_val) / (2**bits - 1) zero_point = -min_val / scale # Quantize and dequantize quantized = torch.round(param / scale + zero_point) quantized = torch.clamp(quantized, 0, 2**bits - 1) dequantized = (quantized - zero_point) * scale # Replace weights param.data = dequantized
Quantization-Aware Training
class QuantizedLinear(nn.Module): def __init__(self, in_features, out_features, bits=8): super().__init__() self.weight = nn.Parameter(torch.randn(out_features, in_features)) self.bits = bits def forward(self, x): # Fake quantization during training if self.training: # Compute scale w_min, w_max = self.weight.min(), self.weight.max() scale = (w_max - w_min) / (2**self.bits - 1) # Quantize and dequantize w_quant = torch.round(self.weight / scale) * scale # Straight-through estimator for gradients w_quant = self.weight + (w_quant - self.weight).detach() else: w_quant = self.weight return F.linear(x, w_quant)
Performance Analysis
Memory Savings
Method | Bits | Memory | Relative Size |
---|---|---|---|
Float32 | 32 | 100% | 1.00× |
Float16 | 16 | 50% | 0.50× |
Int8 | 8 | 25% | 0.25× |
Int4 | 4 | 12.5% | 0.125× |
Binary | 1 | 3.125% | 0.03125× |
Accuracy Impact
Typical accuracy retention:
Float32 → Float16: 99.5% Float32 → Int8: 98-99% Float32 → Int4: 95-97% Float32 → Binary: 85-90%
Speed Improvements
# Benchmark example import time def benchmark_inference(model, input_data, quantized=False): if quantized: model = quantize_model(model) start = time.time() with torch.no_grad(): for _ in range(1000): output = model(input_data) return time.time() - start # Results (typical) # Float32: 1.0s # Int8: 0.3s (3.3× faster) # Int4: 0.2s (5× faster)
Advanced Techniques
1. Mixed Precision
Different precision for different layers:
config = { 'attention': 8, # Int8 for attention 'ffn': 4, # Int4 for feed-forward 'embeddings': 16 # Float16 for embeddings }
2. Dynamic Quantization
Quantize activations on-the-fly:
model = torch.quantization.quantize_dynamic( model, {nn.Linear}, # Layers to quantize dtype=torch.qint8 )
3. Learned Quantization
Learn optimal quantization parameters:
class LearnedQuantizer(nn.Module): def __init__(self, bits=8): super().__init__() self.scale = nn.Parameter(torch.ones(1)) self.zero_point = nn.Parameter(torch.zeros(1)) self.bits = bits def forward(self, x): # Learned affine transformation q = torch.round(x / self.scale + self.zero_point) q = torch.clamp(q, 0, 2**self.bits - 1) return (q - self.zero_point) * self.scale
Quantization for Embeddings
Embedding Table Quantization
class QuantizedEmbedding(nn.Module): def __init__(self, num_embeddings, embedding_dim, bits=8): super().__init__() # Store quantized embeddings self.embeddings = nn.Parameter( torch.randint(0, 2**bits, (num_embeddings, embedding_dim), dtype=torch.uint8) ) self.scale = nn.Parameter(torch.ones(embedding_dim)) self.zero_point = nn.Parameter(torch.zeros(embedding_dim)) def forward(self, indices): # Lookup and dequantize quantized = self.embeddings[indices] return (quantized - self.zero_point) * self.scale
Product Quantization
Split vectors and quantize separately:
def product_quantization(vectors, num_subvectors=8, bits=8): """Quantize vectors using product quantization""" D = vectors.shape[1] d = D // num_subvectors quantized = [] codebooks = [] for i in range(num_subvectors): # Extract subvector subvecs = vectors[:, i*d:(i+1)*d] # Learn codebook (k-means) kmeans = KMeans(n_clusters=2**bits) labels = kmeans.fit_predict(subvecs) quantized.append(labels) codebooks.append(kmeans.cluster_centers_) return quantized, codebooks
Best Practices
1. Calibration
Determine optimal scale/zero-point:
def calibrate_quantization(data_loader, model): """Find optimal quantization parameters""" min_vals, max_vals = {}, {} for batch in data_loader: output = model(batch) for name, tensor in model.named_parameters(): if name not in min_vals: min_vals[name] = tensor.min() max_vals[name] = tensor.max() else: min_vals[name] = min(min_vals[name], tensor.min()) max_vals[name] = max(max_vals[name], tensor.max()) return min_vals, max_vals
2. Outlier Handling
def clip_outliers(tensor, percentile=99.9): """Clip outliers before quantization""" threshold = np.percentile(abs(tensor), percentile) return np.clip(tensor, -threshold, threshold)
3. Error Compensation
def quantize_with_error_compensation(weights, bits=8): """Accumulate and compensate quantization errors""" error = 0 quantized = [] for w in weights: # Add accumulated error w_compensated = w + error # Quantize q = quantize(w_compensated, bits) # Compute new error error = w_compensated - q quantized.append(q) return quantized
Deployment Considerations
Mobile/Edge Deployment
# TensorFlow Lite example import tensorflow as tf converter = tf.lite.TFLiteConverter.from_saved_model(saved_model_dir) converter.optimizations = [tf.lite.Optimize.DEFAULT] converter.representative_dataset = representative_dataset converter.target_spec.supported_ops = [tf.lite.OpsSet.TFLITE_BUILTINS_INT8] converter.inference_input_type = tf.int8 converter.inference_output_type = tf.int8 tflite_model = converter.convert()
Hardware Acceleration
- ARM: Int8 with NEON
- x86: Int8 with AVX512 VNNI
- GPU: Int8 Tensor Cores
- TPU: Bfloat16 native
Related Concepts
- Dense Embeddings - Full precision representations
- Matryoshka Embeddings - Dimension reduction alternative
- Sparse vs Dense - Sparsity as compression
References
- Jacob et al. "Quantization and Training of Neural Networks for Efficient Integer-Arithmetic-Only Inference"
- Gholami et al. "A Survey of Quantization Methods for Efficient Neural Network Inference"
- Dettmers et al. "8-bit Optimizers via Block-wise Quantization"
- Zafrir et al. "Q8BERT: Quantized 8Bit BERT"