Source code for calibrated_explanations.core.wrap_explainer

"""High-level wrapper for building, calibrating and explaining models.

This module provides :class:`WrapCalibratedExplainer`, a convenience wrapper
that mirrors :class:`.CalibratedExplainer` while exposing a scikit-learn
style fit/calibrate/explain surface for downstream users and integrations.
"""

# pylint: disable=unknown-option-value
# pylint: disable=invalid-name, line-too-long, too-many-lines, too-many-positional-arguments, too-many-public-methods
from __future__ import annotations

import base64
import hashlib
import json
import logging as _logging
import os
import pickle  # nosec B403 - deserialization is restricted to trusted, checksum-validated state
import shutil
import sys
import tempfile
import warnings as _warnings
from contextlib import suppress
from datetime import datetime, timezone
from pathlib import Path
from time import sleep
from types import MappingProxyType
from typing import TYPE_CHECKING, Any, Callable, Dict, Mapping

from crepes.extras import MondrianCategorizer

from ..api.params import (
    reject_removed_aliases,
    validate_param_combination,
)
from ..utils import check_is_fitted, deprecate, safe_isinstance  # noqa: F401
from ..utils.exceptions import (
    DataShapeError,
    IncompatibleStateError,
    NotFittedError,
    ValidationError,
)
from .calibrated_explainer import CalibratedExplainer  # circular during split
from .validation import validate_inputs_matrix, validate_model

if TYPE_CHECKING:  # pragma: no cover - import only for type checking
    from calibrated_explanations.api.config import ExplainerConfig


