"""High-level wrapper for building, calibrating and explaining models.
This module provides :class:`WrapCalibratedExplainer`, a convenience wrapper
that mirrors :class:`.CalibratedExplainer` while exposing a scikit-learn
style fit/calibrate/explain surface for downstream users and integrations.
"""
# pylint: disable=unknown-option-value
# pylint: disable=invalid-name, line-too-long, too-many-lines, too-many-positional-arguments, too-many-public-methods
from __future__ import annotations
import base64
import hashlib
import json
import logging as _logging
import os
import pickle # nosec B403 - deserialization is restricted to trusted, checksum-validated state
import shutil
import sys
import tempfile
import warnings as _warnings
from contextlib import suppress
from datetime import datetime, timezone
from pathlib import Path
from time import sleep
from types import MappingProxyType
from typing import TYPE_CHECKING, Any, Callable, Dict, Mapping
from crepes.extras import MondrianCategorizer
from ..api.params import (
reject_removed_aliases,
validate_param_combination,
)
from ..utils import check_is_fitted, deprecate, safe_isinstance # noqa: F401
from ..utils.exceptions import (
DataShapeError,
IncompatibleStateError,
NotFittedError,
ValidationError,
)
from .calibrated_explainer import CalibratedExplainer # circular during split
from .validation import validate_inputs_matrix, validate_model
if TYPE_CHECKING: # pragma: no cover - import only for type checking
from calibrated_explanations.api.config import ExplainerConfig
[docs]
class WrapCalibratedExplainer:
"""Provide a high-level fit/calibrate/explain workflow for learners.
The wrapper mirrors :class:`CalibratedExplainer` while orchestrating
fitting, calibration, and explanation steps behind a scikit-learn style
interface.
Attributes
----------
learner : Any
The underlying predictive learner instance.
explainer : CalibratedExplainer | None
The calibrated explainer created during :meth:`calibrate`.
calibrated : bool
True when the wrapper has been calibrated.
"""
learner: Any
explainer: CalibratedExplainer | None
calibrated: bool
mc: Callable[[Any], Any] | MondrianCategorizer | None
_logger: _logging.Logger
_STATE_SCHEMA_VERSION: int = 1
def __init__(self, learner: Any):
"""Initialize the WrapCalibratedExplainer with a predictive learner.
Parameters
----------
learner : predictive learner
A predictive learner that can be used to predict the target variable.
"""
self.mc: Callable[[Any], Any] | MondrianCategorizer | None = None
self._logger: _logging.Logger = _logging.getLogger(__name__)
# Optional preprocessing
self._preprocessor: Any | None = None
self._pre_fitted: bool = False
self._auto_encode: bool | str = "auto"
self._unseen_category_policy: str = "error"
# Check if the learner is a CalibratedExplainer
if safe_isinstance(learner, "calibrated_explanations.core.CalibratedExplainer"):
explainer = learner
underlying_learner = explainer.learner
self.learner: Any = underlying_learner
check_is_fitted(self.learner)
self.fitted: bool = True
self.explainer: CalibratedExplainer | None = explainer
self.calibrated: bool = True
self._logger.info(
"Initialized from existing CalibratedExplainer (already fitted & calibrated)"
)
return
self.learner = learner
self.explainer = None
self.calibrated = False
# Check if the learner is already fitted
self.fitted = False
with suppress(TypeError, RuntimeError, NotFittedError):
check_is_fitted(learner)
self.fitted = True
def __repr__(self) -> str:
"""Return the string representation of the WrapCalibratedExplainer."""
if self.fitted:
if self.calibrated:
return (
f"WrapCalibratedExplainer(learner={self.learner}, fitted=True, "
f"calibrated=True, \n\t\texplainer={self.explainer})"
)
return f"WrapCalibratedExplainer(learner={self.learner}, fitted=True, calibrated=False)"
return f"WrapCalibratedExplainer(learner={self.learner}, fitted=False, calibrated=False)"
@property
def parallel_executor(self) -> Any:
"""Expose the internal parallel executor if available."""
return getattr(self, "_perf_parallel", None)
@parallel_executor.setter
def parallel_executor(self, value: Any) -> None:
"""Allow setting the internal parallel executor."""
self._perf_parallel = value
@property
def auto_encode(self) -> bool | str:
"""Get the auto_encode configuration."""
return self._auto_encode
@auto_encode.setter
def auto_encode(self, value: bool | str) -> None:
"""Set the auto_encode configuration."""
self._auto_encode = value
@property
def preprocessor(self) -> Any:
"""Get the preprocessor."""
return self._preprocessor
@preprocessor.setter
def preprocessor(self, value: Any) -> None:
"""Set the preprocessor."""
self._preprocessor = value
# internal wiring for config
[docs]
@classmethod
def from_config(cls, cfg: ExplainerConfig) -> WrapCalibratedExplainer:
"""Construct a wrapper from an :class:`ExplainerConfig`.
Notes
-----
- Intentionally minimal and only uses the provided model.
- Further wiring of preprocessing and knobs will be added later.
- Private API to avoid public snapshot changes.
"""
w = cls(cfg.model)
# Stash config on the instance for later optional use (private attr)
w._cfg = cfg # type: ignore[attr-defined]
# Wire perf factory (opt-in). When flags are disabled, factory returns
# harmless defaults (None cache / sequential backend) and does not alter
# runtime behavior.
try:
perf_factory = None
if getattr(cfg, "_perf_factory", None) is not None:
perf_factory = cfg._perf_factory
else:
# lazy import to avoid import cycles
from calibrated_explanations.api.config import _build_perf_factory
perf_factory = _build_perf_factory(cfg)
# stash created primitives for downstream use; keep None when disabled
if perf_factory is not None:
cache = perf_factory.make_cache()
w.perf_cache = cache # type: ignore[attr-defined]
w._perf_parallel = perf_factory.make_parallel_executor(cache) # type: ignore[attr-defined]
# Public-facing attribute expected by tests
w.perf_parallel = w._perf_parallel # type: ignore[attr-defined]
else:
w.perf_cache = None
w._perf_parallel = None
# Expose public attribute for tests that expect it to exist
w.perf_parallel = None # type: ignore[attr-defined]
except: # noqa: E722
if not isinstance(sys.exc_info()[1], Exception):
raise
exc = sys.exc_info()[1]
w.perf_cache = None
w._perf_parallel = None
w._logger.debug("Failed to initialize perf primitives from config: %s", exc)
# Wire internal feature filter config (FAST-based) when present
try:
from .explain._feature_filter import ( # pylint: disable=import-outside-toplevel
FeatureFilterConfig,
)
enabled = getattr(cfg, "perf_feature_filter_enabled", False)
per_instance_top_k = getattr(cfg, "perf_feature_filter_per_instance_top_k", 8)
w._feature_filter_config = FeatureFilterConfig( # type: ignore[attr-defined]
enabled=bool(enabled),
per_instance_top_k=max(1, int(per_instance_top_k)),
)
except: # noqa: E722
# Best-effort fallback: if importing the internal helper fails for
# any reason, create a lightweight fallback object exposing the
# attributes the runtime and tests expect. This avoids silent
# missing attribute errors when feature-filter internals are
# unavailable in constrained environments.
from types import SimpleNamespace
enabled = getattr(cfg, "perf_feature_filter_enabled", False)
per_instance_top_k = getattr(cfg, "perf_feature_filter_per_instance_top_k", 8)
w._feature_filter_config = SimpleNamespace(
enabled=bool(enabled),
per_instance_top_k=max(1, int(per_instance_top_k)),
strict_observability=False,
)
_logging.getLogger(__name__).debug("Using fallback feature_filter_config")
# Wire optional preprocessing in a controlled way (only if provided)
try:
w._preprocessor = cfg.preprocessor # type: ignore[attr-defined]
w._auto_encode = cfg.auto_encode # type: ignore[attr-defined]
w._unseen_category_policy = cfg.unseen_category_policy # type: ignore[attr-defined]
except: # noqa: E722
if not isinstance(sys.exc_info()[1], Exception):
raise
exc = sys.exc_info()[1]
_logging.getLogger(__name__).warning(
"Failed to transfer preprocessing config to wrapper: %s", exc
)
return w
[docs]
def fit(
self, x_proper_train: Any, y_proper_train: Any, **kwargs: Any
) -> WrapCalibratedExplainer:
"""Fit the underlying learner on training data.
Parameters
----------
x_proper_train : array-like of shape (n_samples, n_features)
Training input samples.
y_proper_train : array-like of shape (n_samples,)
Training target values.
**kwargs
Additional keyword arguments forwarded to the learner's ``fit``.
Returns
-------
WrapCalibratedExplainer
The wrapper instance (allows chaining).
Examples
--------
>>> w = WrapCalibratedExplainer(clf)
>>> w.fit(X_train, y_train)
WrapCalibratedExplainer(...)
"""
reinitialize = bool(self.calibrated)
self.fitted = False
self.calibrated = False
# Optional preprocessing: fit on training data when provided
x_train_local = x_proper_train
if self._preprocessor is not None:
x_train_local = self._pre_fit_preprocess(x_train_local)
self._logger.info("Fitting underlying learner: %s", type(self.learner).__name__)
self.learner.fit(x_train_local, y_proper_train, **kwargs)
# delegate shared post-fit logic
return self._finalize_fit(reinitialize)
[docs]
def calibrate(
self,
x_calibration: Any,
y_calibration: Any,
mc: Callable[[Any], Any] | MondrianCategorizer | None = None,
**kwargs: Any,
) -> WrapCalibratedExplainer:
"""Calibrate the wrapper using calibration data and create an explainer.
Parameters
----------
x_calibration : array-like of shape (n_samples, n_features)
Calibration features used to fit internal calibrators.
y_calibration : array-like of shape (n_samples,)
Calibration targets corresponding to ``x_calibration``.
mc : callable or MondrianCategorizer, optional
Optional Mondrian categories helper. Defaults to ``None``.
**kwargs
Forwarded to :class:`.CalibratedExplainer.__init__` for advanced
configuration (e.g. ``mode``, ``feature_names``, ``bins``).
Returns
-------
WrapCalibratedExplainer
The wrapper instance with the ``explainer`` attribute set to a
configured :class:`.CalibratedExplainer`.
Raises
------
NotFittedError
If the underlying learner has not been fitted via :meth:`fit`.
Examples
--------
>>> w = WrapCalibratedExplainer(clf)
>>> w.fit(X_train, y_train)
>>> w.calibrate(X_cal, y_cal)
Notes
-----
If ``mode`` is not provided in ``kwargs`` the wrapper will infer
classification vs regression from the presence of ``predict_proba``
on the underlying learner.
"""
self._assert_fitted("The WrapCalibratedExplainer must be fitted before calibration.")
self.calibrated = False
if mc is not None:
self.mc = mc
# Normalize kwargs at the public boundary; warn and strip alias keys only
kwargs = self._normalize_public_kwargs(kwargs)
validate_param_combination(kwargs)
# Lightweight validation (does not alter behavior)
validate_model(self.learner)
preprocessor_metadata = self._build_preprocessor_metadata()
# Optional preprocessing: ensure preprocessor is fitted (fit here if needed), then transform
x_cal_local = x_calibration
if self._preprocessor is not None:
if not self._pre_fitted:
self._logger.info("Fitting preprocessor on calibration data")
x_cal_local = self._pre_fit_preprocess(x_cal_local)
else:
x_cal_local = self._pre_transform(x_cal_local, stage="calibrate")
# Optional second transform call to ensure deterministic persistence
# accounting in tests (ignore failures defensively)
with suppress(Exception): # pragma: no cover - defensive
_ = self._pre_transform(x_calibration, stage="calibrate_check")
validate_inputs_matrix(x_cal_local, y_calibration, require_y=True, allow_nan=False)
kwargs["bins"] = self._get_bins(x_cal_local, **kwargs)
if preprocessor_metadata is not None:
kwargs.setdefault("preprocessor_metadata", preprocessor_metadata)
self._logger.info("Calibrating with %s samples", getattr(x_calibration, "shape", ["?"])[0])
# Allow passing a default reject policy from the wrapper into the explainer
if "default_reject_policy" in kwargs:
# pass-through to CalibratedExplainer
pass
if "mode" in kwargs:
self.explainer = CalibratedExplainer(
self.learner,
x_cal_local,
y_calibration,
perf_cache=getattr(self, "perf_cache", None),
perf_parallel=getattr(self, "_perf_parallel", None),
**kwargs,
)
elif "predict_proba" in dir(self.learner):
self.explainer = CalibratedExplainer(
self.learner,
x_cal_local,
y_calibration,
mode="classification",
perf_cache=getattr(self, "perf_cache", None),
perf_parallel=getattr(self, "_perf_parallel", None),
**kwargs,
)
else:
self.explainer = CalibratedExplainer(
self.learner,
x_cal_local,
y_calibration,
mode="regression",
perf_cache=getattr(self, "perf_cache", None),
perf_parallel=getattr(self, "_perf_parallel", None),
**kwargs,
)
# Propagate internal feature filter config to explainer when available
if self.explainer is not None and hasattr(self, "_feature_filter_config"):
self.explainer.feature_filter_config = self._feature_filter_config
self.calibrated = True
if preprocessor_metadata is not None and self.explainer is not None:
with suppress(AttributeError):
self.explainer.set_preprocessor_metadata(preprocessor_metadata)
return self
@property
def feature_filter_config(self) -> Any:
"""Expose the feature-filter configuration if available.
Tests and plugins may access this property on the wrapper; prefer
the internally-stored config, otherwise delegate to the explainer.
"""
if hasattr(self, "_feature_filter_config"):
return self._feature_filter_config
if self.explainer is not None:
return getattr(self.explainer, "feature_filter_config", None)
return None
[docs]
def explain_factual(self, x: Any, **kwargs: Any) -> Any:
"""Generate factual explanations for provided instances.
Parameters
----------
x : array-like
Instances to explain (single or batch). Shape should match the
feature dimensionality used during calibration.
**kwargs
Forwarded to :meth:`CalibratedExplainer.explain_factual`.
Returns
-------
CalibratedExplanations or mapping
Explanation collection produced by the underlying explainer.
See Also
--------
:meth:`CalibratedExplainer.explain_factual`
For full parameter and return semantics.
"""
assert (
self._assert_fitted(
"The WrapCalibratedExplainer must be fitted and calibrated before explaining."
)
._assert_calibrated("The WrapCalibratedExplainer must be calibrated before explaining.")
.explainer
is not None
)
# Optional preprocessing
x_local = self._maybe_preprocess_for_inference(x)
kwargs = self._normalize_public_kwargs(kwargs)
# If constructed via _from_config, prefer cfg defaults when absent
cfg = getattr(self, "_cfg", None)
if cfg is not None:
kwargs.setdefault("threshold", cfg.threshold)
# low_high_percentiles only applies to regression-style intervals; safe to pass through
kwargs.setdefault("low_high_percentiles", cfg.low_high_percentiles)
validate_inputs_matrix(x_local, allow_nan=True)
validate_param_combination(kwargs)
kwargs["bins"] = self._get_bins(x_local, **kwargs)
return self.explainer.explain_factual(x_local, **kwargs)
[docs]
def explore_alternatives(self, x: Any, **kwargs: Any) -> Any:
"""Generate alternative explanations for the test data.
See Also
--------
:meth:`.CalibratedExplainer.explore_alternatives` : Refer to the docstring for explore_alternatives in CalibratedExplainer for more details.
"""
assert (
self._assert_fitted(
"The WrapCalibratedExplainer must be fitted and calibrated before explaining."
)
._assert_calibrated("The WrapCalibratedExplainer must be calibrated before explaining.")
.explainer
is not None
)
x_local = self._maybe_preprocess_for_inference(x)
kwargs = self._normalize_public_kwargs(kwargs)
cfg = getattr(self, "_cfg", None)
if cfg is not None:
kwargs.setdefault("threshold", cfg.threshold)
kwargs.setdefault("low_high_percentiles", cfg.low_high_percentiles)
validate_inputs_matrix(x_local, allow_nan=True)
validate_param_combination(kwargs)
kwargs["bins"] = self._get_bins(x_local, **kwargs)
return self.explainer.explore_alternatives(x_local, **kwargs)
[docs]
def explain_guarded_factual(self, x: Any, **kwargs: Any) -> Any:
"""Generate guarded factual explanations that only use in-distribution perturbations.
See Also
--------
:meth:`.CalibratedExplainer.explain_guarded_factual` : Refer to the docstring for full parameter documentation.
"""
self._assert_fitted(
"The WrapCalibratedExplainer must be fitted and calibrated before explaining."
)._assert_calibrated("The WrapCalibratedExplainer must be calibrated before explaining.")
x_local = self._maybe_preprocess_for_inference(x)
kwargs = self._normalize_public_kwargs(kwargs)
cfg = getattr(self, "_cfg", None)
if cfg is not None:
kwargs.setdefault("threshold", cfg.threshold)
kwargs.setdefault("low_high_percentiles", cfg.low_high_percentiles)
validate_inputs_matrix(x_local, allow_nan=True)
validate_param_combination(kwargs)
kwargs["bins"] = self._get_bins(x_local, **kwargs)
return self.explainer.explain_guarded_factual(x_local, **kwargs)
[docs]
def explore_guarded_alternatives(self, x: Any, **kwargs: Any) -> Any:
"""Generate guarded alternative explanations that only use in-distribution perturbations.
See Also
--------
:meth:`.CalibratedExplainer.explore_guarded_alternatives` : Refer to the docstring for full parameter documentation.
"""
self._assert_fitted(
"The WrapCalibratedExplainer must be fitted and calibrated before explaining."
)._assert_calibrated("The WrapCalibratedExplainer must be calibrated before explaining.")
x_local = self._maybe_preprocess_for_inference(x)
kwargs = self._normalize_public_kwargs(kwargs)
cfg = getattr(self, "_cfg", None)
if cfg is not None:
kwargs.setdefault("threshold", cfg.threshold)
kwargs.setdefault("low_high_percentiles", cfg.low_high_percentiles)
validate_inputs_matrix(x_local, allow_nan=True)
validate_param_combination(kwargs)
kwargs["bins"] = self._get_bins(x_local, **kwargs)
return self.explainer.explore_guarded_alternatives(x_local, **kwargs)
[docs]
def explain_fast(self, x: Any, **kwargs: Any) -> Any:
"""Generate fast explanations for the test data.
See Also
--------
:meth:`.CalibratedExplainer.explain_fast` : Refer to the docstring for explain_fast in CalibratedExplainer for more details.
"""
assert (
self._assert_fitted(
"The WrapCalibratedExplainer must be fitted and calibrated before explaining."
)
._assert_calibrated("The WrapCalibratedExplainer must be calibrated before explaining.")
.explainer
is not None
)
x_local = self._maybe_preprocess_for_inference(x)
kwargs = self._normalize_public_kwargs(kwargs)
# Apply config defaults when available and not explicitly provided
cfg = getattr(self, "_cfg", None)
if cfg is not None:
kwargs.setdefault("threshold", cfg.threshold)
kwargs.setdefault("low_high_percentiles", cfg.low_high_percentiles)
validate_inputs_matrix(x_local, allow_nan=True)
validate_param_combination(kwargs)
kwargs["bins"] = self._get_bins(x_local, **kwargs)
assert self.explainer is not None
return self.explainer.explain_fast(x_local, **kwargs)
# pylint: disable=too-many-return-statements
[docs]
def predict(
self,
x: Any,
uq_interval: bool = False,
calibrated: bool = True,
reject_policy: Any | None = None,
**kwargs: Any,
) -> Any:
"""Generate predictions for the test data.
See Also
--------
:meth:`.CalibratedExplainer.predict` : Refer to the docstring for predict in CalibratedExplainer for more details.
"""
self._assert_fitted("The WrapCalibratedExplainer must be fitted before predicting.")
if not self.calibrated:
if "threshold" in kwargs:
raise DataShapeError(
"A thresholded prediction is not possible for uncalibrated learners."
)
if calibrated:
_warnings.warn(
"The WrapCalibratedExplainer must be calibrated to get calibrated predictions.",
UserWarning,
stacklevel=2,
)
if uq_interval:
predict = self.learner.predict(x)
return predict, (predict, predict)
return self.learner.predict(x)
# Optional preprocessing for inference consistency
x_local = self._maybe_preprocess_for_inference(x)
kwargs = self._normalize_public_kwargs(kwargs)
validate_inputs_matrix(x_local, allow_nan=True)
validate_param_combination(kwargs)
kwargs["bins"] = self._get_bins(x_local, **kwargs)
assert (
self._assert_calibrated(
"The WrapCalibratedExplainer must be calibrated to get calibrated predictions."
).explainer
is not None
)
return self.explainer.predict(
x_local,
uq_interval=uq_interval,
calibrated=calibrated,
reject_policy=reject_policy,
**kwargs,
)
[docs]
def predict_proba(
self,
x: Any,
uq_interval: bool = False,
calibrated: bool = True,
threshold: float | None = None,
reject_policy: Any | None = None,
**kwargs: Any,
) -> Any:
"""Generate probability predictions for the test data.
See Also
--------
:meth:`.CalibratedExplainer.predict_proba` : Refer to the docstring for predict_proba in CalibratedExplainer for more details.
"""
self._assert_fitted(
"The WrapCalibratedExplainer must be fitted before predicting probabilities."
)
if "predict_proba" not in dir(self.learner):
if threshold is None:
raise ValidationError("The threshold parameter must be specified for regression.")
self._assert_calibrated(
"The WrapCalibratedExplainer must be calibrated to get calibrated probabilities for regression."
)
if not self.calibrated:
if threshold is not None:
raise DataShapeError(
"A thresholded prediction is not possible for uncalibrated learners."
)
if calibrated:
_warnings.warn(
"The WrapCalibratedExplainer must be calibrated to get calibrated probabilities.",
UserWarning,
stacklevel=2,
)
# getattr to appease typing when learner may not expose predict_proba
proba = self.learner.predict_proba(x)
return self._format_proba_output(proba, uq_interval)
# Optional preprocessing for inference consistency
x_local = self._maybe_preprocess_for_inference(x)
kwargs = self._normalize_public_kwargs(kwargs)
validate_inputs_matrix(x_local, allow_nan=True)
validate_param_combination(kwargs)
kwargs["bins"] = self._get_bins(x_local, **kwargs)
assert (
self._assert_calibrated(
"The WrapCalibratedExplainer must be calibrated to get calibrated probabilities."
).explainer
is not None
)
return self.explainer.predict_proba(
x_local,
uq_interval=uq_interval,
calibrated=calibrated,
threshold=threshold,
reject_policy=reject_policy,
**kwargs,
)
[docs]
def calibrated_confusion_matrix(self) -> Any:
"""Generate a calibrated confusion matrix.
See Also
--------
:meth:`.CalibratedExplainer.calibrated_confusion_matrix` : Refer to the docstring for calibrated_confusion_matrix in CalibratedExplainer for more details.
"""
assert (
self._assert_fitted(
"The WrapCalibratedExplainer must be fitted and calibrated before providing a confusion matrix."
)
._assert_calibrated(
"The WrapCalibratedExplainer must be calibrated before providing a confusion matrix."
)
.explainer
is not None
)
return self.explainer.calibrated_confusion_matrix()
[docs]
def set_difficulty_estimator(self, difficulty_estimator: Any) -> None:
"""Assign or update the difficulty estimator.
See Also
--------
:meth:`.CalibratedExplainer.set_difficulty_estimator` : Refer to the docstring for set_difficulty_estimator in CalibratedExplainer for more details.
"""
assert (
self._assert_fitted(
"The WrapCalibratedExplainer must be fitted and calibrated before assigning a difficulty estimator."
)
._assert_calibrated(
"The WrapCalibratedExplainer must be calibrated before assigning a difficulty estimator."
)
.explainer
is not None
)
self.explainer.set_difficulty_estimator(difficulty_estimator)
[docs]
def initialize_reject_learner( # pylint: disable=invalid-name
self, threshold: float | None = None, ncf=None, w: float = 0.5
) -> Any:
"""Initialize the reject learner with a threshold value.
.. deprecated:: 0.11.1
Use ``reject_orchestrator.initialize_reject_learner`` on the
calibrated explainer instead. This wrapper will be removed no
earlier than v0.13.0.
Parameters
----------
threshold : float or None
Decision threshold (regression only). Defaults to None.
ncf : str or None, default None
Non-conformity function type: ``'default'`` or ``'ensured'``.
The internal default score is task-dependent (margin for
multiclass, hinge for binary/regression). Legacy ``'entropy'``
is accepted and mapped to ``'default'``.
w : float, default 0.5
Blending weight in [0, 1] used only when ``ncf='ensured'``.
Ignored for ``ncf='default'``.
See Also
--------
:meth:`.CalibratedExplainer.initialize_reject_learner` : Refer to the docstring for initialize_reject_learner in CalibratedExplainer for more details.
"""
assert (
self._assert_fitted(
"The WrapCalibratedExplainer must be fitted before initializing the reject learner."
)
._assert_calibrated(
"The WrapCalibratedExplainer must be calibrated before initializing the reject learner."
)
.explainer
is not None
)
deprecate(
"WrapCalibratedExplainer.initialize_reject_learner is deprecated since v0.11.1; "
"use explainer.reject_orchestrator.initialize_reject_learner instead. "
"This wrapper will be removed no earlier than v0.13.0.",
key=(
"calibrated_explanations.core.wrap_explainer."
"WrapCalibratedExplainer.initialize_reject_learner_deprecation"
),
stacklevel=2,
)
self.explainer.plugin_manager.initialize_orchestrators()
return self.explainer.reject_orchestrator.initialize_reject_learner(
threshold=threshold, ncf=ncf, w=w
)
[docs]
def predict_reject(self, x: Any, bins: Any = None, confidence: float = 0.95) -> Any:
"""Predict whether to reject the explanations for the test data.
.. deprecated:: 0.11.1
Use ``reject_orchestrator.predict_reject`` on the calibrated
explainer instead. This wrapper will be removed no earlier than
v0.13.0.
See Also
--------
:meth:`.CalibratedExplainer.predict_reject` : Refer to the docstring for predict_reject in CalibratedExplainer for more details.
"""
bins = self._get_bins(x, **{"bins": bins})
assert (
self._assert_fitted(
"The WrapCalibratedExplainer must be fitted and calibrated before predicting rejection."
)
._assert_calibrated(
"The WrapCalibratedExplainer must be calibrated before predicting rejection."
)
.explainer
is not None
)
deprecate(
"WrapCalibratedExplainer.predict_reject is deprecated since v0.11.1; "
"use explainer.reject_orchestrator.predict_reject instead. "
"This wrapper will be removed no earlier than v0.13.0.",
key=(
"calibrated_explanations.core.wrap_explainer."
"WrapCalibratedExplainer.predict_reject_deprecation"
),
stacklevel=2,
)
self.explainer.plugin_manager.initialize_orchestrators()
return self.explainer.reject_orchestrator.predict_reject(
x, bins=bins, confidence=confidence
)
# pylint: disable=duplicate-code, too-many-branches, too-many-statements, too-many-locals
[docs]
def plot(self, x: Any, y: Any = None, threshold: float | None = None, **kwargs: Any) -> Any:
"""Generate plots for the test data.
Parameters
----------
x : array-like
Test instances to plot explanations for.
y : array-like, optional
True labels for the test instances.
threshold : float, optional
Threshold for probabilistic regression.
**kwargs : dict
Additional keyword arguments passed to the plot method.
Returns
-------
None
See Also
--------
:meth:`.CalibratedExplainer.plot` : Refer to the docstring for plot in CalibratedExplainer for more details.
"""
assert (
self._assert_fitted(
"The WrapCalibratedExplainer must be fitted and calibrated before plotting."
)
._assert_calibrated("The WrapCalibratedExplainer must be calibrated before plotting.")
.explainer
is not None
)
# Apply config defaults when available and not explicitly provided
cfg = getattr(self, "_cfg", None)
if cfg is not None:
if threshold is None:
threshold = cfg.threshold
kwargs.setdefault("low_high_percentiles", cfg.low_high_percentiles)
kwargs["bins"] = self._get_bins(x, **kwargs)
self.explainer.plot(x, y=y, threshold=threshold, **kwargs)
def _get_bins(self, x: Any, **kwargs: Any) -> Any:
"""Derive bin assignments from the configured Mondrian categorizer."""
if isinstance(self.mc, MondrianCategorizer):
return self.mc.apply(x)
if self.mc is not None:
return self.mc(x)
bins = kwargs.get("bins")
if bins is not None:
return bins
# Fallback to explainer bins for Mondrian mode
if (
hasattr(self, "explainer")
and self.explainer
and hasattr(self.explainer, "bins")
and self.explainer.is_mondrian()
):
return self.explainer.bins
return None
@property
def runtime_telemetry(self) -> Mapping[str, Any]:
"""Return the most recent telemetry payload reported by the explainer."""
assert (
self._assert_fitted(
"The WrapCalibratedExplainer must be fitted before accessing runtime telemetry."
)
._assert_calibrated(
"The WrapCalibratedExplainer must be calibrated before accessing runtime telemetry."
)
.explainer
is not None
)
return self.explainer.runtime_telemetry
@property
def preprocessor_metadata(self) -> Dict[str, Any] | None:
"""Return the telemetry-safe preprocessing snapshot if available."""
assert (
self._assert_fitted(
"The WrapCalibratedExplainer must be fitted before accessing preprocessor metadata."
)
._assert_calibrated(
"The WrapCalibratedExplainer must be calibrated before accessing preprocessor metadata."
)
.explainer
is not None
)
return self.explainer.preprocessor_metadata
# ------ Internal helpers (reduce duplication) ------
def _assert_fitted(self, message: str | None = None) -> WrapCalibratedExplainer:
if not self.fitted:
raise NotFittedError(
message or "The WrapCalibratedExplainer must be fitted before this operation."
)
return self
def _assert_calibrated(self, message: str | None = None) -> WrapCalibratedExplainer:
if not self.calibrated:
raise NotFittedError(
message or "The WrapCalibratedExplainer must be calibrated before this operation."
)
return self
def _normalize_public_kwargs(
self, kwargs: dict[str, Any], allowed: "set[str] | None" = None
) -> dict[str, Any]:
"""Normalize public kwargs and reject removed aliases."""
if not kwargs:
return {}
original = dict(kwargs)
reject_removed_aliases(original)
base = dict(original)
if allowed is None:
return base
return {k: v for k, v in base.items() if k in allowed}
def _normalize_auto_encode_flag(self) -> str:
"""Return the auto_encode configuration as a telemetry-friendly literal."""
flag = getattr(self, "_auto_encode", "auto")
if isinstance(flag, bool):
return "true" if flag else "false"
flag_str = str(flag).lower()
if flag_str in {"true", "false", "auto"}:
return flag_str
return "auto"
def _serialise_preprocessor_value(self, value: Any) -> Any:
"""Convert preprocessing metadata values into JSON-friendly structures."""
if value is None:
return None
if isinstance(value, dict):
return {str(key): self._serialise_preprocessor_value(val) for key, val in value.items()}
if isinstance(value, (list, tuple, set)):
return [self._serialise_preprocessor_value(item) for item in value]
if hasattr(value, "tolist"):
try:
return value.tolist() # numpy/pandas friendly
except: # noqa: E722
if not isinstance(sys.exc_info()[1], Exception):
raise
return str(value)
if isinstance(value, (str, int, float, bool)):
return value
return str(value)
def _extract_preprocessor_snapshot(self, preprocessor: Any) -> dict[str, Any] | None:
"""Build a lightweight snapshot describing the configured preprocessor."""
snapshot: dict[str, Any] = {}
getter = getattr(preprocessor, "get_mapping_snapshot", None)
if callable(getter):
try:
custom_snapshot = getter()
except: # noqa: E722
if not isinstance(sys.exc_info()[1], Exception):
raise
custom_snapshot = None
if custom_snapshot is not None:
snapshot["custom"] = self._serialise_preprocessor_value(custom_snapshot)
categories = getattr(preprocessor, "categories_", None)
if categories is not None:
snapshot["categories"] = self._serialise_preprocessor_value(categories)
transformers = getattr(preprocessor, "transformers_", None)
if transformers is not None:
serialised = []
for name, transformer, columns in transformers:
serialised.append(
{
"name": name,
"columns": self._serialise_preprocessor_value(columns),
"transformer": (
f"{transformer.__class__.__module__}:{transformer.__class__.__qualname__}"
if transformer is not None
else None
),
}
)
snapshot["transformers"] = serialised
feature_names_out = getattr(preprocessor, "get_feature_names_out", None)
if callable(feature_names_out):
with suppress(Exception):
snapshot["feature_names_out"] = list(feature_names_out())
mapping_attr = getattr(preprocessor, "mapping_", None)
if mapping_attr is not None:
snapshot["mapping"] = self._serialise_preprocessor_value(mapping_attr)
return snapshot or None
def _build_preprocessor_metadata(self) -> dict[str, Any] | None:
"""Return ADR-009 telemetry metadata for the active preprocessor."""
auto_encode_flag = self._normalize_auto_encode_flag()
preprocessor = getattr(self, "_preprocessor", None)
metadata: dict[str, Any] = {"auto_encode": auto_encode_flag}
if preprocessor is not None:
metadata["transformer_id"] = (
f"{preprocessor.__class__.__module__}:{preprocessor.__class__.__qualname__}"
)
snapshot = self._extract_preprocessor_snapshot(preprocessor)
if snapshot is not None:
metadata["mapping_snapshot"] = snapshot
if (
metadata.get("transformer_id") is None
and len(metadata) == 1
and auto_encode_flag == "auto"
):
return None
return metadata
def _raise_non_numeric_without_preprocessor(self, x: Any, stage: str) -> None:
"""Raise actionable diagnostics for non-numeric inputs when preprocessing is disabled."""
auto_encode_flag = self._normalize_auto_encode_flag()
if auto_encode_flag in {"auto", "true"}:
return
x_arr = x.to_numpy() if hasattr(x, "to_numpy") else x
dtype = getattr(x_arr, "dtype", None)
if dtype is not None and getattr(dtype, "kind", None) not in {"b", "i", "u", "f", "c"}:
raise ValidationError(
f"Non-numeric input detected during {stage} while preprocessing is disabled. "
"Set auto_encode='auto' or provide a preprocessor capable of handling categorical values."
)
def _pre_fit_preprocess(self, x: Any) -> Any:
"""Fit the configured preprocessor and return transformed x.
if a user-supplied preprocessor exposes
fit/transform, we use it. No built-in auto encoding is activated here.
"""
try:
# When no preprocessor is provided and auto_encode is enabled,
# activate the small deterministic builtin encoder.
if self._preprocessor is None:
# ADR-009 default mode: auto_encode='auto' activates deterministic
# built-in encoding when no user preprocessor is provided.
if self._normalize_auto_encode_flag() in {"auto", "true"}:
try:
from calibrated_explanations.preprocessing.builtin_encoder import (
BuiltinEncoder,
)
encoder = BuiltinEncoder(unseen_policy=self._unseen_category_policy)
x_out = encoder.fit_transform(x)
# attach encoder so export/import helpers can find it
self._preprocessor = encoder
self._pre_fitted = True
return x_out
except (
ImportError,
TypeError,
ValueError,
AttributeError,
) as exc: # pragma: no cover - defensive
self._logger.warning("Builtin encoder failed; bypassing: %s", exc)
return x
self._raise_non_numeric_without_preprocessor(x, stage="fit")
return x
if hasattr(self._preprocessor, "fit_transform"):
x_out = self._preprocessor.fit_transform(x)
else:
self._preprocessor.fit(x)
x_out = self._preprocessor.transform(x)
self._pre_fitted = True
return x_out
except: # noqa: E722
if not isinstance(sys.exc_info()[1], Exception):
raise
exc = sys.exc_info()[1]
if isinstance(exc, ValidationError):
raise
self._logger.warning("Preprocessor failed; proceeding without it: %s", exc)
return x
def _pre_transform(self, x: Any, stage: str = "predict") -> Any:
"""Transform x with the fitted preprocessor if available."""
try:
if self._preprocessor is None or not self._pre_fitted:
self._raise_non_numeric_without_preprocessor(x, stage=stage)
return x
return self._preprocessor.transform(x)
except: # noqa: E722
if not isinstance(sys.exc_info()[1], Exception):
raise
exc = sys.exc_info()[1]
pre = getattr(self, "_preprocessor", None)
unseen_policy = str(getattr(pre, "unseen_policy", "")).lower()
if isinstance(exc, (KeyError, ValidationError)) and unseen_policy == "error":
raise ValidationError(
f"Unseen category encountered during {stage} preprocessing. "
"Set unseen_category_policy='ignore' or import/export a stable mapping."
) from exc
self._logger.warning("Preprocessor transform failed at %s; bypassing: %s", stage, exc)
return x
def _maybe_preprocess_for_inference(self, x: Any) -> Any:
"""Apply preprocessing for inference paths if configured/fitted."""
return self._pre_transform(x, stage="inference")
def _finalize_fit(self, reinitialize: bool) -> WrapCalibratedExplainer:
"""Finalize fit logic shared across fit implementations.
Parameters
----------
reinitialize : bool
Whether an existing calibrated explainer should be reinitialized.
"""
check_is_fitted(self.learner)
self.fitted = True
if reinitialize and self.explainer is not None:
# Preserve calibration by updating underlying learner reference
self.explainer.reinitialize(self.learner)
self.calibrated = True
return self
def _format_proba_output(self, proba: Any, uq_interval: bool) -> Any:
"""Format probability output (with optional trivial intervals) without duplicating logic."""
if not uq_interval:
return proba
# Multiclass: return matrix and identical bounds
if proba.ndim == 2 and proba.shape[1] > 2:
return proba, (proba, proba)
# Binary (assume second column is positive class probability)
if proba.ndim == 2 and proba.shape[1] == 2:
return proba, (proba[:, 1], proba[:, 1])
# Fallback (unexpected shape) -> mirror array
return proba, (proba, proba)
# Public aliases for testing
[docs]
def serialise_preprocessor_value(self, value: Any) -> Any:
"""Serialise a preprocessor value for storage.
Parameters
----------
value : Any
The value to serialise.
Returns
-------
Any
The serialised value.
"""
return self._serialise_preprocessor_value(value)
[docs]
def pre_fit_preprocess(self, x: Any) -> Any:
"""Preprocess data before fitting.
Parameters
----------
x : Any
The input data.
Returns
-------
Any
The preprocessed data.
"""
return self._pre_fit_preprocess(x)
[docs]
def maybe_preprocess_for_inference(self, X: Any) -> Any:
"""Preprocess data for inference if needed.
Parameters
----------
X : Any
The input data.
Returns
-------
Any
The preprocessed data.
"""
return self._maybe_preprocess_for_inference(X)
[docs]
def export_preprocessor_mapping(self) -> dict[str, Any] | None:
"""Export the current preprocessor mapping snapshot.
Returns
-------
dict[str, Any] | None
A mapping snapshot suitable for telemetry or round-tripping, or
``None`` when no mapping information is available.
"""
pre = getattr(self, "_preprocessor", None)
if pre is None:
return None
# Prefer a custom getter when available
getter = getattr(pre, "get_mapping_snapshot", None)
if callable(getter):
try:
snapshot = getter()
if snapshot is not None:
if not isinstance(snapshot, Mapping):
raise ValidationError(
"Preprocessor mapping snapshot must be a mapping.",
details={"source": "get_mapping_snapshot"},
)
self._validate_json_safe_mapping(snapshot, source="get_mapping_snapshot")
return dict(snapshot)
return None
except ValidationError:
raise
except (AttributeError, TypeError, ValueError):
self._logger.warning(
"Preprocessor.get_mapping_snapshot failed; falling back to mapping_"
)
# Fall back to attribute if present
mapping_attr = getattr(pre, "mapping_", None)
if mapping_attr is not None:
# Shallow copy to avoid exposing internal objects
try:
snapshot = dict(mapping_attr)
self._validate_json_safe_mapping(snapshot, source="mapping_")
return snapshot
except ValidationError:
raise
except (AttributeError, TypeError, ValueError):
return None
return None
[docs]
def import_preprocessor_mapping(self, mapping: Mapping[str, Any]) -> None:
"""Attempt to apply a mapping snapshot to the configured preprocessor.
This is a best-effort helper: when an attached preprocessor exposes a
setter (``set_mapping``) or a writable ``mapping_`` attribute we will
apply the mapping. Otherwise the mapping is stashed on the wrapper as
``_imported_preprocessor_mapping`` for potential downstream use.
A warning is emitted when the mapping could not be applied to ensure
visibility per the fallback policy.
"""
self._validate_json_safe_mapping(mapping, source="import")
pre = getattr(self, "_preprocessor", None)
applied = False
if pre is not None:
setter = getattr(pre, "set_mapping", None)
if callable(setter):
try:
setter(mapping)
applied = True
except: # noqa: E722
if not isinstance(sys.exc_info()[1], Exception):
raise
self._logger.warning("Preprocessor.set_mapping failed; stashing mapping")
else:
# Try to set mapping_ directly when writable
try:
pre.mapping_ = mapping
applied = True
except: # noqa: E722
if not isinstance(sys.exc_info()[1], Exception):
raise
# fall through to stashing below
pass
if not applied:
# Keep for later application or external tooling
self._imported_preprocessor_mapping = dict(mapping) if mapping is not None else None
_warnings.warn(
"Preprocessor mapping could not be applied directly; mapping stashed on wrapper",
UserWarning,
stacklevel=2,
)
@staticmethod
def _validate_json_safe_mapping(mapping: Mapping[str, Any], *, source: str) -> None:
"""Validate that mapping snapshots are JSON-serialisable primitives.
Parameters
----------
mapping : Mapping[str, Any]
Mapping snapshot to validate.
source : str
Context string used in validation error details.
Raises
------
ValidationError
If the mapping cannot be serialised with standard JSON encoding.
"""
try:
json.dumps(mapping, sort_keys=True, separators=(",", ":"))
except (TypeError, ValueError) as exc:
raise ValidationError(
"Preprocessor mapping must be JSON-serialisable.",
details={"source": source, "error": str(exc)},
) from exc
def _state_path(self, path_or_fileobj: Any) -> Path:
"""Normalize and validate state path inputs."""
if hasattr(path_or_fileobj, "read") or hasattr(path_or_fileobj, "write"):
raise ValidationError(
"Only filesystem paths are supported for state persistence.",
details={"path_or_fileobj_type": type(path_or_fileobj).__name__},
)
try:
return Path(path_or_fileobj)
except TypeError as exc:
raise ValidationError(
"Invalid state path provided to save/load_state.",
details={"path_or_fileobj_type": type(path_or_fileobj).__name__},
) from exc
@staticmethod
def _sha256_bytes(payload: bytes) -> str:
"""Return SHA-256 checksum for raw bytes."""
return hashlib.sha256(payload).hexdigest()
@staticmethod
def _sha256_file(path: Path) -> str:
"""Return SHA-256 checksum for a file."""
digest = hashlib.sha256()
with path.open("rb") as handle:
for chunk in iter(lambda: handle.read(65536), b""):
digest.update(chunk)
return digest.hexdigest()
def _calibrator_to_primitive(self, calibrator: Any) -> dict[str, Any]:
"""Serialize a single calibrator into the ADR-031 primitive contract."""
to_primitive = getattr(calibrator, "to_primitive", None)
if callable(to_primitive):
primitive = to_primitive()
if isinstance(primitive, Mapping):
return dict(primitive)
payload_bytes = pickle.dumps(calibrator, protocol=pickle.HIGHEST_PROTOCOL)
return {
"schema_version": self._STATE_SCHEMA_VERSION,
"calibrator_type": "python_pickle",
"parameters": {
"class_name": calibrator.__class__.__name__,
"module": calibrator.__class__.__module__,
},
"checksums": {
"sha256": self._sha256_bytes(payload_bytes),
},
"payload": {
"pickle_b64": base64.b64encode(payload_bytes).decode("ascii"),
},
}
def _build_calibrator_primitive(self) -> dict[str, Any] | None:
"""Build calibrator primitive payload from the active explainer, if any."""
explainer = getattr(self, "explainer", None)
if explainer is None:
return None
calibrator = getattr(explainer, "interval_learner", None)
if calibrator is None:
return None
if isinstance(calibrator, (list, tuple)):
children = [self._calibrator_to_primitive(item) for item in calibrator]
payload_bytes = json.dumps(children, sort_keys=True).encode("utf-8")
return {
"schema_version": self._STATE_SCHEMA_VERSION,
"calibrator_type": "fast_collection",
"parameters": {"size": len(children)},
"checksums": {"sha256": self._sha256_bytes(payload_bytes)},
"calibrators": children,
}
return self._calibrator_to_primitive(calibrator)
@classmethod
def _restore_calibrator_from_primitive(cls, primitive: Mapping[str, Any]) -> Any:
"""Rehydrate a calibrator object from a persisted primitive payload."""
schema_version = primitive.get("schema_version")
if schema_version != cls._STATE_SCHEMA_VERSION:
raise IncompatibleStateError(
"Unsupported calibrator primitive schema_version.",
details={
"schema_version": schema_version,
"supported_versions": [cls._STATE_SCHEMA_VERSION],
},
)
calibrator_type = primitive.get("calibrator_type")
if calibrator_type == "venn_abers":
from ..calibration.venn_abers import VennAbers
return VennAbers.from_primitive(primitive)
if calibrator_type == "interval_regressor":
from ..calibration.interval_regressor import IntervalRegressor
return IntervalRegressor.from_primitive(primitive)
if calibrator_type == "fast_collection":
children = primitive.get("calibrators")
if not isinstance(children, list):
raise IncompatibleStateError(
"Invalid fast_collection primitive: expected calibrators list.",
details={"field": "calibrators"},
)
expected_sha = primitive.get("checksums", {}).get("sha256")
child_bytes = json.dumps(children, sort_keys=True).encode("utf-8")
actual_sha = cls._sha256_bytes(child_bytes)
if not isinstance(expected_sha, str) or expected_sha != actual_sha:
raise IncompatibleStateError(
"Calibrator primitive checksum validation failed.",
details={"expected_sha256": expected_sha, "actual_sha256": actual_sha},
)
return [cls._restore_calibrator_from_primitive(item) for item in children]
if calibrator_type == "python_pickle":
payload = primitive.get("payload")
if not isinstance(payload, Mapping) or not isinstance(payload.get("pickle_b64"), str):
raise IncompatibleStateError(
"Invalid python_pickle primitive payload.",
details={"field": "payload.pickle_b64"},
)
raw = base64.b64decode(payload["pickle_b64"].encode("ascii"))
expected_sha = primitive.get("checksums", {}).get("sha256")
actual_sha = cls._sha256_bytes(raw)
if not isinstance(expected_sha, str) or expected_sha != actual_sha:
raise IncompatibleStateError(
"Calibrator primitive checksum validation failed.",
details={"expected_sha256": expected_sha, "actual_sha256": actual_sha},
)
return pickle.loads(raw) # noqa: S301 # nosec B301 - trusted, checksum-validated payload
raise IncompatibleStateError(
"Unsupported calibrator_type in persisted state.",
details={"calibrator_type": calibrator_type},
)
def _build_explainer_config_payload(self) -> dict[str, Any]:
"""Build JSON-safe explainer configuration metadata for persistence."""
payload: dict[str, Any] = {}
explainer = getattr(self, "explainer", None)
if explainer is not None:
payload["mode"] = getattr(explainer, "mode", None)
payload["seed"] = getattr(explainer, "seed", None)
payload["condition_source"] = getattr(explainer, "condition_source", None)
payload["interval_summary"] = str(getattr(explainer, "interval_summary", ""))
payload["preprocessor_metadata"] = self._serialise_preprocessor_value(
getattr(explainer, "_preprocessor_metadata", None)
)
plugin_manager = getattr(explainer, "_plugin_manager", None)
if plugin_manager is not None:
payload["plugin_overrides"] = self._serialise_preprocessor_value(
getattr(plugin_manager, "plugin_overrides", None)
)
return payload
[docs]
def save_state(self, path_or_fileobj: Any) -> Path:
"""Persist wrapper state using an ADR-031 manifest + checksums."""
target = self._state_path(path_or_fileobj)
target_parent = target.parent
target_parent.mkdir(parents=True, exist_ok=True)
temp_dir_name = f"{target.name}.tmp-{os.getpid()}-{id(self)}"
temp_dir = Path(tempfile.mkdtemp(prefix=temp_dir_name, dir=str(target_parent)))
checksums: dict[str, str] = {}
try:
wrapper_bytes = pickle.dumps(self, protocol=pickle.HIGHEST_PROTOCOL)
wrapper_file = temp_dir / "wrapper.pkl"
wrapper_file.write_bytes(wrapper_bytes)
checksums["wrapper.pkl"] = self._sha256_bytes(wrapper_bytes)
calibrator_primitive = self._build_calibrator_primitive()
if calibrator_primitive is not None:
calibrator_file = temp_dir / "calibrator_primitive.json"
calibrator_bytes = json.dumps(
calibrator_primitive, indent=2, sort_keys=True
).encode("utf-8")
calibrator_file.write_bytes(calibrator_bytes)
checksums["calibrator_primitive.json"] = self._sha256_bytes(calibrator_bytes)
mapping = self.export_preprocessor_mapping()
if mapping is not None:
mapping_file = temp_dir / "preprocessing_mapping.json"
mapping_bytes = json.dumps(mapping, indent=2, sort_keys=True).encode("utf-8")
mapping_file.write_bytes(mapping_bytes)
checksums["preprocessing_mapping.json"] = self._sha256_bytes(mapping_bytes)
config_payload = self._build_explainer_config_payload()
config_file = temp_dir / "explainer_config.json"
config_bytes = json.dumps(config_payload, indent=2, sort_keys=True).encode("utf-8")
config_file.write_bytes(config_bytes)
checksums["explainer_config.json"] = self._sha256_bytes(config_bytes)
manifest = {
"schema_version": self._STATE_SCHEMA_VERSION,
"created_at_utc": datetime.now(timezone.utc).isoformat(),
"artifact_type": "wrap_calibrated_explainer_state",
"files": checksums,
}
manifest_file = temp_dir / "manifest.json"
manifest_file.write_text(
json.dumps(manifest, indent=2, sort_keys=True), encoding="utf-8"
)
backup: Path | None = None
if target.exists():
backup = target.with_name(f"{target.name}.bak-{os.getpid()}-{id(self)}")
os.replace(target, backup)
try:
replaced = False
last_permission_error: PermissionError | None = None
for _ in range(3):
try:
os.replace(temp_dir, target)
replaced = True
break
except PermissionError as exc:
last_permission_error = exc
sleep(0.05)
if not replaced:
if last_permission_error is not None:
self._logger.debug(
"os.replace failed during save_state; falling back to shutil.move: %s",
last_permission_error,
)
shutil.move(str(temp_dir), str(target))
except OSError:
if backup is not None and backup.exists() and not target.exists():
os.replace(backup, target)
raise
if backup is not None and backup.exists():
shutil.rmtree(backup)
return target
except (OSError, TypeError, ValueError, AttributeError) as exc:
if temp_dir.exists():
shutil.rmtree(temp_dir, ignore_errors=True)
raise ValidationError(
f"Failed to save state to '{target}'.",
details={"path": str(target), "reason": str(exc)},
) from exc
[docs]
@classmethod
def load_state(cls, path_or_fileobj: Any) -> WrapCalibratedExplainer:
"""Load wrapper state from an ADR-031 persisted artifact."""
temp_instance = cls.__new__(cls)
path = temp_instance._state_path(path_or_fileobj)
manifest_path = path / "manifest.json"
if not manifest_path.exists():
raise IncompatibleStateError(
"State artifact is missing manifest.json.",
details={"path": str(path)},
)
manifest = json.loads(manifest_path.read_text(encoding="utf-8"))
schema_version = manifest.get("schema_version")
if schema_version != cls._STATE_SCHEMA_VERSION:
raise IncompatibleStateError(
"Unsupported state schema_version.",
details={
"schema_version": schema_version,
"supported_versions": [cls._STATE_SCHEMA_VERSION],
},
)
files = manifest.get("files")
if not isinstance(files, Mapping):
raise IncompatibleStateError(
"Invalid state manifest: files checksum mapping missing.",
details={"field": "files"},
)
for file_name, expected_sha in files.items():
if not isinstance(file_name, str) or not isinstance(expected_sha, str):
raise IncompatibleStateError(
"Invalid state manifest: malformed checksum entry.",
details={"file": file_name, "checksum": expected_sha},
)
file_path = path / file_name
if not file_path.exists():
raise IncompatibleStateError(
"State artifact is incomplete: expected file is missing.",
details={"file": file_name},
)
actual_sha = cls._sha256_file(file_path)
if actual_sha != expected_sha:
raise IncompatibleStateError(
"State checksum validation failed.",
details={
"file": file_name,
"expected_sha256": expected_sha,
"actual_sha256": actual_sha,
},
)
wrapper_bytes = (path / "wrapper.pkl").read_bytes()
wrapper = pickle.loads(wrapper_bytes) # noqa: S301 # nosec B301 - trusted, checksum-validated payload
if not isinstance(wrapper, cls):
raise IncompatibleStateError(
"Persisted wrapper payload restored unexpected object type.",
details={"restored_type": type(wrapper).__name__},
)
primitive_path = path / "calibrator_primitive.json"
if primitive_path.exists():
primitive = json.loads(primitive_path.read_text(encoding="utf-8"))
restored = cls._restore_calibrator_from_primitive(primitive)
if getattr(wrapper, "explainer", None) is not None:
wrapper.explainer.interval_learner = restored
mapping_path = path / "preprocessing_mapping.json"
if mapping_path.exists():
mapping_payload = json.loads(mapping_path.read_text(encoding="utf-8"))
if isinstance(mapping_payload, Mapping):
wrapper.import_preprocessor_mapping(mapping_payload)
return wrapper
@property
def pre_fitted(self) -> bool:
"""Check if the preprocessor is pre-fitted.
Returns
-------
bool
True if pre-fitted, False otherwise.
"""
return self._pre_fitted
[docs]
def finalize_fit(self, reinitialize: bool) -> WrapCalibratedExplainer:
"""Finalize the fitting process.
Parameters
----------
reinitialize : bool
Whether to reinitialize.
Returns
-------
WrapCalibratedExplainer
The finalized explainer.
"""
return self._finalize_fit(reinitialize)
[docs]
def normalize_auto_encode_flag(self, auto_encode: Any = None) -> bool:
"""Normalize the auto encode flag.
Parameters
----------
auto_encode : Any, optional
The auto encode value.
Returns
-------
bool
The normalized flag.
"""
# Public adapter: legacy callers may pass no argument. The
# internal helper reads `self._auto_encode` so ignore any
# provided value and delegate to the internal normaliser.
return self._normalize_auto_encode_flag()
[docs]
def normalize_public_kwargs(self, payload: Any = None, **kwargs: Any) -> Dict[str, Any]:
"""Normalize public keyword arguments.
Parameters
----------
payload : Any, optional
The payload.
**kwargs : Any
Additional keyword arguments.
Returns
-------
Dict[str, Any]
The normalized kwargs.
"""
# Accept either positional (payload, allowed=...) or keyword-only usage
if payload is None:
return self._normalize_public_kwargs(**kwargs)
return self._normalize_public_kwargs(payload, **kwargs)
@property
def cfg(self) -> Any:
"""Configuration property.
Returns
-------
Any
The configuration.
"""
return self._cfg
def __getstate__(self):
"""Get state for pickling.
Returns
-------
dict
The state dictionary.
"""
state = self.__dict__.copy()
# Exclude mc as it may contain unpicklable objects like RNG in mappingproxy
state["mc"] = None
# Convert any types.MappingProxyType (mappingproxy) instances to plain
# dicts recursively so pickle/joblib can serialize them.
def _convert(obj: Any) -> Any:
if isinstance(obj, MappingProxyType):
# Recursively convert mappingproxy to plain dict and convert
# nested values as well.
return _convert(dict(obj))
if isinstance(obj, dict):
return {k: _convert(v) for k, v in obj.items()}
if isinstance(obj, (list, tuple, set)):
cls = type(obj)
converted = [_convert(v) for v in obj]
return cls(converted)
return obj
for k, v in list(state.items()):
try:
state[k] = _convert(v)
except (TypeError, AttributeError, RecursionError) as exc:
# Defensive: if conversion fails due to type/attribute/recursion
# issues, leave original value and hope it's picklable; avoid
# failing during state build. Suppress the same specific
# exceptions when logging to satisfy ADR-002.
with suppress((TypeError, AttributeError, RecursionError)):
self._logger.debug("__getstate__ conversion skipped for %s: %s", k, exc)
continue
return state
def __setstate__(self, state):
"""Set state for unpickling.
Parameters
----------
state : dict
The state dictionary.
"""
self.__dict__.update(state)
__all__ = ["WrapCalibratedExplainer"]