--- --- language: - tr license: llama3.1 base_model: meta-llama/Llama-3.1-8B-Instruct tags: - legal - turkish - instruction-tuned - llama-3.1 - fp8 - bfloat16 - mixed-precision - question-answering - fsdp-v2 - pytorch - distributed-training - torchao datasets: - newmindai/EuroHPC-Legal library_name: transformers pipeline_tag: text-generation model-index: - name: Llama-3.1-8B-Instruct-w16a16-4nodes results: [] --- ## **Abstract** This repository provides a domain-adapted Turkish legal instruction-tuned model derived from Meta’s Llama-3.1-8B-Instruct. As part of the “Harnessing Fully Sharded Data Parallelism v2 with Float8 Precision for Faster Training” study, this configuration represents the BF16 variant trained on 4 nodes with a 32 global batch size. In this scaling regime, FP8 mixed-precision did not yield a runtime improvement over BF16, highlighting how FP8 efficiency varies with batch size, sequence parallelism, and multi-node communication overhead. This model provides a strong BF16 baseline for comparison across all batch-size and node-scaling experiments in the study. ## **Experiment Context** This model was trained as part of our study for comparing **FSDP2 with bfloat16 precision** against **FSDP2 with FP8 mixed precision bfp16-fp8**. We used `meta-llama/Llama-3.1-8B-Instruct`. The model has been loaded using `torch_dtype = bfloat16` and wrapped at once, also during forward/backward passes `bfloat16` has been used for computations. ```python from torch.distributed._composable.fsdp import fully_shard mesh_device_type = "cuda" if use_cuda else "cpu" mesh = DeviceMesh(mesh_device_type, list(range(world_size))) fsdp_kwargs = { "mesh": mesh, "reshard_after_forward": True, } model = fully_shard(model, **fsdp_kwargs) ```` ## **Base Model Technical Specifications** - **Parameters**: 8 Billion - **Architecture Family**: Llama 3.1 - **Maximum Position Embeddings**: 131,072 - **Attention Heads**: 32 (`num_attention_heads`) - **Key-Value Heads**: 8 (`num_key_value_heads`) - **Hidden Layers**: 32 (`num_hidden_layers`) - **Hidden Size**: 4,096 (`hidden_size`) - **Intermediate Size**: 14,336 - **Vocabulary Size**: 128,256 - **Precision**: bfloat16 - **RoPE Scaling**: type `llama3`, factor = 8.0 - **RMS Norm Epsilon**: 1e-05 - **Activation**: SiLU ## **Training Methodology** ### *Training Configuration* - **Model**: `meta-llama/Llama-3.1-8B-Instruct` - **Sequence Length**: 4,096 (`seq_len`) - **Epochs**: 2 - **Max Steps**: 1,200 - **Per-Device Micro Batch Size**: 4 - **Gradient Accumulation**: 8 - **dtype**: `bf16` && `fp8=false` - Weights: bfloat16 - Activations: bfloat16 - **Optimizer**: AdamW - Learning Rate: 2e-5 - Weight Decay: 0.01 - Betas: (0.9, 0.95) - Epsilon: 1e-8 - **LR Scheduler**: Cosine; warmup = 10% (`warmup_ratio=0.1`) | also `warmup_steps=100` - **Max Grad Norm**: 1.0 - **Gradient Checkpointing**: Enabled - **Checkpointing**: every 10 steps; keep last 5; select best by `eval_loss` - **Logging**: every step to file; Weights & Biases in offline mode - **Seed**: 100 - **Distributed Training**: `torch.distributed.run` (multi-nodes, multi-GPU) - FSDP2 (Optimized Fully Sharded Data Parallel) ### *Setups* - **Precision**: Used Half-precision bfloat16 as data type and for computation. - **Hardware**: HPC (EuroHPC/BSC-class) 4 nodes with 4 × NVIDIA H100 GPUs. - **Framework**: PyTorch with `torchrun` for distributed training. ### *Dependencies* | package | Version | |-------|--------| | Transformers | 4.57.1 | | torch | 2.9.0+cu128 | | accelerate | 0.14.1 | | datasets | 4.3.0 | | huggingface-hub | 0.36.0 | | tensorboard | 2.20.0 | | tensorboard-data-server | 0.7.2 | | wandb | 0.22.1 | ## Job Details | model | Job ID | Runtime (mins) | Nodes | GPUs | Node-hour | GPU-hour | micro-batch | batch-size | gradient_accumulation | total_batch_size | | ---------------------------------------- | -------- | -------------- | ----- | ---- | --------- | ---------- | ----------- | ---------- | --------------------- | ---------------- | | Llama-3.1-8B-Instruct-w16a16-1node | 31472940 | 51.50 | 1 | 4 | **0.858** | **3.433** | 2 | 2 | 4 | 32 | | Llama-3.1-8B-Instruct-w16a8-1node | 31473092 | 47.25 | 1 | 4 | **0.788** | **3.151** | 2 | 2 | 4 | 32 | | Llama-3.1-8B-Instruct-w16a16-4nodes | 31478433 | 31.75 | 4 | 4 | **2.117** | **8.467** | 4 | 4 | 8 | 512 | | Llama-3.1-8B-Instruct-w16a8-4nodes | 31478468 | 39.75 | 4 | 4 | **2.650** | **10.600** | 4 | 4 | 8 | 512 | | Llama-3.1-8B-Instruct-w16a16-8nodes | 31476914 | 22.00 | 8 | 4 | **2.933** | **11.733** | 4 | 4 | 8 | 1024 | | Llama-3.1-8B-Instruct-w16a8-8nodes | 31476844 | 23.50 | 8 | 4 | **3.133** | **12.533** | 4 | 4 | 8 | 1024 | ### Computational Infrastructure - **Platform**: HPC - **GPUs**: NVIDIA H100 (32) - # *All 6-models trained on(1Node,4Noes,8Nodes with both bfp16-fp8 && bfp16 configurations)* | perplexity metric results for bfp16 && bfp16-fp8 configurations | Accuracy metric results for bfp16 && bfp16-fp8 configurations | Loss metric results for bfp16 && bfp16-fp8 configurations | Memory allocation for bfp16 && bfp16-fp8 configurations | Utilization for bfp16 && bfp16-fp8 configurations | |:--:|:--:|:--:|:--:|:--:| | ![prep_train](https://cdn-uploads.huggingface.co/production/uploads/683d4880e639f8d647355997/GI7A-nl-gUev5h_wiAfAW.png) | ![acc_train](https://cdn-uploads.huggingface.co/production/uploads/683d4880e639f8d647355997/VRv9cnkYAu-HIychwvRqL.png) | ![loss_train](https://cdn-uploads.huggingface.co/production/uploads/683d4880e639f8d647355997/rLWCDnxTI6i3qArom_CKM.png) | ![mem_al](https://cdn-uploads.huggingface.co/production/uploads/683d4880e639f8d647355997/2ymYZO41W67ZW2I817AvW.png) | ![utils](https://cdn-uploads.huggingface.co/production/uploads/683d4880e639f8d647355997/VFsmGbKLWvlXdhrXjxEGR.png) | | Model | Batch Size | Max Loss (train) | Min Loss (train) | Avg Loss (train) | ± Std (train) | Final Loss (train) | Max Loss (val) | Min Loss (val) | Avg Loss (val) | ± Std (val) | Final Loss (val) | Total Step | Best Step | | ----------------------------------- | ---------- | ---------------- | ---------------- | ---------------- | ------------- | ------------------ | -------------- | -------------- | -------------- | ----------- | ---------------- | ---------- | --------- | | Llama-3.1-8B-Instruct-w16a16-1node | 8 | 3.1235 | 0.7203 | 0.9750 | 0.3344 | 0.7612 | 1.9113 | 0.8907 | 0.9831 | 0.1897 | 0.8907 | 312 | — | | Llama-3.1-8B-Instruct-w16a8-1node | 8 | 3.1661 | 0.7261 | 0.9804 | 0.3374 | 0.7672 | 1.9230 | 0.8948 | 0.9867 | 0.1906 | 0.8951 | 312 | — | | Llama-3.1-8B-Instruct-w16a16-4nodes | 32 | 3.2452 | 0.7414 | 0.9665 | 0.4844 | 0.7504 | 1.0538 | 0.8382 | 0.8844 | 0.0725 | 0.8382 | 70 | — | | Llama-3.1-8B-Instruct-w16a8-4nodes | 32 | 3.2840 | 0.7478 | 0.9748 | 0.4905 | 0.7581 | 1.0701 | 0.8430 | 0.8922 | 0.0764 | 0.8430 | 70 | — | | Llama-3.1-8B-Instruct-w16a16-8nodes | 32 | 3.2311 | 0.8448 | 1.1856 | 0.6434 | 0.8448 | 1.0257 | 0.8977 | 0.9460 | 0.0568 | 0.8977 | 35 | — | | Llama-3.1-8B-Instruct-w16a8-8nodes | 32 | 3.3003 | 0.8473 | 1.1866 | 0.6481 | 0.8473 | 1.0203 | 0.8992 | 0.9445 | 0.0539 | 0.8992 | 35 | — | ## Implementation ### *Usage* **Note**: the final model has been saved in bfloat16 format. For inference, load the model in bfloat16 as shown below: ```python from transformers import AutoModelForCausalLM, AutoTokenizer import torch model_name = "newmindai/Llama-3.1-8B-Instruct-w16a16-4nodes" dtype = torch.bfloat16 tok = AutoTokenizer.from_pretrained(model_name) model = AutoModelForCausalLM.from_pretrained( model_name, torch_dtype=dtype, device_map="auto" ) prompt = "Soru: Kişisel Verilerin Korunması Kanunu uyarınca hangi durumlarda açık rıza aranmaz? Cevap:" inputs = tok(prompt, return_tensors="pt").to(model.device) with torch.no_grad(): out = model.generate( **inputs, max_new_tokens=256, do_sample=False ) print(tok.decode(out[0], skip_special_tokens=True)) ```` ## **Ethical Considerations and Disclaimers** * Research & development purposes only; not a substitute for professional legal counsel. * Users must ensure compliance with data protection and sector regulations. * Potential biases may exist in domain data and model outputs. ## **Model & Data Card Metadata** * **Total Parameters**: 8,030,261,248 * **Serialized Size (approx.)**: 16,060,522,496 bytes * **Config precision**: bfloat16 * **RoPE**: llama3 scaling, factor 8.0 ## **References and Citations** ### *Base Model* ```bibtex @misc{meta_llama31_8b_instruct, title={Llama 3.1 8B Instruct}, author={Meta AI}, year={2024}, howpublished={\url{https://huggingface.co/meta-llama/Llama-3.1-8B-Instruct}} } ``` ### *Training Dataset* ```bibtex @misc{euro_hpc_legal, title={EuroHPC-Legal}, author={newmindai}, year={2025}, howpublished={\url{https://huggingface.co/datasets/newmindai/EuroHPC-Legal}} } ```