Source code for calibrated_explanations.explanations.explanation

# pylint: disable=unknown-option-value
# pylint: disable=too-many-lines, too-many-arguments, invalid-name, too-many-positional-arguments, line-too-long

"""Calibrated explanation containers and visualization helpers.

This module defines the classes used to represent factual, alternative and
fast explanations produced by :class:`~calibrated_explanations.core.CalibratedExplainer`.

Primary classes
---------------
- :class:`CalibratedExplanation` — Abstract base for explanation instances.
- :class:`FactualExplanation` — Factual explanations for an instance.
- :class:`AlternativeExplanation` — Alternative/counterfactual explanations.
- :class:`FastExplanation` — Lightweight fast-mode explanations.
"""

from __future__ import annotations

import contextlib
import math
import re
import warnings
from abc import ABC, abstractmethod
from collections.abc import Sequence
from copy import copy
from dataclasses import dataclass
from types import MappingProxyType
from typing import Any, Dict, Literal, Optional, Tuple

import numpy as np
from pandas import Categorical

from ..plotting import plot_alternative, plot_probabilistic, plot_regression, plot_triangular
from ..utils import (
    BinaryEntropyDiscretizer,
    BinaryRegressorDiscretizer,
    EntropyDiscretizer,
    RegressorDiscretizer,
    calculate_metrics,
    prepare_for_saving,
    safe_first_element,
    safe_mean,
)
from ..utils.exceptions import CalibratedError, ValidationError
from ..utils.helper import assign_threshold as normalize_threshold
from ..utils.int_utils import collect_ints
from ._conjunctions import ConjunctionState


@dataclass
class RuleWithImpact:
    """Canonical representation of a rule's impact for consistent plotting and narrative."""

    rule_id: str
    feature: str
    text: str
    impact: float
    direction: Literal["positive", "negative", "neutral"]
    base_predict: float
    predict: float
    value: Any
    weight_envelope_low: Optional[float] = None
    weight_envelope_high: Optional[float] = None
    predict_low: Optional[float] = None
    predict_high: Optional[float] = None


# @dataclass
# class PredictionInterval:
#     """A dataclass representing a prediction interval for a single feature.

#     Attributes
#     ----------
#     predict: float
#         The model's prediction for this feature
#     low: float
#         The lower bound of the prediction interval
#     high: float
#         The upper bound of the prediction interval
#     """
#     predict: float
#     low: float
#     high: float

# @dataclass
# class FeatureRule:
#     """A dataclass representing a rule for a single feature in an explanation.

#     Attributes
#     ----------
#     weight : PredictionInterval
#         The weight/importance of this feature rule, containing prediction and interval values
#     prediction : PredictionInterval
#         The model's prediction and interval for this feature rule
#     instance_prediction : PredictionInterval
#         The binned prediction data for this feature rule
#     current_bin : int
#         The bin index containing the current feature value
#     rule : str
#         String representation of the rule
#     feature : int
#         Index of the feature this rule applies to
#     feature_value : float
#         The actual value of the feature
#     is_conjunctive : bool
#         Whether this rule is part of a conjunction
#     value_str : str
#         String representation of the feature value
#     """
#     weight: PredictionInterval
#     prediction: PredictionInterval
#     instance_prediction: PredictionInterval # from binned data
#     current_bin: int
#     rule: str
#     feature: int
#     feature_value: float
#     is_conjunctive: bool
#     value_str: str


