Abstract

This repository provides a domain-adapted Turkish legal instruction-tuned model derived from Meta’s Llama-3.1-8B-Instruct. As part of the “Harnessing Fully Sharded Data Parallelism v2 with Float8 Precision for Faster Training” study, this configuration represents the BF16 variant trained on 4 nodes with a 32 global batch size. In this scaling regime, FP8 mixed-precision did not yield a runtime improvement over BF16, highlighting how FP8 efficiency varies with batch size, sequence parallelism, and multi-node communication overhead. This model provides a strong BF16 baseline for comparison across all batch-size and node-scaling experiments in the study.

Experiment Context

This model was trained as part of our study for comparing FSDP2 with bfloat16 precision against FSDP2 with FP8 mixed precision bfp16-fp8. We used meta-llama/Llama-3.1-8B-Instruct. The model has been loaded using torch_dtype = bfloat16 and wrapped at once, also during forward/backward passes bfloat16 has been used for computations.

from torch.distributed._composable.fsdp import fully_shard
mesh_device_type = "cuda" if use_cuda else "cpu"
mesh = DeviceMesh(mesh_device_type, list(range(world_size)))
fsdp_kwargs = {
    "mesh": mesh,
    "reshard_after_forward": True,
}
model = fully_shard(model, **fsdp_kwargs)

Base Model Technical Specifications

  • Parameters: 8 Billion
  • Architecture Family: Llama 3.1
  • Maximum Position Embeddings: 131,072
  • Attention Heads: 32 (num_attention_heads)
  • Key-Value Heads: 8 (num_key_value_heads)
  • Hidden Layers: 32 (num_hidden_layers)
  • Hidden Size: 4,096 (hidden_size)
  • Intermediate Size: 14,336
  • Vocabulary Size: 128,256
  • Precision: bfloat16
  • RoPE Scaling: type llama3, factor = 8.0
  • RMS Norm Epsilon: 1e-05
  • Activation: SiLU

Training Methodology

Training Configuration

  • Model: meta-llama/Llama-3.1-8B-Instruct
  • Sequence Length: 4,096 (seq_len)
  • Epochs: 2
  • Max Steps: 1,200
  • Per-Device Micro Batch Size: 4
  • Gradient Accumulation: 8
  • dtype: bf16 && fp8=false
    • Weights: bfloat16
    • Activations: bfloat16
  • Optimizer: AdamW
    • Learning Rate: 2e-5
    • Weight Decay: 0.01
    • Betas: (0.9, 0.95)
    • Epsilon: 1e-8
  • LR Scheduler: Cosine; warmup = 10% (warmup_ratio=0.1) | also warmup_steps=100
  • Max Grad Norm: 1.0
  • Gradient Checkpointing: Enabled
  • Checkpointing: every 10 steps; keep last 5; select best by eval_loss
  • Logging: every step to file; Weights & Biases in offline mode
  • Seed: 100
  • Distributed Training: torch.distributed.run (multi-nodes, multi-GPU)
    • FSDP2 (Optimized Fully Sharded Data Parallel)

Setups

  • Precision: Used Half-precision bfloat16 as data type and for computation.
  • Hardware: HPC (EuroHPC/BSC-class) 4 nodes with 4 × NVIDIA H100 GPUs.
  • Framework: PyTorch with torchrun for distributed training.

Dependencies

package Version
Transformers 4.57.1
torch 2.9.0+cu128
accelerate 0.14.1
datasets 4.3.0
huggingface-hub 0.36.0
tensorboard 2.20.0
tensorboard-data-server 0.7.2
wandb 0.22.1

Job Details

model Job ID Runtime (mins) Nodes GPUs Node-hour GPU-hour micro-batch batch-size gradient_accumulation total_batch_size
Llama-3.1-8B-Instruct-w16a16-1node 31472940 51.50 1 4 0.858 3.433 2 2 4 32
Llama-3.1-8B-Instruct-w16a8-1node 31473092 47.25 1 4 0.788 3.151 2 2 4 32
Llama-3.1-8B-Instruct-w16a16-4nodes 31478433 31.75 4 4 2.117 8.467 4 4 8 512
Llama-3.1-8B-Instruct-w16a8-4nodes 31478468 39.75 4 4 2.650 10.600 4 4 8 512
Llama-3.1-8B-Instruct-w16a16-8nodes 31476914 22.00 8 4 2.933 11.733 4 4 8 1024
Llama-3.1-8B-Instruct-w16a8-8nodes 31476844 23.50 8 4 3.133 12.533 4 4 8 1024

