# pylint: disable=unknown-option-value
# pylint: disable=too-many-lines, too-many-arguments, invalid-name, too-many-positional-arguments, line-too-long
"""Calibrated explanation containers and visualization helpers.
This module defines the classes used to represent factual, alternative and
fast explanations produced by :class:`~calibrated_explanations.core.CalibratedExplainer`.
Primary classes
---------------
- :class:`CalibratedExplanation` — Abstract base for explanation instances.
- :class:`FactualExplanation` — Factual explanations for an instance.
- :class:`AlternativeExplanation` — Alternative/counterfactual explanations.
- :class:`FastExplanation` — Lightweight fast-mode explanations.
"""
from __future__ import annotations
import contextlib
import math
import re
import warnings
from abc import ABC, abstractmethod
from collections.abc import Sequence
from copy import copy
from dataclasses import dataclass
from types import MappingProxyType
from typing import Any, Dict, Literal, Optional, Tuple
import numpy as np
from pandas import Categorical
from ..plotting import plot_alternative, plot_probabilistic, plot_regression, plot_triangular
from ..utils import (
BinaryEntropyDiscretizer,
BinaryRegressorDiscretizer,
EntropyDiscretizer,
RegressorDiscretizer,
calculate_metrics,
prepare_for_saving,
safe_first_element,
safe_mean,
)
from ..utils.exceptions import CalibratedError, ValidationError
from ..utils.helper import assign_threshold as normalize_threshold
from ..utils.int_utils import collect_ints
from ._conjunctions import ConjunctionState
@dataclass
class RuleWithImpact:
"""Canonical representation of a rule's impact for consistent plotting and narrative."""
rule_id: str
feature: str
text: str
impact: float
direction: Literal["positive", "negative", "neutral"]
base_predict: float
predict: float
value: Any
weight_envelope_low: Optional[float] = None
weight_envelope_high: Optional[float] = None
predict_low: Optional[float] = None
predict_high: Optional[float] = None
# @dataclass
# class PredictionInterval:
# """A dataclass representing a prediction interval for a single feature.
# Attributes
# ----------
# predict: float
# The model's prediction for this feature
# low: float
# The lower bound of the prediction interval
# high: float
# The upper bound of the prediction interval
# """
# predict: float
# low: float
# high: float
# @dataclass
# class FeatureRule:
# """A dataclass representing a rule for a single feature in an explanation.
# Attributes
# ----------
# weight : PredictionInterval
# The weight/importance of this feature rule, containing prediction and interval values
# prediction : PredictionInterval
# The model's prediction and interval for this feature rule
# instance_prediction : PredictionInterval
# The binned prediction data for this feature rule
# current_bin : int
# The bin index containing the current feature value
# rule : str
# String representation of the rule
# feature : int
# Index of the feature this rule applies to
# feature_value : float
# The actual value of the feature
# is_conjunctive : bool
# Whether this rule is part of a conjunction
# value_str : str
# String representation of the feature value
# """
# weight: PredictionInterval
# prediction: PredictionInterval
# instance_prediction: PredictionInterval # from binned data
# current_bin: int
# rule: str
# feature: int
# feature_value: float
# is_conjunctive: bool
# value_str: str
# pylint: disable=too-many-instance-attributes, too-many-locals, too-many-arguments
[docs]
class CalibratedExplanation(ABC):
"""Abstract base class for storing and visualizing calibrated explanations.
Subclasses implement concrete payload building and plotting utilities while
this base class provides shared validation and convenience accessors.
See documentation at ``docs/foundations/concepts/explanation_structures.md``
for details on the internal payload layout.
"""
def __init__(
self,
calibrated_explanations,
index,
x,
binned,
feature_weights,
feature_predict,
prediction,
y_threshold=None,
instance_bin=None,
condition_source: str = "prediction",
):
"""Abstract base class for storing and visualizing calibrated explanations.
This class defines the interface and shared functionality for different types of calibrated explanations.
Initialize a CalibratedExplanation instance.
Parameters
----------
calibrated_explanations : :class:`.CalibratedExplanations`
The parent :class:`.CalibratedExplanations` object.
index : int
The index of the instance being explained.
x : array-like
The test dataset containing the instances to be explained.
binned : dict
A mapping of binned feature values.
feature_weights : dict
A mapping of feature weights.
feature_predict : dict
A mapping of feature predictions.
prediction : dict
A mapping containing the prediction results.
y_threshold : float or tuple, optional
The threshold for binary classification or regression explanations.
instance_bin : int, optional
The bin index of the instance.
"""
binned = MappingProxyType(binned)
feature_weights = MappingProxyType(feature_weights)
feature_predict = MappingProxyType(feature_predict)
prediction = MappingProxyType(prediction)
self.calibrated_explanations = calibrated_explanations
self.index = index
self.x_test = x
self.binned = {}
self.feature_weights = {}
self.feature_predict = {}
self.prediction = {}
for key in binned:
self.binned[key] = binned[key][index]
for key in feature_weights:
self.feature_weights[key] = feature_weights[key][index]
self.feature_predict[key] = feature_predict[key][index]
for key in prediction:
# Special handling: full probability matrix stored under magic key
if key == "__full_probabilities__":
self.prediction[key] = prediction[
key
] # keep whole matrix (used for golden baseline only)
else:
self.prediction[key] = prediction[key][index]
self.y_threshold = (
y_threshold
if np.isscalar(y_threshold) or isinstance(y_threshold, tuple)
else None
if y_threshold is None
else y_threshold[index]
)
self.conditions = []
self.rules = None
self.conjunctive_rules = None
self.has_rules = False
self.has_conjunctive_rules = False
self.bin = [instance_bin] if instance_bin is not None else None
self.explain_time = None
self.condition_source = condition_source
# reduce dependence on Explainer class
if not isinstance(self.get_explainer().y_cal, Categorical):
self.y_minmax = [
np.min(self.get_explainer().y_cal),
np.max(self.get_explainer().y_cal),
]
else:
self.y_minmax = [0, 0]
self.focus_columns = None
# Optional reject context attached when explanations are produced under
# a reject policy (FLAG / ONLY_REJECTED). Populated by orchestrator.
self.reject_context = None
self._validate_prediction_invariant()
[docs]
def filter_features(
self,
*,
exclude_features=None,
include_features=None,
copy=True,
):
"""Filter rules by feature inclusion or exclusion.
Parameters
----------
exclude_features : str, int, or list of str/int, optional
Features to exclude. Rules containing these features will be removed.
include_features : str, int, or list of str/int, optional
Features to include. Only rules containing these features will be kept.
copy : bool, default=True
If True, return a copy of the explanation. If False, modify in place.
Returns
-------
CalibratedExplanation
Filtered explanation.
"""
if (exclude_features is None) == (include_features is None):
raise ValidationError(
"Exactly one of exclude_features or include_features must be provided",
details={
"exclude_features": exclude_features,
"include_features": include_features,
},
)
if copy:
self = self.copy()
# Normalize the features to indices
target_features = exclude_features if exclude_features is not None else include_features
is_exclude = exclude_features is not None
if isinstance(target_features, (str, int)):
target_features = [target_features]
elif not isinstance(target_features, list):
raise ValidationError("Features must be a string, int, or list of strings/ints")
if not target_features:
raise ValidationError("Features list must not be empty")
target_indices = []
for feat in target_features:
if isinstance(feat, str):
if feat not in self.get_explainer().feature_names:
raise ValidationError(f"Feature name '{feat}' not found in feature_names")
target_indices.append(self.get_explainer().feature_names.index(feat))
elif isinstance(feat, int):
if not (0 <= feat < self.get_explainer().num_features):
raise ValidationError(
f"Feature index {feat} is out of range [0, {self.get_explainer().num_features})"
)
target_indices.append(feat)
else:
raise ValidationError("Features must contain only strings or ints")
# Create mask for rules to keep
keep_mask = []
for i, features in enumerate(self.rules["feature"]):
if self.rules["is_conjunctive"][i]:
# For conjunctive rules
if isinstance(features, list):
has_target = any(f in target_indices for f in features)
else:
has_target = features in target_indices
keep = has_target if not is_exclude else not has_target
else:
# For disjunctive rules (single feature)
has_target = features in target_indices
keep = has_target if not is_exclude else not has_target
keep_mask.append(keep)
# Filter rules
filtered_rules = {}
for key in self.rules:
filtered_rules[key] = [
val for val, keep in zip(self.rules[key], keep_mask, strict=False) if keep
]
self.rules = filtered_rules
return self
def _validate_prediction_invariant(self) -> None:
"""Enforce low <= predict <= high invariant on prediction payload."""
import warnings
predict = self.prediction.get("predict")
low = self.prediction.get("low")
high = self.prediction.get("high")
if predict is None or low is None or high is None:
return
with contextlib.suppress(TypeError, ValueError):
# Handle scalar values (common case)
if (
isinstance(predict, (int, float))
and isinstance(low, (int, float))
and isinstance(high, (int, float))
):
if not low <= high:
warnings.warn(
"Prediction interval invariant violated: low > high",
UserWarning,
stacklevel=2,
)
# Allow small floating point tolerance
epsilon = 1e-9
if not (low - epsilon <= predict <= high + epsilon):
warnings.warn(
"Prediction invariant violated: predict not in [low, high]",
UserWarning,
stacklevel=2,
)
def __len__(self):
"""Return the number of rules in the explanation."""
return len(self.get_rules()["rule"])
[docs]
@abstractmethod
def build_rules_payload(self) -> Dict[str, Any]:
"""Return structured rule payload separating core content from metadata."""
raise NotImplementedError
@property
def prediction_interval(self):
"""Get the prediction interval from the prediction dictionary.
Returns
-------
tuple
A tuple containing (low, high) values of the prediction interval.
"""
return (self.prediction["low"], self.prediction["high"])
@property
def predict(self):
"""Get the prediction from the prediction dictionary.
Returns
-------
float
A prediction value.
"""
return self.prediction["predict"]
[docs]
def get_mode(self):
"""Return the mode of the explanation ('classification' or 'regression')."""
return self.get_explainer().mode
[docs]
def get_class_labels(self):
"""Return the class labels."""
return self.get_explainer().class_labels
[docs]
def is_multiclass(self):
"""Determine if the explanation is multiclass."""
return self.get_explainer().is_multiclass()
[docs]
def get_explainer(self):
"""Return the explainer object."""
container = self.calibrated_explanations
getter = getattr(container, "get_explainer", None)
if callable(getter):
return getter()
getter = getattr(container, "_get_explainer", None)
if callable(getter):
return getter()
if hasattr(container, "explainer"):
return container.explainer
return getattr(container, "calibrated_explainer", container)
[docs]
def copy(self):
"""Create a shallow copy of the explanation object.
Returns
-------
CalibratedExplanation
A shallow copy of the explanation object.
"""
return copy(self)
[docs]
def ignored_features_for_instance(self):
"""Return the set of feature indices ignored for this instance.
Combines collection-level ``features_to_ignore`` with any
per-instance mask exposed via ``feature_filter_per_instance_ignore``.
"""
ignored: set[int] = set()
global_ignore = getattr(self.calibrated_explanations, "features_to_ignore", None)
if global_ignore is None:
global_ignore = ()
ignored.update(collect_ints(global_ignore))
per_instance = getattr(
self.calibrated_explanations, "feature_filter_per_instance_ignore", None
)
instance_mask = None
if isinstance(per_instance, Sequence) and 0 <= self.index < len(per_instance):
instance_mask = per_instance[self.index]
if instance_mask is not None:
ignored.update(collect_ints(instance_mask))
return ignored
[docs]
def rank_features(self, feature_weights=None, width=None, num_to_show=None):
"""Rank the features based on their weights.
Parameters
----------
feature_weights : dict, optional
A mapping of feature weights.
width : dict, optional
A mapping of feature widths.
num_to_show : int, optional
The number of features to show.
Returns
-------
list
The sorted indices of the features.
"""
if not (feature_weights is not None or width is not None):
from ..utils.exceptions import ValidationError
raise ValidationError(
"Either feature_weights or width (or both) must not be None",
details={
"param": "feature_weights/width",
"requirement": "at least one must be provided",
"feature_weights": feature_weights is not None,
"width": width is not None,
},
)
num_features = len(feature_weights) if feature_weights is not None else len(width)
if num_to_show is None or num_to_show > num_features:
num_to_show = num_features
# Robust ranking: handle NaN/inf
if feature_weights is not None:
feature_weights = np.nan_to_num(
feature_weights,
nan=0.0,
posinf=np.finfo(float).max,
neginf=-np.finfo(float).max,
)
if width is not None:
width = np.nan_to_num(
width,
nan=0.0,
posinf=np.finfo(float).max,
neginf=-np.finfo(float).max,
)
# handle case where there are same weight but different uncertainty
if feature_weights is not None and width is not None:
# get the indices by first sorting on the absolute value of the
# feature_weight and then on the width
sorted_indices = [
i
for i, x in sorted(
enumerate(list(zip(np.abs(feature_weights), width, strict=False))),
key=lambda x: (x[1][0], x[1][1]),
)
]
return sorted_indices[-num_to_show:] # pylint: disable=invalid-unary-operand-type
if width is not None:
sorted_indices = np.argsort(width)
return sorted_indices[-num_to_show:] # pylint: disable=invalid-unary-operand-type
sorted_indices = np.argsort(np.abs(feature_weights))
return sorted_indices[-num_to_show:] # pylint: disable=invalid-unary-operand-type
[docs]
def is_one_sided(self) -> bool:
"""Test if a regression explanation is one-sided.
Returns
-------
bool: True if one of the low or high percentiles is infinite
"""
if self.calibrated_explanations.low_high_percentiles is None:
return False
return np.isinf(self.calibrated_explanations.get_low_percentile()) or np.isinf(
self.calibrated_explanations.get_high_percentile()
)
[docs]
def is_thresholded(self) -> bool:
"""Check if the explanation is thresholded.
Returns
-------
bool: True if the y_threshold is not None
"""
return self.y_threshold is not None
[docs]
def is_regression(self) -> bool:
"""Check if the explanation is for regression.
Returns
-------
bool: True if mode is 'regression'
"""
return "regression" in self.get_explainer().mode
[docs]
def is_probabilistic(self) -> bool:
"""Check if the explanation is probabilistic.
Returns
-------
bool: True if mode is 'classification' or is_thresholded and is_regression are True
"""
return "classification" in self.get_explainer().mode or (
self.is_regression() and self.is_thresholded()
)
@abstractmethod
def __repr__(self):
"""Return a string representation of the explanation."""
[docs]
@abstractmethod
def plot(self, filter_top=None, **kwargs):
"""
Plot the explanation.
Parameters
----------
filter_top : int, optional
The number of top features to display.
**kwargs : dict
Additional plotting arguments. See each subclass.
See Also
--------
:meth:`.FactualExplanation.plot` : Refer to the docstring for plot in FactualExplanation for details.
:meth:`.AlternativeExplanation.plot` : Refer to the docstring for plot in AlternativeExplanation for details.
:meth:`.FastExplanation.plot` : Refer to the docstring for plot in FastExplanation for details.
"""
[docs]
def to_narrative(
self,
template_path=None,
expertise_level=("beginner", "advanced"),
output_format="dataframe",
conjunction_separator=" AND ",
align_weights=True,
**kwargs,
):
"""
Generate narrative explanation for this single instance.
This method provides a clean API for generating human-readable narratives
from a single calibrated explanation using customizable templates.
Parameters
----------
template_path : str or None, default=None
Path to the narrative template file (YAML or JSON).
If None or the file doesn't exist, the built-in default template is used.
expertise_level : str or tuple of str, default=("beginner", "advanced")
The expertise level(s) for narrative generation. Can be a single
level or a tuple of levels. Valid values: "beginner", "intermediate", "advanced".
output_format : str, default="dataframe"
Output format. Valid values: "dataframe", "text", "html", "dict", "markdown".
conjunction_separator : str, default=" AND "
Separator to use for conjunctive rules. Conjunctive rules combine
multiple feature conditions (e.g., "Glucose > 120 AND BMI > 28").
align_weights : bool, default=True
If True, vertically align columns in the narrative output.
For factual explanations this aligns the weight marker ("— weight …").
For alternative explanations this aligns the "then" keyword in rule lines.
If False, no alignment is applied.
**kwargs : dict
Additional keyword arguments passed to the narrative plugin.
Returns
-------
pd.DataFrame or str or dict
The generated narrative in the requested format:
- "dataframe": pandas DataFrame with one row
- "text": formatted text string
- "html": HTML table with one row
- "dict": dictionary with narrative fields
Raises
------
FileNotFoundError
If the template file is not found and no default is available.
ValueError
If an invalid expertise level or output format is specified.
ImportError
If pandas is not available and output_format="dataframe" is requested.
Examples
--------
>>> from calibrated_explanations import CalibratedExplainer
>>> explainer = CalibratedExplainer(model, x_train, y_train)
>>> explanations = explainer.explain_factual(x_test)
>>> single_explanation = explanations[0]
>>> narrative = single_explanation.to_narrative(
... expertise_level=("beginner", "advanced"),
... output_format="dataframe"
... )
>>> print(narrative)
See Also
--------
:meth:`.CalibratedExplanations.to_narrative` : Generate narratives for a collection of explanations.
:meth:`.plot` : Plot explanations with various visual styles.
"""
from ..viz.narrative_plugin import NarrativePlotPlugin
# Create a temporary collection with just this explanation
# We need to wrap this single explanation in a collection-like object
# to use the narrative plugin
# Create plugin instance
plugin = NarrativePlotPlugin(template_path=template_path)
# Create a minimal wrapper that looks like a collection
class SingleExplanationWrapper:
def __init__(self, explanation):
self.explanations = [explanation]
self._parent = explanation.calibrated_explanations
self.calibrated_explainer = explanation.calibrated_explanations.calibrated_explainer
self.y_threshold = explanation.y_threshold
def is_alternative(self) -> bool:
parent_is_alternative = getattr(self._parent, "is_alternative", None)
if callable(parent_is_alternative):
return bool(parent_is_alternative())
return "Alternative" in type(self.explanations[0]).__name__
wrapper = SingleExplanationWrapper(self)
# Generate narrative using the plugin
result = plugin.plot(
wrapper,
template_path=template_path,
expertise_level=expertise_level,
output=output_format,
conjunction_separator=conjunction_separator,
align_weights=align_weights,
**kwargs,
)
# For single explanations, extract the first row/item if it's a collection
if output_format == "dataframe":
# Return the DataFrame (will have one row)
return result
elif output_format == "dict":
# Return the first (and only) dictionary
return result[0] if result else {}
else:
# For text and html, return as is
return result
[docs]
def to_dataframe(self, *args, **kwargs):
"""Return the narrative output as a pandas DataFrame.
Call :meth:`to_narrative` with ``output_format='dataframe'`` and return
the resulting DataFrame. Accepts the same arguments as
:meth:`to_narrative`.
"""
kwargs.setdefault("output_format", "dataframe")
return self.to_narrative(*args, **kwargs)
[docs]
@abstractmethod
def add_conjunctions(self, n_top_features=5, max_rule_size=2):
"""
Add conjunctive rules to the explanation.
Parameters
----------
n_top_features : int, optional
Number of top features to combine.
max_rule_size : int, optional
Maximum size of the conjunctions.
Returns
-------
:class:`.CalibratedExplanation`
"""
@abstractmethod
def _check_preconditions(self):
"""Validate that required explanation inputs and state are available."""
pass
[docs]
def reset(self):
"""Reset the explanation to its original state."""
# Reset both base rules and any derived conjunctive overlays.
# Conjunctions can change downstream payloads (plots, telemetry), so a
# reset should restore the atomic/original representation.
self.has_rules = False
self.has_conjunctive_rules = False
self.conjunctive_rules = None
self.get_rules()
return self
[docs]
def remove_conjunctions(self):
"""Remove any conjunctive rules."""
self.has_conjunctive_rules = False
self.conjunctive_rules = None
return self
[docs]
def filter_rule_sizes(
self,
*,
rule_sizes: Optional[Any] = None,
size_range: Optional[Tuple[int, int]] = None,
copy: bool = True,
):
"""Filter rules by conjunctive rule size.
Parameters
----------
rule_sizes : int or sequence of int, optional
Explicit rule sizes to keep (e.g., 1, 2, 3).
size_range : tuple(int, int), optional
Inclusive (min_size, max_size) range of rule sizes to keep.
copy : bool, default=True
If True, return a filtered copy without mutating the original.
"""
if (rule_sizes is None) == (size_range is None):
raise ValidationError(
"Exactly one of rule_sizes or size_range must be provided",
details={"rule_sizes": rule_sizes, "size_range": size_range},
)
normalized_sizes: Optional[set[int]] = None
min_size: Optional[int] = None
max_size: Optional[int] = None
if rule_sizes is not None:
if isinstance(rule_sizes, (int, np.integer)):
normalized_sizes = {int(rule_sizes)}
elif isinstance(rule_sizes, (list, tuple, set, np.ndarray)):
normalized_sizes = {int(v) for v in rule_sizes}
else:
raise ValidationError(
"rule_sizes must be an int or a sequence of ints",
details={"rule_sizes": rule_sizes},
)
if not normalized_sizes:
raise ValidationError(
"rule_sizes must not be empty", details={"rule_sizes": rule_sizes}
)
if any(size <= 0 for size in normalized_sizes):
raise ValidationError(
"rule_sizes must contain positive integers",
details={"rule_sizes": sorted(normalized_sizes)},
)
if size_range is not None:
if not isinstance(size_range, (list, tuple)) or len(size_range) != 2:
raise ValidationError(
"size_range must be a (min_size, max_size) tuple",
details={"size_range": size_range},
)
min_size = int(size_range[0])
max_size = int(size_range[1])
if min_size <= 0 or max_size <= 0:
raise ValidationError(
"size_range bounds must be positive integers",
details={"size_range": size_range},
)
if min_size > max_size:
raise ValidationError(
"size_range must satisfy min_size <= max_size",
details={"size_range": size_range},
)
rules = self.get_rules()
num_rules = len(rules.get("rule", []))
def _rule_size(feature: Any) -> int:
if isinstance(feature, (list, tuple, np.ndarray)):
return len(np.asarray(feature).ravel())
return 1
mask = []
features = rules.get("feature", [None] * num_rules)
for i in range(num_rules):
size = _rule_size(features[i]) if i < len(features) else 1
if normalized_sizes is not None:
keep = size in normalized_sizes
else:
keep = min_size <= size <= max_size # type: ignore[operator]
mask.append(keep)
mask_array = np.asarray(mask, dtype=bool)
def _clone_value(val: Any) -> Any:
if isinstance(val, MappingProxyType):
return {k: _clone_value(v) for k, v in val.items()}
if isinstance(val, list):
return list(val)
if isinstance(val, np.ndarray):
return val.copy()
return val
def _filter_value(val: Any) -> Any:
if isinstance(val, list) and len(val) == num_rules:
return [v for v, keep in zip(val, mask, strict=False) if keep]
if isinstance(val, np.ndarray) and val.shape[0] == num_rules:
return val[mask_array].copy()
return _clone_value(val)
filtered_rules = {k: _filter_value(v) for k, v in rules.items()}
target = self.copy() if copy else self
has_conjunctive = bool(getattr(target, "has_conjunctive_rules", False))
extractor = getattr(
target, "_AlternativeExplanation__extracted_non_conjunctive_rules", None
)
if has_conjunctive and callable(extractor):
extractor(filtered_rules)
target.has_conjunctive_rules = True
return target
target.rules = filtered_rules
if has_conjunctive:
target.conjunctive_rules = filtered_rules
return target
# ------------------------------------------------------------------
# Telemetry helpers
# ------------------------------------------------------------------
[docs]
@staticmethod
def to_python_number(value: Any) -> Any:
"""Convert numpy/scalar values to native Python types suitable for telemetry."""
if isinstance(value, np.generic):
value = value.item()
if isinstance(value, np.ndarray):
return [CalibratedExplanation.to_python_number(v) for v in value.tolist()]
if isinstance(value, (list, tuple)):
return [CalibratedExplanation.to_python_number(v) for v in value]
if value is None:
return None
if isinstance(value, (np.bool_, bool)):
return bool(value)
if isinstance(value, (np.integer, int)):
return int(value)
if isinstance(value, (np.floating, float)):
if math.isnan(value):
return None
return float(value)
return value
[docs]
@staticmethod
def normalize_percentile_value(value: Any) -> Optional[float]:
"""Normalise percentile inputs to decimal fractions."""
value = CalibratedExplanation.to_python_number(value)
if value is None:
return None
if isinstance(value, (float, int)):
value = float(value)
if math.isinf(value):
return value
if abs(value) > 1.0:
return value / 100.0
return value
return None
def _get_percentiles(self) -> Optional[Tuple[Optional[float], Optional[float]]]:
"""Return decimal percentiles if available."""
percentiles = getattr(self.calibrated_explanations, "low_high_percentiles", None)
if percentiles is None or len(percentiles) != 2:
return None
low = self.normalize_percentile_value(percentiles[0])
high = self.normalize_percentile_value(percentiles[1])
return (low, high)
[docs]
def get_percentiles(self) -> Optional[Tuple[Optional[float], Optional[float]]]:
"""Return the decoded percentile configuration."""
return self._get_percentiles()
[docs]
@staticmethod
def compute_confidence_level(
percentiles: Optional[Tuple[Optional[float], Optional[float]]],
) -> Optional[float]:
"""Compute confidence level from decimal percentiles."""
if not percentiles:
return None
low, high = percentiles
if low is None or high is None:
return None
if low == -math.inf:
return None if high in (None, math.inf) else high
if high == math.inf:
return None if low is None else 1 - low
return max(0.0, high - low)
[docs]
def normalize_threshold_value(self) -> Any:
"""Normalise threshold metadata to telemetry-friendly structure."""
threshold = self.y_threshold
if threshold is None:
return None
if isinstance(threshold, np.ndarray):
threshold = threshold.tolist()
if isinstance(threshold, (list, tuple)):
if len(threshold) == 0:
return None
values = [CalibratedExplanation.to_python_number(threshold[0])]
if len(threshold) > 1:
values.append(CalibratedExplanation.to_python_number(threshold[1]))
return values
return CalibratedExplanation.to_python_number(threshold)
def _build_uncertainty_payload(
self,
*,
value: Any,
low: Any,
high: Any,
representation: str,
percentiles: Optional[Tuple[Optional[float], Optional[float]]] = None,
threshold: Any = None,
include_percentiles: bool = True,
) -> Dict[str, Any]:
"""Create a structured uncertainty payload."""
lower = CalibratedExplanation.to_python_number(low)
upper = CalibratedExplanation.to_python_number(high)
payload: Dict[str, Any] = {
"representation": representation,
"calibrated_value": CalibratedExplanation.to_python_number(value),
"lower_bound": lower,
"upper_bound": upper,
"legacy_interval": [lower, upper],
}
payload["threshold"] = threshold
payload["raw_percentiles"] = None
payload["confidence_level"] = None
if include_percentiles and percentiles:
payload["raw_percentiles"] = [
CalibratedExplanation.to_python_number(percentiles[0]),
CalibratedExplanation.to_python_number(percentiles[1]),
]
confidence = CalibratedExplanation.compute_confidence_level(percentiles)
if confidence is not None:
payload["confidence_level"] = confidence
return payload
# Public alias for testing
build_uncertainty_payload = _build_uncertainty_payload
@staticmethod
def _build_interval(low: Any, high: Any) -> Dict[str, Any]:
"""Return a minimal uncertainty interval with Python-native bounds."""
return {
"lower": CalibratedExplanation.to_python_number(low),
"upper": CalibratedExplanation.to_python_number(high),
}
[docs]
@staticmethod
def build_interval(low: Any, high: Any) -> Dict[str, Any]:
"""Public helper exposing interval construction."""
return CalibratedExplanation._build_interval(low, high)
def _build_instance_uncertainty(self) -> Dict[str, Any]:
"""Build uncertainty payload for the current instance prediction."""
if self.is_thresholded():
return self._build_uncertainty_payload(
value=self.prediction["predict"],
low=self.prediction["low"],
high=self.prediction["high"],
representation="threshold",
threshold=self.normalize_threshold_value(),
include_percentiles=False,
)
if self.is_probabilistic():
return self._build_uncertainty_payload(
value=self.prediction["predict"],
low=self.prediction["low"],
high=self.prediction["high"],
representation="venn_abers",
include_percentiles=False,
)
percentiles = self._get_percentiles()
return self._build_uncertainty_payload(
value=self.prediction["predict"],
low=self.prediction["low"],
high=self.prediction["high"],
representation="percentile",
percentiles=percentiles,
include_percentiles=True,
)
[docs]
def build_instance_uncertainty(self) -> Dict[str, Any]:
"""Expose the instance uncertainty payload builder."""
return self._build_instance_uncertainty()
def _safe_feature_name(self, feature_index: Any) -> str:
"""Return a readable feature name for telemetry."""
feature_names = getattr(self.get_explainer(), "feature_names", None)
try:
idx = int(feature_index)
except (
TypeError,
ValueError,
): # ADR002_ALLOW: handle non-numeric feature ids. # pragma: no cover
return str(feature_index)
if feature_names and 0 <= idx < len(feature_names):
return str(feature_names[idx])
return str(idx)
[docs]
def safe_feature_name(self, feature_index: Any) -> str:
"""Return the human-readable name for a feature index."""
return self._safe_feature_name(feature_index)
[docs]
@staticmethod
def convert_condition_value(raw_value: Optional[str], fallback: Any) -> Any:
"""Convert textual condition payloads to structured values."""
if raw_value is None:
return CalibratedExplanation.to_python_number(fallback)
text = raw_value.strip()
if text.lower() in {"-inf", "-infinity"}:
return float("-inf")
if text.lower() in {"inf", "+inf", "infinity"}:
return float("inf")
try:
return float(text)
except (
ValueError
): # ADR002_ALLOW: textual rule fragments may not be numeric. # pragma: no cover
return text
def _parse_condition(self, feature_name: str, rule_text: str) -> Tuple[str, Optional[str]]:
"""Attempt to parse rule text into operator and value tokens."""
if not rule_text:
return "raw", None
text = rule_text.strip()
pattern = rf"^{re.escape(feature_name)}\s*(<=|>=|==|=|<|>|in)\s*(.+)$"
match = re.match(pattern, text)
if match:
operator = match.group(1)
value_text = match.group(2).strip()
if operator == "=":
operator = "=="
return operator.lower(), value_text
return "raw", text
[docs]
def parse_condition(self, feature_name: str, rule_text: str) -> Tuple[str, Optional[str]]:
"""Return the parsed operator/value pair for a rule."""
return self._parse_condition(feature_name, rule_text)
def _build_condition_payload(
self,
feature_index: Any,
rule_text: str,
feature_value: Any,
display_value: Any,
) -> Dict[str, Any]:
"""Convert rule metadata into telemetry condition payload."""
feature_name = self._safe_feature_name(feature_index)
operator, parsed_value = self._parse_condition(feature_name, rule_text)
if operator == "raw":
value = CalibratedExplanation.to_python_number(display_value)
else:
value = self.convert_condition_value(parsed_value, display_value)
return {
"feature": feature_name,
"operator": operator,
"value": value,
"text": rule_text,
}
[docs]
def build_condition_payload(
self,
feature_index: Any,
rule_text: str,
feature_value: Any,
display_value: Any,
) -> Dict[str, Any]:
"""Expose condition payload building for external use."""
return self._build_condition_payload(feature_index, rule_text, feature_value, display_value)
[docs]
def to_telemetry(self) -> Dict[str, Any]:
"""Return telemetry payload for this explanation instance."""
payload = self.build_rules_payload()
metadata = payload.get("metadata", {})
metadata.setdefault("prediction_uncertainty", self._build_instance_uncertainty())
return {
"uncertainty": metadata["prediction_uncertainty"],
"rules": payload,
"metadata": metadata,
}
[docs]
def define_conditions(self):
"""
Define the rule conditions for an instance.
Returns
-------
list[str]
A list of conditions for each feature in the instance.
"""
self.conditions = []
# pylint: disable=invalid-name
explainer = self.get_explainer()
if explainer.discretizer is None:
# Handle missing discretizer (e.g. regression without discretization)
# For now, just use empty conditions or skip
# print("DEBUG: FactualExplanation.define_conditions: discretizer is None")
pass
else:
x = explainer.discretizer.discretize(self.x_test)
ignored = self.ignored_features_for_instance()
for f in range(self.get_explainer().num_features):
if f in ignored:
self.conditions.append("")
continue
if explainer.discretizer is None:
val = self.x_test[f]
rule = f"{self.get_explainer().feature_names[f]} = {val}"
self.conditions.append(rule)
continue
if f in self.get_explainer().categorical_features:
if self.get_explainer().categorical_labels is not None:
try:
target = self.get_explainer().categorical_labels[f][int(x[f])]
rule = f"{self.get_explainer().feature_names[f]} = {target}"
except IndexError:
rule = f"{self.get_explainer().feature_names[f]} = {x[f]}"
else:
rule = f"{self.get_explainer().feature_names[f]} = {x[f]}"
else:
rule = self.get_explainer().discretizer.names[f][int(x[f])]
self.conditions.append(rule)
return self.conditions
[docs]
def predict_conjunction_tuple(
self,
rule_value_set,
original_features,
perturbed,
threshold,
predicted_class,
bins=None,
):
"""Calculate the prediction for a conjunctive rule using batched inference."""
from ..utils.exceptions import ValidationError
predict_fn = self.get_explainer().prediction_orchestrator.predict_internal
# `perturbed` is expected to be a mutable scratch copy from the caller.
perturbed = np.asarray(perturbed)
if perturbed.ndim == 1:
perturbed = perturbed.reshape(1, -1)
elif perturbed.ndim != 2:
raise ValidationError(
"perturbed must be a 1D or 2D array-like",
details={"param": "perturbed", "ndim": int(perturbed.ndim)},
)
try:
original_features = [int(v) for v in original_features]
except (TypeError, ValueError) as exc:
raise ValidationError(
"original_features must contain integer indices",
details={"param": "original_features", "value": original_features},
) from exc
# Prepare value arrays
value_iterables = []
for values in rule_value_set[: len(original_features)]:
arr = np.asarray(values)
if arr.ndim == 0:
arr = arr.reshape(1)
if arr.size == 0:
return 0.0, 0.0, 0.0
value_iterables.append(arr)
if not value_iterables:
return 0.0, 0.0, 0.0
# Generate all combinations in a vectorized fashion
grids = np.meshgrid(*value_iterables, indexing="ij")
combo_matrix = np.stack(grids, axis=-1).reshape(-1, len(original_features))
if combo_matrix.size == 0:
return 0.0, 0.0, 0.0
# Create batch
batch = np.repeat(perturbed, combo_matrix.shape[0], axis=0)
# Handle bins if provided
batch_bins = None
if bins is not None:
# bins is typically a scalar or 1D array for the single instance being explained
# We need to tile it to match the batch size
bins_arr = np.asarray(bins)
if bins_arr.ndim == 0:
batch_bins = np.full(combo_matrix.shape[0], bins_arr.item())
else:
batch_bins = np.tile(bins, combo_matrix.shape[0])
# Validate batch_bins length matches batch size
if len(batch_bins) != batch.shape[0]:
batch_bins = batch_bins[: batch.shape[0]]
# Apply perturbations in bulk
batch[:, original_features] = combo_matrix
# Predict
p_value, low, high, _ = predict_fn(
batch,
threshold=threshold,
low_high_percentiles=self.calibrated_explanations.low_high_percentiles,
classes=predicted_class,
bins=batch_bins,
)
return (
float(np.mean(p_value)),
float(np.mean(low)),
float(np.mean(high)),
)
def _predict_conjunctive(
self,
rule_value_set,
original_features,
perturbed,
threshold,
predicted_class,
bins=None,
use_batched=False,
):
"""
Calculate the prediction for a conjunctive rule.
Parameters
----------
rule_value_set : list
The set of rule values.
original_features : list
The original feature indices.
perturbed : array-like
The perturbed dataset.
threshold : float
The threshold for classification or regression.
predicted_class : int
The predicted class label.
bins : array-like, optional
The bins for discretization.
use_batched : bool, optional
Whether to use batched inference.
Returns
-------
tuple
The predicted value, lower bound, upper bound, and count.
"""
if len(original_features) < 2:
from ..utils.exceptions import ValidationError
raise ValidationError(
"Conjunctive rules require at least two features",
details={
"param": "original_features",
"count": len(original_features),
"requirement": "minimum 2 features",
},
)
if use_batched:
return self.predict_conjunction_tuple(
rule_value_set,
original_features,
perturbed,
threshold,
predicted_class,
bins,
)
predict_fn = self.get_explainer().prediction_orchestrator.predict_internal
# Ensure perturbed is a writable copy to avoid "read-only" errors
perturbed = np.array(perturbed, copy=True)
base_values = np.array([perturbed[idx] for idx in original_features], copy=True)
rule_predict = 0.0
rule_low = 0.0
rule_high = 0.0
rule_count = 0
value_iterables = [
np.asarray(values) for values in rule_value_set[: len(original_features)]
]
def _restore() -> None:
for pos, feat_idx in enumerate(original_features):
perturbed[feat_idx] = base_values[pos]
try:
if len(original_features) == 2:
of1, of2 = original_features[:2]
values1, values2 = value_iterables[:2]
for value_1 in values1:
perturbed[of1] = value_1
for value_2 in values2:
perturbed[of2] = value_2
perturbed_row = perturbed.reshape(1, -1)
p_value, low, high, _ = predict_fn(
perturbed_row,
threshold=threshold,
low_high_percentiles=self.calibrated_explanations.low_high_percentiles,
classes=predicted_class,
bins=bins,
)
rule_predict += float(safe_first_element(p_value))
rule_low += float(safe_first_element(low))
rule_high += float(safe_first_element(high))
rule_count += 1
else:
of1, of2, of3 = original_features[:3]
values1, values2, values3 = value_iterables[:3]
for value_1 in values1:
perturbed[of1] = value_1
for value_2 in values2:
perturbed[of2] = value_2
for value_3 in values3:
perturbed[of3] = value_3
perturbed_row = perturbed.reshape(1, -1)
p_value, low, high, _ = predict_fn(
perturbed_row,
threshold=threshold,
low_high_percentiles=self.calibrated_explanations.low_high_percentiles,
classes=predicted_class,
bins=bins,
)
rule_predict += float(safe_first_element(p_value))
rule_low += float(safe_first_element(low))
rule_high += float(safe_first_element(high))
rule_count += 1
finally:
_restore()
if rule_count:
rule_predict /= rule_count
rule_low /= rule_count
rule_high /= rule_count
return rule_predict, rule_low, rule_high
[docs]
def predict_conjunctive(
self,
rule_value_set,
original_features,
perturbed,
threshold,
predicted_class,
bins=None,
use_batched=False,
):
"""Public wrapper around the conjunctive prediction helper."""
return self._predict_conjunctive(
rule_value_set,
original_features,
perturbed,
threshold,
predicted_class,
bins=bins,
use_batched=use_batched,
)
@abstractmethod
def _is_lesser(self, rule_boundary, instance_value):
"""Return True when an instance value satisfies a 'less than' rule boundary."""
pass
# pylint: disable=too-many-arguments, too-many-statements, too-many-branches, too-many-return-statements
[docs]
def add_new_rule_condition(self, feature, rule_boundary):
"""
Create a new rule condition for a numerical feature.
Parameters
----------
feature : int or str
The feature index or name.
rule_boundary : int or float
The value to define as rule condition.
Returns
-------
:class:`.CalibratedExplanation`
Notes
-----
The function will return the same explanation if the rule is already included or if the feature is categorical.
No implementation is provided for the :class:`.FastExplanation` class.
"""
f = None
if isinstance(feature, int):
f = feature
else:
with contextlib.suppress(ValueError):
f = self.get_explainer().feature_names.index(feature)
if f is None:
warnings.warn(f"Feature {feature} not found", stacklevel=2)
return self
if (
self.get_explainer().categorical_features is not None
and f in self.get_explainer().categorical_features
):
warnings.warn(
"Alternatives for all categorical features are already included", stacklevel=2
)
return self
x_copy = np.array(self.x_test, copy=True)
is_lesser = self._is_lesser(rule_boundary, x_copy[f])
new_rule = self.get_rules()
rule = self._get_rule_str(is_lesser, f, rule_boundary)
if np.any([new_rule["rule"][i] == rule for i in range(len(new_rule["rule"]))]):
warnings.warn("Rule already included", stacklevel=2)
return self
threshold = self.y_threshold
perturbed_threshold = normalize_threshold(threshold)
perturbed_bins = np.empty((0,)) if self.bin is not None else None
perturbed_x = np.empty((0, self.get_explainer().num_features))
perturbed_feature = np.empty((0, 4)) # (feature, instance, bin_index, is_lesser)
perturbed_class = np.empty((0,), dtype=int)
cal_x_f = self.get_explainer().x_cal[:, f]
feature_values = np.unique(np.array(cal_x_f))
sample_percentiles = self.get_explainer().sample_percentiles
if is_lesser:
if not np.any(feature_values < rule_boundary):
warnings.warn(
f"Lowest feature value for feature {feature} is {np.min(feature_values)}",
stacklevel=2,
)
return self
values = np.percentile(cal_x_f[cal_x_f < rule_boundary], sample_percentiles)
covered = np.percentile(cal_x_f[cal_x_f >= rule_boundary], sample_percentiles)
else:
if not np.any(feature_values > rule_boundary):
warnings.warn(
f"Highest feature value for feature {feature} is {np.max(feature_values)}",
stacklevel=2,
)
return self
values = np.percentile(cal_x_f[cal_x_f > rule_boundary], sample_percentiles)
covered = np.percentile(cal_x_f[cal_x_f <= rule_boundary], sample_percentiles)
for value in values:
x_local = np.reshape(x_copy, (1, -1))
x_local[0, f] = value
perturbed_x = np.concatenate((perturbed_x, np.array(x_local)))
perturbed_feature = np.concatenate((perturbed_feature, [(f, 0, None, is_lesser)]))
perturbed_bins = (
np.concatenate((perturbed_bins, self.bin)) if self.bin is not None else None
)
perturbed_class = np.concatenate(
(perturbed_class, np.array([self.prediction["classes"]]))
)
if isinstance(threshold, tuple):
perturbed_threshold = threshold
elif threshold is None:
perturbed_threshold = None
elif np.isscalar(perturbed_threshold) and perturbed_threshold == threshold:
perturbed_threshold = threshold
else:
perturbed_threshold = np.concatenate((perturbed_threshold, threshold))
for value in covered:
x_local = np.reshape(x_copy, (1, -1))
x_local[0, f] = value
perturbed_x = np.concatenate((perturbed_x, np.array(x_local)))
perturbed_feature = np.concatenate((perturbed_feature, [(f, 0, None, None)]))
perturbed_bins = (
np.concatenate((perturbed_bins, self.bin)) if self.bin is not None else None
)
perturbed_class = np.concatenate(
(perturbed_class, np.array([self.prediction["classes"]]))
)
if isinstance(threshold, tuple):
perturbed_threshold = threshold
elif threshold is None:
perturbed_threshold = None
elif np.isscalar(perturbed_threshold) and perturbed_threshold == threshold:
perturbed_threshold = threshold
else:
perturbed_threshold = np.concatenate((perturbed_threshold, threshold))
predict, low, high, _ = self.get_explainer().prediction_orchestrator.predict_internal(
perturbed_x,
threshold=perturbed_threshold,
low_high_percentiles=self.calibrated_explanations.low_high_percentiles,
classes=perturbed_class,
bins=perturbed_bins,
)
instance_predict = [
predict[i] for i in range(len(predict)) if perturbed_feature[i][3] is None
]
rule_predict = [
predict[i] for i in range(len(predict)) if perturbed_feature[i][3] is not None
]
rule_low = [low[i] for i in range(len(low)) if perturbed_feature[i][3] is not None]
rule_high = [high[i] for i in range(len(high)) if perturbed_feature[i][3] is not None]
# skip if identical to original
if self.prediction["low"] == safe_mean(rule_low) and self.prediction["high"] == safe_mean(
rule_high
):
warnings.warn(
"The alternative explanation is identical to the original explanation",
UserWarning,
stacklevel=2,
)
return self
new_rule["predict"].append(safe_mean(rule_predict))
new_rule["predict_low"].append(safe_mean(rule_low))
new_rule["predict_high"].append(safe_mean(rule_high))
new_rule["weight"].append(safe_mean(rule_predict) - safe_mean(instance_predict))
new_rule["weight_low"].append(
safe_mean(rule_low) - safe_mean(instance_predict) if rule_low != -np.inf else rule_low
)
new_rule["weight_high"].append(
safe_mean(rule_high) - safe_mean(instance_predict) if rule_high != np.inf else rule_high
)
new_rule["value"].append(str(np.around(self.x_test[f], decimals=2)))
new_rule["feature"].append(f)
new_rule["sampled_values"].append(self.binned["rule_values"][f][0][0])
new_rule["feature_value"].append(self.x_test[f])
new_rule["is_conjunctive"].append(False)
new_rule["rule"].append(rule)
self.rules = new_rule
return self
def _get_rule_str(self, is_lesser, feature, rule_boundary):
"""Get the rule string for the explanation.
Parameters
----------
is_lesser : bool
Whether the rule is a lesser condition.
feature : str
The feature name.
rule_boundary : float
The rule boundary value.
Returns
-------
str
The rule string.
"""
if is_lesser:
return f"{self.get_explainer().feature_names[feature]} < {rule_boundary:.2f}"
return f"{self.get_explainer().feature_names[feature]} > {rule_boundary:.2f}"
# pylint: disable=too-many-instance-attributes, too-many-locals, too-many-arguments
[docs]
class FactualExplanation(CalibratedExplanation):
"""Store and visualise calibrated factual explanations.
The public contract mirrors the CE definition established in the
classification and regression papers: a factual explanation couples the
calibrated prediction and its uncertainty interval with a collection of
factual feature rules. Each rule binds the observed feature value to a
condition and exposes the calibrated feature weight plus its uncertainty
interval. Downstream helpers (telemetry, JSON export, plots) rely on this
invariant, so the internal representation and helper payloads **must** keep
the prediction + interval pair alongside weight + interval information for
every factual rule.
For detailed information about the internal data structures and rule generation
process, see docs/foundations/concepts/explanation_structures.md.
"""
def __init__(
self,
calibrated_explanations,
index,
x,
binned,
feature_weights,
feature_predict,
prediction,
y_threshold=None,
instance_bin=None,
condition_source: str = "prediction",
):
"""Class for storing and visualizing factual explanations.
Provides factual explanations for a given instance, highlighting features that contribute to the model's prediction.
Initialize a FactualExplanation instance.
Parameters
----------
calibrated_explanations : CalibratedExplanations
The parent CalibratedExplanations object.
index : int
The index of the instance being explained.
x : array-like
The test dataset containing the instances to be explained.
binned : dict
A mapping of binned feature values.
feature_weights : dict
A mapping of feature weights.
feature_predict : dict
A mapping of feature predictions.
prediction : dict
A mapping containing the prediction results.
y_threshold : float or tuple, optional
The threshold for binary classification or regression explanations.
instance_bin : int, optional
The bin index of the instance.
"""
super().__init__(
calibrated_explanations,
index,
x,
binned,
feature_weights,
feature_predict,
prediction,
y_threshold,
instance_bin,
condition_source=condition_source,
)
self._check_preconditions()
self.get_rules()
# Cache per-instance prediction probabilities for golden baseline (classification)
try:
if not self.is_regression():
# Access stored full probability matrix via parent prediction mapping
full_probs = self.prediction.get("__full_probabilities__")
if full_probs is not None:
# full_probs may be a tuple (proba_matrix, classes) for multiclass
if isinstance(full_probs, tuple) and len(full_probs) >= 1:
proba_matrix = full_probs[0]
else:
proba_matrix = full_probs
# Attach whole matrix on first explanation, then propagate
if self.index == 0:
self.prediction_probabilities = proba_matrix
else:
# ensure earlier explanation already stored full matrix
self.prediction_probabilities = getattr(
self.calibrated_explanations.explanations[0],
"prediction_probabilities",
proba_matrix,
)
else:
self.prediction_probabilities = None
except (
Exception
): # ADR002_ALLOW: prediction payloads vary across plugins. # pragma: no cover
self.prediction_probabilities = None
def __repr__(self):
"""Return a string representation of the factual explanation."""
# Use canonical rules to ensure parity with plot and narrative
canonical_rules = self._rules_with_impact()
predict = self.prediction
output = [
f"{'Prediction':10} [{' Low':5}, {' High':5}]",
f"{predict['predict']:5.3f} [{predict['low']:5.3f}, {predict['high']:5.3f}]",
f"{'Value':6}: {'Feature':40s} {'Weight':6} [{' Low':6}, {' High':6}]",
]
for r in canonical_rules:
output.append(
f"{str(r.value):6}: {r.text:40s} {r.impact:>6.3f} [{r.weight_envelope_low:>6.3f}, {r.weight_envelope_high:>6.3f}]"
)
return "\n".join(output) + "\n"
[docs]
def build_rules_payload(self) -> Dict[str, Any]:
"""Return structured payload describing factual feature rules."""
rules = self.get_rules()
prediction_value = CalibratedExplanation.to_python_number(self.prediction.get("predict"))
prediction_interval = CalibratedExplanation._build_interval(
self.prediction.get("low"),
self.prediction.get("high"),
)
core: Dict[str, Any] = {
"kind": "factual",
"prediction": {
"value": prediction_value,
"uncertainty_interval": prediction_interval,
},
"feature_rules": [],
}
metadata: Dict[str, Any] = {"feature_rules": []}
if not rules or "rule" not in rules:
return {"core": core, "metadata": metadata}
base_predict = rules.get("base_predict", [None])
base_low = rules.get("base_predict_low", [None])
base_high = rules.get("base_predict_high", [None])
baseline_value = CalibratedExplanation.to_python_number(base_predict[0])
if baseline_value is not None:
metadata["baseline_prediction"] = baseline_value
metadata["baseline_interval"] = CalibratedExplanation._build_interval(
base_low[0],
base_high[0],
)
percentiles = None
if not self.is_probabilistic():
percentiles = self._get_percentiles()
representation = "venn_abers" if self.is_probabilistic() else "percentile"
count = len(rules.get("rule", []))
for idx in range(count):
feature_index = rules["feature"][idx]
condition = self._build_condition_payload(
feature_index,
rules["rule"][idx],
rules["feature_value"][idx],
rules["value"][idx],
)
weight_value = CalibratedExplanation.to_python_number(rules["weight"][idx])
weight_interval = CalibratedExplanation._build_interval(
rules["weight_low"][idx],
rules["weight_high"][idx],
)
core["feature_rules"].append(
{
"weight": {
"value": weight_value,
"uncertainty_interval": weight_interval,
},
"condition": condition,
}
)
weight_uncertainty = self._build_uncertainty_payload(
value=weight_value,
low=rules["weight_low"][idx],
high=rules["weight_high"][idx],
representation=representation,
percentiles=percentiles if representation == "percentile" else None,
include_percentiles=representation == "percentile",
)
prediction_representation = "threshold" if self.is_thresholded() else representation
prediction_uncertainty = self._build_uncertainty_payload(
value=rules["predict"][idx],
low=rules["predict_low"][idx],
high=rules["predict_high"][idx],
representation=prediction_representation,
percentiles=(
percentiles
if prediction_representation == "percentile" and not self.is_thresholded()
else None
),
threshold=self.normalize_threshold_value() if self.is_thresholded() else None,
include_percentiles=(
prediction_representation == "percentile" and not self.is_thresholded()
),
)
metadata_rule: Dict[str, Any] = {
"feature": self._safe_feature_name(feature_index),
"feature_index": CalibratedExplanation.to_python_number(feature_index),
"weight_uncertainty": weight_uncertainty,
"prediction_uncertainty": prediction_uncertainty,
"prediction_value": CalibratedExplanation.to_python_number(rules["predict"][idx]),
"condition_text": rules["rule"][idx],
"instance_value": CalibratedExplanation.to_python_number(
rules["feature_value"][idx]
),
}
metadata["feature_rules"].append(metadata_rule)
metadata["prediction_uncertainty"] = self._build_instance_uncertainty()
return {"core": core, "metadata": metadata}
def _check_preconditions(self):
"""Warn when the selected discretizer is incompatible with factual explanations."""
if self.is_regression():
if not isinstance(self.get_explainer().discretizer, BinaryRegressorDiscretizer):
warnings.warn(
"Factual explanations for regression recommend using the binaryRegressor "
+ "discretizer. Consider extracting factual explanations using "
+ "`explainer.explain_factual(test_set)`",
stacklevel=2,
)
elif not isinstance(self.get_explainer().discretizer, BinaryEntropyDiscretizer):
warnings.warn(
"Factual explanations for classification recommend using the "
+ "binaryEntropy discretizer. Consider extracting factual "
+ "explanations using `explainer.explain_factual(test_set)`",
stacklevel=2,
)
def _rules_with_impact(
self, *, top_k: Optional[int] = None, sort: bool = True
) -> list[RuleWithImpact]:
"""Extract canonical rules with explicit signed impact.
This method is the source of truth for narrative and plotting consistency.
It defines 'impact' for FactualExplanation as the weight (delta from baseline).
"""
rules_dict = self.get_rules()
canonical_rules = []
# Get global prediction
prediction = self.prediction["predict"]
num_rules = len(rules_dict.get("rule", []))
for i in range(num_rules):
# weight in FactualExplanation is defined as (prediction - instance_predict)
# where instance_predict is the counterfactual/base
w = rules_dict["weight"][i]
# Canonical sign behavior:
# Positive impact = rule increased the prediction relative to base
# Negative impact = rule decreased the prediction relative to base
if w > 0:
direction = "positive"
elif w < 0:
direction = "negative"
else:
direction = "neutral"
feature_id = rules_dict["feature"][i]
if isinstance(feature_id, (list, tuple, np.ndarray)):
feature_names = []
for idx in list(feature_id):
with contextlib.suppress(Exception):
feature_names.append(str(self._safe_feature_name(idx)))
feature_label = " & ".join(feature_names) if feature_names else str(feature_id)
else:
feature_label = str(self._safe_feature_name(feature_id))
base_predicts = rules_dict.get("base_predict", [])
base_predict_value = (
base_predicts[i]
if i < len(base_predicts)
else (base_predicts[0] if base_predicts else None)
)
canonical_rules.append(
RuleWithImpact(
rule_id=str(feature_id), # Using feature index/list as ID for now
feature=feature_label,
text=rules_dict["rule"][i],
impact=float(w),
direction=direction,
base_predict=float(base_predict_value)
if base_predict_value is not None
else float("nan"),
predict=float(prediction),
value=rules_dict["value"][i],
weight_envelope_low=float(rules_dict["weight_low"][i]),
weight_envelope_high=float(rules_dict["weight_high"][i]),
predict_low=float(rules_dict["predict_low"][i]),
predict_high=float(rules_dict["predict_high"][i]),
)
)
# Stable sort by absolute impact
if sort:
canonical_rules.sort(key=lambda r: (-abs(r.impact), r.text))
if top_k is not None:
canonical_rules = canonical_rules[:top_k]
return canonical_rules
# -- Convenience retrieval API -------------------------------------------------
def _factual_rule_to_dict(self, rules: Dict[str, Any], idx: int) -> Dict[str, Any]:
"""Normalize an internal factual rule into a user-friendly dict."""
feature_index = rules["feature"][idx]
weight_value = CalibratedExplanation.to_python_number(rules["weight"][idx])
weight_interval = CalibratedExplanation._build_interval(
rules["weight_low"][idx], rules["weight_high"][idx]
)
return {
"index": int(idx),
"feature": self._safe_feature_name(feature_index),
"condition": rules["rule"][idx],
"weight": weight_value,
"uncertainty_interval": weight_interval,
"support": rules.get("support", [None])[idx]
if isinstance(rules.get("support"), list)
else None,
}
[docs]
def get_rule_by_index(self, index: int) -> Dict[str, Any]:
"""Return a single factual rule by its numeric index.
Raises ``IndexError`` when the index is out of range.
"""
rules = self.get_rules()
count = len(rules.get("rule", []))
if index < 0 or index >= count:
raise IndexError(f"No factual rule at index {index}")
return self._factual_rule_to_dict(rules, index)
[docs]
def get_rules_by_feature(self, feature: str) -> list[Dict[str, Any]]:
"""Return all factual rules that mention the given feature name.
Raises ``KeyError`` when no matching rules are found.
"""
rules = self.get_rules()
matches: list[Dict[str, Any]] = []
for i, f in enumerate(rules.get("feature", [])):
try:
fname = self._safe_feature_name(f)
except (TypeError, ValueError):
fname = str(f)
if fname == feature:
matches.append(self._factual_rule_to_dict(rules, i))
if not matches:
raise KeyError(f"No factual rules for feature {feature}")
return matches
[docs]
def list_rules(self) -> list[Dict[str, Any]]:
"""Return all factual rules as normalized dicts."""
rules = self.get_rules()
return [self._factual_rule_to_dict(rules, i) for i in range(len(rules.get("rule", [])))]
[docs]
def get_rules(self):
"""
Create factual rules.
Returns
-------
List[Dict[str, List]]
A list of dictionaries containing the factual rules.
"""
if (
getattr(self, "has_conjunctive_rules", False)
and getattr(self, "conjunctive_rules", None) is not None
):
return self.conjunctive_rules
if getattr(self, "has_rules", False) and isinstance(self.rules, dict):
return self.rules
# i = self.index
instance = np.array(self.x_test, copy=True)
state_helper = ConjunctionState(None)
state_helper.state["classes"] = self.prediction["classes"]
state_helper.set_base_prediction(
self.prediction["predict"], self.prediction["low"], self.prediction["high"]
)
rules = self.define_conditions()
ignored = self.ignored_features_for_instance()
for f, _ in enumerate(instance): # pylint: disable=invalid-name
if f in ignored:
continue
if self.prediction["predict"] == self.feature_predict["predict"][f]:
continue
value_str = ""
if f in self.get_explainer().categorical_features:
if self.get_explainer().categorical_labels is not None:
value_str = self.get_explainer().categorical_labels[f][int(instance[f])]
else:
value_str = str(instance[f])
else:
value_str = str(np.around(instance[f], decimals=2))
# Calculate weights matching legacy behavior (prediction - instance_predict)
# and using instance_predict uncertainty
instance_predict = self.feature_predict["predict"][f]
instance_low = self.feature_predict["low"][f]
instance_high = self.feature_predict["high"][f]
prediction = self.prediction["predict"]
w = prediction - instance_predict
w_low = prediction - instance_high if instance_high != np.inf else -np.inf
w_high = prediction - instance_low if instance_low != -np.inf else np.inf
state_helper.add_rule(
predict=self.prediction["predict"],
low=self.prediction["low"],
high=self.prediction["high"],
base_predict=instance_predict,
value=value_str,
feature=f,
sampled_values=self.binned["rule_values"][f][0][-1],
feature_value=self.x_test[f],
rule_text=rules[f],
is_conjunctive=False,
weight=w,
weight_low=w_low,
weight_high=w_high,
)
self.rules = state_helper.get_state()
self.has_rules = True
return self.rules
# pylint: disable=too-many-locals, too-many-branches, too-many-statements
[docs]
def add_conjunctions(self, n_top_features=5, max_rule_size=2, **kwargs):
"""Add conjunctive factual rules.
Builds conjunctions by iterating over feature combinations from
``current_size = 2`` up to ``max_rule_size``. For each size the
method ranks existing rules, then pairs every outer-loop feature
(from the original factual rules) with top-ranked inner-loop
candidates stored in ``ConjunctionState``. Duplicate feature
combinations are deduplicated and predictions are obtained via
:meth:`predict_conjunctive`.
Parameters
----------
n_top_features : int, optional
Number of top features to consider when ranking candidates.
max_rule_size : int, optional
Maximum number of features in a single conjunctive rule.
Other Parameters
----------------
_use_batched : bool
Use vectorized batch prediction (default ``True``).
_limit_outer_to_ranked : bool
Restrict outer loop to rank-filtered features (default ``False``).
_dedupe_by_feature_only : bool
Deduplicate by feature index only, ignoring sampled values
(default ``True``).
raise_on_predict_error : bool
Re-raise prediction exceptions instead of silently skipping
(default ``False``).
_fallback_to_legacy_on_zero : bool
Fall back to the legacy implementation when no conjunctions are
created (default ``False``).
Returns
-------
self : :class:`.FactualExplanation`
Returns a self reference, to allow for method chaining.
Notes
-----
After the call, ``self.conjunction_stats`` contains a diagnostic
dict with keys ``"attempts"``, ``"created"``,
``"skipped"`` (:class:`~collections.Counter`), and
``"predict_errors"`` (list of up to 5 captured exception messages).
"""
use_batched = kwargs.get("_use_batched", True)
limit_outer_to_ranked = kwargs.get("_limit_outer_to_ranked", False)
dedupe_by_feature_only = kwargs.get("_dedupe_by_feature_only", True)
raise_on_predict_error = kwargs.get("raise_on_predict_error", False)
if max_rule_size >= 4 and not use_batched:
from ..utils.exceptions import ConfigurationError
raise ConfigurationError(
"max_rule_size >= 4 requires batched execution (internal flag _use_batched=True)",
details={
"param": "max_rule_size",
"value": max_rule_size,
"valid_range": [2, 3],
},
)
if max_rule_size < 2:
return self
if not use_batched:
from .legacy_conjunctions import add_conjunctions_factual_legacy
add_conjunctions_factual_legacy(
self, n_top_features=n_top_features, max_rule_size=max_rule_size
)
return self
factual = self.get_rules() if not self.has_rules else self.rules
state_helper = ConjunctionState(
self.conjunctive_rules
if self.has_conjunctive_rules and self.conjunctive_rules is not None
else factual,
dedupe_by_feature_only=dedupe_by_feature_only,
)
self.has_conjunctive_rules = False
self.conjunctive_rules = []
threshold = None if self.y_threshold is None else self.y_threshold
scratch = np.array(self.x_test, copy=True)
predicted_class = factual["classes"]
state_helper.state["classes"] = predicted_class
base_weight_array = (
np.asarray(factual["weight"], dtype=float) if factual["weight"] else np.array([])
)
base_width_array = (
np.asarray(factual["weight_high"], dtype=float)
- np.asarray(factual["weight_low"], dtype=float)
if factual["weight"]
else np.array([])
)
if n_top_features is None:
n_top_features = len(factual["rule"])
def _feature_length(candidate: Any) -> int:
if isinstance(candidate, (list, tuple, np.ndarray)):
return len(candidate)
return 1
def _coerce_feature_scalar(value: Any) -> int:
if isinstance(value, (list, tuple, np.ndarray)):
return int(np.asarray(value).ravel()[0])
return int(value)
from collections import Counter
self.conjunction_stats = {
"attempts": 0,
"created": 0,
"skipped": Counter(),
"predict_errors": [],
}
stats = self.conjunction_stats
max_logged_errors = 5
def _summarize(arr):
if arr is None or arr.size == 0:
return {"size": 0}
return {
"size": int(arr.size),
"min": float(np.nanmin(arr)),
"max": float(np.nanmax(arr)),
"nan": int(np.isnan(arr).sum()),
}
for current_size in range(2, max_rule_size + 1):
num_rules = len(factual["rule"])
if num_rules == 0:
break
weights_array = state_helper.get_weights()
width_array = state_helper.get_widths()
top_n = min(num_rules, n_top_features)
top_conjunctives = list(
self.rank_features(
weights_array,
width=width_array,
num_to_show=max(2, top_n) if num_rules >= 2 else top_n,
)
)
if not top_conjunctives and num_rules > 0:
top_conjunctives = list(range(num_rules))
# Determine outer loop candidates
# Legacy behavior: iterate all features
if limit_outer_to_ranked:
num_outer = min(num_rules, n_top_features)
outer_indices = list(
self.rank_features(
base_weight_array,
width=base_width_array if base_width_array.size else None,
num_to_show=max(2, num_outer) if num_rules >= 2 else num_outer,
)
)
if not outer_indices and num_rules > 0:
outer_indices = list(range(num_rules))
else:
outer_indices = range(len(factual["feature"]))
for f1 in outer_indices:
of1 = int(factual["feature"][f1])
sampled_values1 = factual["sampled_values"][f1]
rule_value1 = (
sampled_values1
if isinstance(sampled_values1, np.ndarray)
else [sampled_values1]
)
for cf2 in top_conjunctives:
stats["attempts"] += 1
rule_values = [rule_value1]
original_features = [of1]
of2 = state_helper.get_feature(cf2)
target_length = current_size - 1
if _feature_length(of2) != target_length:
stats["skipped"]["len_mismatch"] += 1
continue
if state_helper.is_conjunctive(cf2):
if of1 in of2:
stats["skipped"]["same_feature"] += 1
continue
original_features.extend(int(v) for v in of2)
rule_values.extend(list(state_helper.get_sampled_values(cf2)))
else:
of2 = _coerce_feature_scalar(of2)
if of1 == of2:
stats["skipped"]["same_feature"] += 1
continue
original_features.append(of2)
sampled_values2 = state_helper.get_sampled_values(cf2)
rule_values.append(
sampled_values2
if isinstance(sampled_values2, np.ndarray)
else [sampled_values2]
)
if state_helper.has_combination_key(original_features, rule_values):
stats["skipped"]["duplicate_combo"] += 1
continue
state_helper.register_combination_key(original_features, rule_values)
try:
rule_predict, rule_low, rule_high = self.predict_conjunctive(
rule_values,
original_features,
scratch,
threshold,
predicted_class,
bins=self.bin,
use_batched=use_batched,
)
except (
CalibratedError,
ValueError,
TypeError,
RuntimeError,
Exception, # adr002_allow - defensive guard for predict_conjunctive failures
) as e:
if raise_on_predict_error:
raise
stats["skipped"]["predict_error"] += 1
if len(stats["predict_errors"]) < max_logged_errors:
stats["predict_errors"].append(f"{type(e).__name__}: {e}")
continue
state_helper.add_rule(
predict=rule_predict,
low=rule_low,
high=rule_high,
base_predict=self.prediction["predict"],
value=factual["value"][f1] + "\n" + state_helper.get_value(cf2),
feature=list(original_features),
sampled_values=list(rule_values),
feature_value=None,
rule_text=factual["rule"][f1] + " & \n" + state_helper.get_rule(cf2),
)
stats["created"] += 1
self.conjunctive_rules = state_helper.get_state()
self.has_conjunctive_rules = True
if stats["created"] == 0 and stats["attempts"] > 0:
summary_weights = _summarize(base_weight_array)
import warnings
err_msg = (
f" predict_errors={stats['predict_errors']}" if stats["predict_errors"] else ""
)
warning_msg = (
f"add_conjunctions: created={stats['created']} "
f"attempts={stats['attempts']} skipped={dict(stats['skipped'])} "
f"weights={summary_weights}{err_msg}"
)
if kwargs.get("_fallback_to_legacy_on_zero", False):
warnings.warn(warning_msg + " (falling back to legacy)", UserWarning, stacklevel=2)
from .legacy_conjunctions import add_conjunctions_factual_legacy
return add_conjunctions_factual_legacy(
self, n_top_features=n_top_features, max_rule_size=max_rule_size
)
if kwargs.get("verbose", False):
warnings.warn(warning_msg, UserWarning, stacklevel=2)
return self
def _is_lesser(self, rule_boundary, instance_value):
"""Return whether `instance_value` falls below the provided rule boundary."""
return instance_value < rule_boundary
[docs]
def plot(self, filter_top=None, **kwargs):
"""Plot the factual explanation for a given instance.
Parameters
----------
filter_top : int, optional
The number of top features to display.
**kwargs : dict
Additional plotting arguments:
- show (bool): default=True if filename is empty, False otherwise.
Determines whether the plot should be displayed or not.
- filename (str): default=''. The full path and filename of the plot
image file that will be saved. If not provided or empty, the plot
will not be saved.
- uncertainty (bool): default=False. Whether to plot the uncertainty
intervals for the feature weights.
- style (str): default='regular'. The style of the plot. Possible
styles:
* 'regular' - a regular plot with feature weights and uncertainty
intervals.
* 'narrative' - generate human-readable narrative explanations.
- rnk_metric (str): default='feature_weight'. The metric used to
rank the features. Supported metrics are 'ensured',
'feature_weight', and 'uncertainty'.
- rnk_weight (float): default=0.5. The weight of the uncertainty in
the ranking. Used with the 'ensured' ranking metric.
"""
requested_style = kwargs.get("style")
custom_plot_style = isinstance(requested_style, str) and requested_style not in {
"regular",
"triangular",
"ensured",
"narrative",
}
# Ensure style_override gets passed through
style_override = kwargs.pop("style_override", None)
plot_use_legacy = kwargs.pop("use_legacy", None)
# PlotSpec request forces new renderer
if kwargs.get("return_plot_spec") or custom_plot_style:
plot_use_legacy = False
# Phase 2 Option B: Default to legacy to ensure parity until PlotSpec is fully hardened
elif plot_use_legacy is None:
plot_use_legacy = True
filename = kwargs.pop("filename", "")
show = kwargs.pop("show", filename == "")
uncertainty = kwargs.pop("uncertainty", False)
rnk_metric = kwargs.pop("rnk_metric", "feature_weight")
if rnk_metric is None:
rnk_metric = "feature_weight"
rnk_weight = kwargs.pop("rnk_weight", 0.5)
if rnk_metric == "uncertainty":
rnk_weight = 1.0
rnk_metric = "ensured"
# Consistency guard: one-sided intervals cannot show uncertainty bands
if uncertainty and self.is_one_sided():
raise Warning("Interval plot is not supported for one-sided explanations.")
# Use conjunctive rules when available so that conjunctions appear in plots
if getattr(self, "has_conjunctive_rules", False) and getattr(
self, "conjunctive_rules", None
):
factual = self.conjunctive_rules
else:
factual = self.get_rules() # get_explanation(index)
self._check_preconditions()
predict = self.prediction
num_features_to_show = len(factual["weight"])
if filter_top is None:
filter_top = num_features_to_show
filter_top = np.min([num_features_to_show, filter_top])
if filter_top <= 0:
warnings.warn(
f"The explanation has no rules to plot. The index of the instance is {self.index}",
stacklevel=2,
)
return
if len(filename) > 0:
path, filename, title, ext = prepare_for_saving(filename)
path = f"plots/{path}"
save_ext = [ext]
else:
path = ""
title = ""
save_ext = []
if uncertainty:
feature_weights = {
"predict": factual["weight"],
"low": factual["weight_low"],
"high": factual["weight_high"],
}
# Phase 2: Inject canonical color roles if using PlotSpec
if not plot_use_legacy:
# Derive canonical directions from authoritative source
canonical_rules = self._rules_with_impact(sort=False)
feature_weights["color_role"] = [r.direction for r in canonical_rules]
else:
if not plot_use_legacy:
# Wrap in dict to pass color info to PlotSpec
canonical_rules = self._rules_with_impact(sort=False)
feature_weights = {
"predict": factual["weight"],
"color_role": [r.direction for r in canonical_rules],
}
else:
feature_weights = factual["weight"]
width = np.reshape(
np.array(factual["weight_high"]) - np.array(factual["weight_low"]),
(len(factual["weight"])),
)
if rnk_metric == "feature_weight":
features_to_plot = self.rank_features(
factual["weight"], width=width, num_to_show=filter_top
)
else:
ranking = calculate_metrics(
uncertainty=[
factual["predict_high"][i] - factual["predict_low"][i]
for i in range(len(factual["weight"]))
],
prediction=factual["predict"],
w=rnk_weight,
metric=rnk_metric,
)
features_to_plot = self.rank_features(width=ranking, num_to_show=filter_top)
# Prefer explicit feature/column names when available; fall back to rule strings
column_names = (
factual.get("feature_names") or factual.get("column_names") or factual.get("rule")
)
try:
if "classification" in self.get_explainer().mode or self.is_thresholded():
return plot_probabilistic(
self,
factual["value"],
predict,
feature_weights,
features_to_plot,
filter_top,
column_names,
title=title,
path=path,
interval=uncertainty,
show=show,
idx=self.index,
save_ext=save_ext,
style_override=style_override,
use_legacy=plot_use_legacy,
**kwargs,
)
else:
return plot_regression(
self,
factual["value"],
predict,
feature_weights,
features_to_plot,
filter_top,
column_names,
title=title,
path=path,
interval=uncertainty,
show=show,
idx=self.index,
save_ext=save_ext,
style_override=style_override,
use_legacy=plot_use_legacy,
**kwargs,
)
except (
Exception
) as e: # ADR002_ALLOW: plot renderers may fail on headless hosts. # pragma: no cover
if isinstance(e, RuntimeError) and "Agg" in str(e):
from ..utils.exceptions import ConfigurationError
raise ConfigurationError(
"Matplotlib backend 'Agg' does not support show(). "
"Either set show=False or switch to a different backend.",
details={
"backend": "Agg",
"operation": "show()",
"solution": "set show=False or use interactive backend",
},
) from e
raise # core-only test runs do not fail when visualization extras are
# unavailable. Tests that require viz should use pytest.importorskip.
warnings.warn(
f"Plotting unavailable: {e}",
UserWarning,
stacklevel=2,
)
return None
[docs]
class AlternativeExplanation(CalibratedExplanation):
"""Store and visualise calibrated alternative explanations.
Consistent with the CE papers, alternative explanations surface a collection
of alternative feature rules. Each rule pairs an alternative condition for
the feature with the calibrated prediction estimate and its uncertainty
interval for that scenario. Feature-weight deltas are retained internally
for ranking and metadata, but the user-facing payload **must not** replace
the prediction + interval pair with weights—the prediction interval is the
authoritative quantity for each alternative rule.
For detailed information about the internal data structures and rule generation
process, see docs/foundations/concepts/explanation_structures.md.
"""
def __init__(
self,
calibrated_explanations,
index,
x,
binned,
feature_weights,
feature_predict,
prediction,
y_threshold=None,
instance_bin=None,
condition_source: str = "prediction",
):
"""Class representing an alternative explanation for a given instance.
Offers alternative explanations by exploring how changes to feature values could alter the model's prediction.
Initialize an AlternativeExplanation instance.
Parameters
----------
calibrated_explanations : CalibratedExplanations
The parent CalibratedExplanations object.
index : int
The index of the instance being explained.
x : array-like
The test dataset containing the instances to be explained.
binned : dict
A mapping of binned feature values.
feature_weights : dict
A mapping of feature weights.
feature_predict : dict
A mapping of feature predictions.
prediction : dict
A mapping containing the prediction results.
y_threshold : float or tuple, optional
The threshold for binary classification or regression explanations.
instance_bin : int, optional
The bin index of the instance.
"""
super().__init__(
calibrated_explanations,
index,
x,
binned,
feature_weights,
feature_predict,
prediction,
y_threshold,
instance_bin,
condition_source=condition_source,
)
self._check_preconditions()
self.has_rules = False
self.get_rules()
self.__is_super_explanation = False
self.__is_semi_explanation = False
self.__is_counter_explanation = False
def __repr__(self):
"""Return a string representation of the alternative explanation."""
alternative = self.get_rules()
output = [
f"{'Prediction':10} [{' Low':5}, {' High':5}]",
f"{alternative['base_predict'][0]:5.3f} [{alternative['base_predict_low'][0]:5.3f}, {alternative['base_predict_high'][0]:5.3f}]",
f"{'Value':6}: {'Feature':40s} {'Prediction':10} [{' Low':6}, {' High':6}]",
]
feature_order = self.rank_features(
alternative["weight"],
width=np.array(alternative["weight_high"]) - np.array(alternative["weight_low"]),
num_to_show=len(alternative["rule"]),
)
output.extend(
f"{alternative['value'][f]:6}: {alternative['rule'][f]:40s} {alternative['predict'][f]:>6.3f} [{alternative['predict_low'][f]:>6.3f}, {alternative['predict_high'][f]:>6.3f}]"
for f in reversed(feature_order)
)
return "\n".join(output) + "\n"
def _rules_with_impact(
self, *, top_k: Optional[int] = None, sort: bool = True
) -> list[RuleWithImpact]:
"""Extract canonical rules with explicit signed impact for alternative explanations.
For alternative explanations, impact represents the change in prediction
from the base instance prediction to the alternative scenario prediction.
"""
rules_dict = self.get_rules()
canonical_rules = []
# Base prediction is the instance prediction
base_predict = self.prediction["predict"]
base_predict_value = base_predict
if isinstance(base_predict_value, (list, tuple, np.ndarray)):
base_predict_value = base_predict_value[0] if len(base_predict_value) else 0.0
def _format_feature(feature_index: Any) -> tuple[str, str]:
if isinstance(feature_index, (list, tuple, np.ndarray)):
raw_ids = []
names = []
for idx in feature_index:
try:
raw_ids.append(str(int(idx)))
except (TypeError, ValueError): # ADR002_ALLOW: tolerate non-numeric ids.
raw_ids.append(str(idx))
names.append(self._safe_feature_name(idx))
return ",".join(raw_ids), " & ".join(names)
return str(feature_index), self._safe_feature_name(feature_index)
num_rules = len(rules_dict.get("rule", []))
for i in range(num_rules):
# weight in AlternativeExplanation is defined as (alternative_predict - instance_predict)
w = rules_dict["weight"][i]
# Canonical sign behavior:
# Positive impact = alternative prediction > base prediction
# Negative impact = alternative prediction < base prediction
if w > 0:
direction = "positive"
elif w < 0:
direction = "negative"
else:
direction = "neutral"
feature_index = rules_dict["feature"][i]
rule_id, feature_name = _format_feature(feature_index)
canonical_rules.append(
RuleWithImpact(
rule_id=rule_id,
feature=feature_name,
text=rules_dict["rule"][i],
impact=float(w),
direction=direction,
base_predict=float(base_predict_value),
predict=float(rules_dict["predict"][i]), # Alternative prediction
value=rules_dict["value"][i],
weight_envelope_low=float(rules_dict["weight_low"][i]),
weight_envelope_high=float(rules_dict["weight_high"][i]),
predict_low=float(rules_dict["predict_low"][i]),
predict_high=float(rules_dict["predict_high"][i]),
)
)
# Stable sort by absolute impact
if sort:
canonical_rules.sort(key=lambda r: (-abs(r.impact), r.text))
if top_k is not None:
canonical_rules = canonical_rules[:top_k]
return canonical_rules
[docs]
def build_rules_payload(self) -> Dict[str, Any]:
"""Return structured payload describing alternative feature rules."""
rules = self.get_rules()
core: Dict[str, Any] = {"kind": "alternative", "feature_rules": []}
metadata: Dict[str, Any] = {"feature_rules": []}
if not rules or "rule" not in rules:
return {"core": core, "metadata": metadata}
percentiles = None
if not self.is_probabilistic() and not self.is_thresholded():
percentiles = self._get_percentiles()
prediction_representation = (
"threshold"
if self.is_thresholded()
else ("venn_abers" if self.is_probabilistic() else "percentile")
)
weight_representation = "venn_abers" if self.is_probabilistic() else "percentile"
count = len(rules.get("rule", []))
for idx in range(count):
feature_index = rules["feature"][idx]
condition = self._build_condition_payload(
feature_index,
rules["rule"][idx],
rules["sampled_values"][idx],
rules["value"][idx],
)
prediction_value = CalibratedExplanation.to_python_number(rules["predict"][idx])
prediction_interval = CalibratedExplanation._build_interval(
rules["predict_low"][idx],
rules["predict_high"][idx],
)
core["feature_rules"].append(
{
"prediction": {
"value": prediction_value,
"uncertainty_interval": prediction_interval,
},
"condition": condition,
}
)
prediction_uncertainty = self._build_uncertainty_payload(
value=rules["predict"][idx],
low=rules["predict_low"][idx],
high=rules["predict_high"][idx],
representation=prediction_representation,
percentiles=(percentiles if prediction_representation == "percentile" else None),
threshold=self.normalize_threshold_value() if self.is_thresholded() else None,
include_percentiles=prediction_representation == "percentile",
)
weight_value = CalibratedExplanation.to_python_number(rules["weight"][idx])
weight_uncertainty = self._build_uncertainty_payload(
value=rules["weight"][idx],
low=rules["weight_low"][idx],
high=rules["weight_high"][idx],
representation=weight_representation,
percentiles=(percentiles if weight_representation == "percentile" else None),
include_percentiles=weight_representation == "percentile",
)
metadata_rule: Dict[str, Any] = {
"feature": self._safe_feature_name(feature_index),
"feature_index": CalibratedExplanation.to_python_number(feature_index),
"prediction_uncertainty": prediction_uncertainty,
"prediction_value": prediction_value,
"weight_value": weight_value,
"weight_uncertainty": weight_uncertainty,
"condition_text": rules["rule"][idx],
"instance_value": CalibratedExplanation.to_python_number(
rules["sampled_values"][idx]
),
"alternative_value": CalibratedExplanation.to_python_number(rules["value"][idx]),
}
if self.is_thresholded():
metadata_rule["threshold"] = self.normalize_threshold_value()
metadata["feature_rules"].append(metadata_rule)
metadata["prediction_uncertainty"] = self._build_instance_uncertainty()
return {"core": core, "metadata": metadata}
def _check_preconditions(self):
"""Warn when the configured discretizer is unsuitable for alternative explanations."""
if self.is_regression():
if not isinstance(self.get_explainer().discretizer, RegressorDiscretizer):
warnings.warn(
"Alternative explanations for regression recommend using the "
+ "regressor discretizer. Consider extracting alternative "
+ "explanations using `explainer.explain_alternatives(test_set)`",
stacklevel=2,
)
elif not isinstance(self.get_explainer().discretizer, EntropyDiscretizer):
warnings.warn(
"Alternative explanations for classification recommend using "
+ "the entropy discretizer. Consider extracting alternative "
+ "explanations using `explainer.explain_alternatives(test_set)`",
stacklevel=2,
)
# pylint: disable=too-many-statements, too-many-branches
[docs]
def get_rules(self):
"""
Create alternative rules.
Returns
-------
Array-like : List[Dict[str, List]]
A list of dictionaries containing the alternative rules.
"""
if (
getattr(self, "has_conjunctive_rules", False)
and getattr(self, "conjunctive_rules", None) is not None
):
return self.conjunctive_rules
if getattr(self, "has_rules", False) and isinstance(self.rules, dict):
return self.rules
self.rules = []
self.labels = {} # pylint: disable=attribute-defined-outside-init
instance = np.array(self.x_test, copy=True)
instance.flags.writeable = False
# pylint: disable=protected-access
discretized = self.get_explainer().discretize(instance.reshape(1, -1))[0]
instance_predict = self.binned["predict"]
instance_low = self.binned["low"]
instance_high = self.binned["high"]
state_helper = ConjunctionState(None)
state_helper.state["classes"] = self.prediction["classes"]
state_helper.set_base_prediction(
self.prediction["predict"], self.prediction["low"], self.prediction["high"]
)
rule_boundaries = self.get_explainer().rule_boundaries(instance)
ignored = self.ignored_features_for_instance()
for f, _ in enumerate(instance): # pylint: disable=invalid-name
if f in ignored:
continue
if f in self.get_explainer().categorical_features:
values = np.array(self.get_explainer().feature_values[f])
values = np.delete(values, values == discretized[f])
for value_bin, value in enumerate(values):
# skip if identical to original
if (
self.prediction["low"] == instance_low[f][value_bin]
and self.prediction["high"] == instance_high[f][value_bin]
):
continue
value_str = ""
if self.get_explainer().categorical_labels is not None:
value_str = self.get_explainer().categorical_labels[f][int(instance[f])]
else:
value_str = str(np.around(instance[f], decimals=2))
rule_text = ""
if self.get_explainer().categorical_labels is not None:
self.labels[len(state_helper.state["rule"])] = f
rule_text = (
f"{self.get_explainer().feature_names[f]} = "
+ f"{self.get_explainer().categorical_labels[f][int(value)]}"
)
else:
rule_text = f"{self.get_explainer().feature_names[f]} = {value}"
state_helper.add_rule(
predict=instance_predict[f][value_bin],
low=instance_low[f][value_bin],
high=instance_high[f][value_bin],
base_predict=self.prediction["predict"],
value=value_str,
feature=f,
sampled_values=value,
feature_value=self.x_test[f],
rule_text=rule_text,
is_conjunctive=False,
)
else:
values = np.array(self.get_explainer().x_cal[:, f])
lesser = rule_boundaries[f][0]
greater = rule_boundaries[f][1]
value_bin = 0
if np.any(values < lesser):
# skip if identical to original
if self.prediction["low"] == safe_mean(
instance_low[f][value_bin]
) and self.prediction["high"] == safe_mean(instance_high[f][value_bin]):
pass
else:
state_helper.add_rule(
predict=safe_mean(instance_predict[f][value_bin]),
low=safe_mean(instance_low[f][value_bin]),
high=safe_mean(instance_high[f][value_bin]),
base_predict=self.prediction["predict"],
value=str(np.around(instance[f], decimals=2)),
feature=f,
sampled_values=self.binned["rule_values"][f][0][0],
feature_value=self.x_test[f],
rule_text=f"{self.get_explainer().feature_names[f]} < {lesser:.2f}",
is_conjunctive=False,
)
value_bin = 1
if np.any(values > greater):
# skip if identical to original
if self.prediction["low"] == safe_mean(
instance_low[f][value_bin]
) and self.prediction["high"] == safe_mean(instance_high[f][value_bin]):
pass
else:
state_helper.add_rule(
predict=safe_mean(instance_predict[f][value_bin]),
low=safe_mean(instance_low[f][value_bin]),
high=safe_mean(instance_high[f][value_bin]),
base_predict=self.prediction["predict"],
value=str(np.around(instance[f], decimals=2)),
feature=f,
sampled_values=self.binned["rule_values"][f][0][
1 if len(self.binned["rule_values"][f][0]) == 3 else 0
],
feature_value=self.x_test[f],
rule_text=f"{self.get_explainer().feature_names[f]} > {greater:.2f}",
is_conjunctive=False,
)
self.rules = state_helper.get_state()
self.has_rules = True
return self.rules
def __set_up_result(self):
"""Initialise the container used to build alternative explanation rules."""
result = {
"base_predict": [],
"base_predict_low": [],
"base_predict_high": [],
"predict": [],
"predict_low": [],
"predict_high": [],
"weight": [],
"weight_low": [],
"weight_high": [],
"value": [],
"rule": [],
"feature": [],
"sampled_values": [],
"feature_value": [],
"is_conjunctive": [],
"classes": self.prediction["classes"],
}
result["base_predict"].append(self.prediction["predict"])
result["base_predict_low"].append(self.prediction["low"])
result["base_predict_high"].append(self.prediction["high"])
return result
[docs]
def is_super_explanation(self):
"""Determine if the explanation is a super-explanation."""
return self.__is_super_explanation
[docs]
def is_semi_explanation(self):
"""Determine if the explanation is a semi-explanation."""
return self.__is_semi_explanation
[docs]
def is_counter_explanation(self):
"""Determine if the explanation is a counter-explanation."""
return self.__is_counter_explanation
def __append_rule(self, new_rules, rules, rule):
"""Append a single rule from *rules* at index *rule* to *new_rules*."""
new_rules["predict"].append(rules["predict"][rule])
new_rules["predict_low"].append(rules["predict_low"][rule])
new_rules["predict_high"].append(rules["predict_high"][rule])
new_rules["weight"].append(rules["weight"][rule])
new_rules["weight_low"].append(rules["weight_low"][rule])
new_rules["weight_high"].append(rules["weight_high"][rule])
new_rules["value"].append(rules["value"][rule])
new_rules["rule"].append(rules["rule"][rule])
new_rules["feature"].append(rules["feature"][rule])
new_rules["sampled_values"].append(rules["sampled_values"][rule])
if "feature_value" in rules:
new_rules["feature_value"].append(rules["feature_value"][rule])
else:
new_rules["feature_value"].append(None)
new_rules["is_conjunctive"].append(rules["is_conjunctive"][rule])
def __filter_rules(
self,
only_ensured=False,
make_super=False,
make_semi=False,
make_counter=False,
include_potential=False,
):
"""Filter rules based on the explanation type."""
is_plain_regression = self.is_regression() and not self.is_probabilistic()
initial_uncertainty = np.abs(self.prediction["high"] - self.prediction["low"])
new_rules = self.__set_up_result()
rules = self.get_rules() # pylint: disable=protected-access
if is_plain_regression:
# For plain regression, redefine filtering concepts:
# - super: higher prediction than original
# - semi/counter: lower prediction than original (identical)
# - potential: alternative interval covers the original prediction
# - ensured: smaller uncertainty interval (unchanged)
for rule in range(len(rules["rule"])):
is_potential = (
rules["predict_low"][rule]
<= self.prediction["predict"]
<= rules["predict_high"][rule]
)
if not include_potential and is_potential:
continue
# Super: keep only rules with higher prediction
if make_super and rules["predict"][rule] <= self.prediction["predict"]:
continue
# Semi: for plain regression, keep alternatives where the
# uncertainty intervals mutually include the other's mean
# (i.e. conservative 'semi' definition). Use predict and
# predict_low/predict_high for comparisons.
if make_semi:
try:
rule_mean = float(rules["predict"][rule])
rule_low = float(rules["predict_low"][rule])
rule_high = float(rules["predict_high"][rule])
base_mean = float(self.prediction["predict"])
base_low = float(self.prediction["low"])
base_high = float(self.prediction["high"])
except (TypeError, ValueError):
# If values are not numeric, skip this rule
continue
if not (
(rule_low <= base_mean <= rule_high)
and (base_low <= rule_mean <= base_high)
):
continue
# Counter: keep only rules with lower prediction than original
if make_counter and rules["predict"][rule] >= self.prediction["predict"]:
continue
if (
only_ensured
and rules["predict_high"][rule] - rules["predict_low"][rule]
> initial_uncertainty
):
continue
if (
rules["base_predict_low"] == rules["predict_low"][rule]
and rules["base_predict_high"] == rules["predict_high"][rule]
and rules["predict"][rule] == self.prediction["predict"]
):
continue
self.__append_rule(new_rules, rules, rule)
else:
positive_class = self.prediction["predict"] > 0.5
for rule in range(len(rules["rule"])):
is_potential = rules["predict_low"][rule] < 0.5 < rules["predict_high"][rule]
# filter out potential rules if include_potential is False
if not include_potential and is_potential:
continue
# Compute point-based membership (always enforced).
rule_predict = rules["predict"][rule]
# super: moves further into the predicted class (away from 0.5)
is_super_by_point = (
positive_class and rule_predict > self.prediction["predict"]
) or (not positive_class and rule_predict < self.prediction["predict"])
# semi: same side as base but closer to the decision boundary (towards 0.5)
if positive_class:
is_semi_by_point = (rule_predict > 0.5) and (
rule_predict < self.prediction["predict"]
)
else:
is_semi_by_point = (rule_predict < 0.5) and (
rule_predict > self.prediction["predict"]
)
# counter: crosses the decision boundary (opposite side of 0.5)
is_counter_by_point = (positive_class and rule_predict <= 0.5) or (
not positive_class and rule_predict >= 0.5
)
# Enforce membership by point-prediction for all modes. Potentials
# are still controlled by the `include_potential` flag above, but
# when included they must also satisfy the point-based comparator.
if make_super and not is_super_by_point:
continue
if make_semi and not is_semi_by_point:
continue
if make_counter and not is_counter_by_point:
continue
# if only_ensured is True, filter out rules that lead to increased uncertainty
if (
only_ensured
and rules["predict_high"][rule] - rules["predict_low"][rule]
> initial_uncertainty
):
continue
# filter out rules that does not provide a different prediction
if (
rules["base_predict_low"] == rules["predict_low"][rule]
and rules["base_predict_high"] == rules["predict_high"][rule]
):
continue
self.__append_rule(new_rules, rules, rule)
new_rules["classes"] = rules["classes"]
if self.has_conjunctive_rules: # pylint: disable=protected-access
self.__extracted_non_conjunctive_rules(new_rules)
self.rules = new_rules
return self
def __pareto_rule_indexes(self, rules, *, pareto_cost: str):
"""Return rule indices on the output-envelope Pareto frontier.
The output value (probability for classification or calibrated output
for regression) is treated as the coverage axis, while the Pareto
*cost* dimension is minimized.
Supported Pareto cost dimensions:
- ``"uncertainty_width"``: minimize interval width (``high - low``).
- ``"rule_size"``: minimize number of features changed in the rule
(1 for atomic rules; >1 for conjunctive rules).
"""
rule_count = len(rules.get("rule", []))
if rule_count <= 1:
return list(range(rule_count))
def _rule_size(feature: Any) -> float:
if isinstance(feature, (list, tuple, np.ndarray)):
return float(len(np.asarray(feature).ravel()))
return 1.0
def _rule_cost(index: int) -> float:
if pareto_cost == "uncertainty_width":
return float(rules["predict_high"][index]) - float(rules["predict_low"][index])
if pareto_cost == "rule_size":
features = rules.get("feature", [])
feature_value = features[index] if index < len(features) else None
return _rule_size(feature_value)
raise ValidationError(
"pareto_cost must be one of: uncertainty_width, rule_size",
details={"pareto_cost": pareto_cost},
)
tolerance = 1e-12
best_per_output = {}
for index in range(rule_count):
output_value = float(rules["predict"][index])
cost_value = _rule_cost(index)
output_key = round(output_value, 12)
current_best = best_per_output.get(output_key)
if current_best is None:
best_per_output[output_key] = {
"index": index,
"output": output_value,
"cost": cost_value,
}
continue
if (
cost_value < current_best["cost"] - tolerance
or math.isclose(
cost_value,
current_best["cost"],
rel_tol=tolerance,
abs_tol=tolerance,
)
and index < current_best["index"]
):
best_per_output[output_key] = {
"index": index,
"output": output_value,
"cost": cost_value,
}
candidates = sorted(best_per_output.values(), key=lambda candidate: candidate["output"])
if len(candidates) <= 2:
return sorted(candidate["index"] for candidate in candidates)
left_mins = []
running_left_min = float("inf")
for candidate in candidates:
running_left_min = min(running_left_min, candidate["cost"])
left_mins.append(running_left_min)
right_mins = [0.0] * len(candidates)
running_right_min = float("inf")
for reverse_index in range(len(candidates) - 1, -1, -1):
running_right_min = min(running_right_min, candidates[reverse_index]["cost"])
right_mins[reverse_index] = running_right_min
kept_indexes = {
candidates[0]["index"],
candidates[-1]["index"],
}
for position, candidate in enumerate(candidates):
cost_value = candidate["cost"]
if (
cost_value <= left_mins[position] + tolerance
or cost_value <= right_mins[position] + tolerance
):
kept_indexes.add(candidate["index"])
return sorted(kept_indexes)
def __pareto_filter_rules(self, *, pareto_cost: str):
"""Reduce current rules to the output-envelope Pareto frontier."""
rules = self.get_rules() # pylint: disable=protected-access
pareto_indexes = set(self.__pareto_rule_indexes(rules, pareto_cost=pareto_cost))
new_rules = self.__set_up_result()
for rule in range(len(rules.get("rule", []))):
if rule not in pareto_indexes:
continue
self.__append_rule(new_rules, rules, rule)
new_rules["classes"] = rules["classes"]
if self.has_conjunctive_rules: # pylint: disable=protected-access
self.__extracted_non_conjunctive_rules(new_rules)
self.rules = new_rules
return self
# -- Convenience retrieval API -------------------------------------------------
def _alternative_rule_to_dict(self, rules: Dict[str, Any], idx: int) -> Dict[str, Any]:
"""Normalize an internal alternative rule into a user-friendly dict."""
feature_index = rules["feature"][idx]
prediction_value = CalibratedExplanation.to_python_number(rules["predict"][idx])
prediction_interval = CalibratedExplanation._build_interval(
rules["predict_low"][idx], rules["predict_high"][idx]
)
return {
"index": int(idx),
"feature": self._safe_feature_name(feature_index),
"condition": rules["rule"][idx],
"alternative_prediction": prediction_value,
"uncertainty_interval": prediction_interval,
}
[docs]
def get_rule_by_index(self, index: int) -> Dict[str, Any]:
"""Return a single alternative rule by its numeric index.
Raises ``IndexError`` when the index is out of range.
"""
rules = self.get_rules()
count = len(rules.get("rule", []))
if index < 0 or index >= count:
raise IndexError(f"No alternative rule at index {index}")
return self._alternative_rule_to_dict(rules, index)
[docs]
def get_rules_by_feature(self, feature: str) -> list[Dict[str, Any]]:
"""Return all alternative rules that mention the given feature name.
Raises ``KeyError`` when no matching rules are found.
"""
rules = self.get_rules()
matches: list[Dict[str, Any]] = []
for i, f in enumerate(rules.get("feature", [])):
try:
fname = self._safe_feature_name(f)
except (TypeError, ValueError):
fname = str(f)
if fname == feature:
matches.append(self._alternative_rule_to_dict(rules, i))
if not matches:
raise KeyError(f"No alternative rules for feature {feature}")
return matches
[docs]
def list_rules(self) -> list[Dict[str, Any]]:
"""Return all alternative rules as normalized dicts."""
rules = self.get_rules()
return [self._alternative_rule_to_dict(rules, i) for i in range(len(rules.get("rule", [])))]
def __extracted_non_conjunctive_rules(self, new_rules):
"""Split out non-conjunctive rules while preserving the original mapping."""
self.conjunctive_rules = MappingProxyType(
{k: list(v) if isinstance(v, list) else v for k, v in new_rules.items()}
)
mask = [not is_conj for is_conj in new_rules["is_conjunctive"]]
for k, v in new_rules.items():
if isinstance(v, list) and len(v) == len(mask):
new_rules[k] = [val for i, val in enumerate(v) if mask[i]]
self.rules = new_rules
[docs]
def reset(self):
"""Reset the explanation to its original state."""
self.__is_super_explanation = False
self.__is_semi_explanation = False
self.__is_counter_explanation = False
self.has_rules = False
self.get_rules()
return self
[docs]
def super_explanations(self, only_ensured=False, include_potential=True, copy=True):
"""Return a filtered view of *super* alternative explanations.
A *super* alternative reinforces the model's current prediction.
Parameters
----------
only_ensured : bool, default=False
When ``True``, keep only alternatives whose uncertainty interval is
no wider than the base prediction interval.
include_potential : bool, default=True
Whether to include *potential* alternatives.
copy : bool, default=True
When ``True``, return a new :class:`.AlternativeExplanation`.
When ``False``, filter in place.
Returns
-------
:class:`.AlternativeExplanation`
The filtered alternative explanation.
Notes
-----
The definition of "super" depends on the task mode:
- **Classification / probabilistic regression**: Treat the output as a
calibrated probability with a 0.5 decision boundary. Let
``p_base = prediction['predict']``.
- If ``p_base > 0.5`` (predicted positive/event), keep alternatives
with ``p_rule > p_base``.
- Otherwise, keep alternatives with ``p_rule < p_base``.
- **Plain regression**: Treat the output as a calibrated numeric value.
Keep alternatives with a higher predicted output than the base.
Potential alternatives are those where the uncertainty interval spans
the decision boundary (classification / probabilistic regression) or
covers the base prediction (plain regression).
Examples
--------
>>> alternatives = explainer.explore_alternatives(x_query)
>>> super_alts = alternatives[0].super_explanations()
"""
target = self.copy() if copy else self
target.__filter_rules(
only_ensured=only_ensured, make_super=True, include_potential=include_potential
)
target._AlternativeExplanation__is_super_explanation = True # pylint: disable=protected-access
return target
[docs]
def super(self, only_ensured=False, include_potential=True, copy=True):
"""Shorthand delegator for :meth:`.super_explanations`."""
return self.super_explanations(
only_ensured=only_ensured, include_potential=include_potential, copy=copy
)
[docs]
def semi_explanations(self, only_ensured=False, include_potential=True, copy=True):
"""Return a filtered view of *semi* alternative explanations.
A *semi* alternative moves the prediction toward the decision boundary
without crossing it (classification/probabilistic regression) or moves
the regression output in the opposite direction of a *super*
alternative (plain regression).
Parameters
----------
only_ensured : bool, default=False
When ``True``, keep only alternatives whose uncertainty interval is
no wider than the base prediction interval.
include_potential : bool, default=True
Whether to include *potential* alternatives.
copy : bool, default=True
When ``True``, return a new :class:`.AlternativeExplanation`.
When ``False``, filter in place.
Returns
-------
:class:`.AlternativeExplanation`
The filtered alternative explanation.
Notes
-----
- **Classification / probabilistic regression**: Semi alternatives stay
on the *same side* of the 0.5 boundary as the base prediction, but are
closer to that boundary than the base (unless marked as potential and
``include_potential=True``).
- **Plain regression**: Semi alternatives keep rules with a lower
predicted output than the base prediction.
Examples
--------
>>> alternatives = explainer.explore_alternatives(x_query)
>>> semi_alts = alternatives[0].semi_explanations()
"""
target = self.copy() if copy else self
target.__filter_rules(
only_ensured=only_ensured, make_semi=True, include_potential=include_potential
)
target._AlternativeExplanation__is_semi_explanation = True # pylint: disable=protected-access
return target
[docs]
def semi(self, only_ensured=False, include_potential=True, copy=True):
"""Shorthand delegator for :meth:`.semi_explanations`."""
return self.semi_explanations(
only_ensured=only_ensured, include_potential=include_potential, copy=copy
)
[docs]
def counter_explanations(self, only_ensured=False, include_potential=True, copy=True):
"""Return a filtered view of *counter* alternative explanations.
A *counter* alternative opposes the model's current prediction.
Parameters
----------
only_ensured : bool, default=False
When ``True``, keep only alternatives whose uncertainty interval is
no wider than the base prediction interval.
include_potential : bool, default=True
Whether to include *potential* alternatives.
copy : bool, default=True
When ``True``, return a new :class:`.AlternativeExplanation`.
When ``False``, filter in place.
Returns
-------
:class:`.AlternativeExplanation`
The filtered alternative explanation.
Notes
-----
- **Classification / probabilistic regression**: Counter alternatives
cross the 0.5 decision boundary. For a base prediction
``p_base > 0.5`` they keep rules with ``p_rule <= 0.5`` (and
vice-versa).
- **Plain regression**: Counter alternatives keep rules with a lower
predicted output than the base prediction.
In plain regression, :meth:`.semi_explanations` and
:meth:`.counter_explanations` currently have the same output semantics.
Examples
--------
>>> alternatives = explainer.explore_alternatives(x_query)
>>> counter_alts = alternatives[0].counter_explanations()
"""
target = self.copy() if copy else self
target.__filter_rules(
only_ensured=only_ensured, make_counter=True, include_potential=include_potential
)
target._AlternativeExplanation__is_counter_explanation = True # pylint: disable=protected-access
return target
[docs]
def counter(self, only_ensured=False, include_potential=True, copy=True):
"""Shorthand delegator for :meth:`.counter_explanations`."""
return self.counter_explanations(
only_ensured=only_ensured, include_potential=include_potential, copy=copy
)
[docs]
def ensured_explanations(self, include_potential=True, copy=True):
"""Return a filtered view of *ensured* alternative explanations.
Ensured alternatives are those whose uncertainty interval is no wider
than the base prediction interval.
Parameters
----------
include_potential : bool, default=True
Whether to include *potential* alternatives.
copy : bool, default=True
When ``True``, return a new :class:`.AlternativeExplanation`.
When ``False``, filter in place.
Returns
-------
:class:`.AlternativeExplanation`
The filtered alternative explanation.
Notes
-----
This method is task-agnostic: it filters by *uncertainty interval width*
only and therefore works for classification, probabilistic regression,
and plain regression.
Examples
--------
>>> alternatives = explainer.explore_alternatives(x_query)
>>> ensured = alternatives[0].ensured_explanations()
"""
target = self.copy() if copy else self
target.__filter_rules(only_ensured=True, include_potential=include_potential)
return target
[docs]
def ensured(self, include_potential=True, copy=True):
"""Shorthand delegator for :meth:`.ensured_explanations`."""
return self.ensured_explanations(include_potential=include_potential, copy=copy)
[docs]
def pareto_explanations(
self,
include_potential: bool = True,
copy: bool = True,
*,
pareto_cost: Literal["uncertainty_width", "rule_size"] = "uncertainty_width",
):
"""Return output-envelope Pareto alternatives.
Parameters
----------
include_potential : bool, default=True
Determines whether to include potential explanations before
extracting the Pareto frontier.
copy : bool, default=True
Determines whether to return a copy of the explanation or modify it
in place.
pareto_cost : {"uncertainty_width", "rule_size"}, default="uncertainty_width"
The dimension minimized along the output axis when selecting the
frontier. ``"uncertainty_width"`` reproduces the historical behavior
(minimize interval width). ``"rule_size"`` minimizes the number of
changed features in the rule (useful when conjunctions are present).
Returns
-------
:class:`.AlternativeExplanation`
Notes
-----
Pareto filtering keeps an output-envelope frontier where no alternative
can reduce the chosen Pareto cost without changing the output.
The output axis is the calibrated probability (classification /
probabilistic regression) or the calibrated numeric output (regression).
"""
target = self.copy() if copy else self
target.__filter_rules(include_potential=include_potential)
target.__pareto_filter_rules(pareto_cost=pareto_cost)
return target
[docs]
def pareto(
self,
include_potential: bool = True,
copy: bool = True,
*,
pareto_cost: Literal["uncertainty_width", "rule_size"] = "uncertainty_width",
):
"""Shorthand delegator for :meth:`.pareto_explanations`."""
return self.pareto_explanations(
include_potential=include_potential,
copy=copy,
pareto_cost=pareto_cost,
)
[docs]
def add_conjunctions(self, n_top_features=5, max_rule_size=2, **kwargs):
"""
Add conjunctive alternative rules.
Parameters
----------
n_top_features : int, optional
Number of top features to combine.
max_rule_size : int, optional
Maximum size of the conjunctions.
**kwargs : dict
Internal controls for batching, deduplication, and diagnostics.
Other Parameters
----------------
_use_batched : bool, default True
Controls batched vs sequential prediction.
_limit_outer_to_ranked : bool, default False
Whether to rank-filter outer loop candidates.
_dedupe_by_feature_only : bool, default True
Deduplication strategy for conjunctions.
raise_on_predict_error : bool, default False
Whether to surface prediction errors.
_fallback_to_legacy_on_zero : bool, default False
Whether to fall back to legacy on zero created.
Attributes
----------
conjunction_stats : dict
Summary of attempts, created, skipped, and captured prediction errors.
Returns
-------
self : :class:`.AlternativeExplanation`
Returns a self reference, to allow for method chaining
"""
# Two-phase conjunction search:
# 1) Rank features to select candidate rules for the inner loop.
# 2) Combine outer rules with ranked candidates to build conjunctions.
# current_size grows rules by one feature each iteration.
# ConjunctionState tracks accumulated rules and deduplication keys.
use_batched = kwargs.get("_use_batched", True)
limit_outer_to_ranked = kwargs.get("_limit_outer_to_ranked", False)
dedupe_by_feature_only = kwargs.get("_dedupe_by_feature_only", True)
raise_on_predict_error = kwargs.get("raise_on_predict_error", False)
_MAX_LOGGED_ERRORS = 5 # noqa: N806
if max_rule_size >= 4 and not use_batched:
from ..utils.exceptions import ConfigurationError
raise ConfigurationError(
"max_rule_size >= 4 requires batched execution (internal flag _use_batched=True)",
details={
"param": "max_rule_size",
"value": max_rule_size,
"valid_range": [2, 3],
},
)
if max_rule_size < 2:
return self
if not use_batched:
from .legacy_conjunctions import add_conjunctions_alternative_legacy
add_conjunctions_alternative_legacy(
self, n_top_features=n_top_features, max_rule_size=max_rule_size
)
return self
alternative = self.get_rules() if not self.has_rules else self.rules
state_helper = ConjunctionState(
self.conjunctive_rules
if self.has_conjunctive_rules and self.conjunctive_rules is not None
else alternative,
dedupe_by_feature_only=dedupe_by_feature_only,
)
self.has_conjunctive_rules = False
self.conjunctive_rules = []
threshold = None if self.y_threshold is None else self.y_threshold
scratch = np.array(self.x_test, copy=True)
predicted_class = alternative["classes"]
state_helper.state["classes"] = predicted_class
base_weight_array = (
np.asarray(alternative["weight"], dtype=float)
if alternative["weight"]
else np.array([])
)
base_width_array = (
np.asarray(alternative["weight_high"], dtype=float)
- np.asarray(alternative["weight_low"], dtype=float)
if alternative["weight"]
else np.array([])
)
if n_top_features is None:
n_top_features = len(alternative["rule"])
def _feature_length(candidate: Any) -> int:
if isinstance(candidate, (list, tuple, np.ndarray)):
return len(candidate)
return 1
def _coerce_feature_scalar(value: Any) -> int:
if isinstance(value, (list, tuple, np.ndarray)):
return int(np.asarray(value).ravel()[0])
return int(value)
from collections import Counter
self.conjunction_stats = {
"attempts": 0,
"created": 0,
"skipped": Counter(),
"predict_errors": [],
}
stats = self.conjunction_stats
def _summarize(arr):
if arr is None or arr.size == 0:
return {"size": 0}
return {
"size": int(arr.size),
"min": float(np.nanmin(arr)),
"max": float(np.nanmax(arr)),
"nan": int(np.isnan(arr).sum()),
}
for current_size in range(2, max_rule_size + 1):
num_rules = len(alternative["rule"])
if num_rules == 0:
break
weights_array = state_helper.get_weights()
width_array = state_helper.get_widths()
top_n = min(num_rules, n_top_features)
top_conjunctives = list(
self.rank_features(
weights_array,
width=width_array,
num_to_show=max(2, top_n) if num_rules >= 2 else top_n,
)
)
if not top_conjunctives and num_rules > 0:
top_conjunctives = list(range(num_rules))
# Determine outer loop candidates
# Legacy behavior: iterate all features
if limit_outer_to_ranked:
num_outer = min(num_rules, n_top_features)
outer_indices = list(
self.rank_features(
base_weight_array,
width=base_width_array if base_width_array.size else None,
num_to_show=max(2, num_outer) if num_rules >= 2 else num_outer,
)
)
if not outer_indices and num_rules > 0:
outer_indices = list(range(num_rules))
else:
outer_indices = range(len(alternative["feature"]))
for f1 in outer_indices:
of1 = int(alternative["feature"][f1])
sampled_values1 = alternative["sampled_values"][f1]
rule_value1 = (
sampled_values1
if isinstance(sampled_values1, np.ndarray)
else [sampled_values1]
)
for cf2 in top_conjunctives:
stats["attempts"] += 1
rule_values = [rule_value1]
original_features = [of1]
original_feature_values = [alternative["feature_value"][f1]]
of2 = state_helper.get_feature(cf2)
target_length = current_size - 1
if _feature_length(of2) != target_length:
stats["skipped"]["len_mismatch"] += 1
continue
if state_helper.is_conjunctive(cf2):
if of1 in of2:
stats["skipped"]["same_feature"] += 1
continue
original_features.extend(int(v) for v in of2)
rule_values.extend(list(state_helper.get_sampled_values(cf2)))
feature_values2 = state_helper.get_feature_value(cf2)
if isinstance(feature_values2, (list, tuple, np.ndarray)):
original_feature_values.extend(list(feature_values2))
else:
original_feature_values.append(feature_values2)
else:
of2 = _coerce_feature_scalar(of2)
if of1 == of2:
stats["skipped"]["same_feature"] += 1
continue
original_features.append(of2)
original_feature_values.append(state_helper.get_feature_value(cf2))
sampled_values2 = state_helper.get_sampled_values(cf2)
rule_values.append(
sampled_values2
if isinstance(sampled_values2, np.ndarray)
else [sampled_values2]
)
if state_helper.has_combination_key(original_features, rule_values):
stats["skipped"]["duplicate_combo"] += 1
continue
state_helper.register_combination_key(original_features, rule_values)
try:
rule_predict, rule_low, rule_high = self.predict_conjunctive(
rule_values,
original_features,
scratch,
threshold,
predicted_class,
bins=self.bin,
use_batched=use_batched,
)
except (CalibratedError, ValueError, TypeError, RuntimeError) as e:
if raise_on_predict_error:
raise
stats["skipped"]["predict_error"] += 1
if len(stats["predict_errors"]) < _MAX_LOGGED_ERRORS:
stats["predict_errors"].append(f"{type(e).__name__}: {e}")
continue
state_helper.add_rule(
predict=rule_predict,
low=rule_low,
high=rule_high,
base_predict=self.prediction["predict"],
value=alternative["value"][f1] + "\n" + state_helper.get_value(cf2),
feature=list(original_features),
sampled_values=list(rule_values),
feature_value=list(original_feature_values),
rule_text=alternative["rule"][f1] + " & \n" + state_helper.get_rule(cf2),
)
stats["created"] += 1
self.conjunctive_rules = state_helper.get_state()
self.has_conjunctive_rules = True
if stats["created"] == 0 and stats["attempts"] > 0:
summary_weights = _summarize(base_weight_array)
import warnings
err_msg = (
f" predict_errors={stats['predict_errors']}" if stats["predict_errors"] else ""
)
warning_msg = (
f"add_conjunctions: created={stats['created']} "
f"attempts={stats['attempts']} skipped={dict(stats['skipped'])} "
f"weights={summary_weights}{err_msg}"
)
if kwargs.get("_fallback_to_legacy_on_zero", False):
warnings.warn(
warning_msg + " (falling back to legacy)",
UserWarning,
stacklevel=2,
)
from .legacy_conjunctions import add_conjunctions_alternative_legacy
return add_conjunctions_alternative_legacy(
self, n_top_features=n_top_features, max_rule_size=max_rule_size
)
if kwargs.get("verbose", False):
warnings.warn(warning_msg, UserWarning, stacklevel=2)
self.conjunctive_rules = state_helper.get_state()
self.has_conjunctive_rules = True
return self
def _is_lesser(self, rule_boundary, instance_value):
"""Return whether the instance value exceeds the provided rule boundary."""
return rule_boundary < instance_value
# pylint: disable=consider-iterating-dictionary
[docs]
def plot(self, filter_top=None, **kwargs):
"""
Plot the alternative explanation.
Parameters
----------
filter_top : int, optional
The number of top features to display.
**kwargs : dict
Additional plotting arguments, such as:
show : bool, default=True if filename is empty, False otherwise
A boolean parameter that determines whether the plot should be displayed or not. If set to
True, the plot will be displayed. If set to False, the plot will not be displayed.
filename : str, default=''
The filename parameter is a string that represents the full path and filename of the plot
image file that will be saved. If this parameter is not provided or is an empty string, the plot
will not be saved as an image file.
style : str, default='regular'
The `style` parameter is a string that determines the style of the plot. Possible styles are for :class:`.AlternativeExplanation`:
* 'regular' - a regular plot with feature weights and uncertainty intervals (if applicable)
* 'triangular' - a triangular plot for alternative explanations highlighting the interplay between the prediction and the uncertainty intervals
* 'ensured' - alias for 'triangular' (intended for ensured-style alternative interpretation)
rnk_metric : str, default='ensured'
The metric used to rank the features. Supported metrics are 'ensured', 'feature_weight', and 'uncertainty'.
rnk_weight : float, default=0.5
The weight of the uncertainty in the ranking. Used with the 'ensured' ranking metric.
"""
requested_style = kwargs.get("style")
custom_plot_style = isinstance(requested_style, str) and requested_style not in {
"regular",
"triangular",
"ensured",
"narrative",
}
# Ensure style_override gets passed through
style_override = kwargs.pop("style_override", None)
plot_use_legacy = kwargs.pop("use_legacy", None)
# PlotSpec request forces new renderer
if kwargs.get("return_plot_spec") or custom_plot_style:
plot_use_legacy = False
# Phase 2 Option B: Default to legacy to ensure parity until PlotSpec is fully hardened
elif plot_use_legacy is None:
plot_use_legacy = True
filename = kwargs.pop("filename", "")
show = kwargs.pop("show", filename == "")
rnk_metric = kwargs.pop("rnk_metric", "ensured")
if rnk_metric is None:
rnk_metric = "ensured"
rnk_weight = kwargs.pop("rnk_weight", 0.5)
# Put the most uncertain rules at the top
if rnk_metric == "uncertainty":
rnk_weight = 1.0
rnk_metric = "ensured"
# Use conjunctive rules when available so that conjunctions appear in plots
if getattr(self, "has_conjunctive_rules", False) and getattr(
self, "conjunctive_rules", None
):
alternative = self.conjunctive_rules
else:
alternative = self.get_rules() # get_explanation(index)
self._check_preconditions()
predict = self.prediction
if len(filename) > 0:
path, filename, title, ext = prepare_for_saving(filename)
path = f"plots/{path}"
save_ext = [ext]
else:
path = ""
title = ""
save_ext = []
feature_predict = {
"predict": alternative["predict"],
"low": alternative["predict_low"],
"high": alternative["predict_high"],
}
feature_weights = np.reshape(alternative["weight"], (len(alternative["weight"])))
width = np.reshape(
np.array(alternative["weight_high"]) - np.array(alternative["weight_low"]),
(len(alternative["weight"])),
)
num_rules = len(alternative["rule"])
if filter_top is None:
filter_top = num_rules
num_to_show_ = np.min([num_rules, filter_top])
if num_to_show_ <= 0:
warnings.warn(
f"The explanation has no rules to plot. The index of the instance is {self.index}",
stacklevel=2,
)
return
if rnk_metric == "feature_weight":
features_to_plot = self.rank_features(
feature_weights, width=width, num_to_show=num_to_show_
)
else:
# Always rank base on predicted class
prediction = alternative["predict"]
if self.get_mode() == "classification" or self.is_thresholded():
prediction = prediction if predict["predict"] > 0.5 else [1 - p for p in prediction]
ranking = calculate_metrics(
uncertainty=[
alternative["predict_high"][i] - alternative["predict_low"][i]
for i in range(num_rules)
],
prediction=prediction,
w=rnk_weight,
metric=rnk_metric,
)
features_to_plot = self.rank_features(width=ranking, num_to_show=num_to_show_)
# Display highest-impact rules at the top: reverse the index order returned by
# rank_features (which yields ascending by design).
features_to_plot = list(reversed(features_to_plot))
# Filter out rules that don't change the prediction or uncertainty (exactly identical to base).
# Keep ordering from the ranking.
features_to_plot = [
i
for i in features_to_plot
if not (
np.isclose(feature_predict["predict"][i], predict["predict"])
and np.isclose(feature_predict["low"][i], predict["low"])
and np.isclose(feature_predict["high"][i], predict["high"])
)
]
# Adjust the number to show after filtering
num_to_show_filtered = min(num_to_show_, len(features_to_plot))
style = kwargs.get("style")
if style == "ensured":
kwargs["style"] = "triangular"
style = "triangular"
if style == "triangular":
proba = predict["predict"]
# Uncertainty is the calibrated interval width (high-low).
# Keep semantics consistent with ADR-021 and other plot styles.
y_minmax = getattr(self, "y_minmax", None)
base_low = predict["low"]
base_high = predict["high"]
if y_minmax is not None:
if base_low == -np.inf:
base_low = y_minmax[0]
if base_high == np.inf:
base_high = y_minmax[1]
uncertainty = base_high - base_low
rule_proba = alternative["predict"]
rule_low = np.array(alternative["predict_low"], dtype=float)
rule_high = np.array(alternative["predict_high"], dtype=float)
# Replace infinite endpoints with observed bounds so the triangle plot
# can render meaningful widths.
if y_minmax is not None:
rule_low = np.where(np.isneginf(rule_low), y_minmax[0], rule_low)
rule_high = np.where(np.isposinf(rule_high), y_minmax[1], rule_high)
rule_uncertainty = rule_high - rule_low
# Use list comprehension or NumPy array indexing to select elements
selected_rule_proba = [rule_proba[i] for i in features_to_plot]
selected_rule_uncertainty = [rule_uncertainty[i] for i in features_to_plot]
# Use the filtered number of rules to plot so the number of arrow
# positions (num_to_show) matches the length of the selected rule
# arrays. Previously we passed the original num_to_show_ which could
# be larger than the number of selected rules and caused a size
# mismatch in matplotlib.quiver.
num_to_show_for_plot = min(num_to_show_, len(selected_rule_proba))
plot_triangular(
self,
proba,
uncertainty,
selected_rule_proba,
selected_rule_uncertainty,
num_to_show_for_plot,
title=title,
path=path,
show=show,
save_ext=save_ext,
style_override=style_override,
)
return
column_names = alternative["rule"]
plot_alternative(
self,
alternative["value"],
predict,
feature_predict,
features_to_plot,
num_to_show=num_to_show_filtered,
column_names=column_names,
title=title,
path=path,
show=show,
save_ext=save_ext,
style_override=style_override,
use_legacy=plot_use_legacy,
)
[docs]
class FastExplanation(CalibratedExplanation):
"""Class representing fast explanations.
Represents fast, SHAP-like explanations, enabling efficient interpretation of model behavior for large datasets.
"""
def __init__(
self,
calibrated_explanations,
index,
x,
feature_weights,
feature_predict,
prediction,
y_threshold=None,
instance_bin=None,
condition_source="prediction",
):
"""Class representing fast explanations.
Represents fast, SHAP-like explanations, enabling efficient interpretation of model behavior for large datasets.
Initialize a FastExplanation instance.
Parameters
----------
calibrated_explanations : CalibratedExplanations
The parent CalibratedExplanations object.
index : int
The index of the instance being explained.
x : array-like
The test dataset containing the instances to be explained.
feature_weights : dict
A mapping of feature weights.
feature_predict : dict
A mapping of feature predictions.
prediction : dict
A mapping containing the prediction results.
y_threshold : float or tuple, optional
The threshold for binary classification or regression explanations.
instance_bin : int, optional
The bin index of the instance.
condition_source : str, default="prediction"
The source of the conditions for the explanation.
"""
super().__init__(
calibrated_explanations,
index,
x,
{},
feature_weights,
feature_predict,
prediction,
y_threshold,
instance_bin,
condition_source=condition_source,
)
self._check_preconditions()
self.get_rules()
def __repr__(self):
"""Return a string representation of the fast explanation."""
fast = self.get_rules()
output = [
f"{'Prediction':10} [{' Low':5}, {' High':5}]",
f" {fast['base_predict'][0]:5.3f} [{fast['base_predict_low'][0]:5.3f}, {fast['base_predict_high'][0]:5.3f}]",
f"{'Value':6}: {'Feature':40s} {'Weight':6} [{' Low':6}, {' High':6}]",
]
feature_order = self.rank_features(
fast["weight"],
width=np.array(fast["weight_high"]) - np.array(fast["weight_low"]),
num_to_show=len(fast["rule"]),
)
# feature_order = range(len(fast['rule']))
output.extend(
f"{fast['value'][f]:6}: {fast['rule'][f]:40s} {fast['weight'][f]:>6.3f} [{fast['weight_low'][f]:>6.3f}, {fast['weight_high'][f]:>6.3f}]"
for f in reversed(feature_order)
)
# sum_weights = np.sum((fast['weight']))
# sum_weights_low = np.sum((fast['weight_low']))
# sum_weights_high = np.sum((fast['weight_high']))
# output.append(f"{'Mean':6}: {'':40s} {sum_weights:>6.3f} [{sum_weights_low:>6.3f}, {sum_weights_high:>6.3f}]")
return "\n".join(output) + "\n"
[docs]
def build_rules_payload(self) -> Dict[str, Any]:
"""Reuse the factual payload structure for fast explanations."""
return FactualExplanation.build_rules_payload(self)
[docs]
def add_conjunctions(self, n_top_features=5, max_rule_size=2, **kwargs):
"""Warn that conjunctions are not supported for ``FastExplanation`` and perform no work.
Parameters
----------
n_top_features : int
The number of top features to consider for conjunctions. Default is 5.
max_rule_size : int
The maximum size of the conjunctive rules. Default is 2.
Warnings
--------
This method is not supported for :class:`.FastExplanation` and will not alter the explanation.
"""
warnings.warn(
"The add_conjunctions method is currently not supported for `FastExplanation`, making this call resulting in no change.",
stacklevel=2,
)
# pass
def _is_lesser(self, rule_boundary, instance_value):
"""Return False as fast explanations do not support ordered rule comparisons."""
pass
[docs]
def add_new_rule_condition(self, feature, rule_boundary):
"""Create a new rule condition for a numerical feature.
Warnings
--------
This method is not supported for :class:`.FastExplanation` and will not alter the explanation.
"""
warnings.warn(
"The add_new_rule_condition method is currently not supported for `FastExplanation`, making this call resulting in no change.",
stacklevel=2,
)
# pass
def _check_preconditions(self):
"""Provide a placeholder hook; FAST explanations require no extra checks."""
pass
# pylint: disable=too-many-statements, too-many-branches
[docs]
def get_rules(self):
"""
Create fast explanation rules.
Returns
-------
dict
A dictionary containing the fast explanation rules.
"""
# i = self.index
instance = np.array(self.x_test, copy=True)
fast = {
"base_predict": [],
"base_predict_low": [],
"base_predict_high": [],
"predict": [],
"predict_low": [],
"predict_high": [],
"weight": [],
"weight_low": [],
"weight_high": [],
"value": [],
"rule": [],
"feature": [],
"sampled_values": [],
"feature_value": [],
"is_conjunctive": [],
"classes": self.prediction["classes"],
}
fast["base_predict"].append(self.prediction["predict"])
fast["base_predict_low"].append(self.prediction["low"])
fast["base_predict_high"].append(self.prediction["high"])
rules = self.define_conditions()
for f, _ in enumerate(instance): # pylint: disable=invalid-name
if self.prediction["predict"] == self.feature_predict["predict"][f]:
continue
fast["predict"].append(self.feature_predict["predict"][f])
fast["predict_low"].append(self.feature_predict["low"][f])
fast["predict_high"].append(self.feature_predict["high"][f])
fast["weight"].append(self.feature_weights["predict"][f])
fast["weight_low"].append(self.feature_weights["low"][f])
fast["weight_high"].append(self.feature_weights["high"][f])
if f in self.get_explainer().categorical_features:
if self.get_explainer().categorical_labels is not None:
fast["value"].append(
self.get_explainer().categorical_labels[f][int(instance[f])]
)
else:
fast["value"].append(str(instance[f]))
else:
fast["value"].append(str(np.around(instance[f], decimals=2)))
fast["rule"].append(rules[f])
fast["feature"].append(f)
fast["sampled_values"].append(None)
fast["feature_value"].append(None)
fast["is_conjunctive"].append(False)
self.rules = fast
self.has_rules = True
return self.rules
[docs]
def define_conditions(self):
"""
Define the rule conditions for the fast explanation.
Returns
-------
list[str]
A list of conditions for each feature.
"""
self.conditions = []
for f in range(self.get_explainer().num_features):
rule = f"{self.get_explainer().feature_names[f]}"
self.conditions.append(rule)
return self.conditions
[docs]
def plot(self, filter_top=None, **kwargs):
"""
Plot the fast explanation.
Parameters
----------
filter_top : int, optional
The number of top features to display.
**kwargs : dict
Additional plotting arguments, such as:
show : bool, default=True if filename is empty, False otherwise
A boolean parameter that determines whether the plot should be displayed or not. If set to
True, the plot will be displayed. If set to False, the plot will not be displayed.
filename : str, default=''
The filename parameter is a string that represents the full path and filename of the plot
image file that will be saved. If this parameter is not provided or is an empty string, the plot
will not be saved as an image file.
uncertainty : bool, default=False
The `uncertainty` parameter is a boolean flag that determines whether to plot the uncertainty
intervals for the feature weights. If `uncertainty` is set to `True`, the plot will show the
envelope of possible boundary shifts based on the lower and upper bounds of the uncertainty
intervals. If `uncertainty` is set to `False`, the plot will only show the feature weights
style : str, default='regular'
The `style` parameter is a string that determines the style of the plot. Possible styles are for :class:`.FastExplanation`:
* 'regular' - a regular plot with feature weights and uncertainty intervals (if applicable)
rnk_metric : str, default='feature_weight'
The metric used to rank the features. Supported metrics are 'ensured', 'feature_weight', and 'uncertainty'.
rnk_weight : float, default=0.5
The weight of the uncertainty in the ranking. Used with the 'ensured' ranking metric.
"""
# Ensure style_override gets passed through
style_override = kwargs.get("style_override")
plot_use_legacy = kwargs.get("use_legacy")
filename = kwargs.get("filename", "")
show = kwargs.get("show", filename == "")
uncertainty = kwargs.get("uncertainty", False)
rnk_metric = kwargs.get("rnk_metric", "feature_weight")
if rnk_metric is None:
rnk_metric = "feature_weight"
rnk_weight = kwargs.get("rnk_weight", 0.5)
if rnk_metric == "uncertainty":
rnk_weight = 1.0
rnk_metric = "ensured"
# Consistency guard: one-sided intervals cannot show uncertainty bands
if uncertainty and self.is_one_sided():
raise Warning("Interval plot is not supported for one-sided explanations.")
# Use conjunctive rules when available so that conjunctions appear in plots
if getattr(self, "has_conjunctive_rules", False) and getattr(
self, "conjunctive_rules", None
):
factual = self.conjunctive_rules
else:
factual = self.get_rules() # get_explanation(index)
self._check_preconditions()
predict = self.prediction
num_features_to_show = len(factual["weight"])
if filter_top is None:
filter_top = num_features_to_show
filter_top = np.min([num_features_to_show, filter_top])
if filter_top <= 0:
warnings.warn(
f"The explanation has no rules to plot. The index of the instance is {self.index}",
stacklevel=2,
)
return
if len(filename) > 0:
path, filename, title, ext = prepare_for_saving(filename)
path = f"plots/{path}"
save_ext = [ext]
else:
path = ""
title = ""
save_ext = []
if uncertainty:
feature_weights = {
"predict": factual["weight"],
"low": factual["weight_low"],
"high": factual["weight_high"],
}
else:
feature_weights = factual["weight"]
width = np.reshape(
np.array(factual["weight_high"]) - np.array(factual["weight_low"]),
(len(factual["weight"])),
)
if rnk_metric == "feature_weight":
features_to_plot = self.rank_features(
factual["weight"], width=width, num_to_show=filter_top
)
else:
ranking = calculate_metrics(
uncertainty=[
factual["predict_high"][i] - factual["predict_low"][i]
for i in range(len(factual["weight"]))
],
prediction=factual["predict"],
w=rnk_weight,
metric=rnk_metric,
)
features_to_plot = self.rank_features(width=ranking, num_to_show=filter_top)
# Prefer explicit feature/column names when available; fall back to rule strings
column_names = (
factual.get("feature_names") or factual.get("column_names") or factual.get("rule")
)
if "classification" in self.get_explainer().mode or self.is_thresholded():
plot_probabilistic(
self,
factual["value"],
predict,
feature_weights,
features_to_plot,
filter_top,
column_names,
title=title,
path=path,
interval=uncertainty,
show=show,
idx=self.index,
save_ext=save_ext,
style_override=style_override,
use_legacy=plot_use_legacy,
)
else:
plot_regression(
self,
factual["value"],
predict,
feature_weights,
features_to_plot,
filter_top,
column_names,
title=title,
path=path,
interval=uncertainty,
show=show,
idx=self.index,
save_ext=save_ext,
style_override=style_override,
use_legacy=plot_use_legacy,
)