"""Configuration primitives for calibrated_explanations.
This module introduces a light-weight configuration dataclass and a builder
to simplify constructing explainers with validated options.
no wiring to core classes is performed to avoid behavior changes; consumers
may import and use these types for future-facing code.
See `RELEASE_PLAN_v1` milestone targets and ADR-009 for preprocessing-related fields.
"""
from __future__ import annotations
import sys
from dataclasses import dataclass
from typing import Any, Literal
# Backward-compatible patch point used by tests. When set to a callable, build_config
# uses it instead of the internal factory builder.
_perf_from_config = None
TaskLiteral = Literal["classification", "regression", "auto"]
[docs]
@dataclass
class ExplainerConfig:
"""Configuration for building an explainer wrapper.
Notes
-----
- Fields included here are future-facing.
- Keep defaults aligned with existing behavior to prevent drift when adopted.
"""
model: Any
task: TaskLiteral = "auto"
# Calibration / explanation knobs (subset; extend later as needed)
low_high_percentiles: tuple[int, int] = (5, 95)
threshold: float | None = None # for probabilistic regression use-cases
# Preprocessing (ADR-009)
preprocessor: Any | None = None
auto_encode: bool | Literal["auto"] = "auto"
unseen_category_policy: Literal["ignore", "error"] = "error"
# Parallelism placeholder (not wired yet)
parallel_workers: int | None = None
# Performance feature flags (ADR-003/ADR-004) - disabled by default
perf_cache_enabled: bool = False
perf_cache_max_items: int = 512
perf_cache_max_bytes: int | None = 32 * 1024 * 1024
perf_cache_namespace: str = "calibrator"
perf_cache_version: str = "v1"
perf_cache_ttl: float | None = None
perf_parallel_enabled: bool = False
perf_parallel_backend: Literal["auto", "sequential", "joblib", "threads", "processes"] = "auto"
perf_parallel_workers: int | None = None
perf_parallel_min_batch: int = 8
perf_parallel_min_instances: int | None = None
perf_parallel_tiny_workload: int | None = None
perf_parallel_granularity: Literal["feature", "instance"] = "feature"
perf_telemetry: Any | None = None
# Internal FAST-based feature filtering (disabled by default)
perf_feature_filter_enabled: bool = False
perf_feature_filter_per_instance_top_k: int = 8
@property
def perf_factory(self):
"""Factory for performance telemetry."""
return getattr(self, "_perf_factory", None)
[docs]
class ExplainerBuilder:
"""Fluent helper to assemble an :class:`ExplainerConfig`.
In a later step this builder can produce a configured `WrapCalibratedExplainer`.
"""
def __init__(self, model: Any) -> None:
"""Store base model reference and seed configuration defaults."""
self._cfg = ExplainerConfig(model=model)
# Simple fluent setters
[docs]
def task(self, task: TaskLiteral) -> ExplainerBuilder:
"""Set the task type for the explainer configuration.
Parameters
----------
task : {"classification", "regression", "auto"}
Desired task literal to store in the configuration.
"""
self._cfg.task = task
return self
[docs]
def low_high_percentiles(self, p: tuple[int, int]) -> ExplainerBuilder:
"""Update the percentile pair for interval explanations.
Parameters
----------
p : tuple of int
Inclusive lower and upper percentiles used for interval computation.
"""
self._cfg.low_high_percentiles = p
return self
[docs]
def threshold(self, t: float | None) -> ExplainerBuilder:
"""Store a regression-style threshold value on the configuration.
Parameters
----------
t : float or None
Threshold applied when producing probabilistic regression outputs.
"""
self._cfg.threshold = t
return self
[docs]
def preprocessor(self, pre: Any | None) -> ExplainerBuilder:
"""Attach an optional preprocessing object to the configuration.
Parameters
----------
pre : Any or None
Preprocessor applied to inputs prior to fitting or calibration.
"""
self._cfg.preprocessor = pre
return self
[docs]
def auto_encode(self, flag: bool | Literal["auto"]) -> ExplainerBuilder:
"""Toggle automatic categorical encoding behavior.
Parameters
----------
flag : bool or "auto"
Whether to auto-encode categorical inputs when preprocessing.
"""
self._cfg.auto_encode = flag
return self
[docs]
def unseen_category_policy(self, policy: Literal["ignore", "error"]) -> ExplainerBuilder:
"""Select the strategy for handling unseen categorical values.
Parameters
----------
policy : {"ignore", "error"}
Policy to apply when encountering unseen categories at inference time.
"""
self._cfg.unseen_category_policy = policy
return self
[docs]
def parallel_workers(self, n: int | None) -> ExplainerBuilder:
"""Configure the desired number of parallel worker processes.
Parameters
----------
n : int or None
Worker count for parallel execution; ``None`` leaves the default in place.
"""
self._cfg.parallel_workers = n
return self
# Perf flags (feature-flagged; no behavior change when off)
[docs]
def perf_cache(
self,
enabled: bool,
*,
max_items: int | None = None,
max_bytes: int | None = None,
namespace: str | None = None,
version: str | None = None,
ttl: float | None = None,
) -> ExplainerBuilder:
"""Enable or disable the performance cache options.
Parameters
----------
enabled : bool
Flag indicating whether caching primitives should be provisioned.
max_items : int, optional
Maximum number of cached entries when caching is enabled.
"""
self._cfg.perf_cache_enabled = enabled
if max_items is not None:
self._cfg.perf_cache_max_items = max_items
if max_bytes is not None:
self._cfg.perf_cache_max_bytes = max_bytes
if namespace is not None:
self._cfg.perf_cache_namespace = namespace
if version is not None:
self._cfg.perf_cache_version = version
if ttl is not None:
self._cfg.perf_cache_ttl = ttl
return self
[docs]
def perf_parallel(
self,
enabled: bool,
*,
backend: Literal["auto", "sequential", "joblib", "threads", "processes"] | None = None,
workers: int | None = None,
min_batch: int | None = None,
min_instances: int | None = None,
tiny_workload: int | None = None,
granularity: Literal["feature", "instance"] | None = None,
) -> ExplainerBuilder:
"""Configure the parallel backend used for performance operations.
Parameters
----------
enabled : bool
Whether parallel primitives should be created.
backend : {"auto", "sequential", "joblib"}, optional
Explicit backend selection overriding the default when provided.
"""
self._cfg.perf_parallel_enabled = enabled
if backend is not None:
self._cfg.perf_parallel_backend = backend
if workers is not None:
self._cfg.perf_parallel_workers = workers
if min_batch is not None:
self._cfg.perf_parallel_min_batch = min_batch
if min_instances is not None:
self._cfg.perf_parallel_min_instances = min_instances
if tiny_workload is not None:
self._cfg.perf_parallel_tiny_workload = tiny_workload
if granularity is not None:
self._cfg.perf_parallel_granularity = granularity
return self
[docs]
def perf_telemetry(self, callback: Any | None) -> ExplainerBuilder:
"""Register a telemetry callback shared by cache and parallel executors."""
self._cfg.perf_telemetry = callback
return self
[docs]
def perf_feature_filter(
self,
enabled: bool,
*,
per_instance_top_k: int | None = None,
) -> ExplainerBuilder:
"""Configure internal FAST-based feature filtering.
Parameters
----------
enabled : bool
Flag indicating whether the internal FAST-based feature filter is enabled.
per_instance_top_k : int, optional
Maximum number of features to keep per instance based on FAST weights.
"""
self._cfg.perf_feature_filter_enabled = enabled
if per_instance_top_k is not None:
self._cfg.perf_feature_filter_per_instance_top_k = max(1, int(per_instance_top_k))
return self
[docs]
def build_config(self) -> ExplainerConfig:
"""Return the assembled configuration (no side effects)."""
# attach a perf factory convenience object when building config so later
# consumers can opt-in to perf primitives consistently. This does not
# change behavior unless the factory is used.
try:
factory_builder = _perf_from_config or _build_perf_factory
# stash a lightweight factory on the config for downstream wiring
self._cfg._perf_factory = factory_builder(self._cfg) # type: ignore[attr-defined]
except: # noqa: E722
if not isinstance(sys.exc_info()[1], Exception):
raise
# be conservative: do not fail config building if perf factory creation fails
self._cfg._perf_factory = None # type: ignore[attr-defined]
return self._cfg
class _ConfigPerfFactory:
"""Internal cache/parallel primitive builder for config-based wrapper wiring."""
def __init__(self, cache_cfg: Any, parallel_cfg: Any) -> None:
self._cache_cfg = cache_cfg
self._parallel_cfg = parallel_cfg
def make_cache(self) -> Any:
"""Build a cache backend based on the stored configuration."""
from ..cache import CalibratorCache
return CalibratorCache(self._cache_cfg)
def make_parallel_executor(self, cache: Any | None = None) -> Any:
"""Create a parallel executor wired to the stored parallel configuration."""
from ..parallel import ParallelExecutor
return ParallelExecutor(self._parallel_cfg, cache=cache)
def make_parallel_backend(self, cache: Any | None = None) -> Any:
"""Alias for :meth:`make_parallel_executor`."""
return self.make_parallel_executor(cache=cache)
def _build_perf_factory(cfg: Any) -> _ConfigPerfFactory:
"""Create perf primitives from config without using removed perf root facade."""
from ..cache import CacheConfig
from ..parallel import ParallelConfig
cache_cfg = CacheConfig(
enabled=getattr(cfg, "perf_cache_enabled", False),
namespace=getattr(cfg, "perf_cache_namespace", "calibrator"),
version=getattr(cfg, "perf_cache_version", "v1"),
max_items=getattr(cfg, "perf_cache_max_items", 512),
max_bytes=getattr(cfg, "perf_cache_max_bytes", 32 * 1024 * 1024),
ttl_seconds=getattr(cfg, "perf_cache_ttl", None),
telemetry=getattr(cfg, "perf_telemetry", None),
)
cache_cfg = CacheConfig.from_env(cache_cfg)
parallel_cfg = ParallelConfig(
enabled=getattr(cfg, "perf_parallel_enabled", False),
strategy=getattr(cfg, "perf_parallel_backend", "auto"),
max_workers=getattr(cfg, "perf_parallel_workers", None),
min_batch_size=getattr(cfg, "perf_parallel_min_batch", 8),
min_instances_for_parallel=getattr(cfg, "perf_parallel_min_instances", None),
tiny_workload_threshold=getattr(cfg, "perf_parallel_tiny_workload", None),
granularity=getattr(cfg, "perf_parallel_granularity", "feature"),
telemetry=getattr(cfg, "perf_telemetry", None),
)
parallel_cfg = ParallelConfig.from_env(parallel_cfg)
return _ConfigPerfFactory(cache_cfg=cache_cfg, parallel_cfg=parallel_cfg)
__all__ = [
"ExplainerConfig",
"ExplainerBuilder",
"TaskLiteral",
]