import dataclasses
import datetime
import os
import datasets
import tokenizers
import torch
import torch.distributed as dist
import torch.nn as nn
import torch.nn.functional as F
import torch.optim.lr_scheduler as lr_scheduler
import tqdm
from torch import Tensor
from torch.distributed.checkpoint import load, save
from torch.distributed.checkpoint.default_planner import DefaultLoadPlanner
from torch.distributed.fsdp import FSDPModule, fully_shard
from torch.distributed.tensor import Replicate, Shard
from torch.distributed.tensor.parallel import (
ColwiseParallel,
PrepareModuleInput,
RowwiseParallel,
SequenceParallel,
loss_parallel,
parallelize_module,
)
from torch.utils.data.distributed import DistributedSampler
# Set default to bfloat16
torch.set_default_dtype(torch.bfloat16)
print(“NCCL version:”, torch.cuda.nccl.version())
# Build the model
@dataclasses.dataclass
class LlamaConfig:
“”“Define Llama model hyperparameters.”“”
vocab_size: int = 50000 # Size of the tokenizer vocabulary
max_position_embeddings: int = 2048 # Maximum sequence length
hidden_size: int = 768 # Dimension of hidden layers
intermediate_size: int = 4*768 # Dimension of MLP’s hidden layer
num_hidden_layers: int = 12 # Number of transformer layers
num_attention_heads: int = 12 # Number of attention heads
num_key_value_heads: int = 3 # Number of key-value heads for GQA
class RotaryPositionEncoding(nn.Module):
“”“Rotary position encoding.”“”
def __init__(self, dim: int, max_position_embeddings: int) -> None:
“”“Initialize the RotaryPositionEncoding module.
Args:
dim: The hidden dimension of the input tensor to which RoPE is applied
max_position_embeddings: The maximum sequence length of the input tensor
““”
super().__init__()
self.dim = dim
self.max_position_embeddings = max_position_embeddings
# compute a matrix of n\theta_i
N = 10_000.0
inv_freq = 1.0 / (N ** (torch.arange(0, dim, 2) / dim))
inv_freq = torch.cat((inv_freq, inv_freq), dim=–1)
position = torch.arange(max_position_embeddings)
sinusoid_inp = torch.outer(position, inv_freq)
# save cosine and sine matrices as buffers, not parameters
self.register_buffer(“cos”, sinusoid_inp.cos())
self.register_buffer(“sin”, sinusoid_inp.sin())
def forward(self, x: Tensor) -> Tensor:
“”“Apply RoPE to tensor x.
Args:
x: Input tensor of shape (batch_size, seq_length, num_heads, head_dim)
Returns:
Output tensor of shape (batch_size, seq_length, num_heads, head_dim)
““”
batch_size, seq_len, num_heads, head_dim = x.shape
device = x.device
dtype = x.dtype
# transform the cosine and sine matrices to 4D tensor and the same dtype as x
cos = self.cos.to(device, dtype)[:seq_len].view(1, seq_len, 1, –1)
sin = self.sin.to(device, dtype)[:seq_len].view(1, seq_len, 1, –1)
# apply RoPE to x
x1, x2 = x.chunk(2, dim=–1)
rotated = torch.cat((–x2, x1), dim=–1)
output = (x * cos) + (rotated * sin)
return output
class LlamaAttention(nn.Module):
“”“Grouped-query attention with rotary embeddings.”“”
def __init__(self, config: LlamaConfig) -> None:
super().__init__()
self.hidden_size = config.hidden_size
self.num_heads = config.num_attention_heads
self.head_dim = self.hidden_size // self.num_heads
self.num_kv_heads = config.num_key_value_heads # GQA: H_kv < H_q
# hidden_size must be divisible by num_heads
assert (self.head_dim * self.num_heads) == self.hidden_size
# Linear layers for Q, K, V projections
self.q_proj = nn.Linear(self.hidden_size, self.num_heads * self.head_dim, bias=False)
self.k_proj = nn.Linear(self.hidden_size, self.num_kv_heads * self.head_dim, bias=False)
self.v_proj = nn.Linear(self.hidden_size, self.num_kv_heads * self.head_dim, bias=False)
self.o_proj = nn.Linear(self.num_heads * self.head_dim, self.hidden_size, bias=False)
def forward(self, hidden_states: Tensor, rope: RotaryPositionEncoding, attn_mask: Tensor) -> Tensor:
bs, seq_len, dim = hidden_states.size()
# Project inputs to Q, K, V
query_states = self.q_proj(hidden_states).view(bs, seq_len, self.num_heads, self.head_dim)
key_states = self.k_proj(hidden_states).view(bs, seq_len, self.num_kv_heads, self.head_dim)
value_states = self.v_proj(hidden_states).view(bs, seq_len, self.num_kv_heads, self.head_dim)
# Apply rotary position embeddings
query_states = rope(query_states)
key_states = rope(key_states)
# Transpose tensors from BSHD to BHSD dimension for scaled_dot_product_attention
query_states = query_states.transpose(1, 2)
key_states = key_states.transpose(1, 2)
value_states = value_states.transpose(1, 2)
# Use PyTorch’s optimized attention implementation
# setting is_causal=True is incompatible with setting explicit attention mask
attn_output = F.scaled_dot_product_attention(
query_states,
key_states,
value_states,
attn_mask=attn_mask,
dropout_p=0.0,
enable_gqa=True,
)
# Transpose output tensor from BHSD to BSHD dimension, reshape to 3D, and then project output
attn_output = attn_output.transpose(1, 2).reshape(bs, seq_len, self.hidden_size)
attn_output = self.o_proj(attn_output)
return attn_output
class LlamaMLP(nn.Module):
“”“Feed-forward network with SwiGLU activation.”“”
def __init__(self, config: LlamaConfig) -> None:
super().__init__()
# Two parallel projections for SwiGLU
self.gate_proj = nn.Linear(config.hidden_size, config.intermediate_size, bias=False)
self.up_proj = nn.Linear(config.hidden_size, config.intermediate_size, bias=False)
self.act_fn = F.silu # SwiGLU activation function
# Project back to hidden size
self.down_proj = nn.Linear(config.intermediate_size, config.hidden_size, bias=False)
def forward(self, x: Tensor) -> Tensor:
# SwiGLU activation: multiply gate and up-projected inputs
gate = self.act_fn(self.gate_proj(x))
up = self.up_proj(x)
return self.down_proj(gate * up)
class LlamaDecoderLayer(nn.Module):
“”“Single transformer layer for a Llama model.”“”
def __init__(self, config: LlamaConfig) -> None:
super().__init__()
self.input_layernorm = nn.RMSNorm(config.hidden_size, eps=1e–5)
self.self_attn = LlamaAttention(config)
self.post_attention_layernorm = nn.RMSNorm(config.hidden_size, eps=1e–5)
self.mlp = LlamaMLP(config)
def forward(self, hidden_states: Tensor, rope: RotaryPositionEncoding, attn_mask: Tensor) -> Tensor:
# First residual block: Self-attention
residual = hidden_states
hidden_states = self.input_layernorm(hidden_states)
attn_outputs = self.self_attn(hidden_states, rope=rope, attn_mask=attn_mask)
hidden_states = attn_outputs + residual
# Second residual block: MLP
residual = hidden_states
hidden_states = self.post_attention_layernorm(hidden_states)
hidden_states = self.mlp(hidden_states) + residual
return hidden_states
class LlamaModel(nn.Module):
“”“The full Llama model without any pretraining heads.”“”
def __init__(self, config: LlamaConfig) -> None:
super().__init__()
self.rotary_emb = RotaryPositionEncoding(
config.hidden_size // config.num_attention_heads,
config.max_position_embeddings,
)
self.embed_tokens = nn.Embedding(config.vocab_size, config.hidden_size)
self.layers = nn.ModuleList([
LlamaDecoderLayer(config) for _ in range(config.num_hidden_layers)
])
self.norm = nn.RMSNorm(config.hidden_size, eps=1e–5)
def forward(self, input_ids: Tensor, attn_mask: Tensor) -> Tensor:
# Convert input token IDs to embeddings
hidden_states = self.embed_tokens(input_ids)
# Process through all transformer layers, then the final norm layer
for layer in self.layers:
hidden_states = layer(hidden_states, rope=self.rotary_emb, attn_mask=attn_mask)
hidden_states = self.norm(hidden_states)
# Return the final hidden states
return hidden_states
class LlamaForPretraining(nn.Module):
def __init__(self, config: LlamaConfig) -> None:
super().__init__()
self.base_model = LlamaModel(config)
self.lm_head = nn.Linear(config.hidden_size, config.vocab_size, bias=False)
def forward(self, input_ids: Tensor, attn_mask: Tensor) -> Tensor:
hidden_states = self.base_model(input_ids, attn_mask)
return self.lm_head(hidden_states)
def create_causal_mask(batch: Tensor, dtype: torch.dtype = torch.float32) -> Tensor:
“”“Create a causal mask for self-attention.
Args:
batch: Batch of sequences, shape (batch_size, seq_len)
dtype: Data type of the mask
Returns:
Causal mask of shape (seq_len, seq_len)
““”
batch_size, seq_len = batch.shape
mask = torch.full((seq_len, seq_len), float(“-inf”), device=batch.device, dtype=dtype) \
.triu(diagonal=1)
return mask
def create_padding_mask(batch: Tensor, padding_token_id: int, dtype: torch.dtype = torch.float32) -> Tensor:
“”“Create a padding mask for a batch of sequences for self-attention.
Args:
batch: Batch of sequences, shape (batch_size, seq_len)
padding_token_id: ID of the padding token
dtype: Data type of the mask
Returns:
Padding mask of shape (batch_size, 1, seq_len, seq_len)
““”
padded = torch.zeros_like(batch, device=batch.device, dtype=dtype) \
.masked_fill(batch == padding_token_id, float(“-inf”))
mask = padded[:,:,None] + padded[:,None,:]
return mask[:, None, :, :]
# Generator function to create padded sequences of fixed length
class PretrainingDataset(torch.utils.data.Dataset):
def __init__(self, dataset: datasets.Dataset, tokenizer: tokenizers.Tokenizer,
seq_length: int):
self.dataset = dataset
self.tokenizer = tokenizer
self.seq_length = seq_length
self.bot = tokenizer.token_to_id(“[BOT]”)
self.eot = tokenizer.token_to_id(“[EOT]”)
self.pad = tokenizer.token_to_id(“[PAD]”)
def __len__(self):
return len(self.dataset)
def __getitem__(self, index: int) -> tuple[Tensor, Tensor]:
“”“Get a sequence of token ids from the dataset. [BOT] and [EOT] tokens
are added. Clipped and padded to the sequence length.
““”
seq = self.dataset[index][“text”]
tokens: list[int] = [self.bot] + self.tokenizer.encode(seq).ids + [self.eot]
# pad to target sequence length
toklen = len(tokens)
if toklen < self.seq_length+1:
pad_length = self.seq_length+1 – toklen
tokens += [self.pad] * pad_length
# return the sequence
x = torch.tensor(tokens[:self.seq_length], dtype=torch.int64)
y = torch.tensor(tokens[1:self.seq_length+1], dtype=torch.int64)
return x, y
def load_checkpoint(model: nn.Module, optimizer: torch.optim.Optimizer, scheduler: lr_scheduler.SequentialLR) -> None:
dist.barrier()
load(
{“model”: model, “optimizer”: optimizer},
checkpoint_id=“checkpoint-dist”,
planner=DefaultLoadPlanner(allow_partial_load=True), # ignore keys for RoPE buffer
)
scheduler.load_state_dict(
torch.load(“checkpoint-dist/lrscheduler.pt”, map_location=device),
)
dist.barrier()
def save_checkpoint(model: nn.Module, optimizer: torch.optim.Optimizer, scheduler: lr_scheduler.SequentialLR) -> None:
dist.barrier()
save(
{“model”: model, “optimizer”: optimizer},
checkpoint_id=“checkpoint-dist”,
)
if dist.get_rank() == 0:
torch.save(scheduler.state_dict(), “checkpoint-dist/lrscheduler.pt”)
dist.barrier()
# Load the tokenizer and dataset
tokenizer = tokenizers.Tokenizer.from_file(“bpe_50K.json”)
dataset = datasets.load_dataset(“HuggingFaceFW/fineweb”, “sample-10BT”, split=“train”)
# Initialize the distributed environment
dist.init_process_group(backend=“nccl”, timeout=datetime.timedelta(seconds=60))
local_rank = int(os.environ[“LOCAL_RANK”])
device = torch.device(f“cuda:{local_rank}”)
rank = dist.get_rank()
world_size = dist.get_world_size()
print(f“World size {world_size}, rank {rank}, local rank {local_rank}. Using {device}”)
# Initialize the mesh for tensor parallelism
n_tensor_parallel = 2
assert world_size % n_tensor_parallel == 0, “Expect world size to be divisible by number of tensor parallel GPUs”
mesh = dist.device_mesh.init_device_mesh(
“cuda”,
(world_size // n_tensor_parallel, n_tensor_parallel),
mesh_dim_names=(“dp”, “tp”),
)
print(f“({rank}) Mesh: {mesh}, DP size: {mesh[‘dp’].size()}, TP size: {mesh[‘tp’].size()}, DP local rank: {mesh[‘dp’].get_local_rank()}, TP local rank: {mesh[‘tp’].get_local_rank()}”)
# Create pretraining model on meta device, on all ranks
with torch.device(“meta”):
model_config = LlamaConfig()
model = LlamaForPretraining(model_config)
# Set up tensor parallelism on each transformer block in the base model
tp_plan = {
“input_layernorm”: SequenceParallel(),
“self_attn”: PrepareModuleInput(
input_layouts=Shard(dim=1), # only one position arg will be used
desired_input_layouts=Replicate(),
),
# Q/K projections output will be used with RoPE, need to be replicated
# Q/K/V output will be used with GQA, also need to be replicated
“self_attn.q_proj”: ColwiseParallel(output_layouts=Replicate()),
“self_attn.k_proj”: ColwiseParallel(output_layouts=Replicate()),
“self_attn.v_proj”: ColwiseParallel(output_layouts=Replicate()),
“self_attn.o_proj”: RowwiseParallel(input_layouts=Replicate(), output_layouts=Shard(1)),
“post_attention_layernorm”: SequenceParallel(),
“mlp”: PrepareModuleInput(
input_layouts=Shard(dim=1),
desired_input_layouts=Replicate(),
),
“mlp.gate_proj”: ColwiseParallel(),
“mlp.up_proj”: ColwiseParallel(),
“mlp.down_proj”: RowwiseParallel(output_layouts=Shard(1)),
}
for layer in model.base_model.layers:
parallelize_module(layer, mesh[“tp”], tp_plan)
# Set up tensor parallelism on the embedding and output norm layers in the base model
# and the prediction head in the top-level model
tp_plan = {
“base_model.embed_tokens”: RowwiseParallel(
input_layouts=Replicate(),
output_layouts=Shard(1),
),
“base_model.norm”: SequenceParallel(),
“lm_head”: ColwiseParallel(
input_layouts=Shard(1),
# output_layouts=Replicate(), # only if not using loss parallel
use_local_output=False, # Keep DTensor output for loss parallel
),
}
parallelize_module(model, mesh[“tp”], tp_plan)
# Convert tensor-parallelized model to FSDP2, must shard every component
# shard across the “dp” dimension of the mesh
for layer in model.base_model.layers:
fully_shard(layer, mesh=mesh[“dp”])
fully_shard(model.base_model, mesh=mesh[“dp”])
fully_shard(model, mesh=mesh[“dp”])
def reset_all_weights(model: nn.Module) -> None:
“”“Initialize all weights of the model after moving it away from meta device.”“”
@torch.no_grad()
def weight_reset(m: nn.Module):
reset_parameters = getattr(m, “reset_parameters”, None)
if callable(reset_parameters):
m.reset_parameters()
# Applies fn recursively to model itself and all of model.children()
model.apply(fn=weight_reset)
torch.manual_seed(42)
model.to_empty(device=device)
reset_all_weights(model)
assert isinstance(model, FSDPModule), f“Expected FSDPModule, got {type(model)}”
# Training parameters
epochs = 3
learning_rate = 1e–3
batch_size = 64 // mesh[“dp”].size()
seq_length = 512
num_warmup_steps = 1000
PAD_TOKEN_ID = tokenizer.token_to_id(“[PAD]”)
model.train()
# DataLoader, optimizer, scheduler, and loss function
# Sampler is needed to shard the dataset across world size
dataset = PretrainingDataset(dataset, tokenizer, seq_length)
sampler = DistributedSampler(
dataset, shuffle=False, drop_last=True,
num_replicas=mesh[“dp”].size(),
rank=mesh[“dp”].get_local_rank(),
)
dataloader = torch.utils.data.DataLoader(
dataset,
sampler=sampler,
batch_size=batch_size,
pin_memory=True, # optional
shuffle=False,
num_workers=2,
prefetch_factor=2,
)
num_training_steps = len(dataloader) * epochs
optimizer = torch.optim.AdamW(
model.parameters(), lr=learning_rate, betas=(0.9, 0.99), eps=1e–8, weight_decay=0.1,
)
warmup_scheduler = lr_scheduler.LinearLR(
optimizer,
start_factor=0.1, end_factor=1.0, total_iters=num_warmup_steps,
)
cosine_scheduler = lr_scheduler.CosineAnnealingLR(
optimizer,
T_max=num_training_steps – num_warmup_steps,
eta_min=0,
)
scheduler = lr_scheduler.SequentialLR(
optimizer,
schedulers=[warmup_scheduler, cosine_scheduler],
milestones=[num_warmup_steps],
)
loss_fn = nn.CrossEntropyLoss(ignore_index=PAD_TOKEN_ID)
# if checkpoint-dist dir exists, load the checkpoint to model and optimizer
if os.path.exists(“checkpoint-dist”):
load_checkpoint(model, optimizer, scheduler)
# start training
print(f“({rank}) Starting training”)
for epoch in range(epochs):
pbar = tqdm.tqdm(dataloader, desc=f“({rank}) Epoch {epoch+1}/{epochs}”)
for batch_id, batch in enumerate(pbar):
if batch_id % 1000 == 0:
save_checkpoint(model, optimizer, scheduler)
# Explicit prefetching before sending any data to model
model.unshard()
# Get batched data, move from CPU to GPU
input_ids, target_ids = batch
input_ids = input_ids.to(device)
target_ids = target_ids.to(device)
# create attention mask: causal mask + padding mask
attn_mask = create_causal_mask(input_ids) + \
create_padding_mask(input_ids, PAD_TOKEN_ID)
# Extract output from model
logits = model(input_ids, attn_mask)
optimizer.zero_grad()
with loss_parallel():
# Compute loss: cross-entropy between logits and target, ignoring padding tokens
loss = loss_fn(logits.view(–1, logits.size(–1)), target_ids.view(–1))
# Backward with loss on DTensor
loss.backward()
torch.nn.utils.clip_grad_norm_(model.parameters(), 1.0)
optimizer.step()
scheduler.step()
pbar.set_postfix(loss=loss.item())
pbar.update(1)
pbar.close()
# Save the model
save_checkpoint(model, optimizer, scheduler)
# Clean up the distributed environment
dist.destroy_process_group()