import torch import torch.nn as nn import torch.nn.functional as F import math class Identity(nn.Module): """Identity Module.""" def __init__(self): super().__init__() def forward(self, x): return x class Squeeze(nn.Module): """Squeeze Module.""" def __init__(self, dim): super().__init__() self.dim = dim def forward(self, x): return x.squeeze(self.dim) class SuperLinear(nn.Module): """SuperLinear Layer: Implements Neuron-Level Models (NLMs) for the CTM.""" def __init__(self, in_dims, out_dims, N): super().__init__() self.in_dims = in_dims self.register_parameter('w1', nn.Parameter( torch.empty((in_dims, out_dims, N)).uniform_( -1/math.sqrt(in_dims + out_dims), 1/math.sqrt(in_dims + out_dims) ), requires_grad=True) ) self.register_parameter('b1', nn.Parameter(torch.zeros((1, N, out_dims)), requires_grad=True)) def forward(self, x): out = torch.einsum('BDM,MHD->BDH', x, self.w1) + self.b1 out = out.squeeze(-1) return out def compute_normalized_entropy(logits, reduction='mean'): """Computes the normalized entropy for certainty-loss.""" preds = F.softmax(logits, dim=-1) log_preds = torch.log_softmax(logits, dim=-1) entropy = -torch.sum(preds * log_preds, dim=-1) num_classes = preds.shape[-1] max_entropy = torch.log(torch.tensor(num_classes, dtype=torch.float32)) normalized_entropy = entropy / max_entropy if len(logits.shape) > 2 and reduction == 'mean': normalized_entropy = normalized_entropy.flatten(1).mean(-1) return normalized_entropy class ContinuousThoughtMachine(nn.Module): def __init__(self, iterations, d_model, d_input, memory_length, heads, n_synch_out, n_synch_action, out_dims, memory_hidden_dims, ): super(ContinuousThoughtMachine, self).__init__() # --- Core Parameters --- self.iterations = iterations self.d_model = d_model self.d_input = d_input self.memory_length = memory_length self.n_synch_out = n_synch_out self.n_synch_action = n_synch_action self.out_dims = out_dims self.memory_length = memory_length self.memory_hidden_dims = memory_hidden_dims # --- Input Processing --- self.backbone = nn.Sequential( nn.LazyConv2d(d_input, kernel_size=3, stride=1, padding=1), nn.BatchNorm2d(d_input), nn.ReLU(), nn.MaxPool2d(2, 2), nn.LazyConv2d(d_input, kernel_size=3, stride=1, padding=1), nn.BatchNorm2d(d_input), nn.ReLU(), nn.MaxPool2d(2, 2), ) self.attention = nn.MultiheadAttention(self.d_input, heads, batch_first=True) self.kv_proj = nn.Sequential(nn.LazyLinear(self.d_input), nn.LayerNorm(self.d_input)) self.q_proj = nn.LazyLinear(self.d_input) # --- Core CTM Modules --- self.synapses = nn.Sequential( nn.LazyLinear(d_model * 2), nn.GLU(), nn.LayerNorm(d_model) ) self.trace_processor = nn.Sequential( SuperLinear(in_dims=memory_length, out_dims=2 * memory_hidden_dims, N=d_model), nn.GLU(), SuperLinear(in_dims=memory_hidden_dims, out_dims=2, N=d_model), nn.GLU(), Squeeze(-1) ) # --- Start States --- self.register_parameter('start_activated_state', nn.Parameter( torch.zeros((d_model)).uniform_(-math.sqrt(1/(d_model)), math.sqrt(1/(d_model))), requires_grad=True )) self.register_parameter('start_trace', nn.Parameter( torch.zeros((d_model, memory_length)).uniform_(-math.sqrt(1/(d_model+memory_length)), math.sqrt(1/(d_model+memory_length))), requires_grad=True )) # --- Synchronisation --- self.synch_representation_size_action = (self.n_synch_action * (self.n_synch_action+1))//2 self.synch_representation_size_out = (self.n_synch_out * (self.n_synch_out+1))//2 for synch_type, size in [('action', self.synch_representation_size_action), ('out', self.synch_representation_size_out)]: print(f"Synch representation size {synch_type}: {size}") self.set_synchronisation_parameters('out', self.n_synch_out) self.set_synchronisation_parameters('action', self.n_synch_action) # --- Output Processing --- self.output_projector = nn.Sequential(nn.LazyLinear(self.out_dims)) def set_synchronisation_parameters(self, synch_type: str, n_synch: int): left, right = self.initialize_left_right_neurons(synch_type, self.d_model, n_synch) synch_representation_size = self.synch_representation_size_action if synch_type == 'action' else self.synch_representation_size_out self.register_buffer(f'{synch_type}_neuron_indices_left', left) self.register_buffer(f'{synch_type}_neuron_indices_right', right) self.register_parameter(f'decay_params_{synch_type}', nn.Parameter(torch.zeros(synch_representation_size), requires_grad=True)) def initialize_left_right_neurons(self, synch_type, d_model, n_synch): if synch_type == 'out': neuron_indices_left = neuron_indices_right = torch.arange(0, n_synch) elif synch_type == 'action': neuron_indices_left = neuron_indices_right = torch.arange(d_model-n_synch, d_model) return neuron_indices_left, neuron_indices_right def compute_synchronisation(self, activated_state, decay_alpha, decay_beta, synch_type): B = activated_state.size(0) if synch_type == 'action': n_synch = self.n_synch_action decay_params = self.decay_params_action selected_left = selected_right = activated_state[:, -n_synch:] elif synch_type == 'out': n_synch = self.n_synch_out decay_params = self.decay_params_out selected_left = selected_right = activated_state[:, :n_synch] outer = selected_left.unsqueeze(2) * selected_right.unsqueeze(1) i, j = torch.triu_indices(n_synch, n_synch) pairwise_product = outer[:, i, j] r = torch.exp(-torch.clamp(decay_params, 0, 15)).unsqueeze(0).repeat(B, 1) if decay_alpha is None or decay_beta is None: decay_alpha = pairwise_product decay_beta = torch.ones_like(pairwise_product) else: decay_alpha = r * decay_alpha + pairwise_product decay_beta = r * decay_beta + 1 synchronisation = decay_alpha / (torch.sqrt(decay_beta)) return synchronisation, decay_alpha, decay_beta def compute_features(self, x): input_features = self.backbone(x) kv = self.kv_proj(input_features.flatten(2).transpose(1, 2)) return kv def compute_certainty(self, current_prediction): ne = compute_normalized_entropy(current_prediction) current_certainty = torch.stack((ne, 1-ne), -1) return current_certainty def forward(self, x, track=False): B = x.size(0) device = x.device # --- Tracking Initialization --- pre_activations_tracking = [] post_activations_tracking = [] synch_out_tracking = [] synch_action_tracking = [] attention_tracking = [] # --- Featurise Input Data --- kv = self.compute_features(x) # --- Initialise Recurrent State --- state_trace = self.start_trace.unsqueeze(0).expand(B, -1, -1) # Shape: (B, H, T) activated_state = self.start_activated_state.unsqueeze(0).expand(B, -1) # Shape: (B, H) # --- Storage for outputs per iteration predictions = torch.empty(B, self.out_dims, self.iterations, device=device, dtype=x.dtype) certainties = torch.empty(B, 2, self.iterations, device=device, dtype=x.dtype) decay_alpha_action, decay_beta_action = None, None _, decay_alpha_out, decay_beta_out = self.compute_synchronisation(activated_state, None, None, synch_type='out') # --- Recurrent Loop --- for stepi in range(self.iterations): # --- Calculate Synchronisation for Input Data Interaction --- synchronisation_action, decay_alpha_action, decay_beta_action = self.compute_synchronisation(activated_state, decay_alpha_action, decay_beta_action, synch_type='action') # --- Interact with Data via Attention --- q = self.q_proj(synchronisation_action).unsqueeze(1) attn_out, attn_weights = self.attention(q, kv, kv, average_attn_weights=False, need_weights=True) attn_out = attn_out.squeeze(1) pre_synapse_input = torch.concatenate((attn_out, activated_state), dim=-1) # --- Apply Synapses --- state = self.synapses(pre_synapse_input) state_trace = torch.cat((state_trace[:, :, 1:], state.unsqueeze(-1)), dim=-1) # --- Activate --- activated_state = self.trace_processor(state_trace) # --- Calculate Synchronisation for Output Predictions --- synchronisation_out, decay_alpha_out, decay_beta_out = self.compute_synchronisation(activated_state, decay_alpha_out, decay_beta_out, synch_type='out') # --- Get Predictions and Certainties --- current_prediction = self.output_projector(synchronisation_out) current_certainty = self.compute_certainty(current_prediction) predictions[..., stepi] = current_prediction certainties[..., stepi] = current_certainty # --- Tracking --- if track: pre_activations_tracking.append(state_trace[:,:,-1].detach().cpu().numpy()) post_activations_tracking.append(activated_state.detach().cpu().numpy()) attention_tracking.append(attn_weights.detach().cpu().numpy()) synch_out_tracking.append(synchronisation_out.detach().cpu().numpy()) synch_action_tracking.append(synchronisation_action.detach().cpu().numpy()) # --- Return Values --- if track: return predictions, certainties, (np.array(synch_out_tracking), np.array(synch_action_tracking)), np.array(pre_activations_tracking), np.array(post_activations_tracking), np.array(attention_tracking) return predictions, certainties, synchronisation_out