MedSLM โ Medical Small Language Model (~381M Parameters)
A 381M parameter transformer language model pre-trained on curated medical text from PubMed abstracts, PMC full-text articles, and clinical guidelines.
Architecture
MedSLM uses a modern GPT-style transformer with several architectural improvements over the standard GPT-2 design:
| Component | Detail |
|---|---|
| Normalization | RMSNorm (faster than LayerNorm, used in LLaMA/Mistral) |
| Positional Encoding | Rotary Positional Embeddings (RoPE) โ better length generalization |
| Feed-Forward | SwiGLU activation (gated FFN, outperforms GELU) |
| Attention | Grouped-Query Attention (GQA) โ shared KV heads for efficiency |
| Layers | 24 transformer blocks |
| Attention Heads | 16 query heads, 8 KV heads |
| Embedding Dim | 1024 |
| Context Length | 1024 tokens |
| Vocab Size | 50,257 (GPT-2 BPE tokenizer) |
| Parameters | 381,373,440 (~381M) |
Training
- Dataset:
Saminx22/medical_data_for_slm(~44M tokens) - Sources: PubMed abstracts, PMC Open Access full-text, Clinical Guidelines
- Tokenizer: GPT-2 BPE tokenizer (50,257 vocab)
- Optimizer: AdamW (betas=0.9/0.95, weight_decay=0.1)
- LR Schedule: Linear warmup (1000 steps) + Cosine decay
- Peak LR: 0.0003
- Precision: bfloat16
- Effective Batch Size: 256
- Max Steps: 20,000
- Best Val Loss: 3.2198 (at step 19500)
Usage
Loading the Model
import torch
import json
from safetensors.torch import load_file
from transformers import AutoTokenizer
# Load config
with open("config.json") as f:
config_dict = json.load(f)
# Reconstruct model (requires the MedSLM class definition)
config = MedSLMConfig(**{k: v for k, v in config_dict.items()
if k in MedSLMConfig.__dataclass_fields__})
model = MedSLM(config)
# Load weights
state_dict = load_file("model.safetensors")
model.load_state_dict(state_dict)
model.eval()
# Load tokenizer
tokenizer = AutoTokenizer.from_pretrained("tokenizer/")
Generating Text
prompt = "The patient presented with acute myocardial infarction"
input_ids = torch.tensor(tokenizer.encode(prompt)).unsqueeze(0)
output = model.generate(input_ids, max_new_tokens=200, temperature=0.8, top_k=50, top_p=0.9)
print(tokenizer.decode(output.squeeze().tolist()))
Resuming Training
# Load optimizer state
optimizer_state = torch.load("optimizer.pt")
optimizer.load_state_dict(optimizer_state)
Files
| File | Description |
|---|---|
model.safetensors |
Model weights (safetensors format) |
optimizer.pt |
Optimizer state dict for resuming training |
config.json |
Model architecture configuration |
training_config.json |
Training hyperparameters and loss history |
tokenizer/ |
GPT-2 tokenizer files |
loss_curves.png |
Training/validation loss plot |
Intended Use
This model is intended for research purposes in medical NLP. It can be used as:
- A foundation model for downstream medical NLP tasks (NER, classification, QA)
- A starting point for medical instruction tuning
- A baseline for comparing medical language model architectures
Limitations
- Not for clinical use: This model should NOT be used for clinical decision-making
- Small scale: ~381M parameters is relatively small; larger models will perform better
- Limited data: Trained on ~44M tokens (production models use trillions)
- No alignment: This is a base model without instruction tuning or RLHF
- English only: Trained exclusively on English medical text
- Potential biases: May reflect biases present in the medical literature
License
Apache 2.0
- Downloads last month
- 220