# pylint: disable=too-many-instance-attributes, too-many-locals, too-many-arguments
[docs] class CalibratedExplanation(ABC): """Abstract base class for storing and visualizing calibrated explanations. Subclasses implement concrete payload building and plotting utilities while this base class provides shared validation and convenience accessors. See documentation at ``docs/foundations/concepts/explanation_structures.md`` for details on the internal payload layout. """ def __init__( self, calibrated_explanations, index, x, binned, feature_weights, feature_predict, prediction, y_threshold=None, instance_bin=None, condition_source: str = "prediction", ): """Abstract base class for storing and visualizing calibrated explanations. This class defines the interface and shared functionality for different types of calibrated explanations. Initialize a CalibratedExplanation instance. Parameters ---------- calibrated_explanations : :class:`.CalibratedExplanations` The parent :class:`.CalibratedExplanations` object. index : int The index of the instance being explained. x : array-like The test dataset containing the instances to be explained. binned : dict A mapping of binned feature values. feature_weights : dict A mapping of feature weights. feature_predict : dict A mapping of feature predictions. prediction : dict A mapping containing the prediction results. y_threshold : float or tuple, optional The threshold for binary classification or regression explanations. instance_bin : int, optional The bin index of the instance. """ binned = MappingProxyType(binned) feature_weights = MappingProxyType(feature_weights) feature_predict = MappingProxyType(feature_predict) prediction = MappingProxyType(prediction) self.calibrated_explanations = calibrated_explanations self.index = index self.x_test = x self.binned = {} self.feature_weights = {} self.feature_predict = {} self.prediction = {} for key in binned: self.binned[key] = binned[key][index] for key in feature_weights: self.feature_weights[key] = feature_weights[key][index] self.feature_predict[key] = feature_predict[key][index] for key in prediction: # Special handling: full probability matrix stored under magic key if key == "__full_probabilities__": self.prediction[key] = prediction[ key ] # keep whole matrix (used for golden baseline only) else: self.prediction[key] = prediction[key][index] self.y_threshold = ( y_threshold if np.isscalar(y_threshold) or isinstance(y_threshold, tuple) else None if y_threshold is None else y_threshold[index] ) self.conditions = [] self.rules = None self.conjunctive_rules = None self.has_rules = False self.has_conjunctive_rules = False self.bin = [instance_bin] if instance_bin is not None else None self.explain_time = None self.condition_source = condition_source # reduce dependence on Explainer class if not isinstance(self.get_explainer().y_cal, Categorical): self.y_minmax = [ np.min(self.get_explainer().y_cal), np.max(self.get_explainer().y_cal), ] else: self.y_minmax = [0, 0] self.focus_columns = None # Optional reject context attached when explanations are produced under # a reject policy (FLAG / ONLY_REJECTED). Populated by orchestrator. self.reject_context = None self._validate_prediction_invariant()
[docs] def filter_features( self, *, exclude_features=None, include_features=None, copy=True, ): """Filter rules by feature inclusion or exclusion. Parameters ---------- exclude_features : str, int, or list of str/int, optional Features to exclude. Rules containing these features will be removed. include_features : str, int, or list of str/int, optional Features to include. Only rules containing these features will be kept. copy : bool, default=True If True, return a copy of the explanation. If False, modify in place. Returns ------- CalibratedExplanation Filtered explanation. """ if (exclude_features is None) == (include_features is None): raise ValidationError( "Exactly one of exclude_features or include_features must be provided", details={ "exclude_features": exclude_features, "include_features": include_features, }, ) if copy: self = self.copy() # Normalize the features to indices target_features = exclude_features if exclude_features is not None else include_features is_exclude = exclude_features is not None if isinstance(target_features, (str, int)): target_features = [target_features] elif not isinstance(target_features, list): raise ValidationError("Features must be a string, int, or list of strings/ints") if not target_features: raise ValidationError("Features list must not be empty") target_indices = [] for feat in target_features: if isinstance(feat, str): if feat not in self.get_explainer().feature_names: raise ValidationError(f"Feature name '{feat}' not found in feature_names") target_indices.append(self.get_explainer().feature_names.index(feat)) elif isinstance(feat, int): if not (0 <= feat < self.get_explainer().num_features): raise ValidationError( f"Feature index {feat} is out of range [0, {self.get_explainer().num_features})" ) target_indices.append(feat) else: raise ValidationError("Features must contain only strings or ints") # Create mask for rules to keep keep_mask = [] for i, features in enumerate(self.rules["feature"]): if self.rules["is_conjunctive"][i]: # For conjunctive rules if isinstance(features, list): has_target = any(f in target_indices for f in features) else: has_target = features in target_indices keep = has_target if not is_exclude else not has_target else: # For disjunctive rules (single feature) has_target = features in target_indices keep = has_target if not is_exclude else not has_target keep_mask.append(keep) # Filter rules filtered_rules = {} for key in self.rules: filtered_rules[key] = [ val for val, keep in zip(self.rules[key], keep_mask, strict=False) if keep ] self.rules = filtered_rules return self
def _validate_prediction_invariant(self) -> None: """Enforce low <= predict <= high invariant on prediction payload.""" import warnings predict = self.prediction.get("predict") low = self.prediction.get("low") high = self.prediction.get("high") if predict is None or low is None or high is None: return with contextlib.suppress(TypeError, ValueError): # Handle scalar values (common case) if ( isinstance(predict, (int, float)) and isinstance(low, (int, float)) and isinstance(high, (int, float)) ): if not low <= high: warnings.warn( "Prediction interval invariant violated: low > high", UserWarning, stacklevel=2, ) # Allow small floating point tolerance epsilon = 1e-9 if not (low - epsilon <= predict <= high + epsilon): warnings.warn( "Prediction invariant violated: predict not in [low, high]", UserWarning, stacklevel=2, ) def __len__(self): """Return the number of rules in the explanation.""" return len(self.get_rules()["rule"])
[docs] @abstractmethod def build_rules_payload(self) -> Dict[str, Any]: """Return structured rule payload separating core content from metadata.""" raise NotImplementedError
@property def prediction_interval(self): """Get the prediction interval from the prediction dictionary. Returns ------- tuple A tuple containing (low, high) values of the prediction interval. """ return (self.prediction["low"], self.prediction["high"]) @property def predict(self): """Get the prediction from the prediction dictionary. Returns ------- float A prediction value. """ return self.prediction["predict"]
[docs] def get_mode(self): """Return the mode of the explanation ('classification' or 'regression').""" return self.get_explainer().mode
[docs] def get_class_labels(self): """Return the class labels.""" return self.get_explainer().class_labels
[docs] def is_multiclass(self): """Determine if the explanation is multiclass.""" return self.get_explainer().is_multiclass()
[docs] def get_explainer(self): """Return the explainer object.""" container = self.calibrated_explanations getter = getattr(container, "get_explainer", None) if callable(getter): return getter() getter = getattr(container, "_get_explainer", None) if callable(getter): return getter() if hasattr(container, "explainer"): return container.explainer return getattr(container, "calibrated_explainer", container)
[docs] def copy(self): """Create a shallow copy of the explanation object. Returns ------- CalibratedExplanation A shallow copy of the explanation object. """ return copy(self)
[docs] def ignored_features_for_instance(self): """Return the set of feature indices ignored for this instance. Combines collection-level ``features_to_ignore`` with any per-instance mask exposed via ``feature_filter_per_instance_ignore``. """ ignored: set[int] = set() global_ignore = getattr(self.calibrated_explanations, "features_to_ignore", None) if global_ignore is None: global_ignore = () ignored.update(collect_ints(global_ignore)) per_instance = getattr( self.calibrated_explanations, "feature_filter_per_instance_ignore", None ) instance_mask = None if isinstance(per_instance, Sequence) and 0 <= self.index < len(per_instance): instance_mask = per_instance[self.index] if instance_mask is not None: ignored.update(collect_ints(instance_mask)) return ignored
[docs] def rank_features(self, feature_weights=None, width=None, num_to_show=None): """Rank the features based on their weights. Parameters ---------- feature_weights : dict, optional A mapping of feature weights. width : dict, optional A mapping of feature widths. num_to_show : int, optional The number of features to show. Returns ------- list The sorted indices of the features. """ if not (feature_weights is not None or width is not None): from ..utils.exceptions import ValidationError raise ValidationError( "Either feature_weights or width (or both) must not be None", details={ "param": "feature_weights/width", "requirement": "at least one must be provided", "feature_weights": feature_weights is not None, "width": width is not None, }, ) num_features = len(feature_weights) if feature_weights is not None else len(width) if num_to_show is None or num_to_show > num_features: num_to_show = num_features # Robust ranking: handle NaN/inf if feature_weights is not None: feature_weights = np.nan_to_num( feature_weights, nan=0.0, posinf=np.finfo(float).max, neginf=-np.finfo(float).max, ) if width is not None: width = np.nan_to_num( width, nan=0.0, posinf=np.finfo(float).max, neginf=-np.finfo(float).max, ) # handle case where there are same weight but different uncertainty if feature_weights is not None and width is not None: # get the indices by first sorting on the absolute value of the # feature_weight and then on the width sorted_indices = [ i for i, x in sorted( enumerate(list(zip(np.abs(feature_weights), width, strict=False))), key=lambda x: (x[1][0], x[1][1]), ) ] return sorted_indices[-num_to_show:] # pylint: disable=invalid-unary-operand-type if width is not None: sorted_indices = np.argsort(width) return sorted_indices[-num_to_show:] # pylint: disable=invalid-unary-operand-type sorted_indices = np.argsort(np.abs(feature_weights)) return sorted_indices[-num_to_show:] # pylint: disable=invalid-unary-operand-type
[docs] def is_one_sided(self) -> bool: """Test if a regression explanation is one-sided. Returns ------- bool: True if one of the low or high percentiles is infinite """ if self.calibrated_explanations.low_high_percentiles is None: return False return np.isinf(self.calibrated_explanations.get_low_percentile()) or np.isinf( self.calibrated_explanations.get_high_percentile() )
[docs] def is_thresholded(self) -> bool: """Check if the explanation is thresholded. Returns ------- bool: True if the y_threshold is not None """ return self.y_threshold is not None
[docs] def is_regression(self) -> bool: """Check if the explanation is for regression. Returns ------- bool: True if mode is 'regression' """ return "regression" in self.get_explainer().mode
[docs] def is_probabilistic(self) -> bool: """Check if the explanation is probabilistic. Returns ------- bool: True if mode is 'classification' or is_thresholded and is_regression are True """ return "classification" in self.get_explainer().mode or ( self.is_regression() and self.is_thresholded() )
@abstractmethod def __repr__(self): """Return a string representation of the explanation."""
[docs] @abstractmethod def plot(self, filter_top=None, **kwargs): """ Plot the explanation. Parameters ---------- filter_top : int, optional The number of top features to display. **kwargs : dict Additional plotting arguments. See each subclass. See Also -------- :meth:`.FactualExplanation.plot` : Refer to the docstring for plot in FactualExplanation for details. :meth:`.AlternativeExplanation.plot` : Refer to the docstring for plot in AlternativeExplanation for details. :meth:`.FastExplanation.plot` : Refer to the docstring for plot in FastExplanation for details. """
[docs] def to_narrative( self, template_path=None, expertise_level=("beginner", "advanced"), output_format="dataframe", conjunction_separator=" AND ", align_weights=True, **kwargs, ): """ Generate narrative explanation for this single instance. This method provides a clean API for generating human-readable narratives from a single calibrated explanation using customizable templates. Parameters ---------- template_path : str or None, default=None Path to the narrative template file (YAML or JSON). If None or the file doesn't exist, the built-in default template is used. expertise_level : str or tuple of str, default=("beginner", "advanced") The expertise level(s) for narrative generation. Can be a single level or a tuple of levels. Valid values: "beginner", "intermediate", "advanced". output_format : str, default="dataframe" Output format. Valid values: "dataframe", "text", "html", "dict", "markdown". conjunction_separator : str, default=" AND " Separator to use for conjunctive rules. Conjunctive rules combine multiple feature conditions (e.g., "Glucose > 120 AND BMI > 28"). align_weights : bool, default=True If True, vertically align columns in the narrative output. For factual explanations this aligns the weight marker ("— weight …"). For alternative explanations this aligns the "then" keyword in rule lines. If False, no alignment is applied. **kwargs : dict Additional keyword arguments passed to the narrative plugin. Returns ------- pd.DataFrame or str or dict The generated narrative in the requested format: - "dataframe": pandas DataFrame with one row - "text": formatted text string - "html": HTML table with one row - "dict": dictionary with narrative fields Raises ------ FileNotFoundError If the template file is not found and no default is available. ValueError If an invalid expertise level or output format is specified. ImportError If pandas is not available and output_format="dataframe" is requested. Examples -------- >>> from calibrated_explanations import CalibratedExplainer >>> explainer = CalibratedExplainer(model, x_train, y_train) >>> explanations = explainer.explain_factual(x_test) >>> single_explanation = explanations[0] >>> narrative = single_explanation.to_narrative( ... expertise_level=("beginner", "advanced"), ... output_format="dataframe" ... ) >>> print(narrative) See Also -------- :meth:`.CalibratedExplanations.to_narrative` : Generate narratives for a collection of explanations. :meth:`.plot` : Plot explanations with various visual styles. """ from ..viz.narrative_plugin import NarrativePlotPlugin # Create a temporary collection with just this explanation # We need to wrap this single explanation in a collection-like object # to use the narrative plugin # Create plugin instance plugin = NarrativePlotPlugin(template_path=template_path) # Create a minimal wrapper that looks like a collection class SingleExplanationWrapper: def __init__(self, explanation): self.explanations = [explanation] self._parent = explanation.calibrated_explanations self.calibrated_explainer = explanation.calibrated_explanations.calibrated_explainer self.y_threshold = explanation.y_threshold def is_alternative(self) -> bool: parent_is_alternative = getattr(self._parent, "is_alternative", None) if callable(parent_is_alternative): return bool(parent_is_alternative()) return "Alternative" in type(self.explanations[0]).__name__ wrapper = SingleExplanationWrapper(self) # Generate narrative using the plugin result = plugin.plot( wrapper, template_path=template_path, expertise_level=expertise_level, output=output_format, conjunction_separator=conjunction_separator, align_weights=align_weights, **kwargs, ) # For single explanations, extract the first row/item if it's a collection if output_format == "dataframe": # Return the DataFrame (will have one row) return result elif output_format == "dict": # Return the first (and only) dictionary return result[0] if result else {} else: # For text and html, return as is return result
[docs] def to_dataframe(self, *args, **kwargs): """Return the narrative output as a pandas DataFrame. Call :meth:`to_narrative` with ``output_format='dataframe'`` and return the resulting DataFrame. Accepts the same arguments as :meth:`to_narrative`. """ kwargs.setdefault("output_format", "dataframe") return self.to_narrative(*args, **kwargs)
[docs] @abstractmethod def add_conjunctions(self, n_top_features=5, max_rule_size=2): """ Add conjunctive rules to the explanation. Parameters ---------- n_top_features : int, optional Number of top features to combine. max_rule_size : int, optional Maximum size of the conjunctions. Returns ------- :class:`.CalibratedExplanation` """
@abstractmethod def _check_preconditions(self): """Validate that required explanation inputs and state are available.""" pass
[docs] def reset(self): """Reset the explanation to its original state.""" # Reset both base rules and any derived conjunctive overlays. # Conjunctions can change downstream payloads (plots, telemetry), so a # reset should restore the atomic/original representation. self.has_rules = False self.has_conjunctive_rules = False self.conjunctive_rules = None self.get_rules() return self
[docs] def remove_conjunctions(self): """Remove any conjunctive rules.""" self.has_conjunctive_rules = False self.conjunctive_rules = None return self
[docs] def filter_rule_sizes( self, *, rule_sizes: Optional[Any] = None, size_range: Optional[Tuple[int, int]] = None, copy: bool = True, ): """Filter rules by conjunctive rule size. Parameters ---------- rule_sizes : int or sequence of int, optional Explicit rule sizes to keep (e.g., 1, 2, 3). size_range : tuple(int, int), optional Inclusive (min_size, max_size) range of rule sizes to keep. copy : bool, default=True If True, return a filtered copy without mutating the original. """ if (rule_sizes is None) == (size_range is None): raise ValidationError( "Exactly one of rule_sizes or size_range must be provided", details={"rule_sizes": rule_sizes, "size_range": size_range}, ) normalized_sizes: Optional[set[int]] = None min_size: Optional[int] = None max_size: Optional[int] = None if rule_sizes is not None: if isinstance(rule_sizes, (int, np.integer)): normalized_sizes = {int(rule_sizes)} elif isinstance(rule_sizes, (list, tuple, set, np.ndarray)): normalized_sizes = {int(v) for v in rule_sizes} else: raise ValidationError( "rule_sizes must be an int or a sequence of ints", details={"rule_sizes": rule_sizes}, ) if not normalized_sizes: raise ValidationError( "rule_sizes must not be empty", details={"rule_sizes": rule_sizes} ) if any(size <= 0 for size in normalized_sizes): raise ValidationError( "rule_sizes must contain positive integers", details={"rule_sizes": sorted(normalized_sizes)}, ) if size_range is not None: if not isinstance(size_range, (list, tuple)) or len(size_range) != 2: raise ValidationError( "size_range must be a (min_size, max_size) tuple", details={"size_range": size_range}, ) min_size = int(size_range[0]) max_size = int(size_range[1]) if min_size <= 0 or max_size <= 0: raise ValidationError( "size_range bounds must be positive integers", details={"size_range": size_range}, ) if min_size > max_size: raise ValidationError( "size_range must satisfy min_size <= max_size", details={"size_range": size_range}, ) rules = self.get_rules() num_rules = len(rules.get("rule", [])) def _rule_size(feature: Any) -> int: if isinstance(feature, (list, tuple, np.ndarray)): return len(np.asarray(feature).ravel()) return 1 mask = [] features = rules.get("feature", [None] * num_rules) for i in range(num_rules): size = _rule_size(features[i]) if i < len(features) else 1 if normalized_sizes is not None: keep = size in normalized_sizes else: keep = min_size <= size <= max_size # type: ignore[operator] mask.append(keep) mask_array = np.asarray(mask, dtype=bool) def _clone_value(val: Any) -> Any: if isinstance(val, MappingProxyType): return {k: _clone_value(v) for k, v in val.items()} if isinstance(val, list): return list(val) if isinstance(val, np.ndarray): return val.copy() return val def _filter_value(val: Any) -> Any: if isinstance(val, list) and len(val) == num_rules: return [v for v, keep in zip(val, mask, strict=False) if keep] if isinstance(val, np.ndarray) and val.shape[0] == num_rules: return val[mask_array].copy() return _clone_value(val) filtered_rules = {k: _filter_value(v) for k, v in rules.items()} target = self.copy() if copy else self has_conjunctive = bool(getattr(target, "has_conjunctive_rules", False)) extractor = getattr( target, "_AlternativeExplanation__extracted_non_conjunctive_rules", None ) if has_conjunctive and callable(extractor): extractor(filtered_rules) target.has_conjunctive_rules = True return target target.rules = filtered_rules if has_conjunctive: target.conjunctive_rules = filtered_rules return target
# ------------------------------------------------------------------ # Telemetry helpers # ------------------------------------------------------------------
[docs] @staticmethod def to_python_number(value: Any) -> Any: """Convert numpy/scalar values to native Python types suitable for telemetry.""" if isinstance(value, np.generic): value = value.item() if isinstance(value, np.ndarray): return [CalibratedExplanation.to_python_number(v) for v in value.tolist()] if isinstance(value, (list, tuple)): return [CalibratedExplanation.to_python_number(v) for v in value] if value is None: return None if isinstance(value, (np.bool_, bool)): return bool(value) if isinstance(value, (np.integer, int)): return int(value) if isinstance(value, (np.floating, float)): if math.isnan(value): return None return float(value) return value
[docs] @staticmethod def normalize_percentile_value(value: Any) -> Optional[float]: """Normalise percentile inputs to decimal fractions.""" value = CalibratedExplanation.to_python_number(value) if value is None: return None if isinstance(value, (float, int)): value = float(value) if math.isinf(value): return value if abs(value) > 1.0: return value / 100.0 return value return None
def _get_percentiles(self) -> Optional[Tuple[Optional[float], Optional[float]]]: """Return decimal percentiles if available.""" percentiles = getattr(self.calibrated_explanations, "low_high_percentiles", None) if percentiles is None or len(percentiles) != 2: return None low = self.normalize_percentile_value(percentiles[0]) high = self.normalize_percentile_value(percentiles[1]) return (low, high)
[docs] def get_percentiles(self) -> Optional[Tuple[Optional[float], Optional[float]]]: """Return the decoded percentile configuration.""" return self._get_percentiles()
[docs] @staticmethod def compute_confidence_level( percentiles: Optional[Tuple[Optional[float], Optional[float]]], ) -> Optional[float]: """Compute confidence level from decimal percentiles.""" if not percentiles: return None low, high = percentiles if low is None or high is None: return None if low == -math.inf: return None if high in (None, math.inf) else high if high == math.inf: return None if low is None else 1 - low return max(0.0, high - low)
[docs] def normalize_threshold_value(self) -> Any: """Normalise threshold metadata to telemetry-friendly structure.""" threshold = self.y_threshold if threshold is None: return None if isinstance(threshold, np.ndarray): threshold = threshold.tolist() if isinstance(threshold, (list, tuple)): if len(threshold) == 0: return None values = [CalibratedExplanation.to_python_number(threshold[0])] if len(threshold) > 1: values.append(CalibratedExplanation.to_python_number(threshold[1])) return values return CalibratedExplanation.to_python_number(threshold)
def _build_uncertainty_payload( self, *, value: Any, low: Any, high: Any, representation: str, percentiles: Optional[Tuple[Optional[float], Optional[float]]] = None, threshold: Any = None, include_percentiles: bool = True, ) -> Dict[str, Any]: """Create a structured uncertainty payload.""" lower = CalibratedExplanation.to_python_number(low) upper = CalibratedExplanation.to_python_number(high) payload: Dict[str, Any] = { "representation": representation, "calibrated_value": CalibratedExplanation.to_python_number(value), "lower_bound": lower, "upper_bound": upper, "legacy_interval": [lower, upper], } payload["threshold"] = threshold payload["raw_percentiles"] = None payload["confidence_level"] = None if include_percentiles and percentiles: payload["raw_percentiles"] = [ CalibratedExplanation.to_python_number(percentiles[0]), CalibratedExplanation.to_python_number(percentiles[1]), ] confidence = CalibratedExplanation.compute_confidence_level(percentiles) if confidence is not None: payload["confidence_level"] = confidence return payload # Public alias for testing build_uncertainty_payload = _build_uncertainty_payload @staticmethod def _build_interval(low: Any, high: Any) -> Dict[str, Any]: """Return a minimal uncertainty interval with Python-native bounds.""" return { "lower": CalibratedExplanation.to_python_number(low), "upper": CalibratedExplanation.to_python_number(high), }
[docs] @staticmethod def build_interval(low: Any, high: Any) -> Dict[str, Any]: """Public helper exposing interval construction.""" return CalibratedExplanation._build_interval(low, high)
def _build_instance_uncertainty(self) -> Dict[str, Any]: """Build uncertainty payload for the current instance prediction.""" if self.is_thresholded(): return self._build_uncertainty_payload( value=self.prediction["predict"], low=self.prediction["low"], high=self.prediction["high"], representation="threshold", threshold=self.normalize_threshold_value(), include_percentiles=False, ) if self.is_probabilistic(): return self._build_uncertainty_payload( value=self.prediction["predict"], low=self.prediction["low"], high=self.prediction["high"], representation="venn_abers", include_percentiles=False, ) percentiles = self._get_percentiles() return self._build_uncertainty_payload( value=self.prediction["predict"], low=self.prediction["low"], high=self.prediction["high"], representation="percentile", percentiles=percentiles, include_percentiles=True, )
[docs] def build_instance_uncertainty(self) -> Dict[str, Any]: """Expose the instance uncertainty payload builder.""" return self._build_instance_uncertainty()
def _safe_feature_name(self, feature_index: Any) -> str: """Return a readable feature name for telemetry.""" feature_names = getattr(self.get_explainer(), "feature_names", None) try: idx = int(feature_index) except ( TypeError, ValueError, ): # ADR002_ALLOW: handle non-numeric feature ids. # pragma: no cover return str(feature_index) if feature_names and 0 <= idx < len(feature_names): return str(feature_names[idx]) return str(idx)
[docs] def safe_feature_name(self, feature_index: Any) -> str: """Return the human-readable name for a feature index.""" return self._safe_feature_name(feature_index)
[docs] @staticmethod def convert_condition_value(raw_value: Optional[str], fallback: Any) -> Any: """Convert textual condition payloads to structured values.""" if raw_value is None: return CalibratedExplanation.to_python_number(fallback) text = raw_value.strip() if text.lower() in {"-inf", "-infinity"}: return float("-inf") if text.lower() in {"inf", "+inf", "infinity"}: return float("inf") try: return float(text) except ( ValueError ): # ADR002_ALLOW: textual rule fragments may not be numeric. # pragma: no cover return text
def _parse_condition(self, feature_name: str, rule_text: str) -> Tuple[str, Optional[str]]: """Attempt to parse rule text into operator and value tokens.""" if not rule_text: return "raw", None text = rule_text.strip() pattern = rf"^{re.escape(feature_name)}\s*(<=|>=|==|=|<|>|in)\s*(.+)$" match = re.match(pattern, text) if match: operator = match.group(1) value_text = match.group(2).strip() if operator == "=": operator = "==" return operator.lower(), value_text return "raw", text
[docs] def parse_condition(self, feature_name: str, rule_text: str) -> Tuple[str, Optional[str]]: """Return the parsed operator/value pair for a rule.""" return self._parse_condition(feature_name, rule_text)
def _build_condition_payload( self, feature_index: Any, rule_text: str, feature_value: Any, display_value: Any, ) -> Dict[str, Any]: """Convert rule metadata into telemetry condition payload.""" feature_name = self._safe_feature_name(feature_index) operator, parsed_value = self._parse_condition(feature_name, rule_text) if operator == "raw": value = CalibratedExplanation.to_python_number(display_value) else: value = self.convert_condition_value(parsed_value, display_value) return { "feature": feature_name, "operator": operator, "value": value, "text": rule_text, }
[docs] def build_condition_payload( self, feature_index: Any, rule_text: str, feature_value: Any, display_value: Any, ) -> Dict[str, Any]: """Expose condition payload building for external use.""" return self._build_condition_payload(feature_index, rule_text, feature_value, display_value)
[docs] def to_telemetry(self) -> Dict[str, Any]: """Return telemetry payload for this explanation instance.""" payload = self.build_rules_payload() metadata = payload.get("metadata", {}) metadata.setdefault("prediction_uncertainty", self._build_instance_uncertainty()) return { "uncertainty": metadata["prediction_uncertainty"], "rules": payload, "metadata": metadata, }
[docs] def define_conditions(self): """ Define the rule conditions for an instance. Returns ------- list[str] A list of conditions for each feature in the instance. """ self.conditions = [] # pylint: disable=invalid-name explainer = self.get_explainer() if explainer.discretizer is None: # Handle missing discretizer (e.g. regression without discretization) # For now, just use empty conditions or skip # print("DEBUG: FactualExplanation.define_conditions: discretizer is None") pass else: x = explainer.discretizer.discretize(self.x_test) ignored = self.ignored_features_for_instance() for f in range(self.get_explainer().num_features): if f in ignored: self.conditions.append("") continue if explainer.discretizer is None: val = self.x_test[f] rule = f"{self.get_explainer().feature_names[f]} = {val}" self.conditions.append(rule) continue if f in self.get_explainer().categorical_features: if self.get_explainer().categorical_labels is not None: try: target = self.get_explainer().categorical_labels[f][int(x[f])] rule = f"{self.get_explainer().feature_names[f]} = {target}" except IndexError: rule = f"{self.get_explainer().feature_names[f]} = {x[f]}" else: rule = f"{self.get_explainer().feature_names[f]} = {x[f]}" else: rule = self.get_explainer().discretizer.names[f][int(x[f])] self.conditions.append(rule) return self.conditions
[docs] def predict_conjunction_tuple( self, rule_value_set, original_features, perturbed, threshold, predicted_class, bins=None, ): """Calculate the prediction for a conjunctive rule using batched inference.""" from ..utils.exceptions import ValidationError predict_fn = self.get_explainer().prediction_orchestrator.predict_internal # `perturbed` is expected to be a mutable scratch copy from the caller. perturbed = np.asarray(perturbed) if perturbed.ndim == 1: perturbed = perturbed.reshape(1, -1) elif perturbed.ndim != 2: raise ValidationError( "perturbed must be a 1D or 2D array-like", details={"param": "perturbed", "ndim": int(perturbed.ndim)}, ) try: original_features = [int(v) for v in original_features] except (TypeError, ValueError) as exc: raise ValidationError( "original_features must contain integer indices", details={"param": "original_features", "value": original_features}, ) from exc # Prepare value arrays value_iterables = [] for values in rule_value_set[: len(original_features)]: arr = np.asarray(values) if arr.ndim == 0: arr = arr.reshape(1) if arr.size == 0: return 0.0, 0.0, 0.0 value_iterables.append(arr) if not value_iterables: return 0.0, 0.0, 0.0 # Generate all combinations in a vectorized fashion grids = np.meshgrid(*value_iterables, indexing="ij") combo_matrix = np.stack(grids, axis=-1).reshape(-1, len(original_features)) if combo_matrix.size == 0: return 0.0, 0.0, 0.0 # Create batch batch = np.repeat(perturbed, combo_matrix.shape[0], axis=0) # Handle bins if provided batch_bins = None if bins is not None: # bins is typically a scalar or 1D array for the single instance being explained # We need to tile it to match the batch size bins_arr = np.asarray(bins) if bins_arr.ndim == 0: batch_bins = np.full(combo_matrix.shape[0], bins_arr.item()) else: batch_bins = np.tile(bins, combo_matrix.shape[0]) # Validate batch_bins length matches batch size if len(batch_bins) != batch.shape[0]: batch_bins = batch_bins[: batch.shape[0]] # Apply perturbations in bulk batch[:, original_features] = combo_matrix # Predict p_value, low, high, _ = predict_fn( batch, threshold=threshold, low_high_percentiles=self.calibrated_explanations.low_high_percentiles, classes=predicted_class, bins=batch_bins, ) return ( float(np.mean(p_value)), float(np.mean(low)), float(np.mean(high)), )
def _predict_conjunctive( self, rule_value_set, original_features, perturbed, threshold, predicted_class, bins=None, use_batched=False, ): """ Calculate the prediction for a conjunctive rule. Parameters ---------- rule_value_set : list The set of rule values. original_features : list The original feature indices. perturbed : array-like The perturbed dataset. threshold : float The threshold for classification or regression. predicted_class : int The predicted class label. bins : array-like, optional The bins for discretization. use_batched : bool, optional Whether to use batched inference. Returns ------- tuple The predicted value, lower bound, upper bound, and count. """ if len(original_features) < 2: from ..utils.exceptions import ValidationError raise ValidationError( "Conjunctive rules require at least two features", details={ "param": "original_features", "count": len(original_features), "requirement": "minimum 2 features", }, ) if use_batched: return self.predict_conjunction_tuple( rule_value_set, original_features, perturbed, threshold, predicted_class, bins, ) predict_fn = self.get_explainer().prediction_orchestrator.predict_internal # Ensure perturbed is a writable copy to avoid "read-only" errors perturbed = np.array(perturbed, copy=True) base_values = np.array([perturbed[idx] for idx in original_features], copy=True) rule_predict = 0.0 rule_low = 0.0 rule_high = 0.0 rule_count = 0 value_iterables = [ np.asarray(values) for values in rule_value_set[: len(original_features)] ] def _restore() -> None: for pos, feat_idx in enumerate(original_features): perturbed[feat_idx] = base_values[pos] try: if len(original_features) == 2: of1, of2 = original_features[:2] values1, values2 = value_iterables[:2] for value_1 in values1: perturbed[of1] = value_1 for value_2 in values2: perturbed[of2] = value_2 perturbed_row = perturbed.reshape(1, -1) p_value, low, high, _ = predict_fn( perturbed_row, threshold=threshold, low_high_percentiles=self.calibrated_explanations.low_high_percentiles, classes=predicted_class, bins=bins, ) rule_predict += float(safe_first_element(p_value)) rule_low += float(safe_first_element(low)) rule_high += float(safe_first_element(high)) rule_count += 1 else: of1, of2, of3 = original_features[:3] values1, values2, values3 = value_iterables[:3] for value_1 in values1: perturbed[of1] = value_1 for value_2 in values2: perturbed[of2] = value_2 for value_3 in values3: perturbed[of3] = value_3 perturbed_row = perturbed.reshape(1, -1) p_value, low, high, _ = predict_fn( perturbed_row, threshold=threshold, low_high_percentiles=self.calibrated_explanations.low_high_percentiles, classes=predicted_class, bins=bins, ) rule_predict += float(safe_first_element(p_value)) rule_low += float(safe_first_element(low)) rule_high += float(safe_first_element(high)) rule_count += 1 finally: _restore() if rule_count: rule_predict /= rule_count rule_low /= rule_count rule_high /= rule_count return rule_predict, rule_low, rule_high
[docs] def predict_conjunctive( self, rule_value_set, original_features, perturbed, threshold, predicted_class, bins=None, use_batched=False, ): """Public wrapper around the conjunctive prediction helper.""" return self._predict_conjunctive( rule_value_set, original_features, perturbed, threshold, predicted_class, bins=bins, use_batched=use_batched, )
@abstractmethod def _is_lesser(self, rule_boundary, instance_value): """Return True when an instance value satisfies a 'less than' rule boundary.""" pass # pylint: disable=too-many-arguments, too-many-statements, too-many-branches, too-many-return-statements
[docs] def add_new_rule_condition(self, feature, rule_boundary): """ Create a new rule condition for a numerical feature. Parameters ---------- feature : int or str The feature index or name. rule_boundary : int or float The value to define as rule condition. Returns ------- :class:`.CalibratedExplanation` Notes ----- The function will return the same explanation if the rule is already included or if the feature is categorical. No implementation is provided for the :class:`.FastExplanation` class. """ f = None if isinstance(feature, int): f = feature else: with contextlib.suppress(ValueError): f = self.get_explainer().feature_names.index(feature) if f is None: warnings.warn(f"Feature {feature} not found", stacklevel=2) return self if ( self.get_explainer().categorical_features is not None and f in self.get_explainer().categorical_features ): warnings.warn( "Alternatives for all categorical features are already included", stacklevel=2 ) return self x_copy = np.array(self.x_test, copy=True) is_lesser = self._is_lesser(rule_boundary, x_copy[f]) new_rule = self.get_rules() rule = self._get_rule_str(is_lesser, f, rule_boundary) if np.any([new_rule["rule"][i] == rule for i in range(len(new_rule["rule"]))]): warnings.warn("Rule already included", stacklevel=2) return self threshold = self.y_threshold perturbed_threshold = normalize_threshold(threshold) perturbed_bins = np.empty((0,)) if self.bin is not None else None perturbed_x = np.empty((0, self.get_explainer().num_features)) perturbed_feature = np.empty((0, 4)) # (feature, instance, bin_index, is_lesser) perturbed_class = np.empty((0,), dtype=int) cal_x_f = self.get_explainer().x_cal[:, f] feature_values = np.unique(np.array(cal_x_f)) sample_percentiles = self.get_explainer().sample_percentiles if is_lesser: if not np.any(feature_values < rule_boundary): warnings.warn( f"Lowest feature value for feature {feature} is {np.min(feature_values)}", stacklevel=2, ) return self values = np.percentile(cal_x_f[cal_x_f < rule_boundary], sample_percentiles) covered = np.percentile(cal_x_f[cal_x_f >= rule_boundary], sample_percentiles) else: if not np.any(feature_values > rule_boundary): warnings.warn( f"Highest feature value for feature {feature} is {np.max(feature_values)}", stacklevel=2, ) return self values = np.percentile(cal_x_f[cal_x_f > rule_boundary], sample_percentiles) covered = np.percentile(cal_x_f[cal_x_f <= rule_boundary], sample_percentiles) for value in values: x_local = np.reshape(x_copy, (1, -1)) x_local[0, f] = value perturbed_x = np.concatenate((perturbed_x, np.array(x_local))) perturbed_feature = np.concatenate((perturbed_feature, [(f, 0, None, is_lesser)])) perturbed_bins = ( np.concatenate((perturbed_bins, self.bin)) if self.bin is not None else None ) perturbed_class = np.concatenate( (perturbed_class, np.array([self.prediction["classes"]])) ) if isinstance(threshold, tuple): perturbed_threshold = threshold elif threshold is None: perturbed_threshold = None elif np.isscalar(perturbed_threshold) and perturbed_threshold == threshold: perturbed_threshold = threshold else: perturbed_threshold = np.concatenate((perturbed_threshold, threshold)) for value in covered: x_local = np.reshape(x_copy, (1, -1)) x_local[0, f] = value perturbed_x = np.concatenate((perturbed_x, np.array(x_local))) perturbed_feature = np.concatenate((perturbed_feature, [(f, 0, None, None)])) perturbed_bins = ( np.concatenate((perturbed_bins, self.bin)) if self.bin is not None else None ) perturbed_class = np.concatenate( (perturbed_class, np.array([self.prediction["classes"]])) ) if isinstance(threshold, tuple): perturbed_threshold = threshold elif threshold is None: perturbed_threshold = None elif np.isscalar(perturbed_threshold) and perturbed_threshold == threshold: perturbed_threshold = threshold else: perturbed_threshold = np.concatenate((perturbed_threshold, threshold)) predict, low, high, _ = self.get_explainer().prediction_orchestrator.predict_internal( perturbed_x, threshold=perturbed_threshold, low_high_percentiles=self.calibrated_explanations.low_high_percentiles, classes=perturbed_class, bins=perturbed_bins, ) instance_predict = [ predict[i] for i in range(len(predict)) if perturbed_feature[i][3] is None ] rule_predict = [ predict[i] for i in range(len(predict)) if perturbed_feature[i][3] is not None ] rule_low = [low[i] for i in range(len(low)) if perturbed_feature[i][3] is not None] rule_high = [high[i] for i in range(len(high)) if perturbed_feature[i][3] is not None] # skip if identical to original if self.prediction["low"] == safe_mean(rule_low) and self.prediction["high"] == safe_mean( rule_high ): warnings.warn( "The alternative explanation is identical to the original explanation", UserWarning, stacklevel=2, ) return self new_rule["predict"].append(safe_mean(rule_predict)) new_rule["predict_low"].append(safe_mean(rule_low)) new_rule["predict_high"].append(safe_mean(rule_high)) new_rule["weight"].append(safe_mean(rule_predict) - safe_mean(instance_predict)) new_rule["weight_low"].append( safe_mean(rule_low) - safe_mean(instance_predict) if rule_low != -np.inf else rule_low ) new_rule["weight_high"].append( safe_mean(rule_high) - safe_mean(instance_predict) if rule_high != np.inf else rule_high ) new_rule["value"].append(str(np.around(self.x_test[f], decimals=2))) new_rule["feature"].append(f) new_rule["sampled_values"].append(self.binned["rule_values"][f][0][0]) new_rule["feature_value"].append(self.x_test[f]) new_rule["is_conjunctive"].append(False) new_rule["rule"].append(rule) self.rules = new_rule return self
def _get_rule_str(self, is_lesser, feature, rule_boundary): """Get the rule string for the explanation. Parameters ---------- is_lesser : bool Whether the rule is a lesser condition. feature : str The feature name. rule_boundary : float The rule boundary value. Returns ------- str The rule string. """ if is_lesser: return f"{self.get_explainer().feature_names[feature]} < {rule_boundary:.2f}" return f"{self.get_explainer().feature_names[feature]} > {rule_boundary:.2f}"
# pylint: disable=too-many-instance-attributes, too-many-locals, too-many-arguments
[docs] class FactualExplanation(CalibratedExplanation): """Store and visualise calibrated factual explanations. The public contract mirrors the CE definition established in the classification and regression papers: a factual explanation couples the calibrated prediction and its uncertainty interval with a collection of factual feature rules. Each rule binds the observed feature value to a condition and exposes the calibrated feature weight plus its uncertainty interval. Downstream helpers (telemetry, JSON export, plots) rely on this invariant, so the internal representation and helper payloads **must** keep the prediction + interval pair alongside weight + interval information for every factual rule. For detailed information about the internal data structures and rule generation process, see docs/foundations/concepts/explanation_structures.md. """ def __init__( self, calibrated_explanations, index, x, binned, feature_weights, feature_predict, prediction, y_threshold=None, instance_bin=None, condition_source: str = "prediction", ): """Class for storing and visualizing factual explanations. Provides factual explanations for a given instance, highlighting features that contribute to the model's prediction. Initialize a FactualExplanation instance. Parameters ---------- calibrated_explanations : CalibratedExplanations The parent CalibratedExplanations object. index : int The index of the instance being explained. x : array-like The test dataset containing the instances to be explained. binned : dict A mapping of binned feature values. feature_weights : dict A mapping of feature weights. feature_predict : dict A mapping of feature predictions. prediction : dict A mapping containing the prediction results. y_threshold : float or tuple, optional The threshold for binary classification or regression explanations. instance_bin : int, optional The bin index of the instance. """ super().__init__( calibrated_explanations, index, x, binned, feature_weights, feature_predict, prediction, y_threshold, instance_bin, condition_source=condition_source, ) self._check_preconditions() self.get_rules() # Cache per-instance prediction probabilities for golden baseline (classification) try: if not self.is_regression(): # Access stored full probability matrix via parent prediction mapping full_probs = self.prediction.get("__full_probabilities__") if full_probs is not None: # full_probs may be a tuple (proba_matrix, classes) for multiclass if isinstance(full_probs, tuple) and len(full_probs) >= 1: proba_matrix = full_probs[0] else: proba_matrix = full_probs # Attach whole matrix on first explanation, then propagate if self.index == 0: self.prediction_probabilities = proba_matrix else: # ensure earlier explanation already stored full matrix self.prediction_probabilities = getattr( self.calibrated_explanations.explanations[0], "prediction_probabilities", proba_matrix, ) else: self.prediction_probabilities = None except ( Exception ): # ADR002_ALLOW: prediction payloads vary across plugins. # pragma: no cover self.prediction_probabilities = None def __repr__(self): """Return a string representation of the factual explanation.""" # Use canonical rules to ensure parity with plot and narrative canonical_rules = self._rules_with_impact() predict = self.prediction output = [ f"{'Prediction':10} [{' Low':5}, {' High':5}]", f"{predict['predict']:5.3f} [{predict['low']:5.3f}, {predict['high']:5.3f}]", f"{'Value':6}: {'Feature':40s} {'Weight':6} [{' Low':6}, {' High':6}]", ] for r in canonical_rules: output.append( f"{str(r.value):6}: {r.text:40s} {r.impact:>6.3f} [{r.weight_envelope_low:>6.3f}, {r.weight_envelope_high:>6.3f}]" ) return "\n".join(output) + "\n"
[docs] def build_rules_payload(self) -> Dict[str, Any]: """Return structured payload describing factual feature rules.""" rules = self.get_rules() prediction_value = CalibratedExplanation.to_python_number(self.prediction.get("predict")) prediction_interval = CalibratedExplanation._build_interval( self.prediction.get("low"), self.prediction.get("high"), ) core: Dict[str, Any] = { "kind": "factual", "prediction": { "value": prediction_value, "uncertainty_interval": prediction_interval, }, "feature_rules": [], } metadata: Dict[str, Any] = {"feature_rules": []} if not rules or "rule" not in rules: return {"core": core, "metadata": metadata} base_predict = rules.get("base_predict", [None]) base_low = rules.get("base_predict_low", [None]) base_high = rules.get("base_predict_high", [None]) baseline_value = CalibratedExplanation.to_python_number(base_predict[0]) if baseline_value is not None: metadata["baseline_prediction"] = baseline_value metadata["baseline_interval"] = CalibratedExplanation._build_interval( base_low[0], base_high[0], ) percentiles = None if not self.is_probabilistic(): percentiles = self._get_percentiles() representation = "venn_abers" if self.is_probabilistic() else "percentile" count = len(rules.get("rule", [])) for idx in range(count): feature_index = rules["feature"][idx] condition = self._build_condition_payload( feature_index, rules["rule"][idx], rules["feature_value"][idx], rules["value"][idx], ) weight_value = CalibratedExplanation.to_python_number(rules["weight"][idx]) weight_interval = CalibratedExplanation._build_interval( rules["weight_low"][idx], rules["weight_high"][idx], ) core["feature_rules"].append( { "weight": { "value": weight_value, "uncertainty_interval": weight_interval, }, "condition": condition, } ) weight_uncertainty = self._build_uncertainty_payload( value=weight_value, low=rules["weight_low"][idx], high=rules["weight_high"][idx], representation=representation, percentiles=percentiles if representation == "percentile" else None, include_percentiles=representation == "percentile", ) prediction_representation = "threshold" if self.is_thresholded() else representation prediction_uncertainty = self._build_uncertainty_payload( value=rules["predict"][idx], low=rules["predict_low"][idx], high=rules["predict_high"][idx], representation=prediction_representation, percentiles=( percentiles if prediction_representation == "percentile" and not self.is_thresholded() else None ), threshold=self.normalize_threshold_value() if self.is_thresholded() else None, include_percentiles=( prediction_representation == "percentile" and not self.is_thresholded() ), ) metadata_rule: Dict[str, Any] = { "feature": self._safe_feature_name(feature_index), "feature_index": CalibratedExplanation.to_python_number(feature_index), "weight_uncertainty": weight_uncertainty, "prediction_uncertainty": prediction_uncertainty, "prediction_value": CalibratedExplanation.to_python_number(rules["predict"][idx]), "condition_text": rules["rule"][idx], "instance_value": CalibratedExplanation.to_python_number( rules["feature_value"][idx] ), } metadata["feature_rules"].append(metadata_rule) metadata["prediction_uncertainty"] = self._build_instance_uncertainty() return {"core": core, "metadata": metadata}
def _check_preconditions(self): """Warn when the selected discretizer is incompatible with factual explanations.""" if self.is_regression(): if not isinstance(self.get_explainer().discretizer, BinaryRegressorDiscretizer): warnings.warn( "Factual explanations for regression recommend using the binaryRegressor " + "discretizer. Consider extracting factual explanations using " + "`explainer.explain_factual(test_set)`", stacklevel=2, ) elif not isinstance(self.get_explainer().discretizer, BinaryEntropyDiscretizer): warnings.warn( "Factual explanations for classification recommend using the " + "binaryEntropy discretizer. Consider extracting factual " + "explanations using `explainer.explain_factual(test_set)`", stacklevel=2, ) def _rules_with_impact( self, *, top_k: Optional[int] = None, sort: bool = True ) -> list[RuleWithImpact]: """Extract canonical rules with explicit signed impact. This method is the source of truth for narrative and plotting consistency. It defines 'impact' for FactualExplanation as the weight (delta from baseline). """ rules_dict = self.get_rules() canonical_rules = [] # Get global prediction prediction = self.prediction["predict"] num_rules = len(rules_dict.get("rule", [])) for i in range(num_rules): # weight in FactualExplanation is defined as (prediction - instance_predict) # where instance_predict is the counterfactual/base w = rules_dict["weight"][i] # Canonical sign behavior: # Positive impact = rule increased the prediction relative to base # Negative impact = rule decreased the prediction relative to base if w > 0: direction = "positive" elif w < 0: direction = "negative" else: direction = "neutral" feature_id = rules_dict["feature"][i] if isinstance(feature_id, (list, tuple, np.ndarray)): feature_names = [] for idx in list(feature_id): with contextlib.suppress(Exception): feature_names.append(str(self._safe_feature_name(idx))) feature_label = " & ".join(feature_names) if feature_names else str(feature_id) else: feature_label = str(self._safe_feature_name(feature_id)) base_predicts = rules_dict.get("base_predict", []) base_predict_value = ( base_predicts[i] if i < len(base_predicts) else (base_predicts[0] if base_predicts else None) ) canonical_rules.append( RuleWithImpact( rule_id=str(feature_id), # Using feature index/list as ID for now feature=feature_label, text=rules_dict["rule"][i], impact=float(w), direction=direction, base_predict=float(base_predict_value) if base_predict_value is not None else float("nan"), predict=float(prediction), value=rules_dict["value"][i], weight_envelope_low=float(rules_dict["weight_low"][i]), weight_envelope_high=float(rules_dict["weight_high"][i]), predict_low=float(rules_dict["predict_low"][i]), predict_high=float(rules_dict["predict_high"][i]), ) ) # Stable sort by absolute impact if sort: canonical_rules.sort(key=lambda r: (-abs(r.impact), r.text)) if top_k is not None: canonical_rules = canonical_rules[:top_k] return canonical_rules # -- Convenience retrieval API ------------------------------------------------- def _factual_rule_to_dict(self, rules: Dict[str, Any], idx: int) -> Dict[str, Any]: """Normalize an internal factual rule into a user-friendly dict.""" feature_index = rules["feature"][idx] weight_value = CalibratedExplanation.to_python_number(rules["weight"][idx]) weight_interval = CalibratedExplanation._build_interval( rules["weight_low"][idx], rules["weight_high"][idx] ) return { "index": int(idx), "feature": self._safe_feature_name(feature_index), "condition": rules["rule"][idx], "weight": weight_value, "uncertainty_interval": weight_interval, "support": rules.get("support", [None])[idx] if isinstance(rules.get("support"), list) else None, }
[docs] def get_rule_by_index(self, index: int) -> Dict[str, Any]: """Return a single factual rule by its numeric index. Raises ``IndexError`` when the index is out of range. """ rules = self.get_rules() count = len(rules.get("rule", [])) if index < 0 or index >= count: raise IndexError(f"No factual rule at index {index}") return self._factual_rule_to_dict(rules, index)
[docs] def get_rules_by_feature(self, feature: str) -> list[Dict[str, Any]]: """Return all factual rules that mention the given feature name. Raises ``KeyError`` when no matching rules are found. """ rules = self.get_rules() matches: list[Dict[str, Any]] = [] for i, f in enumerate(rules.get("feature", [])): try: fname = self._safe_feature_name(f) except (TypeError, ValueError): fname = str(f) if fname == feature: matches.append(self._factual_rule_to_dict(rules, i)) if not matches: raise KeyError(f"No factual rules for feature {feature}") return matches
[docs] def list_rules(self) -> list[Dict[str, Any]]: """Return all factual rules as normalized dicts.""" rules = self.get_rules() return [self._factual_rule_to_dict(rules, i) for i in range(len(rules.get("rule", [])))]
[docs] def get_rules(self): """ Create factual rules. Returns ------- List[Dict[str, List]] A list of dictionaries containing the factual rules. """ if ( getattr(self, "has_conjunctive_rules", False) and getattr(self, "conjunctive_rules", None) is not None ): return self.conjunctive_rules if getattr(self, "has_rules", False) and isinstance(self.rules, dict): return self.rules # i = self.index instance = np.array(self.x_test, copy=True) state_helper = ConjunctionState(None) state_helper.state["classes"] = self.prediction["classes"] state_helper.set_base_prediction( self.prediction["predict"], self.prediction["low"], self.prediction["high"] ) rules = self.define_conditions() ignored = self.ignored_features_for_instance() for f, _ in enumerate(instance): # pylint: disable=invalid-name if f in ignored: continue if self.prediction["predict"] == self.feature_predict["predict"][f]: continue value_str = "" if f in self.get_explainer().categorical_features: if self.get_explainer().categorical_labels is not None: value_str = self.get_explainer().categorical_labels[f][int(instance[f])] else: value_str = str(instance[f]) else: value_str = str(np.around(instance[f], decimals=2)) # Calculate weights matching legacy behavior (prediction - instance_predict) # and using instance_predict uncertainty instance_predict = self.feature_predict["predict"][f] instance_low = self.feature_predict["low"][f] instance_high = self.feature_predict["high"][f] prediction = self.prediction["predict"] w = prediction - instance_predict w_low = prediction - instance_high if instance_high != np.inf else -np.inf w_high = prediction - instance_low if instance_low != -np.inf else np.inf state_helper.add_rule( predict=self.prediction["predict"], low=self.prediction["low"], high=self.prediction["high"], base_predict=instance_predict, value=value_str, feature=f, sampled_values=self.binned["rule_values"][f][0][-1], feature_value=self.x_test[f], rule_text=rules[f], is_conjunctive=False, weight=w, weight_low=w_low, weight_high=w_high, ) self.rules = state_helper.get_state() self.has_rules = True return self.rules
# pylint: disable=too-many-locals, too-many-branches, too-many-statements
[docs] def add_conjunctions(self, n_top_features=5, max_rule_size=2, **kwargs): """Add conjunctive factual rules. Builds conjunctions by iterating over feature combinations from ``current_size = 2`` up to ``max_rule_size``. For each size the method ranks existing rules, then pairs every outer-loop feature (from the original factual rules) with top-ranked inner-loop candidates stored in ``ConjunctionState``. Duplicate feature combinations are deduplicated and predictions are obtained via :meth:`predict_conjunctive`. Parameters ---------- n_top_features : int, optional Number of top features to consider when ranking candidates. max_rule_size : int, optional Maximum number of features in a single conjunctive rule. Other Parameters ---------------- _use_batched : bool Use vectorized batch prediction (default ``True``). _limit_outer_to_ranked : bool Restrict outer loop to rank-filtered features (default ``False``). _dedupe_by_feature_only : bool Deduplicate by feature index only, ignoring sampled values (default ``True``). raise_on_predict_error : bool Re-raise prediction exceptions instead of silently skipping (default ``False``). _fallback_to_legacy_on_zero : bool Fall back to the legacy implementation when no conjunctions are created (default ``False``). Returns ------- self : :class:`.FactualExplanation` Returns a self reference, to allow for method chaining. Notes ----- After the call, ``self.conjunction_stats`` contains a diagnostic dict with keys ``"attempts"``, ``"created"``, ``"skipped"`` (:class:`~collections.Counter`), and ``"predict_errors"`` (list of up to 5 captured exception messages). """ use_batched = kwargs.get("_use_batched", True) limit_outer_to_ranked = kwargs.get("_limit_outer_to_ranked", False) dedupe_by_feature_only = kwargs.get("_dedupe_by_feature_only", True) raise_on_predict_error = kwargs.get("raise_on_predict_error", False) if max_rule_size >= 4 and not use_batched: from ..utils.exceptions import ConfigurationError raise ConfigurationError( "max_rule_size >= 4 requires batched execution (internal flag _use_batched=True)", details={ "param": "max_rule_size", "value": max_rule_size, "valid_range": [2, 3], }, ) if max_rule_size < 2: return self if not use_batched: from .legacy_conjunctions import add_conjunctions_factual_legacy add_conjunctions_factual_legacy( self, n_top_features=n_top_features, max_rule_size=max_rule_size ) return self factual = self.get_rules() if not self.has_rules else self.rules state_helper = ConjunctionState( self.conjunctive_rules if self.has_conjunctive_rules and self.conjunctive_rules is not None else factual, dedupe_by_feature_only=dedupe_by_feature_only, ) self.has_conjunctive_rules = False self.conjunctive_rules = [] threshold = None if self.y_threshold is None else self.y_threshold scratch = np.array(self.x_test, copy=True) predicted_class = factual["classes"] state_helper.state["classes"] = predicted_class base_weight_array = ( np.asarray(factual["weight"], dtype=float) if factual["weight"] else np.array([]) ) base_width_array = ( np.asarray(factual["weight_high"], dtype=float) - np.asarray(factual["weight_low"], dtype=float) if factual["weight"] else np.array([]) ) if n_top_features is None: n_top_features = len(factual["rule"]) def _feature_length(candidate: Any) -> int: if isinstance(candidate, (list, tuple, np.ndarray)): return len(candidate) return 1 def _coerce_feature_scalar(value: Any) -> int: if isinstance(value, (list, tuple, np.ndarray)): return int(np.asarray(value).ravel()[0]) return int(value) from collections import Counter self.conjunction_stats = { "attempts": 0, "created": 0, "skipped": Counter(), "predict_errors": [], } stats = self.conjunction_stats max_logged_errors = 5 def _summarize(arr): if arr is None or arr.size == 0: return {"size": 0} return { "size": int(arr.size), "min": float(np.nanmin(arr)), "max": float(np.nanmax(arr)), "nan": int(np.isnan(arr).sum()), } for current_size in range(2, max_rule_size + 1): num_rules = len(factual["rule"]) if num_rules == 0: break weights_array = state_helper.get_weights() width_array = state_helper.get_widths() top_n = min(num_rules, n_top_features) top_conjunctives = list( self.rank_features( weights_array, width=width_array, num_to_show=max(2, top_n) if num_rules >= 2 else top_n, ) ) if not top_conjunctives and num_rules > 0: top_conjunctives = list(range(num_rules)) # Determine outer loop candidates # Legacy behavior: iterate all features if limit_outer_to_ranked: num_outer = min(num_rules, n_top_features) outer_indices = list( self.rank_features( base_weight_array, width=base_width_array if base_width_array.size else None, num_to_show=max(2, num_outer) if num_rules >= 2 else num_outer, ) ) if not outer_indices and num_rules > 0: outer_indices = list(range(num_rules)) else: outer_indices = range(len(factual["feature"])) for f1 in outer_indices: of1 = int(factual["feature"][f1]) sampled_values1 = factual["sampled_values"][f1] rule_value1 = ( sampled_values1 if isinstance(sampled_values1, np.ndarray) else [sampled_values1] ) for cf2 in top_conjunctives: stats["attempts"] += 1 rule_values = [rule_value1] original_features = [of1] of2 = state_helper.get_feature(cf2) target_length = current_size - 1 if _feature_length(of2) != target_length: stats["skipped"]["len_mismatch"] += 1 continue if state_helper.is_conjunctive(cf2): if of1 in of2: stats["skipped"]["same_feature"] += 1 continue original_features.extend(int(v) for v in of2) rule_values.extend(list(state_helper.get_sampled_values(cf2))) else: of2 = _coerce_feature_scalar(of2) if of1 == of2: stats["skipped"]["same_feature"] += 1 continue original_features.append(of2) sampled_values2 = state_helper.get_sampled_values(cf2) rule_values.append( sampled_values2 if isinstance(sampled_values2, np.ndarray) else [sampled_values2] ) if state_helper.has_combination_key(original_features, rule_values): stats["skipped"]["duplicate_combo"] += 1 continue state_helper.register_combination_key(original_features, rule_values) try: rule_predict, rule_low, rule_high = self.predict_conjunctive( rule_values, original_features, scratch, threshold, predicted_class, bins=self.bin, use_batched=use_batched, ) except ( CalibratedError, ValueError, TypeError, RuntimeError, Exception, # adr002_allow - defensive guard for predict_conjunctive failures ) as e: if raise_on_predict_error: raise stats["skipped"]["predict_error"] += 1 if len(stats["predict_errors"]) < max_logged_errors: stats["predict_errors"].append(f"{type(e).__name__}: {e}") continue state_helper.add_rule( predict=rule_predict, low=rule_low, high=rule_high, base_predict=self.prediction["predict"], value=factual["value"][f1] + "\n" + state_helper.get_value(cf2), feature=list(original_features), sampled_values=list(rule_values), feature_value=None, rule_text=factual["rule"][f1] + " & \n" + state_helper.get_rule(cf2), ) stats["created"] += 1 self.conjunctive_rules = state_helper.get_state() self.has_conjunctive_rules = True if stats["created"] == 0 and stats["attempts"] > 0: summary_weights = _summarize(base_weight_array) import warnings err_msg = ( f" predict_errors={stats['predict_errors']}" if stats["predict_errors"] else "" ) warning_msg = ( f"add_conjunctions: created={stats['created']} " f"attempts={stats['attempts']} skipped={dict(stats['skipped'])} " f"weights={summary_weights}{err_msg}" ) if kwargs.get("_fallback_to_legacy_on_zero", False): warnings.warn(warning_msg + " (falling back to legacy)", UserWarning, stacklevel=2) from .legacy_conjunctions import add_conjunctions_factual_legacy return add_conjunctions_factual_legacy( self, n_top_features=n_top_features, max_rule_size=max_rule_size ) if kwargs.get("verbose", False): warnings.warn(warning_msg, UserWarning, stacklevel=2) return self
def _is_lesser(self, rule_boundary, instance_value): """Return whether `instance_value` falls below the provided rule boundary.""" return instance_value < rule_boundary
[docs] def plot(self, filter_top=None, **kwargs): """Plot the factual explanation for a given instance. Parameters ---------- filter_top : int, optional The number of top features to display. **kwargs : dict Additional plotting arguments: - show (bool): default=True if filename is empty, False otherwise. Determines whether the plot should be displayed or not. - filename (str): default=''. The full path and filename of the plot image file that will be saved. If not provided or empty, the plot will not be saved. - uncertainty (bool): default=False. Whether to plot the uncertainty intervals for the feature weights. - style (str): default='regular'. The style of the plot. Possible styles: * 'regular' - a regular plot with feature weights and uncertainty intervals. * 'narrative' - generate human-readable narrative explanations. - rnk_metric (str): default='feature_weight'. The metric used to rank the features. Supported metrics are 'ensured', 'feature_weight', and 'uncertainty'. - rnk_weight (float): default=0.5. The weight of the uncertainty in the ranking. Used with the 'ensured' ranking metric. """ requested_style = kwargs.get("style") custom_plot_style = isinstance(requested_style, str) and requested_style not in { "regular", "triangular", "ensured", "narrative", } # Ensure style_override gets passed through style_override = kwargs.pop("style_override", None) plot_use_legacy = kwargs.pop("use_legacy", None) # PlotSpec request forces new renderer if kwargs.get("return_plot_spec") or custom_plot_style: plot_use_legacy = False # Phase 2 Option B: Default to legacy to ensure parity until PlotSpec is fully hardened elif plot_use_legacy is None: plot_use_legacy = True filename = kwargs.pop("filename", "") show = kwargs.pop("show", filename == "") uncertainty = kwargs.pop("uncertainty", False) rnk_metric = kwargs.pop("rnk_metric", "feature_weight") if rnk_metric is None: rnk_metric = "feature_weight" rnk_weight = kwargs.pop("rnk_weight", 0.5) if rnk_metric == "uncertainty": rnk_weight = 1.0 rnk_metric = "ensured" # Consistency guard: one-sided intervals cannot show uncertainty bands if uncertainty and self.is_one_sided(): raise Warning("Interval plot is not supported for one-sided explanations.") # Use conjunctive rules when available so that conjunctions appear in plots if getattr(self, "has_conjunctive_rules", False) and getattr( self, "conjunctive_rules", None ): factual = self.conjunctive_rules else: factual = self.get_rules() # get_explanation(index) self._check_preconditions() predict = self.prediction num_features_to_show = len(factual["weight"]) if filter_top is None: filter_top = num_features_to_show filter_top = np.min([num_features_to_show, filter_top]) if filter_top <= 0: warnings.warn( f"The explanation has no rules to plot. The index of the instance is {self.index}", stacklevel=2, ) return if len(filename) > 0: path, filename, title, ext = prepare_for_saving(filename) path = f"plots/{path}" save_ext = [ext] else: path = "" title = "" save_ext = [] if uncertainty: feature_weights = { "predict": factual["weight"], "low": factual["weight_low"], "high": factual["weight_high"], } # Phase 2: Inject canonical color roles if using PlotSpec if not plot_use_legacy: # Derive canonical directions from authoritative source canonical_rules = self._rules_with_impact(sort=False) feature_weights["color_role"] = [r.direction for r in canonical_rules] else: if not plot_use_legacy: # Wrap in dict to pass color info to PlotSpec canonical_rules = self._rules_with_impact(sort=False) feature_weights = { "predict": factual["weight"], "color_role": [r.direction for r in canonical_rules], } else: feature_weights = factual["weight"] width = np.reshape( np.array(factual["weight_high"]) - np.array(factual["weight_low"]), (len(factual["weight"])), ) if rnk_metric == "feature_weight": features_to_plot = self.rank_features( factual["weight"], width=width, num_to_show=filter_top ) else: ranking = calculate_metrics( uncertainty=[ factual["predict_high"][i] - factual["predict_low"][i] for i in range(len(factual["weight"])) ], prediction=factual["predict"], w=rnk_weight, metric=rnk_metric, ) features_to_plot = self.rank_features(width=ranking, num_to_show=filter_top) # Prefer explicit feature/column names when available; fall back to rule strings column_names = ( factual.get("feature_names") or factual.get("column_names") or factual.get("rule") ) try: if "classification" in self.get_explainer().mode or self.is_thresholded(): return plot_probabilistic( self, factual["value"], predict, feature_weights, features_to_plot, filter_top, column_names, title=title, path=path, interval=uncertainty, show=show, idx=self.index, save_ext=save_ext, style_override=style_override, use_legacy=plot_use_legacy, **kwargs, ) else: return plot_regression( self, factual["value"], predict, feature_weights, features_to_plot, filter_top, column_names, title=title, path=path, interval=uncertainty, show=show, idx=self.index, save_ext=save_ext, style_override=style_override, use_legacy=plot_use_legacy, **kwargs, ) except ( Exception ) as e: # ADR002_ALLOW: plot renderers may fail on headless hosts. # pragma: no cover if isinstance(e, RuntimeError) and "Agg" in str(e): from ..utils.exceptions import ConfigurationError raise ConfigurationError( "Matplotlib backend 'Agg' does not support show(). " "Either set show=False or switch to a different backend.", details={ "backend": "Agg", "operation": "show()", "solution": "set show=False or use interactive backend", }, ) from e raise # core-only test runs do not fail when visualization extras are # unavailable. Tests that require viz should use pytest.importorskip. warnings.warn( f"Plotting unavailable: {e}", UserWarning, stacklevel=2, ) return None
[docs] class AlternativeExplanation(CalibratedExplanation): """Store and visualise calibrated alternative explanations. Consistent with the CE papers, alternative explanations surface a collection of alternative feature rules. Each rule pairs an alternative condition for the feature with the calibrated prediction estimate and its uncertainty interval for that scenario. Feature-weight deltas are retained internally for ranking and metadata, but the user-facing payload **must not** replace the prediction + interval pair with weights—the prediction interval is the authoritative quantity for each alternative rule. For detailed information about the internal data structures and rule generation process, see docs/foundations/concepts/explanation_structures.md. """ def __init__( self, calibrated_explanations, index, x, binned, feature_weights, feature_predict, prediction, y_threshold=None, instance_bin=None, condition_source: str = "prediction", ): """Class representing an alternative explanation for a given instance. Offers alternative explanations by exploring how changes to feature values could alter the model's prediction. Initialize an AlternativeExplanation instance. Parameters ---------- calibrated_explanations : CalibratedExplanations The parent CalibratedExplanations object. index : int The index of the instance being explained. x : array-like The test dataset containing the instances to be explained. binned : dict A mapping of binned feature values. feature_weights : dict A mapping of feature weights. feature_predict : dict A mapping of feature predictions. prediction : dict A mapping containing the prediction results. y_threshold : float or tuple, optional The threshold for binary classification or regression explanations. instance_bin : int, optional The bin index of the instance. """ super().__init__( calibrated_explanations, index, x, binned, feature_weights, feature_predict, prediction, y_threshold, instance_bin, condition_source=condition_source, ) self._check_preconditions() self.has_rules = False self.get_rules() self.__is_super_explanation = False self.__is_semi_explanation = False self.__is_counter_explanation = False def __repr__(self): """Return a string representation of the alternative explanation.""" alternative = self.get_rules() output = [ f"{'Prediction':10} [{' Low':5}, {' High':5}]", f"{alternative['base_predict'][0]:5.3f} [{alternative['base_predict_low'][0]:5.3f}, {alternative['base_predict_high'][0]:5.3f}]", f"{'Value':6}: {'Feature':40s} {'Prediction':10} [{' Low':6}, {' High':6}]", ] feature_order = self.rank_features( alternative["weight"], width=np.array(alternative["weight_high"]) - np.array(alternative["weight_low"]), num_to_show=len(alternative["rule"]), ) output.extend( f"{alternative['value'][f]:6}: {alternative['rule'][f]:40s} {alternative['predict'][f]:>6.3f} [{alternative['predict_low'][f]:>6.3f}, {alternative['predict_high'][f]:>6.3f}]" for f in reversed(feature_order) ) return "\n".join(output) + "\n" def _rules_with_impact( self, *, top_k: Optional[int] = None, sort: bool = True ) -> list[RuleWithImpact]: """Extract canonical rules with explicit signed impact for alternative explanations. For alternative explanations, impact represents the change in prediction from the base instance prediction to the alternative scenario prediction. """ rules_dict = self.get_rules() canonical_rules = [] # Base prediction is the instance prediction base_predict = self.prediction["predict"] base_predict_value = base_predict if isinstance(base_predict_value, (list, tuple, np.ndarray)): base_predict_value = base_predict_value[0] if len(base_predict_value) else 0.0 def _format_feature(feature_index: Any) -> tuple[str, str]: if isinstance(feature_index, (list, tuple, np.ndarray)): raw_ids = [] names = [] for idx in feature_index: try: raw_ids.append(str(int(idx))) except (TypeError, ValueError): # ADR002_ALLOW: tolerate non-numeric ids. raw_ids.append(str(idx)) names.append(self._safe_feature_name(idx)) return ",".join(raw_ids), " & ".join(names) return str(feature_index), self._safe_feature_name(feature_index) num_rules = len(rules_dict.get("rule", [])) for i in range(num_rules): # weight in AlternativeExplanation is defined as (alternative_predict - instance_predict) w = rules_dict["weight"][i] # Canonical sign behavior: # Positive impact = alternative prediction > base prediction # Negative impact = alternative prediction < base prediction if w > 0: direction = "positive" elif w < 0: direction = "negative" else: direction = "neutral" feature_index = rules_dict["feature"][i] rule_id, feature_name = _format_feature(feature_index) canonical_rules.append( RuleWithImpact( rule_id=rule_id, feature=feature_name, text=rules_dict["rule"][i], impact=float(w), direction=direction, base_predict=float(base_predict_value), predict=float(rules_dict["predict"][i]), # Alternative prediction value=rules_dict["value"][i], weight_envelope_low=float(rules_dict["weight_low"][i]), weight_envelope_high=float(rules_dict["weight_high"][i]), predict_low=float(rules_dict["predict_low"][i]), predict_high=float(rules_dict["predict_high"][i]), ) ) # Stable sort by absolute impact if sort: canonical_rules.sort(key=lambda r: (-abs(r.impact), r.text)) if top_k is not None: canonical_rules = canonical_rules[:top_k] return canonical_rules
[docs] def build_rules_payload(self) -> Dict[str, Any]: """Return structured payload describing alternative feature rules.""" rules = self.get_rules() core: Dict[str, Any] = {"kind": "alternative", "feature_rules": []} metadata: Dict[str, Any] = {"feature_rules": []} if not rules or "rule" not in rules: return {"core": core, "metadata": metadata} percentiles = None if not self.is_probabilistic() and not self.is_thresholded(): percentiles = self._get_percentiles() prediction_representation = ( "threshold" if self.is_thresholded() else ("venn_abers" if self.is_probabilistic() else "percentile") ) weight_representation = "venn_abers" if self.is_probabilistic() else "percentile" count = len(rules.get("rule", [])) for idx in range(count): feature_index = rules["feature"][idx] condition = self._build_condition_payload( feature_index, rules["rule"][idx], rules["sampled_values"][idx], rules["value"][idx], ) prediction_value = CalibratedExplanation.to_python_number(rules["predict"][idx]) prediction_interval = CalibratedExplanation._build_interval( rules["predict_low"][idx], rules["predict_high"][idx], ) core["feature_rules"].append( { "prediction": { "value": prediction_value, "uncertainty_interval": prediction_interval, }, "condition": condition, } ) prediction_uncertainty = self._build_uncertainty_payload( value=rules["predict"][idx], low=rules["predict_low"][idx], high=rules["predict_high"][idx], representation=prediction_representation, percentiles=(percentiles if prediction_representation == "percentile" else None), threshold=self.normalize_threshold_value() if self.is_thresholded() else None, include_percentiles=prediction_representation == "percentile", ) weight_value = CalibratedExplanation.to_python_number(rules["weight"][idx]) weight_uncertainty = self._build_uncertainty_payload( value=rules["weight"][idx], low=rules["weight_low"][idx], high=rules["weight_high"][idx], representation=weight_representation, percentiles=(percentiles if weight_representation == "percentile" else None), include_percentiles=weight_representation == "percentile", ) metadata_rule: Dict[str, Any] = { "feature": self._safe_feature_name(feature_index), "feature_index": CalibratedExplanation.to_python_number(feature_index), "prediction_uncertainty": prediction_uncertainty, "prediction_value": prediction_value, "weight_value": weight_value, "weight_uncertainty": weight_uncertainty, "condition_text": rules["rule"][idx], "instance_value": CalibratedExplanation.to_python_number( rules["sampled_values"][idx] ), "alternative_value": CalibratedExplanation.to_python_number(rules["value"][idx]), } if self.is_thresholded(): metadata_rule["threshold"] = self.normalize_threshold_value() metadata["feature_rules"].append(metadata_rule) metadata["prediction_uncertainty"] = self._build_instance_uncertainty() return {"core": core, "metadata": metadata}
def _check_preconditions(self): """Warn when the configured discretizer is unsuitable for alternative explanations.""" if self.is_regression(): if not isinstance(self.get_explainer().discretizer, RegressorDiscretizer): warnings.warn( "Alternative explanations for regression recommend using the " + "regressor discretizer. Consider extracting alternative " + "explanations using `explainer.explain_alternatives(test_set)`", stacklevel=2, ) elif not isinstance(self.get_explainer().discretizer, EntropyDiscretizer): warnings.warn( "Alternative explanations for classification recommend using " + "the entropy discretizer. Consider extracting alternative " + "explanations using `explainer.explain_alternatives(test_set)`", stacklevel=2, ) # pylint: disable=too-many-statements, too-many-branches
[docs] def get_rules(self): """ Create alternative rules. Returns ------- Array-like : List[Dict[str, List]] A list of dictionaries containing the alternative rules. """ if ( getattr(self, "has_conjunctive_rules", False) and getattr(self, "conjunctive_rules", None) is not None ): return self.conjunctive_rules if getattr(self, "has_rules", False) and isinstance(self.rules, dict): return self.rules self.rules = [] self.labels = {} # pylint: disable=attribute-defined-outside-init instance = np.array(self.x_test, copy=True) instance.flags.writeable = False # pylint: disable=protected-access discretized = self.get_explainer().discretize(instance.reshape(1, -1))[0] instance_predict = self.binned["predict"] instance_low = self.binned["low"] instance_high = self.binned["high"] state_helper = ConjunctionState(None) state_helper.state["classes"] = self.prediction["classes"] state_helper.set_base_prediction( self.prediction["predict"], self.prediction["low"], self.prediction["high"] ) rule_boundaries = self.get_explainer().rule_boundaries(instance) ignored = self.ignored_features_for_instance() for f, _ in enumerate(instance): # pylint: disable=invalid-name if f in ignored: continue if f in self.get_explainer().categorical_features: values = np.array(self.get_explainer().feature_values[f]) values = np.delete(values, values == discretized[f]) for value_bin, value in enumerate(values): # skip if identical to original if ( self.prediction["low"] == instance_low[f][value_bin] and self.prediction["high"] == instance_high[f][value_bin] ): continue value_str = "" if self.get_explainer().categorical_labels is not None: value_str = self.get_explainer().categorical_labels[f][int(instance[f])] else: value_str = str(np.around(instance[f], decimals=2)) rule_text = "" if self.get_explainer().categorical_labels is not None: self.labels[len(state_helper.state["rule"])] = f rule_text = ( f"{self.get_explainer().feature_names[f]} = " + f"{self.get_explainer().categorical_labels[f][int(value)]}" ) else: rule_text = f"{self.get_explainer().feature_names[f]} = {value}" state_helper.add_rule( predict=instance_predict[f][value_bin], low=instance_low[f][value_bin], high=instance_high[f][value_bin], base_predict=self.prediction["predict"], value=value_str, feature=f, sampled_values=value, feature_value=self.x_test[f], rule_text=rule_text, is_conjunctive=False, ) else: values = np.array(self.get_explainer().x_cal[:, f]) lesser = rule_boundaries[f][0] greater = rule_boundaries[f][1] value_bin = 0 if np.any(values < lesser): # skip if identical to original if self.prediction["low"] == safe_mean( instance_low[f][value_bin] ) and self.prediction["high"] == safe_mean(instance_high[f][value_bin]): pass else: state_helper.add_rule( predict=safe_mean(instance_predict[f][value_bin]), low=safe_mean(instance_low[f][value_bin]), high=safe_mean(instance_high[f][value_bin]), base_predict=self.prediction["predict"], value=str(np.around(instance[f], decimals=2)), feature=f, sampled_values=self.binned["rule_values"][f][0][0], feature_value=self.x_test[f], rule_text=f"{self.get_explainer().feature_names[f]} < {lesser:.2f}", is_conjunctive=False, ) value_bin = 1 if np.any(values > greater): # skip if identical to original if self.prediction["low"] == safe_mean( instance_low[f][value_bin] ) and self.prediction["high"] == safe_mean(instance_high[f][value_bin]): pass else: state_helper.add_rule( predict=safe_mean(instance_predict[f][value_bin]), low=safe_mean(instance_low[f][value_bin]), high=safe_mean(instance_high[f][value_bin]), base_predict=self.prediction["predict"], value=str(np.around(instance[f], decimals=2)), feature=f, sampled_values=self.binned["rule_values"][f][0][ 1 if len(self.binned["rule_values"][f][0]) == 3 else 0 ], feature_value=self.x_test[f], rule_text=f"{self.get_explainer().feature_names[f]} > {greater:.2f}", is_conjunctive=False, ) self.rules = state_helper.get_state() self.has_rules = True return self.rules
def __set_up_result(self): """Initialise the container used to build alternative explanation rules.""" result = { "base_predict": [], "base_predict_low": [], "base_predict_high": [], "predict": [], "predict_low": [], "predict_high": [], "weight": [], "weight_low": [], "weight_high": [], "value": [], "rule": [], "feature": [], "sampled_values": [], "feature_value": [], "is_conjunctive": [], "classes": self.prediction["classes"], } result["base_predict"].append(self.prediction["predict"]) result["base_predict_low"].append(self.prediction["low"]) result["base_predict_high"].append(self.prediction["high"]) return result
[docs] def is_super_explanation(self): """Determine if the explanation is a super-explanation.""" return self.__is_super_explanation
[docs] def is_semi_explanation(self): """Determine if the explanation is a semi-explanation.""" return self.__is_semi_explanation
[docs] def is_counter_explanation(self): """Determine if the explanation is a counter-explanation.""" return self.__is_counter_explanation
def __append_rule(self, new_rules, rules, rule): """Append a single rule from *rules* at index *rule* to *new_rules*.""" new_rules["predict"].append(rules["predict"][rule]) new_rules["predict_low"].append(rules["predict_low"][rule]) new_rules["predict_high"].append(rules["predict_high"][rule]) new_rules["weight"].append(rules["weight"][rule]) new_rules["weight_low"].append(rules["weight_low"][rule]) new_rules["weight_high"].append(rules["weight_high"][rule]) new_rules["value"].append(rules["value"][rule]) new_rules["rule"].append(rules["rule"][rule]) new_rules["feature"].append(rules["feature"][rule]) new_rules["sampled_values"].append(rules["sampled_values"][rule]) if "feature_value" in rules: new_rules["feature_value"].append(rules["feature_value"][rule]) else: new_rules["feature_value"].append(None) new_rules["is_conjunctive"].append(rules["is_conjunctive"][rule]) def __filter_rules( self, only_ensured=False, make_super=False, make_semi=False, make_counter=False, include_potential=False, ): """Filter rules based on the explanation type.""" is_plain_regression = self.is_regression() and not self.is_probabilistic() initial_uncertainty = np.abs(self.prediction["high"] - self.prediction["low"]) new_rules = self.__set_up_result() rules = self.get_rules() # pylint: disable=protected-access if is_plain_regression: # For plain regression, redefine filtering concepts: # - super: higher prediction than original # - semi/counter: lower prediction than original (identical) # - potential: alternative interval covers the original prediction # - ensured: smaller uncertainty interval (unchanged) for rule in range(len(rules["rule"])): is_potential = ( rules["predict_low"][rule] <= self.prediction["predict"] <= rules["predict_high"][rule] ) if not include_potential and is_potential: continue # Super: keep only rules with higher prediction if make_super and rules["predict"][rule] <= self.prediction["predict"]: continue # Semi: for plain regression, keep alternatives where the # uncertainty intervals mutually include the other's mean # (i.e. conservative 'semi' definition). Use predict and # predict_low/predict_high for comparisons. if make_semi: try: rule_mean = float(rules["predict"][rule]) rule_low = float(rules["predict_low"][rule]) rule_high = float(rules["predict_high"][rule]) base_mean = float(self.prediction["predict"]) base_low = float(self.prediction["low"]) base_high = float(self.prediction["high"]) except (TypeError, ValueError): # If values are not numeric, skip this rule continue if not ( (rule_low <= base_mean <= rule_high) and (base_low <= rule_mean <= base_high) ): continue # Counter: keep only rules with lower prediction than original if make_counter and rules["predict"][rule] >= self.prediction["predict"]: continue if ( only_ensured and rules["predict_high"][rule] - rules["predict_low"][rule] > initial_uncertainty ): continue if ( rules["base_predict_low"] == rules["predict_low"][rule] and rules["base_predict_high"] == rules["predict_high"][rule] and rules["predict"][rule] == self.prediction["predict"] ): continue self.__append_rule(new_rules, rules, rule) else: positive_class = self.prediction["predict"] > 0.5 for rule in range(len(rules["rule"])): is_potential = rules["predict_low"][rule] < 0.5 < rules["predict_high"][rule] # filter out potential rules if include_potential is False if not include_potential and is_potential: continue # Compute point-based membership (always enforced). rule_predict = rules["predict"][rule] # super: moves further into the predicted class (away from 0.5) is_super_by_point = ( positive_class and rule_predict > self.prediction["predict"] ) or (not positive_class and rule_predict < self.prediction["predict"]) # semi: same side as base but closer to the decision boundary (towards 0.5) if positive_class: is_semi_by_point = (rule_predict > 0.5) and ( rule_predict < self.prediction["predict"] ) else: is_semi_by_point = (rule_predict < 0.5) and ( rule_predict > self.prediction["predict"] ) # counter: crosses the decision boundary (opposite side of 0.5) is_counter_by_point = (positive_class and rule_predict <= 0.5) or ( not positive_class and rule_predict >= 0.5 ) # Enforce membership by point-prediction for all modes. Potentials # are still controlled by the `include_potential` flag above, but # when included they must also satisfy the point-based comparator. if make_super and not is_super_by_point: continue if make_semi and not is_semi_by_point: continue if make_counter and not is_counter_by_point: continue # if only_ensured is True, filter out rules that lead to increased uncertainty if ( only_ensured and rules["predict_high"][rule] - rules["predict_low"][rule] > initial_uncertainty ): continue # filter out rules that does not provide a different prediction if ( rules["base_predict_low"] == rules["predict_low"][rule] and rules["base_predict_high"] == rules["predict_high"][rule] ): continue self.__append_rule(new_rules, rules, rule) new_rules["classes"] = rules["classes"] if self.has_conjunctive_rules: # pylint: disable=protected-access self.__extracted_non_conjunctive_rules(new_rules) self.rules = new_rules return self def __pareto_rule_indexes(self, rules, *, pareto_cost: str): """Return rule indices on the output-envelope Pareto frontier. The output value (probability for classification or calibrated output for regression) is treated as the coverage axis, while the Pareto *cost* dimension is minimized. Supported Pareto cost dimensions: - ``"uncertainty_width"``: minimize interval width (``high - low``). - ``"rule_size"``: minimize number of features changed in the rule (1 for atomic rules; >1 for conjunctive rules). """ rule_count = len(rules.get("rule", [])) if rule_count <= 1: return list(range(rule_count)) def _rule_size(feature: Any) -> float: if isinstance(feature, (list, tuple, np.ndarray)): return float(len(np.asarray(feature).ravel())) return 1.0 def _rule_cost(index: int) -> float: if pareto_cost == "uncertainty_width": return float(rules["predict_high"][index]) - float(rules["predict_low"][index]) if pareto_cost == "rule_size": features = rules.get("feature", []) feature_value = features[index] if index < len(features) else None return _rule_size(feature_value) raise ValidationError( "pareto_cost must be one of: uncertainty_width, rule_size", details={"pareto_cost": pareto_cost}, ) tolerance = 1e-12 best_per_output = {} for index in range(rule_count): output_value = float(rules["predict"][index]) cost_value = _rule_cost(index) output_key = round(output_value, 12) current_best = best_per_output.get(output_key) if current_best is None: best_per_output[output_key] = { "index": index, "output": output_value, "cost": cost_value, } continue if ( cost_value < current_best["cost"] - tolerance or math.isclose( cost_value, current_best["cost"], rel_tol=tolerance, abs_tol=tolerance, ) and index < current_best["index"] ): best_per_output[output_key] = { "index": index, "output": output_value, "cost": cost_value, } candidates = sorted(best_per_output.values(), key=lambda candidate: candidate["output"]) if len(candidates) <= 2: return sorted(candidate["index"] for candidate in candidates) left_mins = [] running_left_min = float("inf") for candidate in candidates: running_left_min = min(running_left_min, candidate["cost"]) left_mins.append(running_left_min) right_mins = [0.0] * len(candidates) running_right_min = float("inf") for reverse_index in range(len(candidates) - 1, -1, -1): running_right_min = min(running_right_min, candidates[reverse_index]["cost"]) right_mins[reverse_index] = running_right_min kept_indexes = { candidates[0]["index"], candidates[-1]["index"], } for position, candidate in enumerate(candidates): cost_value = candidate["cost"] if ( cost_value <= left_mins[position] + tolerance or cost_value <= right_mins[position] + tolerance ): kept_indexes.add(candidate["index"]) return sorted(kept_indexes) def __pareto_filter_rules(self, *, pareto_cost: str): """Reduce current rules to the output-envelope Pareto frontier.""" rules = self.get_rules() # pylint: disable=protected-access pareto_indexes = set(self.__pareto_rule_indexes(rules, pareto_cost=pareto_cost)) new_rules = self.__set_up_result() for rule in range(len(rules.get("rule", []))): if rule not in pareto_indexes: continue self.__append_rule(new_rules, rules, rule) new_rules["classes"] = rules["classes"] if self.has_conjunctive_rules: # pylint: disable=protected-access self.__extracted_non_conjunctive_rules(new_rules) self.rules = new_rules return self # -- Convenience retrieval API ------------------------------------------------- def _alternative_rule_to_dict(self, rules: Dict[str, Any], idx: int) -> Dict[str, Any]: """Normalize an internal alternative rule into a user-friendly dict.""" feature_index = rules["feature"][idx] prediction_value = CalibratedExplanation.to_python_number(rules["predict"][idx]) prediction_interval = CalibratedExplanation._build_interval( rules["predict_low"][idx], rules["predict_high"][idx] ) return { "index": int(idx), "feature": self._safe_feature_name(feature_index), "condition": rules["rule"][idx], "alternative_prediction": prediction_value, "uncertainty_interval": prediction_interval, }
[docs] def get_rule_by_index(self, index: int) -> Dict[str, Any]: """Return a single alternative rule by its numeric index. Raises ``IndexError`` when the index is out of range. """ rules = self.get_rules() count = len(rules.get("rule", [])) if index < 0 or index >= count: raise IndexError(f"No alternative rule at index {index}") return self._alternative_rule_to_dict(rules, index)
[docs] def get_rules_by_feature(self, feature: str) -> list[Dict[str, Any]]: """Return all alternative rules that mention the given feature name. Raises ``KeyError`` when no matching rules are found. """ rules = self.get_rules() matches: list[Dict[str, Any]] = [] for i, f in enumerate(rules.get("feature", [])): try: fname = self._safe_feature_name(f) except (TypeError, ValueError): fname = str(f) if fname == feature: matches.append(self._alternative_rule_to_dict(rules, i)) if not matches: raise KeyError(f"No alternative rules for feature {feature}") return matches
[docs] def list_rules(self) -> list[Dict[str, Any]]: """Return all alternative rules as normalized dicts.""" rules = self.get_rules() return [self._alternative_rule_to_dict(rules, i) for i in range(len(rules.get("rule", [])))]
def __extracted_non_conjunctive_rules(self, new_rules): """Split out non-conjunctive rules while preserving the original mapping.""" self.conjunctive_rules = MappingProxyType( {k: list(v) if isinstance(v, list) else v for k, v in new_rules.items()} ) mask = [not is_conj for is_conj in new_rules["is_conjunctive"]] for k, v in new_rules.items(): if isinstance(v, list) and len(v) == len(mask): new_rules[k] = [val for i, val in enumerate(v) if mask[i]] self.rules = new_rules
[docs] def reset(self): """Reset the explanation to its original state.""" self.__is_super_explanation = False self.__is_semi_explanation = False self.__is_counter_explanation = False self.has_rules = False self.get_rules() return self
[docs] def super_explanations(self, only_ensured=False, include_potential=True, copy=True): """Return a filtered view of *super* alternative explanations. A *super* alternative reinforces the model's current prediction. Parameters ---------- only_ensured : bool, default=False When ``True``, keep only alternatives whose uncertainty interval is no wider than the base prediction interval. include_potential : bool, default=True Whether to include *potential* alternatives. copy : bool, default=True When ``True``, return a new :class:`.AlternativeExplanation`. When ``False``, filter in place. Returns ------- :class:`.AlternativeExplanation` The filtered alternative explanation. Notes ----- The definition of "super" depends on the task mode: - **Classification / probabilistic regression**: Treat the output as a calibrated probability with a 0.5 decision boundary. Let ``p_base = prediction['predict']``. - If ``p_base > 0.5`` (predicted positive/event), keep alternatives with ``p_rule > p_base``. - Otherwise, keep alternatives with ``p_rule < p_base``. - **Plain regression**: Treat the output as a calibrated numeric value. Keep alternatives with a higher predicted output than the base. Potential alternatives are those where the uncertainty interval spans the decision boundary (classification / probabilistic regression) or covers the base prediction (plain regression). Examples -------- >>> alternatives = explainer.explore_alternatives(x_query) >>> super_alts = alternatives[0].super_explanations() """ target = self.copy() if copy else self target.__filter_rules( only_ensured=only_ensured, make_super=True, include_potential=include_potential ) target._AlternativeExplanation__is_super_explanation = True # pylint: disable=protected-access return target
[docs] def super(self, only_ensured=False, include_potential=True, copy=True): """Shorthand delegator for :meth:`.super_explanations`.""" return self.super_explanations( only_ensured=only_ensured, include_potential=include_potential, copy=copy )
[docs] def semi_explanations(self, only_ensured=False, include_potential=True, copy=True): """Return a filtered view of *semi* alternative explanations. A *semi* alternative moves the prediction toward the decision boundary without crossing it (classification/probabilistic regression) or moves the regression output in the opposite direction of a *super* alternative (plain regression). Parameters ---------- only_ensured : bool, default=False When ``True``, keep only alternatives whose uncertainty interval is no wider than the base prediction interval. include_potential : bool, default=True Whether to include *potential* alternatives. copy : bool, default=True When ``True``, return a new :class:`.AlternativeExplanation`. When ``False``, filter in place. Returns ------- :class:`.AlternativeExplanation` The filtered alternative explanation. Notes ----- - **Classification / probabilistic regression**: Semi alternatives stay on the *same side* of the 0.5 boundary as the base prediction, but are closer to that boundary than the base (unless marked as potential and ``include_potential=True``). - **Plain regression**: Semi alternatives keep rules with a lower predicted output than the base prediction. Examples -------- >>> alternatives = explainer.explore_alternatives(x_query) >>> semi_alts = alternatives[0].semi_explanations() """ target = self.copy() if copy else self target.__filter_rules( only_ensured=only_ensured, make_semi=True, include_potential=include_potential ) target._AlternativeExplanation__is_semi_explanation = True # pylint: disable=protected-access return target
[docs] def semi(self, only_ensured=False, include_potential=True, copy=True): """Shorthand delegator for :meth:`.semi_explanations`.""" return self.semi_explanations( only_ensured=only_ensured, include_potential=include_potential, copy=copy )
[docs] def counter_explanations(self, only_ensured=False, include_potential=True, copy=True): """Return a filtered view of *counter* alternative explanations. A *counter* alternative opposes the model's current prediction. Parameters ---------- only_ensured : bool, default=False When ``True``, keep only alternatives whose uncertainty interval is no wider than the base prediction interval. include_potential : bool, default=True Whether to include *potential* alternatives. copy : bool, default=True When ``True``, return a new :class:`.AlternativeExplanation`. When ``False``, filter in place. Returns ------- :class:`.AlternativeExplanation` The filtered alternative explanation. Notes ----- - **Classification / probabilistic regression**: Counter alternatives cross the 0.5 decision boundary. For a base prediction ``p_base > 0.5`` they keep rules with ``p_rule <= 0.5`` (and vice-versa). - **Plain regression**: Counter alternatives keep rules with a lower predicted output than the base prediction. In plain regression, :meth:`.semi_explanations` and :meth:`.counter_explanations` currently have the same output semantics. Examples -------- >>> alternatives = explainer.explore_alternatives(x_query) >>> counter_alts = alternatives[0].counter_explanations() """ target = self.copy() if copy else self target.__filter_rules( only_ensured=only_ensured, make_counter=True, include_potential=include_potential ) target._AlternativeExplanation__is_counter_explanation = True # pylint: disable=protected-access return target
[docs] def counter(self, only_ensured=False, include_potential=True, copy=True): """Shorthand delegator for :meth:`.counter_explanations`.""" return self.counter_explanations( only_ensured=only_ensured, include_potential=include_potential, copy=copy )
[docs] def ensured_explanations(self, include_potential=True, copy=True): """Return a filtered view of *ensured* alternative explanations. Ensured alternatives are those whose uncertainty interval is no wider than the base prediction interval. Parameters ---------- include_potential : bool, default=True Whether to include *potential* alternatives. copy : bool, default=True When ``True``, return a new :class:`.AlternativeExplanation`. When ``False``, filter in place. Returns ------- :class:`.AlternativeExplanation` The filtered alternative explanation. Notes ----- This method is task-agnostic: it filters by *uncertainty interval width* only and therefore works for classification, probabilistic regression, and plain regression. Examples -------- >>> alternatives = explainer.explore_alternatives(x_query) >>> ensured = alternatives[0].ensured_explanations() """ target = self.copy() if copy else self target.__filter_rules(only_ensured=True, include_potential=include_potential) return target
[docs] def ensured(self, include_potential=True, copy=True): """Shorthand delegator for :meth:`.ensured_explanations`.""" return self.ensured_explanations(include_potential=include_potential, copy=copy)
[docs] def pareto_explanations( self, include_potential: bool = True, copy: bool = True, *, pareto_cost: Literal["uncertainty_width", "rule_size"] = "uncertainty_width", ): """Return output-envelope Pareto alternatives. Parameters ---------- include_potential : bool, default=True Determines whether to include potential explanations before extracting the Pareto frontier. copy : bool, default=True Determines whether to return a copy of the explanation or modify it in place. pareto_cost : {"uncertainty_width", "rule_size"}, default="uncertainty_width" The dimension minimized along the output axis when selecting the frontier. ``"uncertainty_width"`` reproduces the historical behavior (minimize interval width). ``"rule_size"`` minimizes the number of changed features in the rule (useful when conjunctions are present). Returns ------- :class:`.AlternativeExplanation` Notes ----- Pareto filtering keeps an output-envelope frontier where no alternative can reduce the chosen Pareto cost without changing the output. The output axis is the calibrated probability (classification / probabilistic regression) or the calibrated numeric output (regression). """ target = self.copy() if copy else self target.__filter_rules(include_potential=include_potential) target.__pareto_filter_rules(pareto_cost=pareto_cost) return target
[docs] def pareto( self, include_potential: bool = True, copy: bool = True, *, pareto_cost: Literal["uncertainty_width", "rule_size"] = "uncertainty_width", ): """Shorthand delegator for :meth:`.pareto_explanations`.""" return self.pareto_explanations( include_potential=include_potential, copy=copy, pareto_cost=pareto_cost, )
[docs] def add_conjunctions(self, n_top_features=5, max_rule_size=2, **kwargs): """ Add conjunctive alternative rules. Parameters ---------- n_top_features : int, optional Number of top features to combine. max_rule_size : int, optional Maximum size of the conjunctions. **kwargs : dict Internal controls for batching, deduplication, and diagnostics. Other Parameters ---------------- _use_batched : bool, default True Controls batched vs sequential prediction. _limit_outer_to_ranked : bool, default False Whether to rank-filter outer loop candidates. _dedupe_by_feature_only : bool, default True Deduplication strategy for conjunctions. raise_on_predict_error : bool, default False Whether to surface prediction errors. _fallback_to_legacy_on_zero : bool, default False Whether to fall back to legacy on zero created. Attributes ---------- conjunction_stats : dict Summary of attempts, created, skipped, and captured prediction errors. Returns ------- self : :class:`.AlternativeExplanation` Returns a self reference, to allow for method chaining """ # Two-phase conjunction search: # 1) Rank features to select candidate rules for the inner loop. # 2) Combine outer rules with ranked candidates to build conjunctions. # current_size grows rules by one feature each iteration. # ConjunctionState tracks accumulated rules and deduplication keys. use_batched = kwargs.get("_use_batched", True) limit_outer_to_ranked = kwargs.get("_limit_outer_to_ranked", False) dedupe_by_feature_only = kwargs.get("_dedupe_by_feature_only", True) raise_on_predict_error = kwargs.get("raise_on_predict_error", False) _MAX_LOGGED_ERRORS = 5 # noqa: N806 if max_rule_size >= 4 and not use_batched: from ..utils.exceptions import ConfigurationError raise ConfigurationError( "max_rule_size >= 4 requires batched execution (internal flag _use_batched=True)", details={ "param": "max_rule_size", "value": max_rule_size, "valid_range": [2, 3], }, ) if max_rule_size < 2: return self if not use_batched: from .legacy_conjunctions import add_conjunctions_alternative_legacy add_conjunctions_alternative_legacy( self, n_top_features=n_top_features, max_rule_size=max_rule_size ) return self alternative = self.get_rules() if not self.has_rules else self.rules state_helper = ConjunctionState( self.conjunctive_rules if self.has_conjunctive_rules and self.conjunctive_rules is not None else alternative, dedupe_by_feature_only=dedupe_by_feature_only, ) self.has_conjunctive_rules = False self.conjunctive_rules = [] threshold = None if self.y_threshold is None else self.y_threshold scratch = np.array(self.x_test, copy=True) predicted_class = alternative["classes"] state_helper.state["classes"] = predicted_class base_weight_array = ( np.asarray(alternative["weight"], dtype=float) if alternative["weight"] else np.array([]) ) base_width_array = ( np.asarray(alternative["weight_high"], dtype=float) - np.asarray(alternative["weight_low"], dtype=float) if alternative["weight"] else np.array([]) ) if n_top_features is None: n_top_features = len(alternative["rule"]) def _feature_length(candidate: Any) -> int: if isinstance(candidate, (list, tuple, np.ndarray)): return len(candidate) return 1 def _coerce_feature_scalar(value: Any) -> int: if isinstance(value, (list, tuple, np.ndarray)): return int(np.asarray(value).ravel()[0]) return int(value) from collections import Counter self.conjunction_stats = { "attempts": 0, "created": 0, "skipped": Counter(), "predict_errors": [], } stats = self.conjunction_stats def _summarize(arr): if arr is None or arr.size == 0: return {"size": 0} return { "size": int(arr.size), "min": float(np.nanmin(arr)), "max": float(np.nanmax(arr)), "nan": int(np.isnan(arr).sum()), } for current_size in range(2, max_rule_size + 1): num_rules = len(alternative["rule"]) if num_rules == 0: break weights_array = state_helper.get_weights() width_array = state_helper.get_widths() top_n = min(num_rules, n_top_features) top_conjunctives = list( self.rank_features( weights_array, width=width_array, num_to_show=max(2, top_n) if num_rules >= 2 else top_n, ) ) if not top_conjunctives and num_rules > 0: top_conjunctives = list(range(num_rules)) # Determine outer loop candidates # Legacy behavior: iterate all features if limit_outer_to_ranked: num_outer = min(num_rules, n_top_features) outer_indices = list( self.rank_features( base_weight_array, width=base_width_array if base_width_array.size else None, num_to_show=max(2, num_outer) if num_rules >= 2 else num_outer, ) ) if not outer_indices and num_rules > 0: outer_indices = list(range(num_rules)) else: outer_indices = range(len(alternative["feature"])) for f1 in outer_indices: of1 = int(alternative["feature"][f1]) sampled_values1 = alternative["sampled_values"][f1] rule_value1 = ( sampled_values1 if isinstance(sampled_values1, np.ndarray) else [sampled_values1] ) for cf2 in top_conjunctives: stats["attempts"] += 1 rule_values = [rule_value1] original_features = [of1] original_feature_values = [alternative["feature_value"][f1]] of2 = state_helper.get_feature(cf2) target_length = current_size - 1 if _feature_length(of2) != target_length: stats["skipped"]["len_mismatch"] += 1 continue if state_helper.is_conjunctive(cf2): if of1 in of2: stats["skipped"]["same_feature"] += 1 continue original_features.extend(int(v) for v in of2) rule_values.extend(list(state_helper.get_sampled_values(cf2))) feature_values2 = state_helper.get_feature_value(cf2) if isinstance(feature_values2, (list, tuple, np.ndarray)): original_feature_values.extend(list(feature_values2)) else: original_feature_values.append(feature_values2) else: of2 = _coerce_feature_scalar(of2) if of1 == of2: stats["skipped"]["same_feature"] += 1 continue original_features.append(of2) original_feature_values.append(state_helper.get_feature_value(cf2)) sampled_values2 = state_helper.get_sampled_values(cf2) rule_values.append( sampled_values2 if isinstance(sampled_values2, np.ndarray) else [sampled_values2] ) if state_helper.has_combination_key(original_features, rule_values): stats["skipped"]["duplicate_combo"] += 1 continue state_helper.register_combination_key(original_features, rule_values) try: rule_predict, rule_low, rule_high = self.predict_conjunctive( rule_values, original_features, scratch, threshold, predicted_class, bins=self.bin, use_batched=use_batched, ) except (CalibratedError, ValueError, TypeError, RuntimeError) as e: if raise_on_predict_error: raise stats["skipped"]["predict_error"] += 1 if len(stats["predict_errors"]) < _MAX_LOGGED_ERRORS: stats["predict_errors"].append(f"{type(e).__name__}: {e}") continue state_helper.add_rule( predict=rule_predict, low=rule_low, high=rule_high, base_predict=self.prediction["predict"], value=alternative["value"][f1] + "\n" + state_helper.get_value(cf2), feature=list(original_features), sampled_values=list(rule_values), feature_value=list(original_feature_values), rule_text=alternative["rule"][f1] + " & \n" + state_helper.get_rule(cf2), ) stats["created"] += 1 self.conjunctive_rules = state_helper.get_state() self.has_conjunctive_rules = True if stats["created"] == 0 and stats["attempts"] > 0: summary_weights = _summarize(base_weight_array) import warnings err_msg = ( f" predict_errors={stats['predict_errors']}" if stats["predict_errors"] else "" ) warning_msg = ( f"add_conjunctions: created={stats['created']} " f"attempts={stats['attempts']} skipped={dict(stats['skipped'])} " f"weights={summary_weights}{err_msg}" ) if kwargs.get("_fallback_to_legacy_on_zero", False): warnings.warn( warning_msg + " (falling back to legacy)", UserWarning, stacklevel=2, ) from .legacy_conjunctions import add_conjunctions_alternative_legacy return add_conjunctions_alternative_legacy( self, n_top_features=n_top_features, max_rule_size=max_rule_size ) if kwargs.get("verbose", False): warnings.warn(warning_msg, UserWarning, stacklevel=2) self.conjunctive_rules = state_helper.get_state() self.has_conjunctive_rules = True return self
def _is_lesser(self, rule_boundary, instance_value): """Return whether the instance value exceeds the provided rule boundary.""" return rule_boundary < instance_value # pylint: disable=consider-iterating-dictionary
[docs] def plot(self, filter_top=None, **kwargs): """ Plot the alternative explanation. Parameters ---------- filter_top : int, optional The number of top features to display. **kwargs : dict Additional plotting arguments, such as: show : bool, default=True if filename is empty, False otherwise A boolean parameter that determines whether the plot should be displayed or not. If set to True, the plot will be displayed. If set to False, the plot will not be displayed. filename : str, default='' The filename parameter is a string that represents the full path and filename of the plot image file that will be saved. If this parameter is not provided or is an empty string, the plot will not be saved as an image file. style : str, default='regular' The `style` parameter is a string that determines the style of the plot. Possible styles are for :class:`.AlternativeExplanation`: * 'regular' - a regular plot with feature weights and uncertainty intervals (if applicable) * 'triangular' - a triangular plot for alternative explanations highlighting the interplay between the prediction and the uncertainty intervals * 'ensured' - alias for 'triangular' (intended for ensured-style alternative interpretation) rnk_metric : str, default='ensured' The metric used to rank the features. Supported metrics are 'ensured', 'feature_weight', and 'uncertainty'. rnk_weight : float, default=0.5 The weight of the uncertainty in the ranking. Used with the 'ensured' ranking metric. """ requested_style = kwargs.get("style") custom_plot_style = isinstance(requested_style, str) and requested_style not in { "regular", "triangular", "ensured", "narrative", } # Ensure style_override gets passed through style_override = kwargs.pop("style_override", None) plot_use_legacy = kwargs.pop("use_legacy", None) # PlotSpec request forces new renderer if kwargs.get("return_plot_spec") or custom_plot_style: plot_use_legacy = False # Phase 2 Option B: Default to legacy to ensure parity until PlotSpec is fully hardened elif plot_use_legacy is None: plot_use_legacy = True filename = kwargs.pop("filename", "") show = kwargs.pop("show", filename == "") rnk_metric = kwargs.pop("rnk_metric", "ensured") if rnk_metric is None: rnk_metric = "ensured" rnk_weight = kwargs.pop("rnk_weight", 0.5) # Put the most uncertain rules at the top if rnk_metric == "uncertainty": rnk_weight = 1.0 rnk_metric = "ensured" # Use conjunctive rules when available so that conjunctions appear in plots if getattr(self, "has_conjunctive_rules", False) and getattr( self, "conjunctive_rules", None ): alternative = self.conjunctive_rules else: alternative = self.get_rules() # get_explanation(index) self._check_preconditions() predict = self.prediction if len(filename) > 0: path, filename, title, ext = prepare_for_saving(filename) path = f"plots/{path}" save_ext = [ext] else: path = "" title = "" save_ext = [] feature_predict = { "predict": alternative["predict"], "low": alternative["predict_low"], "high": alternative["predict_high"], } feature_weights = np.reshape(alternative["weight"], (len(alternative["weight"]))) width = np.reshape( np.array(alternative["weight_high"]) - np.array(alternative["weight_low"]), (len(alternative["weight"])), ) num_rules = len(alternative["rule"]) if filter_top is None: filter_top = num_rules num_to_show_ = np.min([num_rules, filter_top]) if num_to_show_ <= 0: warnings.warn( f"The explanation has no rules to plot. The index of the instance is {self.index}", stacklevel=2, ) return if rnk_metric == "feature_weight": features_to_plot = self.rank_features( feature_weights, width=width, num_to_show=num_to_show_ ) else: # Always rank base on predicted class prediction = alternative["predict"] if self.get_mode() == "classification" or self.is_thresholded(): prediction = prediction if predict["predict"] > 0.5 else [1 - p for p in prediction] ranking = calculate_metrics( uncertainty=[ alternative["predict_high"][i] - alternative["predict_low"][i] for i in range(num_rules) ], prediction=prediction, w=rnk_weight, metric=rnk_metric, ) features_to_plot = self.rank_features(width=ranking, num_to_show=num_to_show_) # Display highest-impact rules at the top: reverse the index order returned by # rank_features (which yields ascending by design). features_to_plot = list(reversed(features_to_plot)) # Filter out rules that don't change the prediction or uncertainty (exactly identical to base). # Keep ordering from the ranking. features_to_plot = [ i for i in features_to_plot if not ( np.isclose(feature_predict["predict"][i], predict["predict"]) and np.isclose(feature_predict["low"][i], predict["low"]) and np.isclose(feature_predict["high"][i], predict["high"]) ) ] # Adjust the number to show after filtering num_to_show_filtered = min(num_to_show_, len(features_to_plot)) style = kwargs.get("style") if style == "ensured": kwargs["style"] = "triangular" style = "triangular" if style == "triangular": proba = predict["predict"] # Uncertainty is the calibrated interval width (high-low). # Keep semantics consistent with ADR-021 and other plot styles. y_minmax = getattr(self, "y_minmax", None) base_low = predict["low"] base_high = predict["high"] if y_minmax is not None: if base_low == -np.inf: base_low = y_minmax[0] if base_high == np.inf: base_high = y_minmax[1] uncertainty = base_high - base_low rule_proba = alternative["predict"] rule_low = np.array(alternative["predict_low"], dtype=float) rule_high = np.array(alternative["predict_high"], dtype=float) # Replace infinite endpoints with observed bounds so the triangle plot # can render meaningful widths. if y_minmax is not None: rule_low = np.where(np.isneginf(rule_low), y_minmax[0], rule_low) rule_high = np.where(np.isposinf(rule_high), y_minmax[1], rule_high) rule_uncertainty = rule_high - rule_low # Use list comprehension or NumPy array indexing to select elements selected_rule_proba = [rule_proba[i] for i in features_to_plot] selected_rule_uncertainty = [rule_uncertainty[i] for i in features_to_plot] # Use the filtered number of rules to plot so the number of arrow # positions (num_to_show) matches the length of the selected rule # arrays. Previously we passed the original num_to_show_ which could # be larger than the number of selected rules and caused a size # mismatch in matplotlib.quiver. num_to_show_for_plot = min(num_to_show_, len(selected_rule_proba)) plot_triangular( self, proba, uncertainty, selected_rule_proba, selected_rule_uncertainty, num_to_show_for_plot, title=title, path=path, show=show, save_ext=save_ext, style_override=style_override, ) return column_names = alternative["rule"] plot_alternative( self, alternative["value"], predict, feature_predict, features_to_plot, num_to_show=num_to_show_filtered, column_names=column_names, title=title, path=path, show=show, save_ext=save_ext, style_override=style_override, use_legacy=plot_use_legacy, )
[docs] class FastExplanation(CalibratedExplanation): """Class representing fast explanations. Represents fast, SHAP-like explanations, enabling efficient interpretation of model behavior for large datasets. """ def __init__( self, calibrated_explanations, index, x, feature_weights, feature_predict, prediction, y_threshold=None, instance_bin=None, condition_source="prediction", ): """Class representing fast explanations. Represents fast, SHAP-like explanations, enabling efficient interpretation of model behavior for large datasets. Initialize a FastExplanation instance. Parameters ---------- calibrated_explanations : CalibratedExplanations The parent CalibratedExplanations object. index : int The index of the instance being explained. x : array-like The test dataset containing the instances to be explained. feature_weights : dict A mapping of feature weights. feature_predict : dict A mapping of feature predictions. prediction : dict A mapping containing the prediction results. y_threshold : float or tuple, optional The threshold for binary classification or regression explanations. instance_bin : int, optional The bin index of the instance. condition_source : str, default="prediction" The source of the conditions for the explanation. """ super().__init__( calibrated_explanations, index, x, {}, feature_weights, feature_predict, prediction, y_threshold, instance_bin, condition_source=condition_source, ) self._check_preconditions() self.get_rules() def __repr__(self): """Return a string representation of the fast explanation.""" fast = self.get_rules() output = [ f"{'Prediction':10} [{' Low':5}, {' High':5}]", f" {fast['base_predict'][0]:5.3f} [{fast['base_predict_low'][0]:5.3f}, {fast['base_predict_high'][0]:5.3f}]", f"{'Value':6}: {'Feature':40s} {'Weight':6} [{' Low':6}, {' High':6}]", ] feature_order = self.rank_features( fast["weight"], width=np.array(fast["weight_high"]) - np.array(fast["weight_low"]), num_to_show=len(fast["rule"]), ) # feature_order = range(len(fast['rule'])) output.extend( f"{fast['value'][f]:6}: {fast['rule'][f]:40s} {fast['weight'][f]:>6.3f} [{fast['weight_low'][f]:>6.3f}, {fast['weight_high'][f]:>6.3f}]" for f in reversed(feature_order) ) # sum_weights = np.sum((fast['weight'])) # sum_weights_low = np.sum((fast['weight_low'])) # sum_weights_high = np.sum((fast['weight_high'])) # output.append(f"{'Mean':6}: {'':40s} {sum_weights:>6.3f} [{sum_weights_low:>6.3f}, {sum_weights_high:>6.3f}]") return "\n".join(output) + "\n"
[docs] def build_rules_payload(self) -> Dict[str, Any]: """Reuse the factual payload structure for fast explanations.""" return FactualExplanation.build_rules_payload(self)
[docs] def add_conjunctions(self, n_top_features=5, max_rule_size=2, **kwargs): """Warn that conjunctions are not supported for ``FastExplanation`` and perform no work. Parameters ---------- n_top_features : int The number of top features to consider for conjunctions. Default is 5. max_rule_size : int The maximum size of the conjunctive rules. Default is 2. Warnings -------- This method is not supported for :class:`.FastExplanation` and will not alter the explanation. """ warnings.warn( "The add_conjunctions method is currently not supported for `FastExplanation`, making this call resulting in no change.", stacklevel=2, )
# pass def _is_lesser(self, rule_boundary, instance_value): """Return False as fast explanations do not support ordered rule comparisons.""" pass
[docs] def add_new_rule_condition(self, feature, rule_boundary): """Create a new rule condition for a numerical feature. Warnings -------- This method is not supported for :class:`.FastExplanation` and will not alter the explanation. """ warnings.warn( "The add_new_rule_condition method is currently not supported for `FastExplanation`, making this call resulting in no change.", stacklevel=2, )
# pass def _check_preconditions(self): """Provide a placeholder hook; FAST explanations require no extra checks.""" pass # pylint: disable=too-many-statements, too-many-branches
[docs] def get_rules(self): """ Create fast explanation rules. Returns ------- dict A dictionary containing the fast explanation rules. """ # i = self.index instance = np.array(self.x_test, copy=True) fast = { "base_predict": [], "base_predict_low": [], "base_predict_high": [], "predict": [], "predict_low": [], "predict_high": [], "weight": [], "weight_low": [], "weight_high": [], "value": [], "rule": [], "feature": [], "sampled_values": [], "feature_value": [], "is_conjunctive": [], "classes": self.prediction["classes"], } fast["base_predict"].append(self.prediction["predict"]) fast["base_predict_low"].append(self.prediction["low"]) fast["base_predict_high"].append(self.prediction["high"]) rules = self.define_conditions() for f, _ in enumerate(instance): # pylint: disable=invalid-name if self.prediction["predict"] == self.feature_predict["predict"][f]: continue fast["predict"].append(self.feature_predict["predict"][f]) fast["predict_low"].append(self.feature_predict["low"][f]) fast["predict_high"].append(self.feature_predict["high"][f]) fast["weight"].append(self.feature_weights["predict"][f]) fast["weight_low"].append(self.feature_weights["low"][f]) fast["weight_high"].append(self.feature_weights["high"][f]) if f in self.get_explainer().categorical_features: if self.get_explainer().categorical_labels is not None: fast["value"].append( self.get_explainer().categorical_labels[f][int(instance[f])] ) else: fast["value"].append(str(instance[f])) else: fast["value"].append(str(np.around(instance[f], decimals=2))) fast["rule"].append(rules[f]) fast["feature"].append(f) fast["sampled_values"].append(None) fast["feature_value"].append(None) fast["is_conjunctive"].append(False) self.rules = fast self.has_rules = True return self.rules
[docs] def define_conditions(self): """ Define the rule conditions for the fast explanation. Returns ------- list[str] A list of conditions for each feature. """ self.conditions = [] for f in range(self.get_explainer().num_features): rule = f"{self.get_explainer().feature_names[f]}" self.conditions.append(rule) return self.conditions
[docs] def plot(self, filter_top=None, **kwargs): """ Plot the fast explanation. Parameters ---------- filter_top : int, optional The number of top features to display. **kwargs : dict Additional plotting arguments, such as: show : bool, default=True if filename is empty, False otherwise A boolean parameter that determines whether the plot should be displayed or not. If set to True, the plot will be displayed. If set to False, the plot will not be displayed. filename : str, default='' The filename parameter is a string that represents the full path and filename of the plot image file that will be saved. If this parameter is not provided or is an empty string, the plot will not be saved as an image file. uncertainty : bool, default=False The `uncertainty` parameter is a boolean flag that determines whether to plot the uncertainty intervals for the feature weights. If `uncertainty` is set to `True`, the plot will show the envelope of possible boundary shifts based on the lower and upper bounds of the uncertainty intervals. If `uncertainty` is set to `False`, the plot will only show the feature weights style : str, default='regular' The `style` parameter is a string that determines the style of the plot. Possible styles are for :class:`.FastExplanation`: * 'regular' - a regular plot with feature weights and uncertainty intervals (if applicable) rnk_metric : str, default='feature_weight' The metric used to rank the features. Supported metrics are 'ensured', 'feature_weight', and 'uncertainty'. rnk_weight : float, default=0.5 The weight of the uncertainty in the ranking. Used with the 'ensured' ranking metric. """ # Ensure style_override gets passed through style_override = kwargs.get("style_override") plot_use_legacy = kwargs.get("use_legacy") filename = kwargs.get("filename", "") show = kwargs.get("show", filename == "") uncertainty = kwargs.get("uncertainty", False) rnk_metric = kwargs.get("rnk_metric", "feature_weight") if rnk_metric is None: rnk_metric = "feature_weight" rnk_weight = kwargs.get("rnk_weight", 0.5) if rnk_metric == "uncertainty": rnk_weight = 1.0 rnk_metric = "ensured" # Consistency guard: one-sided intervals cannot show uncertainty bands if uncertainty and self.is_one_sided(): raise Warning("Interval plot is not supported for one-sided explanations.") # Use conjunctive rules when available so that conjunctions appear in plots if getattr(self, "has_conjunctive_rules", False) and getattr( self, "conjunctive_rules", None ): factual = self.conjunctive_rules else: factual = self.get_rules() # get_explanation(index) self._check_preconditions() predict = self.prediction num_features_to_show = len(factual["weight"]) if filter_top is None: filter_top = num_features_to_show filter_top = np.min([num_features_to_show, filter_top]) if filter_top <= 0: warnings.warn( f"The explanation has no rules to plot. The index of the instance is {self.index}", stacklevel=2, ) return if len(filename) > 0: path, filename, title, ext = prepare_for_saving(filename) path = f"plots/{path}" save_ext = [ext] else: path = "" title = "" save_ext = [] if uncertainty: feature_weights = { "predict": factual["weight"], "low": factual["weight_low"], "high": factual["weight_high"], } else: feature_weights = factual["weight"] width = np.reshape( np.array(factual["weight_high"]) - np.array(factual["weight_low"]), (len(factual["weight"])), ) if rnk_metric == "feature_weight": features_to_plot = self.rank_features( factual["weight"], width=width, num_to_show=filter_top ) else: ranking = calculate_metrics( uncertainty=[ factual["predict_high"][i] - factual["predict_low"][i] for i in range(len(factual["weight"])) ], prediction=factual["predict"], w=rnk_weight, metric=rnk_metric, ) features_to_plot = self.rank_features(width=ranking, num_to_show=filter_top) # Prefer explicit feature/column names when available; fall back to rule strings column_names = ( factual.get("feature_names") or factual.get("column_names") or factual.get("rule") ) if "classification" in self.get_explainer().mode or self.is_thresholded(): plot_probabilistic( self, factual["value"], predict, feature_weights, features_to_plot, filter_top, column_names, title=title, path=path, interval=uncertainty, show=show, idx=self.index, save_ext=save_ext, style_override=style_override, use_legacy=plot_use_legacy, ) else: plot_regression( self, factual["value"], predict, feature_weights, features_to_plot, filter_top, column_names, title=title, path=path, interval=uncertainty, show=show, idx=self.index, save_ext=save_ext, style_override=style_override, use_legacy=plot_use_legacy, )