|
""" |
|
Hugging Face compatible ESG Hierarchical Multi-Task Learning Model |
|
""" |
|
import torch |
|
import torch.nn as nn |
|
from transformers import PreTrainedModel, PretrainedConfig, AutoModel, AutoTokenizer |
|
from transformers.modeling_outputs import SequenceClassifierOutput |
|
from typing import Optional, Tuple, Dict, Any |
|
|
|
class ESGHierarchicalConfig(PretrainedConfig): |
|
""" |
|
Configuration class for ESG Hierarchical Multi-Task Learning model |
|
""" |
|
model_type = "esg_hierarchical" |
|
|
|
def __init__( |
|
self, |
|
backbone_model_name: str = "Alibaba-NLP/gte-multilingual-base", |
|
num_esg_classes: int = 4, |
|
num_e_sub_classes: int = 3, |
|
num_s_sub_classes: int = 7, |
|
num_g_sub_classes: int = 5, |
|
shared_hidden_dim_factor: float = 0.5, |
|
head_hidden_dim_factor: float = 0.25, |
|
dropout_rate: float = 0.1, |
|
freeze_backbone: bool = True, |
|
esg_categories: list = None, |
|
sub_factors_e: list = None, |
|
sub_factors_s: list = None, |
|
sub_factors_g: list = None, |
|
final_others_label: str = "Others", |
|
max_length: int = 4096, |
|
**kwargs |
|
): |
|
super().__init__(**kwargs) |
|
|
|
|
|
self.backbone_model_name = backbone_model_name |
|
self.num_esg_classes = num_esg_classes |
|
self.num_e_sub_classes = num_e_sub_classes |
|
self.num_s_sub_classes = num_s_sub_classes |
|
self.num_g_sub_classes = num_g_sub_classes |
|
self.shared_hidden_dim_factor = shared_hidden_dim_factor |
|
self.head_hidden_dim_factor = head_hidden_dim_factor |
|
self.dropout_rate = dropout_rate |
|
self.freeze_backbone = freeze_backbone |
|
self.max_length = max_length |
|
|
|
|
|
self.esg_categories = esg_categories or ["E", "S", "G", "Others_ESG"] |
|
self.sub_factors_e = sub_factors_e or ["Emission", "Resource Use", "Product Innovation"] |
|
self.sub_factors_s = sub_factors_s or ["Community", "Diversity", "Employment", "HS", "HR", "PR", "Training"] |
|
self.sub_factors_g = sub_factors_g or ["BFunction", "BStructure", "Compensation", "Shareholder", "Vision"] |
|
self.final_others_label = final_others_label |
|
|
|
|
|
self.esg_label2id = {label: i for i, label in enumerate(self.esg_categories)} |
|
self.esg_id2label = {i: label for i, label in enumerate(self.esg_categories)} |
|
|
|
self.e_sub_label2id = {label: i for i, label in enumerate(self.sub_factors_e)} |
|
self.e_sub_id2label = {i: label for i, label in enumerate(self.sub_factors_e)} |
|
|
|
self.s_sub_label2id = {label: i for i, label in enumerate(self.sub_factors_s)} |
|
self.s_sub_id2label = {i: label for i, label in enumerate(self.sub_factors_s)} |
|
|
|
self.g_sub_label2id = {label: i for i, label in enumerate(self.sub_factors_g)} |
|
self.g_sub_id2label = {i: label for i, label in enumerate(self.sub_factors_g)} |
|
|
|
|
|
all_sub_factors = self.sub_factors_e + self.sub_factors_s + self.sub_factors_g |
|
all_final_labels = all_sub_factors + [self.final_others_label] |
|
self.final_label2id = {label: i for i, label in enumerate(all_final_labels)} |
|
self.final_id2label = {i: label for i, label in enumerate(all_final_labels)} |
|
|
|
|
|
class ESGHierarchicalForSequenceClassification(PreTrainedModel): |
|
""" |
|
ESG Hierarchical Multi-Task Learning Model for Hugging Face Hub |
|
|
|
This model performs hierarchical ESG classification: |
|
1. First predicts main ESG category (E, S, G, Others_ESG) |
|
2. Then predicts corresponding sub-factor based on the main category |
|
""" |
|
config_class = ESGHierarchicalConfig |
|
|
|
def __init__(self, config: ESGHierarchicalConfig, backbone=None): |
|
super().__init__(config) |
|
self.config = config |
|
|
|
|
|
if backbone is not None: |
|
self.gte_model = backbone |
|
else: |
|
|
|
from transformers import AutoConfig |
|
backbone_config = AutoConfig.from_pretrained(config.backbone_model_name, trust_remote_code=True) |
|
self.gte_model = AutoModel.from_config(backbone_config, trust_remote_code=True) |
|
|
|
|
|
if config.freeze_backbone: |
|
for param in self.gte_model.parameters(): |
|
param.requires_grad = False |
|
|
|
|
|
embedding_dim = self.gte_model.config.hidden_size |
|
shared_hidden_dim = int(embedding_dim * config.shared_hidden_dim_factor) |
|
head_hidden_dim = int(embedding_dim * config.head_hidden_dim_factor) |
|
|
|
|
|
self.shared_projector = nn.Sequential( |
|
nn.Linear(embedding_dim, shared_hidden_dim), |
|
nn.LayerNorm(shared_hidden_dim), |
|
nn.ReLU(), |
|
nn.Dropout(config.dropout_rate) |
|
) |
|
|
|
|
|
self.esg_head = nn.Sequential( |
|
nn.Linear(shared_hidden_dim, head_hidden_dim), |
|
nn.LayerNorm(head_hidden_dim), |
|
nn.ReLU(), |
|
nn.Dropout(config.dropout_rate), |
|
nn.Linear(head_hidden_dim, config.num_esg_classes) |
|
) |
|
|
|
self.e_sub_head = nn.Sequential( |
|
nn.Linear(shared_hidden_dim, head_hidden_dim), |
|
nn.LayerNorm(head_hidden_dim), |
|
nn.ReLU(), |
|
nn.Dropout(config.dropout_rate), |
|
nn.Linear(head_hidden_dim, config.num_e_sub_classes) |
|
) |
|
|
|
self.s_sub_head = nn.Sequential( |
|
nn.Linear(shared_hidden_dim, head_hidden_dim), |
|
nn.LayerNorm(head_hidden_dim), |
|
nn.ReLU(), |
|
nn.Dropout(config.dropout_rate), |
|
nn.Linear(head_hidden_dim, config.num_s_sub_classes) |
|
) |
|
|
|
self.g_sub_head = nn.Sequential( |
|
nn.Linear(shared_hidden_dim, head_hidden_dim), |
|
nn.LayerNorm(head_hidden_dim), |
|
nn.ReLU(), |
|
nn.Dropout(config.dropout_rate), |
|
nn.Linear(head_hidden_dim, config.num_g_sub_classes) |
|
) |
|
|
|
|
|
self._init_weights() |
|
|
|
def _init_weights(self): |
|
"""Initialize the weights using Xavier initialization""" |
|
for module in [self.shared_projector, self.esg_head, |
|
self.e_sub_head, self.s_sub_head, self.g_sub_head]: |
|
for m in module: |
|
if isinstance(m, nn.Linear): |
|
nn.init.xavier_normal_(m.weight) |
|
if m.bias is not None: |
|
nn.init.constant_(m.bias, 0) |
|
|
|
def forward( |
|
self, |
|
input_ids: Optional[torch.Tensor] = None, |
|
attention_mask: Optional[torch.Tensor] = None, |
|
labels: Optional[Dict[str, torch.Tensor]] = None, |
|
return_dict: Optional[bool] = None, |
|
**kwargs |
|
) -> SequenceClassifierOutput: |
|
""" |
|
Forward pass of the model |
|
|
|
Args: |
|
input_ids: Input token IDs |
|
attention_mask: Attention mask for padding tokens |
|
labels: Dictionary containing target labels for each task |
|
return_dict: Whether to return dictionary format |
|
|
|
Returns: |
|
SequenceClassifierOutput containing logits and loss |
|
""" |
|
return_dict = return_dict if return_dict is not None else self.config.use_return_dict |
|
|
|
|
|
backbone_outputs = self.gte_model( |
|
input_ids=input_ids, |
|
attention_mask=attention_mask |
|
) |
|
|
|
|
|
cls_embedding = backbone_outputs.last_hidden_state[:, 0, :] |
|
|
|
|
|
cls_embedding = cls_embedding.float() |
|
|
|
|
|
cls_embedding = torch.clamp(cls_embedding, min=-100, max=100) |
|
|
|
|
|
shared_features = self.shared_projector(cls_embedding) |
|
|
|
|
|
esg_logits = self.esg_head(shared_features) |
|
e_sub_logits = self.e_sub_head(shared_features) |
|
s_sub_logits = self.s_sub_head(shared_features) |
|
g_sub_logits = self.g_sub_head(shared_features) |
|
|
|
|
|
logits = { |
|
'esg': esg_logits, |
|
'e_sub': e_sub_logits, |
|
's_sub': s_sub_logits, |
|
'g_sub': g_sub_logits |
|
} |
|
|
|
loss = None |
|
if labels is not None: |
|
|
|
loss_fct = nn.CrossEntropyLoss(ignore_index=-100) |
|
|
|
losses = [] |
|
loss_weights = {"esg": 1.0, "e_sub": 0.7, "s_sub": 0.7, "g_sub": 0.7} |
|
|
|
for task_name, task_logits in logits.items(): |
|
if task_name in labels: |
|
task_loss = loss_fct(task_logits, labels[task_name]) |
|
losses.append(loss_weights[task_name] * task_loss) |
|
|
|
if losses: |
|
loss = sum(losses) |
|
|
|
if not return_dict: |
|
output = (logits,) |
|
if loss is not None: |
|
output = (loss,) + output |
|
return output |
|
|
|
return SequenceClassifierOutput( |
|
loss=loss, |
|
logits=logits, |
|
hidden_states=backbone_outputs.hidden_states, |
|
attentions=backbone_outputs.attentions, |
|
) |
|
|
|
def predict(self, input_ids: torch.Tensor, attention_mask: torch.Tensor) -> Tuple[str, str]: |
|
""" |
|
Hierarchical prediction method compatible with original implementation |
|
|
|
Returns: |
|
Tuple of (main_esg_factor, sub_factor) |
|
""" |
|
self.eval() |
|
with torch.no_grad(): |
|
outputs = self.forward(input_ids=input_ids, attention_mask=attention_mask) |
|
logits = outputs.logits |
|
|
|
|
|
prob_esg = torch.softmax(logits['esg'], dim=1) |
|
pred_esg_idx = torch.argmax(prob_esg, dim=1).item() |
|
pred_esg_label = self.config.esg_id2label[pred_esg_idx] |
|
|
|
|
|
sub_factor_label = self.config.final_others_label |
|
|
|
|
|
if pred_esg_label == "E": |
|
prob_e_sub = torch.softmax(logits['e_sub'][0], dim=0) |
|
pred_e_sub_idx = torch.argmax(prob_e_sub).item() |
|
sub_factor_label = self.config.e_sub_id2label[pred_e_sub_idx] |
|
elif pred_esg_label == "S": |
|
prob_s_sub = torch.softmax(logits['s_sub'][0], dim=0) |
|
pred_s_sub_idx = torch.argmax(prob_s_sub).item() |
|
sub_factor_label = self.config.s_sub_id2label[pred_s_sub_idx] |
|
elif pred_esg_label == "G": |
|
prob_g_sub = torch.softmax(logits['g_sub'][0], dim=0) |
|
pred_g_sub_idx = torch.argmax(prob_g_sub).item() |
|
sub_factor_label = self.config.g_sub_id2label[pred_g_sub_idx] |
|
|
|
return pred_esg_label, sub_factor_label |
|
|
|
|
|
|
|
ESGHierarchicalConfig.register_for_auto_class() |
|
ESGHierarchicalForSequenceClassification.register_for_auto_class("AutoModelForSequenceClassification") |
|
|