Point cloud optimisation with summary statistics
One interesting use case of torch_topological
involves changing the
shape of a point cloud using topological summary statistics. Such
summary statistics can either be used as simple loss functions,
constituting a computationally cheap way of assessing the topological
similarity of a given point cloud to a target point cloud, or
serve to highlight certain topological properties of a single point
cloud.
In this example, we will consider both operations.
Ingredients
Our main ingredient is the torch_topological.nn.SummaryStatisticLoss
class. This class bundles different summary statistics on persistence
diagrams and permits their calculation and comparison.
This class can operate in two modes:
Calculating the loss for a single input data set.
Calculating the loss difference for two input data sets.
Our example will showcase both of these modes!
Optimising all the point clouds
Here’s the bulk of the code required to optimise a point cloud. We will walk through the most important parts!
def main(args):
"""Run example."""
n_iterations = args.n_iterations
statistic = args.statistic
p = args.p
q = args.q
X, Y = create_data_set(args)
vr = VietorisRipsComplex(dim=2)
if not args.single:
pi_target = vr(Y)
loss_fn = SummaryStatisticLoss(
summary_statistic=statistic,
p=p,
q=q
)
opt = optim.SGD([X], lr=0.05)
progress = tqdm(range(n_iterations))
for i in progress:
pi_source = vr(X)
if not args.single:
loss = loss_fn(pi_source, pi_target)
else:
loss = loss_fn(pi_source)
opt.zero_grad()
loss.backward()
opt.step()
progress.set_postfix(loss=f'{loss.item():.08f}')
X = X.detach().numpy()
if args.single:
plt.scatter(X[:, 0], X[:, 1], label='Result')
plt.scatter(Y[:, 0], Y[:, 1], label='Initial')
else:
plt.scatter(X[:, 0], X[:, 1], label='Source')
plt.scatter(Y[:, 0], Y[:, 1], label='Target')
plt.legend()
plt.show()
Next to creating some test data sets—check out torch_topological.data
for more routines—the most important thing is to make sure that X
,
our point cloud, is a trainable parameter.
With that being out of the way, we can set up the summary statistic loss
and start training. The main loop of the training might be familiar to
those of you that already have some experience with pytorch
: it
merely evaluates the loss and optimises it, following a general
structure:
# Set up your favourite optimiser
opt = optim.SGD(...)
for i in range(100):
# Do some calculations and obtain a loss term. In our specific
# example, we have to get persistence information from data and
# evaluate the loss based on that.
loss = ...
# This is what you will see in many such examples: we set all
# gradients to zero and do a backwards pass.
opt.zero_grad()
loss.backward()
opt.step()
The rest of this example just involves some nice plotting.
Source code
Here’s the full source code of this example.
"""Demo for summary statistics minimisation of a point cloud.
This example demonstrates how to use various topological summary
statistics in order to change the shape of an input point cloud.
The script can either demonstrate how to adjust the shape of two
point clouds, i.e. using a summary statistic as a loss function,
or how to change the shape of a *single* point cloud. By default
two point clouds will be used.
"""
import argparse
import matplotlib.pyplot as plt
from torch_topological.data import sample_from_disk
from torch_topological.data import sample_from_unit_cube
from torch_topological.nn import SummaryStatisticLoss
from torch_topological.nn import VietorisRipsComplex
from tqdm import tqdm
import torch
import torch.optim as optim
def create_data_set(args):
"""Create data set based on user-provided arguments."""
n = args.n_samples
if args.single:
X = sample_from_unit_cube(n=n, d=2)
Y = X.clone()
else:
X = sample_from_disk(n=n, r=0.5, R=0.6)
Y = sample_from_disk(n=n, r=0.9, R=1.0)
# Make source point cloud adjustable by treating it as a parameter.
# This enables topological loss functions to influence the shape of
# `X`.
X = torch.nn.Parameter(torch.as_tensor(X), requires_grad=True)
return X, Y
def main(args):
"""Run example."""
n_iterations = args.n_iterations
statistic = args.statistic
p = args.p
q = args.q
X, Y = create_data_set(args)
vr = VietorisRipsComplex(dim=2)
if not args.single:
pi_target = vr(Y)
loss_fn = SummaryStatisticLoss(
summary_statistic=statistic,
p=p,
q=q
)
opt = optim.SGD([X], lr=0.05)
progress = tqdm(range(n_iterations))
for i in progress:
pi_source = vr(X)
if not args.single:
loss = loss_fn(pi_source, pi_target)
else:
loss = loss_fn(pi_source)
opt.zero_grad()
loss.backward()
opt.step()
progress.set_postfix(loss=f'{loss.item():.08f}')
X = X.detach().numpy()
if args.single:
plt.scatter(X[:, 0], X[:, 1], label='Result')
plt.scatter(Y[:, 0], Y[:, 1], label='Initial')
else:
plt.scatter(X[:, 0], X[:, 1], label='Source')
plt.scatter(Y[:, 0], Y[:, 1], label='Target')
plt.legend()
plt.show()
if __name__ == '__main__':
parser = argparse.ArgumentParser()
parser.add_argument(
'-i', '--n-iterations',
default=250,
type=int,
help='Number of iterations'
)
parser.add_argument(
'-n', '--n-samples',
default=100,
type=int,
help='Number of samples in point clouds'
)
parser.add_argument(
'-s', '--statistic',
choices=[
'persistent_entropy',
'polynomial_function',
'total_persistence',
],
default='polynomial_function',
help='Name of summary statistic to use for the loss'
)
parser.add_argument(
'-S', '--single',
action='store_true',
help='If set, uses only a single point cloud'
)
parser.add_argument(
'-p',
type=float,
default=2.0,
help='Outer exponent for summary statistic loss calculation'
)
parser.add_argument(
'-q',
type=float,
default=2.0,
help='Inner exponent for summary statistic loss calculation. Will '
'only be used for certain summary statistics.'
)
args = parser.parse_args()
main(args)