Point cloud optimisation with summary statistics

One interesting use case of torch_topological involves changing the shape of a point cloud using topological summary statistics. Such summary statistics can either be used as simple loss functions, constituting a computationally cheap way of assessing the topological similarity of a given point cloud to a target point cloud, or serve to highlight certain topological properties of a single point cloud.

In this example, we will consider both operations.

Ingredients

Our main ingredient is the torch_topological.nn.SummaryStatisticLoss class. This class bundles different summary statistics on persistence diagrams and permits their calculation and comparison.

This class can operate in two modes:

  1. Calculating the loss for a single input data set.

  2. Calculating the loss difference for two input data sets.

Our example will showcase both of these modes!

Optimising all the point clouds

Here’s the bulk of the code required to optimise a point cloud. We will walk through the most important parts!

def main(args):
    """Run example."""
    n_iterations = args.n_iterations
    statistic = args.statistic
    p = args.p
    q = args.q

    X, Y = create_data_set(args)
    vr = VietorisRipsComplex(dim=2)

    if not args.single:
        pi_target = vr(Y)

    loss_fn = SummaryStatisticLoss(
        summary_statistic=statistic,
        p=p,
        q=q
    )

    opt = optim.SGD([X], lr=0.05)
    progress = tqdm(range(n_iterations))

    for i in progress:
        pi_source = vr(X)

        if not args.single:
            loss = loss_fn(pi_source, pi_target)
        else:
            loss = loss_fn(pi_source)

        opt.zero_grad()
        loss.backward()
        opt.step()

        progress.set_postfix(loss=f'{loss.item():.08f}')

    X = X.detach().numpy()

    if args.single:
        plt.scatter(X[:, 0], X[:, 1], label='Result')
        plt.scatter(Y[:, 0], Y[:, 1], label='Initial')
    else:
        plt.scatter(X[:, 0], X[:, 1], label='Source')
        plt.scatter(Y[:, 0], Y[:, 1], label='Target')

    plt.legend()
    plt.show()

Next to creating some test data sets—check out torch_topological.data for more routines—the most important thing is to make sure that X, our point cloud, is a trainable parameter.

With that being out of the way, we can set up the summary statistic loss and start training. The main loop of the training might be familiar to those of you that already have some experience with pytorch: it merely evaluates the loss and optimises it, following a general structure:

# Set up your favourite optimiser
opt = optim.SGD(...)

for i in range(100):

  # Do some calculations and obtain a loss term. In our specific
  # example, we have to get persistence information from data and
  # evaluate the loss based on that.
  loss = ...

  # This is what you will see in many such examples: we set all
  # gradients to zero and do a backwards pass.
  opt.zero_grad()
  loss.backward()
  opt.step()

The rest of this example just involves some nice plotting.

Source code

Here’s the full source code of this example.

"""Demo for summary statistics minimisation of a point cloud.

This example demonstrates how to use various topological summary
statistics in order to change the shape of an input point cloud.
The script can either demonstrate how to adjust the shape of two
point clouds, i.e. using a summary statistic as a loss function,
or how to change the shape of a *single* point cloud. By default
two point clouds will be used.
"""

import argparse

import matplotlib.pyplot as plt

from torch_topological.data import sample_from_disk
from torch_topological.data import sample_from_unit_cube

from torch_topological.nn import SummaryStatisticLoss
from torch_topological.nn import VietorisRipsComplex

from tqdm import tqdm

import torch
import torch.optim as optim


def create_data_set(args):
    """Create data set based on user-provided arguments."""
    n = args.n_samples
    if args.single:
        X = sample_from_unit_cube(n=n, d=2)
        Y = X.clone()
    else:
        X = sample_from_disk(n=n, r=0.5, R=0.6)
        Y = sample_from_disk(n=n, r=0.9, R=1.0)

    # Make source point cloud adjustable by treating it as a parameter.
    # This enables topological loss functions to influence the shape of
    # `X`.
    X = torch.nn.Parameter(torch.as_tensor(X), requires_grad=True)
    return X, Y


def main(args):
    """Run example."""
    n_iterations = args.n_iterations
    statistic = args.statistic
    p = args.p
    q = args.q

    X, Y = create_data_set(args)
    vr = VietorisRipsComplex(dim=2)

    if not args.single:
        pi_target = vr(Y)

    loss_fn = SummaryStatisticLoss(
        summary_statistic=statistic,
        p=p,
        q=q
    )

    opt = optim.SGD([X], lr=0.05)
    progress = tqdm(range(n_iterations))

    for i in progress:
        pi_source = vr(X)

        if not args.single:
            loss = loss_fn(pi_source, pi_target)
        else:
            loss = loss_fn(pi_source)

        opt.zero_grad()
        loss.backward()
        opt.step()

        progress.set_postfix(loss=f'{loss.item():.08f}')

    X = X.detach().numpy()

    if args.single:
        plt.scatter(X[:, 0], X[:, 1], label='Result')
        plt.scatter(Y[:, 0], Y[:, 1], label='Initial')
    else:
        plt.scatter(X[:, 0], X[:, 1], label='Source')
        plt.scatter(Y[:, 0], Y[:, 1], label='Target')

    plt.legend()
    plt.show()


if __name__ == '__main__':
    parser = argparse.ArgumentParser()

    parser.add_argument(
        '-i', '--n-iterations',
        default=250,
        type=int,
        help='Number of iterations'
    )

    parser.add_argument(
        '-n', '--n-samples',
        default=100,
        type=int,
        help='Number of samples in point clouds'
    )

    parser.add_argument(
        '-s', '--statistic',
        choices=[
            'persistent_entropy',
            'polynomial_function',
            'total_persistence',
        ],
        default='polynomial_function',
        help='Name of summary statistic to use for the loss'
    )

    parser.add_argument(
        '-S', '--single',
        action='store_true',
        help='If set, uses only a single point cloud'
    )

    parser.add_argument(
        '-p',
        type=float,
        default=2.0,
        help='Outer exponent for summary statistic loss calculation'
    )

    parser.add_argument(
        '-q',
        type=float,
        default=2.0,
        help='Inner exponent for summary statistic loss calculation. Will '
             'only be used for certain summary statistics.'
    )

    args = parser.parse_args()
    main(args)