From 7dcbae6b28f73af2f1ca8bd7d9a0e5e34aa9bc76 Mon Sep 17 00:00:00 2001 From: Merlin Fisher-Levine Date: Wed, 3 Jul 2024 06:24:13 -0700 Subject: [PATCH] Refactor PSF plotting code and add docs and types --- .../summit/extras/plotting/psfPlotting.py | 288 +++++++++++++++++- 1 file changed, 274 insertions(+), 14 deletions(-) diff --git a/python/lsst/summit/extras/plotting/psfPlotting.py b/python/lsst/summit/extras/plotting/psfPlotting.py index 63f2fbd..bed8abf 100644 --- a/python/lsst/summit/extras/plotting/psfPlotting.py +++ b/python/lsst/summit/extras/plotting/psfPlotting.py @@ -19,26 +19,75 @@ # You should have received a copy of the GNU General Public License # along with this program. If not, see . +from __future__ import annotations + __all__ = [ "addColorbarToAxes", + "makeTableFromSourceCatalogs", + "makeFigureAndAxes", "extendTable", "makeFocalPlanePlot", - "makeAzElPlot", "makeEquatorialPlot", - "makeFigureAndAxes", + "makeAzElPlot", ] +from typing import TYPE_CHECKING + import matplotlib.pyplot as plt import numpy as np +from astropy.table import vstack from mpl_toolkits.axes_grid1 import make_axes_locatable from lsst.afw.cameraGeom import FOCAL_PLANE from lsst.afw.geom.ellipses import Quadrupole -from lsst.geom import LinearTransform +from lsst.geom import LinearTransform, radians + +if TYPE_CHECKING: + import numpy.typing as npt + from astropy.table import Table # noqa: F401 + + from lsst.afw.cameraGeom import Camera # noqa: F401 + from lsst.afw.image import VisitInfo # noqa: F401 + from lsst.afw.table import SourceCatalog # noqa: F401 + + +def selectPoints(table: Table, maxPoints: int) -> Table: + """Select a random subset of points from the given table. + + Parameters + ---------- + table : `astropy.table.Table` + The table containing the data to be plotted. + maxPoints : `int` + The maximum number of points to plot. + + Returns + ------- + table : `astropy.table.Table` + The table containing the randomly selected subset of points. + """ + n = len(table) + if n > maxPoints: + rng = np.random.default_rng() + indices = rng.choice(n, maxPoints, replace=False) + table = table[indices] + return table + + +def addColorbarToAxes(mappable: plt.cm.ScalarMappable): + """Add a colorbar to the given axes. + Parameters + ---------- + mappable : `matplotlib.cm.ScalarMappable` + The mappable object to which the colorbar will be added. -def addColorbarToAxes(mappable): + Returns + ------- + cbar : `matplotlib.colorbar.Colorbar` + The colorbar object that was added to the axes. + """ ax = mappable.axes fig = ax.figure divider = make_axes_locatable(ax) @@ -47,7 +96,83 @@ def addColorbarToAxes(mappable): return cbar -def extendTable(table, rot, prefix): +def makeTableFromSourceCatalogs(icSrcs: dict[int, SourceCatalog], visitInfo: VisitInfo) -> Table: + """Extract the shapes from the source catalogs into an astropy table. + + The shapes of the PSF candidates are extracted from the source catalogs and + transformed into the required coordinate systems for plotting either focal + plane coordinates, az/el coordinates, or equatorial coordinates. + + Parameters + ---------- + icSrcs : `dict` [`int`, `lsst.afw.table.SourceCatalog`] + A dictionary of source catalogs, keyed by the detector numbers. + visitInfo : `lsst.afw.image.VisitInfo` + The visit information for a representative visit. + + Returns + ------- + table : `astropy.table.Table` + The table containing the data from the source catalogs. + """ + tables = [] + + for detectorNum, icSrc in icSrcs.items(): + icSrc = icSrc.asAstropy() + icSrc = icSrc[icSrc["calib_psf_candidate"]] + icSrc["detector"] = detectorNum + tables.append(icSrc) + + table = vstack(tables) + # Add shape columns + table["Ixx"] = table["slot_Shape_xx"] * (0.2) ** 2 + table["Ixy"] = table["slot_Shape_xy"] * (0.2) ** 2 + table["Iyy"] = table["slot_Shape_yy"] * (0.2) ** 2 + table["T"] = table["Ixx"] + table["Iyy"] + table["e1"] = (table["Ixx"] - table["Iyy"]) / table["T"] + table["e2"] = 2 * table["Ixy"] / table["T"] + table["e"] = np.hypot(table["e1"], table["e2"]) + table["x"] = table["base_FPPosition_x"] + table["y"] = table["base_FPPosition_y"] + + table.meta["rotTelPos"] = ( + visitInfo.boresightParAngle - visitInfo.boresightRotAngle - (np.pi / 2 * radians) + ).asRadians() + table.meta["rotSkyPos"] = visitInfo.boresightRotAngle.asRadians() + + rtp = table.meta["rotTelPos"] + srtp, crtp = np.sin(rtp), np.cos(rtp) + aaRot = np.array([[crtp, srtp], [-srtp, crtp]]) @ np.array([[0, 1], [1, 0]]) @ np.array([[-1, 0], [0, 1]]) + table = extendTable(table, aaRot, "aa") + table.meta["aaRot"] = aaRot + + rsp = table.meta["rotSkyPos"] + srsp, crsp = np.sin(rsp), np.cos(rsp) + nwRot = np.array([[crsp, -srsp], [srsp, crsp]]) + table = extendTable(table, nwRot, "nw") + table.meta["nwRot"] = nwRot + + return table + + +def extendTable(table: Table, rot: npt.NDArray[np.float_], prefix: str) -> Table: + """Extend the given table with additional columns for the rotated shapes. + + Parameters + ---------- + table : `astropy.table.Table` + The input table containing the original shapes. + rot : `np.ndarray` + The rotation matrix used to rotate the shapes. + prefix : `str` + The prefix to be added to the column names of the rotated shapes. + + Returns + ------- + table : `astropy.table.Table` + The extended table with additional columns representing the rotated + shapes. + """ transform = LinearTransform(rot) rot_shapes = [] for row in table: @@ -64,12 +189,65 @@ def extendTable(table, rot, prefix): return table -def makeFigureAndAxes(): +def makeFigureAndAxes() -> tuple[plt.Figure, np.ndarray[plt.Axes]]: + """Create a figure and axes for plotting. + + Returns + ------- + fig : `matplotlib.figure.Figure`: + The created figure. + axes : `numpy.ndarray` + The created axes. + """ fig, axes = plt.subplots(nrows=2, ncols=2, figsize=(8, 6)) return fig, axes -def makeFocalPlanePlot(fig, axes, table, camera, saveAs=""): +def makeFocalPlanePlot( + fig: plt.Figure, + axes: np.ndarray[plt.Axes], + table: Table, + camera: Camera, + maxPoints: int = 1000, + saveAs: str = "", +): + """Plot the PSFs in focal plane (detector) coordinates i.e. the raw shapes. + + Top left: + A scatter plot of the T values in square arcseconds. + Top right: + A quiver plot of e1 and e2 + Bottom left: + A scatter plot of e1 + Bottom right: + A scatter plot of e2 + + This function plots the data from the `table` on the provided `fig` and + `axes` objects. It also plots the camera detector outlines on the focal + plane plot, respecting the camera rotation for the exposure. + + If `saveAs` is provided, the figure will be saved at the specified file + path. + + Parameters + ---------- + fig : `matplotlib.figure.Figure` + The figure object to plot on. + axes : `numpy.ndarray` + The array of axes objects to plot on. + table : `numpy.ndarray` + The table containing the data to be plotted. + camera : `list` + The list of camera detector objects. + maxPoints : `int`, optional + The maximum number of points to plot. If the number of points in the + table is greater than this value, a random subset of points will be + plotted. + saveAs : `str`, optional + The file path to save the figure. + """ + table = selectPoints(table, maxPoints) + cbar = addColorbarToAxes(axes[0, 0].scatter(table["x"], table["y"], c=table["T"], s=5)) cbar.set_label("T [arcsec$^2$]") @@ -120,7 +298,51 @@ def makeFocalPlanePlot(fig, axes, table, camera, saveAs=""): fig.savefig(saveAs) -def makeEquatorialPlot(fig, axes, table, camera, saveAs=""): +def makeEquatorialPlot( + fig: plt.Figure, + axes: np.ndarray[plt.Axes], + table: Table, + camera: Camera, + maxPoints: int = 1000, + saveAs: str = "", +): + """Plot the PSFs on the focal plane, rotated to equatorial coordinates. + + Top left: + A scatter plot of the T values in square arcseconds. + Top right: + A quiver plot of e1 and e2 + Bottom left: + A scatter plot of e1 + Bottom right: + A scatter plot of e2 + + This function plots the data from the `table` on the provided `fig` and + `axes` objects. It also plots the camera detector outlines on the focal + plane plot, respecting the camera rotation for the exposure. + + If `saveAs` is provided, the figure will be saved at the specified file + path. + + Parameters + ---------- + fig : `matplotlib.figure.Figure` + The figure object to plot on. + axes : `numpy.ndarray` + The array of axes objects to plot on. + table : `numpy.ndarray` + The table containing the data to be plotted. + camera : `list` + The list of camera detector objects. + maxPoints : `int`, optional + The maximum number of points to plot. If the number of points in the + table is greater than this value, a random subset of points will be + plotted. + saveAs : `str`, optional + The file path to save the figure. + """ + table = selectPoints(table, maxPoints) + cbar = addColorbarToAxes(axes[0, 0].scatter(table["nw_x"], table["nw_y"], c=table["T"], s=5)) cbar.set_label("T [arcsec$^2$]") @@ -180,12 +402,50 @@ def makeEquatorialPlot(fig, axes, table, camera, saveAs=""): fig.savefig(saveAs) -def makeAzElPlot(fig, axes, table, camera, maxPoints=1000, saveAs=""): - n = len(table) - if n > maxPoints: - rng = np.random.default_rng() - indices = rng.choice(n, maxPoints, replace=False) - table = table[indices] +def makeAzElPlot( + fig: plt.Figure, + axes: np.ndarray[plt.Axes], + table: Table, + camera: Camera, + maxPoints: int = 1000, + saveAs: str = "", +): + """Plot the PSFs on the focal plane, rotated to az/el coordinates. + + Top left: + A scatter plot of the T values in square arcseconds. + Top right: + A quiver plot of e1 and e2 + Bottom left: + A scatter plot of e1 + Bottom right: + A scatter plot of e2 + + This function plots the data from the `table` on the provided `fig` and + `axes` objects. It also plots the camera detector outlines on the focal + plane plot, respecting the camera rotation for the exposure. + + If `saveAs` is provided, the figure will be saved at the specified file + path. + + Parameters + ---------- + fig : `matplotlib.figure.Figure` + The figure object to plot on. + axes : `numpy.ndarray` + The array of axes objects to plot on. + table : `numpy.ndarray` + The table containing the data to be plotted. + camera : `list` + The list of camera detector objects. + maxPoints : `int`, optional + The maximum number of points to plot. If the number of points in the + table is greater than this value, a random subset of points will be + plotted. + saveAs : `str`, optional + The file path to save the figure. + """ + table = selectPoints(table, maxPoints) cbar = addColorbarToAxes(axes[0, 0].scatter(table["aa_x"], table["aa_y"], c=table["T"], s=5)) cbar.set_label("T [arcsec$^2$]")