# pylint: disable=line-too-long
"""Helper utilities for filesystem, typing, and data transformations.
Centralizes small routines for safe imports, conversions, and metric
calculations shared across calibrated explanations.
"""
import importlib
import numbers
import os
import sys
from inspect import isclass
from typing import Any
import numpy as np
import pandas as pd
from pandas import CategoricalDtype
from pandas.api.types import is_object_dtype, is_string_dtype
try: # pragma: no cover - script-mode fallback
from .exceptions import NotFittedError, ValidationError
except ImportError: # pragma: no cover - invoked when run as a script
from calibrated_explanations.utils.exceptions import NotFittedError, ValidationError
[docs]
def make_directory(path: str, save_ext=None, add_plots_folder=True) -> None: # pylint: disable=unused-private-member
"""Create directory if it does not exist.
Parameters
----------
path : str
The path to the directory to create
save_ext : str or list, optional
The extension of the file to save, by default None
add_plots_folder : bool, optional
Whether to add a 'plots' folder to the path, by default True
"""
if save_ext is not None and len(save_ext) == 0:
return
if not add_plots_folder:
if not os.path.isdir(path):
os.mkdir(path)
return
if not os.path.isdir("plots"):
os.mkdir("plots")
if path == "plots":
return
path = path.removeprefix("plots/")
if not os.path.isdir(f"plots/{path}"):
os.mkdir(f"plots/{path}")
# Adapted from shap.utils._general.safe_isinstance (MIT License).
# See THIRD_PARTY_NOTICES.md for full license text and attribution.
[docs]
def safe_isinstance(obj, class_path_str):
"""Acts as a safe version of isinstance without having to explicitly import packages which may not exist in the users environment.
Checks if obj is an instance of type specified by class_path_str.
Parameters
----------
obj: Any
Some object you want to test against
class_path_str: str or list
A string or list of strings specifying full class paths
Example: `sklearn.ensemble.RandomForestRegressor`
Returns
-------
bool: True if isinstance is true and the package exists, False otherwise
"""
if isinstance(class_path_str, str):
class_path_strs = [class_path_str]
elif class_path_str is None:
class_path_strs = []
elif isinstance(class_path_str, (list, tuple)):
class_path_strs = class_path_str
else:
class_path_strs = [""]
# try each module path in order
for _class_path_str in class_path_strs:
if "." not in _class_path_str:
raise ValidationError(
"class_path_str must be a fully qualified module path (e.g., "
"'sklearn.ensemble.RandomForestRegressor').",
details={
"param": "class_path_str",
"requirement": "module path must include at least one '.'",
"provided": _class_path_str,
},
)
# Splits on last occurrence of "."
module_name, class_name = _class_path_str.rsplit(".", 1)
# here we don't check further if the model is not imported, since we shouldn't have
# an object of that types passed to us if the model the type is from has never been
# imported. (and we don't want to import lots of new modules for no reason)
if module_name not in sys.modules:
continue
module = sys.modules[module_name]
# Get class
_class = getattr(module, class_name, None)
if _class is None:
continue
if isinstance(obj, _class):
return True
return False
[docs]
def safe_import(module_name, class_name=None):
"""Safely import a module, if it is not installed, print a message and return None."""
try:
imported_module = sys.modules.get(module_name)
if imported_module is None:
imported_module = importlib.import_module(module_name)
if class_name is None:
return imported_module
if isinstance(class_name, (list, np.ndarray)):
return [getattr(imported_module, name) for name in class_name]
return getattr(imported_module, class_name)
except ImportError as exc:
raise ImportError(
f"The required module '{module_name}' is not installed. "
f"Please install it using 'pip install {module_name}' or another method."
) from exc
except BaseException:
exc_info = sys.exc_info()[1]
if not isinstance(exc_info, AttributeError):
raise
raise ImportError(
f"The class or function '{class_name}' does "
+ f"not exist in the module '{module_name}'."
) from exc_info
# Adapted from sklearn.utils.validation.check_is_fitted (BSD 3-Clause License).
# See THIRD_PARTY_NOTICES.md for full license text and attribution.
# pylint: disable=inconsistent-return-statements
[docs]
def check_is_fitted(estimator, attributes=None, *, msg=None, all_or_any=all):
"""Perform is_fitted validation for estimator.
Checks if the estimator is fitted by verifying the presence of
fitted attributes (ending with a trailing underscore) and otherwise
raises a NotFittedError with the given message.
If an estimator does not set any attributes with a trailing underscore, it
can define a ``__sklearn_is_fitted__`` method returning a boolean to specify if the
estimator is fitted or not.
Parameters
----------
estimator : estimator instance
estimator instance for which the check is performed.
attributes : str, list or tuple of str, default=None
Attribute name(s) given as string or a list/tuple of strings
Eg.: ``["coef_", "estimator_", ...], "coef_"``
If `None`, `estimator` is considered fitted if there exist an
attribute that ends with a underscore and does not start with double
underscore.
msg : str, default=None
The default error message is, "This %(name)s instance is not fitted
yet. Call 'fit' with appropriate arguments before using this
estimator."
For custom messages if "%(name)s" is present in the message string,
it is substituted for the estimator name.
Eg. : "Estimator, %(name)s, must be fitted before sparsifying".
all_or_any : callable, {all, any}, default=all
Specify whether all or any of the given attributes must exist.
Returns
-------
None
Raises
------
NotFittedError
If the attributes are not found.
"""
if isclass(estimator):
raise ValidationError(f"{estimator} is a class, not an instance.")
if msg is None:
msg = (
"This %(name)s instance is not fitted yet. Call 'fit' with "
"appropriate arguments before using this estimator."
)
if hasattr(estimator, "fitted"):
return estimator.fitted
if hasattr(estimator, "is_fitted"):
return estimator.is_fitted()
if not (
hasattr(estimator, "fit")
or hasattr(estimator, "partial_fit") # handle online models
or hasattr(estimator, "learn_initial_training_set")
): # handle online_cp package
raise ValidationError(f"{estimator} is not an estimator instance.")
if attributes is not None:
if not isinstance(attributes, (list, tuple)):
attributes = [attributes]
fitted = all_or_any([hasattr(estimator, attr) for attr in attributes])
elif hasattr(estimator, "__sklearn_is_fitted__"):
fitted = estimator.__sklearn_is_fitted__()
elif hasattr(estimator, "XTXinv"): # handle online_cp package and OnlineRidgeRegressor
fitted = estimator.XTXinv is not None or bool(hasattr(estimator, "a") and estimator.a != 0)
else:
fitted = [v for v in vars(estimator) if v.endswith("_") and not v.startswith("__")]
if not fitted or fitted == []:
raise NotFittedError(msg % {"name": type(estimator).__name__})
[docs]
def is_notebook():
"""Check if the code is running in a Jupyter notebook."""
try:
# pylint: disable=import-outside-toplevel
from IPython import get_ipython
if "IPKernelApp" not in get_ipython().config: # pragma: no cover
return False
except BaseException as exc:
if not isinstance(exc, (ImportError, AttributeError)):
raise
return False
return True
# pylint: disable=too-many-locals, too-many-branches
[docs]
def assert_threshold(threshold, x):
"""Test if the thresholds are valid.
Parameters
----------
threshold : int, float, tuple, list, or np.ndarray
The threshold(s) to be validated. It can be a scalar (int or float),
a tuple with two values, or a list/np.ndarray of scalars or tuples.
x : list or np.ndarray
The data against which the thresholds are validated. Used to check
the length of list/np.ndarray thresholds.
Returns
-------
int, float, tuple, or list
The validated threshold(s).
Raises
------
AssertionError
If the length of the list/np.ndarray threshold is not equal to the number of samples.
if the tuple threshold does not have two values.
ValueError
If the threshold is not a scalar, binary tuple, or list of scalars or binary tuples.
Examples
--------
>>> assert_threshold(0.5, [1, 2, 3])
0.5
>>> assert_threshold((0.2, 0.8), [1, 2, 3])
(0.2, 0.8)
>>> assert_threshold([0.1, 0.2, 0.3], [1, 2, 3])
[0.1, 0.2, 0.3]
>>> assert_threshold([(0.1, 0.9), (0.2, 0.8)], [1, 2])
[(0.1, 0.9), (0.2, 0.8)]
>>> assert_threshold(None, [1, 2, 3])
>>> assert_threshold([0.1, 0.2], [1])
Traceback (most recent call last):
...
AssertionError: list thresholds must have the same length as the number of samples
"""
if threshold is None:
return threshold
if np.isscalar(threshold) and isinstance(threshold, (numbers.Integral, numbers.Real)):
return threshold
if isinstance(threshold, tuple):
if len(threshold) != 2:
raise ValidationError(
"tuple thresholds must contain exactly two values",
details={
"param": "threshold",
"expected_length": 2,
"actual_length": len(threshold),
},
)
return threshold
if isinstance(threshold, (list, np.ndarray)):
if not (len(threshold) == np.asarray(x).shape[0]):
raise AssertionError(
"list thresholds must have the same length as the number of samples"
)
return [assert_threshold(t, [x[i]]) for i, t in enumerate(threshold)]
raise ValidationError(
"thresholds must be a scalar, binary tuple or list of scalars or binary tuples",
details={
"param": "threshold",
"expected_types": ["scalar", "tuple(len=2)", "list/ndarray of scalar or tuple(len=2)"],
"actual_type": type(threshold).__name__,
},
)
# pylint: disable=too-many-arguments, too-many-statements
[docs]
def calculate_metrics(
uncertainty=None,
prediction=None,
w=0.5,
metric=None,
normalize=False,
):
"""Calculate different metrics based on the uncertainty and probability values.
The function `calculate_metrics` calculates different metrics based on the uncertainty and
probability values.
Parameters
----------
uncertainty : float
The `uncertainty` parameter is a float value that represents the uncertainty of the
explanation. Uncertainty is a measure of the confidence of the explanation. For
classification, this is a value between 0 and 1, where 0 means the explanation is certain
and 1 means the explanation is uncertain. For regression, this is the width of the
uncertainty interval determined by the user defined percentiles.
prediction : float
The `prediction` parameter is a float value that represents the prediction of the
explanation. For classification, this is the probability of the predicted class. For
regression, this is the predicted value.
w : float, default=0.5
The `w` parameter is a float value that represents the weight of the uncertainty in the
metric calculation. The weight must be between -1 and 1. The default value is 0.5.
metric : str, list of str, or None, default=None
The `metric` parameter is a string that represents the metric to calculate.
If `metric` is set to None, the function will calculate all available metrics.
If `metric` is set to a list of metrics, the function will calculate only those
metrics. The available metrics are:
- 'ensured' : Weighted Sum Method
normalize : bool, default=False
The `normalize` parameter is a boolean value that represents whether to normalize the
uncertainty and prediction values. The default value is False.
Notes
-----
If the method is called with no arguments, it will return the list of available metrics.
"""
if uncertainty is None and prediction is None:
return ["ensured"]
if uncertainty is None or prediction is None:
raise ValidationError(
"Both uncertainty and prediction must be provided if any other argument is provided",
details={
"params": ["uncertainty", "prediction"],
"requirement": "both required when computing metrics",
"uncertainty_is_none": uncertainty is None,
"prediction_is_none": prediction is None,
},
)
uncertainty = np.array(uncertainty) if isinstance(uncertainty, list) else uncertainty
prediction = np.array(prediction) if isinstance(prediction, list) else prediction
metrics = {}
if not (-1 <= w <= 1):
raise ValidationError(
"The weight must be between -1 and 1.",
details={"param": "w", "min": -1, "max": 1, "provided": w},
)
inverse_prediction = False
if w < 0:
w = -w
inverse_prediction = True
if metric is None:
metric = calculate_metrics()
elif isinstance(metric, str):
metric = [metric]
if normalize:
min_uncertainty, max_uncertainty = np.min(uncertainty), np.max(uncertainty)
min_prediction, max_prediction = np.min(prediction), np.max(prediction)
uncertainty = (uncertainty - min_uncertainty) / (max_uncertainty - min_uncertainty)
prediction = (prediction - min_prediction) / (max_prediction - min_prediction)
prediction = -1 * prediction if inverse_prediction and prediction is not None else prediction
if "ensured" in metric:
metrics["ensured"] = (1 - w) * (1 - uncertainty) + w * (prediction)
return metrics if len(metrics) > 1 else metrics[list(metrics.keys())[0]]
[docs]
def convert_targets_to_numeric(y):
"""Convert string/categorical targets to numeric values while preserving labels.
Parameters
----------
y (array-like): Array of target values that may be strings or categorical.
Returns
-------
tuple:
- array-like: Numeric version of the target values
- dict or None: Mapping of original labels to numeric values if conversion was needed
"""
if any(isinstance(val, str) for val in y) or any(
isinstance(val, (np.str_, np.object_)) for val in y
):
unique_labels = np.unique(y)
label_map = {label: i for i, label in enumerate(unique_labels)}
numeric_y = np.array([label_map[label] for label in y])
return numeric_y, label_map
return y, None
[docs]
def concatenate_thresholds(perturbed_threshold, threshold, indices):
"""
Concatenates the given threshold values to the perturbed_threshold based on the provided indices.
Parameters
----------
perturbed_threshold : np.ndarray
The existing perturbed thresholds.
threshold : list or np.ndarray
The original thresholds.
indices : np.ndarray
The indices to select from the threshold.
Returns
-------
np.ndarray
The concatenated thresholds.
"""
if threshold is not None and isinstance(threshold, (list, np.ndarray)):
if isinstance(threshold[0], tuple) and len(perturbed_threshold) == 0:
perturbed_threshold = [threshold[i] for i in indices]
else:
perturbed_threshold = np.concatenate(
(perturbed_threshold, [threshold[i] for i in indices])
)
return perturbed_threshold
[docs]
def immutable_array(array):
"""
Convert a numpy array to an immutable array.
Parameters
----------
array : list or np.ndarray
The numpy array to convert.
Returns
-------
np.ndarray
The immutable numpy array.
Examples
--------
>>> arr = immutable_array([1, 2, 3])
>>> arr.flags.writeable
False
>>> int(arr[0])
1
>>> arr[0] = 10
Traceback (most recent call last):
...
ValueError: assignment destination is read-only
"""
array = np.asarray(array)
array.flags.writeable = False
return array
[docs]
def prepare_for_saving(filename):
"""
Prepare the file path, name, title, and extension for saving a file.
Parameters
----------
filename : str
The full path to the file to save.
Returns
-------
tuple:
- str: The path to the file.
- str: The filename.
- str: The title of the file.
- str: The extension of the file.
"""
if len(filename) > 0:
path = f"{os.path.dirname(filename)}/"
filename = os.path.basename(filename)
title, ext = os.path.splitext(filename)
make_directory(path, save_ext=np.array([ext]))
return path, filename, title, ext
return "", "", "", ""
[docs]
def safe_mean(values, default=0.0):
"""Return the mean of values, but return `default` if values is empty.
This prevents numpy from emitting a "Mean of empty slice" RuntimeWarning
and gives callers a deterministic fallback for empty inputs.
"""
try:
arr = np.asarray(values)
if arr.size == 0:
return default
return float(np.mean(arr))
except BaseException:
exc_info = sys.exc_info()[1]
if not isinstance(exc_info, Exception):
raise
return default
[docs]
def safe_first_element(values, default=0.0, col=None):
"""Return a sensible first element from `values`.
- If `values` is scalar, return it as float.
- If `values` is empty (size == 0), return `default`.
- If `col` is None, return the first flattened element.
- If `col` is given and `values` is 1D, return values[col] when available.
- If `col` is given and `values` is 2D, return values[0, col] when available.
This protects callers that index `[0]` (or `[0, 1]`) on prediction outputs
when fallback/edge cases may produce empty arrays.
"""
try:
arr = np.asarray(values)
if arr.size == 0:
return float(default)
# scalar
if arr.ndim == 0:
return float(arr)
# col not specified -> first flat element
if col is None:
return float(arr.flat[0])
# 1d array
if arr.ndim == 1:
if col < arr.size:
return float(arr[col])
return float(default)
# 2d or higher: try [0, col]
if arr.shape[0] > 0 and arr.shape[1] > col:
return float(arr[0, col])
return float(default)
except BaseException:
exc_info = sys.exc_info()[1]
if not isinstance(exc_info, Exception):
raise
return float(default)
[docs]
def assign_threshold(threshold: Any) -> Any:
"""Normalize regression threshold for prediction tasks.
Returns empty containers for list/array inputs to prevent
threshold broadcast errors. For scalar thresholds, returns the
value unchanged. Used in probabilistic regression to validate
and prepare thresholds before making predictions.
Parameters
----------
threshold : scalar, list, array-like, or None
Optional threshold value for regression explanations.
Returns
-------
None, scalar, or empty array
For None: returns None.
For scalar: returns the scalar unchanged.
For list/array: returns empty array (no threshold broadcast).
Examples
--------
Scalar threshold (valid for single prediction):
>>> assign_threshold(5.0)
5.0
"""
if threshold is None:
return None
if isinstance(threshold, (list, np.ndarray)):
# Return empty array to signal invalid threshold list for broadcast
return (
np.empty((0,), dtype=tuple)
if len(threshold) > 0 and isinstance(threshold[0], tuple)
else np.empty((0,))
)
return threshold
if __name__ == "__main__":
import doctest
(failures, _) = doctest.testmod()
if failures:
sys.exit(1)