|
|
|
|
|
|
|
|
|
class SteeringHook: |
|
def __init__(self, sae, device): |
|
""" |
|
Initialize the SteeringHook with an SAE and device. |
|
""" |
|
self.sae = sae |
|
self.device = device |
|
self.feature_coeffs = [] |
|
self.hooks = [] |
|
self.steering_enabled = False |
|
|
|
def enable_steering(self, feature_coeffs): |
|
""" |
|
Enable steering by specifying feature coefficients. |
|
Args: |
|
feature_coeffs (list): List of (feature_index, coefficient). |
|
""" |
|
self.feature_coeffs = feature_coeffs |
|
self.steering_enabled = True |
|
|
|
def disable_steering(self): |
|
""" |
|
Disable steering and clear hooks. |
|
""" |
|
self.steering_enabled = False |
|
self.feature_coeffs = [] |
|
self.remove_hooks() |
|
|
|
def generate_hook(self): |
|
""" |
|
Create a steering hook function that modifies the residual output. |
|
""" |
|
def hook_fn(module, inputs, outputs): |
|
if not self.steering_enabled: |
|
return outputs |
|
|
|
residual = outputs[0] |
|
for feature_index, coeff in self.feature_coeffs: |
|
steering_vector = self.sae.W_dec[feature_index].to(self.device).unsqueeze(0).unsqueeze(0) |
|
residual = residual + coeff * steering_vector |
|
return (residual, *outputs[1:]) |
|
|
|
return hook_fn |
|
|
|
def register_hooks(self, model, block_idx): |
|
""" |
|
Register the steering hook to the specified block. |
|
Args: |
|
model (nn.Module): The target model. |
|
block_idx (int): The block index to attach the hook. |
|
""" |
|
handle = model.transformer.h[block_idx].register_forward_hook(self.generate_hook()) |
|
self.hooks.append(handle) |
|
|
|
def remove_hooks(self): |
|
""" |
|
Remove all registered hooks. |
|
""" |
|
for handle in self.hooks: |
|
handle.remove() |
|
self.hooks.clear() |
|
|