Source code for pns.util

"""Utility functions to generate and plot sample data."""

import numpy as np

from .base import rotation_matrix

__all__ = [
    "unit_sphere",
    "circular_data",
    "circle_3d",
]


[docs] def unit_sphere(): """Helper function to plot a unit sphere. Returns ------- x, y, z : array Coordinates for unit sphere. Examples -------- >>> from pns.util import unit_sphere >>> import matplotlib.pyplot as plt # doctest: +SKIP ... ax = plt.figure().add_subplot(projection='3d') ... ax.plot_surface(*unit_sphere(), color='skyblue', alpha=0.6, edgecolor='gray') """ u = np.linspace(0, 2 * np.pi, 100) v = np.linspace(0, np.pi, 100) x = np.outer(np.cos(u), np.sin(v)) y = np.outer(np.sin(u), np.sin(v)) z = np.outer(np.ones_like(u), np.cos(v)) return x, y, z
[docs] def circular_data(v=(0, 0, 1)): """Circular data on a 3D unit sphere. Parameters ---------- v : array of shape (3,), default=(0, 0, 1) Unit vector to center of circle. Returns ------- ndarray of shape (100, 3) Data coordinates. Examples -------- >>> from pns.util import unit_sphere, circular_data >>> v = [0, -1, 0] >>> X = circular_data(v) >>> X.shape (100, 3) >>> import matplotlib.pyplot as plt # doctest: +SKIP ... ax = plt.figure().add_subplot(projection='3d', computed_zorder=False) ... ax.plot_surface(*unit_sphere(), color='skyblue', alpha=0.6, edgecolor='gray') ... ax.scatter(*X.T) """ t = np.linspace(0.1 * np.pi, 0.4 * np.pi, 10) p = np.linspace(-np.pi, np.pi / 2, 10) T, P = np.meshgrid(t, p) x = np.array([np.sin(T) * np.cos(P), np.sin(T) * np.sin(P), np.cos(T)]).T return x.reshape(-1, 3) @ rotation_matrix(v)
[docs] def circle_3d(v, theta, n=100): """Helper function to plot a circle in 3D. Parameters ---------- v : (3,) array Unit vector to center of circle in 3D. theta : scalar Geodesic distance. n : int, default=100 Number of points. Examples -------- >>> from pns.util import unit_sphere, circle_3d >>> circle = circle_3d([0, -1, 0], 0.5) >>> import matplotlib.pyplot as plt # doctest: +SKIP ... ax = plt.figure().add_subplot(projection='3d') ... ax.plot_surface(*unit_sphere(), color='skyblue', alpha=0.6, edgecolor='gray') ... ax.plot(*circle, color="tab:orange", zorder=10) """ phi = np.linspace(0, 2 * np.pi, n) x = np.sin(theta) * np.cos(phi) y = np.sin(theta) * np.sin(phi) z = np.full_like(phi, np.cos(theta)) circle = np.stack([x, y, z], axis=1) north_pole = np.array([0.0, 0.0, 1.0]) u = v - north_pole u /= np.linalg.norm(u) H = np.eye(3) - 2 * np.outer(u, u) return H @ circle.T