Quantization Effects Simulator

10 min

Explore memory-accuracy trade-offs in embedding quantization from float32 to binary representations.

Best viewed on desktop for optimal interactive experience

Quantization Effects Simulator

Quantization reduces the precision of embedding values to save memory and accelerate computation, with controllable trade-offs in accuracy.

Interactive Quantization Simulator

Quantization Effects Simulator

Explore memory-accuracy trade-offs in embedding quantization

Quantization Settings

Int8 Quantization

8-bit integers, good balance

Value-Level Comparison

Original
Quantized
Dim 0:
-0.683
-0.685
0%
Dim 1:
-0.977
-0.976
0%
Dim 2:
-0.931
-0.929
0%
Dim 3:
-0.892
-0.890
0%
Dim 4:
0.273
0.276
1%
Dim 5:
0.250
0.252
1%
Dim 6:
-0.421
-0.417
1%
Dim 7:
-0.341
-0.339
1%

Memory Savings Analysis

Original Size
2.25 MB
Quantized Size
0.56 MB
Total Savings1.69 MB (75.0%)

Performance Impact

Memory Usage
0.6 MB
Model Accuracy
98.2%
Inference Speed
3.2x

Bit-Level Representation

Bit Representation of 0.7265

01011100
8 bits per value → 8 bits total

Quantization Methods Comparison

MethodBitsMemoryAccuracySpeedUse Case
Float32 (Original)32100%100%1xResearch, training
Float16 (Half Precision)1650%99.5%1.8xGPU inference
Int8 Quantization825%98.2%3.2xProduction, edge
Int4 Quantization413%95.5%5.5xMobile, IoT
Binary Quantization13%85%20xExtreme edge

Quantization Best Practices

When to Use

  • • Edge deployment with memory constraints
  • • Large-scale serving to reduce costs
  • • Mobile and embedded applications
  • • Batch inference on CPU

Considerations

  • • Test accuracy on your specific task
  • • Consider quantization-aware training
  • • Profile actual speedup on target hardware
  • • May need calibration dataset

Understanding Quantization

Quantization maps continuous values to discrete levels:

Q(x) = \text{round}(x - \text{min}\text{scale}) × \text{scale} + \text{min}

Where:

  • \text{scale} = \text{max} - \text{min}2\text{bits} - 1
  • Lower bits = fewer discrete levels
  • Higher compression = more information loss

Quantization Methods

1. Float16 (Half Precision)

16 bits: 1 sign + 5 exponent + 10 mantissa

Original: 0.123456789 (float32) Quantized: 0.1235 (float16) Memory: 50% reduction Accuracy: ~99.5% preserved

2. Int8 Quantization

8 bits: Maps to [-128, 127]

def quantize_int8(x, scale, zero_point): # Affine quantization q = np.round(x / scale + zero_point) q = np.clip(q, -128, 127).astype(np.int8) return q def dequantize_int8(q, scale, zero_point): return scale * (q - zero_point)

3. Int4 Quantization

4 bits: Maps to [-8, 7]

  • 93.75% memory reduction
  • Good for inference on edge devices
  • Requires careful calibration

4. Binary Quantization

1 bit: Only sign matters

def binary_quantize(x): return np.sign(x) # Returns -1 or 1 # Similarity in binary space def binary_similarity(b1, b2): # Hamming distance return np.sum(b1 == b2) / len(b1)

Quantization Schemes

Symmetric vs Asymmetric

Symmetric Quantization:

# Zero point at origin scale = max(abs(x_min), abs(x_max)) / (2^(bits-1) - 1) q = round(x / scale)

Asymmetric Quantization:

# Arbitrary zero point scale = (x_max - x_min) / (2^bits - 1) zero_point = round(-x_min / scale) q = round(x / scale) + zero_point

Per-Tensor vs Per-Channel

# Per-tensor: Single scale for entire tensor scale = compute_scale(tensor) quantized = quantize(tensor, scale) # Per-channel: Different scale per dimension scales = [compute_scale(tensor[i]) for i in range(channels)] quantized = [quantize(tensor[i], scales[i]) for i in range(channels)]

Implementation Examples

Post-Training Quantization

import torch import torch.nn as nn def quantize_model_weights(model, bits=8): """Quantize model after training""" for name, param in model.named_parameters(): if 'weight' in name: # Calculate quantization parameters min_val = param.min() max_val = param.max() scale = (max_val - min_val) / (2**bits - 1) zero_point = -min_val / scale # Quantize and dequantize quantized = torch.round(param / scale + zero_point) quantized = torch.clamp(quantized, 0, 2**bits - 1) dequantized = (quantized - zero_point) * scale # Replace weights param.data = dequantized

Quantization-Aware Training

class QuantizedLinear(nn.Module): def __init__(self, in_features, out_features, bits=8): super().__init__() self.weight = nn.Parameter(torch.randn(out_features, in_features)) self.bits = bits def forward(self, x): # Fake quantization during training if self.training: # Compute scale w_min, w_max = self.weight.min(), self.weight.max() scale = (w_max - w_min) / (2**self.bits - 1) # Quantize and dequantize w_quant = torch.round(self.weight / scale) * scale # Straight-through estimator for gradients w_quant = self.weight + (w_quant - self.weight).detach() else: w_quant = self.weight return F.linear(x, w_quant)

