HRM-Text1 Amharic Model
This is a custom text generation model based on the Hierarchical Recurrent Memory (HRM) architecture. It was trained from scratch on the amanuelbyte/Amharic_dataset.
This is a custom model and requires trust_remote_code=True to load.
How to Use
Because this is a custom architecture, you need to load the model by importing the HRMText1 class from the hrm_model.py file.
import torch
from transformers import T5Tokenizer
from huggingface_hub import hf_hub_download
from hrm_model import HRMText1 # Import the custom class
import json
# Replace with your repo ID
repo_id = "amanuelbyte/HRM-amharic"
device = "cuda" if torch.cuda.is_available() else "cpu"
# 1. Load the tokenizer
tokenizer = T5Tokenizer.from_pretrained(repo_id)
# 2. Load the model's configuration
config_path = hf_hub_download(repo_id=repo_id, filename="config.json")
with open(config_path, 'r') as f:
config = json.load(f)
# 3. Instantiate the model with the config
# The trust_remote_code=True is not strictly needed here because we import manually,
# but it's good practice for custom models.
model = HRMText1(config)
# 4. Load the model weights
weights_path = hf_hub_download(repo_id=repo_id, filename="pytorch_model.bin")
state_dict = torch.load(weights_path, map_location=device)
model.load_state_dict(state_dict)
model.to(device)
model.eval()
print("Model loaded successfully!")
# Now you can use the model for generation...
prompt = "የኢትዮጵያ ዋና ከተማ"
input_ids = tokenizer.encode(prompt, return_tensors="pt").to(device)
with torch.inference_mode():
output_ids = model.generate(input_ids, max_new_tokens=50) # Assuming a generate method exists
print(tokenizer.decode(output_ids, skip_special_tokens=True))
- Downloads last month
- 5