"""Explain black-box learners using calibrated prediction intervals.
This module implements the core :class:`CalibratedExplainer` which fits
interval calibrators on calibration data and exposes methods for generating
factual and alternative explanations augmented with uncertainty information.
The implementation follows the approach described in
"Calibrated Explanations: with Uncertainty Information and Counterfactuals"
by Helena Löfström et al.
"""
# pylint: disable=unknown-option-value
# pylint: disable=invalid-name, line-too-long, too-many-lines, too-many-positional-arguments, too-many-public-methods
from __future__ import annotations
import copy
import logging
import sys
import warnings
import contextlib
from time import time
from typing import TYPE_CHECKING
import numpy as np
from typing import Any, Dict, List, Mapping, Optional, Tuple
if TYPE_CHECKING:
from ..explanations import AlternativeExplanations, CalibratedExplanations
from ..plugins.manager import PluginManager
try:
import tomllib as _tomllib
except ModuleNotFoundError: # pragma: no cover - fallback for <3.11
try: # pragma: no cover - optional dependency path
import tomli as _tomllib # type: ignore[assignment]
except ModuleNotFoundError: # pragma: no cover - tomllib unavailable
_tomllib = None # type: ignore[assignment]
# Core imports (no cross-sibling dependencies)
from ..calibration.interval_wrappers import is_fast_interval_collection
from ..utils import check_is_fitted, convert_targets_to_numeric, deprecate, safe_isinstance
from ..utils.exceptions import (
DataShapeError,
ValidationError,
)
from .reject.policy import RejectPolicy
from .prediction.interval_summary import IntervalSummary, coerce_interval_summary
# Lazy imports deferred to avoid cross-sibling coupling
# These are imported inside methods/properties where used
# - perf (CalibratorCache, ParallelExecutor) - lazy in __init__
# - plotting (_plot_global) - lazy in plotting methods
# - explanations (AlternativeExplanations, CalibratedExplanations) - lazy as needed
# - integrations (LimeHelper, ShapHelper) - lazy in __init__
# - api.params (canonicalize_kwargs, etc.) - lazy in param handling
# - plugins (IntervalCalibratorContext, PluginManager, LegacyPredictBridge) - lazy in __init__
# - utils.discretizers (EntropyDiscretizer, RegressorDiscretizer) - lazy in validation
[docs]
class CalibratedExplainer:
"""Explain a fitted learner using calibrated intervals and plugins.
The explainer fits internal interval calibrators on provided calibration
data and exposes high-level APIs for producing `CalibratedExplanations`.
Recommended use is to use `WrapCalibratedExplainer`, which is a wrapper around the learner and this explainer.
Examples
--------
>>> from calibrated_explanations import CalibratedExplainer
>>> explainer = CalibratedExplainer(learner, X_cal, y_cal, mode="classification")
>>> explanations = explainer.explain_factual(X_test)
"""
# pylint: disable=too-many-instance-attributes, too-many-arguments, too-many-locals, too-many-branches, too-many-statements
def __init__(
self,
learner,
x_cal,
y_cal,
mode="classification",
feature_names=None,
categorical_features=None,
categorical_labels=None,
class_labels=None,
bins=None,
difficulty_estimator=None,
**kwargs,
) -> None:
"""Initialize the explainer with calibration data and metadata.
Parameters
----------
learner : Any
Predictive learner that must already expose ``fit``/``predict`` and,
for classification, ``predict_proba``.
x_cal : array-like of shape (n_calibration_samples, n_features)
Calibration feature matrix used to fit interval calibrators.
y_cal : array-like of shape (n_calibration_samples,)
Calibration targets paired with ``x_cal``.
mode : {"classification", "regression"}, default="classification"
Operating mode controlling which calibrators/plugins are used.
feature_names : Sequence[str] or None, optional
Optional list of human-readable feature names.
categorical_features : Sequence[int] or None, optional
Indices describing which features should be treated as categorical.
categorical_labels : Mapping[int, Mapping[int, str]] or None, optional
Optional mapping translating categorical feature values to labels.
class_labels : Mapping[int, str] or None, optional
Optional mapping translating class indices to display labels.
bins : array-like or None, optional
Pre-computed Mondrian categories for fast explanations.
difficulty_estimator : Any or None, optional
Optional crepes ``DifficultyEstimator`` instance for regression tasks.
**kwargs : Any
Advanced configuration flags preserved for backward compatibility.
Includes `condition_source` ("observed" or "prediction", default="prediction").
Notes
-----
Minimal lifecycle logging is available at INFO level. To enable, run::
import logging
logging.getLogger("calibrated_explanations").setLevel(logging.INFO)
"""
perf_cache = kwargs.pop("perf_cache", None)
perf_parallel = kwargs.pop("perf_parallel", None)
init_time = time()
self.__initialized = False
preprocessor_metadata = kwargs.pop("preprocessor_metadata", None)
if isinstance(preprocessor_metadata, Mapping):
self._preprocessor_metadata: Dict[str, Any] | None = dict(preprocessor_metadata)
else:
self._preprocessor_metadata = None
check_is_fitted(learner)
self.learner = learner
self.predict_function = kwargs.get("predict_function")
if self.predict_function is None:
self.predict_function = (
learner.predict_proba if mode == "classification" else learner.predict
)
# Optionally suppress or convert low-level crepes errors into clearer messages.
# Caller can pass suppress_crepes_errors=True via kwargs to avoid raising on
# crepes broadcasting/shape errors (useful for synthetic tiny datasets).
self.suppress_crepes_errors = bool(kwargs.get("suppress_crepes_errors", False))
self.oob = kwargs.get("oob", False)
self._categorical_value_counts_cache: Dict[int, Dict[Any, int]] | None = None
self._numeric_sorted_cache: Dict[int, np.ndarray] | None = None
self._calibration_summary_shape: Tuple[int, int] | None = None
if self.oob:
if mode == "classification":
y_oob_proba = self.learner.oob_decision_function_
if (
len(y_oob_proba.shape) == 1 or y_oob_proba.shape[1] == 1
): # Binary classification
y_oob = (y_oob_proba > 0.5).astype(np.dtype(y_cal.dtype))
else: # Multiclass classification
y_oob = np.argmax(y_oob_proba, axis=1)
if safe_isinstance(y_cal, "pandas.core.arrays.categorical.Categorical"):
y_oob = y_cal.categories[y_oob]
else:
y_oob = y_oob.astype(np.dtype(y_cal.dtype))
else:
y_oob = self.learner.oob_prediction_
if len(x_cal) != len(y_oob):
raise DataShapeError(
"The length of the out-of-bag predictions does not match the length of x_cal."
)
y_cal = y_oob
self.x_cal = x_cal
self.y_cal = y_cal
# Initialize RNG with seed
from ..utils import set_rng_seed # pylint: disable=import-outside-toplevel
seed = kwargs.get("seed", 42)
self.seed = seed
self.rng = set_rng_seed(seed)
self.sample_percentiles = kwargs.get("sample_percentiles", [25, 50, 75])
self.verbose = kwargs.get("verbose", False)
self.bins = bins
self.interval_summary = coerce_interval_summary(
kwargs.get("interval_summary", IntervalSummary.REGULARIZED_MEAN)
)
self.__fast = kwargs.get("fast", False)
self.__noise_type = kwargs.get("noise_type", "uniform")
self.__scale_factor = kwargs.get("scale_factor", 5)
self.__severity = kwargs.get("severity", 1)
# Prefer explicit caller value; otherwise default to 'prediction' as of v0.10.3
if "condition_source" in kwargs:
self.condition_source = kwargs.get("condition_source")
else:
self.condition_source = "prediction"
logging.getLogger(__name__).info(
"condition_source not provided; defaulting to 'prediction' (v0.10.3)"
)
if self.verbose:
warnings.warn(
"condition_source not provided; defaulting to 'prediction' in v0.10.3. "
"Pass condition_source='observed' to retain previous behaviour.",
UserWarning,
stacklevel=2,
)
if self.condition_source not in {"observed", "prediction"}:
raise ValidationError(
"condition_source must be either 'observed' or 'prediction'",
details={
"param": "condition_source",
"value": self.condition_source,
"allowed": ("observed", "prediction"),
},
)
self.categorical_labels = categorical_labels
self.class_labels = class_labels
if categorical_features is None:
if categorical_labels is not None:
categorical_features = categorical_labels.keys()
else:
categorical_features = []
self.categorical_features = list(categorical_features)
self._invalidate_calibration_summaries()
self.features_to_ignore = kwargs.get("features_to_ignore", [])
# Identify constant calibration features that can be ignored downstream
from .calibration_helpers import identify_constant_features # pylint: disable=import-outside-toplevel
constant_ignore = identify_constant_features(self.x_cal)
try:
self.features_to_ignore = (
np.union1d(self.features_to_ignore, constant_ignore).astype(int).tolist()
)
except (TypeError, ValueError):
# Be defensive: if union fails due to incompatible types, fall back to constants.
self.features_to_ignore = list(constant_ignore)
if feature_names is None:
feature_names = (
self._X_cal[0].keys()
if isinstance(self._X_cal[0], dict)
else [str(i) for i in range(self.num_features)]
)
self._feature_names = list(feature_names)
if mode == "classification":
if any(isinstance(val, str) for val in self.y_cal) or any(
isinstance(val, (np.str_, np.object_)) for val in self.y_cal
):
self.y_cal_numeric, self.label_map = convert_targets_to_numeric(self.y_cal)
self.y_cal = self.y_cal_numeric # save to _y_cal to avoid append
if self.class_labels is None:
self.class_labels = {v: k for k, v in self.label_map.items()}
else:
self.label_map = None
if self.class_labels is None:
self.class_labels = {int(label): str(label) for label in np.unique(self.y_cal)}
else:
self.label_map = None
self.class_labels = None
self.discretizer: Any = None
self.discretized_X_cal: Optional[np.ndarray] = None
# Predeclare attributes for fast mode to satisfy type checkers
self.fast_x_cal: Optional[np.ndarray] = None
self.scaled_x_cal: Optional[np.ndarray] = None
self.scaled_y_cal: Optional[np.ndarray] = None
self.feature_values: Dict[int, List[Any]] = {}
self.feature_frequencies: Dict[int, np.ndarray] = {}
# Lazy import helper integrations (deferred from module level)
from ..integrations import LimeHelper, ShapHelper
self.latest_explanation: Optional[CalibratedExplanations] = None
self._lime_helper = LimeHelper(self)
self._shap_helper = ShapHelper(self)
self.reject = kwargs.get("reject", False)
# Optional default reject policy for explainer-level defaults
from .reject.policy import RejectPolicy as _RejectPolicy
self.default_reject_policy = kwargs.get("default_reject_policy", _RejectPolicy.NONE)
self.set_difficulty_estimator(difficulty_estimator, initialize=False)
self.set_mode(str.lower(mode), initialize=False)
# Lazy import orchestrator and plugin management (deferred from module level)
from ..plugins.manager import PluginManager
from ..plugins.builtins import LegacyPredictBridge
from ..cache import CalibratorCache
# Initialize plugin manager (SINGLE SOURCE OF TRUTH for plugin management)
# PluginManager handles ALL plugin initialization including:
# - Reading pyproject.toml configurations
# - Setting up plugin overrides from kwargs
# - Creating and initializing orchestrators
# - Building plugin fallback chains
self.plugin_manager = PluginManager(self)
self.plugin_manager.initialize_from_kwargs(kwargs)
self.plugin_manager.initialize_orchestrators()
# Initialize interval learner after orchestrators are ready
self.prediction_orchestrator.interval_registry.initialize()
self.perf_cache: CalibratorCache[Any] | None = perf_cache
# Initialize parallel executor (ADR-004: Honor CE_PARALLEL overrides)
self._perf_parallel: Any | None = self._resolve_parallel_executor(perf_parallel)
# Orchestrator references are now accessed via properties that delegate to PluginManager
# No direct assignment needed - properties handle the delegation
# Reject learner initialization
self.reject_learner = (
self.initialize_reject_learner() if kwargs.get("reject", False) else None
)
self._predict_bridge = LegacyPredictBridge(self)
self.init_time = time() - init_time
# TODO: Needs to be
def __deepcopy__(self, memo):
"""Safely deepcopy the explainer, handling circular references."""
if id(self) in memo:
return memo[id(self)]
# Create a shallow copy without calling __init__
cls = self.__class__
result = cls.__new__(cls)
memo[id(self)] = result
# Manually copy attributes
# Some attributes are runtime helpers or refer back into the explainer
# (plugin manager, parallel executor, caches, integration helpers, etc.).
# Deep-copying these can cause recursion or try to copy unpicklable objects.
# Shallow-copy them instead to preserve references and avoid recursion.
shallow_copy_keys = {
"_plugin_manager",
"_perf_parallel",
"perf_cache",
"_lime_helper",
"_shap_helper",
"_predict_bridge",
"latest_explanation",
"learner",
"predict_function",
"rng",
}
for k, v in self.__dict__.items():
if k in shallow_copy_keys:
# ADR002_ALLOW: swallowing to keep deepcopy best-effort.
with contextlib.suppress(Exception):
setattr(result, k, v)
continue
try:
setattr(result, k, copy.deepcopy(v, memo))
except (
Exception
): # ADR002_ALLOW: fallback to shallow copy when deepcopy fails. # pragma: no cover
# Fallback: if deepcopy fails for any reason, keep original reference.
# ADR002_ALLOW: ignore attributes that cannot be copied.
with contextlib.suppress(Exception):
setattr(result, k, v)
return result
def __getstate__(self):
"""Exclude runtime helpers when pickling."""
state = self.__dict__.copy()
state["perf_cache"] = None
state["_perf_parallel"] = None
return state
def __setstate__(self, state):
"""Restore state after pickling without restoring helpers."""
self.__dict__.update(state)
[docs]
def require_plugin_manager(self) -> PluginManager:
"""Return the plugin manager or raise if the explainer is not initialized.
Returns
-------
PluginManager
The active plugin manager instance.
Raises
------
NotFittedError
If the plugin manager is not initialized.
"""
from ..utils.exceptions import NotFittedError
manager = getattr(self, "_plugin_manager", None)
if manager is None:
raise NotFittedError(
"PluginManager is not initialized. Instantiate CalibratedExplainer via __init__.",
details={
"state": "uninitialized",
"reason": "plugin_manager_missing",
"required_method": "__init__",
},
)
return manager
[docs]
def get_plugin_manager(self) -> PluginManager:
"""Return the active plugin manager, applying any derived defaults.
Wrapper layers must not mutate plugin manager state directly. Any
runtime-derived plugin preferences (for example, feature filter
execution requirements) are enforced here so orchestration remains
centralized in the explainer/manager layers.
"""
manager = self.require_plugin_manager()
self._enforce_feature_filter_plugin_preferences(manager)
return manager
def _deprecate_nonessential_surface(self, symbol: str, replacement: str) -> None:
"""Emit ADR-011 deprecation for compatibility delegators on explainer."""
deprecate(
f"CalibratedExplainer.{symbol} is deprecated since v0.11.1; use "
f"{replacement} instead. This compatibility delegator will be removed "
"no earlier than v0.13.0 and only as a major-release change in v1.0.0+ "
"per ADR-020.",
key=f"CalibratedExplainer.{symbol}_delegator_deprecation",
stacklevel=3,
)
def _enforce_feature_filter_plugin_preferences(self, manager: PluginManager) -> None:
cfg = getattr(self, "_feature_filter_config", None)
enabled = getattr(cfg, "enabled", False)
if enabled is not True:
return
override_id = "core.explanation.factual.sequential"
logger = logging.getLogger(__name__)
try:
chain = manager.explanation_plugin_fallbacks.get("factual", ())
except Exception as exc: # adr002_allow
logger.warning(
"Failed to read explanation plugin fallback chain; feature filter enforcement skipped: %s",
exc,
exc_info=True,
)
warnings.warn(
"Feature filter is enabled but plugin fallback chain could not be read; "
"continuing with the configured explanation plugin selection.",
UserWarning,
stacklevel=3,
)
return
if chain and chain[0] == override_id:
return
if not chain:
try:
manager.initialize_chains()
chain = manager.explanation_plugin_fallbacks.get("factual", ())
except Exception as exc: # adr002_allow
logger.warning(
"Failed to initialize plugin chains; feature filter enforcement skipped: %s",
exc,
exc_info=True,
)
warnings.warn(
"Feature filter is enabled but plugin chains could not be initialized; "
"continuing with the configured explanation plugin selection.",
UserWarning,
stacklevel=3,
)
return
if chain and chain[0] == override_id:
return
previous = chain[0] if chain else None
logger.warning(
"Feature filter enabled; forcing factual explanation plugin to '%s' (was '%s')",
override_id,
previous,
extra={"mode": "factual", "plugin_identifier": override_id},
)
warnings.warn(
f"Feature filter is enabled; overriding the factual explanation plugin from '{previous}' "
f"to '{override_id}'.",
UserWarning,
stacklevel=3,
)
try:
manager.explanation_plugin_overrides["factual"] = override_id
manager.clear_explanation_plugin_instances()
manager.clear_explanation_plugin_identifiers()
manager.initialize_chains()
except Exception as exc: # adr002_allow
logger.warning(
"Failed to enforce factual explanation plugin for feature filter: %s",
exc,
exc_info=True,
)
warnings.warn(
"Feature filter is enabled but forcing the factual explanation plugin failed; "
"continuing with the configured explanation plugin selection.",
UserWarning,
stacklevel=3,
)
def _resolve_parallel_executor(self, explicit_executor: Any | None) -> Any | None:
"""Resolve the parallel executor honoring overrides and environment config."""
return self.resolve_parallel_executor(explicit_executor)
[docs]
def resolve_parallel_executor(self, explicit_executor: Any | None) -> Any | None:
"""Resolve the parallel executor honoring overrides and environment config."""
from ..parallel import ParallelConfig, ParallelExecutor
if explicit_executor is not None:
return explicit_executor
env_config = ParallelConfig.from_env()
if env_config.enabled:
return ParallelExecutor(env_config)
return None
# ------------------------------------------------------------------
# Parallel pool lifecycle helpers
# ------------------------------------------------------------------
[docs]
def initialize_pool(self, n_workers: int | None = None, *, pool_at_init: bool = False) -> None:
"""Create a `ParallelExecutor` for this explainer.
Parameters
----------
n_workers: int | None
Optional maximum worker count to enforce.
pool_at_init: bool
If True, enter the pool immediately so worker processes are
spawned at initialization time (useful for warm-up and
initializer-based harness installation).
"""
from ..parallel import ParallelConfig, ParallelExecutor
if getattr(self, "_perf_parallel", None) is not None:
return
cfg = ParallelConfig.from_env()
cfg.enabled = True
if n_workers is not None:
cfg.max_workers = n_workers
# If requested, set up a worker initializer that will receive a
# compact explainer spec. Keep the spec deliberately small and
# picklable.
if pool_at_init:
# ADR002_ALLOW: optional initializer wiring should not block.
with contextlib.suppress(Exception):
import calibrated_explanations.core.explain.parallel_runtime as pr_mod
# Build a picklable compact spec containing only the data
# required to rehydrate an explainer in worker processes.
# Attempt to include a picklable learner payload. If the
# learner is not picklable, fall back to omitting it so the
# worker initializer must handle a missing learner case.
learner_bytes = None
try:
import pickle # nosec B403
learner_bytes = pickle.dumps(getattr(self, "learner", None))
except (
Exception
): # ADR002_ALLOW: learner pickling best-effort fallback. # pragma: no cover
learner_bytes = None
spec = {
"learner_bytes": learner_bytes,
"x_cal": getattr(self, "x_cal", None),
"y_cal": getattr(self, "y_cal", None),
"mode": getattr(self, "mode", None),
"num_features": getattr(self, "num_features", None),
"bins": getattr(self, "bins", None),
"sample_percentiles": getattr(self, "sample_percentiles", None),
}
cfg.worker_initializer = pr_mod.worker_init_from_explainer_spec
cfg.worker_init_args = (spec,)
self._perf_parallel = ParallelExecutor(cfg)
if pool_at_init:
# Enter context to spawn worker pool now
self._perf_parallel.__enter__()
[docs]
def close(self) -> None:
"""Reset runtime state, then shutdown any provisioned parallel pool."""
self.reset()
perf = getattr(self, "_perf_parallel", None)
if perf is None:
return
try:
perf.__exit__(None, None, None)
finally:
self._perf_parallel = None
[docs]
def reset(self) -> None:
"""Clear transient runtime state retained between explanation calls."""
self.latest_explanation = None
for helper_name in ("_lime_helper", "_shap_helper"):
helper = getattr(self, helper_name, None)
if helper is not None and hasattr(helper, "reset"):
helper.reset()
plugin_manager = getattr(self, "_plugin_manager", None)
if plugin_manager is None:
return
with contextlib.suppress(Exception):
plugin_manager.clear_explanation_plugin_instances()
with contextlib.suppress(Exception):
plugin_manager.clear_explanation_plugin_identifiers()
with contextlib.suppress(Exception):
plugin_manager.clear_bridge_monitors()
contexts = getattr(plugin_manager, "explanation_contexts", None)
if isinstance(contexts, dict):
contexts.clear()
def __enter__(self) -> "CalibratedExplainer":
"""Context manager entry; create and enter a worker pool."""
self.initialize_pool(pool_at_init=True)
return self
def __exit__(self, exc_type: Any, exc_val: Any, exc_tb: Any) -> None:
"""Context manager exit; close any provisioned pool."""
self.close()
[docs]
def infer_explanation_mode(self) -> str:
"""Infer the explanation mode from runtime state."""
# Lazy import discretizers (deferred from module level)
from ..utils import EntropyDiscretizer, RegressorDiscretizer
# Check discretizer type to infer mode
discretizer = self.discretizer if hasattr(self, "discretizer") else None
if discretizer is not None and isinstance(
discretizer, (EntropyDiscretizer, RegressorDiscretizer)
):
return "alternative"
# All other discretizers (Binary*, or None) indicate factual
return "factual"
# ===================================================================
# Delegation methods for orchestrator operations
# ===================================================================
# These methods delegate to PluginManager and orchestrators.
# PluginManager is the single source of truth for plugin defaults and chains.
# Tests that call these directly MUST initialize PluginManager properly.
@property
def prediction_orchestrator(self) -> Any:
"""Return the PredictionOrchestrator provisioned by the PluginManager."""
return self.require_plugin_manager().prediction_orchestrator
@prediction_orchestrator.setter
def prediction_orchestrator(self, value: Any) -> None:
"""Set the PredictionOrchestrator."""
self.require_plugin_manager().prediction_orchestrator = value
@prediction_orchestrator.deleter
def prediction_orchestrator(self) -> None:
"""Delete the PredictionOrchestrator."""
del self.require_plugin_manager().prediction_orchestrator
@property
def explanation_orchestrator(self) -> Any:
"""Return the ExplanationOrchestrator provisioned by the PluginManager."""
return self.require_plugin_manager().explanation_orchestrator
@explanation_orchestrator.setter
def explanation_orchestrator(self, value: Any) -> None:
"""Set the ExplanationOrchestrator."""
self.require_plugin_manager().explanation_orchestrator = value
@explanation_orchestrator.deleter
def explanation_orchestrator(self) -> None:
"""Delete the ExplanationOrchestrator."""
del self.require_plugin_manager().explanation_orchestrator
@property
def reject_orchestrator(self) -> Any:
"""Return the RejectOrchestrator provisioned by the PluginManager."""
return self.require_plugin_manager().reject_orchestrator
@reject_orchestrator.setter
def reject_orchestrator(self, value: Any) -> None:
"""Set the RejectOrchestrator."""
self.require_plugin_manager().reject_orchestrator = value
@reject_orchestrator.deleter
def reject_orchestrator(self) -> None:
"""Delete the RejectOrchestrator."""
del self.require_plugin_manager().reject_orchestrator
[docs]
def build_plot_style_chain(self) -> Tuple[str, ...]:
"""Return the plot style chain.
This is the public replacement for the legacy internal helper. It delegates
to :class:`PluginManager` to construct the chain when available and
otherwise returns an empty tuple for minimal explainer stubs used in tests.
"""
self._deprecate_nonessential_surface(
"build_plot_style_chain", "plugin_manager.build_plot_chain"
)
return self.plugin_manager.build_plot_chain()
@property
def plot_style_chain(self) -> Tuple[str, ...]:
"""Return the plot style chain.
This property provides access to the current plot style chain used by the explainer.
"""
return self.build_plot_style_chain()
[docs]
def instantiate_plugin(self, prototype: Any) -> Any:
"""Delegate to ExplanationOrchestrator."""
self._deprecate_nonessential_surface(
"instantiate_plugin", "plugin_manager.explanation_orchestrator.instantiate_plugin"
)
return self.plugin_manager.explanation_orchestrator.instantiate_plugin(prototype)
[docs]
def build_instance_telemetry_payload(self, explanations: Any) -> Dict[str, Any]:
"""Delegate to ExplanationOrchestrator."""
return self.explanation_orchestrator.build_instance_telemetry_payload(explanations)
def _invoke_explanation_plugin(self, *args, **kwargs) -> Any:
"""Invoke the explanation plugin with the given parameters."""
return self.invoke_explanation_plugin(*args, **kwargs)
[docs]
def invoke_explanation_plugin(
self,
mode: str,
x: Any,
threshold: Any,
low_high_percentiles: Any,
bins: Any,
features_to_ignore: Any,
*,
extras: Mapping[str, Any] | None = None,
reject_policy: Any | None = None,
) -> Any:
"""Delegate to ExplanationOrchestrator."""
self._deprecate_nonessential_surface(
"invoke_explanation_plugin", "explanation_orchestrator.invoke"
)
# Reject integration (ADR-029):
# - default remains RejectPolicy.NONE (no reject)
# - per-call reject_policy overrides the explainer-level default_reject_policy
# Backward compatibility:
# - do not pass reject_policy=None / RejectPolicy.NONE through to orchestrator calls
from .reject.orchestrator import ( # pylint: disable=import-outside-toplevel
resolve_effective_reject_policy,
)
resolution = resolve_effective_reject_policy(
reject_policy,
self,
default_policy=getattr(self, "default_reject_policy", RejectPolicy.NONE),
logger=logging.getLogger(__name__),
)
effective_policy = resolution.policy
if effective_policy is RejectPolicy.NONE:
return self.explanation_orchestrator.invoke(
mode,
x,
threshold,
low_high_percentiles,
bins,
features_to_ignore,
extras=extras,
)
# Policy enabled: ensure reject orchestration and delegate via RejectOrchestrator
confidence = extras.get("confidence", 0.95) if isinstance(extras, Mapping) else 0.95
def _explain_fn(x_subset, **kw):
return self.explanation_orchestrator.invoke(
mode,
x_subset,
threshold,
low_high_percentiles,
kw.get("bins", bins),
features_to_ignore,
extras=extras,
_ce_skip_reject=True,
)
# Implicitly enable reject orchestration
try:
# ensure plugin manager has set up orchestrators
_ = self.reject_orchestrator
except Exception: # adr002_allow
# fallback: initialize via plugin manager if available
with contextlib.suppress(Exception):
self.plugin_manager.initialize_orchestrators()
result = self.reject_orchestrator.apply_policy(
effective_policy,
x,
explain_fn=_explain_fn,
bins=bins,
confidence=confidence,
threshold=threshold,
result_schema="v2",
)
try:
from ..explanations.reject import (
RejectResultV2, # pylint: disable=import-outside-toplevel
reject_result_v2_to_legacy,
)
if isinstance(result, RejectResultV2):
return reject_result_v2_to_legacy(result, emit_deprecation_warning=False)
except Exception as exc: # adr002_allow
logging.getLogger(__name__).debug(
"RejectResultV2 compatibility conversion failed in invoke_explanation_plugin: %s",
exc,
exc_info=True,
)
return result
[docs]
def ensure_interval_runtime_state(self) -> None:
"""Delegate to PredictionOrchestrator."""
self._deprecate_nonessential_surface(
"ensure_interval_runtime_state", "prediction_orchestrator.ensure_interval_runtime_state"
)
return self.prediction_orchestrator.ensure_interval_runtime_state()
[docs]
def gather_interval_hints(self, *, fast: bool) -> Tuple[str, ...]:
"""Delegate to PredictionOrchestrator."""
self._deprecate_nonessential_surface(
"gather_interval_hints", "prediction_orchestrator.gather_interval_hints"
)
return self.prediction_orchestrator.gather_interval_hints(fast=fast)
# ===================================================================
# Backward-compatibility properties for plugin state (via PluginManager)
# ===================================================================
# These properties delegate to PluginManager for backward compatibility
# with code that accesses plugin state directly from explainer.
@property
def _interval_plugin_hints(self) -> Dict[str, Tuple[str, ...]]:
"""Delegate to PluginManager."""
return self.plugin_manager.interval_plugin_hints
@_interval_plugin_hints.setter
def _interval_plugin_hints(self, value: Dict[str, Tuple[str, ...]]) -> None:
"""Delegate to PluginManager."""
self.plugin_manager.interval_plugin_hints = value
@_interval_plugin_hints.deleter
def _interval_plugin_hints(self) -> None:
"""Delegate to PluginManager."""
del self.plugin_manager.interval_plugin_hints
@property
def _interval_plugin_fallbacks(self) -> Dict[str, Tuple[str, ...]]:
"""Delegate to PluginManager."""
return self.plugin_manager.interval_plugin_fallbacks
@_interval_plugin_fallbacks.setter
def _interval_plugin_fallbacks(self, value: Dict[str, Tuple[str, ...]]) -> None:
"""Delegate to PluginManager."""
self.plugin_manager.interval_plugin_fallbacks = value
@_interval_plugin_fallbacks.deleter
def _interval_plugin_fallbacks(self) -> None:
"""Delegate to PluginManager."""
del self.plugin_manager.interval_plugin_fallbacks
@property
def _interval_preferred_identifier(self) -> Dict[str, str | None]:
"""Delegate to PluginManager."""
return self.plugin_manager.interval_preferred_identifier
@_interval_preferred_identifier.setter
def _interval_preferred_identifier(self, value: Dict[str, str | None]) -> None:
"""Delegate to PluginManager."""
self.plugin_manager.interval_preferred_identifier = value
@_interval_preferred_identifier.deleter
def _interval_preferred_identifier(self) -> None:
"""Delegate to PluginManager."""
del self.plugin_manager.interval_preferred_identifier
@property
def _telemetry_interval_sources(self) -> Dict[str, str | None]:
"""Delegate to PluginManager."""
return self.plugin_manager.telemetry_interval_sources
@_telemetry_interval_sources.setter
def _telemetry_interval_sources(self, value: Dict[str, str | None]) -> None:
"""Delegate to PluginManager."""
self.plugin_manager.telemetry_interval_sources = value
@_telemetry_interval_sources.deleter
def _telemetry_interval_sources(self) -> None:
"""Delegate to PluginManager."""
del self.plugin_manager.telemetry_interval_sources
@property
def _interval_plugin_identifiers(self) -> Dict[str, str | None]:
"""Delegate to PluginManager."""
return self.plugin_manager.interval_plugin_identifiers
@_interval_plugin_identifiers.setter
def _interval_plugin_identifiers(self, value: Dict[str, str | None]) -> None:
"""Delegate to PluginManager."""
self.plugin_manager.interval_plugin_identifiers = value
@_interval_plugin_identifiers.deleter
def _interval_plugin_identifiers(self) -> None:
"""Delegate to PluginManager."""
del self.plugin_manager.interval_plugin_identifiers
@property
def _interval_context_metadata(self) -> Dict[str, Dict[str, Any]]:
"""Delegate to PluginManager."""
return self.plugin_manager.interval_context_metadata
@_interval_context_metadata.setter
def _interval_context_metadata(self, value: Dict[str, Dict[str, Any]]) -> None:
"""Delegate to PluginManager."""
self.plugin_manager.interval_context_metadata = value
@_interval_context_metadata.deleter
def _interval_context_metadata(self) -> None:
"""Delegate to PluginManager."""
del self.plugin_manager.interval_context_metadata
@property
def plot_plugin_fallbacks(self) -> Dict[str, Tuple[str, ...]]:
"""Return the plot plugin fallback configuration.
Returns
-------
Dict[str, Tuple[str, ...]]
Mapping of mode to fallback identifiers.
"""
return self.plugin_manager.plot_plugin_fallbacks
@plot_plugin_fallbacks.setter
def plot_plugin_fallbacks(self, value: Dict[str, Tuple[str, ...]]) -> None:
"""Set the plot plugin fallback configuration."""
self.plugin_manager.plot_plugin_fallbacks = value
@property
def _explanation_plugin_overrides(self) -> Dict[str, Any]:
"""Delegate to PluginManager."""
return self.plugin_manager.explanation_plugin_overrides
@_explanation_plugin_overrides.setter
def _explanation_plugin_overrides(self, value: Dict[str, Any]) -> None:
"""Delegate to PluginManager."""
self.plugin_manager.explanation_plugin_overrides = value
@property
def _interval_plugin_override(self) -> Any:
"""Delegate to PluginManager."""
return self.plugin_manager.interval_plugin_override
@_interval_plugin_override.setter
def _interval_plugin_override(self, value: Any) -> None:
"""Delegate to PluginManager."""
self.plugin_manager.interval_plugin_override = value
@property
def _fast_interval_plugin_override(self) -> Any:
"""Delegate to PluginManager."""
return self.plugin_manager.fast_interval_plugin_override
@_fast_interval_plugin_override.setter
def _fast_interval_plugin_override(self, value: Any) -> None:
"""Delegate to PluginManager."""
self.plugin_manager.fast_interval_plugin_override = value
@property
def _plot_style_override(self) -> Any:
"""Delegate to PluginManager."""
return self.plugin_manager.plot_style_override
@_plot_style_override.setter
def _plot_style_override(self, value: Any) -> None:
"""Delegate to PluginManager."""
self.plugin_manager.plot_style_override = value
@property
def _explanation_plugin_instances(self) -> Dict[str, Any]:
"""Delegate to PluginManager."""
return self.plugin_manager.explanation_plugin_instances
@_explanation_plugin_instances.setter
def _explanation_plugin_instances(self, value: Dict[str, Any]) -> None:
"""Delegate to PluginManager."""
self.plugin_manager.explanation_plugin_instances = value
# Public aliases to replace test access of private members (safe one-line delegations)
@property
def plugin_manager(self) -> PluginManager:
"""Public accessor for the active PluginManager."""
return self.get_plugin_manager()
@plugin_manager.setter
def plugin_manager(self, value: Any) -> None:
"""Set the plugin manager for this explainer."""
self._plugin_manager = value
@plugin_manager.deleter
def plugin_manager(self) -> None:
"""Delete the plugin manager."""
if hasattr(self, "_plugin_manager"):
del self._plugin_manager
@property
def interval_plugin_hints(self) -> Dict[str, Tuple[str, ...]]:
"""Public alias for `_interval_plugin_hints`.
Tests should use this instead of accessing the private attribute.
"""
self._deprecate_nonessential_surface(
"interval_plugin_hints", "plugin_manager.interval_plugin_hints"
)
return self._interval_plugin_hints
@interval_plugin_hints.setter
def interval_plugin_hints(self, value: Dict[str, Tuple[str, ...]]) -> None:
self._interval_plugin_hints = value
@interval_plugin_hints.deleter
def interval_plugin_hints(self) -> None:
if hasattr(self, "plugin_manager"):
del self.plugin_manager.interval_plugin_hints
@property
def interval_plugin_fallbacks(self) -> Dict[str, Tuple[str, ...]]:
"""Public alias for `_interval_plugin_fallbacks`."""
self._deprecate_nonessential_surface(
"interval_plugin_fallbacks", "plugin_manager.interval_plugin_fallbacks"
)
return self._interval_plugin_fallbacks
@interval_plugin_fallbacks.setter
def interval_plugin_fallbacks(self, value: Dict[str, Tuple[str, ...]]) -> None:
self._interval_plugin_fallbacks = value
@interval_plugin_fallbacks.deleter
def interval_plugin_fallbacks(self) -> None:
if hasattr(self, "plugin_manager"):
del self.plugin_manager.interval_plugin_fallbacks
@property
def explanation_plugin_overrides(self) -> Dict[str, Any]:
"""Public alias for `_explanation_plugin_overrides`."""
self._deprecate_nonessential_surface(
"explanation_plugin_overrides", "plugin_manager.explanation_plugin_overrides"
)
if hasattr(self, "plugin_manager"):
return self._explanation_plugin_overrides
return {}
@explanation_plugin_overrides.setter
def explanation_plugin_overrides(self, value: Dict[str, Any]) -> None:
self._explanation_plugin_overrides = value
@property
def interval_plugin_override(self) -> Any:
"""Public alias for `_interval_plugin_override`."""
self._deprecate_nonessential_surface(
"interval_plugin_override", "plugin_manager.interval_plugin_override"
)
if hasattr(self, "plugin_manager"):
return self._interval_plugin_override
return None
@interval_plugin_override.setter
def interval_plugin_override(self, value: Any) -> None:
if hasattr(self, "plugin_manager"):
self._interval_plugin_override = value
# else do nothing
@property
def fast_interval_plugin_override(self) -> Any:
"""Public alias for `_fast_interval_plugin_override`."""
self._deprecate_nonessential_surface(
"fast_interval_plugin_override", "plugin_manager.fast_interval_plugin_override"
)
return self._fast_interval_plugin_override
@fast_interval_plugin_override.setter
def fast_interval_plugin_override(self, value: Any) -> None:
self._fast_interval_plugin_override = value
@property
def plot_style_override(self) -> Any:
"""Public alias for `_plot_style_override`."""
self._deprecate_nonessential_surface(
"plot_style_override", "plugin_manager.plot_style_override"
)
return self._plot_style_override
@plot_style_override.setter
def plot_style_override(self, value: Any) -> None:
self._plot_style_override = value
@property
def interval_preferred_identifier(self) -> Dict[str, str | None]:
"""Public alias for `_interval_preferred_identifier`."""
self._deprecate_nonessential_surface(
"interval_preferred_identifier", "plugin_manager.interval_preferred_identifier"
)
return self._interval_preferred_identifier
@interval_preferred_identifier.setter
def interval_preferred_identifier(self, value: Dict[str, str | None]) -> None:
self._interval_preferred_identifier = value
@interval_preferred_identifier.deleter
def interval_preferred_identifier(self) -> None:
"""Delete the interval preferred identifier."""
del self._interval_preferred_identifier
@property
def telemetry_interval_sources(self) -> Dict[str, str | None]:
"""Public alias for `_telemetry_interval_sources`."""
self._deprecate_nonessential_surface(
"telemetry_interval_sources", "plugin_manager.telemetry_interval_sources"
)
return self._telemetry_interval_sources
@telemetry_interval_sources.setter
def telemetry_interval_sources(self, value: Dict[str, str | None]) -> None:
self._telemetry_interval_sources = value
@telemetry_interval_sources.deleter
def telemetry_interval_sources(self) -> None:
"""Delete the telemetry interval sources."""
del self._telemetry_interval_sources
@property
def interval_plugin_identifiers(self) -> Dict[str, str | None]:
"""Public alias for `_interval_plugin_identifiers`."""
self._deprecate_nonessential_surface(
"interval_plugin_identifiers", "plugin_manager.interval_plugin_identifiers"
)
return self._interval_plugin_identifiers
@interval_plugin_identifiers.setter
def interval_plugin_identifiers(self, value: Dict[str, str | None]) -> None:
self._interval_plugin_identifiers = value
@property
def preprocessor_metadata(self) -> Any:
"""Public alias for `_preprocessor_metadata`."""
return self._preprocessor_metadata
@preprocessor_metadata.setter
def preprocessor_metadata(self, value: Any) -> None:
self._preprocessor_metadata = value
@property
def feature_names_internal(self) -> Any:
"""Public alias for `_feature_names`."""
return self._feature_names
@feature_names_internal.setter
def feature_names_internal(self, value: Any) -> None:
self._feature_names = value
@property
def perf_parallel(self) -> bool:
"""Public alias for `_perf_parallel`."""
return self._perf_parallel
@perf_parallel.setter
def perf_parallel(self, value: bool) -> None:
self._perf_parallel = value
@property
def get_sigma_test(self) -> bool:
"""Public alias for `_get_sigma_test`."""
return self._get_sigma_test
@get_sigma_test.setter
def get_sigma_test(self, value: bool) -> None:
self._get_sigma_test = value
[docs]
def initialize_interval_learner_for_fast_explainer(self, *args, **kwargs) -> Any:
"""Public alias for internal interval learner initialization."""
return self._CalibratedExplainer__initialize_interval_learner_for_fast_explainer(
*args, **kwargs
)
@interval_plugin_identifiers.deleter
def interval_plugin_identifiers(self) -> None:
"""Delete the interval plugin identifiers."""
del self._interval_plugin_identifiers
@property
def interval_context_metadata(self) -> Dict[str, Dict[str, Any]]:
"""Public alias for `_interval_context_metadata`."""
self._deprecate_nonessential_surface(
"interval_context_metadata", "plugin_manager.interval_context_metadata"
)
return self._interval_context_metadata
@interval_context_metadata.setter
def interval_context_metadata(self, value: Dict[str, Dict[str, Any]]) -> None:
self._interval_context_metadata = value
@interval_context_metadata.deleter
def interval_context_metadata(self) -> None:
"""Delete the interval context metadata."""
del self._interval_context_metadata
@property
def bridge_monitors(self) -> Dict[str, Any]:
"""Public alias for `_bridge_monitors`."""
return self._bridge_monitors
@bridge_monitors.setter
def bridge_monitors(self, value: Dict[str, Any]) -> None:
"""Set the bridge monitors."""
self.require_plugin_manager().bridge_monitors = value
@property
def explanation_plugin_instances(self) -> Dict[str, Any]:
"""Public alias for `_explanation_plugin_instances`."""
return self._explanation_plugin_instances
@explanation_plugin_instances.setter
def explanation_plugin_instances(self, value: Dict[str, Any]) -> None:
"""Set the explanation plugin instances."""
self.require_plugin_manager().explanation_plugin_instances = value
@property
def pyproject_explanations(self) -> Dict[str, Any] | None:
"""Public alias for `_pyproject_explanations`."""
return self._pyproject_explanations
@pyproject_explanations.setter
def pyproject_explanations(self, value: Dict[str, Any] | None) -> None:
self._pyproject_explanations = value
@property
def pyproject_intervals(self) -> Dict[str, Any] | None:
"""Public alias for `_pyproject_intervals`."""
return self._pyproject_intervals
@pyproject_intervals.setter
def pyproject_intervals(self, value: Dict[str, Any] | None) -> None:
self._pyproject_intervals = value
@property
def pyproject_plots(self) -> Dict[str, Any] | None:
"""Public alias for `_pyproject_plots`."""
return self._pyproject_plots
@pyproject_plots.setter
def pyproject_plots(self, value: Dict[str, Any] | None) -> None:
self._pyproject_plots = value
@property
def lime_helper(self) -> Any:
"""Public alias for `_lime_helper`."""
return self._lime_helper
@lime_helper.setter
def lime_helper(self, value: Any) -> None:
"""Set the LIME helper."""
self._lime_helper = value
@lime_helper.deleter
def lime_helper(self) -> None:
"""Delete the LIME helper."""
if hasattr(self, "_lime_helper"):
del self._lime_helper
@property
def shap_helper(self) -> Any:
"""Public alias for `_shap_helper`."""
return self._shap_helper
@shap_helper.setter
def shap_helper(self, value: Any) -> None:
"""Set the SHAP helper."""
self._shap_helper = value
@shap_helper.deleter
def shap_helper(self) -> None:
"""Delete the SHAP helper."""
if hasattr(self, "_shap_helper"):
del self._shap_helper
@property
def initialized(self) -> bool:
"""Return True if the explainer is initialized."""
return getattr(self, "_CalibratedExplainer__initialized", False)
@initialized.setter
def initialized(self, value: bool) -> None:
"""Set the initialization state of the explainer."""
self.__initialized = value
@property
def is_initialized(self) -> bool:
"""Public check for whether the explainer has been initialized.
.. deprecated:: 0.10.1
Use :attr:`initialized` instead.
"""
return self.initialized
@property
def last_explanation_mode(self) -> str | None:
"""Return the mode of the last generated explanation."""
return self._last_explanation_mode
@last_explanation_mode.setter
def last_explanation_mode(self, value: str | None) -> None:
"""Set the mode of the last generated explanation."""
self._last_explanation_mode = value
@property
def feature_filter_per_instance_ignore(self) -> Any:
"""Return the per-instance feature filter ignore list."""
return getattr(self, "_feature_filter_per_instance_ignore", None)
@feature_filter_per_instance_ignore.setter
def feature_filter_per_instance_ignore(self, value: Any) -> None:
"""Set the per-instance feature filter ignore list."""
self._feature_filter_per_instance_ignore = value
@feature_filter_per_instance_ignore.deleter
def feature_filter_per_instance_ignore(self) -> None:
"""Delete the per-instance feature filter ignore list."""
if hasattr(self, "_feature_filter_per_instance_ignore"):
delattr(self, "_feature_filter_per_instance_ignore")
@property
def parallel_executor(self) -> Any:
"""Return the active parallel executor."""
return getattr(self, "_perf_parallel", None)
@parallel_executor.setter
def parallel_executor(self, value: Any) -> None:
"""Set the active parallel executor."""
self._perf_parallel = value
@property
def feature_filter_config(self) -> Any:
"""Return the feature filter configuration."""
return getattr(self, "_feature_filter_config", None)
@feature_filter_config.setter
def feature_filter_config(self, value: Any) -> None:
"""Set the feature filter configuration."""
self._feature_filter_config = value
@property
def predict_bridge(self) -> Any:
"""Return the prediction bridge."""
return getattr(self, "_predict_bridge", None)
@predict_bridge.setter
def predict_bridge(self, value: Any) -> None:
"""Set the prediction bridge."""
self._predict_bridge = value
@property
def categorical_value_counts_cache(self) -> Any:
"""Return the categorical value counts cache."""
return getattr(self, "_categorical_value_counts_cache", None)
@categorical_value_counts_cache.setter
def categorical_value_counts_cache(self, value: Any) -> None:
"""Set the categorical value counts cache."""
self._categorical_value_counts_cache = value
@property
def numeric_sorted_cache(self) -> Any:
"""Return the numeric sorted cache."""
return getattr(self, "_numeric_sorted_cache", None)
@numeric_sorted_cache.setter
def numeric_sorted_cache(self, value: Any) -> None:
"""Set the numeric sorted cache."""
self._numeric_sorted_cache = value
@property
def calibration_summary_shape(self) -> Any:
"""Return the calibration summary shape."""
return getattr(self, "_calibration_summary_shape", None)
@calibration_summary_shape.setter
def calibration_summary_shape(self, value: Any) -> None:
"""Set the calibration summary shape."""
self._calibration_summary_shape = value
[docs]
def enable_fast_mode(self) -> None:
"""Enable fast explanation mode.
This initializes the interval learner for fast explanations if not already done.
"""
if not self.is_fast():
try:
self._CalibratedExplainer__fast = True
# Prefer calling the public method name so unit tests that patch
# `initialize_interval_learner_for_fast_explainer` observe the
# raised exception. Fall back to the name-mangled implementation
# if the public alias is absent.
init_fn = getattr(self, "initialize_interval_learner_for_fast_explainer", None)
if callable(init_fn):
init_fn()
else:
self._CalibratedExplainer__initialize_interval_learner_for_fast_explainer()
except Exception: # adr002_allow
self._CalibratedExplainer__fast = False
raise
@property
def _bridge_monitors(self) -> Dict[str, Any]:
"""Expose bridge monitor registry managed by PluginManager."""
return self.require_plugin_manager().bridge_monitors
@property
def _pyproject_explanations(self) -> Dict[str, Any] | None:
"""Delegate to PluginManager."""
return self.plugin_manager.pyproject_explanations
@_pyproject_explanations.setter
def _pyproject_explanations(self, value: Dict[str, Any] | None) -> None:
"""Delegate to PluginManager."""
self.plugin_manager.pyproject_explanations = value
@property
def _pyproject_intervals(self) -> Dict[str, Any] | None:
"""Delegate to PluginManager."""
return self.plugin_manager.pyproject_intervals
@_pyproject_intervals.setter
def _pyproject_intervals(self, value: Dict[str, Any] | None) -> None:
"""Delegate to PluginManager."""
self.plugin_manager.pyproject_intervals = value
@property
def _pyproject_plots(self) -> Dict[str, Any] | None:
"""Delegate to PluginManager."""
return self.plugin_manager.pyproject_plots
@_pyproject_plots.setter
def _pyproject_plots(self, value: Dict[str, Any] | None) -> None:
"""Delegate to PluginManager."""
self.plugin_manager.pyproject_plots = value
@property
def runtime_telemetry(self) -> Mapping[str, Any]:
"""Return the most recent telemetry payload reported by the explainer."""
return dict(self.plugin_manager.last_telemetry)
@property
def preprocessor_metadata(self) -> Dict[str, Any] | None:
"""Return the telemetry-safe preprocessing snapshot if available."""
if self._preprocessor_metadata is None:
return None
return dict(self._preprocessor_metadata)
@property
def x_cal(self):
"""Get the calibration input data.
Returns
-------
array-like
The calibration input data.
"""
from ..calibration.state import CalibrationState # pylint: disable=import-outside-toplevel
return CalibrationState.get_x_cal(self)
@x_cal.setter
def x_cal(self, value):
"""Set the calibration input data.
Parameters
----------
value : array-like of shape (n_samples, n_features)
The new calibration input data.
Raises
------
ValueError
If the number of features in value does not match the existing calibration data.
"""
from ..calibration.state import CalibrationState # pylint: disable=import-outside-toplevel
CalibrationState.set_x_cal(self, value)
@property
def y_cal(self):
"""Get the calibration target data.
Returns
-------
array-like
The calibration target data.
"""
from ..calibration.state import CalibrationState # pylint: disable=import-outside-toplevel
return CalibrationState.get_y_cal(self)
@y_cal.setter
def y_cal(self, value):
"""Set the calibration target data.
Parameters
----------
value : array-like of shape (n_samples,)
The new calibration target data.
"""
from ..calibration.state import CalibrationState # pylint: disable=import-outside-toplevel
CalibrationState.set_y_cal(self, value)
[docs]
def append_cal(self, x, y):
"""Append new calibration data.
Parameters
----------
x : array-like of shape (n_samples, n_features)
The new calibration input data to append.
y : array-like of shape (n_samples,)
The new calibration target data to append.
"""
from ..calibration.state import CalibrationState # pylint: disable=import-outside-toplevel
CalibrationState.append_calibration(self, x, y)
def _invalidate_calibration_summaries(self) -> None:
"""Drop cached calibration summaries used during explanation.
Delegates to the calibration.summaries module which manages the cache.
"""
from ..calibration.summaries import ( # pylint: disable=import-outside-toplevel
invalidate_calibration_summaries as _invalidate,
)
_invalidate(self)
[docs]
def get_calibration_summaries(
self, x_cal_np: Optional[np.ndarray] = None
) -> Tuple[Dict[int, Dict[Any, int]], Dict[int, np.ndarray]]:
"""Return cached categorical counts and sorted numeric calibration values.
Delegates to the calibration.summaries module which manages caching of
statistical summaries used during explanation generation.
"""
from ..calibration.summaries import ( # pylint: disable=import-outside-toplevel
get_calibration_summaries as _get,
)
return _get(self, x_cal_np)
@property
def num_features(self):
"""Get the number of features in the calibration data.
Returns
-------
int
The number of features in the calibration data. For dictionary input,
returns the number of keys. For array input, returns the number of columns.
"""
return (
len(self._X_cal[0].keys())
if isinstance(self._X_cal[0], dict)
else len(self._X_cal[0, :])
)
@property
def feature_names(self):
"""Get the feature names.
Returns
-------
list
The list of feature names. If no feature names were provided during initialization,
returns None.
"""
return self._feature_names
@feature_names.setter
def feature_names(self, value):
"""Set the feature names.
Parameters
----------
value : list
The list of feature names.
"""
self._feature_names = list(value) if value is not None else None
@property
def interval_learner(self) -> Any:
"""Access the interval learner managed by the prediction orchestrator.
Returns
-------
Any
The interval calibrator (e.g., VennAbers, IntervalRegressor, or list for fast mode).
Notes
-----
This is a backward-compatible property that delegates to the interval registry
managed by the PredictionOrchestrator. See ADR-001.
"""
return self.prediction_orchestrator.interval_registry.interval_learner
@interval_learner.setter
def interval_learner(self, value: Any) -> None:
"""Set the interval learner through the prediction orchestrator's registry.
Parameters
----------
value : Any
The interval calibrator to set (e.g., VennAbers, IntervalRegressor).
Notes
-----
This is a backward-compatible setter that delegates to the interval registry
managed by the PredictionOrchestrator.
"""
self.prediction_orchestrator.interval_registry.interval_learner = value
def _get_sigma_test(self, x: np.ndarray) -> np.ndarray:
"""Return the difficulty (sigma) of the test instances.
Parameters
----------
x : np.ndarray
Test instances for which to estimate difficulty.
Returns
-------
np.ndarray
Difficulty estimates (sigma values) for each test instance.
Notes
-----
This is a backward-compatible method that delegates to the interval registry
managed by the PredictionOrchestrator. See ADR-001.
"""
return self.prediction_orchestrator.interval_registry.get_sigma_test(x)
[docs]
def get_sigma_test(self, x: np.ndarray) -> np.ndarray:
"""Return the difficulty (sigma) of the test instances.
Parameters
----------
x : np.ndarray
Test instances for which to estimate difficulty.
Returns
-------
np.ndarray
Difficulty estimates (sigma values) for each test instance.
"""
return self._get_sigma_test(x)
def _CalibratedExplainer__initialize_interval_learner_for_fast_explainer(self) -> None: # noqa: N802
"""Backward-compatible wrapper for fast-mode interval learner initialization.
Notes
-----
This method delegates to the interval registry. It is kept for backward
compatibility with the external fast_explanations plugin and other
production code that calls this private method.
See ADR-001.
"""
self.prediction_orchestrator.interval_registry.initialize_for_fast_explainer()
[docs]
def reinitialize(self, learner, xs=None, ys=None, bins=None):
"""Reinitialize the explainer with a new learner.
This is useful when the learner is updated or retrained and the explainer needs to be reinitialized.
Parameters
----------
learner : predictive learner
A predictive learner that can be used to predict the target variable. The learner must be fitted and have a predict_proba method (for classification) or a predict method (for regression).
xs : array-like, optional
New calibration input data to append
ys : array-like, optional
New calibration target data to append
Returns
-------
:class:`.CalibratedExplainer`
A :class:`.CalibratedExplainer` object that can be used to explain predictions from a predictive learner.
"""
self.__initialized = False
check_is_fitted(learner)
self.learner = learner
if xs is not None and ys is not None:
self.append_cal(xs, ys)
if bins is not None:
if self.bins is None:
raise ValidationError("Cannot mix calibration instances with and without bins.")
if len(bins) != len(ys):
raise DataShapeError(
"The length of bins must match the number of added instances."
)
self.bins = np.concatenate((self.bins, bins)) if self.bins is not None else bins
# update interval learner via helper
from ..calibration.interval_learner import update_interval_learner as _upd_il
_upd_il(self, xs, ys, bins=bins)
else:
from ..calibration.interval_learner import initialize_interval_learner as _init_il
_init_il(self)
self.__initialized = True
def __repr__(self):
"""Return the string representation of the CalibratedExplainer."""
# pylint: disable=line-too-long
disp_str = f"CalibratedExplainer(mode={self.mode}{', conditional=True' if self.bins is not None else ''}{f', discretizer={self.discretizer}' if self.discretizer is not None else ''}, learner={self.learner}{f', difficulty_estimator={self.difficulty_estimator})' if self.mode == 'regression' else ')'}"
if self.verbose:
disp_str += f"\n\tinit_time={self.init_time}"
if self.latest_explanation is not None:
disp_str += f"\n\ttotal_explain_time={self.latest_explanation.total_explain_time}"
disp_str += f"\n\tsample_percentiles={self.sample_percentiles}\
\n\tseed={self.seed}\
\n\tverbose={self.verbose}"
if self.feature_names is not None:
disp_str += f"\n\tfeature_names={self.feature_names}"
if self.categorical_features is not None:
disp_str += f"\n\tcategorical_features={self.categorical_features}"
if self.categorical_labels is not None:
disp_str += f"\n\tcategorical_labels={self.categorical_labels}"
if self.class_labels is not None:
disp_str += f"\n\tclass_labels={self.class_labels}"
return disp_str
[docs]
def obtain_interval_calibrator(
self,
*,
fast: bool,
metadata: Mapping[str, Any],
) -> Tuple[Any, str | None]:
"""Return the interval calibrator from the prediction orchestrator."""
return self.prediction_orchestrator.obtain_interval_calibrator(fast=fast, metadata=metadata)
[docs]
def explain_factual(
self,
x,
threshold=None,
low_high_percentiles=(5, 95),
bins=None,
features_to_ignore=None,
*,
_use_plugin: bool = True,
**kwargs,
) -> CalibratedExplanations:
"""Create a :class:`.CalibratedExplanations` object for the test data with the discretizer automatically assigned for factual explanations.
This is a thin delegator that sets up the appropriate discretizer and delegates to the orchestrator.
Parameters
----------
x : array-like
A set with n_samples of test objects to predict.
threshold : float, int or array-like, default=None
Values for which p-values should be returned. Only used for probabilistic explanations for regression.
low_high_percentiles : a tuple of floats, default=(5, 95)
The low and high percentile used to calculate the interval. Applicable to regression.
bins : array-like of shape (n_samples,), default=None
Mondrian categories
**kwargs : dict
Additional arguments passed to the explanation orchestrator.
Returns
-------
CalibratedExplanations : :class:`.CalibratedExplanations`
A `CalibratedExplanations` containing one :class:`.FactualExplanation` for each instance.
"""
if bins is None and self.is_mondrian():
bins = self.bins
# Thin delegator that sets discretizer and delegates to orchestrator
discretizer = "binaryRegressor" if "regression" in self.mode else "binaryEntropy"
ctx = self._perf_parallel if self._perf_parallel is not None else contextlib.nullcontext()
with ctx:
reject_policy = kwargs.pop("reject_policy", None)
invoke_kwargs = {
"x": x,
"threshold": threshold,
"low_high_percentiles": low_high_percentiles,
"bins": bins,
"features_to_ignore": features_to_ignore,
"discretizer": discretizer,
"_use_plugin": _use_plugin,
**kwargs,
}
if reject_policy is not None:
invoke_kwargs["reject_policy"] = reject_policy
return self.explanation_orchestrator.invoke_factual(**invoke_kwargs)
[docs]
def explore_alternatives(
self,
x,
threshold=None,
low_high_percentiles=(5, 95),
bins=None,
features_to_ignore=None,
*,
_use_plugin: bool = True,
**kwargs,
) -> AlternativeExplanations:
"""Create a :class:`.AlternativeExplanations` object for the test data with the discretizer automatically assigned for alternative explanations.
This is a thin delegator that sets up the appropriate discretizer and delegates to the orchestrator.
Parameters
----------
x : array-like
A set with n_samples of test objects to predict.
threshold : float, int or array-like, default=None
Values for which p-values should be returned. Only used for probabilistic explanations for regression.
low_high_percentiles : a tuple of floats, default=(5, 95)
The low and high percentile used to calculate the interval. Applicable to regression.
bins : array-like of shape (n_samples,), default=None
Mondrian categories
**kwargs : dict
Additional arguments passed to the explanation orchestrator.
Returns
-------
AlternativeExplanations : :class:`.AlternativeExplanations`
Notes
-----
The `explore_alternatives` will eventually be used instead of the `explain_counterfactual` method.
"""
if bins is None and self.is_mondrian():
bins = self.bins
# Thin delegator that sets discretizer and delegates to orchestrator
discretizer = "regressor" if "regression" in self.mode else "entropy"
ctx = self._perf_parallel if self._perf_parallel is not None else contextlib.nullcontext()
with ctx:
reject_policy = kwargs.pop("reject_policy", None)
invoke_kwargs = {
"x": x,
"threshold": threshold,
"low_high_percentiles": low_high_percentiles,
"bins": bins,
"features_to_ignore": features_to_ignore,
"discretizer": discretizer,
"_use_plugin": _use_plugin,
**kwargs,
}
if reject_policy is not None:
invoke_kwargs["reject_policy"] = reject_policy
return self.explanation_orchestrator.invoke_alternative(**invoke_kwargs) # type: ignore[return-value]
[docs]
def explain_guarded_factual(
self,
x,
threshold=None,
low_high_percentiles=(5, 95),
bins=None,
features_to_ignore=None,
*,
_use_plugin: bool = True,
significance: float = 0.1,
merge_adjacent: bool = False,
n_neighbors: int = 5,
normalize_guard: bool = True,
verbose: bool = False,
**kwargs,
):
"""Create guarded factual explanations that only use in-distribution perturbations.
Unlike :meth:`explain_factual`, which uses a binary (``max_depth=1``)
discretiser, this method uses the same multi-bin (``max_depth=3``)
discretiser as :meth:`explore_alternatives`. For each leaf an
in-distribution guard tests whether the representative perturbation
is conforming to the calibration distribution; leaves that fail the
test are filtered out.
Rule conditions are **intervals** such as ``"30 < age <= 50"`` rather
than simple threshold splits. Adjacent conforming bins can optionally
be merged into wider intervals (``merge_adjacent=True``).
Parameters
----------
x : array-like
A set with n_samples of test objects to explain.
threshold : float, int or array-like, optional
Values for which p-values should be returned. Only used for
probabilistic regression.
low_high_percentiles : tuple of float, default=(5, 95)
The low and high percentile used to calculate the interval.
bins : array-like of shape (n_samples,), optional
Mondrian categories.
features_to_ignore : sequence of int or str, optional
Features to exclude from explanations.
significance : float, default=0.1
Acceptable false-OOD rate. Bins are considered conforming when
``p_value >= significance``; bins below that threshold are
treated as out-of-distribution and not included.
merge_adjacent : bool, default=False
When ``True``, merge adjacent conforming bins into a single wider
interval condition.
n_neighbors : int, default=5
Number of nearest calibration neighbours used by the in-distribution
guard for computing non-conformity scores.
normalize_guard : bool, default=True
Apply per-feature min-max normalisation before computing KNN
distances inside the guard.
verbose : bool, default=False
When True, emit UserWarnings for guarded-explanation diagnostics.
**kwargs : dict
Additional arguments (reserved for future use).
Returns
-------
CalibratedExplanations
A :class:`~calibrated_explanations.CalibratedExplanations` container
whose individual explanations are
:class:`~calibrated_explanations.explanations.guarded_explanation.GuardedFactualExplanation`
instances.
"""
if not _use_plugin and verbose:
warnings.warn(
"_use_plugin has no effect on guarded explanation methods",
UserWarning,
stacklevel=2,
)
if bins is None and self.is_mondrian():
bins = self.bins
ctx = self._perf_parallel if self._perf_parallel is not None else contextlib.nullcontext()
with ctx:
reject_policy = kwargs.pop("reject_policy", None)
return self.explanation_orchestrator.invoke_guarded_factual(
x=x,
threshold=threshold,
low_high_percentiles=low_high_percentiles,
bins=bins,
features_to_ignore=features_to_ignore,
reject_policy=reject_policy,
significance=significance,
merge_adjacent=merge_adjacent,
n_neighbors=n_neighbors,
normalize_guard=normalize_guard,
verbose=verbose,
**kwargs,
)
[docs]
def explore_guarded_alternatives(
self,
x,
threshold=None,
low_high_percentiles=(5, 95),
bins=None,
features_to_ignore=None,
*,
_use_plugin: bool = True,
significance: float = 0.1,
merge_adjacent: bool = False,
n_neighbors: int = 5,
normalize_guard: bool = True,
verbose: bool = False,
**kwargs,
):
"""Create guarded alternative explanations that only use in-distribution perturbations.
This method extends :meth:`explore_alternatives` with an in-distribution
guard: for each leaf of the multi-bin discretiser, it tests whether
perturbing the feature to the leaf's representative value (while keeping
all other features at their original level) produces an instance that is
conforming to the calibration distribution. Non-conforming leaves are
excluded from the alternatives.
Rule conditions are **intervals** such as ``"30 < age <= 50"``; for the
current (factual) bin a factual rule is also stored (``is_factual=True``).
Adjacent conforming bins can optionally be merged (``merge_adjacent=True``).
Parameters
----------
x : array-like
A set with n_samples of test objects to explain.
threshold : float, int or array-like, optional
Values for which p-values should be returned. Only used for
probabilistic regression.
low_high_percentiles : tuple of float, default=(5, 95)
The low and high percentile used to calculate the interval.
bins : array-like of shape (n_samples,), optional
Mondrian categories.
features_to_ignore : sequence of int or str, optional
Features to exclude from explanations.
significance : float, default=0.1
Acceptable false-OOD rate. Bins are considered conforming when
``p_value >= significance``; bins below that threshold are
treated as out-of-distribution and not included as alternatives.
merge_adjacent : bool, default=False
When ``True``, merge adjacent conforming bins into a single wider
interval condition.
n_neighbors : int, default=5
Number of nearest calibration neighbours used by the in-distribution
guard for computing non-conformity scores.
normalize_guard : bool, default=True
Apply per-feature min-max normalisation before computing KNN
distances inside the guard.
verbose : bool, default=False
When True, emit UserWarnings for guarded-explanation diagnostics.
**kwargs : dict
Additional arguments (reserved for future use).
Returns
-------
AlternativeExplanations
An :class:`~calibrated_explanations.AlternativeExplanations` container
whose individual explanations are
:class:`~calibrated_explanations.explanations.guarded_explanation.GuardedAlternativeExplanation`
instances.
"""
if not _use_plugin and verbose:
warnings.warn(
"_use_plugin has no effect on guarded explanation methods",
UserWarning,
stacklevel=2,
)
if bins is None and self.is_mondrian():
bins = self.bins
ctx = self._perf_parallel if self._perf_parallel is not None else contextlib.nullcontext()
with ctx:
reject_policy = kwargs.pop("reject_policy", None)
return self.explanation_orchestrator.invoke_guarded_alternative(
x=x,
threshold=threshold,
low_high_percentiles=low_high_percentiles,
bins=bins,
features_to_ignore=features_to_ignore,
reject_policy=reject_policy,
significance=significance,
merge_adjacent=merge_adjacent,
n_neighbors=n_neighbors,
normalize_guard=normalize_guard,
verbose=verbose,
**kwargs,
)
def __call__(
self,
x,
threshold=None,
low_high_percentiles=(5, 95),
bins=None,
features_to_ignore=None,
*,
reject_policy: Any | None = None,
_use_plugin: bool = True,
_skip_instance_parallel: bool = False,
) -> CalibratedExplanations:
"""Call self as a function to create a :class:`.CalibratedExplanations` object for the test data with the already assigned discretizer.
Since v0.4.0, this method is equivalent to the `_explain` method.
"""
call_kwargs: dict[str, Any] = {
"_use_plugin": _use_plugin,
"_skip_instance_parallel": _skip_instance_parallel,
}
if reject_policy is not None:
call_kwargs["reject_policy"] = reject_policy
return self._explain(
x,
threshold,
low_high_percentiles,
bins,
features_to_ignore,
**call_kwargs,
)
def _explain(self, *args, **kwargs) -> CalibratedExplanations:
"""Generate explanations for test instances by analyzing feature effects.
This is an internal orchestration primitive that delegates to the explanation orchestrator.
It is NOT part of the public API and should not be called directly.
This method:
1. Makes predictions on original test instances
2. Creates perturbed versions by varying feature values
3. Analyzes how predictions change with feature perturbations
4. Generates feature importance weights and prediction intervals
Returns
-------
CalibratedExplanations : :class:`.CalibratedExplanations`
A :class:`.CalibratedExplanations` containing one :class:`.CalibratedExplanation` for each instance.
See Also
--------
:meth:`.CalibratedExplainer.explain_factual` : Refer to the documentation for `explain_factual` for more details.
:meth:`.CalibratedExplainer.explore_alternatives` : Refer to the documentation for `explore_alternatives` for more details.
"""
# Delegate the args to the actual implementation
return self._explain_impl(*args, **kwargs)
def _explain_impl(
self,
x,
threshold=None,
low_high_percentiles=(5, 95),
bins=None,
features_to_ignore=None,
*,
reject_policy: Any | None = None,
_use_plugin: bool = True,
_skip_instance_parallel: bool = False,
) -> CalibratedExplanations:
if bins is None and self.is_mondrian():
bins = self.bins
# Thin delegator to orchestrator
if _use_plugin:
mode = self.infer_explanation_mode()
invoke_kwargs: dict[str, Any] = {
"extras": {"mode": mode, "_skip_instance_parallel": _skip_instance_parallel}
}
if reject_policy is not None:
invoke_kwargs["reject_policy"] = reject_policy
return self.explanation_orchestrator.invoke(
mode,
x,
threshold,
low_high_percentiles,
bins,
features_to_ignore,
**invoke_kwargs,
)
# Legacy path for backward compatibility and testing
from .explain import legacy_explain # pylint: disable=import-outside-toplevel
return legacy_explain(
self,
x,
threshold=threshold,
low_high_percentiles=low_high_percentiles,
bins=bins,
features_to_ignore=features_to_ignore,
)
# NOTE: Instance- and feature-parallel helpers have been moved into the
# plugin-based implementation under `core.explain.*`. The legacy helper
# methods were intentionally removed to centralize parallel execution in
# the plugin modules. Tests should exercise the plugin classes
# (e.g. InstanceParallelExplainExecutor, FeatureParallelExplainExecutor,
# SequentialExplainExecutor) rather than calling these private helpers.
# NOTE: merge_feature_result functionality has been moved to
# `calibrated_explanations.core.explain._helpers.merge_feature_result`.
# Plugins and explain code should call that free-function directly.
# NOTE: Thin wrapper methods (_slice_threshold, _slice_bins, _validate_and_prepare_input,
# _initialize_explanation, _compute_weight_delta, _discretize) have been removed.
# Callers should import these directly from core.explain submodules:
# - core.explain._helpers: slice_threshold, slice_bins, validate_and_prepare_input
# - core.explain._computation: initialize_explanation, discretize
# - core.explain._helpers: compute_weight_delta
[docs]
def explain_fast(
self,
x,
threshold=None,
low_high_percentiles=(5, 95),
bins=None,
*,
reject_policy: Any | None = None,
_use_plugin: bool = True,
) -> CalibratedExplanations:
"""Create a :class:`.CalibratedExplanations` object for the test data.
Parameters
----------
x : array-like
A set with n_samples of test objects to predict
threshold : float, int or array-like of shape (n_samples,), default=None
values for which p-values should be returned. Only used for probabilistic explanations for regression.
low_high_percentiles : a tuple of floats, default=(5, 95)
The low and high percentile used to calculate the interval. Applicable to regression.
bins : array-like of shape (n_samples,), default=None
Mondrian categories
Raises
------
ValueError: The number of features in the test data must be the same as in the calibration data.
Warning: The threshold-parameter is only supported for mode='regression'.
ValueError: The length of the threshold parameter must be either a constant or the same as the number of
instances in x.
RuntimeError: Fast explanations are only possible if the explainer is a Fast Calibrated Explainer.
Returns
-------
CalibratedExplanations : :class:`.CalibratedExplanations`
A `CalibratedExplanations` containing one :class:`.FastExplanation` for each instance.
"""
if bins is None and self.is_mondrian():
bins = self.bins
if _use_plugin:
return self.explanation_orchestrator.invoke(
"fast",
x,
threshold,
low_high_percentiles,
bins,
tuple(self.features_to_ignore),
extras={"mode": "fast"},
reject_policy=reject_policy,
)
# Delegate to external plugin pipeline for non-plugin path
# pylint: disable-next=import-outside-toplevel
from pathlib import Path
# Ensure the repository root is in the path
repo_root = Path(__file__).resolve().parents[3]
if str(repo_root) not in sys.path:
sys.path.insert(0, str(repo_root))
from external_plugins.fast_explanations.pipeline import FastExplanationPipeline
pipeline = FastExplanationPipeline(self)
return pipeline.explain(x, threshold, low_high_percentiles, bins)
# feature-merge and feature-parallel logic moved to plugin helpers
[docs]
def is_multiclass(self) -> bool:
"""Test if it is a multiclass problem.
Returns
-------
bool
True if multiclass (num_classes > 2).
"""
return self.num_classes > 2
[docs]
def is_fast(self) -> bool:
"""Test if the explainer uses fast mode.
Returns
-------
bool
True if fast mode is enabled.
"""
return self.__fast
[docs]
def is_mondrian(self) -> bool:
"""Test if Mondrian (per-bin) calibration is enabled.
Returns
-------
bool
True if bins are configured, indicating Mondrian calibration.
"""
return self.bins is not None
[docs]
def discretize(self, data: np.ndarray) -> np.ndarray:
"""Apply the discretizer to input data.
Parameters
----------
data : np.ndarray
The data to discretize.
Returns
-------
np.ndarray
The discretized data.
"""
from .explain import discretize as _discretize_func # pylint: disable=import-outside-toplevel
return _discretize_func(self, data)
[docs]
def rule_boundaries(self, instances, perturbed_instances=None):
"""Extract the rule boundaries for a set of instances.
Parameters
----------
instances : array-like
The instances to extract boundaries for.
perturbed_instances : array-like, optional
Discretized versions of instances. Defaults to None.
Returns
-------
array-like
Min and max values for each feature for each instance.
"""
from .explain import rule_boundaries as _rule_boundaries # pylint: disable=import-outside-toplevel
return _rule_boundaries(self, instances, perturbed_instances)
[docs]
def set_difficulty_estimator(self, difficulty_estimator, initialize=True) -> None:
"""Assign or update the difficulty estimator.
If initialized to a difficulty estimator, the explainer can be used to reject explanations that are deemed too difficult.
Parameters
----------
difficulty_estimator : :class:`crepes.extras.DifficultyEstimator` or None):
A :class:`crepes.extras.DifficultyEstimator` object from the crepes package. To remove the :class:`crepes.extras.DifficultyEstimator`, set to None.
initialize (bool, optional):
If true, then the interval learner is initialized once done. Defaults to True.
"""
from .difficulty_estimator_helpers import ( # pylint: disable=import-outside-toplevel
validate_difficulty_estimator,
)
validate_difficulty_estimator(difficulty_estimator)
self.__initialized = False
self.difficulty_estimator = difficulty_estimator
# Invalidate cached interval plugin metadata.
# Interval resolution persists context metadata (including a cached calibrator)
# across invocations for performance. When the difficulty estimator changes,
# we must drop that cache so the regression backend (IntervalRegressor)
# re-fits crepes' ConformalPredictiveSystem with the updated `sigmas`.
plugin_manager = getattr(self, "_plugin_manager", None)
if plugin_manager is not None:
meta = getattr(plugin_manager, "interval_context_metadata", None)
if isinstance(meta, dict):
for key in ("default", "fast"):
bucket = meta.get(key)
if isinstance(bucket, dict):
bucket.pop("calibrator", None)
bucket.pop("fast_calibrators", None)
bucket.pop("existing_fast_calibrators", None)
bucket.pop("difficulty_estimator", None)
# Clear the active interval learner if orchestrators are available.
# (During __init__ we call set_difficulty_estimator before orchestrator setup.)
orchestrator_ready = (
plugin_manager is not None
and getattr(plugin_manager, "_prediction_orchestrator", None) is not None
)
if orchestrator_ready:
self.interval_learner = None
if initialize:
self.prediction_orchestrator.interval_registry.initialize() # type: ignore[attr-defined]
[docs]
def set_mode(self, mode, initialize=True) -> None:
"""Assign the mode of the explainer. The mode can be either 'classification' or 'regression'.
Parameters
----------
mode (str): The mode can be either 'classification' or 'regression'.
initialize (bool, optional): If true, then the interval learner is initialized once done. Defaults to True.
Raises
------
ValueError: The mode can be either 'classification' or 'regression'.
"""
self.__initialized = False
if mode == "classification":
# assert 'predict_proba' in dir(self.learner), "The learner must have a predict_proba method."
self.num_classes = len(np.unique(self.y_cal))
elif mode == "regression":
# assert 'predict' in dir(self.learner), "The learner must have a predict method."
self.num_classes = 0
else:
raise ValidationError("The mode must be either 'classification' or 'regression'.")
self.mode = mode
if initialize:
self.prediction_orchestrator.interval_registry.initialize() # type: ignore[attr-defined]
[docs]
def initialize_reject_learner( # pylint: disable=invalid-name
self, calibration_set=None, threshold=None, ncf=None, w=0.5
):
"""Initialize the reject learner with a threshold value.
.. deprecated:: 0.11.1
Use ``reject_orchestrator.initialize_reject_learner`` instead.
This wrapper will be removed no earlier than v0.13.0.
Parameters
----------
calibration_set : array-like, optional
Optional calibration set override.
threshold : float, optional
Decision threshold (required for regression reject calibration).
ncf : str or None, default None
Non-conformity function type.
w : float, default 0.5
Blending weight used only when ``ncf='ensured'``.
Ignored for ``ncf='default'``.
Returns
-------
Any
The initialized reject learner.
"""
deprecate(
"CalibratedExplainer.initialize_reject_learner is deprecated since v0.11.1; "
"use reject_orchestrator.initialize_reject_learner instead. "
"This wrapper will be removed no earlier than v0.13.0.",
key=(
"calibrated_explanations.core.calibrated_explainer."
"CalibratedExplainer.initialize_reject_learner_deprecation"
),
stacklevel=2,
)
self.plugin_manager.initialize_orchestrators()
return self.reject_orchestrator.initialize_reject_learner(
calibration_set=calibration_set, threshold=threshold, ncf=ncf, w=w
)
[docs]
def predict_reject(self, x, bins=None, confidence=0.95):
"""Predict whether to reject the explanations for the test data.
.. deprecated:: 0.11.1
Use ``reject_orchestrator.predict_reject`` instead.
This wrapper will be removed no earlier than v0.13.0.
Parameters
----------
x : array-like
The test data.
bins : array-like, optional
Mondrian categories for conditional calibration.
confidence : float, default=0.95
Confidence level used by the reject predictor.
Returns
-------
tuple
Rejection decisions and summary rates.
"""
deprecate(
"CalibratedExplainer.predict_reject is deprecated since v0.11.1; "
"use reject_orchestrator.predict_reject instead. "
"This wrapper will be removed no earlier than v0.13.0.",
key=(
"calibrated_explanations.core.calibrated_explainer."
"CalibratedExplainer.predict_reject_deprecation"
),
stacklevel=2,
)
self.plugin_manager.initialize_orchestrators()
return self.reject_orchestrator.predict_reject(x, bins=bins, confidence=confidence)
# pylint: disable=too-many-branches
[docs]
def set_discretizer(
self,
discretizer,
x_cal=None,
y_cal=None,
features_to_ignore=None,
*,
condition_source: Optional[str] = None,
) -> None:
"""Assign the discretizer to be used.
Parameters
----------
discretizer : str or discretizer object
The discretizer to be used.
x_cal : array-like, optional
The calibration data for the discretizer.
y_cal : array-like, optional
The calibration target data for the discretizer.
"""
self.explanation_orchestrator.set_discretizer(
discretizer,
x_cal=x_cal,
y_cal=y_cal,
features_to_ignore=features_to_ignore,
condition_source=condition_source,
)
# pylint: disable=duplicate-code, too-many-branches, too-many-statements, too-many-locals
[docs]
def predict(self, x, uq_interval=False, calibrated=True, **kwargs):
"""Generate predictions for the test data.
Parameters
----------
x : array-like
The test data.
uq_interval : bool, default=False
Whether to return uncertainty intervals.
calibrated : bool, default=True
If True, the calibrator is used for prediction. If False, the underlying learner is used for prediction.
**kwargs : Various types, optional
Additional parameters to customize the explanation process. Supported parameters include:
- threshold : float, int, or array-like of shape (n_samples,), optional, default=None
Specifies the threshold for probabilistic regression. Returns calibrated probabilities
P(y <= threshold) for regression tasks. This parameter is ignored for classification tasks.
- low_high_percentiles : tuple of two floats, optional, default=(5, 95)
The lower and upper percentiles used to calculate the prediction interval for regression tasks.
Determines the breadth of the interval based on the distribution of the predictions.
This parameter is ignored for classification tasks and when threshold is provided.
Raises
------
RuntimeError
If the learner has not been fitted prior to making predictions.
Warning
If the learner is not calibrated.
Returns
-------
calibrated_prediction : float or array-like, or str
The calibrated prediction. For regression tasks without threshold, this is the median of the
conformal predictive system. For probabilistic regression (with threshold), this is a probability
P(y <= threshold). For classification tasks, it is the class label with the highest calibrated probability.
interval : tuple of floats, optional
A tuple (low, high) representing the lower and upper bounds of the uncertainty interval. This is returned only if ``uq_interval=True``.
Examples
--------
For a prediction without prediction intervals:
.. code-block:: python
w.predict(x)
For a prediction with uncertainty quantification intervals:
.. code-block:: python
w.predict(x, uq_interval=True)
Notes
-----
The `threshold` and `low_high_percentiles` parameters are only used for regression tasks.
"""
from .prediction_helpers import ( # pylint: disable=import-outside-toplevel
handle_uncalibrated_regression_prediction,
handle_uncalibrated_classification_prediction,
format_regression_prediction,
format_classification_prediction,
)
# Lazy import API params functions (deferred from module level)
from ..api.params import (
canonicalize_kwargs,
reject_removed_aliases,
validate_param_combination,
)
# reject removed aliases and normalize kwargs
reject_removed_aliases(kwargs)
kwargs = canonicalize_kwargs(kwargs)
validate_param_combination(kwargs)
if "interval_summary" not in kwargs or kwargs["interval_summary"] is None:
kwargs["interval_summary"] = self.interval_summary
else:
kwargs["interval_summary"] = coerce_interval_summary(kwargs["interval_summary"])
if not calibrated:
if self.mode == "regression":
return handle_uncalibrated_regression_prediction(
self.learner, x, threshold=kwargs.get("threshold"), uq_interval=uq_interval
)
return handle_uncalibrated_classification_prediction(
self.learner, x, threshold=kwargs.get("threshold"), uq_interval=uq_interval
)
# Resolve reject policy (per-call overrides explainer default)
from .reject.policy import RejectPolicy as _RejectPolicy
from .reject.orchestrator import ( # pylint: disable=import-outside-toplevel
resolve_effective_reject_policy,
)
# Internal callers may skip reject orchestration by setting this flag
if kwargs.pop("_ce_skip_reject", False):
skip_reject_for_internal = True
resolution = None
else:
skip_reject_for_internal = False
resolution = resolve_effective_reject_policy(
kwargs.pop("reject_policy", None),
self,
default_policy=getattr(self, "default_reject_policy", _RejectPolicy.NONE),
logger=logging.getLogger(__name__),
)
policy = _RejectPolicy.NONE if skip_reject_for_internal else resolution.policy
implicit_default_used = (
(not skip_reject_for_internal)
and resolution is not None
and resolution.used_default
and policy is not _RejectPolicy.NONE
)
# If no reject orchestration requested, proceed with legacy behavior
if policy is _RejectPolicy.NONE or skip_reject_for_internal:
# Calibrated predictions
if self.mode == "regression":
predict, low, high, _ = self.prediction_orchestrator.predict(x, **kwargs)
return format_regression_prediction(
predict, low, high, threshold=kwargs.get("threshold"), uq_interval=uq_interval
)
# Classification
predict, low, high, new_classes = self.prediction_orchestrator.predict(x, **kwargs)
return format_classification_prediction(
predict,
low,
high,
new_classes,
self.is_multiclass(),
label_map=self.label_map,
class_labels=self.class_labels,
uq_interval=uq_interval,
)
# Reject policy active: use orchestrator to apply policy and return RejectResult envelope
bins_arg = kwargs.pop("bins", None)
confidence_arg = kwargs.pop("confidence", 0.95)
rr = self.reject_orchestrator.apply_policy(
policy,
x,
explain_fn=None,
bins=bins_arg,
confidence=confidence_arg,
result_schema="v2",
**kwargs,
)
try:
from ..explanations.reject import (
RejectResultV2, # pylint: disable=import-outside-toplevel
reject_result_v2_to_legacy,
)
if isinstance(rr, RejectResultV2):
rr = reject_result_v2_to_legacy(rr, emit_deprecation_warning=False)
except Exception as exc: # adr002_allow
logging.getLogger(__name__).debug(
"RejectResultV2 compatibility conversion failed in predict: %s",
exc,
exc_info=True,
)
# Format the legacy payload into rr.prediction for consumer ergonomics
try:
if rr.prediction is not None:
if self.mode == "regression":
# prediction is expected as (predict, low, high, _)
predict, low, high, _ = rr.prediction
rr.prediction = format_regression_prediction(
predict,
low,
high,
threshold=kwargs.get("threshold"),
uq_interval=uq_interval,
)
else:
predict, low, high, new_classes = rr.prediction
rr.prediction = format_classification_prediction(
predict,
low,
high,
new_classes,
self.is_multiclass(),
label_map=self.label_map,
class_labels=self.class_labels,
uq_interval=uq_interval,
)
except Exception as exc: # adr002_allow
# If formatting fails, leave rr.prediction as-is but warn
logging.getLogger(__name__).info(
"Failed to format RejectResult.prediction; leaving raw.", exc_info=True
)
warnings.warn(
f"Failed to format RejectResult.prediction: {exc!s}", UserWarning, stacklevel=2
)
# Log once-per-call when an implicit default caused an envelope return
if implicit_default_used:
logging.getLogger(__name__).info(
"Default reject policy %s applied implicitly; returning RejectResult envelope for this call.",
str(policy),
)
return rr
[docs]
def predict_proba(self, x, uq_interval=False, calibrated=True, threshold=None, **kwargs):
"""Generate probability predictions for the test data.
This is a wrapper around the predict_proba method which is more similar to the scikit-learn predict_proba method for classification.
As opposed to predict_proba, this method may output uncertainty intervals.
Parameters
----------
x : array-like
The test data for which predictions are to be made. This should be in a format compatible with sklearn (e.g., numpy arrays, pandas DataFrames).
uq_interval : bool, default=False
If true, then the prediction interval is returned as well.
calibrated : bool, default=True
If True, the calibrator is used for prediction. If False, the underlying learner is used for prediction.
threshold : float, int or array-like of shape (n_samples,), optional, default=None
Threshold values used with regression to get probability of being below the threshold. Only applicable to regression.
Raises
------
RuntimeError
If the learner is not fitted before predicting.
ValueError
If the `threshold` parameter's length does not match the number of instances in `x`, or if it is not a single constant value applicable to all instances.
RuntimeError
If the learner is not fitted before predicting.
Warning
If the learner is not calibrated.
Returns
-------
calibrated probability :
The calibrated probability of the positive class (or the predicted class for multiclass).
(low, high) : tuple of float lists, corresponding to the lower and upper bound of each prediction interval.
Examples
--------
For a prediction without uncertainty quantification intervals:
.. code-block:: python
w.predict_proba(x)
For a prediction with uncertainty quantification intervals:
.. code-block:: python
w.predict_proba(x, uq_interval=True)
Notes
-----
The `threshold` parameter is only used for regression tasks.
"""
# strip plotting-only keys that callers may pass
kwargs.pop("show", None)
kwargs.pop("style_override", None)
# Lazy import API params functions (deferred from module level)
from ..api.params import (
canonicalize_kwargs,
reject_removed_aliases,
validate_param_combination,
)
# reject removed aliases and normalize kwargs
reject_removed_aliases(kwargs)
kwargs = canonicalize_kwargs(kwargs)
validate_param_combination(kwargs)
# Inject default interval_summary if not provided
kwargs.setdefault("interval_summary", self.interval_summary)
confidence_arg = kwargs.pop("confidence", 0.95)
# Resolve reject policy (per-call override else explainer default)
from .reject.policy import RejectPolicy as _RejectPolicy
from .reject.orchestrator import ( # pylint: disable=import-outside-toplevel
resolve_effective_reject_policy,
)
# Internal callers may skip reject orchestration by setting this flag
if kwargs.pop("_ce_skip_reject", False):
skip_reject_for_internal = True
resolution = None
else:
skip_reject_for_internal = False
resolution = resolve_effective_reject_policy(
kwargs.pop("reject_policy", None),
self,
default_policy=getattr(self, "default_reject_policy", _RejectPolicy.NONE),
logger=logging.getLogger(__name__),
)
policy = _RejectPolicy.NONE if skip_reject_for_internal else resolution.policy
implicit_default_used = (
(not skip_reject_for_internal)
and resolution is not None
and resolution.used_default
and policy is not _RejectPolicy.NONE
)
if (
not skip_reject_for_internal
and policy is not _RejectPolicy.NONE
and self.mode == "regression"
and threshold is None
):
raise ValidationError("reject learner unavailable for regression without threshold")
# Helper: compute legacy proba payload for this call
proba_payload = None
if not calibrated:
if threshold is not None:
raise ValidationError(
"A thresholded prediction is not possible for uncalibrated learners."
)
if uq_interval:
proba = self.learner.predict_proba(x)
if proba.shape[1] > 2:
proba_payload = (proba, (proba, proba))
else:
proba_payload = (proba, (proba[:, 1], proba[:, 1]))
else:
proba_payload = self.learner.predict_proba(x)
else:
# Calibrated predictions
if self.mode == "regression":
if is_fast_interval_collection(self.interval_learner):
proba_1, low, high, _ = self.interval_learner[-1].predict_probability(
x, y_threshold=threshold, **kwargs
)
else:
proba_1, low, high, _ = self.interval_learner.predict_probability(
x, y_threshold=threshold, **kwargs
)
proba = np.array([[1 - proba_1[i], proba_1[i]] for i in range(len(proba_1))])
proba_payload = (proba, (low, high)) if uq_interval else proba
# Classification - multiclass
elif self.is_multiclass():
if is_fast_interval_collection(self.interval_learner):
proba, low, high, _ = self.interval_learner[-1].predict_proba(
x, output_interval=True, **kwargs
)
else:
proba, low, high, _ = self.interval_learner.predict_proba(
x, output_interval=True, **kwargs
)
proba_payload = (proba, (low, high)) if uq_interval else proba
# Classification - binary
else:
if is_fast_interval_collection(self.interval_learner):
proba, low, high = self.interval_learner[-1].predict_proba(
x, output_interval=True, **kwargs
)
else:
proba, low, high = self.interval_learner.predict_proba(
x, output_interval=True, **kwargs
)
proba_payload = (proba, (low, high)) if uq_interval else proba
# If no reject orchestration requested, return legacy payload
if policy is _RejectPolicy.NONE or skip_reject_for_internal:
return proba_payload
# Reject policy active: compute envelope via orchestrator and attach legacy payload
bins_arg = kwargs.pop("bins", None)
rr = self.reject_orchestrator.apply_policy(
policy,
x,
explain_fn=None,
bins=bins_arg,
confidence=confidence_arg,
threshold=threshold,
result_schema="v2",
**kwargs,
)
try:
from ..explanations.reject import (
RejectResultV2, # pylint: disable=import-outside-toplevel
reject_result_v2_to_legacy,
)
if isinstance(rr, RejectResultV2):
rr = reject_result_v2_to_legacy(rr, emit_deprecation_warning=False)
except Exception as exc: # adr002_allow
logging.getLogger(__name__).debug(
"RejectResultV2 compatibility conversion failed in predict_proba: %s",
exc,
exc_info=True,
)
rr.prediction = proba_payload
# Log once-per-call when an implicit default caused an envelope return
if implicit_default_used:
logging.getLogger(__name__).info(
"Default reject policy %s applied implicitly; returning RejectResult envelope for this call.",
str(policy),
)
return rr
# pylint: disable=duplicate-code, too-many-branches, too-many-statements, too-many-locals
[docs]
def plot(self, x, y=None, threshold=None, **kwargs):
"""Generate plots for the test data."""
# Pass any style overrides along to the plotting function
style_override = kwargs.pop("style_override", None)
kwargs["style_override"] = style_override
# Lazy import plotting function (deferred from module level)
from ..plotting import plot_global
plot_global(self, x, y=y, threshold=threshold, **kwargs)
[docs]
def calibrated_confusion_matrix(self):
"""Generate a calibrated confusion matrix.
Generates a confusion matrix for the calibration set to provide insights about model behavior.
The confusion matrix is only available for classification tasks. Stratified cross-validation is
used on the calibration set to generate the confusion matrix while avoiding quadratic
recalibration overhead.
Returns
-------
array-like
The calibrated confusion matrix.
"""
if self.mode != "classification":
raise ValidationError(
"The confusion matrix is only available for classification tasks."
)
from .calibration_metrics import ( # pylint: disable=import-outside-toplevel
compute_calibrated_confusion_matrix,
)
return compute_calibrated_confusion_matrix(
self.x_cal, self.y_cal, self.learner, bins=self.bins
)
[docs]
def predict_calibration(self):
"""Predict the target values for the calibration data.
Returns
-------
array-like
Predicted values for the calibration data. For models that expose a hat matrix,
this returns updated predictions using that matrix; otherwise it uses the
predict_function on the calibration data.
"""
return self.predict_function(self.x_cal)
# Public alias for testing purposes (to avoid private member access in tests)
@property
def fast(self) -> bool:
"""Whether to use fast mode.
Returns
-------
bool
True if fast mode is enabled.
"""
return self.__fast
@fast.setter
def fast(self, value: bool) -> None:
self.__fast = value
@property
def _fast(self) -> bool:
return self.fast
@_fast.setter
def _fast(self, value: bool) -> None:
self.fast = value
@property
def noise_type(self) -> str:
"""The type of noise to use.
Returns
-------
str
The noise type.
"""
return self.__noise_type
@noise_type.setter
def noise_type(self, value: str) -> None:
self.__noise_type = value
@property
def _noise_type(self) -> str:
return self.noise_type
@_noise_type.setter
def _noise_type(self, value: str) -> None:
self.noise_type = value
@property
def scale_factor(self) -> float | None:
"""The scale factor for perturbations.
Returns
-------
float | None
The scale factor.
"""
return self.__scale_factor
@scale_factor.setter
def scale_factor(self, value: float | None) -> None:
self.__scale_factor = value
@property
def _scale_factor(self) -> float | None:
return self.scale_factor
@_scale_factor.setter
def _scale_factor(self, value: float | None) -> None:
self.scale_factor = value
@property
def severity(self) -> float | None:
"""The severity of perturbations.
Returns
-------
float | None
The severity.
"""
return self.__severity
@severity.setter
def severity(self, value: float | None) -> None:
self.__severity = value
@property
def _severity(self) -> float | None:
return self.severity
@_severity.setter
def _severity(self, value: float | None) -> None:
self.severity = value
__all__ = ["CalibratedExplainer"]