Performance Analysis

Memory Savings

MethodBitsMemoryRelative Size
Float3232100%1.00×
Float161650%0.50×
Int8825%0.25×
Int4412.5%0.125×
Binary13.125%0.03125×

Accuracy Impact

Typical accuracy retention:

Float32 → Float16: 99.5% Float32 → Int8: 98-99% Float32 → Int4: 95-97% Float32 → Binary: 85-90%

Speed Improvements

# Benchmark example import time def benchmark_inference(model, input_data, quantized=False): if quantized: model = quantize_model(model) start = time.time() with torch.no_grad(): for _ in range(1000): output = model(input_data) return time.time() - start # Results (typical) # Float32: 1.0s # Int8: 0.3s (3.3× faster) # Int4: 0.2s (5× faster)

Advanced Techniques

1. Mixed Precision

Different precision for different layers:

config = { 'attention': 8, # Int8 for attention 'ffn': 4, # Int4 for feed-forward 'embeddings': 16 # Float16 for embeddings }

2. Dynamic Quantization

Quantize activations on-the-fly:

model = torch.quantization.quantize_dynamic( model, {nn.Linear}, # Layers to quantize dtype=torch.qint8 )

3. Learned Quantization

Learn optimal quantization parameters:

class LearnedQuantizer(nn.Module): def __init__(self, bits=8): super().__init__() self.scale = nn.Parameter(torch.ones(1)) self.zero_point = nn.Parameter(torch.zeros(1)) self.bits = bits def forward(self, x): # Learned affine transformation q = torch.round(x / self.scale + self.zero_point) q = torch.clamp(q, 0, 2**self.bits - 1) return (q - self.zero_point) * self.scale

Quantization for Embeddings

Embedding Table Quantization

class QuantizedEmbedding(nn.Module): def __init__(self, num_embeddings, embedding_dim, bits=8): super().__init__() # Store quantized embeddings self.embeddings = nn.Parameter( torch.randint(0, 2**bits, (num_embeddings, embedding_dim), dtype=torch.uint8) ) self.scale = nn.Parameter(torch.ones(embedding_dim)) self.zero_point = nn.Parameter(torch.zeros(embedding_dim)) def forward(self, indices): # Lookup and dequantize quantized = self.embeddings[indices] return (quantized - self.zero_point) * self.scale

Product Quantization

Split vectors and quantize separately:

def product_quantization(vectors, num_subvectors=8, bits=8): """Quantize vectors using product quantization""" D = vectors.shape[1] d = D // num_subvectors quantized = [] codebooks = [] for i in range(num_subvectors): # Extract subvector subvecs = vectors[:, i*d:(i+1)*d] # Learn codebook (k-means) kmeans = KMeans(n_clusters=2**bits) labels = kmeans.fit_predict(subvecs) quantized.append(labels) codebooks.append(kmeans.cluster_centers_) return quantized, codebooks

Best Practices

1. Calibration

Determine optimal scale/zero-point:

def calibrate_quantization(data_loader, model): """Find optimal quantization parameters""" min_vals, max_vals = {}, {} for batch in data_loader: output = model(batch) for name, tensor in model.named_parameters(): if name not in min_vals: min_vals[name] = tensor.min() max_vals[name] = tensor.max() else: min_vals[name] = min(min_vals[name], tensor.min()) max_vals[name] = max(max_vals[name], tensor.max()) return min_vals, max_vals

2. Outlier Handling

def clip_outliers(tensor, percentile=99.9): """Clip outliers before quantization""" threshold = np.percentile(abs(tensor), percentile) return np.clip(tensor, -threshold, threshold)

3. Error Compensation

def quantize_with_error_compensation(weights, bits=8): """Accumulate and compensate quantization errors""" error = 0 quantized = [] for w in weights: # Add accumulated error w_compensated = w + error # Quantize q = quantize(w_compensated, bits) # Compute new error error = w_compensated - q quantized.append(q) return quantized

Deployment Considerations

Mobile/Edge Deployment

# TensorFlow Lite example import tensorflow as tf converter = tf.lite.TFLiteConverter.from_saved_model(saved_model_dir) converter.optimizations = [tf.lite.Optimize.DEFAULT] converter.representative_dataset = representative_dataset converter.target_spec.supported_ops = [tf.lite.OpsSet.TFLITE_BUILTINS_INT8] converter.inference_input_type = tf.int8 converter.inference_output_type = tf.int8 tflite_model = converter.convert()

Hardware Acceleration

  • ARM: Int8 with NEON
  • x86: Int8 with AVX512 VNNI
  • GPU: Int8 Tensor Cores
  • TPU: Bfloat16 native

References

  • Jacob et al. "Quantization and Training of Neural Networks for Efficient Integer-Arithmetic-Only Inference"
  • Gholami et al. "A Survey of Quantization Methods for Efficient Neural Network Inference"
  • Dettmers et al. "8-bit Optimizers via Block-wise Quantization"
  • Zafrir et al. "Q8BERT: Quantized 8Bit BERT"

If you found this explanation helpful, consider sharing it with others.

Mastodon