File size: 4,217 Bytes
14b495b
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
from dataclasses import dataclass
from typing import Optional, Tuple
import torch
import torch.nn as nn
from transformers import LlamaPreTrainedModel, LlamaModel
from transformers.utils import ModelOutput

@dataclass
class MultiAspectRewardOutput(ModelOutput):
    """
    Custom output class to return multi-aspect predictions plus final reward.
    
    Args:
        aspect_scores (torch.FloatTensor): shape (batch, 5)
        final_reward (torch.FloatTensor): shape (batch,)
        logits (torch.FloatTensor): shape (batch,) same as final_reward
        loss (torch.FloatTensor): optional scalar
        hidden_states (tuple(torch.FloatTensor)): optional hidden states
        attentions (tuple(torch.FloatTensor)): optional attentions
    """
    aspect_scores: torch.FloatTensor = None
    final_reward: torch.FloatTensor = None
    logits: torch.FloatTensor = None
    loss: Optional[torch.FloatTensor] = None
    hidden_states: Optional[Tuple[torch.FloatTensor, ...]] = None
    attentions: Optional[Tuple[torch.FloatTensor, ...]] = None

class LlamaFixedWeightReward(LlamaPreTrainedModel):
    """
    A single final class that:
      1) Optionally takes a pretrained Llama backbone (base_llama),
      2) Predicts 5 aspect scores, computing MSE if 5-dim labels are provided,
      3) Aggregates the 5 aspect scores via fixed weights -> 1 scalar reward,
      4) Returns MultiAspectRewardOutput with shape [batch] in 'final_reward' and 'logits'.
    """
    def __init__(self, config, base_llama=None, rule_weights=None):
        """
        Args:
          config: LlamaConfig with num_labels=5 for multi-aspect predictions.
          base_llama: (optional) an already loaded LlamaModel
          rule_weights: (optional) A list or torch.Tensor of shape (5,) for aggregation.
                        If None, defaults to [0.2, 0.2, 0.2, 0.2, 0.2].
        """
        super().__init__(config)
        
        # 1) If base_llama is given, re-use that. Otherwise instantiate from config
        if base_llama is not None:
            self.llama = base_llama
        else:
            self.llama = LlamaModel(config)

        # 2) Linear head to predict 5 aspect scores
        # Expect config.num_labels=5
        self.aspect_head = nn.Linear(config.hidden_size, config.num_labels)

        # 3) Register the fixed aggregator weights
        if rule_weights is not None:
            w = torch.tensor(rule_weights, dtype=torch.float)
        else:
            weights = [1/config.num_labels] * config.num_labels
            # weights = [1.0] + [0.0] *9 #DEBUG
            w = torch.tensor(weights, dtype=torch.float)
        self.register_buffer("rule_weights", w.view(1, -1), persistent=True)

        self.post_init()

    def forward(
        self,
        input_ids=None,
        attention_mask=None,
        labels=None,  # shape: (batch, 5), optional
        **kwargs
    ):
        # 1) Forward pass through Llama
        outputs = self.llama(
            input_ids=input_ids,
            attention_mask=attention_mask,
            **kwargs
        )
        # last hidden state: [batch, seq_len, hidden_size]
        last_hidden = outputs.last_hidden_state

        # 2) pool by taking the last token representation
        pooled = last_hidden[:, -1, :]  # [batch, hidden_size]

        # 3) Predict 5 aspect scores
        aspect_scores = self.aspect_head(pooled)  # [batch, 5]
        # If your labels are in [0,1], clamp with sigmoid
        aspect_scores = torch.sigmoid(aspect_scores)

        # 4) optional MSE loss
        loss = None
        if labels is not None:
            mse_fn = nn.MSELoss()
            loss = mse_fn(aspect_scores, labels.float())

        # 5) aggregate via fixed weights => final scalar: shape [batch]
        reward = (aspect_scores * self.rule_weights).sum(dim=-1)

        # Return a custom output
        return MultiAspectRewardOutput(
            loss=loss,
            aspect_scores=aspect_scores,       # shape: [batch, 5]
            final_reward=reward,              # shape: [batch]
            logits=reward,                    # same as final_reward
            hidden_states=outputs.hidden_states, 
            attentions=outputs.attentions
        )