Computational Infrastructure

  • Platform: HPC
  • GPUs: NVIDIA H100 (32)

All 6-models trained on(1Node,4Noes,8Nodes with both bfp16-fp8 && bfp16 configurations)

perplexity metric results for bfp16 && bfp16-fp8 configurations Accuracy metric results for bfp16 && bfp16-fp8 configurations Loss metric results for bfp16 && bfp16-fp8 configurations Memory allocation for bfp16 && bfp16-fp8 configurations Utilization for bfp16 && bfp16-fp8 configurations
prep_train acc_train loss_train mem_al utils
Model Batch Size Max Loss (train) Min Loss (train) Avg Loss (train) ± Std (train) Final Loss (train) Max Loss (val) Min Loss (val) Avg Loss (val) ± Std (val) Final Loss (val) Total Step Best Step
Llama-3.1-8B-Instruct-w16a16-1node 8 3.1235 0.7203 0.9750 0.3344 0.7612 1.9113 0.8907 0.9831 0.1897 0.8907 312
Llama-3.1-8B-Instruct-w16a8-1node 8 3.1661 0.7261 0.9804 0.3374 0.7672 1.9230 0.8948 0.9867 0.1906 0.8951 312
Llama-3.1-8B-Instruct-w16a16-4nodes 32 3.2452 0.7414 0.9665 0.4844 0.7504 1.0538 0.8382 0.8844 0.0725 0.8382 70
Llama-3.1-8B-Instruct-w16a8-4nodes 32 3.2840 0.7478 0.9748 0.4905 0.7581 1.0701 0.8430 0.8922 0.0764 0.8430 70
Llama-3.1-8B-Instruct-w16a16-8nodes 32 3.2311 0.8448 1.1856 0.6434 0.8448 1.0257 0.8977 0.9460 0.0568 0.8977 35
Llama-3.1-8B-Instruct-w16a8-8nodes 32 3.3003 0.8473 1.1866 0.6481 0.8473 1.0203 0.8992 0.9445 0.0539 0.8992 35

Implementation

Usage

Note: the final model has been saved in bfloat16 format. For inference, load the model in bfloat16 as shown below:

from transformers import AutoModelForCausalLM, AutoTokenizer
import torch
model_name = "newmindai/Llama-3.1-8B-Instruct-w16a16-4nodes"
dtype = torch.bfloat16
tok = AutoTokenizer.from_pretrained(model_name)
model = AutoModelForCausalLM.from_pretrained(
    model_name,
    torch_dtype=dtype,
    device_map="auto"
)
prompt = "Soru: Kişisel Verilerin Korunması Kanunu uyarınca hangi durumlarda açık rıza aranmaz? Cevap:"
inputs = tok(prompt, return_tensors="pt").to(model.device)
with torch.no_grad():
    out = model.generate(
        **inputs,
        max_new_tokens=256,
        do_sample=False
    )

print(tok.decode(out[0], skip_special_tokens=True))

Ethical Considerations and Disclaimers

  • Research & development purposes only; not a substitute for professional legal counsel.
  • Users must ensure compliance with data protection and sector regulations.
  • Potential biases may exist in domain data and model outputs.

Model & Data Card Metadata

  • Total Parameters: 8,030,261,248
  • Serialized Size (approx.): 16,060,522,496 bytes
  • Config precision: bfloat16
  • RoPE: llama3 scaling, factor 8.0

References and Citations

Base Model

@misc{meta_llama31_8b_instruct,
  title={Llama 3.1 8B Instruct},
  author={Meta AI},
  year={2024},
  howpublished={\url{https://huggingface.co/meta-llama/Llama-3.1-8B-Instruct}}
}

Training Dataset

@misc{euro_hpc_legal,
  title={EuroHPC-Legal},
  author={newmindai},
  year={2025},
  howpublished={\url{https://huggingface.co/datasets/newmindai/EuroHPC-Legal}}
}
Downloads last month
6
Safetensors
Model size
8B params
Tensor type
BF16
·
Inference Providers NEW
This model isn't deployed by any Inference Provider. 🙋 Ask for provider support

Model tree for newmindai/Llama-3.1-8B-Instruct-w16a16-4nodes

Finetuned
(1966)
this model

Dataset used to train newmindai/Llama-3.1-8B-Instruct-w16a16-4nodes

Collection including newmindai/Llama-3.1-8B-Instruct-w16a16-4nodes