Source code for calibrated_explanations.explanations.explanations

# pylint: disable=unknown-option-value, too-many-arguments
# pylint: disable=too-many-lines, too-many-public-methods, invalid-name, too-many-positional-arguments, line-too-long
"""Containers for storing, exporting, and visualising calibrated explanations.

This module implements :class:`CalibratedExplanations`, a container that
holds per-instance explanation objects (factual, alternative, fast) and
provides helpers for exporting, iterating and aggregating explanation
collections.
"""

from __future__ import annotations

import contextlib
import json
import logging
import sys
import tracemalloc
import warnings
from collections.abc import Sequence as ABCSequence
from copy import copy, deepcopy
from dataclasses import dataclass
from itertools import permutations
from time import time
from typing import Any, Dict, List, Mapping, Optional, Sequence, Tuple, Union, cast

import numpy as np

from ..core.prediction_helpers import validate_and_prepare_input
from ..utils import EntropyDiscretizer, RegressorDiscretizer, deprecate, prepare_for_saving
from ..utils.exceptions import ValidationError
from ..utils.helper import calculate_metrics
from .adapters import legacy_to_domain
from .explanation import AlternativeExplanation, FactualExplanation, FastExplanation
from .models import Explanation as DomainExplanation

_LOGGER = logging.getLogger(__name__)


def _plot_alternative_dict(*args, **kwargs):
    """Lazy wrapper to avoid importing plotting dependencies at module import time."""
    from ..plotting import _plot_alternative_dict as _impl

    return _impl(*args, **kwargs)


def _plot_probabilistic_dict(*args, **kwargs):
    """Lazy wrapper to avoid importing plotting dependencies at module import time."""
    from ..plotting import _plot_probabilistic_dict as _impl

    return _impl(*args, **kwargs)


def get_multiclass_config():
    """Lazy wrapper to avoid importing plotting dependencies at module import time."""
    from ..plotting import get_multiclass_config as _impl

    return _impl()


@dataclass(frozen=True)
class ExportedExplanationCollection:
    """Lightweight representation of exported explanations plus collection metadata."""

    metadata: Mapping[str, Any]
    explanations: Sequence[DomainExplanation]

    def __getstate__(self):
        """Get state for pickling.

        Returns
        -------
        dict
            The state dictionary.
        """
        # Convert mappingproxy to dict for pickling
        return dict(self.__dict__)


@dataclass(frozen=True)
class ExportedMultiClassExplanationCollection:
    """Exported multiclass explanations grouped by instance and class index."""

    metadata: Mapping[str, Any]
    explanations_by_instance: Sequence[Mapping[int, DomainExplanation]]

    @property
    def explanations(self) -> Sequence[DomainExplanation]:
        """Return flattened exported explanations for backward-compatible access."""
        flattened: list[DomainExplanation] = []
        for per_instance in self.explanations_by_instance:
            flattened.extend(per_instance.values())
        return tuple(flattened)

    def __getstate__(self):
        """Get state for pickling.

        Returns
        -------
        dict
            The state dictionary.
        """
        return dict(self.__dict__)


def _jsonify(value: Any) -> Any:
    """Convert numpy objects and arrays into JSON-serialisable primitives."""
    if isinstance(value, np.ndarray):
        return [_jsonify(item) for item in value.tolist()]
    if isinstance(value, (list, tuple, set)):
        return [_jsonify(item) for item in value]
    if isinstance(value, Mapping):
        return {str(key): _jsonify(val) for key, val in value.items()}
    if isinstance(value, np.generic):  # numpy scalars
        return value.item()
    if callable(value):
        return str(value)
    return value


