|
| 1 | +# This file is part of analysis_tools. |
| 2 | +# |
| 3 | +# Developed for the LSST Data Management System. |
| 4 | +# This product includes software developed by the LSST Project |
| 5 | +# (https://www.lsst.org). |
| 6 | +# See the COPYRIGHT file at the top-level directory of this distribution |
| 7 | +# for details of code ownership. |
| 8 | +# |
| 9 | +# This program is free software: you can redistribute it and/or modify |
| 10 | +# it under the terms of the GNU General Public License as published by |
| 11 | +# the Free Software Foundation, either version 3 of the License, or |
| 12 | +# (at your option) any later version. |
| 13 | +# |
| 14 | +# This program is distributed in the hope that it will be useful, |
| 15 | +# but WITHOUT ANY WARRANTY; without even the implied warranty of |
| 16 | +# MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE. See the |
| 17 | +# GNU General Public License for more details. |
| 18 | +# |
| 19 | +# You should have received a copy of the GNU General Public License |
| 20 | +# along with this program. If not, see <https://www.gnu.org/licenses/>. |
| 21 | + |
| 22 | + |
| 23 | +from ...interfaces import KeyedData, KeyedDataSchema, PlotAction, Scalar, ScalarType, Vector |
| 24 | +from astropy.table import vstack |
| 25 | +from matplotlib.figure import Figure |
| 26 | +import matplotlib.pyplot as plt |
| 27 | +import numpy as np |
| 28 | +from .plotUtils import addPlotInfo |
| 29 | +from typing import Mapping |
| 30 | + |
| 31 | +__all__ = ("PercentilePlot",) |
| 32 | + |
| 33 | + |
| 34 | +class PercentilePlot(PlotAction): |
| 35 | + """Makes a scatter plot of the data with a marginal |
| 36 | + histogram for each axis. |
| 37 | + """ |
| 38 | + |
| 39 | + def getInputSchema(self) -> KeyedDataSchema: |
| 40 | + base: list[tuple[str, type[Vector] | ScalarType]] = [] |
| 41 | + base.append(("amplifier", Vector)) |
| 42 | + base.append(("detector", Vector)) |
| 43 | + base.append(("percentile_0", Vector)) |
| 44 | + base.append(("percentile_5", Vector)) |
| 45 | + base.append(("percentile_16", Vector)) |
| 46 | + base.append(("percentile_50", Vector)) |
| 47 | + base.append(("percentile_84", Vector)) |
| 48 | + base.append(("percentile_95", Vector)) |
| 49 | + base.append(("percentile_100", Vector)) |
| 50 | + return base |
| 51 | + |
| 52 | + def __call__(self, data: KeyedData, **kwargs) -> Mapping[str, Figure] | Figure: |
| 53 | + self._validateInput(data, **kwargs) |
| 54 | + return self.makePlot(data, **kwargs) |
| 55 | + |
| 56 | + def _validateInput(self, data: KeyedData, **kwargs) -> None: |
| 57 | + """NOTE currently can only check that something is not a Scalar, not |
| 58 | + check that the data is consistent with Vector |
| 59 | + """ |
| 60 | + needed = self.getFormattedInputSchema(**kwargs) |
| 61 | + if remainder := {key.format(**kwargs) for key, _ in needed} - { |
| 62 | + key.format(**kwargs) for key in data.keys() |
| 63 | + }: |
| 64 | + raise ValueError(f"Task needs keys {remainder} but they were not found in input") |
| 65 | + for name, typ in needed: |
| 66 | + isScalar = issubclass((colType := type(data[name.format(**kwargs)])), Scalar) |
| 67 | + if isScalar and typ != Scalar: |
| 68 | + raise ValueError(f"Data keyed by {name} has type {colType} but action requires type {typ}") |
| 69 | + |
| 70 | + def makePlot(self, data, plotInfo, **kwargs): |
| 71 | + """Makes a plot showing the percentiles of the normalized distribution |
| 72 | + of the data. |
| 73 | +
|
| 74 | + Parameters |
| 75 | + ---------- |
| 76 | + data : `KeyedData` |
| 77 | + All the data |
| 78 | + plotInfo : `dict` |
| 79 | + A dictionary of information about the data being plotted with keys: |
| 80 | + ``camera`` |
| 81 | + The camera used to take the data (`lsst.afw.cameraGeom.Camera`) |
| 82 | + ``"cameraName"`` |
| 83 | + The name of camera used to take the data (`str`). |
| 84 | + ``"filter"`` |
| 85 | + The filter used for this data (`str`). |
| 86 | + ``"ccdKey"`` |
| 87 | + The ccd/dectector key associated with this camera (`str`). |
| 88 | + ``"visit"`` |
| 89 | + The visit of the data; only included if the data is from a |
| 90 | + single epoch dataset (`str`). |
| 91 | + ``"patch"`` |
| 92 | + The patch that the data is from; only included if the data is |
| 93 | + from a coadd dataset (`str`). |
| 94 | + ``"tract"`` |
| 95 | + The tract that the data comes from (`str`). |
| 96 | + ``"photoCalibDataset"`` |
| 97 | + The dataset used for the calibration, e.g. "jointcal" or "fgcm" |
| 98 | + (`str`). |
| 99 | + ``"skyWcsDataset"`` |
| 100 | + The sky Wcs dataset used (`str`). |
| 101 | + ``"rerun"`` |
| 102 | + The rerun the data is stored in (`str`). |
| 103 | +
|
| 104 | + Returns |
| 105 | + ------ |
| 106 | + ``fig`` |
| 107 | + The figure to be saved (`matplotlib.figure.Figure`). |
| 108 | +
|
| 109 | + Notes |
| 110 | + ----- |
| 111 | + Makes a plot showing the normalized percentile distribution of data. |
| 112 | + """ |
| 113 | + amplifiers = [ |
| 114 | + "C17", |
| 115 | + "C07", |
| 116 | + "C16", |
| 117 | + "C06", |
| 118 | + "C15", |
| 119 | + "C05", |
| 120 | + "C14", |
| 121 | + "C04", |
| 122 | + "C13", |
| 123 | + "C03", |
| 124 | + "C12", |
| 125 | + "C02", |
| 126 | + "C11", |
| 127 | + "C01", |
| 128 | + "C10", |
| 129 | + "C00", |
| 130 | + ] |
| 131 | + # TODO: generalize to make N per-detector plots |
| 132 | + detector = data["detector"] == 0 |
| 133 | + data = vstack([data[detector & (data["amplifier"] == amp)][0] for amp in amplifiers]) |
| 134 | + percentiles = ["0", "5", "16", "50", "84", "95", "100"] |
| 135 | + distributions = [data[f"percentile_{pct}"] for pct in percentiles] |
| 136 | + medians = [np.nanmedian(dist) for dist in distributions] |
| 137 | + normalizedDistributions = [np.abs(dist / med) for (med, dist) in list(zip(medians, distributions))] |
| 138 | + |
| 139 | + fig, axs = plt.subplots(nrows=8, ncols=2, sharex=True, sharey=True) |
| 140 | + # Set threshold for a bad normalized bias. |
| 141 | + threshold = [0.1, 10] |
| 142 | + pcts = [int(pct) for pct in percentiles] |
| 143 | + for i, ax in enumerate(axs.reshape(16)): |
| 144 | + distribution = np.array([dist[i] for dist in normalizedDistributions]) |
| 145 | + colors = np.where((distribution < threshold[0]) | (distribution > threshold[1]), "r", "C0") |
| 146 | + ax.hlines(1.0, xmin=pcts[0], xmax=pcts[-1], colors="k", linestyle="--") |
| 147 | + ax.scatter(pcts, distribution, c=colors) |
| 148 | + ax.plot(pcts, distribution) |
| 149 | + ax.set_ylabel(data["amplifier"][i]) |
| 150 | + ax.set_yscale("log") |
| 151 | + |
| 152 | + plt.xticks(ticks=pcts, labels=percentiles) |
| 153 | + fig.supxlabel("Percentile") |
| 154 | + fig.supylabel("Normalized distribution") |
| 155 | + plt.subplots_adjust(left=None, bottom=None, right=None, top=None, wspace=None, hspace=0) |
| 156 | + |
| 157 | + # Add useful information to the plot |
| 158 | + fig = plt.gcf() |
| 159 | + addPlotInfo(fig, plotInfo) |
| 160 | + return fig |
0 commit comments