Abstract
This repository provides a domain-adapted Turkish legal instruction-tuned model derived from meta-llama/Llama-3.1-8B-Instruct. The model corresponds to the FP8 mixed-precision configuration trained on 8 nodes with a global batch size of 32, and is part of the “Harnessing Fully Sharded Data Parallelism v2 with Float8 Precision for Faster Training” study. This configuration was designed to evaluate the behavior of FSDP2 + FP8 mixed precision at large scale. In this specific setup, the FP8 run completed in 23.50 minutes, compared to 22.00 minutes for the BF16 baseline—indicating that FP8 was approximately 6.8% slower at this batch-size and node-count. This result shows that FP8 speed benefits depend strongly on batch size, communication scaling, and numerical-precision overhead, and that FP8 does not universally outperform BF16 in all distributed settings.
Setups
- Precision: Used Half-precision bfloat16 as data type and for computation.
- Hardware: HPC (EuroHPC/BSC-class) 4 nodes with 4 × NVIDIA H100 GPUs.
- Framework: PyTorch with
torchrunfor distributed training.
Experiment Context
This model was trained as part of our study for comparing FSDP2 with bfloat16 precision against FSDP2 with FP8 mixed precision bfp16-fp8.
We used meta-llama/Llama-3.1-8B-Instruct. The model has been loaded using torch_dtype = bfloat16 and for FP8 + FSDP2 compatibility the model has been wrap per-layer instead of whole model This helped to avoid dimension misalignment issues and during forward and backward passes float8 variats been used for computations where FP8E4M3 for activations (forward pass)
and FP8E5M2 for gradients (backward pass, wider range) also we setted the pad_inner_dim for automatically pad dimensions to be divisible by 16 which is required for FP8.
from torchao.float8 import (
convert_to_float8_training,
Float8LinearConfig,
precompute_float8_dynamic_scale_for_fsdp,
)
config = Float8LinearConfig(
pad_inner_dim=True,
enable_fsdp_float8_all_gather=True,
)
model = convert_to_float8_training(model, config=config)
if use_fp8:
for i, layer in enumerate(model.model.layers):
fully_shard(layer, **fsdp_kwargs)
fully_shard(model.model.embed_tokens, **fsdp_kwargs)
fully_shard(model.lm_head, **fsdp_kwargs)
Base Model Technical Specifications
- Parameters: 8 Billion
- Architecture Family: Llama 3.1
- Maximum Position Embeddings: 131,072
- Attention Heads: 32 (
num_attention_heads) - Key-Value Heads: 8 (
num_key_value_heads) - Hidden Layers: 32 (
num_hidden_layers) - Hidden Size: 4,096 (
hidden_size) - Intermediate Size: 14,336
- Vocabulary Size: 128,256
- Precision: bfloat16
- RoPE Scaling: type
llama3, factor = 8.0 - RMS Norm Epsilon: 1e-05
- Activation: SiLU
Training Methodology
Training Configuration
- Model:
meta-llama/Llama-3.1-8B-Instruct - Sequence Length: 4,096 (
seq_len) - Epochs: 2
- Per-Device Micro Batch Size: 4
- Gradient Accumulation: 8
- GPUs: 4 (via
CUDA_VISIBLE_DEVICES=0,1,2,3) - dtype:
bf16&&fp8=true- Weights: bfloat16
- Activations: float8
- Optimizer: AdamW
- Learning Rate: 2e-5
- Weight Decay: 0.01
- Betas: (0.9, 0.95)
- Epsilon: 1e-8
- LR Scheduler: Cosine; warmup = 10% (
warmup_ratio=0.1) | alsowarmup_steps=100 - Max Grad Norm: 1.0
- Gradient Checkpointing: Enabled
- Checkpointing: every 10 steps; keep last 5; select best by
eval_loss - Logging: every step to file; Weights & Biases in offline mode
- Seed: 100
- Distributed Training:
torch.distributed.run(4 nodes, multi-GPU)- FSDP2 (Optimized Fully Sharded Data Parallel)
Dependencies
| package | Version |
|---|---|
| Transformers | 4.57.1 |
| torch | 2.9.0+cu128 |
| accelerate | 0.14.1 |
| datasets | 4.3.0 |
| huggingface-hub | 0.36.0 |
| tensorboard | 2.20.0 |
| tensorboard-data-server | 0.7.2 |
| wandb | 0.22.1 |
Job Details
| model | Job ID | Runtime (mins) | Nodes | GPUs | Node-hour | GPU-hour | micro-batch | batch-size | gradient_accumulation | total_batch_size |
|---|---|---|---|---|---|---|---|---|---|---|
| Llama-3.1-8B-Instruct-w16a16-1node | 31472940 | 51.50 | 1 | 4 | 0.858 | 3.433 | 2 | 2 | 4 | 32 |
| Llama-3.1-8B-Instruct-w16a8-1node | 31473092 | 47.25 | 1 | 4 | 0.788 | 3.151 | 2 | 2 | 4 | 32 |
| Llama-3.1-8B-Instruct-w16a16-4nodes | 31478433 | 31.75 | 4 | 4 | 2.117 | 8.467 | 4 | 4 | 8 | 512 |
| Llama-3.1-8B-Instruct-w16a8-4nodes | 31478468 | 39.75 | 4 | 4 | 2.650 | 10.600 | 4 | 4 | 8 | 512 |
| Llama-3.1-8B-Instruct-w16a16-8nodes | 31476914 | 22.00 | 8 | 4 | 2.933 | 11.733 | 4 | 4 | 8 | 1024 |
| Llama-3.1-8B-Instruct-w16a8-8nodes | 31476844 | 23.50 | 8 | 4 | 3.133 | 12.533 | 4 | 4 | 8 | 1024 |
Computational Infrastructure
- Platform: HPC
- GPUs: NVIDIA H100 (32)
All 6-models trained on(1Node,4Noes,8Nodes with both bfp16-fp8 && bfp16 configurations)
| Model | Batch Size | Max Loss (train) | Min Loss (train) | Avg Loss (train) | ± Std (train) | Final Loss (train) | Max Loss (val) | Min Loss (val) | Avg Loss (val) | ± Std (val) | Final Loss (val) | Total Step | Best Step |
|---|---|---|---|---|---|---|---|---|---|---|---|---|---|
| Llama-3.1-8B-Instruct-w16a16-1node | 8 | 3.1235 | 0.7203 | 0.9750 | 0.3344 | 0.7612 | 1.9113 | 0.8907 | 0.9831 | 0.1897 | 0.8907 | 312 | — |
| Llama-3.1-8B-Instruct-w16a8-1node | 8 | 3.1661 | 0.7261 | 0.9804 | 0.3374 | 0.7672 | 1.9230 | 0.8948 | 0.9867 | 0.1906 | 0.8951 | 312 | — |
| Llama-3.1-8B-Instruct-w16a16-4nodes | 32 | 3.2452 | 0.7414 | 0.9665 | 0.4844 | 0.7504 | 1.0538 | 0.8382 | 0.8844 | 0.0725 | 0.8382 | 70 | — |
| Llama-3.1-8B-Instruct-w16a8-4nodes | 32 | 3.2840 | 0.7478 | 0.9748 | 0.4905 | 0.7581 | 1.0701 | 0.8430 | 0.8922 | 0.0764 | 0.8430 | 70 | — |
| Llama-3.1-8B-Instruct-w16a16-8nodes | 32 | 3.2311 | 0.8448 | 1.1856 | 0.6434 | 0.8448 | 1.0257 | 0.8977 | 0.9460 | 0.0568 | 0.8977 | 35 | — |
| Llama-3.1-8B-Instruct-w16a8-8nodes | 32 | 3.3003 | 0.8473 | 1.1866 | 0.6481 | 0.8473 | 1.0203 | 0.8992 | 0.9445 | 0.0539 | 0.8992 | 35 | — |
Usage
Note: the final model has saved in bfloat16 format. For inference, load the model in bfloat16 or float16 as shown below:
from transformers import AutoModelForCausalLM, AutoTokenizer
import torch
model_name = newmindai/Llama-3.1-8B-Instruct-w16a8-4nodes-ts
dtype = torch.bfloat16
tok = AutoTokenizer.from_pretrained(model_name)
model = AutoModelForCausalLM.from_pretrained(
model_name,
torch_dtype=dtype,
device_map="auto"
)
prompt = "Soru: Kişisel Verilerin Korunması Kanunu uyarınca hangi durumlarda açık rıza aranmaz? Cevap:"
inputs = tok(prompt, return_tensors="pt").to(model.device)
with torch.no_grad():
out = model.generate(
**inputs,
max_new_tokens=256,
do_sample=False
)
print(tok.decode(out[0], skip_special_tokens=True))
Ethical Considerations and Disclaimers
- Research & development purposes only; not a substitute for professional legal counsel.
- Users must ensure compliance with data protection and sector regulations.
- Potential biases may exist in domain data and model outputs.
Model & Data Card Metadata
- Total Parameters: 8,030,261,248
- Serialized Size (approx.): 16,060,522,496 bytes
- Config precision: bfloat16
- RoPE: llama3 scaling, factor 8.0
References and Citations
Base Model
@misc{meta_llama31_8b_instruct,
title={Llama 3.1 8B Instruct},
author={Meta AI},
year={2024},
howpublished={\url{https://huggingface.co/meta-llama/Llama-3.1-8B-Instruct}}
}
Training Dataset
@misc{euro_hpc_legal,
title={EuroHPC-Legal},
author={newmindai},
year={2025},
howpublished={\url{https://huggingface.co/datasets/newmindai/EuroHPC-Legal}}
}
- Downloads last month
- 9
Model tree for newmindai/Llama-3.1-8B-Instruct-w16a8-8nodes
Base model
meta-llama/Llama-3.1-8B



