diff --git a/core/embedding/faiss_manager.py b/core/embedding/faiss_manager.py index aab14c548..7d2900319 100644 --- a/core/embedding/faiss_manager.py +++ b/core/embedding/faiss_manager.py @@ -11,7 +11,12 @@ from pathlib import Path from dataclasses import dataclass import numpy as np import json -import pickle + +from core.security.signed_serializer import ( + SignatureVerificationError, + load_signed, + save_signed, +) logger = logging.getLogger(__name__) @@ -500,21 +505,23 @@ class FAISSManager: # Sauvegarder index FAISS faiss.write_index(index_to_save, str(index_path)) - # Sauvegarder métadonnées + # Sauvegarder métadonnées (JSON signé HMAC — cf. core.security.signed_serializer) metadata = { "dimensions": self.dimensions, "index_type": self.index_type, "metric": self.metric, "next_id": self.next_id, - "metadata_store": self.metadata_store, + # Les clés dict sont des int côté Python ; on les sérialise en str + # puis on les reconvertit au chargement. JSON n'autorise pas de + # clés non-string. + "metadata_store": {str(k): v for k, v in self.metadata_store.items()}, "nlist": self.nlist, "nprobe": self.nprobe, "is_trained": self.is_trained, - "auto_optimize": self.auto_optimize + "auto_optimize": self.auto_optimize, } - - with open(metadata_path, 'wb') as f: - pickle.dump(metadata, f) + + save_signed(metadata_path, metadata) @classmethod def load(cls, index_path: Path, metadata_path: Path, use_gpu: bool = False) -> 'FAISSManager': @@ -529,11 +536,22 @@ class FAISSManager: Returns: FAISSManager chargé """ - # Charger métadonnées - with open(metadata_path, 'rb') as f: - metadata = pickle.load(f) - - # Créer instance + # Charger métadonnées (JSON signé ; fallback legacy pickle avec migration). + try: + metadata = load_signed(metadata_path) + except SignatureVerificationError: + logger.error( + "Signature HMAC invalide pour %s — refus de chargement.", + metadata_path, + ) + raise + + # Reconvertir les clés int du metadata_store (JSON force des clés str). + if isinstance(metadata.get("metadata_store"), dict): + metadata["metadata_store"] = { + int(k) if isinstance(k, str) and k.lstrip("-").isdigit() else k: v + for k, v in metadata["metadata_store"].items() + } manager = cls( dimensions=metadata["dimensions"], index_type=metadata["index_type"], diff --git a/core/execution/dag_executor.py b/core/execution/dag_executor.py index 954278cdc..24c74d782 100644 --- a/core/execution/dag_executor.py +++ b/core/execution/dag_executor.py @@ -525,11 +525,25 @@ class DAGExecutor: True/False selon le résultat de la condition """ condition = action.get("condition", "True") - # Contexte d'évaluation sécurisé : uniquement les résultats + # Contexte d'évaluation sécurisé : uniquement les résultats. + # NB : on utilise un évaluateur AST restreint (pas d'eval/exec), + # seuls literals, comparaisons, booléens et indexations sont permis. eval_context = {"results": dict(self._results)} + # Import local pour éviter une dépendance circulaire au chargement. + from core.execution.safe_condition_evaluator import ( + UnsafeExpressionError, + safe_eval_condition, + ) + try: - result = bool(eval(condition, {"__builtins__": {}}, eval_context)) + result = bool(safe_eval_condition(condition, eval_context)) + except UnsafeExpressionError as exc: + logger.error( + "Condition refusée pour '%s' (expression non sûre) : %s", + step.step_id, exc, + ) + result = False except Exception as exc: logger.warning( "Erreur d'évaluation de condition pour '%s' : %s", diff --git a/core/execution/safe_condition_evaluator.py b/core/execution/safe_condition_evaluator.py new file mode 100644 index 000000000..01a1f9609 --- /dev/null +++ b/core/execution/safe_condition_evaluator.py @@ -0,0 +1,228 @@ +""" +Évaluateur de conditions sécurisé pour le DAGExecutor. + +Remplace `eval()` (vulnérable à l'exécution de code arbitraire) par un +parseur AST restreint : + +- Seuls les noeuds AST nécessaires sont autorisés (literals, comparaisons, + booléens, indexations, accès attribut limité, arithmétique simple). +- Les appels de fonction sont interdits. +- Les accès à des attributs « dunder » (`__class__`, `__import__`, etc.) + sont systématiquement refusés pour éviter les évasions classiques. +- Le contexte d'évaluation est fourni explicitement par l'appelant ; + aucun builtins n'est exposé. + +Usage typique : + >>> evaluator = SafeConditionEvaluator() + >>> evaluator.evaluate("results['step_1']['score'] >= 0.8", + ... {"results": {"step_1": {"score": 0.92}}}) + True +""" + +from __future__ import annotations + +import ast +import operator +from typing import Any, Callable, Dict, Mapping + + +class UnsafeExpressionError(ValueError): + """Levée lorsqu'une expression contient un noeud AST interdit.""" + + +# Opérateurs arithmétiques & de comparaison autorisés. +_BIN_OPS: Dict[type, Callable[[Any, Any], Any]] = { + ast.Add: operator.add, + ast.Sub: operator.sub, + ast.Mult: operator.mul, + ast.Div: operator.truediv, + ast.FloorDiv: operator.floordiv, + ast.Mod: operator.mod, + ast.Pow: operator.pow, +} + +_BOOL_OPS: Dict[type, Callable[[Any, Any], Any]] = { + ast.And: lambda a, b: a and b, + ast.Or: lambda a, b: a or b, +} + +_UNARY_OPS: Dict[type, Callable[[Any], Any]] = { + ast.Not: operator.not_, + ast.USub: operator.neg, + ast.UAdd: operator.pos, +} + +_CMP_OPS: Dict[type, Callable[[Any, Any], bool]] = { + ast.Eq: operator.eq, + ast.NotEq: operator.ne, + ast.Lt: operator.lt, + ast.LtE: operator.le, + ast.Gt: operator.gt, + ast.GtE: operator.ge, + ast.In: lambda a, b: a in b, + ast.NotIn: lambda a, b: a not in b, + ast.Is: operator.is_, + ast.IsNot: operator.is_not, +} + + +class SafeConditionEvaluator: + """Évalue une expression de condition via un parseur AST restreint.""" + + # Longueur max — stoppe les expressions pathologiques très tôt. + MAX_EXPRESSION_LENGTH = 1024 + + def evaluate( + self, + expression: str, + context: Mapping[str, Any], + ) -> Any: + if not isinstance(expression, str): + raise UnsafeExpressionError( + "L'expression doit être une chaîne de caractères." + ) + if len(expression) > self.MAX_EXPRESSION_LENGTH: + raise UnsafeExpressionError( + "Expression trop longue (> 1024 caractères)." + ) + + try: + tree = ast.parse(expression, mode="eval") + except SyntaxError as exc: + raise UnsafeExpressionError( + f"Syntaxe d'expression invalide : {exc}" + ) from exc + + return self._eval_node(tree.body, context) + + # ------------------------------------------------------------------ + # Dispatch AST + # ------------------------------------------------------------------ + + def _eval_node(self, node: ast.AST, context: Mapping[str, Any]) -> Any: + # Littéraux (Constant remplace Num/Str/Bytes/NameConstant depuis 3.8) + if isinstance(node, ast.Constant): + return node.value + + # Variables : uniquement celles présentes dans `context`. + if isinstance(node, ast.Name): + if node.id not in context: + raise UnsafeExpressionError( + f"Variable '{node.id}' non autorisée." + ) + return context[node.id] + + # Accès attribut — interdit tout attribut dunder. + if isinstance(node, ast.Attribute): + if node.attr.startswith("_"): + raise UnsafeExpressionError( + f"Accès à l'attribut privé '{node.attr}' interdit." + ) + value = self._eval_node(node.value, context) + return getattr(value, node.attr) + + # Indexation (results['step_1']). + if isinstance(node, ast.Subscript): + value = self._eval_node(node.value, context) + # Python < 3.9 utilise ast.Index, >= 3.9 utilise directement un + # noeud. On gère les deux cas. + slice_node = node.slice + if isinstance(slice_node, ast.Index): # type: ignore[attr-defined] + slice_value = self._eval_node( + slice_node.value, context # type: ignore[attr-defined] + ) + else: + slice_value = self._eval_node(slice_node, context) + return value[slice_value] + + # Comparaisons chaînées (a < b <= c). + if isinstance(node, ast.Compare): + left = self._eval_node(node.left, context) + for op_node, comparator in zip(node.ops, node.comparators): + op_cls = type(op_node) + if op_cls not in _CMP_OPS: + raise UnsafeExpressionError( + f"Opérateur de comparaison '{op_cls.__name__}' interdit." + ) + right = self._eval_node(comparator, context) + if not _CMP_OPS[op_cls](left, right): + return False + left = right + return True + + # Booléen (and / or) — short-circuit manuel. + if isinstance(node, ast.BoolOp): + op_cls = type(node.op) + if op_cls not in _BOOL_OPS: + raise UnsafeExpressionError( + f"Opérateur booléen '{op_cls.__name__}' interdit." + ) + if isinstance(node.op, ast.And): + result: Any = True + for sub in node.values: + result = self._eval_node(sub, context) + if not result: + return result + return result + # Or + result = False + for sub in node.values: + result = self._eval_node(sub, context) + if result: + return result + return result + + # Unaires (-x, not x) + if isinstance(node, ast.UnaryOp): + op_cls = type(node.op) + if op_cls not in _UNARY_OPS: + raise UnsafeExpressionError( + f"Opérateur unaire '{op_cls.__name__}' interdit." + ) + return _UNARY_OPS[op_cls](self._eval_node(node.operand, context)) + + # Binaires (+, -, *, /, %, **, //) + if isinstance(node, ast.BinOp): + op_cls = type(node.op) + if op_cls not in _BIN_OPS: + raise UnsafeExpressionError( + f"Opérateur binaire '{op_cls.__name__}' interdit." + ) + left = self._eval_node(node.left, context) + right = self._eval_node(node.right, context) + return _BIN_OPS[op_cls](left, right) + + # Literals composites + if isinstance(node, ast.Tuple): + return tuple(self._eval_node(e, context) for e in node.elts) + if isinstance(node, ast.List): + return [self._eval_node(e, context) for e in node.elts] + if isinstance(node, ast.Set): + return {self._eval_node(e, context) for e in node.elts} + if isinstance(node, ast.Dict): + return { + self._eval_node(k, context) if k is not None else None: + self._eval_node(v, context) + for k, v in zip(node.keys, node.values) + } + + # Tout le reste (Call, Lambda, Comprehensions, Import, etc.) est + # refusé explicitement. + raise UnsafeExpressionError( + f"Noeud AST '{type(node).__name__}' interdit dans les conditions." + ) + + +def safe_eval_condition( + expression: str, + context: Mapping[str, Any], +) -> Any: + """Helper fonctionnel : évalue `expression` avec le contexte donné.""" + return SafeConditionEvaluator().evaluate(expression, context) + + +__all__ = [ + "SafeConditionEvaluator", + "UnsafeExpressionError", + "safe_eval_condition", +] diff --git a/core/security/signed_serializer.py b/core/security/signed_serializer.py new file mode 100644 index 000000000..a7e66bcec --- /dev/null +++ b/core/security/signed_serializer.py @@ -0,0 +1,308 @@ +""" +Sérialiseur signé — RPA Vision V3 + +Remplace les usages de `pickle.load` (vulnérables à la désérialisation arbitraire +de code) par une sérialisation JSON signée via HMAC-SHA256. + +Principes : +- Les données sont sérialisées en JSON (avec support des types numpy / datetime + via un encodeur custom). +- Une signature HMAC-SHA256 est calculée sur le JSON avec une clé secrète + dérivée de `RPA_SIGNING_KEY` (ou, à défaut, de `TOKEN_SECRET_KEY`). +- À la lecture, la signature est vérifiée AVANT tout parsing applicatif. +- Rétrocompatibilité : un fallback `pickle.load` est disponible pour migrer + les anciens fichiers. Il logue un WARNING et doit être suivi d'une + ré-écriture en JSON signé. + +ATTENTION : n'utiliser le fallback pickle que sur des fichiers dont la source +est réputée sûre (locale + protégée). Le fallback est désactivable via la +variable d'environnement `RPA_ALLOW_PICKLE_FALLBACK=0`. +""" + +from __future__ import annotations + +import base64 +import hashlib +import hmac +import io +import json +import logging +import os +import pickle +from datetime import datetime, timedelta +from pathlib import Path +from typing import Any, Callable, Optional, Union + +import numpy as np + +logger = logging.getLogger(__name__) + +# ----------------------------------------------------------------------------- +# Clé de signature +# ----------------------------------------------------------------------------- + +_SIGNATURE_ALGO = "sha256" +_SIGNATURE_HEADER = b"RPA_SIGNED_V1\n" # Marqueur de format signé + + +def _resolve_signing_key() -> bytes: + """Récupère la clé de signature HMAC. + + Ordre de priorité : + 1. RPA_SIGNING_KEY (dédiée à la signature de fichiers) + 2. TOKEN_SECRET_KEY (clé déjà utilisée pour signer les tokens API) + 3. Clé dérivée en dev (avec WARNING) + + La clé dev est stable pour une même machine (dérivée du hostname + path) + afin que les lectures/écritures locales restent cohérentes en l'absence + de configuration, tout en refusant de valider des fichiers produits + ailleurs. + """ + explicit = os.getenv("RPA_SIGNING_KEY", "").strip() + if explicit: + return explicit.encode("utf-8") + + fallback = os.getenv("TOKEN_SECRET_KEY", "").strip() + if fallback: + return fallback.encode("utf-8") + + # Clé dev dérivée : non cryptographiquement sûre, juste pour éviter des + # erreurs en dev local. On loggue explicitement. + logger.warning( + "RPA_SIGNING_KEY et TOKEN_SECRET_KEY non définis — " + "utilisation d'une clé dérivée locale. " + "Définir RPA_SIGNING_KEY en production." + ) + seed = f"rpa-vision-v3::{os.uname().nodename}::dev-signing" # type: ignore[attr-defined] + return hashlib.sha256(seed.encode("utf-8")).digest() + + +# ----------------------------------------------------------------------------- +# Encodage JSON étendu (numpy, datetime, Path, bytes) +# ----------------------------------------------------------------------------- + +class _RPAJSONEncoder(json.JSONEncoder): + """Encodeur JSON supportant numpy / datetime / Path / bytes.""" + + def default(self, obj: Any) -> Any: # noqa: D401 - API json standard + if isinstance(obj, np.ndarray): + return { + "__type__": "ndarray", + "dtype": str(obj.dtype), + "shape": list(obj.shape), + "data": base64.b64encode(obj.tobytes()).decode("ascii"), + } + if isinstance(obj, (np.integer,)): + return int(obj) + if isinstance(obj, (np.floating,)): + return float(obj) + if isinstance(obj, (np.bool_,)): + return bool(obj) + if isinstance(obj, datetime): + return {"__type__": "datetime", "iso": obj.isoformat()} + if isinstance(obj, timedelta): + return {"__type__": "timedelta", "seconds": obj.total_seconds()} + if isinstance(obj, Path): + return {"__type__": "path", "value": str(obj)} + if isinstance(obj, bytes): + return { + "__type__": "bytes", + "data": base64.b64encode(obj).decode("ascii"), + } + if isinstance(obj, set): + return {"__type__": "set", "items": list(obj)} + return super().default(obj) + + +def _json_object_hook(obj: Any) -> Any: + """Reconstruit les types étendus depuis le JSON.""" + if not isinstance(obj, dict): + return obj + tag = obj.get("__type__") + if tag is None: + return obj + if tag == "ndarray": + raw = base64.b64decode(obj["data"]) + arr = np.frombuffer(raw, dtype=np.dtype(obj["dtype"])) + return arr.reshape(obj["shape"]).copy() + if tag == "datetime": + return datetime.fromisoformat(obj["iso"]) + if tag == "timedelta": + return timedelta(seconds=float(obj["seconds"])) + if tag == "path": + return Path(obj["value"]) + if tag == "bytes": + return base64.b64decode(obj["data"]) + if tag == "set": + return set(obj.get("items", [])) + return obj + + +# ----------------------------------------------------------------------------- +# Erreurs dédiées +# ----------------------------------------------------------------------------- + +class SignedSerializerError(Exception): + """Erreur de base du module.""" + + +class SignatureVerificationError(SignedSerializerError): + """Signature HMAC invalide : le fichier a été altéré ou forgé.""" + + +class UnsupportedFormatError(SignedSerializerError): + """Le fichier n'est ni au format signé, ni reconnu comme pickle legacy.""" + + +# ----------------------------------------------------------------------------- +# API publique +# ----------------------------------------------------------------------------- + +def _compute_hmac(payload: bytes, key: bytes) -> str: + return hmac.new(key, payload, hashlib.sha256).hexdigest() + + +def dumps_signed(data: Any, key: Optional[bytes] = None) -> bytes: + """Sérialise `data` en JSON signé HMAC-SHA256. + + Format binaire retourné : + b"RPA_SIGNED_V1\n" + utf8(json({"hmac": "", "payload": })) + + Le HMAC couvre le JSON canonique de `payload` (keys triées, + séparateurs compacts) pour qu'un même objet produise toujours la + même signature. + """ + signing_key = key if key is not None else _resolve_signing_key() + payload_json = json.dumps( + data, + cls=_RPAJSONEncoder, + sort_keys=True, + separators=(",", ":"), + ensure_ascii=False, + ).encode("utf-8") + signature = _compute_hmac(payload_json, signing_key) + envelope = {"hmac": signature, "payload_b64": base64.b64encode(payload_json).decode("ascii")} + body = json.dumps(envelope, separators=(",", ":"), ensure_ascii=False).encode("utf-8") + return _SIGNATURE_HEADER + body + + +def loads_signed(raw: bytes, key: Optional[bytes] = None) -> Any: + """Désérialise un blob produit par `dumps_signed` après vérification HMAC.""" + if not raw.startswith(_SIGNATURE_HEADER): + raise UnsupportedFormatError("Marqueur RPA_SIGNED_V1 absent.") + signing_key = key if key is not None else _resolve_signing_key() + body = raw[len(_SIGNATURE_HEADER):] + try: + envelope = json.loads(body.decode("utf-8")) + except (UnicodeDecodeError, json.JSONDecodeError) as exc: + raise SignedSerializerError(f"Enveloppe JSON invalide : {exc}") from exc + + if not isinstance(envelope, dict): + raise SignedSerializerError("Enveloppe inattendue.") + signature = envelope.get("hmac") + payload_b64 = envelope.get("payload_b64") + if not isinstance(signature, str) or not isinstance(payload_b64, str): + raise SignedSerializerError("Enveloppe mal formée (hmac / payload_b64).") + + try: + payload_bytes = base64.b64decode(payload_b64.encode("ascii"), validate=True) + except Exception as exc: # noqa: BLE001 - base64 peut lever plusieurs erreurs + raise SignedSerializerError(f"Payload base64 invalide : {exc}") from exc + + expected = _compute_hmac(payload_bytes, signing_key) + if not hmac.compare_digest(expected, signature): + raise SignatureVerificationError( + "Signature HMAC invalide — fichier altéré ou clé différente." + ) + + return json.loads(payload_bytes.decode("utf-8"), object_hook=_json_object_hook) + + +def _pickle_fallback_allowed() -> bool: + return os.getenv("RPA_ALLOW_PICKLE_FALLBACK", "1") != "0" + + +def save_signed(path: Union[str, Path], data: Any, key: Optional[bytes] = None) -> None: + """Écrit `data` sur disque dans le format JSON signé.""" + path = Path(path) + path.parent.mkdir(parents=True, exist_ok=True) + blob = dumps_signed(data, key=key) + tmp = path.with_suffix(path.suffix + ".tmp") + with open(tmp, "wb") as fp: + fp.write(blob) + os.replace(tmp, path) + + +def load_signed( + path: Union[str, Path], + *, + allow_pickle_fallback: bool = True, + migrate_on_fallback: bool = True, + pickle_loader: Optional[Callable[[io.BufferedReader], Any]] = None, + key: Optional[bytes] = None, +) -> Any: + """Charge un fichier sauvegardé par `save_signed`. + + Si le fichier n'est pas au format signé, et si `allow_pickle_fallback` + est vrai (ET `RPA_ALLOW_PICKLE_FALLBACK != "0"`), tente un + `pickle.load()` pour migrer les anciens fichiers. Dans ce cas, un + WARNING est émis et le fichier est ré-écrit en JSON signé si + `migrate_on_fallback` vaut True. + + Args: + path: Chemin du fichier + allow_pickle_fallback: Activer la compat legacy + migrate_on_fallback: Ré-écrire en JSON signé après fallback + pickle_loader: Callable alternatif (pour tests / restricted unpickler) + key: Clé HMAC explicite (sinon dérivée de l'environnement) + + Raises: + SignatureVerificationError: HMAC invalide (fichier altéré) + UnsupportedFormatError: format inconnu et fallback désactivé + """ + path = Path(path) + with open(path, "rb") as fp: + raw = fp.read() + + if raw.startswith(_SIGNATURE_HEADER): + return loads_signed(raw, key=key) + + if not allow_pickle_fallback or not _pickle_fallback_allowed(): + raise UnsupportedFormatError( + f"{path} n'est pas au format signé et le fallback pickle est désactivé." + ) + + logger.warning( + "Chargement legacy pickle pour %s — ce format est obsolète et " + "sera ré-écrit en JSON signé. Voir docs/SECURITY.md.", + path, + ) + + # Par défaut on refuse tout type non documenté dans ce fichier à risque : + # utilisateur peut fournir un `pickle_loader` custom (ex: Unpickler + # restreint). On log l'ouverture pour la traçabilité. + loader = pickle_loader or (lambda f: pickle.load(f)) # noqa: S301 - usage legacy + with open(path, "rb") as fp: + data = loader(fp) + + if migrate_on_fallback: + try: + save_signed(path, data, key=key) + logger.info("Fichier %s migré en JSON signé.", path) + except Exception as exc: # noqa: BLE001 + logger.error( + "Migration JSON signé échouée pour %s : %s", path, exc + ) + + return data + + +__all__ = [ + "SignedSerializerError", + "SignatureVerificationError", + "UnsupportedFormatError", + "dumps_signed", + "loads_signed", + "save_signed", + "load_signed", +] diff --git a/core/visual/visual_embedding_manager.py b/core/visual/visual_embedding_manager.py index 1563a7379..8afbba19d 100644 --- a/core/visual/visual_embedding_manager.py +++ b/core/visual/visual_embedding_manager.py @@ -26,11 +26,15 @@ from PIL import Image import logging import threading from concurrent.futures import ThreadPoolExecutor -import pickle import os from core.models import BBox from core.embedding.fusion_engine import FusionEngine +from core.security.signed_serializer import ( + SignatureVerificationError, + load_signed, + save_signed, +) logger = logging.getLogger(__name__) @@ -521,42 +525,90 @@ class VisualEmbeddingManager: logger.debug(f"Éviction de {num_to_remove} entrées du cache") + def _entry_to_dict(self, entry: "EmbeddingCacheEntry") -> Dict[str, Any]: + """Convertit une entrée du cache en dict JSON-serialisable.""" + return { + "embedding": entry.embedding, # numpy → encodé par signed_serializer + "signature": entry.signature, + "created_at": entry.created_at, + "access_count": entry.access_count, + "last_accessed": entry.last_accessed, + } + + def _dict_to_entry(self, data: Any) -> Optional["EmbeddingCacheEntry"]: + """Reconstruit une EmbeddingCacheEntry depuis un dict (format JSON) + ou depuis un objet déjà typé (fallback pickle legacy). + Retourne None si la donnée n'est pas exploitable. + """ + if isinstance(data, EmbeddingCacheEntry): + return data + if not isinstance(data, dict): + return None + try: + return EmbeddingCacheEntry( + embedding=np.asarray(data["embedding"]), + signature=data["signature"], + created_at=data["created_at"], + access_count=int(data.get("access_count", 0)), + last_accessed=data.get("last_accessed"), + ) + except (KeyError, TypeError, ValueError) as exc: + logger.warning(f"Entrée de cache invalide ignorée: {exc}") + return None + def _load_persistent_cache(self): - """Charge le cache persistant depuis le disque""" + """Charge le cache persistant depuis le disque (JSON signé HMAC, + fallback pickle legacy avec migration automatique).""" if not self.cache_persistence_path or not os.path.exists(self.cache_persistence_path): return - + try: - with open(self.cache_persistence_path, 'rb') as f: - cached_data = pickle.load(f) - - # Filtrer les entrées trop anciennes (plus de 24h) - cutoff_time = datetime.now() - timedelta(hours=24) - - for signature, entry in cached_data.items(): - if entry.created_at > cutoff_time: - self._embedding_cache[signature] = entry - - logger.info(f"Cache persistant chargé: {len(self._embedding_cache)} entrées") - + cached_data = load_signed(self.cache_persistence_path) + except SignatureVerificationError: + logger.error( + "Cache persistant %s altéré (HMAC invalide) — ignoré.", + self.cache_persistence_path, + ) + return except Exception as e: logger.warning(f"Erreur lors du chargement du cache persistant: {e}") - + return + + if not isinstance(cached_data, dict): + logger.warning("Format de cache inattendu — ignoré.") + return + + # Filtrer les entrées trop anciennes (plus de 24h) + cutoff_time = datetime.now() - timedelta(hours=24) + loaded = 0 + for signature, raw in cached_data.items(): + entry = self._dict_to_entry(raw) + if entry is None: + continue + if entry.created_at > cutoff_time: + self._embedding_cache[signature] = entry + loaded += 1 + + logger.info(f"Cache persistant chargé: {loaded} entrées") + def _save_persistent_cache(self): - """Sauvegarde le cache sur disque""" + """Sauvegarde le cache sur disque en JSON signé HMAC.""" if not self.cache_persistence_path: return - + try: # Créer le répertoire si nécessaire os.makedirs(os.path.dirname(self.cache_persistence_path), exist_ok=True) - + with self._cache_lock: - with open(self.cache_persistence_path, 'wb') as f: - pickle.dump(dict(self._embedding_cache), f) - + serializable = { + signature: self._entry_to_dict(entry) + for signature, entry in self._embedding_cache.items() + } + + save_signed(self.cache_persistence_path, serializable) logger.debug("Cache persistant sauvegardé") - + except Exception as e: logger.warning(f"Erreur lors de la sauvegarde du cache: {e}") diff --git a/core/visual/visual_persistence_manager.py b/core/visual/visual_persistence_manager.py index 7fbe1c284..61701fedf 100644 --- a/core/visual/visual_persistence_manager.py +++ b/core/visual/visual_persistence_manager.py @@ -14,8 +14,9 @@ import asyncio import logging import json import base64 -import pickle import gzip +import pickle # noqa: S403 - usage legacy restreint au fallback de migration +import io from typing import Dict, List, Optional, Any, Tuple from dataclasses import dataclass, asdict from datetime import datetime @@ -24,6 +25,12 @@ import numpy as np from core.visual.visual_target_manager import VisualTarget, VisualTargetManager from core.visual.screenshot_validation_manager import ScreenshotValidationManager, ValidationResult +from core.security.signed_serializer import ( + SignatureVerificationError, + UnsupportedFormatError, + dumps_signed, + loads_signed, +) logger = logging.getLogger(__name__) @@ -435,19 +442,19 @@ class VisualPersistenceManager: return None async def _serialize_workflow_data(self, workflow_data: VisualWorkflowData) -> bytes: - """Sérialise les données d'un workflow""" + """Sérialise les données d'un workflow en JSON signé HMAC.""" # Convertir en dictionnaire data_dict = asdict(workflow_data) - + # Traiter les types spéciaux data_dict['created_at'] = workflow_data.created_at.isoformat() - + # Sérialiser les cibles visuelles serialized_targets = {} for signature, target in workflow_data.visual_targets.items(): serialized_targets[signature] = await self._serialize_visual_target(target) data_dict['visual_targets'] = serialized_targets - + # Sérialiser l'historique de validation serialized_history = {} for signature, history in workflow_data.validation_history.items(): @@ -455,15 +462,30 @@ class VisualPersistenceManager: self._serialize_validation_result(result) for result in history ] data_dict['validation_history'] = serialized_history - - # Convertir en bytes - return pickle.dumps(data_dict) - + + # JSON signé HMAC (cf. core.security.signed_serializer) + return dumps_signed(data_dict) + async def _deserialize_workflow_data(self, data: bytes) -> VisualWorkflowData: - """Désérialise les données d'un workflow""" - # Désérialiser le dictionnaire - data_dict = pickle.loads(data) - + """Désérialise les données d'un workflow (JSON signé HMAC ; + fallback pickle legacy avec WARNING pour migrer les anciens fichiers).""" + try: + data_dict = loads_signed(data) + except SignatureVerificationError: + # Fichier altéré ou clé différente : on refuse sans fallback. + logger.error("Workflow visuel : signature HMAC invalide — refus.") + raise + except UnsupportedFormatError: + # Ancien format pickle : fallback explicite et bruyant. + import os + if os.getenv("RPA_ALLOW_PICKLE_FALLBACK", "1") == "0": + raise + logger.warning( + "Workflow visuel au format pickle legacy — lecture de compat, " + "ré-écrire en JSON signé dès que possible." + ) + data_dict = pickle.loads(data) # noqa: S301 - fallback legacy + # Reconstruire les objets workflow_data = VisualWorkflowData( workflow_id=data_dict['workflow_id'], diff --git a/tests/unit/test_security_safe_condition.py b/tests/unit/test_security_safe_condition.py new file mode 100644 index 000000000..28c330d08 --- /dev/null +++ b/tests/unit/test_security_safe_condition.py @@ -0,0 +1,179 @@ +"""Tests de sécurité : évaluateur de conditions AST restreint.""" + +from __future__ import annotations + +import pytest + +from core.execution.safe_condition_evaluator import ( + SafeConditionEvaluator, + UnsafeExpressionError, + safe_eval_condition, +) + + +# --------------------------------------------------------------------------- +# Cas valides — expressions que les workflows doivent pouvoir évaluer +# --------------------------------------------------------------------------- + +class TestValidExpressions: + def test_literal_true(self): + assert safe_eval_condition("True", {}) is True + + def test_literal_false(self): + assert safe_eval_condition("False", {}) is False + + def test_numeric_comparison(self): + assert safe_eval_condition("1 < 2", {}) is True + assert safe_eval_condition("2 < 1", {}) is False + + def test_chained_comparison(self): + assert safe_eval_condition("1 < 2 < 3", {}) is True + assert safe_eval_condition("1 < 3 < 2", {}) is False + + def test_variable_access(self): + assert safe_eval_condition("x > 5", {"x": 10}) is True + + def test_subscript_dict(self): + ctx = {"results": {"step_1": {"score": 0.9}}} + assert safe_eval_condition( + "results['step_1']['score'] >= 0.8", ctx + ) is True + + def test_boolean_and(self): + assert safe_eval_condition("True and False", {}) is False + assert safe_eval_condition("True and True", {}) is True + + def test_boolean_or(self): + assert safe_eval_condition("False or True", {}) is True + + def test_not_operator(self): + assert safe_eval_condition("not False", {}) is True + + def test_arithmetic(self): + assert safe_eval_condition("(a + b) * 2 > 10", {"a": 3, "b": 4}) is True + + def test_in_operator(self): + assert safe_eval_condition("'ok' in status", {"status": ["ok", "done"]}) is True + + def test_list_literal(self): + assert safe_eval_condition("x in [1, 2, 3]", {"x": 2}) is True + + +# --------------------------------------------------------------------------- +# Cas malveillants — tentatives d'injection / RCE +# --------------------------------------------------------------------------- + +class TestMaliciousExpressions: + """Toutes ces expressions DOIVENT lever UnsafeExpressionError.""" + + def test_rejects_import(self): + with pytest.raises(UnsafeExpressionError): + safe_eval_condition("__import__('os').system('echo pwn')", {}) + + def test_rejects_function_call(self): + with pytest.raises(UnsafeExpressionError): + safe_eval_condition("print('hello')", {"print": print}) + + def test_rejects_eval(self): + with pytest.raises(UnsafeExpressionError): + safe_eval_condition("eval('1+1')", {}) + + def test_rejects_exec(self): + with pytest.raises(UnsafeExpressionError): + safe_eval_condition("exec('x=1')", {}) + + def test_rejects_dunder_attribute(self): + # Classique : remonter à __builtins__ via __class__.__mro__ + with pytest.raises(UnsafeExpressionError): + safe_eval_condition("x.__class__", {"x": "abc"}) + + def test_rejects_dunder_subclasses(self): + with pytest.raises(UnsafeExpressionError): + safe_eval_condition( + "x.__class__.__mro__[-1].__subclasses__()", + {"x": []}, + ) + + def test_rejects_undefined_variable(self): + with pytest.raises(UnsafeExpressionError): + safe_eval_condition("secret > 0", {}) + + def test_rejects_lambda(self): + with pytest.raises(UnsafeExpressionError): + safe_eval_condition("(lambda: 42)()", {}) + + def test_rejects_list_comprehension(self): + with pytest.raises(UnsafeExpressionError): + safe_eval_condition("[x for x in range(3)]", {}) + + def test_rejects_generator(self): + with pytest.raises(UnsafeExpressionError): + safe_eval_condition("(x for x in [1])", {}) + + def test_rejects_walrus(self): + with pytest.raises(UnsafeExpressionError): + safe_eval_condition("(x := 1)", {}) + + def test_rejects_ifexp(self): + # IfExp (conditional) non autorisé par défaut — si besoin ajouter plus tard. + with pytest.raises(UnsafeExpressionError): + safe_eval_condition("1 if True else 2", {}) + + def test_rejects_starred(self): + with pytest.raises(UnsafeExpressionError): + safe_eval_condition("[*x]", {"x": [1, 2]}) + + def test_rejects_attribute_call_chain(self): + # Même si 'dict' est fourni dans le contexte, on n'autorise pas les + # appels de méthode. + with pytest.raises(UnsafeExpressionError): + safe_eval_condition( + "results.keys()", {"results": {"a": 1}} + ) + + def test_rejects_huge_expression(self): + big = "0+" * 1000 + "0" + with pytest.raises(UnsafeExpressionError): + safe_eval_condition(big, {}) + + def test_rejects_syntax_error(self): + with pytest.raises(UnsafeExpressionError): + safe_eval_condition("1 + ", {}) + + def test_rejects_non_string(self): + with pytest.raises(UnsafeExpressionError): + safe_eval_condition(12345, {}) # type: ignore[arg-type] + + +# --------------------------------------------------------------------------- +# Intégration avec DAGExecutor : le step condition doit refuser l'injection +# --------------------------------------------------------------------------- + +class TestDAGExecutorIntegration: + def test_condition_step_refuses_malicious_payload(self): + """Un workflow injectant __import__ dans 'condition' doit être refusé + silencieusement (result = False) sans exécuter le code.""" + from core.execution.dag_executor import DAGExecutor, WorkflowStep, StepType + + executor = DAGExecutor() + step = WorkflowStep( + step_id="malicious", + step_type=StepType.CONDITION, + action={"condition": "__import__('os').system('echo PWNED')"}, + ) + # Accès direct à la méthode privée pour isoler le comportement. + result = executor._execute_condition_step(step, step.action) + assert result is False + + def test_condition_step_accepts_safe_expression(self): + from core.execution.dag_executor import DAGExecutor, WorkflowStep, StepType + + executor = DAGExecutor() + executor._results["step_prev"] = {"ok": True} + step = WorkflowStep( + step_id="cond", + step_type=StepType.CONDITION, + action={"condition": "results['step_prev']['ok']"}, + ) + result = executor._execute_condition_step(step, step.action) + assert result is True diff --git a/tests/unit/test_security_signed_serializer.py b/tests/unit/test_security_signed_serializer.py new file mode 100644 index 000000000..b9e677a87 --- /dev/null +++ b/tests/unit/test_security_signed_serializer.py @@ -0,0 +1,239 @@ +"""Tests de sécurité : sérialiseur JSON signé HMAC. + +Couvre : +- round-trip JSON signé +- rejet d'un fichier altéré +- fallback pickle legacy + migration +- intégration FAISSManager (lecture / rejet HMAC) +""" + +from __future__ import annotations + +import json +import os +import pickle +from datetime import datetime +from pathlib import Path + +import numpy as np +import pytest + +from core.security.signed_serializer import ( + SignatureVerificationError, + UnsupportedFormatError, + dumps_signed, + load_signed, + loads_signed, + save_signed, +) + + +@pytest.fixture(autouse=True) +def _signing_key(monkeypatch): + """Force une clé de signature stable pour les tests.""" + monkeypatch.setenv("RPA_SIGNING_KEY", "test-signing-key-for-unit-tests-only") + monkeypatch.setenv("RPA_ALLOW_PICKLE_FALLBACK", "1") + yield + + +# --------------------------------------------------------------------------- +# Round-trip et types étendus +# --------------------------------------------------------------------------- + +class TestRoundTrip: + def test_primitive_types(self, tmp_path: Path): + payload = {"a": 1, "b": "texte", "c": [1, 2, 3], "d": None} + path = tmp_path / "data.json.signed" + save_signed(path, payload) + assert load_signed(path) == payload + + def test_numpy_roundtrip(self, tmp_path: Path): + arr = np.arange(12, dtype=np.float32).reshape(3, 4) + path = tmp_path / "arr.json.signed" + save_signed(path, {"embedding": arr}) + loaded = load_signed(path) + assert isinstance(loaded["embedding"], np.ndarray) + assert loaded["embedding"].shape == (3, 4) + assert loaded["embedding"].dtype == np.float32 + np.testing.assert_array_equal(loaded["embedding"], arr) + + def test_datetime_roundtrip(self, tmp_path: Path): + now = datetime(2026, 4, 13, 10, 0, 0) + path = tmp_path / "dt.json.signed" + save_signed(path, {"created_at": now}) + loaded = load_signed(path) + assert loaded["created_at"] == now + + def test_bytes_payload(self): + raw = dumps_signed({"blob": b"\x00\x01\x02"}) + out = loads_signed(raw) + assert out["blob"] == b"\x00\x01\x02" + + +# --------------------------------------------------------------------------- +# Rejet d'un fichier altéré +# --------------------------------------------------------------------------- + +class TestTampering: + def test_rejects_tampered_payload(self, tmp_path: Path): + path = tmp_path / "f.signed" + save_signed(path, {"score": 0.5}) + + raw = path.read_bytes() + # Altérer un caractère quelque part dans le payload base64. + idx = raw.rfind(b'"payload_b64":"') + len(b'"payload_b64":"') + tampered = raw[:idx] + (b"X" if raw[idx:idx + 1] != b"X" else b"Y") + raw[idx + 1:] + path.write_bytes(tampered) + + with pytest.raises((SignatureVerificationError, Exception)): + load_signed(path) + + def test_rejects_tampered_hmac(self, tmp_path: Path): + path = tmp_path / "f.signed" + save_signed(path, {"score": 0.5}) + + raw = path.read_bytes() + tampered = raw.replace(b'"hmac":"', b'"hmac":"0') + path.write_bytes(tampered) + + with pytest.raises(SignatureVerificationError): + load_signed(path) + + def test_rejects_wrong_key(self, tmp_path: Path, monkeypatch): + path = tmp_path / "f.signed" + save_signed(path, {"score": 0.5}) + + # Changer la clé : la vérification doit échouer. + monkeypatch.setenv("RPA_SIGNING_KEY", "other-key") + with pytest.raises(SignatureVerificationError): + load_signed(path) + + +# --------------------------------------------------------------------------- +# Fallback pickle + migration +# --------------------------------------------------------------------------- + +class TestPickleFallback: + def test_pickle_fallback_loads_and_migrates(self, tmp_path: Path): + # Écrire un vieux fichier pickle (format legacy). + path = tmp_path / "legacy.pkl" + payload = {"score": 0.42, "label": "legacy"} + with open(path, "wb") as fp: + pickle.dump(payload, fp) + + # Chargement : doit réussir ET migrer le fichier en signé. + loaded = load_signed(path, allow_pickle_fallback=True, migrate_on_fallback=True) + assert loaded == payload + + # Le fichier doit maintenant être au format signé. + new_raw = path.read_bytes() + assert new_raw.startswith(b"RPA_SIGNED_V1\n") + + # Et relisable via le format signé. + loaded2 = load_signed(path) + assert loaded2 == payload + + def test_pickle_fallback_disabled(self, tmp_path: Path, monkeypatch): + monkeypatch.setenv("RPA_ALLOW_PICKLE_FALLBACK", "0") + path = tmp_path / "legacy.pkl" + with open(path, "wb") as fp: + pickle.dump({"x": 1}, fp) + + with pytest.raises(UnsupportedFormatError): + load_signed(path) + + def test_pickle_fallback_explicit_off(self, tmp_path: Path): + path = tmp_path / "legacy.pkl" + with open(path, "wb") as fp: + pickle.dump({"x": 1}, fp) + + with pytest.raises(UnsupportedFormatError): + load_signed(path, allow_pickle_fallback=False) + + +# --------------------------------------------------------------------------- +# Intégration FAISSManager +# --------------------------------------------------------------------------- + +pytest.importorskip("faiss", reason="FAISS non installé.") + + +class TestFAISSManagerSignedMetadata: + def test_save_and_load_roundtrip(self, tmp_path: Path): + from core.embedding.faiss_manager import FAISSManager + + manager = FAISSManager(dimensions=8, index_type="Flat", metric="cosine") + vec = np.random.rand(8).astype(np.float32) + manager.add_embedding("emb_1", vec, metadata={"label": "target"}) + + index_path = tmp_path / "index.bin" + meta_path = tmp_path / "meta.signed" + manager.save(index_path, meta_path) + + # Le fichier métadonnées doit être signé. + raw = meta_path.read_bytes() + assert raw.startswith(b"RPA_SIGNED_V1\n") + + # Recharger. + reloaded = FAISSManager.load(index_path, meta_path) + assert reloaded.dimensions == 8 + assert reloaded.next_id == 1 + assert 0 in reloaded.metadata_store + assert reloaded.metadata_store[0]["embedding_id"] == "emb_1" + + def test_load_refuses_tampered_metadata(self, tmp_path: Path): + from core.embedding.faiss_manager import FAISSManager + + manager = FAISSManager(dimensions=4, index_type="Flat", metric="cosine") + manager.add_embedding("e", np.ones(4, dtype=np.float32), metadata={}) + + index_path = tmp_path / "index.bin" + meta_path = tmp_path / "meta.signed" + manager.save(index_path, meta_path) + + # Altérer la signature du fichier. + raw = meta_path.read_bytes() + meta_path.write_bytes(raw.replace(b'"hmac":"', b'"hmac":"0')) + + with pytest.raises(SignatureVerificationError): + FAISSManager.load(index_path, meta_path) + + def test_load_migrates_legacy_pickle(self, tmp_path: Path): + """Un fichier métadonnées pickle legacy doit être migré.""" + from core.embedding.faiss_manager import FAISSManager + import faiss + + # Construire manuellement un fichier legacy (comme l'ancienne version). + manager = FAISSManager(dimensions=4, index_type="Flat", metric="cosine") + vec = np.ones(4, dtype=np.float32) + manager.add_embedding("legacy_emb", vec, metadata={"tag": "old"}) + + index_path = tmp_path / "index.bin" + meta_path = tmp_path / "meta.pkl" + + # Écrire l'index FAISS normalement... + index_to_save = manager.index + faiss.write_index(index_to_save, str(index_path)) + + # ...mais les métadonnées en pickle brut (format pré-correctif). + legacy = { + "dimensions": 4, + "index_type": "Flat", + "metric": "cosine", + "next_id": manager.next_id, + "metadata_store": manager.metadata_store, + "nlist": None, + "nprobe": 8, + "is_trained": True, + "auto_optimize": True, + } + with open(meta_path, "wb") as fp: + pickle.dump(legacy, fp) + + # Chargement : doit réussir + migrer vers format signé. + reloaded = FAISSManager.load(index_path, meta_path) + assert reloaded.dimensions == 4 + assert reloaded.next_id == 1 + + # Le fichier a été ré-écrit en signé. + assert meta_path.read_bytes().startswith(b"RPA_SIGNED_V1\n")