seonglae's picture
feat: hf space corr-steer
889f722
# ------------------------------------
# Modular Steering Hook Logic
# ------------------------------------
class SteeringHook:
def __init__(self, sae, device):
"""
Initialize the SteeringHook with an SAE and device.
"""
self.sae = sae
self.device = device
self.feature_coeffs = [] # List of (feature_index, coefficient)
self.hooks = [] # Store hook handles
self.steering_enabled = False
def enable_steering(self, feature_coeffs):
"""
Enable steering by specifying feature coefficients.
Args:
feature_coeffs (list): List of (feature_index, coefficient).
"""
self.feature_coeffs = feature_coeffs
self.steering_enabled = True
def disable_steering(self):
"""
Disable steering and clear hooks.
"""
self.steering_enabled = False
self.feature_coeffs = []
self.remove_hooks()
def generate_hook(self):
"""
Create a steering hook function that modifies the residual output.
"""
def hook_fn(module, inputs, outputs):
if not self.steering_enabled:
return outputs
residual = outputs[0] # Residual output of the module
for feature_index, coeff in self.feature_coeffs:
steering_vector = self.sae.W_dec[feature_index].to(self.device).unsqueeze(0).unsqueeze(0)
residual = residual + coeff * steering_vector
return (residual, *outputs[1:])
return hook_fn
def register_hooks(self, model, block_idx):
"""
Register the steering hook to the specified block.
Args:
model (nn.Module): The target model.
block_idx (int): The block index to attach the hook.
"""
handle = model.transformer.h[block_idx].register_forward_hook(self.generate_hook())
self.hooks.append(handle)
def remove_hooks(self):
"""
Remove all registered hooks.
"""
for handle in self.hooks:
handle.remove()
self.hooks.clear()