Hinton's Forward-Forward algorithm for local learning without backpropagation. Use for biologically plausible learning, on-chip training, memory-efficient networks, or parallel layer-wise training.
/plugin marketplace add plurigrid/asi/plugin install plurigrid-asi-skills@plurigrid/asiThis skill inherits all available tools. When active, it can use any tool Claude has access to.
Trit: +1 (PLUS - generator) Color: Red (#D82626)
Implements Geoffrey Hinton's Forward-Forward (FF) algorithm (2022) and extensions:
Replace backprop with two forward passes:
\text{Positive pass}: x^+ \text{ (real data)} \rightarrow \text{high goodness}
\text{Negative pass}: x^- \text{ (generated/corrupted)} \rightarrow \text{low goodness}
\text{Goodness function}: G(h) = \sum_i h_i^2 \text{ (sum of squared activations)}
\text{Layer objective}: \max G(h^+) - G(h^-) \text{ subject to threshold } \theta
Each layer trains independently:
Layer L objective:
P(positive | h_L) = σ(G(h_L) - θ)
Loss: -log P(positive | h_L^+) - log(1 - P(positive | h_L^-))
Generate negative samples from the network itself:
x^- = \text{augment}(x^+) \text{ or } x^- = G_\phi(z) \text{ (learned generator)}
import torch
import torch.nn as nn
import torch.nn.functional as F
class FFLayer(nn.Module):
"""Forward-Forward layer with local learning."""
def __init__(self, in_dim, out_dim, threshold=2.0):
super().__init__()
self.linear = nn.Linear(in_dim, out_dim)
self.threshold = threshold
self.optimizer = None # Set per-layer optimizer
def goodness(self, h):
"""Compute goodness: sum of squared activations."""
return (h ** 2).sum(dim=-1)
def forward(self, x, label=None):
"""Forward pass with optional label embedding."""
if label is not None:
# Embed label in first 10 dimensions (for MNIST)
x = x.clone()
x[:, :10] = 0
x[:, label] = 1
h = F.relu(self.linear(x))
return h
def train_step(self, x_pos, x_neg):
"""Local training step using FF algorithm."""
h_pos = self.forward(x_pos)
h_neg = self.forward(x_neg)
g_pos = self.goodness(h_pos)
g_neg = self.goodness(h_neg)
# Loss: positive above threshold, negative below
loss_pos = F.softplus(self.threshold - g_pos).mean()
loss_neg = F.softplus(g_neg - self.threshold).mean()
loss = loss_pos + loss_neg
self.optimizer.zero_grad()
loss.backward()
self.optimizer.step()
return loss.item(), h_pos.detach(), h_neg.detach()
class FFNetwork(nn.Module):
"""Full Forward-Forward network."""
def __init__(self, dims, threshold=2.0, lr=0.03):
super().__init__()
self.layers = nn.ModuleList([
FFLayer(dims[i], dims[i+1], threshold)
for i in range(len(dims) - 1)
])
# Per-layer optimizers
for layer in self.layers:
layer.optimizer = torch.optim.Adam(layer.parameters(), lr=lr)
def train_epoch(self, dataloader, neg_generator):
"""Train all layers for one epoch."""
total_loss = 0
for x, y in dataloader:
# Generate negative samples
x_neg = neg_generator(x, y)
# Embed labels
x_pos = self.embed_label(x, y)
x_neg = self.embed_label(x_neg, self.random_labels(y))
# Train layer by layer
h_pos, h_neg = x_pos, x_neg
for layer in self.layers:
loss, h_pos, h_neg = layer.train_step(h_pos, h_neg)
total_loss += loss
return total_loss
def predict(self, x):
"""Predict by finding label with highest goodness."""
best_label, best_goodness = None, -float('inf')
for label in range(10):
x_labeled = self.embed_label(x, label)
h = x_labeled
for layer in self.layers:
h = layer(h)
goodness = layer.goodness(h).mean()
if goodness > best_goodness:
best_label = label
best_goodness = goodness
return best_label
class SelfContrastiveFF(FFNetwork):
"""Self-Contrastive FF (Nature 2025)."""
def __init__(self, dims, threshold=2.0):
super().__init__(dims, threshold)
# Learned negative generator
self.neg_generator = nn.Sequential(
nn.Linear(dims[0], dims[0]),
nn.ReLU(),
nn.Linear(dims[0], dims[0])
)
def generate_negatives(self, x_pos):
"""Generate negatives from positives."""
# Method 1: Learned transformation
x_neg = self.neg_generator(x_pos)
# Method 2: Augmentation (simpler)
# x_neg = x_pos + 0.1 * torch.randn_like(x_pos)
return x_neg
class DistanceForwardLayer(FFLayer):
"""Distance-Forward layer (arXiv:2408.14925)."""
def __init__(self, in_dim, out_dim, num_classes=10):
super().__init__(in_dim, out_dim)
self.class_centers = nn.Parameter(torch.randn(num_classes, out_dim))
def distance_goodness(self, h, labels):
"""Goodness based on distance to class centers."""
centers = self.class_centers[labels]
return -((h - centers) ** 2).sum(dim=-1) # Negative distance
def train_step(self, x, labels):
h = self.forward(x)
goodness = self.distance_goodness(h, labels)
loss = -goodness.mean() # Minimize distance to correct center
self.optimizer.zero_grad()
loss.backward()
self.optimizer.step()
return loss.item(), h.detach()
import jax
import jax.numpy as jnp
from flax import linen as nn
class FFLayerJAX(nn.Module):
features: int
threshold: float = 2.0
@nn.compact
def __call__(self, x):
h = nn.Dense(self.features)(x)
h = nn.relu(h)
return h
def goodness(self, h):
return jnp.sum(h ** 2, axis=-1)
def ff_loss(params, model, x_pos, x_neg, threshold):
"""Forward-Forward loss in JAX."""
h_pos = model.apply(params, x_pos)
h_neg = model.apply(params, x_neg)
g_pos = model.goodness(h_pos)
g_neg = model.goodness(h_neg)
loss_pos = jax.nn.softplus(threshold - g_pos).mean()
loss_neg = jax.nn.softplus(g_neg - threshold).mean()
return loss_pos + loss_neg
@jax.jit
def ff_train_step(params, opt_state, x_pos, x_neg, optimizer):
loss, grads = jax.value_and_grad(ff_loss)(params, model, x_pos, x_neg, 2.0)
updates, opt_state = optimizer.update(grads, opt_state)
params = optax.apply_updates(params, updates)
return params, opt_state, loss
This skill participates in balanced triads:
sheaf-cohomology (-1) ⊗ sheaf-laplacian-coordination (0) ⊗ forward-forward-learning (+1) = 0 ✓
proofgeneral-narya (-1) ⊗ unworld (0) ⊗ forward-forward-learning (+1) = 0 ✓
persistent-homology (-1) ⊗ open-games (0) ⊗ forward-forward-learning (+1) = 0 ✓
# No need to store activations for backward pass
model = FFNetwork([784, 500, 500, 10])
# Memory usage: O(layer_size) not O(depth * layer_size)
# Each layer can train independently
from concurrent.futures import ThreadPoolExecutor
def train_layer(layer, h_pos, h_neg):
return layer.train_step(h_pos, h_neg)
with ThreadPoolExecutor() as executor:
# All layers train in parallel
futures = [executor.submit(train_layer, l, hp, hn)
for l, hp, hn in zip(layers, h_pos_list, h_neg_list)]
# Suitable for neuromorphic hardware
# No weight transport problem (no backprop)
# Local synaptic updates only
# Forward-Forward for NCA rule learning
class FF_NCA(nn.Module):
def __init__(self):
self.perceive = FFLayer(48, 128) # Sobel + identity
self.update = FFLayer(128, 16)
def step(self, grid):
perception = self.perceive(grid)
delta = self.update(perception)
return grid + delta * self.stochastic_mask()
;; In parallel_color_fork.clj
(defn ff-color-learning
"Learn color preferences via Forward-Forward"
[positive-colors negative-colors]
(let [ff-layer (make-ff-layer 3 16) ; RGB -> hidden
goodness-pos (compute-goodness (forward ff-layer positive-colors))
goodness-neg (compute-goodness (forward ff-layer negative-colors))]
(local-update ff-layer goodness-pos goodness-neg)))
| Aspect | Backprop | Forward-Forward |
|---|---|---|
| Memory | O(depth × width) | O(width) |
| Parallelism | Sequential layers | Parallel layers |
| Biological plausibility | Low | Higher |
| Weight transport | Required | Not needed |
| Gradient vanishing | Problem | Avoided |
| On-chip learning | Difficult | Natural |
sheaf-laplacian-coordination - Distributed coordination (complementary coordinator)self-evolving-agent - Continual adaptation (uses FF for local updates)jaxlife-open-ended - Open-ended evolution (FF for agent learning)gay-mcp - Deterministic colors for positive/negative sample generation@article{hinton2022forward,
title={The Forward-Forward Algorithm: Some Preliminary Investigations},
author={Hinton, Geoffrey E},
journal={arXiv:2212.13345},
year={2022}
}
@article{nature2025selfcontrastive,
title={Self-Contrastive Forward-Forward Algorithm},
journal={Nature Communications},
year={2025}
}
@article{wu2024distance,
title={Distance-Forward Learning},
author={Wu, Yujie and others},
journal={arXiv:2408.14925},
year={2024}
}