feat(security): eval()→AST parseur + pickle→JSON+HMAC signé
Vulnérabilité 1 — eval() dans DAG executor : - Nouveau module safe_condition_evaluator.py - Parseur AST avec whitelist (Constants, Names, Compare, BoolOp, BinOp) - Rejet explicite Call/Lambda/Import/__dunder__/walrus/comprehensions - Expression non sûre → logged ERROR + évaluée à False (pas de crash) - 31 tests (12 valides, 17 malveillantes rejetées, 2 intégration) Vulnérabilité 2 — 3× pickle.load() non sécurisés : - Nouveau module signed_serializer.py (JSON+HMAC-SHA256) - Format : RPA_SIGNED_V1\\n + JSON(hmac + payload base64) - Migration automatique transparente au premier chargement - Fallback pickle avec WARNING (désactivable RPA_ALLOW_PICKLE_FALLBACK=0) - Remplacement dans faiss_manager, visual_embedding_manager, visual_persistence_manager - 13 tests Clé signature : RPA_SIGNING_KEY (fallback TOKEN_SECRET_KEY puis hostname-derived). Co-Authored-By: Claude Opus 4.6 (1M context) <noreply@anthropic.com>
This commit is contained in:
@@ -11,7 +11,12 @@ from pathlib import Path
|
||||
from dataclasses import dataclass
|
||||
import numpy as np
|
||||
import json
|
||||
import pickle
|
||||
|
||||
from core.security.signed_serializer import (
|
||||
SignatureVerificationError,
|
||||
load_signed,
|
||||
save_signed,
|
||||
)
|
||||
|
||||
logger = logging.getLogger(__name__)
|
||||
|
||||
@@ -500,21 +505,23 @@ class FAISSManager:
|
||||
# Sauvegarder index FAISS
|
||||
faiss.write_index(index_to_save, str(index_path))
|
||||
|
||||
# Sauvegarder métadonnées
|
||||
# Sauvegarder métadonnées (JSON signé HMAC — cf. core.security.signed_serializer)
|
||||
metadata = {
|
||||
"dimensions": self.dimensions,
|
||||
"index_type": self.index_type,
|
||||
"metric": self.metric,
|
||||
"next_id": self.next_id,
|
||||
"metadata_store": self.metadata_store,
|
||||
# Les clés dict sont des int côté Python ; on les sérialise en str
|
||||
# puis on les reconvertit au chargement. JSON n'autorise pas de
|
||||
# clés non-string.
|
||||
"metadata_store": {str(k): v for k, v in self.metadata_store.items()},
|
||||
"nlist": self.nlist,
|
||||
"nprobe": self.nprobe,
|
||||
"is_trained": self.is_trained,
|
||||
"auto_optimize": self.auto_optimize
|
||||
"auto_optimize": self.auto_optimize,
|
||||
}
|
||||
|
||||
with open(metadata_path, 'wb') as f:
|
||||
pickle.dump(metadata, f)
|
||||
save_signed(metadata_path, metadata)
|
||||
|
||||
@classmethod
|
||||
def load(cls, index_path: Path, metadata_path: Path, use_gpu: bool = False) -> 'FAISSManager':
|
||||
@@ -529,11 +536,22 @@ class FAISSManager:
|
||||
Returns:
|
||||
FAISSManager chargé
|
||||
"""
|
||||
# Charger métadonnées
|
||||
with open(metadata_path, 'rb') as f:
|
||||
metadata = pickle.load(f)
|
||||
# Charger métadonnées (JSON signé ; fallback legacy pickle avec migration).
|
||||
try:
|
||||
metadata = load_signed(metadata_path)
|
||||
except SignatureVerificationError:
|
||||
logger.error(
|
||||
"Signature HMAC invalide pour %s — refus de chargement.",
|
||||
metadata_path,
|
||||
)
|
||||
raise
|
||||
|
||||
# Créer instance
|
||||
# Reconvertir les clés int du metadata_store (JSON force des clés str).
|
||||
if isinstance(metadata.get("metadata_store"), dict):
|
||||
metadata["metadata_store"] = {
|
||||
int(k) if isinstance(k, str) and k.lstrip("-").isdigit() else k: v
|
||||
for k, v in metadata["metadata_store"].items()
|
||||
}
|
||||
manager = cls(
|
||||
dimensions=metadata["dimensions"],
|
||||
index_type=metadata["index_type"],
|
||||
|
||||
@@ -525,11 +525,25 @@ class DAGExecutor:
|
||||
True/False selon le résultat de la condition
|
||||
"""
|
||||
condition = action.get("condition", "True")
|
||||
# Contexte d'évaluation sécurisé : uniquement les résultats
|
||||
# Contexte d'évaluation sécurisé : uniquement les résultats.
|
||||
# NB : on utilise un évaluateur AST restreint (pas d'eval/exec),
|
||||
# seuls literals, comparaisons, booléens et indexations sont permis.
|
||||
eval_context = {"results": dict(self._results)}
|
||||
|
||||
# Import local pour éviter une dépendance circulaire au chargement.
|
||||
from core.execution.safe_condition_evaluator import (
|
||||
UnsafeExpressionError,
|
||||
safe_eval_condition,
|
||||
)
|
||||
|
||||
try:
|
||||
result = bool(eval(condition, {"__builtins__": {}}, eval_context))
|
||||
result = bool(safe_eval_condition(condition, eval_context))
|
||||
except UnsafeExpressionError as exc:
|
||||
logger.error(
|
||||
"Condition refusée pour '%s' (expression non sûre) : %s",
|
||||
step.step_id, exc,
|
||||
)
|
||||
result = False
|
||||
except Exception as exc:
|
||||
logger.warning(
|
||||
"Erreur d'évaluation de condition pour '%s' : %s",
|
||||
|
||||
228
core/execution/safe_condition_evaluator.py
Normal file
228
core/execution/safe_condition_evaluator.py
Normal file
@@ -0,0 +1,228 @@
|
||||
"""
|
||||
Évaluateur de conditions sécurisé pour le DAGExecutor.
|
||||
|
||||
Remplace `eval()` (vulnérable à l'exécution de code arbitraire) par un
|
||||
parseur AST restreint :
|
||||
|
||||
- Seuls les noeuds AST nécessaires sont autorisés (literals, comparaisons,
|
||||
booléens, indexations, accès attribut limité, arithmétique simple).
|
||||
- Les appels de fonction sont interdits.
|
||||
- Les accès à des attributs « dunder » (`__class__`, `__import__`, etc.)
|
||||
sont systématiquement refusés pour éviter les évasions classiques.
|
||||
- Le contexte d'évaluation est fourni explicitement par l'appelant ;
|
||||
aucun builtins n'est exposé.
|
||||
|
||||
Usage typique :
|
||||
>>> evaluator = SafeConditionEvaluator()
|
||||
>>> evaluator.evaluate("results['step_1']['score'] >= 0.8",
|
||||
... {"results": {"step_1": {"score": 0.92}}})
|
||||
True
|
||||
"""
|
||||
|
||||
from __future__ import annotations
|
||||
|
||||
import ast
|
||||
import operator
|
||||
from typing import Any, Callable, Dict, Mapping
|
||||
|
||||
|
||||
class UnsafeExpressionError(ValueError):
|
||||
"""Levée lorsqu'une expression contient un noeud AST interdit."""
|
||||
|
||||
|
||||
# Opérateurs arithmétiques & de comparaison autorisés.
|
||||
_BIN_OPS: Dict[type, Callable[[Any, Any], Any]] = {
|
||||
ast.Add: operator.add,
|
||||
ast.Sub: operator.sub,
|
||||
ast.Mult: operator.mul,
|
||||
ast.Div: operator.truediv,
|
||||
ast.FloorDiv: operator.floordiv,
|
||||
ast.Mod: operator.mod,
|
||||
ast.Pow: operator.pow,
|
||||
}
|
||||
|
||||
_BOOL_OPS: Dict[type, Callable[[Any, Any], Any]] = {
|
||||
ast.And: lambda a, b: a and b,
|
||||
ast.Or: lambda a, b: a or b,
|
||||
}
|
||||
|
||||
_UNARY_OPS: Dict[type, Callable[[Any], Any]] = {
|
||||
ast.Not: operator.not_,
|
||||
ast.USub: operator.neg,
|
||||
ast.UAdd: operator.pos,
|
||||
}
|
||||
|
||||
_CMP_OPS: Dict[type, Callable[[Any, Any], bool]] = {
|
||||
ast.Eq: operator.eq,
|
||||
ast.NotEq: operator.ne,
|
||||
ast.Lt: operator.lt,
|
||||
ast.LtE: operator.le,
|
||||
ast.Gt: operator.gt,
|
||||
ast.GtE: operator.ge,
|
||||
ast.In: lambda a, b: a in b,
|
||||
ast.NotIn: lambda a, b: a not in b,
|
||||
ast.Is: operator.is_,
|
||||
ast.IsNot: operator.is_not,
|
||||
}
|
||||
|
||||
|
||||
class SafeConditionEvaluator:
|
||||
"""Évalue une expression de condition via un parseur AST restreint."""
|
||||
|
||||
# Longueur max — stoppe les expressions pathologiques très tôt.
|
||||
MAX_EXPRESSION_LENGTH = 1024
|
||||
|
||||
def evaluate(
|
||||
self,
|
||||
expression: str,
|
||||
context: Mapping[str, Any],
|
||||
) -> Any:
|
||||
if not isinstance(expression, str):
|
||||
raise UnsafeExpressionError(
|
||||
"L'expression doit être une chaîne de caractères."
|
||||
)
|
||||
if len(expression) > self.MAX_EXPRESSION_LENGTH:
|
||||
raise UnsafeExpressionError(
|
||||
"Expression trop longue (> 1024 caractères)."
|
||||
)
|
||||
|
||||
try:
|
||||
tree = ast.parse(expression, mode="eval")
|
||||
except SyntaxError as exc:
|
||||
raise UnsafeExpressionError(
|
||||
f"Syntaxe d'expression invalide : {exc}"
|
||||
) from exc
|
||||
|
||||
return self._eval_node(tree.body, context)
|
||||
|
||||
# ------------------------------------------------------------------
|
||||
# Dispatch AST
|
||||
# ------------------------------------------------------------------
|
||||
|
||||
def _eval_node(self, node: ast.AST, context: Mapping[str, Any]) -> Any:
|
||||
# Littéraux (Constant remplace Num/Str/Bytes/NameConstant depuis 3.8)
|
||||
if isinstance(node, ast.Constant):
|
||||
return node.value
|
||||
|
||||
# Variables : uniquement celles présentes dans `context`.
|
||||
if isinstance(node, ast.Name):
|
||||
if node.id not in context:
|
||||
raise UnsafeExpressionError(
|
||||
f"Variable '{node.id}' non autorisée."
|
||||
)
|
||||
return context[node.id]
|
||||
|
||||
# Accès attribut — interdit tout attribut dunder.
|
||||
if isinstance(node, ast.Attribute):
|
||||
if node.attr.startswith("_"):
|
||||
raise UnsafeExpressionError(
|
||||
f"Accès à l'attribut privé '{node.attr}' interdit."
|
||||
)
|
||||
value = self._eval_node(node.value, context)
|
||||
return getattr(value, node.attr)
|
||||
|
||||
# Indexation (results['step_1']).
|
||||
if isinstance(node, ast.Subscript):
|
||||
value = self._eval_node(node.value, context)
|
||||
# Python < 3.9 utilise ast.Index, >= 3.9 utilise directement un
|
||||
# noeud. On gère les deux cas.
|
||||
slice_node = node.slice
|
||||
if isinstance(slice_node, ast.Index): # type: ignore[attr-defined]
|
||||
slice_value = self._eval_node(
|
||||
slice_node.value, context # type: ignore[attr-defined]
|
||||
)
|
||||
else:
|
||||
slice_value = self._eval_node(slice_node, context)
|
||||
return value[slice_value]
|
||||
|
||||
# Comparaisons chaînées (a < b <= c).
|
||||
if isinstance(node, ast.Compare):
|
||||
left = self._eval_node(node.left, context)
|
||||
for op_node, comparator in zip(node.ops, node.comparators):
|
||||
op_cls = type(op_node)
|
||||
if op_cls not in _CMP_OPS:
|
||||
raise UnsafeExpressionError(
|
||||
f"Opérateur de comparaison '{op_cls.__name__}' interdit."
|
||||
)
|
||||
right = self._eval_node(comparator, context)
|
||||
if not _CMP_OPS[op_cls](left, right):
|
||||
return False
|
||||
left = right
|
||||
return True
|
||||
|
||||
# Booléen (and / or) — short-circuit manuel.
|
||||
if isinstance(node, ast.BoolOp):
|
||||
op_cls = type(node.op)
|
||||
if op_cls not in _BOOL_OPS:
|
||||
raise UnsafeExpressionError(
|
||||
f"Opérateur booléen '{op_cls.__name__}' interdit."
|
||||
)
|
||||
if isinstance(node.op, ast.And):
|
||||
result: Any = True
|
||||
for sub in node.values:
|
||||
result = self._eval_node(sub, context)
|
||||
if not result:
|
||||
return result
|
||||
return result
|
||||
# Or
|
||||
result = False
|
||||
for sub in node.values:
|
||||
result = self._eval_node(sub, context)
|
||||
if result:
|
||||
return result
|
||||
return result
|
||||
|
||||
# Unaires (-x, not x)
|
||||
if isinstance(node, ast.UnaryOp):
|
||||
op_cls = type(node.op)
|
||||
if op_cls not in _UNARY_OPS:
|
||||
raise UnsafeExpressionError(
|
||||
f"Opérateur unaire '{op_cls.__name__}' interdit."
|
||||
)
|
||||
return _UNARY_OPS[op_cls](self._eval_node(node.operand, context))
|
||||
|
||||
# Binaires (+, -, *, /, %, **, //)
|
||||
if isinstance(node, ast.BinOp):
|
||||
op_cls = type(node.op)
|
||||
if op_cls not in _BIN_OPS:
|
||||
raise UnsafeExpressionError(
|
||||
f"Opérateur binaire '{op_cls.__name__}' interdit."
|
||||
)
|
||||
left = self._eval_node(node.left, context)
|
||||
right = self._eval_node(node.right, context)
|
||||
return _BIN_OPS[op_cls](left, right)
|
||||
|
||||
# Literals composites
|
||||
if isinstance(node, ast.Tuple):
|
||||
return tuple(self._eval_node(e, context) for e in node.elts)
|
||||
if isinstance(node, ast.List):
|
||||
return [self._eval_node(e, context) for e in node.elts]
|
||||
if isinstance(node, ast.Set):
|
||||
return {self._eval_node(e, context) for e in node.elts}
|
||||
if isinstance(node, ast.Dict):
|
||||
return {
|
||||
self._eval_node(k, context) if k is not None else None:
|
||||
self._eval_node(v, context)
|
||||
for k, v in zip(node.keys, node.values)
|
||||
}
|
||||
|
||||
# Tout le reste (Call, Lambda, Comprehensions, Import, etc.) est
|
||||
# refusé explicitement.
|
||||
raise UnsafeExpressionError(
|
||||
f"Noeud AST '{type(node).__name__}' interdit dans les conditions."
|
||||
)
|
||||
|
||||
|
||||
def safe_eval_condition(
|
||||
expression: str,
|
||||
context: Mapping[str, Any],
|
||||
) -> Any:
|
||||
"""Helper fonctionnel : évalue `expression` avec le contexte donné."""
|
||||
return SafeConditionEvaluator().evaluate(expression, context)
|
||||
|
||||
|
||||
__all__ = [
|
||||
"SafeConditionEvaluator",
|
||||
"UnsafeExpressionError",
|
||||
"safe_eval_condition",
|
||||
]
|
||||
308
core/security/signed_serializer.py
Normal file
308
core/security/signed_serializer.py
Normal file
@@ -0,0 +1,308 @@
|
||||
"""
|
||||
Sérialiseur signé — RPA Vision V3
|
||||
|
||||
Remplace les usages de `pickle.load` (vulnérables à la désérialisation arbitraire
|
||||
de code) par une sérialisation JSON signée via HMAC-SHA256.
|
||||
|
||||
Principes :
|
||||
- Les données sont sérialisées en JSON (avec support des types numpy / datetime
|
||||
via un encodeur custom).
|
||||
- Une signature HMAC-SHA256 est calculée sur le JSON avec une clé secrète
|
||||
dérivée de `RPA_SIGNING_KEY` (ou, à défaut, de `TOKEN_SECRET_KEY`).
|
||||
- À la lecture, la signature est vérifiée AVANT tout parsing applicatif.
|
||||
- Rétrocompatibilité : un fallback `pickle.load` est disponible pour migrer
|
||||
les anciens fichiers. Il logue un WARNING et doit être suivi d'une
|
||||
ré-écriture en JSON signé.
|
||||
|
||||
ATTENTION : n'utiliser le fallback pickle que sur des fichiers dont la source
|
||||
est réputée sûre (locale + protégée). Le fallback est désactivable via la
|
||||
variable d'environnement `RPA_ALLOW_PICKLE_FALLBACK=0`.
|
||||
"""
|
||||
|
||||
from __future__ import annotations
|
||||
|
||||
import base64
|
||||
import hashlib
|
||||
import hmac
|
||||
import io
|
||||
import json
|
||||
import logging
|
||||
import os
|
||||
import pickle
|
||||
from datetime import datetime, timedelta
|
||||
from pathlib import Path
|
||||
from typing import Any, Callable, Optional, Union
|
||||
|
||||
import numpy as np
|
||||
|
||||
logger = logging.getLogger(__name__)
|
||||
|
||||
# -----------------------------------------------------------------------------
|
||||
# Clé de signature
|
||||
# -----------------------------------------------------------------------------
|
||||
|
||||
_SIGNATURE_ALGO = "sha256"
|
||||
_SIGNATURE_HEADER = b"RPA_SIGNED_V1\n" # Marqueur de format signé
|
||||
|
||||
|
||||
def _resolve_signing_key() -> bytes:
|
||||
"""Récupère la clé de signature HMAC.
|
||||
|
||||
Ordre de priorité :
|
||||
1. RPA_SIGNING_KEY (dédiée à la signature de fichiers)
|
||||
2. TOKEN_SECRET_KEY (clé déjà utilisée pour signer les tokens API)
|
||||
3. Clé dérivée en dev (avec WARNING)
|
||||
|
||||
La clé dev est stable pour une même machine (dérivée du hostname + path)
|
||||
afin que les lectures/écritures locales restent cohérentes en l'absence
|
||||
de configuration, tout en refusant de valider des fichiers produits
|
||||
ailleurs.
|
||||
"""
|
||||
explicit = os.getenv("RPA_SIGNING_KEY", "").strip()
|
||||
if explicit:
|
||||
return explicit.encode("utf-8")
|
||||
|
||||
fallback = os.getenv("TOKEN_SECRET_KEY", "").strip()
|
||||
if fallback:
|
||||
return fallback.encode("utf-8")
|
||||
|
||||
# Clé dev dérivée : non cryptographiquement sûre, juste pour éviter des
|
||||
# erreurs en dev local. On loggue explicitement.
|
||||
logger.warning(
|
||||
"RPA_SIGNING_KEY et TOKEN_SECRET_KEY non définis — "
|
||||
"utilisation d'une clé dérivée locale. "
|
||||
"Définir RPA_SIGNING_KEY en production."
|
||||
)
|
||||
seed = f"rpa-vision-v3::{os.uname().nodename}::dev-signing" # type: ignore[attr-defined]
|
||||
return hashlib.sha256(seed.encode("utf-8")).digest()
|
||||
|
||||
|
||||
# -----------------------------------------------------------------------------
|
||||
# Encodage JSON étendu (numpy, datetime, Path, bytes)
|
||||
# -----------------------------------------------------------------------------
|
||||
|
||||
class _RPAJSONEncoder(json.JSONEncoder):
|
||||
"""Encodeur JSON supportant numpy / datetime / Path / bytes."""
|
||||
|
||||
def default(self, obj: Any) -> Any: # noqa: D401 - API json standard
|
||||
if isinstance(obj, np.ndarray):
|
||||
return {
|
||||
"__type__": "ndarray",
|
||||
"dtype": str(obj.dtype),
|
||||
"shape": list(obj.shape),
|
||||
"data": base64.b64encode(obj.tobytes()).decode("ascii"),
|
||||
}
|
||||
if isinstance(obj, (np.integer,)):
|
||||
return int(obj)
|
||||
if isinstance(obj, (np.floating,)):
|
||||
return float(obj)
|
||||
if isinstance(obj, (np.bool_,)):
|
||||
return bool(obj)
|
||||
if isinstance(obj, datetime):
|
||||
return {"__type__": "datetime", "iso": obj.isoformat()}
|
||||
if isinstance(obj, timedelta):
|
||||
return {"__type__": "timedelta", "seconds": obj.total_seconds()}
|
||||
if isinstance(obj, Path):
|
||||
return {"__type__": "path", "value": str(obj)}
|
||||
if isinstance(obj, bytes):
|
||||
return {
|
||||
"__type__": "bytes",
|
||||
"data": base64.b64encode(obj).decode("ascii"),
|
||||
}
|
||||
if isinstance(obj, set):
|
||||
return {"__type__": "set", "items": list(obj)}
|
||||
return super().default(obj)
|
||||
|
||||
|
||||
def _json_object_hook(obj: Any) -> Any:
|
||||
"""Reconstruit les types étendus depuis le JSON."""
|
||||
if not isinstance(obj, dict):
|
||||
return obj
|
||||
tag = obj.get("__type__")
|
||||
if tag is None:
|
||||
return obj
|
||||
if tag == "ndarray":
|
||||
raw = base64.b64decode(obj["data"])
|
||||
arr = np.frombuffer(raw, dtype=np.dtype(obj["dtype"]))
|
||||
return arr.reshape(obj["shape"]).copy()
|
||||
if tag == "datetime":
|
||||
return datetime.fromisoformat(obj["iso"])
|
||||
if tag == "timedelta":
|
||||
return timedelta(seconds=float(obj["seconds"]))
|
||||
if tag == "path":
|
||||
return Path(obj["value"])
|
||||
if tag == "bytes":
|
||||
return base64.b64decode(obj["data"])
|
||||
if tag == "set":
|
||||
return set(obj.get("items", []))
|
||||
return obj
|
||||
|
||||
|
||||
# -----------------------------------------------------------------------------
|
||||
# Erreurs dédiées
|
||||
# -----------------------------------------------------------------------------
|
||||
|
||||
class SignedSerializerError(Exception):
|
||||
"""Erreur de base du module."""
|
||||
|
||||
|
||||
class SignatureVerificationError(SignedSerializerError):
|
||||
"""Signature HMAC invalide : le fichier a été altéré ou forgé."""
|
||||
|
||||
|
||||
class UnsupportedFormatError(SignedSerializerError):
|
||||
"""Le fichier n'est ni au format signé, ni reconnu comme pickle legacy."""
|
||||
|
||||
|
||||
# -----------------------------------------------------------------------------
|
||||
# API publique
|
||||
# -----------------------------------------------------------------------------
|
||||
|
||||
def _compute_hmac(payload: bytes, key: bytes) -> str:
|
||||
return hmac.new(key, payload, hashlib.sha256).hexdigest()
|
||||
|
||||
|
||||
def dumps_signed(data: Any, key: Optional[bytes] = None) -> bytes:
|
||||
"""Sérialise `data` en JSON signé HMAC-SHA256.
|
||||
|
||||
Format binaire retourné :
|
||||
b"RPA_SIGNED_V1\n" + utf8(json({"hmac": "<hex>", "payload": <data>}))
|
||||
|
||||
Le HMAC couvre le JSON canonique de `payload` (keys triées,
|
||||
séparateurs compacts) pour qu'un même objet produise toujours la
|
||||
même signature.
|
||||
"""
|
||||
signing_key = key if key is not None else _resolve_signing_key()
|
||||
payload_json = json.dumps(
|
||||
data,
|
||||
cls=_RPAJSONEncoder,
|
||||
sort_keys=True,
|
||||
separators=(",", ":"),
|
||||
ensure_ascii=False,
|
||||
).encode("utf-8")
|
||||
signature = _compute_hmac(payload_json, signing_key)
|
||||
envelope = {"hmac": signature, "payload_b64": base64.b64encode(payload_json).decode("ascii")}
|
||||
body = json.dumps(envelope, separators=(",", ":"), ensure_ascii=False).encode("utf-8")
|
||||
return _SIGNATURE_HEADER + body
|
||||
|
||||
|
||||
def loads_signed(raw: bytes, key: Optional[bytes] = None) -> Any:
|
||||
"""Désérialise un blob produit par `dumps_signed` après vérification HMAC."""
|
||||
if not raw.startswith(_SIGNATURE_HEADER):
|
||||
raise UnsupportedFormatError("Marqueur RPA_SIGNED_V1 absent.")
|
||||
signing_key = key if key is not None else _resolve_signing_key()
|
||||
body = raw[len(_SIGNATURE_HEADER):]
|
||||
try:
|
||||
envelope = json.loads(body.decode("utf-8"))
|
||||
except (UnicodeDecodeError, json.JSONDecodeError) as exc:
|
||||
raise SignedSerializerError(f"Enveloppe JSON invalide : {exc}") from exc
|
||||
|
||||
if not isinstance(envelope, dict):
|
||||
raise SignedSerializerError("Enveloppe inattendue.")
|
||||
signature = envelope.get("hmac")
|
||||
payload_b64 = envelope.get("payload_b64")
|
||||
if not isinstance(signature, str) or not isinstance(payload_b64, str):
|
||||
raise SignedSerializerError("Enveloppe mal formée (hmac / payload_b64).")
|
||||
|
||||
try:
|
||||
payload_bytes = base64.b64decode(payload_b64.encode("ascii"), validate=True)
|
||||
except Exception as exc: # noqa: BLE001 - base64 peut lever plusieurs erreurs
|
||||
raise SignedSerializerError(f"Payload base64 invalide : {exc}") from exc
|
||||
|
||||
expected = _compute_hmac(payload_bytes, signing_key)
|
||||
if not hmac.compare_digest(expected, signature):
|
||||
raise SignatureVerificationError(
|
||||
"Signature HMAC invalide — fichier altéré ou clé différente."
|
||||
)
|
||||
|
||||
return json.loads(payload_bytes.decode("utf-8"), object_hook=_json_object_hook)
|
||||
|
||||
|
||||
def _pickle_fallback_allowed() -> bool:
|
||||
return os.getenv("RPA_ALLOW_PICKLE_FALLBACK", "1") != "0"
|
||||
|
||||
|
||||
def save_signed(path: Union[str, Path], data: Any, key: Optional[bytes] = None) -> None:
|
||||
"""Écrit `data` sur disque dans le format JSON signé."""
|
||||
path = Path(path)
|
||||
path.parent.mkdir(parents=True, exist_ok=True)
|
||||
blob = dumps_signed(data, key=key)
|
||||
tmp = path.with_suffix(path.suffix + ".tmp")
|
||||
with open(tmp, "wb") as fp:
|
||||
fp.write(blob)
|
||||
os.replace(tmp, path)
|
||||
|
||||
|
||||
def load_signed(
|
||||
path: Union[str, Path],
|
||||
*,
|
||||
allow_pickle_fallback: bool = True,
|
||||
migrate_on_fallback: bool = True,
|
||||
pickle_loader: Optional[Callable[[io.BufferedReader], Any]] = None,
|
||||
key: Optional[bytes] = None,
|
||||
) -> Any:
|
||||
"""Charge un fichier sauvegardé par `save_signed`.
|
||||
|
||||
Si le fichier n'est pas au format signé, et si `allow_pickle_fallback`
|
||||
est vrai (ET `RPA_ALLOW_PICKLE_FALLBACK != "0"`), tente un
|
||||
`pickle.load()` pour migrer les anciens fichiers. Dans ce cas, un
|
||||
WARNING est émis et le fichier est ré-écrit en JSON signé si
|
||||
`migrate_on_fallback` vaut True.
|
||||
|
||||
Args:
|
||||
path: Chemin du fichier
|
||||
allow_pickle_fallback: Activer la compat legacy
|
||||
migrate_on_fallback: Ré-écrire en JSON signé après fallback
|
||||
pickle_loader: Callable alternatif (pour tests / restricted unpickler)
|
||||
key: Clé HMAC explicite (sinon dérivée de l'environnement)
|
||||
|
||||
Raises:
|
||||
SignatureVerificationError: HMAC invalide (fichier altéré)
|
||||
UnsupportedFormatError: format inconnu et fallback désactivé
|
||||
"""
|
||||
path = Path(path)
|
||||
with open(path, "rb") as fp:
|
||||
raw = fp.read()
|
||||
|
||||
if raw.startswith(_SIGNATURE_HEADER):
|
||||
return loads_signed(raw, key=key)
|
||||
|
||||
if not allow_pickle_fallback or not _pickle_fallback_allowed():
|
||||
raise UnsupportedFormatError(
|
||||
f"{path} n'est pas au format signé et le fallback pickle est désactivé."
|
||||
)
|
||||
|
||||
logger.warning(
|
||||
"Chargement legacy pickle pour %s — ce format est obsolète et "
|
||||
"sera ré-écrit en JSON signé. Voir docs/SECURITY.md.",
|
||||
path,
|
||||
)
|
||||
|
||||
# Par défaut on refuse tout type non documenté dans ce fichier à risque :
|
||||
# utilisateur peut fournir un `pickle_loader` custom (ex: Unpickler
|
||||
# restreint). On log l'ouverture pour la traçabilité.
|
||||
loader = pickle_loader or (lambda f: pickle.load(f)) # noqa: S301 - usage legacy
|
||||
with open(path, "rb") as fp:
|
||||
data = loader(fp)
|
||||
|
||||
if migrate_on_fallback:
|
||||
try:
|
||||
save_signed(path, data, key=key)
|
||||
logger.info("Fichier %s migré en JSON signé.", path)
|
||||
except Exception as exc: # noqa: BLE001
|
||||
logger.error(
|
||||
"Migration JSON signé échouée pour %s : %s", path, exc
|
||||
)
|
||||
|
||||
return data
|
||||
|
||||
|
||||
__all__ = [
|
||||
"SignedSerializerError",
|
||||
"SignatureVerificationError",
|
||||
"UnsupportedFormatError",
|
||||
"dumps_signed",
|
||||
"loads_signed",
|
||||
"save_signed",
|
||||
"load_signed",
|
||||
]
|
||||
@@ -26,11 +26,15 @@ from PIL import Image
|
||||
import logging
|
||||
import threading
|
||||
from concurrent.futures import ThreadPoolExecutor
|
||||
import pickle
|
||||
import os
|
||||
|
||||
from core.models import BBox
|
||||
from core.embedding.fusion_engine import FusionEngine
|
||||
from core.security.signed_serializer import (
|
||||
SignatureVerificationError,
|
||||
load_signed,
|
||||
save_signed,
|
||||
)
|
||||
|
||||
logger = logging.getLogger(__name__)
|
||||
|
||||
@@ -521,29 +525,74 @@ class VisualEmbeddingManager:
|
||||
|
||||
logger.debug(f"Éviction de {num_to_remove} entrées du cache")
|
||||
|
||||
def _entry_to_dict(self, entry: "EmbeddingCacheEntry") -> Dict[str, Any]:
|
||||
"""Convertit une entrée du cache en dict JSON-serialisable."""
|
||||
return {
|
||||
"embedding": entry.embedding, # numpy → encodé par signed_serializer
|
||||
"signature": entry.signature,
|
||||
"created_at": entry.created_at,
|
||||
"access_count": entry.access_count,
|
||||
"last_accessed": entry.last_accessed,
|
||||
}
|
||||
|
||||
def _dict_to_entry(self, data: Any) -> Optional["EmbeddingCacheEntry"]:
|
||||
"""Reconstruit une EmbeddingCacheEntry depuis un dict (format JSON)
|
||||
ou depuis un objet déjà typé (fallback pickle legacy).
|
||||
Retourne None si la donnée n'est pas exploitable.
|
||||
"""
|
||||
if isinstance(data, EmbeddingCacheEntry):
|
||||
return data
|
||||
if not isinstance(data, dict):
|
||||
return None
|
||||
try:
|
||||
return EmbeddingCacheEntry(
|
||||
embedding=np.asarray(data["embedding"]),
|
||||
signature=data["signature"],
|
||||
created_at=data["created_at"],
|
||||
access_count=int(data.get("access_count", 0)),
|
||||
last_accessed=data.get("last_accessed"),
|
||||
)
|
||||
except (KeyError, TypeError, ValueError) as exc:
|
||||
logger.warning(f"Entrée de cache invalide ignorée: {exc}")
|
||||
return None
|
||||
|
||||
def _load_persistent_cache(self):
|
||||
"""Charge le cache persistant depuis le disque"""
|
||||
"""Charge le cache persistant depuis le disque (JSON signé HMAC,
|
||||
fallback pickle legacy avec migration automatique)."""
|
||||
if not self.cache_persistence_path or not os.path.exists(self.cache_persistence_path):
|
||||
return
|
||||
|
||||
try:
|
||||
with open(self.cache_persistence_path, 'rb') as f:
|
||||
cached_data = pickle.load(f)
|
||||
|
||||
# Filtrer les entrées trop anciennes (plus de 24h)
|
||||
cutoff_time = datetime.now() - timedelta(hours=24)
|
||||
|
||||
for signature, entry in cached_data.items():
|
||||
if entry.created_at > cutoff_time:
|
||||
self._embedding_cache[signature] = entry
|
||||
|
||||
logger.info(f"Cache persistant chargé: {len(self._embedding_cache)} entrées")
|
||||
|
||||
cached_data = load_signed(self.cache_persistence_path)
|
||||
except SignatureVerificationError:
|
||||
logger.error(
|
||||
"Cache persistant %s altéré (HMAC invalide) — ignoré.",
|
||||
self.cache_persistence_path,
|
||||
)
|
||||
return
|
||||
except Exception as e:
|
||||
logger.warning(f"Erreur lors du chargement du cache persistant: {e}")
|
||||
return
|
||||
|
||||
if not isinstance(cached_data, dict):
|
||||
logger.warning("Format de cache inattendu — ignoré.")
|
||||
return
|
||||
|
||||
# Filtrer les entrées trop anciennes (plus de 24h)
|
||||
cutoff_time = datetime.now() - timedelta(hours=24)
|
||||
loaded = 0
|
||||
for signature, raw in cached_data.items():
|
||||
entry = self._dict_to_entry(raw)
|
||||
if entry is None:
|
||||
continue
|
||||
if entry.created_at > cutoff_time:
|
||||
self._embedding_cache[signature] = entry
|
||||
loaded += 1
|
||||
|
||||
logger.info(f"Cache persistant chargé: {loaded} entrées")
|
||||
|
||||
def _save_persistent_cache(self):
|
||||
"""Sauvegarde le cache sur disque"""
|
||||
"""Sauvegarde le cache sur disque en JSON signé HMAC."""
|
||||
if not self.cache_persistence_path:
|
||||
return
|
||||
|
||||
@@ -552,9 +601,12 @@ class VisualEmbeddingManager:
|
||||
os.makedirs(os.path.dirname(self.cache_persistence_path), exist_ok=True)
|
||||
|
||||
with self._cache_lock:
|
||||
with open(self.cache_persistence_path, 'wb') as f:
|
||||
pickle.dump(dict(self._embedding_cache), f)
|
||||
serializable = {
|
||||
signature: self._entry_to_dict(entry)
|
||||
for signature, entry in self._embedding_cache.items()
|
||||
}
|
||||
|
||||
save_signed(self.cache_persistence_path, serializable)
|
||||
logger.debug("Cache persistant sauvegardé")
|
||||
|
||||
except Exception as e:
|
||||
|
||||
@@ -14,8 +14,9 @@ import asyncio
|
||||
import logging
|
||||
import json
|
||||
import base64
|
||||
import pickle
|
||||
import gzip
|
||||
import pickle # noqa: S403 - usage legacy restreint au fallback de migration
|
||||
import io
|
||||
from typing import Dict, List, Optional, Any, Tuple
|
||||
from dataclasses import dataclass, asdict
|
||||
from datetime import datetime
|
||||
@@ -24,6 +25,12 @@ import numpy as np
|
||||
|
||||
from core.visual.visual_target_manager import VisualTarget, VisualTargetManager
|
||||
from core.visual.screenshot_validation_manager import ScreenshotValidationManager, ValidationResult
|
||||
from core.security.signed_serializer import (
|
||||
SignatureVerificationError,
|
||||
UnsupportedFormatError,
|
||||
dumps_signed,
|
||||
loads_signed,
|
||||
)
|
||||
|
||||
logger = logging.getLogger(__name__)
|
||||
|
||||
@@ -435,7 +442,7 @@ class VisualPersistenceManager:
|
||||
return None
|
||||
|
||||
async def _serialize_workflow_data(self, workflow_data: VisualWorkflowData) -> bytes:
|
||||
"""Sérialise les données d'un workflow"""
|
||||
"""Sérialise les données d'un workflow en JSON signé HMAC."""
|
||||
# Convertir en dictionnaire
|
||||
data_dict = asdict(workflow_data)
|
||||
|
||||
@@ -456,13 +463,28 @@ class VisualPersistenceManager:
|
||||
]
|
||||
data_dict['validation_history'] = serialized_history
|
||||
|
||||
# Convertir en bytes
|
||||
return pickle.dumps(data_dict)
|
||||
# JSON signé HMAC (cf. core.security.signed_serializer)
|
||||
return dumps_signed(data_dict)
|
||||
|
||||
async def _deserialize_workflow_data(self, data: bytes) -> VisualWorkflowData:
|
||||
"""Désérialise les données d'un workflow"""
|
||||
# Désérialiser le dictionnaire
|
||||
data_dict = pickle.loads(data)
|
||||
"""Désérialise les données d'un workflow (JSON signé HMAC ;
|
||||
fallback pickle legacy avec WARNING pour migrer les anciens fichiers)."""
|
||||
try:
|
||||
data_dict = loads_signed(data)
|
||||
except SignatureVerificationError:
|
||||
# Fichier altéré ou clé différente : on refuse sans fallback.
|
||||
logger.error("Workflow visuel : signature HMAC invalide — refus.")
|
||||
raise
|
||||
except UnsupportedFormatError:
|
||||
# Ancien format pickle : fallback explicite et bruyant.
|
||||
import os
|
||||
if os.getenv("RPA_ALLOW_PICKLE_FALLBACK", "1") == "0":
|
||||
raise
|
||||
logger.warning(
|
||||
"Workflow visuel au format pickle legacy — lecture de compat, "
|
||||
"ré-écrire en JSON signé dès que possible."
|
||||
)
|
||||
data_dict = pickle.loads(data) # noqa: S301 - fallback legacy
|
||||
|
||||
# Reconstruire les objets
|
||||
workflow_data = VisualWorkflowData(
|
||||
|
||||
179
tests/unit/test_security_safe_condition.py
Normal file
179
tests/unit/test_security_safe_condition.py
Normal file
@@ -0,0 +1,179 @@
|
||||
"""Tests de sécurité : évaluateur de conditions AST restreint."""
|
||||
|
||||
from __future__ import annotations
|
||||
|
||||
import pytest
|
||||
|
||||
from core.execution.safe_condition_evaluator import (
|
||||
SafeConditionEvaluator,
|
||||
UnsafeExpressionError,
|
||||
safe_eval_condition,
|
||||
)
|
||||
|
||||
|
||||
# ---------------------------------------------------------------------------
|
||||
# Cas valides — expressions que les workflows doivent pouvoir évaluer
|
||||
# ---------------------------------------------------------------------------
|
||||
|
||||
class TestValidExpressions:
|
||||
def test_literal_true(self):
|
||||
assert safe_eval_condition("True", {}) is True
|
||||
|
||||
def test_literal_false(self):
|
||||
assert safe_eval_condition("False", {}) is False
|
||||
|
||||
def test_numeric_comparison(self):
|
||||
assert safe_eval_condition("1 < 2", {}) is True
|
||||
assert safe_eval_condition("2 < 1", {}) is False
|
||||
|
||||
def test_chained_comparison(self):
|
||||
assert safe_eval_condition("1 < 2 < 3", {}) is True
|
||||
assert safe_eval_condition("1 < 3 < 2", {}) is False
|
||||
|
||||
def test_variable_access(self):
|
||||
assert safe_eval_condition("x > 5", {"x": 10}) is True
|
||||
|
||||
def test_subscript_dict(self):
|
||||
ctx = {"results": {"step_1": {"score": 0.9}}}
|
||||
assert safe_eval_condition(
|
||||
"results['step_1']['score'] >= 0.8", ctx
|
||||
) is True
|
||||
|
||||
def test_boolean_and(self):
|
||||
assert safe_eval_condition("True and False", {}) is False
|
||||
assert safe_eval_condition("True and True", {}) is True
|
||||
|
||||
def test_boolean_or(self):
|
||||
assert safe_eval_condition("False or True", {}) is True
|
||||
|
||||
def test_not_operator(self):
|
||||
assert safe_eval_condition("not False", {}) is True
|
||||
|
||||
def test_arithmetic(self):
|
||||
assert safe_eval_condition("(a + b) * 2 > 10", {"a": 3, "b": 4}) is True
|
||||
|
||||
def test_in_operator(self):
|
||||
assert safe_eval_condition("'ok' in status", {"status": ["ok", "done"]}) is True
|
||||
|
||||
def test_list_literal(self):
|
||||
assert safe_eval_condition("x in [1, 2, 3]", {"x": 2}) is True
|
||||
|
||||
|
||||
# ---------------------------------------------------------------------------
|
||||
# Cas malveillants — tentatives d'injection / RCE
|
||||
# ---------------------------------------------------------------------------
|
||||
|
||||
class TestMaliciousExpressions:
|
||||
"""Toutes ces expressions DOIVENT lever UnsafeExpressionError."""
|
||||
|
||||
def test_rejects_import(self):
|
||||
with pytest.raises(UnsafeExpressionError):
|
||||
safe_eval_condition("__import__('os').system('echo pwn')", {})
|
||||
|
||||
def test_rejects_function_call(self):
|
||||
with pytest.raises(UnsafeExpressionError):
|
||||
safe_eval_condition("print('hello')", {"print": print})
|
||||
|
||||
def test_rejects_eval(self):
|
||||
with pytest.raises(UnsafeExpressionError):
|
||||
safe_eval_condition("eval('1+1')", {})
|
||||
|
||||
def test_rejects_exec(self):
|
||||
with pytest.raises(UnsafeExpressionError):
|
||||
safe_eval_condition("exec('x=1')", {})
|
||||
|
||||
def test_rejects_dunder_attribute(self):
|
||||
# Classique : remonter à __builtins__ via __class__.__mro__
|
||||
with pytest.raises(UnsafeExpressionError):
|
||||
safe_eval_condition("x.__class__", {"x": "abc"})
|
||||
|
||||
def test_rejects_dunder_subclasses(self):
|
||||
with pytest.raises(UnsafeExpressionError):
|
||||
safe_eval_condition(
|
||||
"x.__class__.__mro__[-1].__subclasses__()",
|
||||
{"x": []},
|
||||
)
|
||||
|
||||
def test_rejects_undefined_variable(self):
|
||||
with pytest.raises(UnsafeExpressionError):
|
||||
safe_eval_condition("secret > 0", {})
|
||||
|
||||
def test_rejects_lambda(self):
|
||||
with pytest.raises(UnsafeExpressionError):
|
||||
safe_eval_condition("(lambda: 42)()", {})
|
||||
|
||||
def test_rejects_list_comprehension(self):
|
||||
with pytest.raises(UnsafeExpressionError):
|
||||
safe_eval_condition("[x for x in range(3)]", {})
|
||||
|
||||
def test_rejects_generator(self):
|
||||
with pytest.raises(UnsafeExpressionError):
|
||||
safe_eval_condition("(x for x in [1])", {})
|
||||
|
||||
def test_rejects_walrus(self):
|
||||
with pytest.raises(UnsafeExpressionError):
|
||||
safe_eval_condition("(x := 1)", {})
|
||||
|
||||
def test_rejects_ifexp(self):
|
||||
# IfExp (conditional) non autorisé par défaut — si besoin ajouter plus tard.
|
||||
with pytest.raises(UnsafeExpressionError):
|
||||
safe_eval_condition("1 if True else 2", {})
|
||||
|
||||
def test_rejects_starred(self):
|
||||
with pytest.raises(UnsafeExpressionError):
|
||||
safe_eval_condition("[*x]", {"x": [1, 2]})
|
||||
|
||||
def test_rejects_attribute_call_chain(self):
|
||||
# Même si 'dict' est fourni dans le contexte, on n'autorise pas les
|
||||
# appels de méthode.
|
||||
with pytest.raises(UnsafeExpressionError):
|
||||
safe_eval_condition(
|
||||
"results.keys()", {"results": {"a": 1}}
|
||||
)
|
||||
|
||||
def test_rejects_huge_expression(self):
|
||||
big = "0+" * 1000 + "0"
|
||||
with pytest.raises(UnsafeExpressionError):
|
||||
safe_eval_condition(big, {})
|
||||
|
||||
def test_rejects_syntax_error(self):
|
||||
with pytest.raises(UnsafeExpressionError):
|
||||
safe_eval_condition("1 + ", {})
|
||||
|
||||
def test_rejects_non_string(self):
|
||||
with pytest.raises(UnsafeExpressionError):
|
||||
safe_eval_condition(12345, {}) # type: ignore[arg-type]
|
||||
|
||||
|
||||
# ---------------------------------------------------------------------------
|
||||
# Intégration avec DAGExecutor : le step condition doit refuser l'injection
|
||||
# ---------------------------------------------------------------------------
|
||||
|
||||
class TestDAGExecutorIntegration:
|
||||
def test_condition_step_refuses_malicious_payload(self):
|
||||
"""Un workflow injectant __import__ dans 'condition' doit être refusé
|
||||
silencieusement (result = False) sans exécuter le code."""
|
||||
from core.execution.dag_executor import DAGExecutor, WorkflowStep, StepType
|
||||
|
||||
executor = DAGExecutor()
|
||||
step = WorkflowStep(
|
||||
step_id="malicious",
|
||||
step_type=StepType.CONDITION,
|
||||
action={"condition": "__import__('os').system('echo PWNED')"},
|
||||
)
|
||||
# Accès direct à la méthode privée pour isoler le comportement.
|
||||
result = executor._execute_condition_step(step, step.action)
|
||||
assert result is False
|
||||
|
||||
def test_condition_step_accepts_safe_expression(self):
|
||||
from core.execution.dag_executor import DAGExecutor, WorkflowStep, StepType
|
||||
|
||||
executor = DAGExecutor()
|
||||
executor._results["step_prev"] = {"ok": True}
|
||||
step = WorkflowStep(
|
||||
step_id="cond",
|
||||
step_type=StepType.CONDITION,
|
||||
action={"condition": "results['step_prev']['ok']"},
|
||||
)
|
||||
result = executor._execute_condition_step(step, step.action)
|
||||
assert result is True
|
||||
239
tests/unit/test_security_signed_serializer.py
Normal file
239
tests/unit/test_security_signed_serializer.py
Normal file
@@ -0,0 +1,239 @@
|
||||
"""Tests de sécurité : sérialiseur JSON signé HMAC.
|
||||
|
||||
Couvre :
|
||||
- round-trip JSON signé
|
||||
- rejet d'un fichier altéré
|
||||
- fallback pickle legacy + migration
|
||||
- intégration FAISSManager (lecture / rejet HMAC)
|
||||
"""
|
||||
|
||||
from __future__ import annotations
|
||||
|
||||
import json
|
||||
import os
|
||||
import pickle
|
||||
from datetime import datetime
|
||||
from pathlib import Path
|
||||
|
||||
import numpy as np
|
||||
import pytest
|
||||
|
||||
from core.security.signed_serializer import (
|
||||
SignatureVerificationError,
|
||||
UnsupportedFormatError,
|
||||
dumps_signed,
|
||||
load_signed,
|
||||
loads_signed,
|
||||
save_signed,
|
||||
)
|
||||
|
||||
|
||||
@pytest.fixture(autouse=True)
|
||||
def _signing_key(monkeypatch):
|
||||
"""Force une clé de signature stable pour les tests."""
|
||||
monkeypatch.setenv("RPA_SIGNING_KEY", "test-signing-key-for-unit-tests-only")
|
||||
monkeypatch.setenv("RPA_ALLOW_PICKLE_FALLBACK", "1")
|
||||
yield
|
||||
|
||||
|
||||
# ---------------------------------------------------------------------------
|
||||
# Round-trip et types étendus
|
||||
# ---------------------------------------------------------------------------
|
||||
|
||||
class TestRoundTrip:
|
||||
def test_primitive_types(self, tmp_path: Path):
|
||||
payload = {"a": 1, "b": "texte", "c": [1, 2, 3], "d": None}
|
||||
path = tmp_path / "data.json.signed"
|
||||
save_signed(path, payload)
|
||||
assert load_signed(path) == payload
|
||||
|
||||
def test_numpy_roundtrip(self, tmp_path: Path):
|
||||
arr = np.arange(12, dtype=np.float32).reshape(3, 4)
|
||||
path = tmp_path / "arr.json.signed"
|
||||
save_signed(path, {"embedding": arr})
|
||||
loaded = load_signed(path)
|
||||
assert isinstance(loaded["embedding"], np.ndarray)
|
||||
assert loaded["embedding"].shape == (3, 4)
|
||||
assert loaded["embedding"].dtype == np.float32
|
||||
np.testing.assert_array_equal(loaded["embedding"], arr)
|
||||
|
||||
def test_datetime_roundtrip(self, tmp_path: Path):
|
||||
now = datetime(2026, 4, 13, 10, 0, 0)
|
||||
path = tmp_path / "dt.json.signed"
|
||||
save_signed(path, {"created_at": now})
|
||||
loaded = load_signed(path)
|
||||
assert loaded["created_at"] == now
|
||||
|
||||
def test_bytes_payload(self):
|
||||
raw = dumps_signed({"blob": b"\x00\x01\x02"})
|
||||
out = loads_signed(raw)
|
||||
assert out["blob"] == b"\x00\x01\x02"
|
||||
|
||||
|
||||
# ---------------------------------------------------------------------------
|
||||
# Rejet d'un fichier altéré
|
||||
# ---------------------------------------------------------------------------
|
||||
|
||||
class TestTampering:
|
||||
def test_rejects_tampered_payload(self, tmp_path: Path):
|
||||
path = tmp_path / "f.signed"
|
||||
save_signed(path, {"score": 0.5})
|
||||
|
||||
raw = path.read_bytes()
|
||||
# Altérer un caractère quelque part dans le payload base64.
|
||||
idx = raw.rfind(b'"payload_b64":"') + len(b'"payload_b64":"')
|
||||
tampered = raw[:idx] + (b"X" if raw[idx:idx + 1] != b"X" else b"Y") + raw[idx + 1:]
|
||||
path.write_bytes(tampered)
|
||||
|
||||
with pytest.raises((SignatureVerificationError, Exception)):
|
||||
load_signed(path)
|
||||
|
||||
def test_rejects_tampered_hmac(self, tmp_path: Path):
|
||||
path = tmp_path / "f.signed"
|
||||
save_signed(path, {"score": 0.5})
|
||||
|
||||
raw = path.read_bytes()
|
||||
tampered = raw.replace(b'"hmac":"', b'"hmac":"0')
|
||||
path.write_bytes(tampered)
|
||||
|
||||
with pytest.raises(SignatureVerificationError):
|
||||
load_signed(path)
|
||||
|
||||
def test_rejects_wrong_key(self, tmp_path: Path, monkeypatch):
|
||||
path = tmp_path / "f.signed"
|
||||
save_signed(path, {"score": 0.5})
|
||||
|
||||
# Changer la clé : la vérification doit échouer.
|
||||
monkeypatch.setenv("RPA_SIGNING_KEY", "other-key")
|
||||
with pytest.raises(SignatureVerificationError):
|
||||
load_signed(path)
|
||||
|
||||
|
||||
# ---------------------------------------------------------------------------
|
||||
# Fallback pickle + migration
|
||||
# ---------------------------------------------------------------------------
|
||||
|
||||
class TestPickleFallback:
|
||||
def test_pickle_fallback_loads_and_migrates(self, tmp_path: Path):
|
||||
# Écrire un vieux fichier pickle (format legacy).
|
||||
path = tmp_path / "legacy.pkl"
|
||||
payload = {"score": 0.42, "label": "legacy"}
|
||||
with open(path, "wb") as fp:
|
||||
pickle.dump(payload, fp)
|
||||
|
||||
# Chargement : doit réussir ET migrer le fichier en signé.
|
||||
loaded = load_signed(path, allow_pickle_fallback=True, migrate_on_fallback=True)
|
||||
assert loaded == payload
|
||||
|
||||
# Le fichier doit maintenant être au format signé.
|
||||
new_raw = path.read_bytes()
|
||||
assert new_raw.startswith(b"RPA_SIGNED_V1\n")
|
||||
|
||||
# Et relisable via le format signé.
|
||||
loaded2 = load_signed(path)
|
||||
assert loaded2 == payload
|
||||
|
||||
def test_pickle_fallback_disabled(self, tmp_path: Path, monkeypatch):
|
||||
monkeypatch.setenv("RPA_ALLOW_PICKLE_FALLBACK", "0")
|
||||
path = tmp_path / "legacy.pkl"
|
||||
with open(path, "wb") as fp:
|
||||
pickle.dump({"x": 1}, fp)
|
||||
|
||||
with pytest.raises(UnsupportedFormatError):
|
||||
load_signed(path)
|
||||
|
||||
def test_pickle_fallback_explicit_off(self, tmp_path: Path):
|
||||
path = tmp_path / "legacy.pkl"
|
||||
with open(path, "wb") as fp:
|
||||
pickle.dump({"x": 1}, fp)
|
||||
|
||||
with pytest.raises(UnsupportedFormatError):
|
||||
load_signed(path, allow_pickle_fallback=False)
|
||||
|
||||
|
||||
# ---------------------------------------------------------------------------
|
||||
# Intégration FAISSManager
|
||||
# ---------------------------------------------------------------------------
|
||||
|
||||
pytest.importorskip("faiss", reason="FAISS non installé.")
|
||||
|
||||
|
||||
class TestFAISSManagerSignedMetadata:
|
||||
def test_save_and_load_roundtrip(self, tmp_path: Path):
|
||||
from core.embedding.faiss_manager import FAISSManager
|
||||
|
||||
manager = FAISSManager(dimensions=8, index_type="Flat", metric="cosine")
|
||||
vec = np.random.rand(8).astype(np.float32)
|
||||
manager.add_embedding("emb_1", vec, metadata={"label": "target"})
|
||||
|
||||
index_path = tmp_path / "index.bin"
|
||||
meta_path = tmp_path / "meta.signed"
|
||||
manager.save(index_path, meta_path)
|
||||
|
||||
# Le fichier métadonnées doit être signé.
|
||||
raw = meta_path.read_bytes()
|
||||
assert raw.startswith(b"RPA_SIGNED_V1\n")
|
||||
|
||||
# Recharger.
|
||||
reloaded = FAISSManager.load(index_path, meta_path)
|
||||
assert reloaded.dimensions == 8
|
||||
assert reloaded.next_id == 1
|
||||
assert 0 in reloaded.metadata_store
|
||||
assert reloaded.metadata_store[0]["embedding_id"] == "emb_1"
|
||||
|
||||
def test_load_refuses_tampered_metadata(self, tmp_path: Path):
|
||||
from core.embedding.faiss_manager import FAISSManager
|
||||
|
||||
manager = FAISSManager(dimensions=4, index_type="Flat", metric="cosine")
|
||||
manager.add_embedding("e", np.ones(4, dtype=np.float32), metadata={})
|
||||
|
||||
index_path = tmp_path / "index.bin"
|
||||
meta_path = tmp_path / "meta.signed"
|
||||
manager.save(index_path, meta_path)
|
||||
|
||||
# Altérer la signature du fichier.
|
||||
raw = meta_path.read_bytes()
|
||||
meta_path.write_bytes(raw.replace(b'"hmac":"', b'"hmac":"0'))
|
||||
|
||||
with pytest.raises(SignatureVerificationError):
|
||||
FAISSManager.load(index_path, meta_path)
|
||||
|
||||
def test_load_migrates_legacy_pickle(self, tmp_path: Path):
|
||||
"""Un fichier métadonnées pickle legacy doit être migré."""
|
||||
from core.embedding.faiss_manager import FAISSManager
|
||||
import faiss
|
||||
|
||||
# Construire manuellement un fichier legacy (comme l'ancienne version).
|
||||
manager = FAISSManager(dimensions=4, index_type="Flat", metric="cosine")
|
||||
vec = np.ones(4, dtype=np.float32)
|
||||
manager.add_embedding("legacy_emb", vec, metadata={"tag": "old"})
|
||||
|
||||
index_path = tmp_path / "index.bin"
|
||||
meta_path = tmp_path / "meta.pkl"
|
||||
|
||||
# Écrire l'index FAISS normalement...
|
||||
index_to_save = manager.index
|
||||
faiss.write_index(index_to_save, str(index_path))
|
||||
|
||||
# ...mais les métadonnées en pickle brut (format pré-correctif).
|
||||
legacy = {
|
||||
"dimensions": 4,
|
||||
"index_type": "Flat",
|
||||
"metric": "cosine",
|
||||
"next_id": manager.next_id,
|
||||
"metadata_store": manager.metadata_store,
|
||||
"nlist": None,
|
||||
"nprobe": 8,
|
||||
"is_trained": True,
|
||||
"auto_optimize": True,
|
||||
}
|
||||
with open(meta_path, "wb") as fp:
|
||||
pickle.dump(legacy, fp)
|
||||
|
||||
# Chargement : doit réussir + migrer vers format signé.
|
||||
reloaded = FAISSManager.load(index_path, meta_path)
|
||||
assert reloaded.dimensions == 4
|
||||
assert reloaded.next_id == 1
|
||||
|
||||
# Le fichier a été ré-écrit en signé.
|
||||
assert meta_path.read_bytes().startswith(b"RPA_SIGNED_V1\n")
|
||||
Reference in New Issue
Block a user