feat(cache): ScreenStateCache clé composite context-aware (Lot D)

Avant : clé = phash seul
-> deux contextes différents avec même screenshot partageaient
la même entrée cache -> collisions silencieuses.

Après : clé composite {phash}|{md5(ctx)[:16]} avec ctx =
  - window_title
  - app_name
  - enable_ocr
  - enable_ui_detection
  - workflow_id (isolation inter-workflows)

get_or_compute() kwargs-only. TTL 2s et éviction LRU inchangés.
invalidate_if_changed() continue de comparer uniquement les phash.

ExecutionLoop propage tout le contexte au cache.

8 nouveaux tests prouvant :
  - même image + window différent = miss
  - même image + app différent = miss
  - même image + flags différents = miss
  - même image + workflow_id différent = miss
  - même image + même contexte = hit

Co-Authored-By: Claude Opus 4.6 (1M context) <noreply@anthropic.com>
This commit is contained in:
Dom
2026-04-15 09:06:51 +02:00
parent 9ca277a63f
commit c8a3618e27
2 changed files with 858 additions and 0 deletions

View File

@@ -0,0 +1,409 @@
"""
ScreenStateCache — Cache perceptuel de ScreenState (context-aware).
Objectif : éviter de réanalyser un screenshot identique (5-15s VLM/OCR)
à chaque step de la boucle d'exécution.
Principe (Lot D — avril 2026) :
- Clé = composite de 6 éléments pour éviter les collisions silencieuses
entre contextes différents partageant un même screenshot :
1. phash (dhash 8x8 du screenshot) — calculé en ~2-5ms
2. window_title (titre fenêtre active)
3. app_name (nom process actif)
4. enable_ocr (flag runtime)
5. enable_ui_detection (flag runtime)
6. workflow_id (isolation inter-workflows)
- TTL par défaut : 2 secondes (configurable)
- Invalidation explicite possible (par clé composite ou globale)
- invalidate_if_changed reste piloté par le phash seul (détection de
changement visuel majeur, indépendant du contexte)
- Thread-safe (lock interne)
API principale :
>>> cache = ScreenStateCache(ttl_seconds=2.0)
>>> state, hit, ms = cache.get_or_compute(
... screenshot_path, compute_fn,
... window_title="App", app_name="app.exe",
... enable_ocr=True, enable_ui_detection=True,
... workflow_id="wf_123",
... )
La fonction `compute_fn` prend le chemin du screenshot et doit retourner
un `ScreenState`. Elle n'est appelée qu'en cache miss.
"""
from __future__ import annotations
import hashlib
import logging
import threading
import time
from dataclasses import dataclass
from pathlib import Path
from typing import Callable, Optional, Tuple
from PIL import Image
from core.models.screen_state import ScreenState
logger = logging.getLogger(__name__)
# =============================================================================
# Hash perceptuel (dhash simple, sans dépendance imagehash)
# =============================================================================
def _hamming_distance_hex(a: str, b: str) -> int:
"""
Distance de Hamming entre deux chaînes hexadécimales de même longueur.
Retourne le nombre de bits qui diffèrent entre les deux hashes.
Si les longueurs diffèrent, on pad à droite par des zéros.
"""
if len(a) != len(b):
max_len = max(len(a), len(b))
a = a.ljust(max_len, "0")
b = b.ljust(max_len, "0")
try:
xor = int(a, 16) ^ int(b, 16)
return bin(xor).count("1")
except ValueError:
# Fallback : comparaison caractère à caractère
return sum(1 for ca, cb in zip(a, b) if ca != cb) * 4
def compute_perceptual_hash(screenshot_path: str, size: int = 8) -> str:
"""
Calculer un dhash (difference hash) pour un screenshot.
Algorithme :
1. Convertir en niveaux de gris
2. Redimensionner à (size+1) x size
3. Comparer chaque pixel avec son voisin de droite (dhash)
4. Retourner un hash hexadécimal de size*size bits
Robuste aux petites variations (curseur, blink, compression).
Coût typique : 2-5 ms sur un 1920x1080.
Args:
screenshot_path: Chemin vers le fichier image
size: Taille du hash (8 = 64 bits, défaut)
Returns:
Chaîne hexadécimale (size*size/4 caractères)
"""
try:
img = Image.open(screenshot_path)
img = img.convert("L").resize((size + 1, size), Image.LANCZOS)
pixels = list(img.getdata())
# dhash : comparer chaque pixel avec celui de droite
bits = []
for row in range(size):
for col in range(size):
left = pixels[row * (size + 1) + col]
right = pixels[row * (size + 1) + col + 1]
bits.append(1 if left > right else 0)
# Convertir en hex
value = 0
for bit in bits:
value = (value << 1) | bit
return format(value, f"0{size * size // 4}x")
except Exception as e:
logger.warning(f"Hash perceptuel échoué pour {screenshot_path}: {e}")
# Fallback : hash du contenu brut
try:
data = Path(screenshot_path).read_bytes()
return hashlib.md5(data).hexdigest()[:16]
except Exception:
return f"unhashable_{int(time.time() * 1000)}"
# =============================================================================
# Clé composite (Lot D)
# =============================================================================
def _make_cache_key(
phash: str,
window_title: str,
app_name: str,
enable_ocr: bool,
enable_ui_detection: bool,
workflow_id: str,
) -> str:
"""
Construire une clé composite stable pour le cache.
Combine les 6 dimensions du contexte d'exécution dans une chaîne
hexadécimale (md5 tronqué à 16 caractères), préfixée par le phash pour
conserver une lisibilité minimale en debug (log : `aabb…|ctx=1234…`).
NB : On hash plutôt que concaténer brut pour :
- Borner la taille de la clé même si window_title est long
- Éviter les collisions triviales (séparateur présent dans un titre)
- Rendre la clé opaque (pas de PII en clair dans les logs de cache)
Args:
phash: Hash perceptuel du screenshot (dhash 8x8)
window_title: Titre de la fenêtre active (str)
app_name: Nom du process actif (str)
enable_ocr: Flag runtime OCR (bool)
enable_ui_detection: Flag runtime détection UI (bool)
workflow_id: ID du workflow en cours (str, "" pour legacy)
Returns:
Clé composite `{phash}|{ctx_hash}` où ctx_hash = md5(16)
"""
# Sérialisation déterministe ; `|` comme séparateur interne puisque hashé.
ctx_repr = (
f"{window_title or ''}\x1f"
f"{app_name or ''}\x1f"
f"{int(bool(enable_ocr))}\x1f"
f"{int(bool(enable_ui_detection))}\x1f"
f"{workflow_id or ''}"
)
ctx_hash = hashlib.md5(ctx_repr.encode("utf-8")).hexdigest()[:16]
return f"{phash}|{ctx_hash}"
# =============================================================================
# Entry
# =============================================================================
@dataclass
class _CacheEntry:
state: ScreenState
created_at: float
phash: str # phash seul (utilisé par invalidate_if_changed)
# =============================================================================
# Cache
# =============================================================================
class ScreenStateCache:
"""
Cache de ScreenState avec TTL et clé composite context-aware.
Thread-safe. Utilise un lock interne pour les opérations get/set.
"""
def __init__(self, ttl_seconds: float = 2.0, max_entries: int = 16):
"""
Args:
ttl_seconds: Durée de vie d'une entrée (en secondes)
max_entries: Nombre max d'entrées avant éviction LRU simple
"""
self.ttl_seconds = ttl_seconds
self.max_entries = max_entries
# Clé = composite (_make_cache_key), valeur = _CacheEntry
self._store: dict[str, _CacheEntry] = {}
self._lock = threading.Lock()
# Métriques simples (utile pour le debug / logs)
self.hits = 0
self.misses = 0
self.invalidations = 0
# -------------------------------------------------------------------------
# API bas niveau (par clé composite)
# -------------------------------------------------------------------------
def _get(self, composite_key: str) -> Optional[ScreenState]:
"""Retourne l'entrée pour cette clé composite si encore valide."""
with self._lock:
entry = self._store.get(composite_key)
if entry is None:
return None
if time.time() - entry.created_at > self.ttl_seconds:
# Expiré
self._store.pop(composite_key, None)
return None
return entry.state
def _set(self, composite_key: str, phash: str, state: ScreenState) -> None:
"""Enregistre un état pour cette clé composite."""
with self._lock:
# Éviction simple : si plein, virer l'entrée la plus ancienne
if (
len(self._store) >= self.max_entries
and composite_key not in self._store
):
oldest_key = min(
self._store, key=lambda k: self._store[k].created_at
)
self._store.pop(oldest_key, None)
self._store[composite_key] = _CacheEntry(
state=state,
created_at=time.time(),
phash=phash,
)
def invalidate(self, composite_key: Optional[str] = None) -> None:
"""
Invalider une entrée ou tout le cache.
Args:
composite_key: Clé à invalider. Si None, vide tout le cache.
"""
with self._lock:
if composite_key is None:
self._store.clear()
else:
self._store.pop(composite_key, None)
self.invalidations += 1
def invalidate_if_changed(
self,
screenshot_path: str,
threshold: float = 0.3,
) -> bool:
"""
Invalider le cache si l'écran a suffisamment changé.
Compare le dhash du screenshot courant avec le phash (seul) de chaque
entrée du cache. La décision est volontairement indépendante du reste
de la clé composite : un changement visuel majeur rend toutes les
entrées obsolètes, quel que soit le contexte.
Args:
screenshot_path: Chemin du screenshot courant
threshold: Proportion de bits qui doivent différer (0.0-1.0).
0.3 = 30% (~19 bits sur 64) = changement significatif.
Returns:
True si le cache a été invalidé, False sinon.
"""
if not self._store:
return False
current_phash = compute_perceptual_hash(screenshot_path)
# Bits totaux : 64 pour un dhash 8x8 standard. On déduit via la
# longueur hexa du hash courant pour rester générique.
total_bits = len(current_phash) * 4
if total_bits == 0:
return False
threshold_bits = threshold * total_bits
with self._lock:
if not self._store:
return False
# Distance de Hamming minimale avec les phashes des entrées
# (on regarde entry.phash, pas la clé composite).
min_distance = None
for entry in self._store.values():
distance = _hamming_distance_hex(current_phash, entry.phash)
if min_distance is None or distance < min_distance:
min_distance = distance
if min_distance is not None and min_distance > threshold_bits:
size_before = len(self._store)
self._store.clear()
self.invalidations += 1
logger.debug(
f"[ScreenStateCache] invalidate_if_changed: "
f"distance={min_distance}/{total_bits} > "
f"threshold={threshold_bits:.1f}{size_before} entrées purgées"
)
return True
return False
# -------------------------------------------------------------------------
# API haut niveau (context-aware)
# -------------------------------------------------------------------------
def get_or_compute(
self,
screenshot_path: str,
compute_fn: Callable[[str], ScreenState],
*,
window_title: str = "",
app_name: str = "",
enable_ocr: bool = True,
enable_ui_detection: bool = True,
workflow_id: str = "",
force_refresh: bool = False,
) -> Tuple[ScreenState, bool, float]:
"""
Récupérer ou calculer le ScreenState pour un screenshot + contexte.
Clé de cache = composite(phash, window_title, app_name, enable_ocr,
enable_ui_detection, workflow_id). Deux contextes différents partageant
le même screenshot n'entrent PAS en collision.
Rétrocompatibilité : tous les kwargs de contexte ont une valeur par
défaut. Un caller legacy qui n'a pas encore été adapté partagera la
même entrée de cache qu'un autre caller legacy (comportement antérieur).
Args:
screenshot_path: Chemin du screenshot
compute_fn: Fonction qui construit un ScreenState si cache miss
window_title: Titre de la fenêtre active (contexte visuel)
app_name: Nom du process actif (contexte applicatif)
enable_ocr: Flag runtime — différencie états avec/sans OCR
enable_ui_detection: Flag runtime — différencie états avec/sans UI
workflow_id: ID du workflow — isolation inter-workflows
force_refresh: Ignorer le cache et recalculer
Returns:
Tuple (state, cache_hit, elapsed_ms)
"""
t0 = time.time()
phash = compute_perceptual_hash(screenshot_path)
composite_key = _make_cache_key(
phash=phash,
window_title=window_title,
app_name=app_name,
enable_ocr=enable_ocr,
enable_ui_detection=enable_ui_detection,
workflow_id=workflow_id,
)
if not force_refresh:
cached = self._get(composite_key)
if cached is not None:
self.hits += 1
elapsed_ms = (time.time() - t0) * 1000
logger.debug(
f"[ScreenStateCache] HIT key={composite_key[:24]}"
f"({elapsed_ms:.1f}ms)"
)
return cached, True, elapsed_ms
# Cache miss → calcul complet
self.misses += 1
state = compute_fn(screenshot_path)
self._set(composite_key, phash, state)
elapsed_ms = (time.time() - t0) * 1000
logger.debug(
f"[ScreenStateCache] MISS key={composite_key[:24]}"
f"({elapsed_ms:.1f}ms)"
)
return state, False, elapsed_ms
def stats(self) -> dict:
"""Retourne les métriques du cache."""
with self._lock:
total = self.hits + self.misses
return {
"hits": self.hits,
"misses": self.misses,
"invalidations": self.invalidations,
"hit_rate": self.hits / total if total > 0 else 0.0,
"size": len(self._store),
"max_entries": self.max_entries,
"ttl_seconds": self.ttl_seconds,
}
def __len__(self) -> int:
with self._lock:
return len(self._store)