Source code for calibrated_explanations.api.config

"""Configuration primitives for calibrated_explanations.

This module introduces a light-weight configuration dataclass and a builder
to simplify constructing explainers with validated options.
no wiring to core classes is performed to avoid behavior changes; consumers
may import and use these types for future-facing code.

See `RELEASE_PLAN_v1` milestone targets and ADR-009 for preprocessing-related fields.
"""

from __future__ import annotations

import sys
from dataclasses import dataclass
from typing import Any, Literal

# Backward-compatible patch point used by tests. When set to a callable, build_config
# uses it instead of the internal factory builder.
_perf_from_config = None

TaskLiteral = Literal["classification", "regression", "auto"]


[docs] @dataclass class ExplainerConfig: """Configuration for building an explainer wrapper. Notes ----- - Fields included here are future-facing. - Keep defaults aligned with existing behavior to prevent drift when adopted. """ model: Any task: TaskLiteral = "auto" # Calibration / explanation knobs (subset; extend later as needed) low_high_percentiles: tuple[int, int] = (5, 95) threshold: float | None = None # for probabilistic regression use-cases # Preprocessing (ADR-009) preprocessor: Any | None = None auto_encode: bool | Literal["auto"] = "auto" unseen_category_policy: Literal["ignore", "error"] = "error" # Parallelism placeholder (not wired yet) parallel_workers: int | None = None # Performance feature flags (ADR-003/ADR-004) - disabled by default perf_cache_enabled: bool = False perf_cache_max_items: int = 512 perf_cache_max_bytes: int | None = 32 * 1024 * 1024 perf_cache_namespace: str = "calibrator" perf_cache_version: str = "v1" perf_cache_ttl: float | None = None perf_parallel_enabled: bool = False perf_parallel_backend: Literal["auto", "sequential", "joblib", "threads", "processes"] = "auto" perf_parallel_workers: int | None = None perf_parallel_min_batch: int = 8 perf_parallel_min_instances: int | None = None perf_parallel_tiny_workload: int | None = None perf_parallel_granularity: Literal["feature", "instance"] = "feature" perf_telemetry: Any | None = None # Internal FAST-based feature filtering (disabled by default) perf_feature_filter_enabled: bool = False perf_feature_filter_per_instance_top_k: int = 8 @property def perf_factory(self): """Factory for performance telemetry.""" return getattr(self, "_perf_factory", None)
[docs] class ExplainerBuilder: """Fluent helper to assemble an :class:`ExplainerConfig`. In a later step this builder can produce a configured `WrapCalibratedExplainer`. """ def __init__(self, model: Any) -> None: """Store base model reference and seed configuration defaults.""" self._cfg = ExplainerConfig(model=model) # Simple fluent setters
[docs] def task(self, task: TaskLiteral) -> ExplainerBuilder: """Set the task type for the explainer configuration. Parameters ---------- task : {"classification", "regression", "auto"} Desired task literal to store in the configuration. """ self._cfg.task = task return self
[docs] def low_high_percentiles(self, p: tuple[int, int]) -> ExplainerBuilder: """Update the percentile pair for interval explanations. Parameters ---------- p : tuple of int Inclusive lower and upper percentiles used for interval computation. """ self._cfg.low_high_percentiles = p return self
[docs] def threshold(self, t: float | None) -> ExplainerBuilder: """Store a regression-style threshold value on the configuration. Parameters ---------- t : float or None Threshold applied when producing probabilistic regression outputs. """ self._cfg.threshold = t return self
[docs] def preprocessor(self, pre: Any | None) -> ExplainerBuilder: """Attach an optional preprocessing object to the configuration. Parameters ---------- pre : Any or None Preprocessor applied to inputs prior to fitting or calibration. """ self._cfg.preprocessor = pre return self
[docs] def auto_encode(self, flag: bool | Literal["auto"]) -> ExplainerBuilder: """Toggle automatic categorical encoding behavior. Parameters ---------- flag : bool or "auto" Whether to auto-encode categorical inputs when preprocessing. """ self._cfg.auto_encode = flag return self
[docs] def unseen_category_policy(self, policy: Literal["ignore", "error"]) -> ExplainerBuilder: """Select the strategy for handling unseen categorical values. Parameters ---------- policy : {"ignore", "error"} Policy to apply when encountering unseen categories at inference time. """ self._cfg.unseen_category_policy = policy return self
[docs] def parallel_workers(self, n: int | None) -> ExplainerBuilder: """Configure the desired number of parallel worker processes. Parameters ---------- n : int or None Worker count for parallel execution; ``None`` leaves the default in place. """ self._cfg.parallel_workers = n return self
# Perf flags (feature-flagged; no behavior change when off)
[docs] def perf_cache( self, enabled: bool, *, max_items: int | None = None, max_bytes: int | None = None, namespace: str | None = None, version: str | None = None, ttl: float | None = None, ) -> ExplainerBuilder: """Enable or disable the performance cache options. Parameters ---------- enabled : bool Flag indicating whether caching primitives should be provisioned. max_items : int, optional Maximum number of cached entries when caching is enabled. """ self._cfg.perf_cache_enabled = enabled if max_items is not None: self._cfg.perf_cache_max_items = max_items if max_bytes is not None: self._cfg.perf_cache_max_bytes = max_bytes if namespace is not None: self._cfg.perf_cache_namespace = namespace if version is not None: self._cfg.perf_cache_version = version if ttl is not None: self._cfg.perf_cache_ttl = ttl return self
[docs] def perf_parallel( self, enabled: bool, *, backend: Literal["auto", "sequential", "joblib", "threads", "processes"] | None = None, workers: int | None = None, min_batch: int | None = None, min_instances: int | None = None, tiny_workload: int | None = None, granularity: Literal["feature", "instance"] | None = None, ) -> ExplainerBuilder: """Configure the parallel backend used for performance operations. Parameters ---------- enabled : bool Whether parallel primitives should be created. backend : {"auto", "sequential", "joblib"}, optional Explicit backend selection overriding the default when provided. """ self._cfg.perf_parallel_enabled = enabled if backend is not None: self._cfg.perf_parallel_backend = backend if workers is not None: self._cfg.perf_parallel_workers = workers if min_batch is not None: self._cfg.perf_parallel_min_batch = min_batch if min_instances is not None: self._cfg.perf_parallel_min_instances = min_instances if tiny_workload is not None: self._cfg.perf_parallel_tiny_workload = tiny_workload if granularity is not None: self._cfg.perf_parallel_granularity = granularity return self
[docs] def perf_telemetry(self, callback: Any | None) -> ExplainerBuilder: """Register a telemetry callback shared by cache and parallel executors.""" self._cfg.perf_telemetry = callback return self
[docs] def perf_feature_filter( self, enabled: bool, *, per_instance_top_k: int | None = None, ) -> ExplainerBuilder: """Configure internal FAST-based feature filtering. Parameters ---------- enabled : bool Flag indicating whether the internal FAST-based feature filter is enabled. per_instance_top_k : int, optional Maximum number of features to keep per instance based on FAST weights. """ self._cfg.perf_feature_filter_enabled = enabled if per_instance_top_k is not None: self._cfg.perf_feature_filter_per_instance_top_k = max(1, int(per_instance_top_k)) return self
[docs] def build_config(self) -> ExplainerConfig: """Return the assembled configuration (no side effects).""" # attach a perf factory convenience object when building config so later # consumers can opt-in to perf primitives consistently. This does not # change behavior unless the factory is used. try: factory_builder = _perf_from_config or _build_perf_factory # stash a lightweight factory on the config for downstream wiring self._cfg._perf_factory = factory_builder(self._cfg) # type: ignore[attr-defined] except: # noqa: E722 if not isinstance(sys.exc_info()[1], Exception): raise # be conservative: do not fail config building if perf factory creation fails self._cfg._perf_factory = None # type: ignore[attr-defined] return self._cfg
class _ConfigPerfFactory: """Internal cache/parallel primitive builder for config-based wrapper wiring.""" def __init__(self, cache_cfg: Any, parallel_cfg: Any) -> None: self._cache_cfg = cache_cfg self._parallel_cfg = parallel_cfg def make_cache(self) -> Any: """Build a cache backend based on the stored configuration.""" from ..cache import CalibratorCache return CalibratorCache(self._cache_cfg) def make_parallel_executor(self, cache: Any | None = None) -> Any: """Create a parallel executor wired to the stored parallel configuration.""" from ..parallel import ParallelExecutor return ParallelExecutor(self._parallel_cfg, cache=cache) def make_parallel_backend(self, cache: Any | None = None) -> Any: """Alias for :meth:`make_parallel_executor`.""" return self.make_parallel_executor(cache=cache) def _build_perf_factory(cfg: Any) -> _ConfigPerfFactory: """Create perf primitives from config without using removed perf root facade.""" from ..cache import CacheConfig from ..parallel import ParallelConfig cache_cfg = CacheConfig( enabled=getattr(cfg, "perf_cache_enabled", False), namespace=getattr(cfg, "perf_cache_namespace", "calibrator"), version=getattr(cfg, "perf_cache_version", "v1"), max_items=getattr(cfg, "perf_cache_max_items", 512), max_bytes=getattr(cfg, "perf_cache_max_bytes", 32 * 1024 * 1024), ttl_seconds=getattr(cfg, "perf_cache_ttl", None), telemetry=getattr(cfg, "perf_telemetry", None), ) cache_cfg = CacheConfig.from_env(cache_cfg) parallel_cfg = ParallelConfig( enabled=getattr(cfg, "perf_parallel_enabled", False), strategy=getattr(cfg, "perf_parallel_backend", "auto"), max_workers=getattr(cfg, "perf_parallel_workers", None), min_batch_size=getattr(cfg, "perf_parallel_min_batch", 8), min_instances_for_parallel=getattr(cfg, "perf_parallel_min_instances", None), tiny_workload_threshold=getattr(cfg, "perf_parallel_tiny_workload", None), granularity=getattr(cfg, "perf_parallel_granularity", "feature"), telemetry=getattr(cfg, "perf_telemetry", None), ) parallel_cfg = ParallelConfig.from_env(parallel_cfg) return _ConfigPerfFactory(cache_cfg=cache_cfg, parallel_cfg=parallel_cfg) __all__ = [ "ExplainerConfig", "ExplainerBuilder", "TaskLiteral", ]