Masked Autoencoders (MAE) have revolutionized self-supervised learning for vision transformers 1. In this tutorial, we’ll build a working MAE from scratch in PyTorch, understanding each component along the way and train it on a MNIST dataset as a toy exercise.
Most of the vision-transformer (ViT) backbones that are widely used in the SOTA vision applications are typically trained with MAE or one of its variants. For example, Segment-Anything-V1 used a pretrained ViT encoder backbone and upgrade it. See my mini blog series on Segment-Anything-V1 here if you are interested.
Overview: What Makes MAE Special?
MAE uses an asymmetric encoder-decoder architecture where:
- Random masking: 75% of image patches are masked during training
- Efficient encoding: Only visible patches are processed by the encoder
- Lightweight decoding: A small decoder reconstructs masked patches
- Simple objective: Pixel-level MSE loss on masked regions only
Let’s build this step by step.
Setup and Dependencies
First, let’s import our required libraries:
import torch
import torch.nn as nn
import torch.nn.functional as F
from einops import rearrange, repeat
from einops.layers.torch import Rearrange
import math
Step 1: Patch Embedding
MAE starts by dividing images into non-overlapping patches. We’ll create a simple patch embedding layer that projects these patches into our embedding dimension.
class PatchEmbedding(nn.Module):
def __init__(self, image_size=224, patch_size=16, in_channels=3, embed_dim=768):
super().__init__()
self.image_size = image_size
self.patch_size = patch_size
self.num_patches = (image_size // patch_size) ** 2
# Linear projection of flattened patches
# ------------------------------------------------------------------
# DETAILED EXPLANATION OF THE REARRANGE LAYER:
# ------------------------------------------------------------------
# The input tensor (x) has shape (B, C, H_img, W_img).
# The goal is to reshape it to (B, N_patches, P_flat_dim).
# Rearrange operation: 'b c (h p1) (w p2) -> b (h w) (p1 p2 c)'
# 1. Input Pattern (Left side): 'b c (h p1) (w p2)'
# - 'b': Batch size (unaffected).
# - 'c': Number of input channels (e.g., 3 for RGB).
# - '(h p1)': The total height (H_img) is explicitly factored into
# (h) number of patches vertically, each having (p1) height.
# (h = H_img / p1, where p1 is 'patch_size').
# - '(w p2)': The total width (W_img) is factored into (w) number
# of patches horizontally, each having (p2) width.
# (w = W_img / p2, where p2 is 'patch_size').
# - p1 and p2 are fixed by the keyword arguments (p1=16, p2=16).
# 2. Output Pattern (Right side): 'b (h w) (p1 p2 c)'
# - 'b': Batch size remains the first dimension.
# - '(h w)': These two dimensions (h=vertical patch count, w=horizontal patch count)
# are multiplied together to form the total *Number of Patches* (N_patches = h * w).
# This becomes the second dimension: the sequence length.
# - '(p1 p2 c)': These three dimensions (p1=patch height, p2=patch width, c=channels)
# are multiplied together to form the *Flattened Patch Vector* # (P_flat_dim = p1 * p2 * c). This is the embedding dimension before Linear projection.
# This step effectively flattens each individual patch.
# Final Result: The tensor is reshaped from (B, C, H_img, W_img)
# to (B, N_patches, P_flat_dim). For the defaults (224, 16, 3),
# the shape is (B, 196, 768).
# ------------------------------------------------------------------
self.projection = nn.Sequential(
Rearrange('b c (h p1) (w p2) -> b (h w) (p1 p2 c)',
p1=patch_size, p2=patch_size),
nn.Linear(patch_size * patch_size * in_channels, embed_dim)
)
def forward(self, x):
# x: (batch, channels, height, width)
# output: (batch, num_patches, embed_dim)
return self.projection(x)
For a 224×224 image with 16×16 patches, we get 196 patches. Each patch is projected to our embedding dimension (typically 768 for ViT-Base).
Step 2: Positional Encoding
Since transformers have no inherent notion of spatial position, we add non-learnable sinosoidal positional embeddings to the patches.
class SinusoidalPositionalEncoding(nn.Module):
"""
Non-learnable positional encoding from the original Transformer paper.
Calculated based on index and dimension of the embedding.
"""
def __init__(self, num_patches, embed_dim):
super().__init__()
# Create a buffer (not a parameter) that holds the constant P.E. matrix
pe = torch.zeros(num_patches, embed_dim)
position = torch.arange(0, num_patches, dtype=torch.float).unsqueeze(1)
# Calculate the denominator (10000^(2i / d_model))
div_term = torch.exp(
torch.arange(0, embed_dim, 2).float() * (-math.log(10000.0) / embed_dim)
)
# ----------------------------------------------------------------------
# SINUSOIDAL P.E. AND SIGNAL PROCESSING CONNECTION
# ----------------------------------------------------------------------
# The argument for the sine/cosine function is (position * div_term).
# This mirrors the standard signal processing form: A = t * (2 * pi * f).
#
# 1. 'position' is analogous to 't' (time or spatial index).
# 2. 'div_term' is analogous to '2 * pi * f' (the angular frequency).
#
# By iterating through the embedding dimension (2k), 'div_term' is varied:
#
# - **When 2k is small (beginning of embed_dim):**
# The div_term is large (closer to 1). This corresponds to a **High Frequency**.
# High frequency waves oscillate quickly and are responsible for
# encoding *fine, local details* about the patch's exact position.
#
# - **When 2k is large (end of embed_dim):**
# The div_term is small (closer to 1/10000). This corresponds to a **Low Frequency**.
# Low frequency waves oscillate slowly and are responsible for
# encoding *macro, global information* about the patch's relative location.
#
# This design ensures every patch is assigned a unique location code
# across a spectrum of frequencies.
# ----------------------------------------------------------------------
# Apply sine to even indices
pe[:, 0::2] = torch.sin(position * div_term)
# Apply cosine to odd indices
pe[:, 1::2] = torch.cos(position * div_term)
# Add batch dimension and register as a buffer
pe = pe.unsqueeze(0) # Shape (1, num_patches, embed_dim)
self.register_buffer('pe', pe)
def forward(self, x):
# x: (batch, num_patches, embed_dim)
# Truncate or extend the encoding if needed (for varying resolution)
current_num_patches = x.shape[1]
# Note: If the actual number of patches changes, you'd typically need
# to re-calculate or handle the indexing here. For a fixed input size,
# simply adding the pre-calculated tensor is enough.
return x + self.pe[:, :current_num_patches, :]
Step 3: Transformer Block
The core building block of our encoder and decoder is the standard transformer block with multi-head self-attention and feed-forward networks.
class TransformerBlock(nn.Module):
def __init__(self, embed_dim, num_heads, mlp_ratio=4.0, dropout=0.0):
super().__init__()
self.norm1 = nn.LayerNorm(embed_dim)
self.attn = nn.MultiheadAttention(
embed_dim, num_heads, dropout=dropout, batch_first=True
)
self.norm2 = nn.LayerNorm(embed_dim)
# MLP block
mlp_hidden_dim = int(embed_dim * mlp_ratio)
self.mlp = nn.Sequential(
nn.Linear(embed_dim, mlp_hidden_dim),
nn.GELU(),
nn.Dropout(dropout),
nn.Linear(mlp_hidden_dim, embed_dim),
nn.Dropout(dropout)
)
def forward(self, x):
# Pre-norm architecture
x = x + self.attn(self.norm1(x), self.norm1(x), self.norm1(x))[0]
x = x + self.mlp(self.norm2(x))
return x
This follows the standard Vision Transformer architecture with pre-normalization, which has shown better training stability.
Step 4: Random Masking
The heart of MAE is its masking strategy. We randomly select which patches to keep (visible) and which to mask.
def random_masking(x, mask_ratio=0.75):
"""
Perform random masking by shuffling.
x: (batch, num_patches, embed_dim)
Returns:
x_masked: visible patches only
mask: binary mask (1 = masked, 0 = kept)
ids_restore: indices to restore original order
"""
batch, num_patches, dim = x.shape
num_keep = int(num_patches * (1 - mask_ratio))
# Generate random noise for shuffling
noise = torch.rand(batch, num_patches, device=x.device)
# Sort by noise to get random permutation
ids_shuffle = torch.argsort(noise, dim=1)
ids_restore = torch.argsort(ids_shuffle, dim=1)
# Keep only the first subset
ids_keep = ids_shuffle[:, :num_keep]
x_masked = torch.gather(
x, dim=1, index=ids_keep.unsqueeze(-1).expand(-1, -1, dim)
)
# Generate binary mask: 1 is masked/removed, 0 is kept/shown to the model
mask = torch.ones([batch, num_patches], device=x.device)
mask[:, :num_keep] = 0
mask = torch.gather(mask, dim=1, index=ids_restore)
return x_masked, mask, ids_restore
This function is crucial: it randomly shuffles patches, keeps only 25% (with default 75% masking), and provides indices to restore the original order later.
Step 5: MAE Encoder
The encoder processes only the visible patches, making it highly efficient.
class MAEEncoder(nn.Module):
def __init__(self, image_size=224, patch_size=16, in_channels=3,
embed_dim=768, depth=12, num_heads=12, mlp_ratio=4.0):
super().__init__()
num_patches = (image_size // patch_size) ** 2
self.patch_embed = PatchEmbedding(
image_size, patch_size, in_channels, embed_dim
)
self.pos_embed = SinusoidalPositionalEncoding(num_patches, embed_dim)
# Transformer blocks
self.blocks = nn.ModuleList([
TransformerBlock(embed_dim, num_heads, mlp_ratio)
for _ in range(depth)
])
self.norm = nn.LayerNorm(embed_dim)
def forward(self, x, mask_ratio=0.75):
# Patch embedding
x = self.patch_embed(x) # (B, num_patches, embed_dim)
x = self.pos_embed(x)
# Random masking
x, mask, ids_restore = random_masking(x, mask_ratio)
# Apply transformer blocks
for block in self.blocks:
x = block(x)
x = self.norm(x)
return x, mask, ids_restore
Step 6: MAE Decoder
The decoder is lightweight (fewer layers, smaller dimension) and reconstructs the full image from encoded visible patches and mask tokens.
class MAEDecoder(nn.Module):
def __init__(self, num_patches, patch_size=16, in_channels=3,
embed_dim=768, decoder_embed_dim=512,
decoder_depth=8, decoder_num_heads=16, mlp_ratio=4.0):
super().__init__()
self.num_patches = num_patches
self.patch_size = patch_size
self.in_channels = in_channels
# Project from encoder dimension to decoder dimension
self.decoder_embed = nn.Linear(embed_dim, decoder_embed_dim)
# Learnable mask token
self.mask_token = nn.Parameter(torch.randn(1, 1, decoder_embed_dim) * 0.02)
# Positional encoding for decoder
self.decoder_pos_embed = SinusoidalPositionalEncoding(num_patches, decoder_embed_dim)
# Transformer blocks
self.decoder_blocks = nn.ModuleList([
TransformerBlock(decoder_embed_dim, decoder_num_heads, mlp_ratio)
for _ in range(decoder_depth)
])
self.decoder_norm = nn.LayerNorm(decoder_embed_dim)
# Prediction head to reconstruct pixels
self.decoder_pred = nn.Linear(
decoder_embed_dim,
patch_size * patch_size * in_channels
)
def forward(self, x, ids_restore):
# x: encoded visible patches (B, num_visible, embed_dim)
batch_size = x.shape[0]
# Project to decoder dimension
x = self.decoder_embed(x)
# Append mask tokens
mask_tokens = self.mask_token.repeat(
batch_size, ids_restore.shape[1] - x.shape[1], 1
)
x_full = torch.cat([x, mask_tokens], dim=1)
# Unshuffle to restore original order
x_full = torch.gather(
x_full, dim=1,
index=ids_restore.unsqueeze(-1).expand(-1, -1, x_full.shape[2])
)
# Add positional encoding
x_full = self.decoder_pos_embed(x_full)
# Apply transformer blocks
for block in self.decoder_blocks:
x_full = block(x_full)
x_full = self.decoder_norm(x_full)
# Predict pixel values
x_full = self.decoder_pred(x_full) # (B, num_patches, patch_size^2 * channels)
return x_full
The decoder inserts learnable mask tokens at masked positions, applies positional embeddings, processes everything through transformer blocks, and finally predicts pixel values for each patch.
Step 7: Complete MAE Model
Now we combine encoder and decoder into the full MAE model with the loss function.
class MAE(nn.Module):
def __init__(self, image_size=224, patch_size=16, in_channels=3,
embed_dim=768, depth=12, num_heads=12,
decoder_embed_dim=512, decoder_depth=8, decoder_num_heads=16,
mlp_ratio=4.0, mask_ratio=0.75):
super().__init__()
self.patch_size = patch_size
self.in_channels = in_channels
self.mask_ratio = mask_ratio
num_patches = (image_size // patch_size) ** 2
self.encoder = MAEEncoder(
image_size, patch_size, in_channels,
embed_dim, depth, num_heads, mlp_ratio
)
self.decoder = MAEDecoder(
num_patches, patch_size, in_channels,
embed_dim, decoder_embed_dim,
decoder_depth, decoder_num_heads, mlp_ratio
)
def patchify(self, imgs):
"""Convert images to patches for loss computation"""
p = self.patch_size
h = w = imgs.shape[2] // p
x = rearrange(imgs, 'b c (h p1) (w p2) -> b (h w) (p1 p2 c)',
p1=p, p2=p, h=h, w=w)
return x
def forward(self, imgs):
# Encode (with random masking)
latent, mask, ids_restore = self.encoder(imgs, self.mask_ratio)
# Decode
pred = self.decoder(latent, ids_restore)
# Compute loss
target = self.patchify(imgs)
loss = self.compute_loss(pred, target, mask)
return loss, pred, mask
def compute_loss(self, pred, target, mask):
"""
Compute MSE loss only on masked patches.
pred: (B, num_patches, patch_size^2 * channels)
target: (B, num_patches, patch_size^2 * channels)
mask: (B, num_patches), 1 is masked, 0 is visible
"""
loss = (pred - target) ** 2 # squared error
loss = loss.mean(dim=-1) # Mean per patch
# Compute loss only on masked patches
loss = (loss * mask).sum() / mask.sum()
return loss
The key insight in the loss function: we only penalize reconstruction error on masked patches. This forces the model to predict missing content rather than memorize visible patches.
Training Example
Here’s how you would set up training for mnist dataset as a toy exercise where the model can converge faster given the low dimensionality of the input images (28 x 28 x 1):
import torch
import torch.nn as nn
import torch.nn.functional as F
from einops import rearrange, repeat
from einops.layers.torch import Rearrange
import math
import torch.optim as optim
from torchvision import datasets, transforms
from torch.utils.data import DataLoader
import matplotlib.pyplot as plt
import numpy as np
def unpatchify(x, image_size, patch_size, in_channels):
"""
Converts predicted patches back to image format.
x: (B, num_patches, P*P*C)
Returns: (B, C, H, W)
"""
p = patch_size
h = w = image_size // p
# x shape: (B, h*w, p*p*C)
# Rearrange: (B, h*w, p1*p2*C) -> (B, C, h*p1, w*p2)
imgs = rearrange(x, 'b (h w) (p1 p2 c) -> b c (h p1) (w p2)',
h=h, w=w, p1=p, p2=p, c=in_channels)
return imgs
def train_mae():
# --- Configuration for MNIST (28x28, 1 Channel) ---
MNIST_IMG_SIZE = 28
PATCH_SIZE = 4
# Training Hyperparameters
NUM_EPOCHS = 64
BATCH_SIZE = 128
LEARNING_RATE = 1.5e-4
# Using small dimensions for fast training demo
MODEL_CONFIG = {
'image_size': MNIST_IMG_SIZE,
'patch_size': PATCH_SIZE,
'in_channels': 1, # MNIST is grayscale
'embed_dim': 64,
'depth': 4,
'num_heads': 4,
'decoder_embed_dim': 32,
'decoder_depth': 2,
'decoder_num_heads': 4,
'mlp_ratio': 2.0,
'mask_ratio': 0.75
}
device = torch.device('cuda' if torch.cuda.is_available() else 'cpu')
print(f"Using device: {device}")
# --- Data Loading ---
transform = transforms.ToTensor()
# 1. Training Dataset/DataLoader
train_dataset = datasets.MNIST('./data', train=True, download=True, transform=transform)
train_dataloader = DataLoader(train_dataset, batch_size=BATCH_SIZE, shuffle=True, drop_last=True)
# 2. Test Dataset/DataLoader
test_dataset = datasets.MNIST('./data', train=False, download=True, transform=transform)
test_dataloader = DataLoader(test_dataset, batch_size=BATCH_SIZE, shuffle=False)
# --- Model Initialization ---
model = MAE(**MODEL_CONFIG).to(device)
optimizer = optim.AdamW(model.parameters(), lr=LEARNING_RATE, weight_decay=0.05)
print(f"Model initialized. Total patches: {(MNIST_IMG_SIZE // PATCH_SIZE) ** 2}. Mask ratio: {MODEL_CONFIG['mask_ratio'] * 100:.0f}%.")
# --- Training Loop ---
model.train()
for epoch in range(NUM_EPOCHS):
total_loss = 0
for batch_idx, (images, _) in enumerate(train_dataloader):
images = images.to(device)
optimizer.zero_grad()
loss, _, _ = model(images)
loss.backward()
optimizer.step()
total_loss += loss.item()
if batch_idx % 200 == 0:
print(f"Epoch {epoch+1}/{NUM_EPOCHS} | Batch {batch_idx}/{len(train_dataloader)} | Loss: {loss.item():.6f}")
avg_loss = total_loss / len(train_dataloader)
print(f"--- Epoch {epoch+1} Finished | Average Loss: {avg_loss:.6f} ---")
print("\nTraining complete.")
return model, MODEL_CONFIG, test_dataloader
# --- Visualization Function ---
def visualize_reconstruction(model, dataloader, config, num_samples=5):
device = next(model.parameters()).device
model.eval()
# Get one batch of test images
test_images, _ = next(iter(dataloader))
test_images = test_images.to(device)
# Run the model
with torch.no_grad():
loss, pred_patches, mask = model(test_images)
# Convert predicted patches back to image format
# The output is (B, N_patches, P*P*C)
reconstructed_images = unpatchify(
pred_patches,
config['image_size'], config['patch_size'], config['in_channels']
)
# Create the masked input (for visualization)
# 1. Get the target patches (Original, unmasked)
target_patches = model.patchify(test_images)
# 2. Fill the masked areas of the target patches with a gray placeholder (-0.5)
# The mask is 1 for MASKED patches, 0 for VISIBLE patches.
mask_vis = (1 - mask.unsqueeze(-1)).bool() # (B, N, 1) -> (B, N, P*P*C)
# Create a tensor of average gray value (0.5 in [0, 1] range) for the placeholder
placeholder = torch.full_like(target_patches, 0.5)
# Create the "visible only" patch sequence by selecting original for visible,
# and placeholder for masked.
visible_patches = target_patches * (1 - mask.unsqueeze(-1)) + placeholder * mask.unsqueeze(-1)
# Convert the visible patches back to image format
visible_images = unpatchify(
visible_patches,
config['image_size'], config['patch_size'], config['in_channels']
)
# --- Plotting ---
fig, axes = plt.subplots(num_samples, 3, figsize=(9, 3 * num_samples))
plt.suptitle(f"MAE Reconstruction (Mask Ratio: {config['mask_ratio']:.2f})", y=1.02)
for i in range(num_samples):
# Original Image (Col 1)
orig_img = test_images[i].cpu().squeeze().numpy()
axes[i, 0].imshow(orig_img, cmap='gray')
axes[i, 0].set_title(f"Original {i+1}")
axes[i, 0].axis('off')
# Masked Image (Visible Only) (Col 2)
masked_img = visible_images[i].cpu().squeeze().numpy()
axes[i, 1].imshow(masked_img, cmap='gray')
axes[i, 1].set_title("Masked Input")
axes[i, 1].axis('off')
# Reconstructed Image (Col 3)
recon_img = reconstructed_images[i].cpu().squeeze().numpy()
# Clip pixel values to [0, 1] range for visualization
recon_img = np.clip(recon_img, 0, 1)
axes[i, 2].imshow(recon_img, cmap='gray')
axes[i, 2].set_title("Reconstruction")
axes[i, 2].axis('off')
plt.tight_layout(rect=[0, 0, 1, 0.98])
plt.show() # In a real environment, this would display the plot.
print("")
Call the training and visualization functions. Note that the parameters are hardcoded inside the train_mae() function where I have set the NUM_EPOCHS = 64. If you desire a shorter training, set this to a lower number:
# 1. Train the model
trained_model, config, test_dataloader = train_mae()
# 2. Visualize the results on unseen test data
print("\nStarting visualization of test set reconstruction...")
visualize_reconstruction(trained_model, test_dataloader, config, num_samples=4)
Output of the visualization showcasing how the model reconstructs the masked patches
Key Design Choices Explained
Why 75% masking? High masking ratios force the model to learn semantic representations rather than relying on local interpolation. The original paper found 75% optimal.
Why asymmetric encoder-decoder? The encoder only processes 25% of patches, making training 3-4× faster. The decoder is lightweight since reconstruction is easier than encoding.
Why MSE in pixel space? Despite seeming naive, pixel-level reconstruction works remarkably well and is simpler than perceptual losses or other alternatives.
Why mask tokens? The decoder needs placeholders for masked positions. A shared learnable token is efficient and allows the model to learn a good initialization for “unknown” content.
Using the Pretrained Encoder
After pretraining, you can extract the encoder for downstream tasks and follow the high-level steps below:
# Construct the dataloader
# train_dataloader = ...
# Extract encoder for fine-tuning
encoder = model.encoder
# Add a classification head
num_classes = 10
classifier = nn.Linear(16, num_classes)
# Fine-tune on Mnist classification
for images, labels in train_dataloader:
# No masking during fine-tuning
features = encoder(images, mask_ratio=0.0)[0]
# Global average pooling over the patches, practically dropping the patch axis
features = features.mean(dim=1)
logits = classifier(features)
# ... compute cross-entropy loss and optimize
Conclusion
This implementation covers all essential components of MAE:
- Efficient patch embedding and positional encoding
- Random masking that processes only visible patches
- Asymmetric encoder-decoder architecture
- Pixel-level MSE loss on masked regions
The beauty of MAE lies in its simplicity. By masking aggressively and reconstructing in pixel space, it learns powerful visual representations without requiring any labels. This makes it perfect for pretraining on massive unlabeled datasets before fine-tuning on specific tasks.
References
He, K., Chen, X., Xie, S., Li, Y., Dollár, P., & Girshick, R. (2022). Masked autoencoders are scalable vision learners. In Proceedings of the IEEE/CVF conference on computer vision and pattern recognition (pp. 16000-16009). ↩︎
