File size: 2,136 Bytes
889f722
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
# ------------------------------------
# Modular Steering Hook Logic
# ------------------------------------

class SteeringHook:
    def __init__(self, sae, device):
        """
        Initialize the SteeringHook with an SAE and device.
        """
        self.sae = sae
        self.device = device
        self.feature_coeffs = []  # List of (feature_index, coefficient)
        self.hooks = []  # Store hook handles
        self.steering_enabled = False

    def enable_steering(self, feature_coeffs):
        """
        Enable steering by specifying feature coefficients.
        Args:
            feature_coeffs (list): List of (feature_index, coefficient).
        """
        self.feature_coeffs = feature_coeffs
        self.steering_enabled = True

    def disable_steering(self):
        """
        Disable steering and clear hooks.
        """
        self.steering_enabled = False
        self.feature_coeffs = []
        self.remove_hooks()

    def generate_hook(self):
        """
        Create a steering hook function that modifies the residual output.
        """
        def hook_fn(module, inputs, outputs):
            if not self.steering_enabled:
                return outputs

            residual = outputs[0]  # Residual output of the module
            for feature_index, coeff in self.feature_coeffs:
                steering_vector = self.sae.W_dec[feature_index].to(self.device).unsqueeze(0).unsqueeze(0)
                residual = residual + coeff * steering_vector
            return (residual, *outputs[1:])

        return hook_fn

    def register_hooks(self, model, block_idx):
        """
        Register the steering hook to the specified block.
        Args:
            model (nn.Module): The target model.
            block_idx (int): The block index to attach the hook.
        """
        handle = model.transformer.h[block_idx].register_forward_hook(self.generate_hook())
        self.hooks.append(handle)

    def remove_hooks(self):
        """
        Remove all registered hooks.
        """
        for handle in self.hooks:
            handle.remove()
        self.hooks.clear()