Source code for pns.pss

"""Detect the principal subsphere by optimization."""

import warnings

import numpy as np
from scipy.optimize import least_squares

from .base import circle_mean, exp_map, log_map, rotation_matrix

__all__ = [
    "pss",
]


[docs] def pss(x, tol=1e-3, maxiter=None, lm_kwargs=None): r"""Find the principal subsphere (PSS) from data on a hypersphere. Parameters ---------- x : (N, d+1) real array Extrinsic coordinates of data on a ``d``-dimensional hypersphere, embedded in a ``d+1``-dimensional space. tol : float, default=1e-3 Convergence tolerance in radian. maxiter : int, optional Maximum number of iterations for the optimization. If None, the number of iterations is not checked. lm_kwargs : dict, optional Additional keyword arguments to be passed for Levenberg-Marquardt optimization. Follows the signature of :func:`scipy.optimize.least_squares`. Returns ------- v : (d+1,) real array Estimated principal axis of the subsphere in extrinsic coordinates. r : scalar in [0, pi] Geodesic distance from the pole by *v* to the estimated principal subsphere. Notes ----- This function determines the best fitting subsphere :math:`\hat{A}_{d-k} = A_{d-k}(\hat{v}_k, \hat{r}_k) \subset S^{d-k+1}` for :math:`k = 1, 2, \ldots, d`. The Fréchet mean :math:`\hat{A}_0` of the lowest level best fitting subsphere :math:`\hat{A}_1` is also determined by this function. Examples -------- >>> from pns.pss import pss >>> from pns.util import unit_sphere, circular_data, circle_3d >>> x = circular_data([0, -1, 0]) >>> v, r = pss(x) >>> import matplotlib.pyplot as plt # doctest: +SKIP ... ax = plt.figure().add_subplot(projection='3d', computed_zorder=False) ... ax.plot_surface(*unit_sphere(), color='skyblue', alpha=0.6, edgecolor='gray') ... ax.scatter(*x.T, marker="x") ... ax.plot(*circle_3d(v, r), color="tab:orange", zorder=10) """ if lm_kwargs is None: lm_kwargs = {} else: lm_kwargs = lm_kwargs.copy() lm_kwargs.pop("method", None) lm_kwargs.pop("args", None) _, D = x.shape if D <= 1: raise ValueError("Data must be on at least 1-sphere.") elif D == 2: # Circle mean r = 0 v = circle_mean(x) else: pole = np.array([0] * (D - 1) + [1]) R = np.eye(D) _x = x v, r = _pss(_x, lm_kwargs=lm_kwargs) iter_count = 0 while np.arccos(np.dot(pole, v)) > tol: if iter_count == maxiter: warnings.warn( f"Maximum number of iterations ({maxiter}) reached. " "Optimization may not have converged.", UserWarning, stacklevel=2, ) break # Rotate so that v becomes the pole _x, _R = _rotate(_x, v) v, r = _pss(_x, lm_kwargs=lm_kwargs) R = R @ _R.T iter_count += 1 v = R @ v # re-rotate back return v, r
def _rotate(pts, v): R = rotation_matrix(v) return pts @ R.T, R def _pss(pts, lm_kwargs): # Projection x_dag = log_map(pts) v_dag_init = np.mean(x_dag, axis=0) r_init = np.mean(np.linalg.norm(x_dag - v_dag_init, axis=1)) init = np.concatenate([v_dag_init, [r_init]]) # Optimization opt = least_squares(_res, init, _jac, method="lm", args=(x_dag,), **lm_kwargs).x v_dag_opt, r_opt = opt[:-1], opt[-1] v_opt = exp_map(v_dag_opt.reshape(1, -1)).reshape(-1) r_opt = np.mod(r_opt, np.pi) return v_opt, r_opt def _res(params, x_dag): v_dag, r = params[:-1], params[-1] return np.linalg.norm(x_dag - v_dag.reshape(1, -1), axis=1) - r def _jac(params, x_dag): n = len(x_dag) m = len(params) v_dag = params[:-1].reshape(1, -1) diff = x_dag - v_dag dist = np.linalg.norm(diff, axis=1) mask = dist > 1e-12 out = np.empty((n, m)) out[mask, :-1] = -diff[mask] / dist[mask][:, None] out[~mask, :-1] = 0.0 out[:, -1] = -1.0 return out