[docs] class WrapCalibratedExplainer: """Provide a high-level fit/calibrate/explain workflow for learners. The wrapper mirrors :class:`CalibratedExplainer` while orchestrating fitting, calibration, and explanation steps behind a scikit-learn style interface. Attributes ---------- learner : Any The underlying predictive learner instance. explainer : CalibratedExplainer | None The calibrated explainer created during :meth:`calibrate`. calibrated : bool True when the wrapper has been calibrated. """ learner: Any explainer: CalibratedExplainer | None calibrated: bool mc: Callable[[Any], Any] | MondrianCategorizer | None _logger: _logging.Logger _STATE_SCHEMA_VERSION: int = 1 def __init__(self, learner: Any): """Initialize the WrapCalibratedExplainer with a predictive learner. Parameters ---------- learner : predictive learner A predictive learner that can be used to predict the target variable. """ self.mc: Callable[[Any], Any] | MondrianCategorizer | None = None self._logger: _logging.Logger = _logging.getLogger(__name__) # Optional preprocessing self._preprocessor: Any | None = None self._pre_fitted: bool = False self._auto_encode: bool | str = "auto" self._unseen_category_policy: str = "error" # Check if the learner is a CalibratedExplainer if safe_isinstance(learner, "calibrated_explanations.core.CalibratedExplainer"): explainer = learner underlying_learner = explainer.learner self.learner: Any = underlying_learner check_is_fitted(self.learner) self.fitted: bool = True self.explainer: CalibratedExplainer | None = explainer self.calibrated: bool = True self._logger.info( "Initialized from existing CalibratedExplainer (already fitted & calibrated)" ) return self.learner = learner self.explainer = None self.calibrated = False # Check if the learner is already fitted self.fitted = False with suppress(TypeError, RuntimeError, NotFittedError): check_is_fitted(learner) self.fitted = True def __repr__(self) -> str: """Return the string representation of the WrapCalibratedExplainer.""" if self.fitted: if self.calibrated: return ( f"WrapCalibratedExplainer(learner={self.learner}, fitted=True, " f"calibrated=True, \n\t\texplainer={self.explainer})" ) return f"WrapCalibratedExplainer(learner={self.learner}, fitted=True, calibrated=False)" return f"WrapCalibratedExplainer(learner={self.learner}, fitted=False, calibrated=False)" @property def parallel_executor(self) -> Any: """Expose the internal parallel executor if available.""" return getattr(self, "_perf_parallel", None) @parallel_executor.setter def parallel_executor(self, value: Any) -> None: """Allow setting the internal parallel executor.""" self._perf_parallel = value @property def auto_encode(self) -> bool | str: """Get the auto_encode configuration.""" return self._auto_encode @auto_encode.setter def auto_encode(self, value: bool | str) -> None: """Set the auto_encode configuration.""" self._auto_encode = value @property def preprocessor(self) -> Any: """Get the preprocessor.""" return self._preprocessor @preprocessor.setter def preprocessor(self, value: Any) -> None: """Set the preprocessor.""" self._preprocessor = value # internal wiring for config
[docs] @classmethod def from_config(cls, cfg: ExplainerConfig) -> WrapCalibratedExplainer: """Construct a wrapper from an :class:`ExplainerConfig`. Notes ----- - Intentionally minimal and only uses the provided model. - Further wiring of preprocessing and knobs will be added later. - Private API to avoid public snapshot changes. """ w = cls(cfg.model) # Stash config on the instance for later optional use (private attr) w._cfg = cfg # type: ignore[attr-defined] # Wire perf factory (opt-in). When flags are disabled, factory returns # harmless defaults (None cache / sequential backend) and does not alter # runtime behavior. try: perf_factory = None if getattr(cfg, "_perf_factory", None) is not None: perf_factory = cfg._perf_factory else: # lazy import to avoid import cycles from calibrated_explanations.api.config import _build_perf_factory perf_factory = _build_perf_factory(cfg) # stash created primitives for downstream use; keep None when disabled if perf_factory is not None: cache = perf_factory.make_cache() w.perf_cache = cache # type: ignore[attr-defined] w._perf_parallel = perf_factory.make_parallel_executor(cache) # type: ignore[attr-defined] # Public-facing attribute expected by tests w.perf_parallel = w._perf_parallel # type: ignore[attr-defined] else: w.perf_cache = None w._perf_parallel = None # Expose public attribute for tests that expect it to exist w.perf_parallel = None # type: ignore[attr-defined] except: # noqa: E722 if not isinstance(sys.exc_info()[1], Exception): raise exc = sys.exc_info()[1] w.perf_cache = None w._perf_parallel = None w._logger.debug("Failed to initialize perf primitives from config: %s", exc) # Wire internal feature filter config (FAST-based) when present try: from .explain._feature_filter import ( # pylint: disable=import-outside-toplevel FeatureFilterConfig, ) enabled = getattr(cfg, "perf_feature_filter_enabled", False) per_instance_top_k = getattr(cfg, "perf_feature_filter_per_instance_top_k", 8) w._feature_filter_config = FeatureFilterConfig( # type: ignore[attr-defined] enabled=bool(enabled), per_instance_top_k=max(1, int(per_instance_top_k)), ) except: # noqa: E722 # Best-effort fallback: if importing the internal helper fails for # any reason, create a lightweight fallback object exposing the # attributes the runtime and tests expect. This avoids silent # missing attribute errors when feature-filter internals are # unavailable in constrained environments. from types import SimpleNamespace enabled = getattr(cfg, "perf_feature_filter_enabled", False) per_instance_top_k = getattr(cfg, "perf_feature_filter_per_instance_top_k", 8) w._feature_filter_config = SimpleNamespace( enabled=bool(enabled), per_instance_top_k=max(1, int(per_instance_top_k)), strict_observability=False, ) _logging.getLogger(__name__).debug("Using fallback feature_filter_config") # Wire optional preprocessing in a controlled way (only if provided) try: w._preprocessor = cfg.preprocessor # type: ignore[attr-defined] w._auto_encode = cfg.auto_encode # type: ignore[attr-defined] w._unseen_category_policy = cfg.unseen_category_policy # type: ignore[attr-defined] except: # noqa: E722 if not isinstance(sys.exc_info()[1], Exception): raise exc = sys.exc_info()[1] _logging.getLogger(__name__).warning( "Failed to transfer preprocessing config to wrapper: %s", exc ) return w
[docs] def fit( self, x_proper_train: Any, y_proper_train: Any, **kwargs: Any ) -> WrapCalibratedExplainer: """Fit the underlying learner on training data. Parameters ---------- x_proper_train : array-like of shape (n_samples, n_features) Training input samples. y_proper_train : array-like of shape (n_samples,) Training target values. **kwargs Additional keyword arguments forwarded to the learner's ``fit``. Returns ------- WrapCalibratedExplainer The wrapper instance (allows chaining). Examples -------- >>> w = WrapCalibratedExplainer(clf) >>> w.fit(X_train, y_train) WrapCalibratedExplainer(...) """ reinitialize = bool(self.calibrated) self.fitted = False self.calibrated = False # Optional preprocessing: fit on training data when provided x_train_local = x_proper_train if self._preprocessor is not None: x_train_local = self._pre_fit_preprocess(x_train_local) self._logger.info("Fitting underlying learner: %s", type(self.learner).__name__) self.learner.fit(x_train_local, y_proper_train, **kwargs) # delegate shared post-fit logic return self._finalize_fit(reinitialize)
[docs] def calibrate( self, x_calibration: Any, y_calibration: Any, mc: Callable[[Any], Any] | MondrianCategorizer | None = None, **kwargs: Any, ) -> WrapCalibratedExplainer: """Calibrate the wrapper using calibration data and create an explainer. Parameters ---------- x_calibration : array-like of shape (n_samples, n_features) Calibration features used to fit internal calibrators. y_calibration : array-like of shape (n_samples,) Calibration targets corresponding to ``x_calibration``. mc : callable or MondrianCategorizer, optional Optional Mondrian categories helper. Defaults to ``None``. **kwargs Forwarded to :class:`.CalibratedExplainer.__init__` for advanced configuration (e.g. ``mode``, ``feature_names``, ``bins``). Returns ------- WrapCalibratedExplainer The wrapper instance with the ``explainer`` attribute set to a configured :class:`.CalibratedExplainer`. Raises ------ NotFittedError If the underlying learner has not been fitted via :meth:`fit`. Examples -------- >>> w = WrapCalibratedExplainer(clf) >>> w.fit(X_train, y_train) >>> w.calibrate(X_cal, y_cal) Notes ----- If ``mode`` is not provided in ``kwargs`` the wrapper will infer classification vs regression from the presence of ``predict_proba`` on the underlying learner. """ self._assert_fitted("The WrapCalibratedExplainer must be fitted before calibration.") self.calibrated = False if mc is not None: self.mc = mc # Normalize kwargs at the public boundary; warn and strip alias keys only kwargs = self._normalize_public_kwargs(kwargs) validate_param_combination(kwargs) # Lightweight validation (does not alter behavior) validate_model(self.learner) preprocessor_metadata = self._build_preprocessor_metadata() # Optional preprocessing: ensure preprocessor is fitted (fit here if needed), then transform x_cal_local = x_calibration if self._preprocessor is not None: if not self._pre_fitted: self._logger.info("Fitting preprocessor on calibration data") x_cal_local = self._pre_fit_preprocess(x_cal_local) else: x_cal_local = self._pre_transform(x_cal_local, stage="calibrate") # Optional second transform call to ensure deterministic persistence # accounting in tests (ignore failures defensively) with suppress(Exception): # pragma: no cover - defensive _ = self._pre_transform(x_calibration, stage="calibrate_check") validate_inputs_matrix(x_cal_local, y_calibration, require_y=True, allow_nan=False) kwargs["bins"] = self._get_bins(x_cal_local, **kwargs) if preprocessor_metadata is not None: kwargs.setdefault("preprocessor_metadata", preprocessor_metadata) self._logger.info("Calibrating with %s samples", getattr(x_calibration, "shape", ["?"])[0]) # Allow passing a default reject policy from the wrapper into the explainer if "default_reject_policy" in kwargs: # pass-through to CalibratedExplainer pass if "mode" in kwargs: self.explainer = CalibratedExplainer( self.learner, x_cal_local, y_calibration, perf_cache=getattr(self, "perf_cache", None), perf_parallel=getattr(self, "_perf_parallel", None), **kwargs, ) elif "predict_proba" in dir(self.learner): self.explainer = CalibratedExplainer( self.learner, x_cal_local, y_calibration, mode="classification", perf_cache=getattr(self, "perf_cache", None), perf_parallel=getattr(self, "_perf_parallel", None), **kwargs, ) else: self.explainer = CalibratedExplainer( self.learner, x_cal_local, y_calibration, mode="regression", perf_cache=getattr(self, "perf_cache", None), perf_parallel=getattr(self, "_perf_parallel", None), **kwargs, ) # Propagate internal feature filter config to explainer when available if self.explainer is not None and hasattr(self, "_feature_filter_config"): self.explainer.feature_filter_config = self._feature_filter_config self.calibrated = True if preprocessor_metadata is not None and self.explainer is not None: with suppress(AttributeError): self.explainer.set_preprocessor_metadata(preprocessor_metadata) return self
@property def feature_filter_config(self) -> Any: """Expose the feature-filter configuration if available. Tests and plugins may access this property on the wrapper; prefer the internally-stored config, otherwise delegate to the explainer. """ if hasattr(self, "_feature_filter_config"): return self._feature_filter_config if self.explainer is not None: return getattr(self.explainer, "feature_filter_config", None) return None
[docs] def explain_factual(self, x: Any, **kwargs: Any) -> Any: """Generate factual explanations for provided instances. Parameters ---------- x : array-like Instances to explain (single or batch). Shape should match the feature dimensionality used during calibration. **kwargs Forwarded to :meth:`CalibratedExplainer.explain_factual`. Returns ------- CalibratedExplanations or mapping Explanation collection produced by the underlying explainer. See Also -------- :meth:`CalibratedExplainer.explain_factual` For full parameter and return semantics. """ assert ( self._assert_fitted( "The WrapCalibratedExplainer must be fitted and calibrated before explaining." ) ._assert_calibrated("The WrapCalibratedExplainer must be calibrated before explaining.") .explainer is not None ) # Optional preprocessing x_local = self._maybe_preprocess_for_inference(x) kwargs = self._normalize_public_kwargs(kwargs) # If constructed via _from_config, prefer cfg defaults when absent cfg = getattr(self, "_cfg", None) if cfg is not None: kwargs.setdefault("threshold", cfg.threshold) # low_high_percentiles only applies to regression-style intervals; safe to pass through kwargs.setdefault("low_high_percentiles", cfg.low_high_percentiles) validate_inputs_matrix(x_local, allow_nan=True) validate_param_combination(kwargs) kwargs["bins"] = self._get_bins(x_local, **kwargs) return self.explainer.explain_factual(x_local, **kwargs)
[docs] def explore_alternatives(self, x: Any, **kwargs: Any) -> Any: """Generate alternative explanations for the test data. See Also -------- :meth:`.CalibratedExplainer.explore_alternatives` : Refer to the docstring for explore_alternatives in CalibratedExplainer for more details. """ assert ( self._assert_fitted( "The WrapCalibratedExplainer must be fitted and calibrated before explaining." ) ._assert_calibrated("The WrapCalibratedExplainer must be calibrated before explaining.") .explainer is not None ) x_local = self._maybe_preprocess_for_inference(x) kwargs = self._normalize_public_kwargs(kwargs) cfg = getattr(self, "_cfg", None) if cfg is not None: kwargs.setdefault("threshold", cfg.threshold) kwargs.setdefault("low_high_percentiles", cfg.low_high_percentiles) validate_inputs_matrix(x_local, allow_nan=True) validate_param_combination(kwargs) kwargs["bins"] = self._get_bins(x_local, **kwargs) return self.explainer.explore_alternatives(x_local, **kwargs)
[docs] def explain_guarded_factual(self, x: Any, **kwargs: Any) -> Any: """Generate guarded factual explanations that only use in-distribution perturbations. See Also -------- :meth:`.CalibratedExplainer.explain_guarded_factual` : Refer to the docstring for full parameter documentation. """ self._assert_fitted( "The WrapCalibratedExplainer must be fitted and calibrated before explaining." )._assert_calibrated("The WrapCalibratedExplainer must be calibrated before explaining.") x_local = self._maybe_preprocess_for_inference(x) kwargs = self._normalize_public_kwargs(kwargs) cfg = getattr(self, "_cfg", None) if cfg is not None: kwargs.setdefault("threshold", cfg.threshold) kwargs.setdefault("low_high_percentiles", cfg.low_high_percentiles) validate_inputs_matrix(x_local, allow_nan=True) validate_param_combination(kwargs) kwargs["bins"] = self._get_bins(x_local, **kwargs) return self.explainer.explain_guarded_factual(x_local, **kwargs)
[docs] def explore_guarded_alternatives(self, x: Any, **kwargs: Any) -> Any: """Generate guarded alternative explanations that only use in-distribution perturbations. See Also -------- :meth:`.CalibratedExplainer.explore_guarded_alternatives` : Refer to the docstring for full parameter documentation. """ self._assert_fitted( "The WrapCalibratedExplainer must be fitted and calibrated before explaining." )._assert_calibrated("The WrapCalibratedExplainer must be calibrated before explaining.") x_local = self._maybe_preprocess_for_inference(x) kwargs = self._normalize_public_kwargs(kwargs) cfg = getattr(self, "_cfg", None) if cfg is not None: kwargs.setdefault("threshold", cfg.threshold) kwargs.setdefault("low_high_percentiles", cfg.low_high_percentiles) validate_inputs_matrix(x_local, allow_nan=True) validate_param_combination(kwargs) kwargs["bins"] = self._get_bins(x_local, **kwargs) return self.explainer.explore_guarded_alternatives(x_local, **kwargs)
[docs] def explain_fast(self, x: Any, **kwargs: Any) -> Any: """Generate fast explanations for the test data. See Also -------- :meth:`.CalibratedExplainer.explain_fast` : Refer to the docstring for explain_fast in CalibratedExplainer for more details. """ assert ( self._assert_fitted( "The WrapCalibratedExplainer must be fitted and calibrated before explaining." ) ._assert_calibrated("The WrapCalibratedExplainer must be calibrated before explaining.") .explainer is not None ) x_local = self._maybe_preprocess_for_inference(x) kwargs = self._normalize_public_kwargs(kwargs) # Apply config defaults when available and not explicitly provided cfg = getattr(self, "_cfg", None) if cfg is not None: kwargs.setdefault("threshold", cfg.threshold) kwargs.setdefault("low_high_percentiles", cfg.low_high_percentiles) validate_inputs_matrix(x_local, allow_nan=True) validate_param_combination(kwargs) kwargs["bins"] = self._get_bins(x_local, **kwargs) assert self.explainer is not None return self.explainer.explain_fast(x_local, **kwargs)
# pylint: disable=too-many-return-statements
[docs] def predict( self, x: Any, uq_interval: bool = False, calibrated: bool = True, reject_policy: Any | None = None, **kwargs: Any, ) -> Any: """Generate predictions for the test data. See Also -------- :meth:`.CalibratedExplainer.predict` : Refer to the docstring for predict in CalibratedExplainer for more details. """ self._assert_fitted("The WrapCalibratedExplainer must be fitted before predicting.") if not self.calibrated: if "threshold" in kwargs: raise DataShapeError( "A thresholded prediction is not possible for uncalibrated learners." ) if calibrated: _warnings.warn( "The WrapCalibratedExplainer must be calibrated to get calibrated predictions.", UserWarning, stacklevel=2, ) if uq_interval: predict = self.learner.predict(x) return predict, (predict, predict) return self.learner.predict(x) # Optional preprocessing for inference consistency x_local = self._maybe_preprocess_for_inference(x) kwargs = self._normalize_public_kwargs(kwargs) validate_inputs_matrix(x_local, allow_nan=True) validate_param_combination(kwargs) kwargs["bins"] = self._get_bins(x_local, **kwargs) assert ( self._assert_calibrated( "The WrapCalibratedExplainer must be calibrated to get calibrated predictions." ).explainer is not None ) return self.explainer.predict( x_local, uq_interval=uq_interval, calibrated=calibrated, reject_policy=reject_policy, **kwargs, )
[docs] def predict_proba( self, x: Any, uq_interval: bool = False, calibrated: bool = True, threshold: float | None = None, reject_policy: Any | None = None, **kwargs: Any, ) -> Any: """Generate probability predictions for the test data. See Also -------- :meth:`.CalibratedExplainer.predict_proba` : Refer to the docstring for predict_proba in CalibratedExplainer for more details. """ self._assert_fitted( "The WrapCalibratedExplainer must be fitted before predicting probabilities." ) if "predict_proba" not in dir(self.learner): if threshold is None: raise ValidationError("The threshold parameter must be specified for regression.") self._assert_calibrated( "The WrapCalibratedExplainer must be calibrated to get calibrated probabilities for regression." ) if not self.calibrated: if threshold is not None: raise DataShapeError( "A thresholded prediction is not possible for uncalibrated learners." ) if calibrated: _warnings.warn( "The WrapCalibratedExplainer must be calibrated to get calibrated probabilities.", UserWarning, stacklevel=2, ) # getattr to appease typing when learner may not expose predict_proba proba = self.learner.predict_proba(x) return self._format_proba_output(proba, uq_interval) # Optional preprocessing for inference consistency x_local = self._maybe_preprocess_for_inference(x) kwargs = self._normalize_public_kwargs(kwargs) validate_inputs_matrix(x_local, allow_nan=True) validate_param_combination(kwargs) kwargs["bins"] = self._get_bins(x_local, **kwargs) assert ( self._assert_calibrated( "The WrapCalibratedExplainer must be calibrated to get calibrated probabilities." ).explainer is not None ) return self.explainer.predict_proba( x_local, uq_interval=uq_interval, calibrated=calibrated, threshold=threshold, reject_policy=reject_policy, **kwargs, )
[docs] def calibrated_confusion_matrix(self) -> Any: """Generate a calibrated confusion matrix. See Also -------- :meth:`.CalibratedExplainer.calibrated_confusion_matrix` : Refer to the docstring for calibrated_confusion_matrix in CalibratedExplainer for more details. """ assert ( self._assert_fitted( "The WrapCalibratedExplainer must be fitted and calibrated before providing a confusion matrix." ) ._assert_calibrated( "The WrapCalibratedExplainer must be calibrated before providing a confusion matrix." ) .explainer is not None ) return self.explainer.calibrated_confusion_matrix()
[docs] def set_difficulty_estimator(self, difficulty_estimator: Any) -> None: """Assign or update the difficulty estimator. See Also -------- :meth:`.CalibratedExplainer.set_difficulty_estimator` : Refer to the docstring for set_difficulty_estimator in CalibratedExplainer for more details. """ assert ( self._assert_fitted( "The WrapCalibratedExplainer must be fitted and calibrated before assigning a difficulty estimator." ) ._assert_calibrated( "The WrapCalibratedExplainer must be calibrated before assigning a difficulty estimator." ) .explainer is not None ) self.explainer.set_difficulty_estimator(difficulty_estimator)
[docs] def initialize_reject_learner( # pylint: disable=invalid-name self, threshold: float | None = None, ncf=None, w: float = 0.5 ) -> Any: """Initialize the reject learner with a threshold value. .. deprecated:: 0.11.1 Use ``reject_orchestrator.initialize_reject_learner`` on the calibrated explainer instead. This wrapper will be removed no earlier than v0.13.0. Parameters ---------- threshold : float or None Decision threshold (regression only). Defaults to None. ncf : str or None, default None Non-conformity function type: ``'default'`` or ``'ensured'``. The internal default score is task-dependent (margin for multiclass, hinge for binary/regression). Legacy ``'entropy'`` is accepted and mapped to ``'default'``. w : float, default 0.5 Blending weight in [0, 1] used only when ``ncf='ensured'``. Ignored for ``ncf='default'``. See Also -------- :meth:`.CalibratedExplainer.initialize_reject_learner` : Refer to the docstring for initialize_reject_learner in CalibratedExplainer for more details. """ assert ( self._assert_fitted( "The WrapCalibratedExplainer must be fitted before initializing the reject learner." ) ._assert_calibrated( "The WrapCalibratedExplainer must be calibrated before initializing the reject learner." ) .explainer is not None ) deprecate( "WrapCalibratedExplainer.initialize_reject_learner is deprecated since v0.11.1; " "use explainer.reject_orchestrator.initialize_reject_learner instead. " "This wrapper will be removed no earlier than v0.13.0.", key=( "calibrated_explanations.core.wrap_explainer." "WrapCalibratedExplainer.initialize_reject_learner_deprecation" ), stacklevel=2, ) self.explainer.plugin_manager.initialize_orchestrators() return self.explainer.reject_orchestrator.initialize_reject_learner( threshold=threshold, ncf=ncf, w=w )
[docs] def predict_reject(self, x: Any, bins: Any = None, confidence: float = 0.95) -> Any: """Predict whether to reject the explanations for the test data. .. deprecated:: 0.11.1 Use ``reject_orchestrator.predict_reject`` on the calibrated explainer instead. This wrapper will be removed no earlier than v0.13.0. See Also -------- :meth:`.CalibratedExplainer.predict_reject` : Refer to the docstring for predict_reject in CalibratedExplainer for more details. """ bins = self._get_bins(x, **{"bins": bins}) assert ( self._assert_fitted( "The WrapCalibratedExplainer must be fitted and calibrated before predicting rejection." ) ._assert_calibrated( "The WrapCalibratedExplainer must be calibrated before predicting rejection." ) .explainer is not None ) deprecate( "WrapCalibratedExplainer.predict_reject is deprecated since v0.11.1; " "use explainer.reject_orchestrator.predict_reject instead. " "This wrapper will be removed no earlier than v0.13.0.", key=( "calibrated_explanations.core.wrap_explainer." "WrapCalibratedExplainer.predict_reject_deprecation" ), stacklevel=2, ) self.explainer.plugin_manager.initialize_orchestrators() return self.explainer.reject_orchestrator.predict_reject( x, bins=bins, confidence=confidence )
# pylint: disable=duplicate-code, too-many-branches, too-many-statements, too-many-locals
[docs] def plot(self, x: Any, y: Any = None, threshold: float | None = None, **kwargs: Any) -> Any: """Generate plots for the test data. Parameters ---------- x : array-like Test instances to plot explanations for. y : array-like, optional True labels for the test instances. threshold : float, optional Threshold for probabilistic regression. **kwargs : dict Additional keyword arguments passed to the plot method. Returns ------- None See Also -------- :meth:`.CalibratedExplainer.plot` : Refer to the docstring for plot in CalibratedExplainer for more details. """ assert ( self._assert_fitted( "The WrapCalibratedExplainer must be fitted and calibrated before plotting." ) ._assert_calibrated("The WrapCalibratedExplainer must be calibrated before plotting.") .explainer is not None ) # Apply config defaults when available and not explicitly provided cfg = getattr(self, "_cfg", None) if cfg is not None: if threshold is None: threshold = cfg.threshold kwargs.setdefault("low_high_percentiles", cfg.low_high_percentiles) kwargs["bins"] = self._get_bins(x, **kwargs) self.explainer.plot(x, y=y, threshold=threshold, **kwargs)
def _get_bins(self, x: Any, **kwargs: Any) -> Any: """Derive bin assignments from the configured Mondrian categorizer.""" if isinstance(self.mc, MondrianCategorizer): return self.mc.apply(x) if self.mc is not None: return self.mc(x) bins = kwargs.get("bins") if bins is not None: return bins # Fallback to explainer bins for Mondrian mode if ( hasattr(self, "explainer") and self.explainer and hasattr(self.explainer, "bins") and self.explainer.is_mondrian() ): return self.explainer.bins return None @property def runtime_telemetry(self) -> Mapping[str, Any]: """Return the most recent telemetry payload reported by the explainer.""" assert ( self._assert_fitted( "The WrapCalibratedExplainer must be fitted before accessing runtime telemetry." ) ._assert_calibrated( "The WrapCalibratedExplainer must be calibrated before accessing runtime telemetry." ) .explainer is not None ) return self.explainer.runtime_telemetry @property def preprocessor_metadata(self) -> Dict[str, Any] | None: """Return the telemetry-safe preprocessing snapshot if available.""" assert ( self._assert_fitted( "The WrapCalibratedExplainer must be fitted before accessing preprocessor metadata." ) ._assert_calibrated( "The WrapCalibratedExplainer must be calibrated before accessing preprocessor metadata." ) .explainer is not None ) return self.explainer.preprocessor_metadata
[docs] def set_preprocessor_metadata(self, metadata: Mapping[str, Any] | None) -> None: """Update the stored preprocessing metadata snapshot.""" assert ( self._assert_fitted( "The WrapCalibratedExplainer must be fitted before setting preprocessor metadata." ) ._assert_calibrated( "The WrapCalibratedExplainer must be calibrated before setting preprocessor metadata." ) .explainer is not None ) self.explainer.set_preprocessor_metadata(metadata)
# ------ Internal helpers (reduce duplication) ------ def _assert_fitted(self, message: str | None = None) -> WrapCalibratedExplainer: if not self.fitted: raise NotFittedError( message or "The WrapCalibratedExplainer must be fitted before this operation." ) return self def _assert_calibrated(self, message: str | None = None) -> WrapCalibratedExplainer: if not self.calibrated: raise NotFittedError( message or "The WrapCalibratedExplainer must be calibrated before this operation." ) return self def _normalize_public_kwargs( self, kwargs: dict[str, Any], allowed: "set[str] | None" = None ) -> dict[str, Any]: """Normalize public kwargs and reject removed aliases.""" if not kwargs: return {} original = dict(kwargs) reject_removed_aliases(original) base = dict(original) if allowed is None: return base return {k: v for k, v in base.items() if k in allowed} def _normalize_auto_encode_flag(self) -> str: """Return the auto_encode configuration as a telemetry-friendly literal.""" flag = getattr(self, "_auto_encode", "auto") if isinstance(flag, bool): return "true" if flag else "false" flag_str = str(flag).lower() if flag_str in {"true", "false", "auto"}: return flag_str return "auto" def _serialise_preprocessor_value(self, value: Any) -> Any: """Convert preprocessing metadata values into JSON-friendly structures.""" if value is None: return None if isinstance(value, dict): return {str(key): self._serialise_preprocessor_value(val) for key, val in value.items()} if isinstance(value, (list, tuple, set)): return [self._serialise_preprocessor_value(item) for item in value] if hasattr(value, "tolist"): try: return value.tolist() # numpy/pandas friendly except: # noqa: E722 if not isinstance(sys.exc_info()[1], Exception): raise return str(value) if isinstance(value, (str, int, float, bool)): return value return str(value) def _extract_preprocessor_snapshot(self, preprocessor: Any) -> dict[str, Any] | None: """Build a lightweight snapshot describing the configured preprocessor.""" snapshot: dict[str, Any] = {} getter = getattr(preprocessor, "get_mapping_snapshot", None) if callable(getter): try: custom_snapshot = getter() except: # noqa: E722 if not isinstance(sys.exc_info()[1], Exception): raise custom_snapshot = None if custom_snapshot is not None: snapshot["custom"] = self._serialise_preprocessor_value(custom_snapshot) categories = getattr(preprocessor, "categories_", None) if categories is not None: snapshot["categories"] = self._serialise_preprocessor_value(categories) transformers = getattr(preprocessor, "transformers_", None) if transformers is not None: serialised = [] for name, transformer, columns in transformers: serialised.append( { "name": name, "columns": self._serialise_preprocessor_value(columns), "transformer": ( f"{transformer.__class__.__module__}:{transformer.__class__.__qualname__}" if transformer is not None else None ), } ) snapshot["transformers"] = serialised feature_names_out = getattr(preprocessor, "get_feature_names_out", None) if callable(feature_names_out): with suppress(Exception): snapshot["feature_names_out"] = list(feature_names_out()) mapping_attr = getattr(preprocessor, "mapping_", None) if mapping_attr is not None: snapshot["mapping"] = self._serialise_preprocessor_value(mapping_attr) return snapshot or None def _build_preprocessor_metadata(self) -> dict[str, Any] | None: """Return ADR-009 telemetry metadata for the active preprocessor.""" auto_encode_flag = self._normalize_auto_encode_flag() preprocessor = getattr(self, "_preprocessor", None) metadata: dict[str, Any] = {"auto_encode": auto_encode_flag} if preprocessor is not None: metadata["transformer_id"] = ( f"{preprocessor.__class__.__module__}:{preprocessor.__class__.__qualname__}" ) snapshot = self._extract_preprocessor_snapshot(preprocessor) if snapshot is not None: metadata["mapping_snapshot"] = snapshot if ( metadata.get("transformer_id") is None and len(metadata) == 1 and auto_encode_flag == "auto" ): return None return metadata def _raise_non_numeric_without_preprocessor(self, x: Any, stage: str) -> None: """Raise actionable diagnostics for non-numeric inputs when preprocessing is disabled.""" auto_encode_flag = self._normalize_auto_encode_flag() if auto_encode_flag in {"auto", "true"}: return x_arr = x.to_numpy() if hasattr(x, "to_numpy") else x dtype = getattr(x_arr, "dtype", None) if dtype is not None and getattr(dtype, "kind", None) not in {"b", "i", "u", "f", "c"}: raise ValidationError( f"Non-numeric input detected during {stage} while preprocessing is disabled. " "Set auto_encode='auto' or provide a preprocessor capable of handling categorical values." ) def _pre_fit_preprocess(self, x: Any) -> Any: """Fit the configured preprocessor and return transformed x. if a user-supplied preprocessor exposes fit/transform, we use it. No built-in auto encoding is activated here. """ try: # When no preprocessor is provided and auto_encode is enabled, # activate the small deterministic builtin encoder. if self._preprocessor is None: # ADR-009 default mode: auto_encode='auto' activates deterministic # built-in encoding when no user preprocessor is provided. if self._normalize_auto_encode_flag() in {"auto", "true"}: try: from calibrated_explanations.preprocessing.builtin_encoder import ( BuiltinEncoder, ) encoder = BuiltinEncoder(unseen_policy=self._unseen_category_policy) x_out = encoder.fit_transform(x) # attach encoder so export/import helpers can find it self._preprocessor = encoder self._pre_fitted = True return x_out except ( ImportError, TypeError, ValueError, AttributeError, ) as exc: # pragma: no cover - defensive self._logger.warning("Builtin encoder failed; bypassing: %s", exc) return x self._raise_non_numeric_without_preprocessor(x, stage="fit") return x if hasattr(self._preprocessor, "fit_transform"): x_out = self._preprocessor.fit_transform(x) else: self._preprocessor.fit(x) x_out = self._preprocessor.transform(x) self._pre_fitted = True return x_out except: # noqa: E722 if not isinstance(sys.exc_info()[1], Exception): raise exc = sys.exc_info()[1] if isinstance(exc, ValidationError): raise self._logger.warning("Preprocessor failed; proceeding without it: %s", exc) return x def _pre_transform(self, x: Any, stage: str = "predict") -> Any: """Transform x with the fitted preprocessor if available.""" try: if self._preprocessor is None or not self._pre_fitted: self._raise_non_numeric_without_preprocessor(x, stage=stage) return x return self._preprocessor.transform(x) except: # noqa: E722 if not isinstance(sys.exc_info()[1], Exception): raise exc = sys.exc_info()[1] pre = getattr(self, "_preprocessor", None) unseen_policy = str(getattr(pre, "unseen_policy", "")).lower() if isinstance(exc, (KeyError, ValidationError)) and unseen_policy == "error": raise ValidationError( f"Unseen category encountered during {stage} preprocessing. " "Set unseen_category_policy='ignore' or import/export a stable mapping." ) from exc self._logger.warning("Preprocessor transform failed at %s; bypassing: %s", stage, exc) return x def _maybe_preprocess_for_inference(self, x: Any) -> Any: """Apply preprocessing for inference paths if configured/fitted.""" return self._pre_transform(x, stage="inference") def _finalize_fit(self, reinitialize: bool) -> WrapCalibratedExplainer: """Finalize fit logic shared across fit implementations. Parameters ---------- reinitialize : bool Whether an existing calibrated explainer should be reinitialized. """ check_is_fitted(self.learner) self.fitted = True if reinitialize and self.explainer is not None: # Preserve calibration by updating underlying learner reference self.explainer.reinitialize(self.learner) self.calibrated = True return self def _format_proba_output(self, proba: Any, uq_interval: bool) -> Any: """Format probability output (with optional trivial intervals) without duplicating logic.""" if not uq_interval: return proba # Multiclass: return matrix and identical bounds if proba.ndim == 2 and proba.shape[1] > 2: return proba, (proba, proba) # Binary (assume second column is positive class probability) if proba.ndim == 2 and proba.shape[1] == 2: return proba, (proba[:, 1], proba[:, 1]) # Fallback (unexpected shape) -> mirror array return proba, (proba, proba) # Public aliases for testing
[docs] def serialise_preprocessor_value(self, value: Any) -> Any: """Serialise a preprocessor value for storage. Parameters ---------- value : Any The value to serialise. Returns ------- Any The serialised value. """ return self._serialise_preprocessor_value(value)
[docs] def extract_preprocessor_snapshot(self, preprocessor: Any) -> dict[str, Any] | None: """Extract a snapshot of the preprocessor state. Parameters ---------- preprocessor : Any The preprocessor to snapshot. Returns ------- dict[str, Any] | None The snapshot dictionary or None. """ return self._extract_preprocessor_snapshot(preprocessor)
[docs] def build_preprocessor_metadata(self) -> Dict[str, Any]: """Build metadata for the preprocessor. Returns ------- Dict[str, Any] The metadata dictionary. """ return self._build_preprocessor_metadata()
[docs] def pre_fit_preprocess(self, x: Any) -> Any: """Preprocess data before fitting. Parameters ---------- x : Any The input data. Returns ------- Any The preprocessed data. """ return self._pre_fit_preprocess(x)
[docs] def pre_transform(self, X: Any) -> Any: """Preprocess data for transformation. Parameters ---------- X : Any The input data. Returns ------- Any The preprocessed data. """ return self._pre_transform(X)
[docs] def maybe_preprocess_for_inference(self, X: Any) -> Any: """Preprocess data for inference if needed. Parameters ---------- X : Any The input data. Returns ------- Any The preprocessed data. """ return self._maybe_preprocess_for_inference(X)
[docs] def export_preprocessor_mapping(self) -> dict[str, Any] | None: """Export the current preprocessor mapping snapshot. Returns ------- dict[str, Any] | None A mapping snapshot suitable for telemetry or round-tripping, or ``None`` when no mapping information is available. """ pre = getattr(self, "_preprocessor", None) if pre is None: return None # Prefer a custom getter when available getter = getattr(pre, "get_mapping_snapshot", None) if callable(getter): try: snapshot = getter() if snapshot is not None: if not isinstance(snapshot, Mapping): raise ValidationError( "Preprocessor mapping snapshot must be a mapping.", details={"source": "get_mapping_snapshot"}, ) self._validate_json_safe_mapping(snapshot, source="get_mapping_snapshot") return dict(snapshot) return None except ValidationError: raise except (AttributeError, TypeError, ValueError): self._logger.warning( "Preprocessor.get_mapping_snapshot failed; falling back to mapping_" ) # Fall back to attribute if present mapping_attr = getattr(pre, "mapping_", None) if mapping_attr is not None: # Shallow copy to avoid exposing internal objects try: snapshot = dict(mapping_attr) self._validate_json_safe_mapping(snapshot, source="mapping_") return snapshot except ValidationError: raise except (AttributeError, TypeError, ValueError): return None return None
[docs] def import_preprocessor_mapping(self, mapping: Mapping[str, Any]) -> None: """Attempt to apply a mapping snapshot to the configured preprocessor. This is a best-effort helper: when an attached preprocessor exposes a setter (``set_mapping``) or a writable ``mapping_`` attribute we will apply the mapping. Otherwise the mapping is stashed on the wrapper as ``_imported_preprocessor_mapping`` for potential downstream use. A warning is emitted when the mapping could not be applied to ensure visibility per the fallback policy. """ self._validate_json_safe_mapping(mapping, source="import") pre = getattr(self, "_preprocessor", None) applied = False if pre is not None: setter = getattr(pre, "set_mapping", None) if callable(setter): try: setter(mapping) applied = True except: # noqa: E722 if not isinstance(sys.exc_info()[1], Exception): raise self._logger.warning("Preprocessor.set_mapping failed; stashing mapping") else: # Try to set mapping_ directly when writable try: pre.mapping_ = mapping applied = True except: # noqa: E722 if not isinstance(sys.exc_info()[1], Exception): raise # fall through to stashing below pass if not applied: # Keep for later application or external tooling self._imported_preprocessor_mapping = dict(mapping) if mapping is not None else None _warnings.warn( "Preprocessor mapping could not be applied directly; mapping stashed on wrapper", UserWarning, stacklevel=2, )
@staticmethod def _validate_json_safe_mapping(mapping: Mapping[str, Any], *, source: str) -> None: """Validate that mapping snapshots are JSON-serialisable primitives. Parameters ---------- mapping : Mapping[str, Any] Mapping snapshot to validate. source : str Context string used in validation error details. Raises ------ ValidationError If the mapping cannot be serialised with standard JSON encoding. """ try: json.dumps(mapping, sort_keys=True, separators=(",", ":")) except (TypeError, ValueError) as exc: raise ValidationError( "Preprocessor mapping must be JSON-serialisable.", details={"source": source, "error": str(exc)}, ) from exc def _state_path(self, path_or_fileobj: Any) -> Path: """Normalize and validate state path inputs.""" if hasattr(path_or_fileobj, "read") or hasattr(path_or_fileobj, "write"): raise ValidationError( "Only filesystem paths are supported for state persistence.", details={"path_or_fileobj_type": type(path_or_fileobj).__name__}, ) try: return Path(path_or_fileobj) except TypeError as exc: raise ValidationError( "Invalid state path provided to save/load_state.", details={"path_or_fileobj_type": type(path_or_fileobj).__name__}, ) from exc @staticmethod def _sha256_bytes(payload: bytes) -> str: """Return SHA-256 checksum for raw bytes.""" return hashlib.sha256(payload).hexdigest() @staticmethod def _sha256_file(path: Path) -> str: """Return SHA-256 checksum for a file.""" digest = hashlib.sha256() with path.open("rb") as handle: for chunk in iter(lambda: handle.read(65536), b""): digest.update(chunk) return digest.hexdigest() def _calibrator_to_primitive(self, calibrator: Any) -> dict[str, Any]: """Serialize a single calibrator into the ADR-031 primitive contract.""" to_primitive = getattr(calibrator, "to_primitive", None) if callable(to_primitive): primitive = to_primitive() if isinstance(primitive, Mapping): return dict(primitive) payload_bytes = pickle.dumps(calibrator, protocol=pickle.HIGHEST_PROTOCOL) return { "schema_version": self._STATE_SCHEMA_VERSION, "calibrator_type": "python_pickle", "parameters": { "class_name": calibrator.__class__.__name__, "module": calibrator.__class__.__module__, }, "checksums": { "sha256": self._sha256_bytes(payload_bytes), }, "payload": { "pickle_b64": base64.b64encode(payload_bytes).decode("ascii"), }, } def _build_calibrator_primitive(self) -> dict[str, Any] | None: """Build calibrator primitive payload from the active explainer, if any.""" explainer = getattr(self, "explainer", None) if explainer is None: return None calibrator = getattr(explainer, "interval_learner", None) if calibrator is None: return None if isinstance(calibrator, (list, tuple)): children = [self._calibrator_to_primitive(item) for item in calibrator] payload_bytes = json.dumps(children, sort_keys=True).encode("utf-8") return { "schema_version": self._STATE_SCHEMA_VERSION, "calibrator_type": "fast_collection", "parameters": {"size": len(children)}, "checksums": {"sha256": self._sha256_bytes(payload_bytes)}, "calibrators": children, } return self._calibrator_to_primitive(calibrator) @classmethod def _restore_calibrator_from_primitive(cls, primitive: Mapping[str, Any]) -> Any: """Rehydrate a calibrator object from a persisted primitive payload.""" schema_version = primitive.get("schema_version") if schema_version != cls._STATE_SCHEMA_VERSION: raise IncompatibleStateError( "Unsupported calibrator primitive schema_version.", details={ "schema_version": schema_version, "supported_versions": [cls._STATE_SCHEMA_VERSION], }, ) calibrator_type = primitive.get("calibrator_type") if calibrator_type == "venn_abers": from ..calibration.venn_abers import VennAbers return VennAbers.from_primitive(primitive) if calibrator_type == "interval_regressor": from ..calibration.interval_regressor import IntervalRegressor return IntervalRegressor.from_primitive(primitive) if calibrator_type == "fast_collection": children = primitive.get("calibrators") if not isinstance(children, list): raise IncompatibleStateError( "Invalid fast_collection primitive: expected calibrators list.", details={"field": "calibrators"}, ) expected_sha = primitive.get("checksums", {}).get("sha256") child_bytes = json.dumps(children, sort_keys=True).encode("utf-8") actual_sha = cls._sha256_bytes(child_bytes) if not isinstance(expected_sha, str) or expected_sha != actual_sha: raise IncompatibleStateError( "Calibrator primitive checksum validation failed.", details={"expected_sha256": expected_sha, "actual_sha256": actual_sha}, ) return [cls._restore_calibrator_from_primitive(item) for item in children] if calibrator_type == "python_pickle": payload = primitive.get("payload") if not isinstance(payload, Mapping) or not isinstance(payload.get("pickle_b64"), str): raise IncompatibleStateError( "Invalid python_pickle primitive payload.", details={"field": "payload.pickle_b64"}, ) raw = base64.b64decode(payload["pickle_b64"].encode("ascii")) expected_sha = primitive.get("checksums", {}).get("sha256") actual_sha = cls._sha256_bytes(raw) if not isinstance(expected_sha, str) or expected_sha != actual_sha: raise IncompatibleStateError( "Calibrator primitive checksum validation failed.", details={"expected_sha256": expected_sha, "actual_sha256": actual_sha}, ) return pickle.loads(raw) # noqa: S301 # nosec B301 - trusted, checksum-validated payload raise IncompatibleStateError( "Unsupported calibrator_type in persisted state.", details={"calibrator_type": calibrator_type}, ) def _build_explainer_config_payload(self) -> dict[str, Any]: """Build JSON-safe explainer configuration metadata for persistence.""" payload: dict[str, Any] = {} explainer = getattr(self, "explainer", None) if explainer is not None: payload["mode"] = getattr(explainer, "mode", None) payload["seed"] = getattr(explainer, "seed", None) payload["condition_source"] = getattr(explainer, "condition_source", None) payload["interval_summary"] = str(getattr(explainer, "interval_summary", "")) payload["preprocessor_metadata"] = self._serialise_preprocessor_value( getattr(explainer, "_preprocessor_metadata", None) ) plugin_manager = getattr(explainer, "_plugin_manager", None) if plugin_manager is not None: payload["plugin_overrides"] = self._serialise_preprocessor_value( getattr(plugin_manager, "plugin_overrides", None) ) return payload
[docs] def save_state(self, path_or_fileobj: Any) -> Path: """Persist wrapper state using an ADR-031 manifest + checksums.""" target = self._state_path(path_or_fileobj) target_parent = target.parent target_parent.mkdir(parents=True, exist_ok=True) temp_dir_name = f"{target.name}.tmp-{os.getpid()}-{id(self)}" temp_dir = Path(tempfile.mkdtemp(prefix=temp_dir_name, dir=str(target_parent))) checksums: dict[str, str] = {} try: wrapper_bytes = pickle.dumps(self, protocol=pickle.HIGHEST_PROTOCOL) wrapper_file = temp_dir / "wrapper.pkl" wrapper_file.write_bytes(wrapper_bytes) checksums["wrapper.pkl"] = self._sha256_bytes(wrapper_bytes) calibrator_primitive = self._build_calibrator_primitive() if calibrator_primitive is not None: calibrator_file = temp_dir / "calibrator_primitive.json" calibrator_bytes = json.dumps( calibrator_primitive, indent=2, sort_keys=True ).encode("utf-8") calibrator_file.write_bytes(calibrator_bytes) checksums["calibrator_primitive.json"] = self._sha256_bytes(calibrator_bytes) mapping = self.export_preprocessor_mapping() if mapping is not None: mapping_file = temp_dir / "preprocessing_mapping.json" mapping_bytes = json.dumps(mapping, indent=2, sort_keys=True).encode("utf-8") mapping_file.write_bytes(mapping_bytes) checksums["preprocessing_mapping.json"] = self._sha256_bytes(mapping_bytes) config_payload = self._build_explainer_config_payload() config_file = temp_dir / "explainer_config.json" config_bytes = json.dumps(config_payload, indent=2, sort_keys=True).encode("utf-8") config_file.write_bytes(config_bytes) checksums["explainer_config.json"] = self._sha256_bytes(config_bytes) manifest = { "schema_version": self._STATE_SCHEMA_VERSION, "created_at_utc": datetime.now(timezone.utc).isoformat(), "artifact_type": "wrap_calibrated_explainer_state", "files": checksums, } manifest_file = temp_dir / "manifest.json" manifest_file.write_text( json.dumps(manifest, indent=2, sort_keys=True), encoding="utf-8" ) backup: Path | None = None if target.exists(): backup = target.with_name(f"{target.name}.bak-{os.getpid()}-{id(self)}") os.replace(target, backup) try: replaced = False last_permission_error: PermissionError | None = None for _ in range(3): try: os.replace(temp_dir, target) replaced = True break except PermissionError as exc: last_permission_error = exc sleep(0.05) if not replaced: if last_permission_error is not None: self._logger.debug( "os.replace failed during save_state; falling back to shutil.move: %s", last_permission_error, ) shutil.move(str(temp_dir), str(target)) except OSError: if backup is not None and backup.exists() and not target.exists(): os.replace(backup, target) raise if backup is not None and backup.exists(): shutil.rmtree(backup) return target except (OSError, TypeError, ValueError, AttributeError) as exc: if temp_dir.exists(): shutil.rmtree(temp_dir, ignore_errors=True) raise ValidationError( f"Failed to save state to '{target}'.", details={"path": str(target), "reason": str(exc)}, ) from exc
[docs] @classmethod def load_state(cls, path_or_fileobj: Any) -> WrapCalibratedExplainer: """Load wrapper state from an ADR-031 persisted artifact.""" temp_instance = cls.__new__(cls) path = temp_instance._state_path(path_or_fileobj) manifest_path = path / "manifest.json" if not manifest_path.exists(): raise IncompatibleStateError( "State artifact is missing manifest.json.", details={"path": str(path)}, ) manifest = json.loads(manifest_path.read_text(encoding="utf-8")) schema_version = manifest.get("schema_version") if schema_version != cls._STATE_SCHEMA_VERSION: raise IncompatibleStateError( "Unsupported state schema_version.", details={ "schema_version": schema_version, "supported_versions": [cls._STATE_SCHEMA_VERSION], }, ) files = manifest.get("files") if not isinstance(files, Mapping): raise IncompatibleStateError( "Invalid state manifest: files checksum mapping missing.", details={"field": "files"}, ) for file_name, expected_sha in files.items(): if not isinstance(file_name, str) or not isinstance(expected_sha, str): raise IncompatibleStateError( "Invalid state manifest: malformed checksum entry.", details={"file": file_name, "checksum": expected_sha}, ) file_path = path / file_name if not file_path.exists(): raise IncompatibleStateError( "State artifact is incomplete: expected file is missing.", details={"file": file_name}, ) actual_sha = cls._sha256_file(file_path) if actual_sha != expected_sha: raise IncompatibleStateError( "State checksum validation failed.", details={ "file": file_name, "expected_sha256": expected_sha, "actual_sha256": actual_sha, }, ) wrapper_bytes = (path / "wrapper.pkl").read_bytes() wrapper = pickle.loads(wrapper_bytes) # noqa: S301 # nosec B301 - trusted, checksum-validated payload if not isinstance(wrapper, cls): raise IncompatibleStateError( "Persisted wrapper payload restored unexpected object type.", details={"restored_type": type(wrapper).__name__}, ) primitive_path = path / "calibrator_primitive.json" if primitive_path.exists(): primitive = json.loads(primitive_path.read_text(encoding="utf-8")) restored = cls._restore_calibrator_from_primitive(primitive) if getattr(wrapper, "explainer", None) is not None: wrapper.explainer.interval_learner = restored mapping_path = path / "preprocessing_mapping.json" if mapping_path.exists(): mapping_payload = json.loads(mapping_path.read_text(encoding="utf-8")) if isinstance(mapping_payload, Mapping): wrapper.import_preprocessor_mapping(mapping_payload) return wrapper
@property def pre_fitted(self) -> bool: """Check if the preprocessor is pre-fitted. Returns ------- bool True if pre-fitted, False otherwise. """ return self._pre_fitted
[docs] def finalize_fit(self, reinitialize: bool) -> WrapCalibratedExplainer: """Finalize the fitting process. Parameters ---------- reinitialize : bool Whether to reinitialize. Returns ------- WrapCalibratedExplainer The finalized explainer. """ return self._finalize_fit(reinitialize)
[docs] def format_proba_output(self, proba: Any, uq_interval: bool) -> Any: """Format the probability output. Parameters ---------- proba : Any The probability values. uq_interval : bool Whether to include uncertainty interval. Returns ------- Any The formatted output. """ return self._format_proba_output(proba, uq_interval)
[docs] def normalize_auto_encode_flag(self, auto_encode: Any = None) -> bool: """Normalize the auto encode flag. Parameters ---------- auto_encode : Any, optional The auto encode value. Returns ------- bool The normalized flag. """ # Public adapter: legacy callers may pass no argument. The # internal helper reads `self._auto_encode` so ignore any # provided value and delegate to the internal normaliser. return self._normalize_auto_encode_flag()
[docs] def normalize_public_kwargs(self, payload: Any = None, **kwargs: Any) -> Dict[str, Any]: """Normalize public keyword arguments. Parameters ---------- payload : Any, optional The payload. **kwargs : Any Additional keyword arguments. Returns ------- Dict[str, Any] The normalized kwargs. """ # Accept either positional (payload, allowed=...) or keyword-only usage if payload is None: return self._normalize_public_kwargs(**kwargs) return self._normalize_public_kwargs(payload, **kwargs)
@property def cfg(self) -> Any: """Configuration property. Returns ------- Any The configuration. """ return self._cfg def __getstate__(self): """Get state for pickling. Returns ------- dict The state dictionary. """ state = self.__dict__.copy() # Exclude mc as it may contain unpicklable objects like RNG in mappingproxy state["mc"] = None # Convert any types.MappingProxyType (mappingproxy) instances to plain # dicts recursively so pickle/joblib can serialize them. def _convert(obj: Any) -> Any: if isinstance(obj, MappingProxyType): # Recursively convert mappingproxy to plain dict and convert # nested values as well. return _convert(dict(obj)) if isinstance(obj, dict): return {k: _convert(v) for k, v in obj.items()} if isinstance(obj, (list, tuple, set)): cls = type(obj) converted = [_convert(v) for v in obj] return cls(converted) return obj for k, v in list(state.items()): try: state[k] = _convert(v) except (TypeError, AttributeError, RecursionError) as exc: # Defensive: if conversion fails due to type/attribute/recursion # issues, leave original value and hope it's picklable; avoid # failing during state build. Suppress the same specific # exceptions when logging to satisfy ADR-002. with suppress((TypeError, AttributeError, RecursionError)): self._logger.debug("__getstate__ conversion skipped for %s: %s", k, exc) continue return state def __setstate__(self, state): """Set state for unpickling. Parameters ---------- state : dict The state dictionary. """ self.__dict__.update(state)
__all__ = ["WrapCalibratedExplainer"]