Entropy_final_reward_model / modeling_custom.py
HFXM's picture
Upload folder using huggingface_hub
14b495b verified
from dataclasses import dataclass
from typing import Optional, Tuple
import torch
import torch.nn as nn
from transformers import LlamaPreTrainedModel, LlamaModel
from transformers.utils import ModelOutput
@dataclass
class MultiAspectRewardOutput(ModelOutput):
"""
Custom output class to return multi-aspect predictions plus final reward.
Args:
aspect_scores (torch.FloatTensor): shape (batch, 5)
final_reward (torch.FloatTensor): shape (batch,)
logits (torch.FloatTensor): shape (batch,) same as final_reward
loss (torch.FloatTensor): optional scalar
hidden_states (tuple(torch.FloatTensor)): optional hidden states
attentions (tuple(torch.FloatTensor)): optional attentions
"""
aspect_scores: torch.FloatTensor = None
final_reward: torch.FloatTensor = None
logits: torch.FloatTensor = None
loss: Optional[torch.FloatTensor] = None
hidden_states: Optional[Tuple[torch.FloatTensor, ...]] = None
attentions: Optional[Tuple[torch.FloatTensor, ...]] = None
class LlamaFixedWeightReward(LlamaPreTrainedModel):
"""
A single final class that:
1) Optionally takes a pretrained Llama backbone (base_llama),
2) Predicts 5 aspect scores, computing MSE if 5-dim labels are provided,
3) Aggregates the 5 aspect scores via fixed weights -> 1 scalar reward,
4) Returns MultiAspectRewardOutput with shape [batch] in 'final_reward' and 'logits'.
"""
def __init__(self, config, base_llama=None, rule_weights=None):
"""
Args:
config: LlamaConfig with num_labels=5 for multi-aspect predictions.
base_llama: (optional) an already loaded LlamaModel
rule_weights: (optional) A list or torch.Tensor of shape (5,) for aggregation.
If None, defaults to [0.2, 0.2, 0.2, 0.2, 0.2].
"""
super().__init__(config)
# 1) If base_llama is given, re-use that. Otherwise instantiate from config
if base_llama is not None:
self.llama = base_llama
else:
self.llama = LlamaModel(config)
# 2) Linear head to predict 5 aspect scores
# Expect config.num_labels=5
self.aspect_head = nn.Linear(config.hidden_size, config.num_labels)
# 3) Register the fixed aggregator weights
if rule_weights is not None:
w = torch.tensor(rule_weights, dtype=torch.float)
else:
weights = [1/config.num_labels] * config.num_labels
# weights = [1.0] + [0.0] *9 #DEBUG
w = torch.tensor(weights, dtype=torch.float)
self.register_buffer("rule_weights", w.view(1, -1), persistent=True)
self.post_init()
def forward(
self,
input_ids=None,
attention_mask=None,
labels=None, # shape: (batch, 5), optional
**kwargs
):
# 1) Forward pass through Llama
outputs = self.llama(
input_ids=input_ids,
attention_mask=attention_mask,
**kwargs
)
# last hidden state: [batch, seq_len, hidden_size]
last_hidden = outputs.last_hidden_state
# 2) pool by taking the last token representation
pooled = last_hidden[:, -1, :] # [batch, hidden_size]
# 3) Predict 5 aspect scores
aspect_scores = self.aspect_head(pooled) # [batch, 5]
# If your labels are in [0,1], clamp with sigmoid
aspect_scores = torch.sigmoid(aspect_scores)
# 4) optional MSE loss
loss = None
if labels is not None:
mse_fn = nn.MSELoss()
loss = mse_fn(aspect_scores, labels.float())
# 5) aggregate via fixed weights => final scalar: shape [batch]
reward = (aspect_scores * self.rule_weights).sum(dim=-1)
# Return a custom output
return MultiAspectRewardOutput(
loss=loss,
aspect_scores=aspect_scores, # shape: [batch, 5]
final_reward=reward, # shape: [batch]
logits=reward, # same as final_reward
hidden_states=outputs.hidden_states,
attentions=outputs.attentions
)