Deep Hebbian Image Encoder
Hebbian Deep Encoding for Image Clustering
This project is exploring new Hebbian methods to train stackable image encoders. Afterward, the learned image embeddings are visualized using UMAP + K-Means Clustering.
Below is a sample clustering of learned image embeddings from the Tiny ImageNet dataset.

Source code on Github
Why Hebbian? (Biological Analogy)
Unlike traditional neural networks that rely on error signals and gradient descent, Hebbian models operate more like biological brains. There’s no supervision, no target output, no error propagation. Just forward input, and strengthening connections between neurons that activate together.
Neurons that fire together, wire together.
This means the network doesn’t know what the correct answer is, it just recognizes and reinforces co-occurrence in the data. If two features appear frequently at the same time, their connection strengthens. Over time, this results in internal representations that reflect the structure of the data, without needing labels or supervision.
This approach mirrors parts of the brain’s learning strategy, where connections are locally updated based on experience, rather than global goals. It’s simpler, and more biologically plausible, and that simplicity might be useful for efficient or modular systems in the future.
Dataset Preparation
Images were extracted from either a sprite sheet or folder (e.g., Tiny ImageNet). For the Pokémon dataset, we used a 30x30 grid (898 total) from a transparent PNG.
sprites = load_spritesheet("pokemon_all_transparent.png", sprite_size=(96, 96), tile_size=(96, 96))
For Tiny ImageNet:
dataset = ImageFolder("../data/tiny_imagenet/train", transform=transforms.Compose([
transforms.Resize((96, 96)),
transforms.ToTensor()
]))
Hebbian Encoder Architecture
Each Hebbian encoder layer applies:
- A %%3 \times 3%% convolution with stride 2
- ReLU activation
- L2 normalization over spatial dimensions
- Lateral weights updated by Hebbian rule:
$$ \Delta W_{ij} = \eta \cdot \langle a_i \cdot a_j \rangle $$
Where:
- %%\Delta W_{ij}%% is the change in weight from unit %%j%% to %%i%%
- %%\eta%% is the learning rate
- %%a_i%% and %%a_j%% are the activations of units %%i%% and %%j%% respectively
- %%\langle a_i \cdot a_j \rangle%% denotes the batch-averaged outer product
This update promotes co-activation patterns. Inhibition is enforced by subtracting mean activation:
hebbian = torch.einsum("bni,bnj->nij", act_flat, act_flat) # outer product over batch
delta = 0.001 * hebbian.mean(dim=0)
self.lateral_weights.data += delta.clamp(-1.0, 1.0)
Energy, Delta, Norm Logging
During training, the following values are logged per step:
- Energy: mean squared activation across all units, i.e., %%\mathbb{E}[|a|^2]%%
- Delta: mean absolute change in lateral weights
- Norm: Frobenius norm of the lateral weight matrix, i.e., %%|W|_F%%
Example:
[LOG] Step 122: energy=0.1466, delta=0.000166, norm=155.7744
This gives insight into encoder dynamics: stable energy and delta values indicate convergence, while growing norm may suggest over-association.
Feature Extraction
Images are passed through a multi-layer encoder consisting of 4 HebbianEncoder layers. The final feature map is flattened to a 1D vector and stored.
features = model(images)
features = F.normalize(features, dim=1).cpu().numpy()
Hebbian Network Structure
The Hebbian encoder processes 96×96 RGBA Pokémon sprites using a stack of convolutional layers with stride 2. Each layer halves the spatial resolution while increasing the channel count. Each Hebbian layer also includes lateral recurrent weights trained with Hebbian updates to reinforce co-activation patterns.
model = MultiLayerHebbian([
(4, 16, (48, 48)),
(16, 32, (24, 24)),
(32, 64, (12, 12)),
(64, 128, (6, 6))
])
Each tuple in the list specifies the parameters for a HebbianEncoder
layer:
(in_channels, out_channels, spatial_shape)
This configuration maps as follows:
Layer | Input Channels | Output Channels | Input Spatial Size | Output Spatial Size |
---|---|---|---|---|
1 | 4 (RGBA) | 16 | 96×96 | 48×48 |
2 | 16 | 32 | 48×48 | 24×24 |
3 | 32 | 64 | 24×24 | 12×12 |
4 | 64 | 128 | 12×12 | 6×6 |
This structure results in a final feature tensor of shape (B, 128, 6, 6)
per image, which is flattened to (B, 4608)
and used for clustering and visualization.
The spatial
argument in each HebbianEncoder
is required to initialize the lateral weight matrix:
self.lateral_weights = torch.nn.Parameter(torch.zeros(C, N, N))
where N = H × W
(the number of spatial positions per channel). These weights are updated using Hebbian learning:
$$ \Delta W_{ij} = \eta \cdot \langle a_i \cdot a_j \rangle $$
which in code becomes:
hebbian = torch.einsum("bni,bnj->nij", act_flat, act_flat)
delta = 0.001 * hebbian.mean(dim=0)
self.lateral_weights.data += delta
This configuration balances spatial compression and representational capacity, while the Hebbian lateral updates encourage neurons to specialize by detecting and reinforcing co-activation patterns.
Embedding and Clustering Visualization
UMAP is used to project feature vectors to 2D.
reducer = UMAP(n_components=2, random_state=42)
reduced = reducer.fit_transform(features)
Optionally, KMeans clustering is applied to the feature space:
kmeans = KMeans(n_clusters=6, random_state=42).fit(features)
These steps are primarily for visualization and evaluation. They allow us to inspect whether the encoder has organized inputs meaningfully—not as a training objective.
Layout and Plotting
Projected coordinates are normalized to fit inside a square canvas (e.g., 2500x2500 pixels). Margin padding ensures images are not clipped.
reduced -= reduced.min(axis=0)
reduced /= (reduced.max(axis=0) + 1e-8)
reduced *= (canvas_size - 2 * margin)
reduced += margin
Sprites are drawn onto the canvas using their corresponding (x, y) UMAP coordinates.
for (x, y), img in zip(reduced, sprites):
pil = transforms.ToPILImage()(img).resize((sprite_size, sprite_size))
canvas.paste(pil, (int(x), int(y)), mask=pil if has_alpha else None)
Code + Results
Tiny ImageNet (Hebbian)

