# ------------------------------------ # Modular Steering Hook Logic # ------------------------------------ class SteeringHook: def __init__(self, sae, device): """ Initialize the SteeringHook with an SAE and device. """ self.sae = sae self.device = device self.feature_coeffs = [] # List of (feature_index, coefficient) self.hooks = [] # Store hook handles self.steering_enabled = False def enable_steering(self, feature_coeffs): """ Enable steering by specifying feature coefficients. Args: feature_coeffs (list): List of (feature_index, coefficient). """ self.feature_coeffs = feature_coeffs self.steering_enabled = True def disable_steering(self): """ Disable steering and clear hooks. """ self.steering_enabled = False self.feature_coeffs = [] self.remove_hooks() def generate_hook(self): """ Create a steering hook function that modifies the residual output. """ def hook_fn(module, inputs, outputs): if not self.steering_enabled: return outputs residual = outputs[0] # Residual output of the module for feature_index, coeff in self.feature_coeffs: steering_vector = self.sae.W_dec[feature_index].to(self.device).unsqueeze(0).unsqueeze(0) residual = residual + coeff * steering_vector return (residual, *outputs[1:]) return hook_fn def register_hooks(self, model, block_idx): """ Register the steering hook to the specified block. Args: model (nn.Module): The target model. block_idx (int): The block index to attach the hook. """ handle = model.transformer.h[block_idx].register_forward_hook(self.generate_hook()) self.hooks.append(handle) def remove_hooks(self): """ Remove all registered hooks. """ for handle in self.hooks: handle.remove() self.hooks.clear()