SDG Classifier: A Fine-Tuned LUKE Model for Multi-Label SDG Classification

This repository contains the pre-trained model weights (best_model.pt) for the paper: "Bridging the Sustainable Development Goals: A Multi-Label Text Classification Approach for Mapping and Visualizing Nexuses in Sustainability Research".

➑️ GitHub Repository (Code): [https://github.com/Green-Engineers-Lab/SDGs-classifier/]
➑️ Paper Link: [Link to Published Paper will be added upon publication]

πŸ“ Model Description

This model is a fine-tuned version of studio-ousia/luke-large-lite for multi-label text classification of the 17 UN Sustainable Development Goals (SDGs). It has been trained on a uniquely diverse, multi-sectoral, and multilingual corpus designed to achieve high generalization performance across various domains (academic, policy, civil society, etc.).

The model takes a text input (up to 512 tokens) and outputs a probability score for each of the 17 SDGs, indicating the relevance of the text to each goal.

πŸš€ How to Use

This model was trained with a custom classification head in PyTorch. To use it, you need to define the model architecture first and then load the downloaded weights (best_model.pt).

Below is a complete example of how to load the model and perform a prediction.

import torch
from torch import nn
from transformers import AutoTokenizer, AutoModel
from huggingface_hub import hf_hub_download
from pathlib import Path

# --- 1. Define the Model Architecture ---
# This class must match the architecture used during training.
# You can copy this class from the original training script.
class SDGClassifier(nn.Module):
    def __init__(self, model_path, pooler_dropout, class_number):
        super(SDGClassifier, self).__init__()
        self.bert = AutoModel.from_pretrained(model_path)
        self.dropout = nn.Dropout(pooler_dropout)
        self.pooler = nn.Sequential(nn.Linear(in_features=self.bert.config.hidden_size, out_features=self.bert.config.hidden_size))
        self.tanh = nn.Tanh()
        self.cls = nn.Linear(in_features=self.bert.config.hidden_size, out_features=class_number)

    def forward(self, input_ids, attention_mask, token_type_ids, position, labels):
        # Note: 'position' and 'labels' are dummy inputs required by the forward signature,
        # but are not used for inference if labels are not provided.
        bert_output = self.bert(input_ids, attention_mask, token_type_ids=token_type_ids, output_attentions=True, output_hidden_states=True)
        average_hidden_state = (bert_output.last_hidden_state * attention_mask.unsqueeze(-1)).sum(1) / attention_mask.sum(1, keepdim=True)
        pooler_output = self.tanh(self.pooler(self.dropout(average_hidden_state)))
        logits = self.cls(pooler_output)
        return logits, average_hidden_state, bert_output.attentions

# --- 2. Setup and Load Model ---
device = torch.device("cuda" if torch.cuda.is_available() else "cpu")

# Model configuration
BASE_MODEL = 'studio-ousia/luke-large-lite'
NUM_CLASSES = 17
DROPOUT_RATE = 0.26 # This is the optimized dropout rate from the paper's training

# Instantiate the model
model = SDGClassifier(model_path=BASE_MODEL, pooler_dropout=DROPOUT_RATE, class_number=NUM_CLASSES).to(device)
model.eval() # Set to evaluation mode

# Download the fine-tuned weights from this Hub
model_weights_path = hf_hub_download(
    repo_id="GE-Lab/SDGs-classifier",
    filename="best_model.pt"
)

# Load the weights into the model
model.load_state_dict(torch.load(model_weights_path, map_location=device))

print("Model loaded successfully!")

# --- 3. Prepare Input ---
tokenizer = AutoTokenizer.from_pretrained(BASE_MODEL)
text = "Our research focuses on renewable energy solutions to combat climate change and ensure a sustainable future for all."

inputs = tokenizer.encode_plus(
    text,
    None,
    add_special_tokens=True,
    max_length=512,
    padding='max_length',
    return_token_type_ids=True,
    truncation=True,
    return_tensors='pt'
).to(device)

# The model's forward pass requires these additional dummy inputs
inputs['position'] = torch.arange(0, inputs['input_ids'].shape[1]).unsqueeze(0).to(device)
inputs['labels'] = torch.zeros(1, NUM_CLASSES).to(device) # Dummy labels for inference

# --- 4. Get Predictions ---
with torch.no_grad():
    logits, _, _ = model(**inputs)
    probabilities = torch.sigmoid(logits).cpu().numpy()[0]
    predictions = (probabilities > 0.5).astype(int)

# --- 5. Interpret the Results ---
goal_contents = ['Goal 1: No Poverty','Goal 2: Zero Hunger','Goal 3: Good Health and Well-being','Goal 4: Quality Education','Goal 5: Gender Equality','Goal 6: Clean Water and Sanitation','Goal 7: Affordable and Clean Energy','Goal 8: Decent Work and Economic Growth','Goal 9: Industry, Innovation and Infrastructure','Goal 10: Reduced Inequalities','Goal 11: Sustainable Cities and Communities','Goal 12: Responsible Consumption and Production','Goal 13: Climate Action','Goal 14: Life Below Water','Goal 15: Life on Land','Goal 16: Peace, Justice and Strong Institutions','Goal 17: Partnerships for the Goals']

print(f"\nText: '{text}'")
print("\n--- Predicted SDGs (Threshold > 0.5) ---")
predicted_goals = [goal_contents[i] for i, pred in enumerate(predictions) if pred == 1]
if predicted_goals:
    for goal in predicted_goals:
        print(goal)
else:
    print("No SDGs detected with a probability > 0.5")

print("\n--- All SDG Probabilities ---")
for i, prob in enumerate(probabilities):
    print(f"{goal_contents[i]:<55}: {prob:.2%}")

πŸ“ˆ Training and Evaluation

Training Data

The model was trained on a novel, heterogeneous corpus of 23,969 multi-labeled documents from 11 diverse sources, including government, academia, industry, and civil society, with some sources translated from Japanese. This approach was designed to address the "interpretive diversity" of SDG-related language.

For full details on reconstructing the training corpus, please refer to Supplementary Information S4 in our paper.

Evaluation

This model was selected based on its superior generalization performance (especially recall) on external datasets like the OSDG Community Dataset and the SDGi Corpus. On a human-coded sample of scientific articles, the model achieved a macro-averaged F1-score of 0.623. For a full breakdown of performance metrics, please see the paper.

πŸ“œ Citation

If you use this model in your research, please cite our paper:

@article{Miyashita2025,
  author    = {Naoto Miyashita and Takanori Matsui and Chihiro Haga and Naoki Masuhara and Shun Kawakubo},
  title     = {Bridging the Sustainable Development Goals: A Multi-Label Text Classification Approach for Mapping and Visualizing Nexuses in Sustainability Research},
  journal   = {Sustainability Science},
  year      = {2025},
  % TODO: Add Volume, Pages, DOI upon publication
}
Downloads last month

-

Downloads are not tracked for this model. How to track
Inference Providers NEW
This model isn't deployed by any Inference Provider. πŸ™‹ Ask for provider support