Pokémon Full RGBA Hebbian

Pokémon Similarity
The first column contains randomly selected Pokémon, and then the most-similar 5 Pokémon are listed to the right.

Hebbian Image Encoder (Single-Layer)
This was the first prototype. The first column contains randomly selected image (from Tiny Imagenet), and then the most-similar 5 images are listed to the right.

Hebbian Deep Image Encoder - Tiny ImageNet - Source Code
import sys
import os
import torch
import torch.nn.functional as F
from torch.utils.data import DataLoader
from torchvision import transforms
from torchvision.datasets import ImageFolder
from PIL import Image
import matplotlib.pyplot as plt
import numpy as np
from sklearn.cluster import KMeans
from umap import UMAP
# === Config ===
IMAGE_DIR = "../data/tiny_imagenet/train"
SPRITE_SIZE = (64, 64)
BATCH_SIZE = 8
NUM_IMAGES = 1000
EPOCHS = 50
CLUSTERS = 4
EMBED_SIZE = 32 # thumbnail size
CANVAS_SIZE = 2500
# === Hebbian Network ===
class HebbianEncoder(torch.nn.Module):
def __init__(self, in_channels, out_channels, spatial):
super().__init__()
self.encode = torch.nn.Conv2d(in_channels, out_channels, kernel_size=3, stride=2, padding=1)
C = out_channels
N = spatial[0] * spatial[1]
self.lateral_weights = torch.nn.Parameter(torch.zeros(C, N, N))
def forward(self, x, step=None):
act = F.relu(self.encode(x))
act = act / (act.norm(dim=(2, 3), keepdim=True) + 1e-6)
B, C, H, W = act.shape
act_flat = act.view(B, C, -1)
with torch.no_grad():
hebbian = torch.einsum("bni,bnj->nij", act_flat, act_flat)
delta = 0.001 * hebbian.mean(dim=0)
self.lateral_weights.data += delta
self.lateral_weights.data.clamp_(-1.0, 1.0)
lateral = torch.einsum("bci,cij->bcj", act_flat, self.lateral_weights)
lateral = lateral.view(B, C, H, W)
lateral = lateral - lateral.mean(dim=(2, 3), keepdim=True)
act += lateral
if step is not None:
print(f"[LOG] Step {step}: energy={act.pow(2).mean():.4f}, delta={delta.abs().mean():.6f}, norm={self.lateral_weights.data.norm():.4f}")
return act
class MultiLayerHebbian(torch.nn.Module):
def __init__(self, layer_shapes):
super().__init__()
self.layers = torch.nn.ModuleList([
HebbianEncoder(in_c, out_c, spatial) for (in_c, out_c, spatial) in layer_shapes
])
def forward(self, x, step=None):
for i, layer in enumerate(self.layers):
x = layer(x, step=step if i == len(self.layers) - 1 else None)
return x.view(x.size(0), -1).detach()
# === Load Dataset ===
def load_dataset():
transform = transforms.Compose([
transforms.Resize(SPRITE_SIZE),
transforms.ToTensor()
])
dataset = ImageFolder(IMAGE_DIR, transform=transform)
subset = torch.utils.data.Subset(dataset, list(range(min(NUM_IMAGES, len(dataset)))))
loader = DataLoader(subset, batch_size=BATCH_SIZE, shuffle=False)
return subset, loader
# === Plot Utility ===
def plot_with_images(embeddings, images, title="Hebbian Clusters", size=32, canvas_size=2500):
fig, ax = plt.subplots(figsize=(canvas_size / 100, canvas_size / 100), facecolor='black', dpi=100)
ax.set_facecolor('black')
ax.set_title(title, color='white')
ax.set_xticks([])
ax.set_yticks([])
from matplotlib.offsetbox import OffsetImage, AnnotationBbox
# Normalize coordinates to canvas
margin = size * 2
embeddings -= embeddings.min(axis=0)
embeddings /= (embeddings.max(axis=0) + 1e-8)
embeddings *= (canvas_size - 2 * margin)
embeddings += margin
for (x, y), img_tensor in zip(embeddings, images):
img = transforms.ToPILImage()(img_tensor).resize((size, size), resample=Image.BILINEAR).convert("RGB")
imbox = OffsetImage(img, zoom=1.5) # zoom factor for visibility
ab = AnnotationBbox(imbox, (x, y), frameon=False)
ax.add_artist(ab)
ax.set_xlim(0, canvas_size)
ax.set_ylim(0, canvas_size)
ax.invert_yaxis()
plt.tight_layout()
plt.savefig("tinyimagenet_hebbian_cluster_plot.png", facecolor='black')
print("[SAVED] tinyimagenet_hebbian_cluster_plot.png")
# === Main ===
if __name__ == "__main__":
dataset, dataloader = load_dataset()
model = MultiLayerHebbian([
(3, 16, (32, 32)),
(16, 32, (16, 16)),
(32, 64, (8, 8)),
(64, 128, (4, 4))
])
all_features = []
for epoch in range(EPOCHS):
for step, (batch, _) in enumerate(dataloader):
z = model(batch, step=step) if epoch == EPOCHS - 1 else model(batch)
if epoch == EPOCHS - 1:
all_features.append(z)
features = torch.cat(all_features, dim=0).cpu().numpy()
features = np.nan_to_num(features)
features /= (np.linalg.norm(features, axis=1, keepdims=True) + 1e-6)
reducer = UMAP(n_components=2, random_state=42) #, min_dist=0.2)
reduced = reducer.fit_transform(features)
margin = EMBED_SIZE // 2
reduced -= reduced.min(axis=0)
reduced /= (reduced.max(axis=0) + 1e-8)
reduced *= (CANVAS_SIZE - 2 * margin)
reduced += margin
all_images = [img for img, _ in dataset]
plot_with_images(reduced, all_images)