Autoencoders with geometrical–topological losses
In this example, we will create a simple autoencoder based on the Topological Signature Loss introduced by Moor et al. [Moor20a].
A simple autoencoder
We first define a simple linear autoencoder. The representations obtained from this autoencoder are very similar to those obtained via PCA.
class LinearAutoencoder(torch.nn.Module):
"""Simple linear autoencoder class.
This module performs simple embeddings based on an MSE loss. This is
similar to ordinary principal component analysis. Notice that the
class is only meant to provide a simple example that can be run
easily even without the availability of a GPU. In practice, there
are many more architectures with improved expressive power
available.
"""
def __init__(self, input_dim, latent_dim=2):
"""Create new autoencoder with pre-defined latent dimension."""
super().__init__()
self.input_dim = input_dim
self.latent_dim = latent_dim
self.encoder = torch.nn.Sequential(
torch.nn.Linear(self.input_dim, self.latent_dim)
)
self.decoder = torch.nn.Sequential(
torch.nn.Linear(self.latent_dim, self.input_dim)
)
self.loss_fn = torch.nn.MSELoss()
def encode(self, x):
"""Embed data in latent space."""
return self.encoder(x)
def decode(self, z):
"""Decode data from latent space."""
return self.decoder(z)
def forward(self, x):
"""Embeds and reconstructs data, returning a loss."""
z = self.encode(x)
x_hat = self.decode(z)
# The loss can of course be changed. If this is your first time
# working with autoencoders, a good exercise would be to 'grok'
# the meaning of different losses.
reconstruction_error = self.loss_fn(x, x_hat)
return reconstruction_error
Of particular interest in the code are the encode
and decode
functions. With encode
, we embed data in a latent space, whereas
with decode
, we reconstruct it to its ‘original’ space.
This reconstruction is of course never perfect. We therefore measure is quality using a reconstruction loss. Let’s zoom into the specific function for this:
def forward(self, x):
"""Embeds and reconstructs data, returning a loss."""
z = self.encode(x)
x_hat = self.decode(z)
# The loss can of course be changed. If this is your first time
# working with autoencoders, a good exercise would be to 'grok'
# the meaning of different losses.
reconstruction_error = self.loss_fn(x, x_hat)
return reconstruction_error
The important take-away here is that forward
should return at least
return one loss value. We will make use of this later on!
A topological wrapper for autoencoder models
Our previous model uses encode
to provide us with a lower-dimensional
representation, the so-called latent representation. We can use this
representation in order to calculate a topology-based loss! To this end,
let’s write a new forward
function that uses an existing model model
for the latent space generation:
def forward(self, x):
z = self.model.encode(x)
pi_x = self.vr(x)
pi_z = self.vr(z)
geom_loss = self.model(x)
topo_loss = self.loss([x, pi_x], [z, pi_z])
loss = geom_loss + self.lam * topo_loss
return loss
In the code above, the important things are:
The use of a Vietoris–Rips complex
self.vr
to obtain persistence information about the input spacex
and the latent spacez
, respectively. We call this type of datapi_x
andpi_z
, respectively.The call to a topology-based loss function
self.loss()
, which takes two spacesx
andy
, as well as their corresponding persistence information, to calculate the signature loss from [Moor20a].
Putting this all together, we have the following ‘wrapper class’ that makes an existing model topology-aware:
class TopologicalAutoencoder(torch.nn.Module):
"""Wrapper for a topologically-regularised autoencoder.
This class uses another autoencoder model and imbues it with an
additional topology-based loss term.
"""
def __init__(self, model, lam=1.0):
super().__init__()
self.lam = lam
self.model = model
self.loss = SignatureLoss(p=2)
# TODO: Make dimensionality configurable
self.vr = VietorisRipsComplex(dim=0)
def forward(self, x):
z = self.model.encode(x)
pi_x = self.vr(x)
pi_z = self.vr(z)
geom_loss = self.model(x)
topo_loss = self.loss([x, pi_x], [z, pi_z])
loss = geom_loss + self.lam * topo_loss
return loss
See [Moor20a] for more models to extend—being topology-aware can be crucial for many applications.
Source code
Here’s the full source code of this example.
"""Demo for topology-regularised autoencoders.
This example demonstrates how to use `pytorch-topological` to create an
additional differentiable loss term that makes autoencoders aware of
topological features. See [Moor20a]_ for more information.
"""
import torch
import torch.optim as optim
import matplotlib.pyplot as plt
from tqdm import tqdm
from torch.utils.data import DataLoader
from torch_topological.datasets import Spheres
from torch_topological.nn import SignatureLoss
from torch_topological.nn import VietorisRipsComplex
class LinearAutoencoder(torch.nn.Module):
"""Simple linear autoencoder class.
This module performs simple embeddings based on an MSE loss. This is
similar to ordinary principal component analysis. Notice that the
class is only meant to provide a simple example that can be run
easily even without the availability of a GPU. In practice, there
are many more architectures with improved expressive power
available.
"""
def __init__(self, input_dim, latent_dim=2):
"""Create new autoencoder with pre-defined latent dimension."""
super().__init__()
self.input_dim = input_dim
self.latent_dim = latent_dim
self.encoder = torch.nn.Sequential(
torch.nn.Linear(self.input_dim, self.latent_dim)
)
self.decoder = torch.nn.Sequential(
torch.nn.Linear(self.latent_dim, self.input_dim)
)
self.loss_fn = torch.nn.MSELoss()
def encode(self, x):
"""Embed data in latent space."""
return self.encoder(x)
def decode(self, z):
"""Decode data from latent space."""
return self.decoder(z)
def forward(self, x):
"""Embeds and reconstructs data, returning a loss."""
z = self.encode(x)
x_hat = self.decode(z)
# The loss can of course be changed. If this is your first time
# working with autoencoders, a good exercise would be to 'grok'
# the meaning of different losses.
reconstruction_error = self.loss_fn(x, x_hat)
return reconstruction_error
class TopologicalAutoencoder(torch.nn.Module):
"""Wrapper for a topologically-regularised autoencoder.
This class uses another autoencoder model and imbues it with an
additional topology-based loss term.
"""
def __init__(self, model, lam=1.0):
super().__init__()
self.lam = lam
self.model = model
self.loss = SignatureLoss(p=2)
# TODO: Make dimensionality configurable
self.vr = VietorisRipsComplex(dim=0)
def forward(self, x):
z = self.model.encode(x)
pi_x = self.vr(x)
pi_z = self.vr(z)
geom_loss = self.model(x)
topo_loss = self.loss([x, pi_x], [z, pi_z])
loss = geom_loss + self.lam * topo_loss
return loss
if __name__ == '__main__':
# We first have to create a data set. This follows the original
# publication by Moor et al. by introducing a simple 'manifold'
# data set consisting of multiple spheres.
n_spheres = 11
data_set = Spheres(n_spheres=n_spheres)
train_loader = DataLoader(
data_set,
batch_size=32,
shuffle=True,
drop_last=True
)
# Let's set up the two models that we are training. Note that in
# a real application, you would have a more complicated training
# setup, potentially with early stopping etc. This training loop
# is merely to be seen as a proof of concept.
model = LinearAutoencoder(input_dim=data_set.dimension)
topo_model = TopologicalAutoencoder(model, lam=10)
optimizer = optim.Adam(topo_model.parameters(), lr=1e-3)
n_epochs = 5
progress = tqdm(range(n_epochs))
for i in progress:
topo_model.train()
for batch, (x, y) in enumerate(train_loader):
loss = topo_model(x)
optimizer.zero_grad()
loss.backward()
optimizer.step()
progress.set_postfix(loss=loss.item())
# Evaluate the autoencoder on a new instance of the data set.
data_set = Spheres(
train=False,
n_samples=2000,
n_spheres=n_spheres,
)
test_loader = DataLoader(
data_set,
shuffle=False,
batch_size=len(data_set)
)
X, y = next(iter(test_loader))
Z = topo_model.model.encode(X).detach().numpy()
plt.scatter(
Z[:, 0], Z[:, 1],
c=y,
cmap='Set1',
marker='o',
alpha=0.9,
)
plt.show()