Source code for gpyconform.prediction_intervals

#!/usr/bin/env python3

import torch
import warnings

from gpyconform.constants import _CONF_DP, _CONF_SCALE

[docs] class PredictionIntervals: """ Contains the conformal Prediction Intervals (PIs) at one or more confidence levels and provides functionality for their retrieval and evaluation. Notes ----- Confidence levels are internally stored using fixed-point integer keys to avoid float equality issues when retrieving intervals for a requested level. """ def __init__(self, conf_levels: torch.Tensor, all_pis: torch.Tensor): self.all_pis = all_pis # Fixed-point integer keys: exact, to avoid float-equality issues conf_scaled = torch.round(conf_levels * _CONF_SCALE) self._conf_keys = conf_scaled.to(torch.int64) # One-time index map: int_key -> row index in all_pis self._index = {int(k.item()): i for i, k in enumerate(self._conf_keys)}
[docs] def __call__(self, conf_level: float | None = None, *, y_min = float('-inf'), y_max = float('inf'), dp: int = 6): """ Returns the Prediction Intervals for a specified confidence level or all intervals if confidence level is not specified. Parameters ---------- conf_level : float in range (0,1), optional Confidence level for which to return the corresponding Prediction Intervals. If not specified, the Prediction Intervals for all confidence levels will be returned. y_min : float, keyword-only, default=-inf If provided, PIs are cut to exclude values below y_min. y_max : float, keyword-only, default=inf If provided, PIs are cut to exclude values above y_max. dp : int in range [1,6], keyword-only, default=6 Number of decimals to show in the string keys when returning all levels. Returns ------- torch.Tensor or dict[str, torch.Tensor] A torch tensor with the Prediction Intervals for the specified ``conf_level``, or a dictionary with confidence levels as keys (str) and the corresponding Prediction Interval tensors as values if ``conf_level`` is None. Examples -------- Assuming ``PIs`` is an instance of ``PredictionIntervals`` that includes the 95% confidence level. To retrieve the Prediction Intervals at the 95% confidence level as a tensor: >>> intervals = PIs(0.95) >>> print(intervals) To retrieve the Prediction Intervals for all confidence levels as a dictionary: >>> all_intervals = PIs() >>> print(all_intervals) """ if not isinstance(dp, int): raise TypeError(f"dp must be an int, got {type(dp).__name__}.") if not (1 <= dp <= _CONF_DP): raise ValueError(f"dp must be an integer in [1, {_CONF_DP}], got {dp}.") if conf_level is None: # Create a dictionary of all prediction intervals (string keys -> interval tensors) out = {} for i, cl in enumerate(self._conf_keys): cl_str = f"{(cl.item() / _CONF_SCALE):.{dp}f}" cl_pis = self.all_pis[i].clone() cl_pis[:, 0] = cl_pis[:, 0].clamp(min=y_min) cl_pis[:, 1] = cl_pis[:, 1].clamp(max=y_max) out[cl_str] = cl_pis return out cl = int(round(float(conf_level) * _CONF_SCALE)) idx = self._index.get(cl) if idx is None: available = ", ".join( self._format_level_key(cl) for cl in self._index.keys() ) raise ValueError( f"Confidence level {conf_level} not found. Available levels are: [{available}]" ) out = self.all_pis[idx].clone() out[:, 0] = out[:, 0].clamp(min=y_min) out[:, 1] = out[:, 1].clamp(max=y_max) return out
[docs] def evaluate(self, conf_level, metrics=None, y=None, *, y_min = float('-inf'), y_max = float('inf')): """ Evaluates the Prediction Intervals at a specified confidence level. Parameters ---------- conf_level : float in range (0,1) Confidence level of the Prediction Intervals to be evaluated. metrics : list of str or str, optional, default=['mean_width', 'median_width', 'error'] Metrics to calculate. Possible options: - 'mean_width': Average width of the Prediction Intervals. - 'median_width': Median width of the Prediction Intervals. - 'error': Percentage of Prediction Intervals that do not contain the true target value. y : torch.Tensor of shape (n_test,), optional, default=None True target values, required for calculating the 'error' metric. If not provided, 'error' is not calculated. y_min : float, keyword-only, default=-inf If provided, PIs are evaluated after cutting values below y_min. y_max : float, keyword-only, default=inf If provided, PIs are evaluated after cutting values above y_max. Returns ------- results : dict A dictionary with a key for each metric in ``metrics`` and the calculated result as its value. For example: {'mean_width': 3.852, 'error': 0.049}. Examples -------- Assuming ``PIs`` is an instance of ``PredictionIntervals`` that includes the 99% confidence level, and ``test_y`` is a tensor with the true targets. To evaluate the Prediction Intervals at the 99% confidence level using all available metrics (which is the default): >>> results = PIs.evaluate(0.99, y=test_y) To evaluate only the mean width of the Prediction Intervals at the 99% confidence level: >>> results = PIs.evaluate(0.99, metrics='mean_width') """ if metrics is None: metrics = ['mean_width', 'median_width', 'error'] if isinstance(metrics, str): metrics = [metrics] cl_pis = self(conf_level, y_min=y_min, y_max=y_max) results = {} # Check if any metrics require pi_widths before calculating need_widths = any(m in ('mean_width', 'median_width') for m in metrics) if need_widths: pi_widths = cl_pis[:, 1] - cl_pis[:, 0] for name in metrics: if name == 'error': if y is None: warnings.warn( "True labels 'y' not provided for error calculation - skipping 'error' metric.", RuntimeWarning ) else: y = torch.as_tensor(y, device=cl_pis.device, dtype=cl_pis.dtype) errors = (y < cl_pis[:, 0]) | (y > cl_pis[:, 1]) results['error'] = errors.to(torch.float64).mean().item() elif name == 'mean_width': results['mean_width'] = pi_widths.mean().item() elif name == 'median_width': results['median_width'] = pi_widths.median().item() else: warnings.warn(f"'{name}' is not a recognized metric.", RuntimeWarning) return results
def _decimal_places_for_key(self, int_key: int) -> int: tz = 0 tmp = abs(int_key) # count trailing zeros, up to the stored precision while tz < _CONF_DP-2 and (tmp % 10) == 0: tz += 1 tmp //= 10 return _CONF_DP - tz def _format_level_key(self, int_key: int) -> str: dp = self._decimal_places_for_key(int_key) return f"{int_key / _CONF_SCALE:.{dp}f}"