[docs] class CalibratedExplanations: # pylint: disable=too-many-instance-attributes """A class for storing and visualizing calibrated explanations. This class is created by :class:`.CalibratedExplainer` and provides methods for managing and accessing explanations for test instances. """ def __init__( self, calibrated_explainer, x, y_threshold, bins, features_to_ignore=None, *, condition_source: str = "prediction", ) -> None: """Initialize the explanation collection for a calibrated explainer. Parameters ---------- calibrated_explainer : CalibratedExplainer The calibrated explainer object. x : array-like The test data. y_threshold : float or tuple The threshold for regression explanations. bins : array-like The bins for conditional explanations. """ if condition_source not in {"observed", "prediction"}: raise ValidationError( "condition_source must be 'observed' or 'prediction'", details={"param": "condition_source", "value": condition_source}, ) self.calibrated_explainer: FrozenCalibratedExplainer = FrozenCalibratedExplainer( calibrated_explainer ) self.condition_source: str = condition_source self.x_test: np.ndarray = x self.y_threshold: Optional[Union[float, Tuple[float, float], List[Tuple[float, float]]]] = ( y_threshold ) self.low_high_percentiles: Optional[Tuple[float, float]] = None self.explanations: List[ Union[FactualExplanation, AlternativeExplanation, FastExplanation] ] = [] self.start_index: int = 0 self.current_index: int = self.start_index self.end_index: int = len(x[:, 0]) self.bins: Optional[Sequence[Any]] = bins self.total_explain_time: Optional[float] = None self.features_to_ignore: List[int] = ( features_to_ignore if features_to_ignore is not None else [] ) # Optional per-instance feature ignore masks produced by the internal # FAST-based feature filter. When present, each entry corresponds to # the indices ignored for that instance on top of any global ignore. self.feature_filter_per_instance_ignore: Optional[Sequence[Sequence[int]]] = None # Optional telemetry from the internal FAST-based feature filter. self.filter_telemetry: Optional[Dict[str, Any]] = None # Derived caches (set during finalize of individual explanations) self._feature_names_cache: Optional[Sequence[str]] = None # populated lazily self._predictions_cache: Optional[np.ndarray] = None self._probabilities_cache: Optional[np.ndarray] = None # classification only self._lower_cache: Optional[np.ndarray] = None # regression only self._upper_cache: Optional[np.ndarray] = None self._class_labels_cache: Optional[Sequence[str]] = None # classification only def __iter__(self): """Return an iterator for the explanations.""" self.current_index = self.start_index return self def __next__(self): """Return the next explanation.""" if self.current_index >= self.end_index: raise StopIteration result = self[self.current_index] self.current_index += 1 return result def __len__(self): """Return the number of explanations.""" return len(self.x_test[:, 0])
[docs] def build_rules_payload(self) -> List[Dict[str, Any]]: """Delegate payload materialisation to each stored explanation.""" return [exp.build_rules_payload() for exp in self.explanations]
[docs] def get_guarded_audit(self) -> Dict[str, Any]: """Return guarded interval audit for the collection and each instance. Raises ------ ValidationError If called on a non-guarded explanation collection. """ if not self.explanations: return { "summary": { "n_instances": 0, "intervals_tested": 0, "intervals_conforming": 0, "intervals_removed_guard": 0, "intervals_emitted": 0, "instances_with_any_removed_guard": 0, "instances_all_intervals_removed_guard": 0, "instances_with_zero_emitted": 0, }, "instances": [], } if not all(hasattr(exp, "get_guarded_audit") for exp in self.explanations): raise ValidationError( "get_guarded_audit is only available for guarded explanation collections. " "Use explain_guarded_factual(...) or explore_guarded_alternatives(...).", details={"collection_type": type(self).__name__}, ) instances = [exp.get_guarded_audit() for exp in self.explanations] intervals_tested = int(sum(inst["summary"]["intervals_tested"] for inst in instances)) intervals_conforming = int( sum(inst["summary"]["intervals_conforming"] for inst in instances) ) intervals_removed_guard = int( sum(inst["summary"]["intervals_removed_guard"] for inst in instances) ) intervals_emitted = int(sum(inst["summary"]["intervals_emitted"] for inst in instances)) return { "summary": { "n_instances": int(len(instances)), "intervals_tested": intervals_tested, "intervals_conforming": intervals_conforming, "intervals_removed_guard": intervals_removed_guard, "intervals_emitted": intervals_emitted, "instances_with_any_removed_guard": int( sum(1 for inst in instances if inst["summary"]["intervals_removed_guard"] > 0) ), "instances_all_intervals_removed_guard": int( sum( 1 for inst in instances if inst["summary"]["intervals_tested"] > 0 and inst["summary"]["intervals_removed_guard"] == inst["summary"]["intervals_tested"] ) ), "instances_with_zero_emitted": int( sum(1 for inst in instances if inst["summary"]["intervals_emitted"] == 0) ), }, "instances": instances, }
[docs] def copy(self, deep=False): """Return a copy of the collection. Parameters ---------- deep : bool, default=False Determines whether to return a shallow or deep copy. Returns ------- CalibratedExplanations A copy of the collection. """ if deep: return deepcopy(self) return copy(self)
def __getitem__(self, key: Union[int, slice, List[int], List[bool], np.ndarray]): """Return the explanation for the given key. In case the index key is an integer (or results in a single result), the function returns the explanation corresponding to the index. If the key is a slice or an integer or boolean list (or numpy array) resulting in more than one explanation, the function returns a new `CalibratedExplanations` object with the indexed explanations. """ if isinstance(key, int): # Handle single item access return self.explanations[key] if isinstance(key, (slice, list, np.ndarray)): new_ = copy(self) if isinstance(key, slice): # Handle slicing new_.explanations = list(self.explanations[key]) if isinstance(key, (list, np.ndarray)): if isinstance(key[0], (bool, np.bool_)): # Handle boolean indexing new_.explanations = [ exp for exp, include in zip(self.explanations, key, strict=False) if include ] elif isinstance(key[0], int): # Handle integer list indexing new_.explanations = [self.explanations[i] for i in key] if len(new_.explanations) == 1: return new_.explanations[0] new_.start_index = 0 new_.current_index = new_.start_index new_.end_index = len(new_.explanations) new_.bins = None if self.bins is None else [self.bins[e.index] for e in new_] new_.x_test = np.array([self.x_test[e.index, :] for e in new_]) if self.y_threshold is None: new_.y_threshold = None elif isinstance(self.y_threshold, (int, float)): new_.y_threshold = float(self.y_threshold) elif isinstance(self.y_threshold, tuple): new_.y_threshold = self.y_threshold else: # assume list of tuples aligned with instances new_.y_threshold = [self.y_threshold[e.index] for e in new_] # Preserve per-instance feature ignore masks when present by slicing # them in the same way as bins/x_test/y_threshold. masks_value = getattr(self, "feature_filter_per_instance_ignore", None) if isinstance(masks_value, ABCSequence): try: new_.feature_filter_per_instance_ignore = [masks_value[e.index] for e in new_] except IndexError: new_.feature_filter_per_instance_ignore = None # Reset cached aggregates to avoid referencing stale state from the source new_._feature_names_cache = None new_._predictions_cache = None new_._probabilities_cache = None new_._lower_cache = None new_._upper_cache = None new_._class_labels_cache = None for i, e in enumerate(new_): e.index = i return new_ raise ValidationError("Invalid argument type.", details={"argument": key}) def __repr__(self) -> str: """Return the string representation of the CalibratedExplanations object.""" explanations_str = "\n".join([str(e) for e in self.explanations]) return f"CalibratedExplanations({len(self)} explanations):\n{explanations_str}" def __str__(self) -> str: """Return the string representation of the CalibratedExplanations object.""" return self.__repr__() # ------------------------------------------------------------------ # Plugin bridge helpers # ------------------------------------------------------------------
[docs] def to_batch(self): """Serialise the collection into an :class:`ExplanationBatch`.""" from ..plugins.builtins import collection_to_batch # lazy import return collection_to_batch(self)
[docs] @classmethod def from_batch(cls, batch): """Reconstruct a collection from an :class:`ExplanationBatch`.""" from ..utils.exceptions import SerializationError, ValidationError # Check for required batch attributes (duck-typing for flexibility) if not hasattr(batch, "collection_metadata"): raise SerializationError( "ExplanationBatch payload has unexpected type", details={ "param": "batch", "expected_type": "ExplanationBatch", "actual_type": type(batch).__name__, }, ) # Get container_cls if available (may be None for duck-typed batches with template) container_cls = getattr(batch, "container_cls", None) metadata = dict(batch.collection_metadata) template = metadata.pop("container", None) if container_cls is None and template is not None: container_cls = type(template) # If neither container_cls nor template is present, raise error if container_cls is None: raise SerializationError( "ExplanationBatch payload missing container_cls and template", details={ "param": "batch", "required": "container_cls or collection_metadata['container']", }, ) # Validate container_cls if present if not issubclass(container_cls, cls): raise ValidationError( "ExplanationBatch container metadata has unexpected type", details={ "param": "container_cls", "expected_type": cls.__name__, "actual_type": container_cls.__name__, }, ) # If template is a valid CalibratedExplanations instance, use it for metadata # but still reconstruct a new container from batch.instances to ensure # canonical reconstruction (ADR-015). if template is not None and not isinstance(template, cls): raise ValidationError( "ExplanationBatch container metadata has unexpected type", details={ "param": "container", "expected_type": cls.__name__, "actual_type": type(template).__name__, }, ) calibrated_explainer = metadata.get("calibrated_explainer") if calibrated_explainer is None: calibrated_explainer = metadata.get("explainer") if calibrated_explainer is None and template is not None: calibrated_explainer = template.calibrated_explainer x_test = metadata.get("x_test") if x_test is None: x_test = metadata.get("x") if x_test is None and template is not None: x_test = template.x_test y_threshold = metadata.get("y_threshold") if y_threshold is None and template is not None: y_threshold = template.y_threshold bins = metadata.get("bins") if bins is None and template is not None: bins = template.bins features_to_ignore = metadata.get("features_to_ignore") if features_to_ignore is None and template is not None: features_to_ignore = template.features_to_ignore condition_source = metadata.get("condition_source") if condition_source is None: if template is not None: condition_source = getattr(template, "condition_source", "prediction") else: condition_source = "prediction" if calibrated_explainer is None or x_test is None: raise SerializationError( "ExplanationBatch metadata missing explainer context", details={ "artifact": "ExplanationBatch", "field": "calibrated_explainer", "available_keys": tuple(sorted(metadata.keys())), }, ) container = container_cls( calibrated_explainer, x_test, y_threshold, bins, features_to_ignore, condition_source=condition_source, ) container.low_high_percentiles = metadata.get( "low_high_percentiles", getattr(template, "low_high_percentiles", None) ) container.total_explain_time = metadata.get( "total_explain_time", getattr(template, "total_explain_time", None) ) container.feature_filter_per_instance_ignore = metadata.get( "feature_filter_per_instance_ignore", getattr(template, "feature_filter_per_instance_ignore", None), ) container.batch_metadata = dict(metadata) # Propagate any full probability cube from instances into batch metadata # (keeps collection metadata aligned with telemetry exports). Also # populate a minimal `telemetry` attribute on the materialised # container so callers using `from_batch` directly receive the # same dependency hints and probability summaries as the orchestrator. container.telemetry = {"interval_dependencies": metadata.get("interval_dependencies")} full_probs = None for inst in getattr(batch, "instances", ()): pred = inst.get("prediction") if isinstance(inst, dict) else None if isinstance(pred, dict) and "__full_probabilities__" in pred: full_probs = pred.get("__full_probabilities__") break if full_probs is not None: container.batch_metadata.setdefault("__full_probabilities__", full_probs) with contextlib.suppress(Exception): # adr002_allow arr = np.asarray(full_probs) container.telemetry = { "full_probabilities_shape": tuple(arr.shape), "full_probabilities_summary": { "mean": float(np.mean(arr)), "min": float(np.min(arr)), "max": float(np.max(arr)), }, "interval_dependencies": metadata.get("interval_dependencies"), } for index, instance in enumerate(batch.instances): explanation = instance.get("explanation") if explanation is None: raise SerializationError( "ExplanationBatch instance missing explanation payload", details={ "artifact": "ExplanationBatch", "field": "explanation", "instance_index": index, }, ) if not isinstance(explanation, batch.explanation_cls): raise ValidationError( "ExplanationBatch instance has unexpected explanation type", details={ "param": "explanation", "expected_type": batch.explanation_cls.__name__, "actual_type": type(explanation).__name__, }, ) explanation_copy = copy(explanation) explanation_copy.calibrated_explanations = container container.explanations.append(explanation_copy) return container
# ------------------------------------------------------------------ # JSON export helpers (schema v1 wrappers) # ------------------------------------------------------------------
[docs] def to_json(self, *, include_version: bool = True) -> Mapping[str, Any]: """Return a JSON-friendly payload describing this collection. The payload wraps each explanation using the schema v1 helpers from :mod:`calibrated_explanations.serialization` and adds collection-level metadata (mode, thresholds, feature names, telemetry snapshot). Parameters ---------- include_version: When ``True`` (default) the ``schema_version`` field is included on the top-level payload as well as on each explanation entry. """ from ..serialization import to_json as _explanation_to_json instances = [] for exp in self.explanations: domain = legacy_to_domain(exp.index, self._legacy_payload(exp)) provenance = getattr(exp, "provenance", None) metadata = getattr(exp, "metadata", None) if provenance is not None: domain.provenance = cast(Optional[Mapping[str, Any]], _jsonify(provenance)) if metadata is not None: domain.metadata = cast(Optional[Mapping[str, Any]], _jsonify(metadata)) instances.append(_explanation_to_json(domain, include_version=include_version)) payload: dict[str, Any] = { "collection": self._collection_metadata(), "explanations": instances, } if include_version: payload.setdefault("schema_version", "1.0.0") return payload
[docs] def to_json_stream(self, *, chunk_size: int = 256, format: str = "jsonl"): """Stream the collection as JSON. This generator yields either a JSON Lines stream or chunked JSON arrays. Parameters ---------- chunk_size: Number of explanations per yielded chunk (for "chunked") or used only for grouping when `format=="chunked"`. format: Either ``"jsonl"`` (default) for JSON Lines or ``"chunked"`` for chunked JSON arrays. Yields ------ str UTF-8 JSON fragments (one per yield). The first yielded fragment is a small metadata object describing the collection and the export telemetry. """ from ..serialization import to_json as _explanation_to_json if format not in {"jsonl", "chunked"}: raise ValidationError("Unsupported stream format", details={"format": format}) start = time() tracemalloc.start() # Prepare collection metadata snapshot (without export telemetry yet) metadata = dict(self._collection_metadata()) # Yield metadata first as a standalone JSON object line # Telemetry placeholders updated after the stream completes. meta_fragment = {"collection": metadata, "schema_version": "1.0.0"} yield json.dumps(meta_fragment, default=_jsonify) # Stream explanations chunk: List[str] = [] n = 0 for exp in self.explanations: domain = legacy_to_domain(exp.index, self._legacy_payload(exp)) provenance = getattr(exp, "provenance", None) metadata_exp = getattr(exp, "metadata", None) if provenance is not None: domain.provenance = cast(Optional[Mapping[str, Any]], _jsonify(provenance)) if metadata_exp is not None: domain.metadata = cast(Optional[Mapping[str, Any]], _jsonify(metadata_exp)) item = _explanation_to_json(domain, include_version=True) line = json.dumps(item, default=_jsonify) n += 1 if format == "jsonl": yield line else: # chunked chunk.append(line) if len(chunk) >= chunk_size: # yield a JSON array for this chunk yield "[" + ",".join(chunk) + "]" chunk = [] # flush remaining chunk if format == "chunked" and chunk: yield "[" + ",".join(chunk) + "]" # stop tracemalloc and capture peak memory peak = tracemalloc.get_traced_memory()[1] tracemalloc.stop() elapsed = time() - start # Build telemetry telemetry = { "export_rows": n, "chunk_size": chunk_size, "mode": getattr(self.calibrated_explainer, "mode", None), "peak_memory_mb": round(float(peak) / (1024 * 1024), 3), "elapsed_seconds": round(float(elapsed), 3), "schema_version": "1.0.0", # feature_branch replaced by explicit fields "build_id": None, "feature_flags": None, } # Attach minimal telemetry to collection metadata and attempt to store # a more complete record on the underlying explainer if available. try: # minimal metadata metadata.setdefault("export_telemetry", {}) metadata["export_telemetry"].update(telemetry) # attempt to update underlying explainer last telemetry underlying = getattr(self.calibrated_explainer, "_explainer", None) if underlying is not None: try: last = getattr(underlying, "_last_telemetry", None) or {} last.update({"export": telemetry}) underlying._last_telemetry = last except Exception: # adr002_allow # best-effort only: log for observability per fallback policy _LOGGER.info( "failed to attach export telemetry to underlying explainer", exc_info=True, ) except Exception: # adr002_allow _LOGGER.info("failed to attach export telemetry to collection", exc_info=True) # final telemetry fragment yield json.dumps({"export_telemetry": telemetry}, default=_jsonify)
[docs] @classmethod def from_json(cls, payload: Mapping[str, Any]) -> ExportedExplanationCollection: """Materialise domain explanations from a :meth:`to_json` payload.""" from ..serialization import from_json as _explanation_from_json explanations_blob = payload.get("explanations", []) or [] domain: list[DomainExplanation] = [] for item in explanations_blob: # Extract explicit multiclass annotations when present on the raw payload cls_idx = None cls_label = None if isinstance(item, Mapping): cls_idx = item.get("class_index") cls_label = item.get("class_label") # Also allow annotations under item['metadata'] when produced by other exporters meta = item.get("metadata") if isinstance(item.get("metadata"), Mapping) else None if meta is not None: if cls_idx is None: cls_idx = meta.get("class_index") if cls_label is None: cls_label = meta.get("class_label") domain_exp = _explanation_from_json(item) # Ensure metadata is mutable dict and propagate class annotations m = dict(domain_exp.metadata) if isinstance(domain_exp.metadata, Mapping) else {} if cls_idx is not None: try: m.setdefault("class_index", int(cls_idx)) except (TypeError, ValueError, OverflowError): m.setdefault("class_index", cls_idx) if cls_label is not None: m.setdefault("class_label", cls_label) # attach back domain_exp.metadata = m or None domain.append(domain_exp) metadata = payload.get("collection", {}) or {} return ExportedExplanationCollection( metadata=cast(Mapping[str, Any], _jsonify(metadata)), explanations=tuple(domain) )
# ------------------------------------------------------------------ # Internal helpers # ------------------------------------------------------------------ def _legacy_payload(self, exp) -> Mapping[str, Any]: """Build a legacy-shaped payload from an explanation instance.""" rules_blob = None # prefer conjunctive rules when present and populated if getattr(exp, "has_conjunctive_rules", False): rules_blob = getattr(exp, "conjunctive_rules", None) if not rules_blob: rules_blob = getattr(exp, "rules", None) if not rules_blob and hasattr(exp, "get_rules"): with warnings.catch_warnings(): warnings.simplefilter("ignore") try: rules_blob = exp.get_rules() # type: ignore[attr-defined] except: # noqa: E722 if not isinstance(sys.exc_info()[1], Exception): raise rules_blob = {} explanation_type = "factual" if isinstance(exp, AlternativeExplanation): explanation_type = "alternative" elif isinstance(exp, FastExplanation): explanation_type = "fast" payload: dict[str, Any] = { "task": getattr( exp, "get_mode", lambda: getattr(self.calibrated_explainer, "mode", None) )(), "rules": _jsonify(rules_blob or {}), "feature_weights": _jsonify(getattr(exp, "feature_weights", {})), "feature_predict": _jsonify(getattr(exp, "feature_predict", {})), "prediction": _jsonify(getattr(exp, "prediction", {})), "explanation_type": explanation_type, } return payload def _collection_metadata(self) -> Mapping[str, Any]: """Collect calibration metadata required to interpret the payload.""" base = getattr(self, "calibrated_explainer", None) underlying = getattr(base, "_explainer", None) feature_names = None try: names = self.feature_names if names is not None: feature_names = list(names) except: # noqa: E722 if not isinstance(sys.exc_info()[1], Exception): raise feature_names = None class_labels = None if hasattr(base, "class_labels"): try: class_labels = _jsonify(base.class_labels) # type: ignore[attr-defined] except: # noqa: E722 if not isinstance(sys.exc_info()[1], Exception): raise class_labels = None sample_percentiles = None if hasattr(base, "sample_percentiles"): try: sample_percentiles = _jsonify(base.sample_percentiles) # type: ignore[attr-defined] except: # noqa: E722 if not isinstance(sys.exc_info()[1], Exception): raise sample_percentiles = None runtime_telemetry = None if underlying is not None: try: runtime_telemetry = getattr(underlying, "runtime_telemetry", None) if callable(runtime_telemetry): runtime_telemetry = runtime_telemetry() except: # noqa: E722 if not isinstance(sys.exc_info()[1], Exception): raise runtime_telemetry = None metadata = { "size": len(self), "mode": getattr(base, "mode", None), "y_threshold": _jsonify(self.y_threshold), "low_high_percentiles": _jsonify(self.low_high_percentiles), "feature_names": _jsonify(feature_names), "class_labels": class_labels, "sample_percentiles": sample_percentiles, "runtime_telemetry": _jsonify(runtime_telemetry), } return {k: v for k, v in metadata.items() if v is not None} # Public wrappers for formerly-private helpers (temporary, Category A remediation)
[docs] def collection_metadata(self) -> Mapping[str, Any]: """Public wrapper around internal collection metadata helper.""" return self._collection_metadata()
[docs] def legacy_payload(self, exp) -> Mapping[str, Any]: """Public wrapper to obtain the legacy payload for an explanation.""" return self._legacy_payload(exp)
@property def prediction_interval(self) -> List[Tuple[Optional[float], Optional[float]]]: """Return the prediction intervals for each explanation. Returns ------- list of tuples A list of tuples containing (low, high) values of the prediction interval. """ return [e.prediction_interval for e in self.explanations] @property def predict(self) -> List[Any]: """Return the scalar prediction for every explanation. Returns ------- list A list of prediction value. """ return [e.predict for e in self.explanations] # ---- Rich baseline exposure (Phase 1A golden snapshot enrichment) ---- @property def feature_names(self): # consistent naming with underlying explainer """Return cached feature names sourced from the underlying explainer.""" if self._feature_names_cache is None: # Underlying FrozenCalibratedExplainer exposes feature_names via original explainer try: self._feature_names_cache = self.calibrated_explainer.feature_names except: # noqa: E722 if not isinstance(sys.exc_info()[1], Exception): raise self._feature_names_cache = None return self._feature_names_cache @property def class_labels(self): """Return class labels for classification explanations if available.""" if self._class_labels_cache is None: try: labels = getattr(self.calibrated_explainer, "class_labels", None) if labels is not None and isinstance(labels, dict): # normalize to list ordered by class index if dict provided # assume keys are numeric class indices labels = [labels[k] for k in sorted(labels.keys())] self._class_labels_cache = labels except: # noqa: E722 if not isinstance(sys.exc_info()[1], Exception): raise self._class_labels_cache = None return self._class_labels_cache @property def predictions(self): # noqa: D401 """Vector of scalar predictions for the explained instances (cached).""" if self._predictions_cache is None: try: self._predictions_cache = np.asarray([e.predict for e in self.explanations]) except: # noqa: E722 if not isinstance(sys.exc_info()[1], Exception): raise self._predictions_cache = None return self._predictions_cache @property def probabilities(self): # classification only """Return cached probability matrices for classification explanations.""" if self._probabilities_cache is None: try: # Each explanation may store: # (a) its own probability vector (shape (n_classes,)) OR # (b) the full matrix (n_instances, n_classes) due to earlier enrichment raw = [getattr(e, "prediction_probabilities", None) for e in self.explanations] if all(r is not None for r in raw): # If first is a tuple (should not now), handle defensively first = raw[0] if isinstance(first, tuple): # pragma: no cover - defensive first = first[0] first = np.asarray(first) if first.ndim == 2 and first.shape[0] == len(self.explanations): # Case (b): each explanation redundantly holds full matrix self._probabilities_cache = first else: # Case (a): stack per-instance vectors self._probabilities_cache = np.vstack(raw) except: # noqa: E722 if not isinstance(sys.exc_info()[1], Exception): raise self._probabilities_cache = None return self._probabilities_cache @property def lower(self): # regression only """Return cached lower bounds for regression prediction intervals.""" if self._lower_cache is None: try: lows = [ getattr(e, "prediction_interval", (None, None))[0] for e in self.explanations ] if any(low is not None for low in lows): self._lower_cache = np.asarray(lows) except: # noqa: E722 if not isinstance(sys.exc_info()[1], Exception): raise self._lower_cache = None return self._lower_cache @property def upper(self): # regression only """Return cached upper bounds for regression prediction intervals.""" if self._upper_cache is None: try: highs = [ getattr(e, "prediction_interval", (None, None))[1] for e in self.explanations ] if any(h is not None for h in highs): self._upper_cache = np.asarray(highs) except: # noqa: E722 if not isinstance(sys.exc_info()[1], Exception): raise self._upper_cache = None return self._upper_cache @property def is_probabilistic_regression(self) -> bool: """Check if the explanations use probabilistic regression (thresholded). Probabilistic regression and thresholded regression are synonymous terms. See ADR-021 for terminology guidance. """ return self.y_threshold is not None @property def is_one_sided(self) -> bool: """Check if the explanations are one-sided.""" if self.low_high_percentiles is None: return False return np.isinf(self.get_low_percentile()) or np.isinf(self.get_high_percentile())
[docs] def get_confidence(self) -> float: """Return the confidence level of the explanations. This method calculates the confidence interval for regression tasks by determining the distance between the lower and upper percentiles. By default, these percentiles are set to 5 and 95. Returns ------- float The difference between the high and low percentiles, representing the confidence interval. Notes ----- - This method is only applicable to regression tasks. - If the high percentile is infinite, the confidence is calculated as `100 - low_percentile`. - If the low percentile is infinite, the confidence is calculated as `high_percentile`. """ if np.isinf(self.get_high_percentile()): return 100 - self.get_low_percentile() if np.isinf(self.get_low_percentile()): return self.get_high_percentile() return self.get_high_percentile() - self.get_low_percentile()
[docs] def get_low_percentile(self) -> float: """Return the low percentile of the explanations. This method returns the first element of the `low_high_percentiles` attribute, which represents the lower bound of the percentile range for the explanation. Returns ------- float The low percentile value of the explanation. """ # mypy: low_high_percentiles is Optional; ensure it's set by callers before use assert self.low_high_percentiles is not None, "low_high_percentiles not set" return self.low_high_percentiles[0] # pylint: disable=unsubscriptable-object
[docs] def get_high_percentile(self) -> float: """Return the high percentile of the explanations. Returns ------- float The high percentile value of the explanation. """ assert self.low_high_percentiles is not None, "low_high_percentiles not set" return self.low_high_percentiles[1] # pylint: disable=unsubscriptable-object
# pylint: disable=too-many-arguments
[docs] def finalize( self, binned, feature_weights, feature_predict, prediction, instance_time=None, total_time=None, ) -> "CalibratedExplanations": """ Finalize the explanation by adding the binned data and the feature weights. Parameters ---------- binned : array-like The binned data for the features. feature_weights : array-like The weights of the features. feature_predict : array-like The predicted values for the features. prediction : array-like The prediction values. instance_time : array-like, optional The time taken to explain each instance, by default None. total_time : float, optional The total time taken to explain all instances, by default None. Returns ------- self : object Returns the instance of the class with explanations finalized. """ for i, instance in enumerate(self.x_test): instance_bin = self.bins[i] if self.bins is not None else None if self.is_alternative(): explanation: Union[FactualExplanation, AlternativeExplanation, FastExplanation] explanation = AlternativeExplanation( self, i, instance, binned, feature_weights, feature_predict, prediction, self.y_threshold, instance_bin=instance_bin, condition_source=self.condition_source, ) else: explanation = FactualExplanation( self, i, instance, binned, feature_weights, feature_predict, prediction, self.y_threshold, instance_bin=instance_bin, condition_source=self.condition_source, ) explanation.explain_time = instance_time[i] if instance_time is not None else None self.explanations.append(explanation) self.total_explain_time = time() - total_time if total_time is not None else None if self.is_alternative(): return self.__convert_to_alternative_explanations() return self
def __convert_to_alternative_explanations(self) -> "AlternativeExplanations": """Return an ``AlternativeExplanations`` view sharing this collection's backing data.""" alternative_explanations = AlternativeExplanations.__new__(AlternativeExplanations) alternative_explanations.__dict__.update(self.__dict__) return alternative_explanations # pylint: disable=too-many-arguments
[docs] def finalize_fast( self, feature_weights, feature_predict, prediction, instance_time=None, total_time=None ) -> None: """ Finalize the explanation by adding the binned data and the feature weights. Parameters ---------- binned : array-like The binned data for the features. feature_weights : array-like The weights of the features. feature_predict : array-like The predicted values for the features. prediction : array-like The prediction values. instance_time : array-like, optional The time taken to explain each instance, by default None. total_time : float, optional The total time taken to explain all instances, by default None. Notes ----- - This method iterates over the test instances and creates a `FastExplanation` object for each instance. - The `FastExplanation` object is initialized with the provided feature weights, predictions, and other relevant data. - The explanation time for each instance is recorded if `instance_time` is provided. - The total explanation time is calculated if `total_time` is provided. """ for i, instance in enumerate(self.x_test): instance_bin = self.bins[i] if self.bins is not None else None explanation = FastExplanation( self, i, instance, feature_weights, feature_predict, prediction, self.y_threshold, instance_bin=instance_bin, condition_source=self.condition_source, ) explanation.explain_time = instance_time[i] if instance_time is not None else None self.explanations.append(explanation) self.total_explain_time = time() - total_time if total_time is not None else None
[docs] def get_explainer(self): """Return the underlying :class:`~calibrated_explanations.core.calibrated_explainer.CalibratedExplainer` instance.""" return self.calibrated_explainer
[docs] def get_rules(self): """Return the materialised rule payload for each explanation in the collection.""" return [ # pylint: disable=protected-access explanation.get_rules() for explanation in self.explanations ]
[docs] def add_conjunctions(self, n_top_features=5, max_rule_size=2, **kwargs): """ Add conjunctive rules to the explanations. The conjunctive rules are added to the `conjunctive_rules` attribute of the `CalibratedExplanations` object. Parameters ---------- n_top_features : int, optional The number of most important factual rules to try to combine into conjunctive rules. Defaults to 5. max_rule_size : int, optional The maximum size of the conjunctions. Defaults to 2 (meaning `rule_one and rule_two`). Returns ------- CalibratedExplanations Returns a self reference, to allow for method chaining. """ for explanation in self.explanations: explanation.add_conjunctions(n_top_features, max_rule_size, **kwargs) return self
[docs] def reset(self): """Reset the explanations to their original state.""" for explanation in self.explanations: explanation.reset() return self
[docs] def remove_conjunctions(self): """Remove any conjunctive rules.""" for explanation in self.explanations: explanation.remove_conjunctions() return self
[docs] def filter_rule_sizes( self, *, rule_sizes: Optional[Any] = None, size_range: Optional[Tuple[int, int]] = None, copy: bool = True, ): """Filter rules by conjunctive rule size across the collection.""" if copy: new_obj = self.copy() new_obj.explanations = [ explanation.filter_rule_sizes( rule_sizes=rule_sizes, size_range=size_range, copy=True ) for explanation in self.explanations ] return new_obj for idx, explanation in enumerate(self.explanations): self.explanations[idx] = explanation.filter_rule_sizes( rule_sizes=rule_sizes, size_range=size_range, copy=False ) return self
[docs] def filter_features( self, *, exclude_features=None, include_features=None, copy: bool = True, ) -> "CalibratedExplanations": """Filter rules by feature inclusion or exclusion across all explanations. Parameters ---------- exclude_features : str, int, or sequence of str/int, optional Feature names (str) or indices (int) to exclude. Rules containing any of these features will be removed. include_features : str, int, or sequence of str/int, optional Feature names (str) or indices (int) to include. Only rules containing these features will be kept. copy : bool, default=True If True, return a filtered copy without mutating the original. Returns ------- CalibratedExplanations Filtered explanations object. """ if copy: new_obj = self.copy() new_obj.explanations = [ explanation.filter_features( exclude_features=exclude_features, include_features=include_features, copy=True ) for explanation in self.explanations ] return new_obj for idx, explanation in enumerate(self.explanations): self.explanations[idx] = explanation.filter_features( exclude_features=exclude_features, include_features=include_features, copy=False ) return self
[docs] def is_alternative(self): """Return True when the collection represents an alternative explanation workflow.""" return isinstance( self.calibrated_explainer.discretizer, (RegressorDiscretizer, EntropyDiscretizer) )
# pylint: disable=too-many-arguments, too-many-locals, unused-argument
[docs] def plot( self, index=None, filter_top=10, show=True, filename="", uncertainty=False, style="regular", rnk_metric=None, rnk_weight=0.5, style_override=None, **kwargs, ): """Plot explanations for a given instance, with the option to show or save the plots. Parameters ---------- index : int or None, default=None The index of the instance for which you want to plot the explanation. If None, the function will plot all the explanations. filter_top : int or None, default=10 The number of top features to display in the plot. If set to `None`, all the features will be shown. show : bool, default=True Determines whether the plots should be displayed immediately after they are generated. Suitable to set to False when saving the plots to a file. filename : str, default='' The full path and filename of the plot image file that will be saved. If empty, the plot will not be saved. uncertainty : bool, default=False Determines whether to include uncertainty information in the plots. style : str, default='regular' The style of the plot. Supported styles are 'regular' and 'triangular'. Use ``style='ensured'`` as an alias for ``style='triangular'``. rnk_metric : str, default=None The metric used to rank the features. Supported metrics are 'ensured', 'feature_weight', and 'uncertainty'. If None, the default from the explanation class is used. rnk_weight : float, default=0.5 The weight of the uncertainty in the ranking. Used with the 'ensured' ranking metric. Returns ------- None See Also -------- FactualExplanation.plot Refer to the docstring for plot in FactualExplanation for details on default ranking ('feature_weight'). AlternativeExplanation.plot Refer to the docstring for plot in AlternativeExplanation for details on default ranking ('ensured'). FastExplanation.plot Refer to the docstring for plot in FastExplanation for details on default ranking ('feature_weight'). """ if style == "ensured": style = "triangular" custom_plot_style = isinstance(style, str) and style not in { "regular", "triangular", "ensured", "narrative", } if index is None and custom_plot_style: selected_instance_index = kwargs.get("instance_index") if isinstance(selected_instance_index, int): kwargs = dict(kwargs) kwargs.pop("instance_index", None) index = selected_instance_index if style == "narrative": from ..viz.narrative_plugin import NarrativePlotPlugin template_path = kwargs.pop("template_path", None) expertise_level = kwargs.pop( "expertise_level", ("beginner", "intermediate", "advanced") ) output_format = kwargs.pop("output", "dataframe") plugin = NarrativePlotPlugin(template_path=template_path) if index is not None: # Delegate to single explanation helper when a specific index is requested return self[index].to_narrative( template_path=template_path, expertise_level=expertise_level, output_format=output_format, **kwargs, ) return plugin.plot( self, template_path=template_path, expertise_level=expertise_level, output=output_format, **kwargs, ) if len(filename) > 0: path, filename, title, ext = prepare_for_saving(filename) plugin_path = filename plugin_save_ext = ext else: plugin_path = None plugin_save_ext = None if index is None and custom_plot_style: from ..plotting import _render_collection_plot_plugin plugin_result = _render_collection_plot_plugin( self, explicit_style=style_override if isinstance(style_override, str) and style_override else style, show=show, path=plugin_path, save_ext=plugin_save_ext, renderer_override=kwargs.get("renderer"), intent_type="alternative" if self.is_alternative() else "factual", options=kwargs, ) if plugin_result is not None: return plugin_result if index is not None: if len(filename) > 0: filename = path + title + str(index) + ext return self[index].plot( filter_top=filter_top, show=show, filename=filename, uncertainty=uncertainty, style=style, rnk_metric=rnk_metric, rnk_weight=rnk_weight, style_override=style_override, **kwargs, ) else: results = [] for i, explanation in enumerate(self.explanations): if len(filename) > 0: filename = path + title + str(i) + ext results.append( explanation.plot( filter_top=filter_top, show=show, filename=filename, uncertainty=uncertainty, style=style, rnk_metric=rnk_metric, rnk_weight=rnk_weight, style_override=style_override, **kwargs, ) ) if kwargs.get("return_plot_spec"): return results[0] if len(results) == 1 else results non_null_results = [result for result in results if result is not None] if non_null_results: return results[0] if len(results) == 1 else results
[docs] def to_narrative( self, template_path=None, expertise_level=("beginner", "advanced"), output_format="dataframe", conjunction_separator=" AND ", align_weights=True, **kwargs, ): """ Generate narrative explanations for the collection. This method provides a clean API for generating human-readable narratives from calibrated explanations using customizable templates. Parameters ---------- template_path : str or None, default=None Path to the narrative template file (YAML or JSON). If None or the file doesn't exist, the built-in default template is used. expertise_level : str or tuple of str, default=("beginner", "advanced") The expertise level(s) for narrative generation. Can be a single level or a tuple of levels. Valid values: "beginner", "intermediate", "advanced". output_format : str, default="dataframe" Output format. Valid values: "dataframe", "text", "html", "dict", "markdown". conjunction_separator : str, default=" AND " Separator to use for conjunctive rules. Conjunctive rules combine multiple feature conditions (e.g., "Glucose > 120 AND BMI > 28"). align_weights : bool, default=True If True, vertically align weight columns in the narrative output. If False, no alignment is applied. **kwargs : dict Additional keyword arguments passed to the narrative plugin. Returns ------- pd.DataFrame or str or list of dict The generated narratives in the requested format: - "dataframe": pandas DataFrame with columns for each expertise level - "text": formatted text string with all narratives - "html": HTML table with all narratives - "dict": list of dictionaries, one per instance Raises ------ FileNotFoundError If the template file is not found and no default is available. ValueError If an invalid expertise level or output format is specified. ImportError If pandas is not available and output_format="dataframe" is requested. Examples -------- >>> from calibrated_explanations import CalibratedExplainer >>> explainer = CalibratedExplainer(model, x_train, y_train) >>> explanations = explainer.explain_factual(x_test) >>> narratives = explanations.to_narrative( ... expertise_level=("beginner", "advanced"), ... output_format="dataframe" ... ) >>> print(narratives) See Also -------- :meth:`.plot` : Plot explanations with various visual styles. """ from ..viz.narrative_plugin import NarrativePlotPlugin # Create plugin instance plugin = NarrativePlotPlugin(template_path=template_path) # Generate narratives using the plugin return plugin.plot( self, template_path=template_path, expertise_level=expertise_level, output=output_format, conjunction_separator=conjunction_separator, align_weights=align_weights, **kwargs, )
[docs] def to_dataframe(self, *args, **kwargs): """Return the narrative output as a pandas DataFrame. Call :meth:`to_narrative` with ``output_format='dataframe'`` and return the resulting DataFrame. Accepts the same arguments as :meth:`to_narrative`. """ kwargs.setdefault("output_format", "dataframe") return self.to_narrative(*args, **kwargs)
@staticmethod def _deprecate_lime_shap_surface( symbol: str, replacement: str, *, removal_version: str, ) -> None: """Emit Task-21 deprecation warning for collection LIME/SHAP export helpers.""" deprecate( f"CalibratedExplanations.{symbol} is deprecated since v0.11.1; use " f"{replacement} instead. This API is scheduled for removal by {removal_version} " "under the pre-v1.0 zero-deprecation closure policy.", key=f"CalibratedExplanations.{symbol}_lime_shap_deprecation", stacklevel=4, ) # pylint: disable=protected-access
[docs] def as_lime(self, num_features_to_show=None): """Transform the explanations into LIME explanation objects. Returns ------- list of lime.Explanation List of LIME explanation objects with the same values as the `CalibratedExplanations`. """ self._deprecate_lime_shap_surface( "as_lime", "external_plugins.integrations.lime_pipeline.LimePipeline(...).explain(...)", removal_version="v0.11.3", ) with warnings.catch_warnings(): warnings.simplefilter("ignore", DeprecationWarning) _, lime_exp = self.calibrated_explainer.preload_lime() exp = [] for explanation in self.explanations: # range(len(self.x[:,0])): tmp = deepcopy(lime_exp) tmp.intercept[1] = 0 tmp.local_pred = explanation.prediction["predict"] if "regression" in self.calibrated_explainer.mode: tmp.predicted_value = explanation.prediction["predict"] tmp.min_value = np.min(self.calibrated_explainer.y_cal) tmp.max_value = np.max(self.calibrated_explainer.y_cal) else: tmp.predict_proba[0], tmp.predict_proba[1] = ( 1 - explanation.prediction["predict"], explanation.prediction["predict"], ) feature_weights = explanation.feature_weights["predict"] num_to_show = ( num_features_to_show if num_features_to_show is not None else self.calibrated_explainer.num_features ) features_to_plot = explanation.rank_features(feature_weights, num_to_show=num_to_show) define_conditions = getattr(explanation, "define_conditions", None) if define_conditions is None: define_conditions = getattr(explanation, "_define_conditions", None) rules = define_conditions() if define_conditions is not None else [] for j, f in enumerate(features_to_plot[::-1]): # pylint: disable=invalid-name tmp.local_exp[1][j] = (f, feature_weights[f]) del tmp.local_exp[1][num_to_show:] tmp.domain_mapper.discretized_feature_names = rules tmp.domain_mapper.feature_values = explanation.x_test exp.append(tmp) return exp
[docs] def as_shap(self): """Transform the explanations into a SHAP explanation object. Returns ------- shap.Explanation SHAP explanation object with the same values as the explanation. """ self._deprecate_lime_shap_surface( "as_shap", "external_plugins.integrations.shap_pipeline.ShapPipeline(...).explain(...)", removal_version="v0.11.3", ) with warnings.catch_warnings(): warnings.simplefilter("ignore", DeprecationWarning) _, shap_exp = self.calibrated_explainer.preload_shap() shap_exp.base_values = np.resize(shap_exp.base_values, len(self)) shap_exp.values = np.resize(shap_exp.values, (len(self), len(self.x_test[0, :]))) shap_exp.data = self.x_test for i, explanation in enumerate(self.explanations): # range(len(self.x[:,0])): # shap_exp.base_values[i] = explanation.prediction['predict'] for f in range(len(self.x_test[0, :])): shap_exp.values[i][f] = -explanation.feature_weights["predict"][f] return shap_exp
[docs] class AlternativeExplanations(CalibratedExplanations): """A class for storing and visualizing alternative explanations. Inherits from :class:`.CalibratedExplanations` and provides methods specific to alternative explanations, such as filtering explanations by type. """
[docs] def super_explanations(self, only_ensured=False, include_potential=True, copy=True): """ Return a copy with only super-explanations. Super-explanations are individual rules with higher probability that support the predicted class. Parameters ---------- only_ensured : bool, default=False Determines whether to return only ensured explanations. include_potential : bool, default=True Determines whether to include potential explanations in the super-explanations. copy : bool, default=True Determines whether to return a copy of the explanations or modify them in place. Returns ------- AlternativeExplanations A new `AlternativeExplanations` object containing only super-factual or super-potential explanations. Notes ----- Super-explanations are only available for `AlternativeExplanation` explanations. """ if copy: new_obj = self.copy() new_obj.explanations = [ explanation.super_explanations( only_ensured=only_ensured, include_potential=include_potential, copy=True ) for explanation in self.explanations ] return new_obj for explanation in self.explanations: explanation.super_explanations( only_ensured=only_ensured, include_potential=include_potential, copy=False ) return self
[docs] def super(self, only_ensured=False, include_potential=True, copy=True): """Shorthand delegator for :meth:`.super_explanations`.""" return self.super_explanations( only_ensured=only_ensured, include_potential=include_potential, copy=copy )
[docs] @classmethod def from_collection(cls, collection: "CalibratedExplanations"): """Create an AlternativeExplanations instance from an existing collection. This provides a safe public API for tests and callers that previously constructed an instance via low-level hacks like `__new__` and direct `__dict__` assignment. """ inst = cls.__new__(cls) # Copy the public and necessary internal state conservatively. inst.calibrated_explainer = collection.calibrated_explainer inst.condition_source = getattr(collection, "condition_source", None) inst.x_test = getattr(collection, "x_test", None) inst.y_threshold = getattr(collection, "y_threshold", None) inst.low_high_percentiles = getattr(collection, "low_high_percentiles", None) inst.explanations = list(getattr(collection, "explanations", [])) inst.start_index = getattr(collection, "start_index", 0) inst.current_index = getattr(collection, "current_index", inst.start_index) inst.end_index = getattr( collection, "end_index", len(inst.x_test[:, 0]) if inst.x_test is not None else 0 ) inst.bins = getattr(collection, "bins", None) inst.total_explain_time = getattr(collection, "total_explain_time", None) inst.features_to_ignore = list(getattr(collection, "features_to_ignore", [])) inst.feature_filter_per_instance_ignore = getattr( collection, "feature_filter_per_instance_ignore", None ) # Preserve caches if present inst._feature_names_cache = getattr(collection, "_feature_names_cache", None) inst._predictions_cache = getattr(collection, "_predictions_cache", None) inst._probabilities_cache = getattr(collection, "_probabilities_cache", None) inst._lower_cache = getattr(collection, "_lower_cache", None) inst._upper_cache = getattr(collection, "_upper_cache", None) inst._class_labels_cache = getattr(collection, "_class_labels_cache", None) return inst
[docs] def semi_explanations(self, only_ensured=False, include_potential=True, copy=True): """ Return a copy with only semi-explanations. Semi-explanations are individual rules with lower probability that support the predicted class. Parameters ---------- only_ensured : bool, default=False Determines whether to return only ensured explanations. include_potential : bool, default=True Determines whether to include potential explanations in the semi-explanations. copy : bool, default=True Determines whether to return a copy of the explanations or modify them in place. Returns ------- AlternativeExplanations A new `AlternativeExplanations` object containing only semi-factual or semi-potential explanations. Notes ----- Semi-explanations are only available for `AlternativeExplanation` explanations. """ if copy: new_obj = self.copy() new_obj.explanations = [ explanation.semi_explanations( only_ensured=only_ensured, include_potential=include_potential, copy=True ) for explanation in self.explanations ] return new_obj for explanation in self.explanations: explanation.semi_explanations( only_ensured=only_ensured, include_potential=include_potential, copy=False ) return self
[docs] def semi(self, only_ensured=False, include_potential=True, copy=True): """Shorthand delegator for :meth:`.semi_explanations`.""" return self.semi_explanations( only_ensured=only_ensured, include_potential=include_potential, copy=copy )
[docs] def counter_explanations(self, only_ensured=False, include_potential=True, copy=True): """ Return a copy with only counter-explanations. Counter-explanations are individual rules that do not support the predicted class. Parameters ---------- only_ensured : bool, default=False Determines whether to return only ensured explanations. include_potential : bool, default=True Determines whether to include potential explanations in the counter-explanations. copy : bool, default=True Determines whether to return a copy of the explanations or modify them in place. Returns ------- AlternativeExplanations A new `AlternativeExplanations` object containing only counter-factual or counter-potential explanations. Notes ----- Counter-explanations are only available for `AlternativeExplanation` explanations. """ if copy: new_obj = self.copy() new_obj.explanations = [ explanation.counter_explanations( only_ensured=only_ensured, include_potential=include_potential, copy=True ) for explanation in self.explanations ] return new_obj for explanation in self.explanations: explanation.counter_explanations( only_ensured=only_ensured, include_potential=include_potential, copy=False ) return self
[docs] def counter(self, only_ensured=False, include_potential=True, copy=True): """Shorthand delegator for :meth:`.counter_explanations`.""" return self.counter_explanations( only_ensured=only_ensured, include_potential=include_potential, copy=copy )
[docs] def ensured_explanations(self, include_potential=True, copy=True): """ Return a copy with only ensured explanations. Ensured explanations are individual rules that have a narrower uncertainty interval. Parameters ---------- include_potential : bool, default=True Determines whether to include potential explanations in the ensured explanations. copy : bool, default=True Determines whether to return a copy of the explanations or modify them in place. Returns ------- AlternativeExplanations A new `AlternativeExplanations` object containing only ensured explanations. """ if copy: new_obj = self.copy() new_obj.explanations = [ explanation.ensured_explanations(include_potential=include_potential, copy=True) for explanation in self.explanations ] return new_obj for explanation in self.explanations: explanation.ensured_explanations(include_potential=include_potential, copy=False) return self
[docs] def ensured(self, include_potential=True, copy=True): """Shorthand delegator for :meth:`.ensured_explanations`.""" return self.ensured_explanations(include_potential=include_potential, copy=copy)
[docs] def pareto_explanations( self, include_potential: bool = True, copy: bool = True, *, pareto_cost: str = "uncertainty_width", ): """Return a copy with only output-envelope Pareto alternatives. Parameters ---------- include_potential : bool, default=True Determines whether to include potential explanations before extracting the Pareto frontier. copy : bool, default=True Determines whether to return a copy of the explanations or modify them in place. pareto_cost : str, default="uncertainty_width" The Pareto cost dimension minimized along the output axis. Returns ------- AlternativeExplanations A new ``AlternativeExplanations`` object containing Pareto-front alternatives. """ if copy: new_obj = self.copy() new_obj.explanations = [ explanation.pareto_explanations( include_potential=include_potential, copy=True, pareto_cost=pareto_cost, ) for explanation in self.explanations ] return new_obj for explanation in self.explanations: explanation.pareto_explanations( include_potential=include_potential, copy=False, pareto_cost=pareto_cost, ) return self
[docs] def pareto( self, include_potential: bool = True, copy: bool = True, *, pareto_cost: str = "uncertainty_width", ): """Shorthand delegator for :meth:`.pareto_explanations`.""" return self.pareto_explanations( include_potential=include_potential, copy=copy, pareto_cost=pareto_cost, )
class FrozenCalibratedExplainer: """A class that wraps an explainer to provide a read-only interface. Prevents modification of the underlying explainer, ensuring its state remains unchanged. """ def __init__(self, explainer): """Initialize a new instance of the FrozenCalibratedExplainer class. Parameters ---------- explainer : CalibratedExplainer The explainer to be wrapped. """ try: self._explainer = deepcopy(explainer) except ( Exception ): # adr002_allow # pragma: no cover - defensive fallback for unpickleable state # Deepcopy of complex explainer objects can fail; log at DEBUG # instead of emitting a RuntimeWarning to avoid noisy test output. try: import logging logging.getLogger(__name__).debug( "Deepcopy of explainer failed; using original instance for frozen wrapper" ) except Exception: # adr002_allow # If logging fails, fall back to warnings to preserve behavior warnings.warn( "Deepcopy of explainer failed; using original instance for frozen wrapper", UserWarning, stacklevel=2, ) self._explainer = explainer @property def explainer(self): """Return the wrapped explainer instance.""" return self._explainer @property def x_cal(self): """ Retrieves the calibrated feature matrix from the underlying explainer. This property provides access to the feature matrix used in the explainer, allowing users to understand the data being analyzed. Returns ------- numpy.ndarray: The calibrated feature matrix. """ return self._explainer.x_cal @property def y_cal(self): """ Retrieves the calibrated target values from the underlying explainer. This property provides access to the target values used in the explainer, allowing users to understand the data being analyzed. Returns ------- numpy.ndarray: The calibrated target values. """ return self._explainer.y_cal @property def num_features(self): """ Retrieves the number of features in the dataset. This property provides access to the count of features that the underlying explainer is using. It is useful for understanding the dimensionality of the data being analyzed. Returns ------- int: The number of features in the dataset. """ return self._explainer.num_features @property def categorical_features(self): """ Retrieves the indices of categorical features from the underlying explainer. This property provides access to the indices of categorical features used in the explainer, allowing users to understand the data being analyzed. Returns ------- list: The indices of categorical features. """ return self._explainer.categorical_features @property def categorical_labels(self): """ Retrieves the labels for categorical features from the underlying explainer. This property provides access to the labels for categorical features used in the explainer, allowing users to understand the data being analyzed. Returns ------- list: The labels for categorical features. """ return self._explainer.categorical_labels @property def feature_values(self): """ Retrieves the unique values for each feature from the underlying explainer. This property provides access to the unique values for each feature used in the explainer, allowing users to understand the data being analyzed. Returns ------- list: The unique values for each feature. """ return self._explainer.feature_values @property def feature_names(self): """ Retrieves the names of the features from the underlying explainer. This property provides access to the names of the features used in the explainer, allowing users to understand the data being analyzed. Returns ------- list: The names of the features. """ return self._explainer.feature_names @property def class_labels(self): """ Retrieves the labels for the classes from the underlying explainer. This property provides access to the labels for the classes used in the explainer, allowing users to understand the data being analyzed. Returns ------- list: The labels for the classes. """ return self._explainer.class_labels @property def sample_percentiles(self): """ Retrieves the sample percentiles from the underlying explainer. This property provides access to the percentiles of the samples used in the explainer, allowing users to understand the distribution of the data being analyzed. Returns ------- list: The sample percentiles as a list. """ return self._explainer.sample_percentiles @property def mode(self): """ Retrieves the mode of the explainer from the underlying explainer. This property provides access to the mode of the explainer, allowing users to understand the type of problem being analyzed. Returns ------- str: The mode of the explainer. """ return self._explainer.mode @property def is_multiclass(self): """ Retrieves a boolean indicating if the problem is multiclass from the underlying explainer. This property provides access to a boolean value indicating if the problem is multiclass, allowing users to understand the type of problem being analyzed. Returns ------- bool: True if the problem is multiclass, False otherwise. """ return self._explainer.is_multiclass @property def discretizer(self): """ Retrieves the discretizer used by the explainer from the underlying explainer. This property provides access to the discretizer used by the explainer, allowing users to understand the discretization process. Returns ------- Discretizer: The discretizer used by the explainer. """ return self._explainer.discretizer @property def discretize(self): """Public accessor for the discretize function (testing helper).""" return self._explainer.discretize @property def rule_boundaries(self): """Expose the underlying rule boundaries helper.""" return self._explainer.rule_boundaries @property def learner(self): """ Retrieves the learner associated with the explainer from the underlying explainer. This property provides access to the learner associated with the explainer, allowing users to understand the learning process. Returns ------- object: The learner associated with the explainer. """ return self._explainer.learner @property def difficulty_estimator(self): """ Retrieves the estimator for difficulty levels from the underlying explainer. This property provides access to the estimator for difficulty levels used in the explainer, allowing users to understand the learning process. Returns ------- object: The estimator for difficulty levels. """ return self._explainer.difficulty_estimator @property def prediction_orchestrator(self): """Expose the underlying prediction orchestrator (read-only).""" return self._explainer.prediction_orchestrator def predict(self, *args, **kwargs): """Forward the public prediction API to the underlying explainer.""" return self._explainer.predict(*args, **kwargs) @property def _preload_lime(self): """ Retrieves the preload_lime function from the underlying explainer. This property provides access to the preload_lime function used by the explainer, allowing users to understand the prediction process. Returns ------- function: The preload_lime function used by the explainer. """ return self._explainer.preload_lime @property def preload_lime(self): """Public accessor for the lime preload helper (testing helper).""" return self._explainer.preload_lime @property def preload_shap(self): """Public accessor for the shap preload helper (testing helper).""" return self._explainer.preload_shap def __setattr__(self, key, value): """Prevent modification of attributes except for '_explainer'.""" if key == "_explainer": super().__setattr__(key, value) else: raise AttributeError("Cannot modify frozen instance") class MultiClassCalibratedExplanations(CalibratedExplanations): """ A class for storing and visualizing calibrated explanations for multi-class classification. This class extends `CalibratedExplanations` to support multi-class explanations, allowing storage and retrieval of explanations per instance using a dictionary. """ def __init__(self, calibrated_explainer, x_test, bins, num_classes, explanations=None): """Initialize multiclass explanation storage for one or more instances.""" x_test = validate_and_prepare_input(calibrated_explainer, x_test) super().__init__(calibrated_explainer, x_test, None, bins) self.num_classes = num_classes if explanations is None: self.explanations = [{} for _ in range(len(x_test))] else: self.explanations = deepcopy(explanations) def _first_explanation_for_instance(self, index): """Return the first explanation stored for an instance regardless of class key.""" if index < 0 or index >= len(self.explanations): return None instance_explanations = self.explanations[index] if not instance_explanations: return None return next(iter(instance_explanations.values())) @property def X_test(self): # noqa: N802 """Backward-compatible alias for x_test.""" return self.x_test def __repr__(self): """Return the string representation of the MultiClassCalibratedExplanations object.""" explanations_str = ( "\n" + f"MultiClassCalibratedExplanations({len(self.explanations)} explanations):\n" ) first_explanation = self._first_explanation_for_instance(0) if first_explanation is None: return explanations_str labels = first_explanation.get_class_labels() for i in range(len(self.explanations)): explanations_str += f"explanation({i}):\n" for class_key, label in labels.items(): label_explanation = self.__getitem__((i, class_key)) explanations_str += f"explanation for label({label}):\n" explanations_str += str(label_explanation) return explanations_str def __getitem__(self, key): """ Return the explanation for the given key. If key is an integer, return all class labels explanations at that index as MultiClassCalibratedExplanations. If key is a tuple (index, class_idx), return the explanation for a specific class label as FactualExplanation. """ if isinstance(key, int): # Mirror CalibratedExplanations semantics: integer indexing returns # a single-instance view. x_single = np.atleast_2d(self.x_test[key]) return MultiClassCalibratedExplanations( self.calibrated_explainer, x_single, self.bins, self.num_classes, [self.explanations[key]], ) if isinstance(key, slice): return MultiClassCalibratedExplanations( self.calibrated_explainer, self.x_test[key], self.bins, self.num_classes, self.explanations[key], ) if isinstance(key, (list, np.ndarray)): arr = np.asarray(key) if arr.dtype == bool: if len(arr) != len(self.explanations): raise IndexError( "Boolean index length must match number of explanations in collection." ) indices = np.where(arr)[0] else: indices = np.asarray(arr, dtype=int) selected_explanations = [self.explanations[int(i)] for i in indices] return MultiClassCalibratedExplanations( self.calibrated_explainer, self.x_test[indices], self.bins, self.num_classes, selected_explanations, ) elif isinstance(key, tuple) and len(key) == 2: # Return Factual explanation of only one class label explanation index, class_idx = key # Accept both Python ints and numpy integer types if isinstance(class_idx, (int, np.integer)): return self.explanations[index].get(int(class_idx), None) elif isinstance(class_idx, str): first_explanation = self._first_explanation_for_instance(index) if first_explanation is None: return None labels = first_explanation.get_class_labels() try: class_idx = list(labels.keys())[list(labels.values()).index(class_idx)] except ValueError as exc: raise KeyError(f"Unknown class label '{class_idx}' for index {index}.") from exc return self.explanations[index].get(int(class_idx), None) raise ValidationError("Invalid argument type. Use an index (int) or (index, class) tuple.") def get_explanation(self, index, class_idx=None): """Return explanation(s) at ``index``, optionally narrowed to ``class_idx``.""" if class_idx is None: return self[index] return self[(index, class_idx)] # ------------------------------------------------------------------ # Multiclass-specific overrides (dispatch into per-class dicts) # ------------------------------------------------------------------ def __iter__(self): """Iterate yielding single-instance views (align with base semantics).""" for i in range(len(self.explanations)): yield self[i] @classmethod def from_json(cls, payload: Mapping[str, Any]) -> ExportedMultiClassExplanationCollection: """Materialise grouped multiclass explanations from exported JSON payload. Raises ------ ValidationError If top-level or item-level schema versions are missing/unsupported, or required multiclass keys cannot be restored. """ from ..serialization import from_json as _explanation_from_json expected_schema = "1.0.0" schema_version = payload.get("schema_version") if schema_version != expected_schema: raise ValidationError( "Unsupported multiclass payload schema version.", details={"expected": expected_schema, "received": schema_version}, ) explanations_blob = payload.get("explanations", []) if not isinstance(explanations_blob, list): raise ValidationError( "Multiclass payload explanations must be a list.", details={"type": type(explanations_blob).__name__}, ) grouped: dict[int, dict[int, DomainExplanation]] = {} for item in explanations_blob: if not isinstance(item, Mapping): raise ValidationError( "Each multiclass explanation item must be a mapping.", details={"type": type(item).__name__}, ) item_schema = item.get("schema_version") if item_schema != expected_schema: raise ValidationError( "Unsupported multiclass explanation item schema version.", details={"expected": expected_schema, "received": item_schema}, ) try: instance_index = int(item.get("index")) except (TypeError, ValueError, OverflowError) as exc: raise ValidationError( "Multiclass explanation item is missing a valid instance index.", details={"index": item.get("index")}, ) from exc metadata_map = item.get("metadata") metadata_dict = metadata_map if isinstance(metadata_map, Mapping) else {} class_index_raw = item.get("class_index", metadata_dict.get("class_index")) class_label = item.get("class_label", metadata_dict.get("class_label")) if class_index_raw is None: raise ValidationError( "Multiclass explanation item is missing class_index.", details={"index": instance_index}, ) try: class_index = int(class_index_raw) except (TypeError, ValueError, OverflowError) as exc: raise ValidationError( "Multiclass explanation item has invalid class_index.", details={"index": instance_index, "class_index": class_index_raw}, ) from exc domain_exp = _explanation_from_json(item) metadata_out = ( dict(domain_exp.metadata) if isinstance(domain_exp.metadata, Mapping) else {} ) metadata_out["class_index"] = class_index if class_label is not None: metadata_out["class_label"] = class_label domain_exp.metadata = metadata_out or None per_instance = grouped.setdefault(instance_index, {}) if class_index in per_instance: raise ValidationError( "Duplicate class_index for multiclass explanation item.", details={"index": instance_index, "class_index": class_index}, ) per_instance[class_index] = domain_exp ordered = tuple(grouped[idx] for idx in sorted(grouped)) metadata = payload.get("collection", {}) or {} return ExportedMultiClassExplanationCollection( metadata=cast(Mapping[str, Any], _jsonify(metadata)), explanations_by_instance=ordered, ) def add_conjunctions(self, n_top_features=5, max_rule_size=2, **kwargs): """Apply add_conjunctions to every class-specific explanation.""" for class_dict in self.explanations: for explanation in class_dict.values(): explanation.add_conjunctions(n_top_features, max_rule_size, **kwargs) return self def remove_conjunctions(self): """Apply remove_conjunctions to every class-specific explanation.""" for class_dict in self.explanations: for explanation in class_dict.values(): explanation.remove_conjunctions() return self def reset(self): """Reset each class-specific explanation to original state.""" for class_dict in self.explanations: for explanation in class_dict.values(): explanation.reset() return self def filter_rule_sizes( self, *, rule_sizes: Optional[Any] = None, size_range: Optional[Tuple[int, int]] = None, copy: bool = True, ): """Filter rules by size across every class-specific explanation.""" if copy: new_obj = self.copy() new_obj.explanations = [ { k: exp.filter_rule_sizes( rule_sizes=rule_sizes, size_range=size_range, copy=True ) for k, exp in class_dict.items() } for class_dict in self.explanations ] return new_obj for idx, class_dict in enumerate(self.explanations): for cls_key, explanation in class_dict.items(): self.explanations[idx][cls_key] = explanation.filter_rule_sizes( rule_sizes=rule_sizes, size_range=size_range, copy=False ) return self def filter_features(self, *, exclude_features=None, include_features=None, copy: bool = True): """Filter features across every class-specific explanation.""" if copy: new_obj = self.copy() new_obj.explanations = [ { k: exp.filter_features( exclude_features=exclude_features, include_features=include_features, copy=True, ) for k, exp in class_dict.items() } for class_dict in self.explanations ] return new_obj for idx, class_dict in enumerate(self.explanations): for cls_key, explanation in class_dict.items(): self.explanations[idx][cls_key] = explanation.filter_features( exclude_features=exclude_features, include_features=include_features, copy=False ) return self def get_rules(self): """Return per-instance, per-class rule payloads. Returns ------- list of dict Each item is a mapping {class_key: rules_payload} for that instance. """ return [ {cls_key: exp.get_rules() for cls_key, exp in class_dict.items()} for class_dict in self.explanations ] # Safe adapters / explicit not-implemented for adapters that assume flat lists def as_lime(self): """Raise for multiclass collections where a flat LIME export is undefined.""" self._deprecate_lime_shap_surface( "as_lime", "external_plugins.integrations.lime_pipeline.LimePipeline(...).explain(...)", removal_version="v0.11.3", ) raise NotImplementedError( "as_lime() is not supported for multi-label collections. " "Call get_explanation(i, cls).as_lime() for a specific class, or iterate over the collection " "to build a per-class LIME mapping. If you need an aggregated LIME export, convert each per-class " "explanation via get_explanation(i, cls).as_lime() and combine the results in your caller." ) def as_shap(self): """Raise for multiclass collections where a flat SHAP export is undefined.""" self._deprecate_lime_shap_surface( "as_shap", "external_plugins.integrations.shap_pipeline.ShapPipeline(...).explain(...)", removal_version="v0.11.3", ) raise NotImplementedError( "as_shap() is not supported for multi-label collections. " "Call get_explanation(i, cls).as_shap() for a specific class, or iterate and aggregate per-class SHAP outputs. " "Aggregating SHAP across classes is application-specific; prefer per-class SHAP objects for downstream use." ) def to_narrative(self, *args, **kwargs): """ Generate narratives for a multiclass (multi-label) collection. The method returns per-instance, per-class narratives. The behaviour depends on ``output_format`` (same semantics as single-instance :meth:`to_narrative`): - ``output_format='dict'``: returns ``List[Dict[class_key, narrative_dict]]`` where each item corresponds to an instance and maps class keys to the narrative dict for that class. - ``output_format='text'``: returns a single combined text containing the narratives for every instance and class (human-readable). - ``output_format='dataframe'``: returns a pandas DataFrame with columns ``['instance', 'class', 'narrative']`` (requires pandas). For other formats (e.g., 'html', 'markdown') the implementation will attempt to coerce per-class outputs into the requested format where reasonable. """ # Normalize kwargs used by the single-explanation API template_path = kwargs.pop("template_path", args[0] if len(args) > 0 else "exp.yaml") expertise_level = kwargs.pop( "expertise_level", kwargs.get("expertise_level", ("beginner", "advanced")) ) output_format = kwargs.pop( "output_format", kwargs.get("output_format", kwargs.get("output", "dataframe")) ) conjunction_separator = kwargs.pop( "conjunction_separator", kwargs.get("conjunction_separator", " AND ") ) align_weights = kwargs.pop("align_weights", kwargs.get("align_weights", True)) # Helper to convert a per-class explanation to the desired intermediate dict per_instance = [] for _i, class_dict in enumerate(self.explanations): inst_map = {} for cls_key, explanation in class_dict.items(): try: narr = explanation.to_narrative( template_path=template_path, expertise_level=expertise_level, output_format="dict", conjunction_separator=conjunction_separator, align_weights=align_weights, **kwargs, ) except (AttributeError, TypeError, ValueError, KeyError): # Fallback: try to obtain text output narr = { "text": explanation.to_narrative( template_path=template_path, expertise_level=expertise_level, output_format="text", conjunction_separator=conjunction_separator, align_weights=align_weights, **kwargs, ) } inst_map[int(cls_key)] = narr per_instance.append(inst_map) # Return according to requested format if output_format == "dict": return per_instance if output_format == "text": parts = [] for i, inst_map in enumerate(per_instance): parts.append(f"Instance {i}:") for cls_key, narr in inst_map.items(): label = None first_exp = self._first_explanation_for_instance(i) if first_exp is not None: labels = first_exp.get_class_labels() label = labels.get(cls_key, None) hdr = f" Class {cls_key}" + (f" ({label})" if label is not None else "") parts.append(hdr) if isinstance(narr, dict): text = narr.get("text") or narr.get("short") or str(narr) else: text = str(narr) parts.append(text) parts.append("") return "\n".join(parts) if output_format == "dataframe": try: import pandas as pd except ImportError as exc: # pragma: no cover - pandas import error path raise ImportError("pandas is required for output_format='dataframe'") from exc rows = [] for i, inst_map in enumerate(per_instance): for cls_key, narr in inst_map.items(): # narr is a dict produced by single-explanation output_format='dict' # Attempt to extract a compact textual narrative for a 'narrative' column if isinstance(narr, dict): text = narr.get("text") or narr.get("short") or str(narr) else: text = str(narr) rows.append({"instance": i, "class": int(cls_key), "narrative": text}) df = pd.DataFrame(rows) return df # Fall back to returning the dict structure for unknown formats return per_instance def to_json(self, *, include_version: bool = True) -> Mapping[str, Any]: """Return a JSON-friendly payload describing this multiclass collection. This mirrors :meth:`CalibratedExplanations.to_json` but emits one exported explanation per (instance, class) pair. Each legacy payload is augmented with ``class_index`` and, when available, ``class_label``. """ from ..serialization import to_json as _explanation_to_json instances = [] for idx, class_dict in enumerate(self.explanations): for cls_key, exp in class_dict.items(): # Build legacy-shaped payload and annotate with class info payload = dict(self._legacy_payload(exp)) payload["class_index"] = int(cls_key) try: first = self._first_explanation_for_instance(idx) if first is not None: labels = first.get_class_labels() payload.setdefault("class_label", labels.get(int(cls_key))) except (AttributeError, TypeError, ValueError, KeyError): _LOGGER.debug( "Failed to resolve class_label while exporting multiclass payload", exc_info=True, ) domain = legacy_to_domain(int(idx), payload) provenance = getattr(exp, "provenance", None) metadata = getattr(exp, "metadata", None) if provenance is not None: domain.provenance = cast(Optional[Mapping[str, Any]], _jsonify(provenance)) if metadata is not None: domain.metadata = cast(Optional[Mapping[str, Any]], _jsonify(metadata)) instances.append(_explanation_to_json(domain, include_version=include_version)) payload: dict[str, Any] = { "collection": self._collection_metadata(), "explanations": instances, } if include_version: payload.setdefault("schema_version", "1.0.0") return payload def to_json_stream(self, *, chunk_size: int = 256, format: str = "jsonl"): """Stream the multiclass collection as JSON. Yields the same fragments as :meth:`CalibratedExplanations.to_json_stream` but emits one item per (instance, class) pair. """ from ..serialization import to_json as _explanation_to_json if format not in {"jsonl", "chunked"}: raise ValidationError("Unsupported stream format", details={"format": format}) start = time() tracemalloc.start() metadata = dict(self._collection_metadata()) meta_fragment = {"collection": metadata, "schema_version": "1.0.0"} yield json.dumps(meta_fragment, default=_jsonify) chunk: List[str] = [] n = 0 for idx, class_dict in enumerate(self.explanations): for cls_key, exp in class_dict.items(): payload = dict(self._legacy_payload(exp)) payload["class_index"] = int(cls_key) try: first = self._first_explanation_for_instance(idx) if first is not None: labels = first.get_class_labels() payload.setdefault("class_label", labels.get(int(cls_key))) except (AttributeError, TypeError, ValueError, KeyError): _LOGGER.debug( "Failed to resolve class_label while streaming multiclass payload", exc_info=True, ) domain = legacy_to_domain(int(idx), payload) provenance = getattr(exp, "provenance", None) metadata_exp = getattr(exp, "metadata", None) if provenance is not None: domain.provenance = cast(Optional[Mapping[str, Any]], _jsonify(provenance)) if metadata_exp is not None: domain.metadata = cast(Optional[Mapping[str, Any]], _jsonify(metadata_exp)) item = _explanation_to_json(domain, include_version=True) line = json.dumps(item, default=_jsonify) n += 1 if format == "jsonl": yield line else: # chunked chunk.append(line) if len(chunk) >= chunk_size: yield "[" + ",".join(chunk) + "]" chunk = [] if format == "chunked" and chunk: yield "[" + ",".join(chunk) + "]" peak = tracemalloc.get_traced_memory()[1] tracemalloc.stop() elapsed = time() - start telemetry = { "export_rows": n, "chunk_size": chunk_size, "mode": getattr(self.calibrated_explainer, "mode", None), "peak_memory_mb": round(float(peak) / (1024 * 1024), 3), "elapsed_seconds": round(float(elapsed), 3), "schema_version": "1.0.0", "build_id": None, "feature_flags": None, } try: metadata.setdefault("export_telemetry", {}) metadata["export_telemetry"].update(telemetry) underlying = getattr(self.calibrated_explainer, "_explainer", None) if underlying is not None: try: last = getattr(underlying, "_last_telemetry", None) or {} last.update({"export": telemetry}) underlying._last_telemetry = last except Exception: # adr002_allow _LOGGER.info( "failed to attach export telemetry to underlying explainer", exc_info=True, ) except Exception: # adr002_allow _LOGGER.info("failed to attach export telemetry to collection", exc_info=True) yield json.dumps({"export_telemetry": telemetry}, default=_jsonify) # Properties that aggregate per-class values into per-instance dicts @property def predictions(self): """Return per-instance per-class scalar predictions as a list of dicts.""" return [ {int(cls_key): getattr(exp, "predict", None) for cls_key, exp in class_dict.items()} for class_dict in self.explanations ] @property def prediction_interval(self): """Return per-instance per-class prediction intervals as a list of dicts.""" return [ { int(cls_key): getattr(exp, "prediction_interval", (None, None)) for cls_key, exp in class_dict.items() } for class_dict in self.explanations ] @property def probabilities(self): """Return per-instance per-class probability vectors as a list of dicts. Each dict maps class_key -> the stored `prediction_probabilities` (if present) or None when unavailable. """ return [ { int(cls_key): getattr(exp, "prediction_probabilities", None) for cls_key, exp in class_dict.items() } for class_dict in self.explanations ] def plot( self, index=None, class_idx=None, filter_top=10, show=True, filename="", uncertainty=False, style="regular", **kwargs, ): """Plot multiclass explanations as factual or alternative views.""" if len(self.explanations) > 0: first_explanation = self._first_explanation_for_instance(0) if isinstance(first_explanation, FactualExplanation): self.plot_factual( index=index, class_idx=class_idx, filter_top=filter_top, show=show, filename=filename, uncertainty=uncertainty, style=style, **kwargs, ) elif isinstance(first_explanation, AlternativeExplanation): self.plot_alternative( index=index, class_idx=class_idx, filter_top=filter_top, show=show, filename=filename, uncertainty=uncertainty, style=style, **kwargs, ) else: warnings.warn("No explanations found", stacklevel=2) def plot_alternative( self, index=None, class_idx=None, filter_top=10, show=True, filename="", uncertainty=False, style="regular", **kwargs, ): """ Plot explanations for a given instance and class. If no class is specified, plots explanations for all classes at that index. """ style_override = kwargs.get("style_override", get_multiclass_config()) if index is not None: if class_idx is not None: explanation = self[index, class_idx] if explanation: explanation.plot( filter_top=filter_top, show=show, filename=filename, uncertainty=uncertainty, style=style, ) else: warnings.warn( f"No explanation found for instance {index}, class {class_idx}", stacklevel=2, ) else: self.__getitem__(index).plot( filter_top=filter_top, show=show, filename=filename, uncertainty=uncertainty, style=style, ) else: import matplotlib.colors as mcolors rgb = np.array(list(permutations(range(0, 256, 11), 3))) / 255.0 colors = [rgb.tolist()[i * 23] for i in range(25)] colors = list(mcolors.BASE_COLORS.values()) for i, class_explanations in enumerate(self.explanations): # Ensure style_override gets passed through class_explanations_list = list(class_explanations.values()) # Respect the explicit arguments passed to plot_alternative() # (do not override via kwargs in this multi-label/all-classes branch). iter_filename = filename iter_show = show rnk_metric = kwargs.get("rnk_metric", "ensured") if rnk_metric is None: rnk_metric = "ensured" rnk_weight = kwargs.get("rnk_weight", 0.5) if rnk_metric == "uncertainty": rnk_weight = 1.0 rnk_metric = "ensured" alternatives = [] for ex in list(class_explanations.values()): get_rules = getattr(ex, "get_rules", None) if callable(get_rules): alternatives.append(get_rules()) else: alternatives.append(ex._get_rules()) # Ensure each explanation has a sensible `index` set before # precondition checks or plotting. Some explanation objects # may be frozen; use best-effort assignment. for ex in class_explanations_list: with contextlib.suppress(Exception): ex.index = i with contextlib.suppress(Exception): ex._check_preconditions() predicts = [getattr(ex, "prediction", None) for ex in class_explanations_list] filter_top = [len(alternative["rule"]) for alternative in alternatives] """if filter_top is None: filter_top = num_features_to_show_list else: filter_top = [filter_top for factual in factuals] filter_top = [np.min([num_features_to_show, filter_]) for num_features_to_show, filter_ in zip(num_features_to_show_list,filter_top)]""" if len(filter_top) <= 0: warnings.warn( f"The explanation has no rules to plot. The index of the instance is {i}", stacklevel=2, ) return if len(iter_filename) > 0: path, iter_filename, title, ext = prepare_for_saving(iter_filename) path = f"plots/{path}" save_ext = [ext] else: path = "" title = "" save_ext = [] feature_predicts = [ { "predict": alternative["predict"], "low": alternative["predict_low"], "high": alternative["predict_high"], "classes": alternative["classes"], } for alternative in alternatives ] widths = [ np.reshape( np.array(alternative["weight_high"]) - np.array(alternative["weight_low"]), (len(alternative["weight"])), ) for alternative in alternatives ] features_weights = [ np.reshape(alternative["weight"], (len(alternative["weight"]))) for alternative in alternatives ] def _rank_features_for_multiclass(explanation, *args, **rank_kwargs): rank_fn = getattr(explanation, "rank_features", None) if callable(rank_fn): return rank_fn(*args, **rank_kwargs) return explanation._rank_features(*args, **rank_kwargs) if rnk_metric == "feature_weight": features_list_to_plot = [ _rank_features_for_multiclass( ex, feature_weights, width=width, num_to_show=num_to_show ) for ex, feature_weights, width, num_to_show in zip( list(class_explanations.values()), features_weights, widths, filter_top, strict=False, ) ] else: predictions = [ alternative["predict"] if predict["predict"] > 0.5 else [1 - p for p in alternative["predict"]] for alternative, predict in zip(alternatives, predicts, strict=False) ] rankings = [ calculate_metrics( uncertainty=[ alternative["predict_high"][i] - alternative["predict_low"][i] for i in range(len(alternative["rule"])) ], prediction=prediction, w=rnk_weight, metric=rnk_metric, ) for alternative, prediction in zip(alternatives, predictions, strict=False) ] features_list_to_plot = [ _rank_features_for_multiclass(ex, width=ranking, num_to_show=num_to_show) for ex, ranking, num_to_show in zip( list(class_explanations.values()), rankings, filter_top, strict=False ) ] #################### if "style" in kwargs and kwargs["style"] == "triangular": raise ValidationError( "triangular style does not support multi labels explanation, please set multi_explanation to None and try again!." ) """probas = [predict["predict"] for predict in predicts] uncertainties = [np.abs(predict["high"] - predict["low"]) for predict in predicts] rule_probas = [alternative["predict"] for alternative in alternatives] rule_uncertainties = [np.abs( np.array(alternative["predict_high"]) - np.array(alternative["predict_low"]) ) for alternative in alternatives] # Use list comprehension or NumPy array indexing to select elements selected_rule_probas = [[rule_proba[i] for i in features_to_plot] \ for rule_proba, features_to_plot in zip(rule_probas, features_list_to_plot)] selected_rule_uncertainties = [[rule_uncertainty[i] for i in features_to_plot] \ for rule_uncertainty, features_to_plot in zip(rule_uncertainties, features_list_to_plot)] _plot_triangular( self, proba, uncertainty, selected_rule_proba, selected_rule_uncertainty, num_to_show_, title=title, path=path, show=show, save_ext=save_ext, style_override=style_override, )""" return alternatives_values = [alternative["value"] for alternative in alternatives] column_names_list = [alternative["rule"] for alternative in alternatives] _plot_alternative_dict( list(class_explanations.values()), alternatives_values, predicts, feature_predicts, features_list_to_plot, num_to_show_list=filter_top, colors=colors, column_names_list=column_names_list, title=title, path=path, show=iter_show, save_ext=save_ext, style_override=style_override, idx=i, ) def merge_rules(self, factuals): # pragma: no cover # dead code: zero callers """Merge rule dictionaries from multiple class-specific factual explanations.""" merged_factuals = { "base_predict": [], "base_predict_low": [], "base_predict_high": [], "predict": [], "predict_low": [], "predict_high": [], "weight": [], "weight_low": [], "weight_high": [], "value": [], "rule": [], "feature": [], "feature_value": [], "is_conjunctive": [], "classes": [], } for _i, factual in enumerate(factuals): # pylint: disable=invalid-name base_predicts = [factual["base_predict"][0] for _ in range(len(factual["rule"]))] base_predict_low = [factual["base_predict_low"][0] for _ in range(len(factual["rule"]))] base_predict_high = [ factual["base_predict_high"][0] for _ in range(len(factual["rule"])) ] classes = [factual["classes"] for _ in range(len(factual["rule"]))] merged_factuals["base_predict"].extend(base_predicts) merged_factuals["base_predict_low"].extend(base_predict_low) merged_factuals["base_predict_high"].extend(base_predict_high) merged_factuals["classes"].extend(classes) merged_factuals["predict"].extend(factual["predict"]) merged_factuals["predict_low"].extend(factual["predict_low"]) merged_factuals["predict_high"].extend(factual["predict_high"]) merged_factuals["weight"].extend(factual["weight"]) merged_factuals["weight_low"].extend(factual["weight_low"]) merged_factuals["weight_high"].extend(factual["weight_high"]) merged_factuals["value"].extend(factual["value"]) merged_factuals["rule"].extend(factual["rule"]) merged_factuals["feature"].extend(factual["feature"]) merged_factuals["feature_value"].extend(factual["feature_value"]) merged_factuals["is_conjunctive"].extend(factual["is_conjunctive"]) return merged_factuals def plot_factual( self, index=None, class_idx=None, filter_top=10, show=True, filename="", uncertainty=False, style="regular", **kwargs, ): """ Plot explanations for a given instance and class. If no class is specified, plots explanations for all classes at that index. """ style_override = kwargs.get("style_override", get_multiclass_config()) if index is not None: if class_idx is not None: explanation = self[index, class_idx] if explanation: explanation.plot( filter_top=filter_top, show=show, filename=filename, uncertainty=uncertainty, style=style, ) else: warnings.warn( f"No explanation found for instance {index}, class {class_idx}", stacklevel=2, ) else: self.__getitem__(index).plot( filter_top=filter_top, show=show, filename=filename, uncertainty=uncertainty, style=style, ) else: for i, class_explanations in enumerate(self.explanations): # Ensure style_override gets passed through # Delegate non-render payload construction to helper for testability payload = self._build_factual_plot_payload( i=i, class_explanations=class_explanations, filename=filename, show=show, uncertainty=uncertainty, style_override=style_override, kwargs=kwargs, ) if payload is None: # No rules to plot for this instance continue _plot_probabilistic_dict( payload["class_explanations_list"], payload["factual_values"], payload["predicts"], payload["feature_weights_list"], payload["features_list_to_plot"], payload["filter_top"], payload["colors"], payload["column_names_list"], title=payload["title"], path=payload["path"], interval=payload["interval"], show=payload["show"], idx=payload["idx"], save_ext=payload["save_ext"], style_override=payload["style_override"], ) def sort_factuals_by_rule(self, factuals): # pragma: no cover # ADR-023: multiclass viz """Group factual explanation entries by rule string across classes.""" sorted_factuals = {} factual_rule = { "base_predict": [], "base_predict_low": [], "base_predict_high": [], "predict": [], "predict_low": [], "predict_high": [], "weight": [], "weight_low": [], "weight_high": [], "value": [], "rule": [], "feature": [], "feature_value": [], "is_conjunctive": [], "classes": [], } for _i, factual in enumerate(factuals): # pylint: disable=invalid-name base_predict = factual["base_predict"][0] base_predict_low = factual["base_predict_low"][0] base_predict_high = factual["base_predict_high"][0] cls = factual["classes"] for j, rule in enumerate(factual["rule"]): if rule not in sorted_factuals: sorted_factuals[rule] = deepcopy(factual_rule) sorted_factuals[rule]["base_predict"].append(base_predict) sorted_factuals[rule]["base_predict_low"].append(base_predict_low) sorted_factuals[rule]["base_predict_high"].append(base_predict_high) sorted_factuals[rule]["classes"].append(cls) sorted_factuals[rule]["predict"].append(factual["predict"][j]) sorted_factuals[rule]["predict_low"].append(factual["predict_low"][j]) sorted_factuals[rule]["predict_high"].append(factual["predict_high"][j]) sorted_factuals[rule]["weight"].append(factual["weight"][j]) sorted_factuals[rule]["weight_low"].append(factual["weight_low"][j]) sorted_factuals[rule]["weight_high"].append(factual["weight_high"][j]) sorted_factuals[rule]["value"].append(factual["value"][j]) sorted_factuals[rule]["rule"].append(factual["rule"][j]) sorted_factuals[rule]["feature"].append(factual["feature"][j]) sorted_factuals[rule]["feature_value"].append(factual["feature_value"][j]) sorted_factuals[rule]["is_conjunctive"].append(factual["is_conjunctive"][j]) return sorted_factuals def _build_factual_plot_payload( self, *, i: int, class_explanations: Mapping[Any, Any], filename: str, show: bool, uncertainty: bool, style_override: Any, kwargs: Mapping[str, Any], ) -> dict | None: """Construct the non-render payload for plotting factual multiclass explanations. Returns a dict containing the exact arguments needed by `_plot_probabilistic_dict`. Returns ``None`` when there are no rules to plot for the given instance. """ # Prepare colors similar to previous inline logic import matplotlib.colors as mcolors rgb = np.array(list(permutations(range(0, 256, 11), 3))) / 255.0 colors = [rgb.tolist()[i * 23] for i in range(25)] colors = list(mcolors.BASE_COLORS.values()) class_explanations_list = list(class_explanations.values()) rnk_metric = kwargs.get("rnk_metric", "feature_weight") if rnk_metric is None: rnk_metric = "feature_weight" rnk_weight = kwargs.get("rnk_weight", 0.5) if rnk_metric == "uncertainty": rnk_weight = 1.0 rnk_metric = "ensured" factuals = [ex.get_rules() for ex in class_explanations_list] factuals = self.sort_factuals_by_rule(factuals) # Ensure each explanation has a sensible `index` set before checks for ex in class_explanations_list: with contextlib.suppress(Exception): ex.index = i with contextlib.suppress(Exception): ex._check_preconditions() predicts = [getattr(ex, "prediction", None) for ex in class_explanations_list] filter_top = [len(factual["weight"]) for factual in list(factuals.values())] if len(filter_top) <= 0: return None if uncertainty: feature_weights_list = [ { "predict": factual["weight"], "low": factual["weight_low"], "high": factual["weight_high"], "classes": factual["classes"], } for factual in list(factuals.values()) ] else: feature_weights_list = [ {"predict": factual["weight"], "classes": factual["classes"]} for factual in list(factuals.values()) ] widths = [ np.reshape( np.array(factual["weight_high"]) - np.array(factual["weight_low"]), (len(factual["weight"])), ) for factual in list(factuals.values()) ] first_explanation = next(iter(class_explanations.values())) rank_features = getattr(first_explanation, "rank_features", None) if not callable(rank_features): rank_features = first_explanation._rank_features if rnk_metric == "feature_weight": features_list_to_plot = [ rank_features(factual["weight"], width=width, num_to_show=num_to_show) for factual, width, num_to_show in zip( list(factuals.values()), widths, filter_top, strict=False ) ] else: rankings = [ calculate_metrics( uncertainty=[ factual["predict_high"][j] - factual["predict_low"][j] for j in range(len(factual["weight"])) ], prediction=factual["predict"], w=rnk_weight, metric=rnk_metric, ) for factual in list(factuals.values()) ] features_list_to_plot = [ rank_features(width=ranking, num_to_show=num_to_show) for ranking, num_to_show in zip(rankings, filter_top, strict=False) ] column_names_list = list(factuals) factual_values = [factual["value"] for factual in list(factuals.values())] # Prepare filename/path/title/save_ext if len(filename) > 0: path, _, title, ext = prepare_for_saving(str(i) + "_" + filename) path = f"plots/{path}" save_ext = [ext] else: path = "" title = "" save_ext = [] return { "class_explanations_list": list(class_explanations.values()), "factual_values": factual_values, "predicts": predicts, "feature_weights_list": feature_weights_list, "features_list_to_plot": features_list_to_plot, "filter_top": filter_top, "colors": colors, "column_names_list": column_names_list, "title": title, "path": path, "interval": uncertainty, "show": show, "idx": i, "save_ext": save_ext, "style_override": style_override, }