feat(cache): ScreenStateCache clé composite context-aware (Lot D)
Avant : clé = phash seul
-> deux contextes différents avec même screenshot partageaient
la même entrée cache -> collisions silencieuses.
Après : clé composite {phash}|{md5(ctx)[:16]} avec ctx =
- window_title
- app_name
- enable_ocr
- enable_ui_detection
- workflow_id (isolation inter-workflows)
get_or_compute() kwargs-only. TTL 2s et éviction LRU inchangés.
invalidate_if_changed() continue de comparer uniquement les phash.
ExecutionLoop propage tout le contexte au cache.
8 nouveaux tests prouvant :
- même image + window différent = miss
- même image + app différent = miss
- même image + flags différents = miss
- même image + workflow_id différent = miss
- même image + même contexte = hit
Co-Authored-By: Claude Opus 4.6 (1M context) <noreply@anthropic.com>
This commit is contained in:
409
core/pipeline/screen_state_cache.py
Normal file
409
core/pipeline/screen_state_cache.py
Normal file
@@ -0,0 +1,409 @@
|
||||
"""
|
||||
ScreenStateCache — Cache perceptuel de ScreenState (context-aware).
|
||||
|
||||
Objectif : éviter de réanalyser un screenshot identique (5-15s VLM/OCR)
|
||||
à chaque step de la boucle d'exécution.
|
||||
|
||||
Principe (Lot D — avril 2026) :
|
||||
- Clé = composite de 6 éléments pour éviter les collisions silencieuses
|
||||
entre contextes différents partageant un même screenshot :
|
||||
1. phash (dhash 8x8 du screenshot) — calculé en ~2-5ms
|
||||
2. window_title (titre fenêtre active)
|
||||
3. app_name (nom process actif)
|
||||
4. enable_ocr (flag runtime)
|
||||
5. enable_ui_detection (flag runtime)
|
||||
6. workflow_id (isolation inter-workflows)
|
||||
- TTL par défaut : 2 secondes (configurable)
|
||||
- Invalidation explicite possible (par clé composite ou globale)
|
||||
- invalidate_if_changed reste piloté par le phash seul (détection de
|
||||
changement visuel majeur, indépendant du contexte)
|
||||
- Thread-safe (lock interne)
|
||||
|
||||
API principale :
|
||||
>>> cache = ScreenStateCache(ttl_seconds=2.0)
|
||||
>>> state, hit, ms = cache.get_or_compute(
|
||||
... screenshot_path, compute_fn,
|
||||
... window_title="App", app_name="app.exe",
|
||||
... enable_ocr=True, enable_ui_detection=True,
|
||||
... workflow_id="wf_123",
|
||||
... )
|
||||
|
||||
La fonction `compute_fn` prend le chemin du screenshot et doit retourner
|
||||
un `ScreenState`. Elle n'est appelée qu'en cache miss.
|
||||
"""
|
||||
|
||||
from __future__ import annotations
|
||||
|
||||
import hashlib
|
||||
import logging
|
||||
import threading
|
||||
import time
|
||||
from dataclasses import dataclass
|
||||
from pathlib import Path
|
||||
from typing import Callable, Optional, Tuple
|
||||
|
||||
from PIL import Image
|
||||
|
||||
from core.models.screen_state import ScreenState
|
||||
|
||||
logger = logging.getLogger(__name__)
|
||||
|
||||
|
||||
# =============================================================================
|
||||
# Hash perceptuel (dhash simple, sans dépendance imagehash)
|
||||
# =============================================================================
|
||||
|
||||
|
||||
def _hamming_distance_hex(a: str, b: str) -> int:
|
||||
"""
|
||||
Distance de Hamming entre deux chaînes hexadécimales de même longueur.
|
||||
|
||||
Retourne le nombre de bits qui diffèrent entre les deux hashes.
|
||||
Si les longueurs diffèrent, on pad à droite par des zéros.
|
||||
"""
|
||||
if len(a) != len(b):
|
||||
max_len = max(len(a), len(b))
|
||||
a = a.ljust(max_len, "0")
|
||||
b = b.ljust(max_len, "0")
|
||||
try:
|
||||
xor = int(a, 16) ^ int(b, 16)
|
||||
return bin(xor).count("1")
|
||||
except ValueError:
|
||||
# Fallback : comparaison caractère à caractère
|
||||
return sum(1 for ca, cb in zip(a, b) if ca != cb) * 4
|
||||
|
||||
|
||||
def compute_perceptual_hash(screenshot_path: str, size: int = 8) -> str:
|
||||
"""
|
||||
Calculer un dhash (difference hash) pour un screenshot.
|
||||
|
||||
Algorithme :
|
||||
1. Convertir en niveaux de gris
|
||||
2. Redimensionner à (size+1) x size
|
||||
3. Comparer chaque pixel avec son voisin de droite (dhash)
|
||||
4. Retourner un hash hexadécimal de size*size bits
|
||||
|
||||
Robuste aux petites variations (curseur, blink, compression).
|
||||
Coût typique : 2-5 ms sur un 1920x1080.
|
||||
|
||||
Args:
|
||||
screenshot_path: Chemin vers le fichier image
|
||||
size: Taille du hash (8 = 64 bits, défaut)
|
||||
|
||||
Returns:
|
||||
Chaîne hexadécimale (size*size/4 caractères)
|
||||
"""
|
||||
try:
|
||||
img = Image.open(screenshot_path)
|
||||
img = img.convert("L").resize((size + 1, size), Image.LANCZOS)
|
||||
pixels = list(img.getdata())
|
||||
|
||||
# dhash : comparer chaque pixel avec celui de droite
|
||||
bits = []
|
||||
for row in range(size):
|
||||
for col in range(size):
|
||||
left = pixels[row * (size + 1) + col]
|
||||
right = pixels[row * (size + 1) + col + 1]
|
||||
bits.append(1 if left > right else 0)
|
||||
|
||||
# Convertir en hex
|
||||
value = 0
|
||||
for bit in bits:
|
||||
value = (value << 1) | bit
|
||||
return format(value, f"0{size * size // 4}x")
|
||||
except Exception as e:
|
||||
logger.warning(f"Hash perceptuel échoué pour {screenshot_path}: {e}")
|
||||
# Fallback : hash du contenu brut
|
||||
try:
|
||||
data = Path(screenshot_path).read_bytes()
|
||||
return hashlib.md5(data).hexdigest()[:16]
|
||||
except Exception:
|
||||
return f"unhashable_{int(time.time() * 1000)}"
|
||||
|
||||
|
||||
# =============================================================================
|
||||
# Clé composite (Lot D)
|
||||
# =============================================================================
|
||||
|
||||
|
||||
def _make_cache_key(
|
||||
phash: str,
|
||||
window_title: str,
|
||||
app_name: str,
|
||||
enable_ocr: bool,
|
||||
enable_ui_detection: bool,
|
||||
workflow_id: str,
|
||||
) -> str:
|
||||
"""
|
||||
Construire une clé composite stable pour le cache.
|
||||
|
||||
Combine les 6 dimensions du contexte d'exécution dans une chaîne
|
||||
hexadécimale (md5 tronqué à 16 caractères), préfixée par le phash pour
|
||||
conserver une lisibilité minimale en debug (log : `aabb…|ctx=1234…`).
|
||||
|
||||
NB : On hash plutôt que concaténer brut pour :
|
||||
- Borner la taille de la clé même si window_title est long
|
||||
- Éviter les collisions triviales (séparateur présent dans un titre)
|
||||
- Rendre la clé opaque (pas de PII en clair dans les logs de cache)
|
||||
|
||||
Args:
|
||||
phash: Hash perceptuel du screenshot (dhash 8x8)
|
||||
window_title: Titre de la fenêtre active (str)
|
||||
app_name: Nom du process actif (str)
|
||||
enable_ocr: Flag runtime OCR (bool)
|
||||
enable_ui_detection: Flag runtime détection UI (bool)
|
||||
workflow_id: ID du workflow en cours (str, "" pour legacy)
|
||||
|
||||
Returns:
|
||||
Clé composite `{phash}|{ctx_hash}` où ctx_hash = md5(16)
|
||||
"""
|
||||
# Sérialisation déterministe ; `|` comme séparateur interne puisque hashé.
|
||||
ctx_repr = (
|
||||
f"{window_title or ''}\x1f"
|
||||
f"{app_name or ''}\x1f"
|
||||
f"{int(bool(enable_ocr))}\x1f"
|
||||
f"{int(bool(enable_ui_detection))}\x1f"
|
||||
f"{workflow_id or ''}"
|
||||
)
|
||||
ctx_hash = hashlib.md5(ctx_repr.encode("utf-8")).hexdigest()[:16]
|
||||
return f"{phash}|{ctx_hash}"
|
||||
|
||||
|
||||
# =============================================================================
|
||||
# Entry
|
||||
# =============================================================================
|
||||
|
||||
|
||||
@dataclass
|
||||
class _CacheEntry:
|
||||
state: ScreenState
|
||||
created_at: float
|
||||
phash: str # phash seul (utilisé par invalidate_if_changed)
|
||||
|
||||
|
||||
# =============================================================================
|
||||
# Cache
|
||||
# =============================================================================
|
||||
|
||||
|
||||
class ScreenStateCache:
|
||||
"""
|
||||
Cache de ScreenState avec TTL et clé composite context-aware.
|
||||
|
||||
Thread-safe. Utilise un lock interne pour les opérations get/set.
|
||||
"""
|
||||
|
||||
def __init__(self, ttl_seconds: float = 2.0, max_entries: int = 16):
|
||||
"""
|
||||
Args:
|
||||
ttl_seconds: Durée de vie d'une entrée (en secondes)
|
||||
max_entries: Nombre max d'entrées avant éviction LRU simple
|
||||
"""
|
||||
self.ttl_seconds = ttl_seconds
|
||||
self.max_entries = max_entries
|
||||
# Clé = composite (_make_cache_key), valeur = _CacheEntry
|
||||
self._store: dict[str, _CacheEntry] = {}
|
||||
self._lock = threading.Lock()
|
||||
|
||||
# Métriques simples (utile pour le debug / logs)
|
||||
self.hits = 0
|
||||
self.misses = 0
|
||||
self.invalidations = 0
|
||||
|
||||
# -------------------------------------------------------------------------
|
||||
# API bas niveau (par clé composite)
|
||||
# -------------------------------------------------------------------------
|
||||
|
||||
def _get(self, composite_key: str) -> Optional[ScreenState]:
|
||||
"""Retourne l'entrée pour cette clé composite si encore valide."""
|
||||
with self._lock:
|
||||
entry = self._store.get(composite_key)
|
||||
if entry is None:
|
||||
return None
|
||||
if time.time() - entry.created_at > self.ttl_seconds:
|
||||
# Expiré
|
||||
self._store.pop(composite_key, None)
|
||||
return None
|
||||
return entry.state
|
||||
|
||||
def _set(self, composite_key: str, phash: str, state: ScreenState) -> None:
|
||||
"""Enregistre un état pour cette clé composite."""
|
||||
with self._lock:
|
||||
# Éviction simple : si plein, virer l'entrée la plus ancienne
|
||||
if (
|
||||
len(self._store) >= self.max_entries
|
||||
and composite_key not in self._store
|
||||
):
|
||||
oldest_key = min(
|
||||
self._store, key=lambda k: self._store[k].created_at
|
||||
)
|
||||
self._store.pop(oldest_key, None)
|
||||
|
||||
self._store[composite_key] = _CacheEntry(
|
||||
state=state,
|
||||
created_at=time.time(),
|
||||
phash=phash,
|
||||
)
|
||||
|
||||
def invalidate(self, composite_key: Optional[str] = None) -> None:
|
||||
"""
|
||||
Invalider une entrée ou tout le cache.
|
||||
|
||||
Args:
|
||||
composite_key: Clé à invalider. Si None, vide tout le cache.
|
||||
"""
|
||||
with self._lock:
|
||||
if composite_key is None:
|
||||
self._store.clear()
|
||||
else:
|
||||
self._store.pop(composite_key, None)
|
||||
self.invalidations += 1
|
||||
|
||||
def invalidate_if_changed(
|
||||
self,
|
||||
screenshot_path: str,
|
||||
threshold: float = 0.3,
|
||||
) -> bool:
|
||||
"""
|
||||
Invalider le cache si l'écran a suffisamment changé.
|
||||
|
||||
Compare le dhash du screenshot courant avec le phash (seul) de chaque
|
||||
entrée du cache. La décision est volontairement indépendante du reste
|
||||
de la clé composite : un changement visuel majeur rend toutes les
|
||||
entrées obsolètes, quel que soit le contexte.
|
||||
|
||||
Args:
|
||||
screenshot_path: Chemin du screenshot courant
|
||||
threshold: Proportion de bits qui doivent différer (0.0-1.0).
|
||||
0.3 = 30% (~19 bits sur 64) = changement significatif.
|
||||
|
||||
Returns:
|
||||
True si le cache a été invalidé, False sinon.
|
||||
"""
|
||||
if not self._store:
|
||||
return False
|
||||
|
||||
current_phash = compute_perceptual_hash(screenshot_path)
|
||||
|
||||
# Bits totaux : 64 pour un dhash 8x8 standard. On déduit via la
|
||||
# longueur hexa du hash courant pour rester générique.
|
||||
total_bits = len(current_phash) * 4
|
||||
if total_bits == 0:
|
||||
return False
|
||||
|
||||
threshold_bits = threshold * total_bits
|
||||
|
||||
with self._lock:
|
||||
if not self._store:
|
||||
return False
|
||||
|
||||
# Distance de Hamming minimale avec les phashes des entrées
|
||||
# (on regarde entry.phash, pas la clé composite).
|
||||
min_distance = None
|
||||
for entry in self._store.values():
|
||||
distance = _hamming_distance_hex(current_phash, entry.phash)
|
||||
if min_distance is None or distance < min_distance:
|
||||
min_distance = distance
|
||||
|
||||
if min_distance is not None and min_distance > threshold_bits:
|
||||
size_before = len(self._store)
|
||||
self._store.clear()
|
||||
self.invalidations += 1
|
||||
logger.debug(
|
||||
f"[ScreenStateCache] invalidate_if_changed: "
|
||||
f"distance={min_distance}/{total_bits} > "
|
||||
f"threshold={threshold_bits:.1f} → {size_before} entrées purgées"
|
||||
)
|
||||
return True
|
||||
return False
|
||||
|
||||
# -------------------------------------------------------------------------
|
||||
# API haut niveau (context-aware)
|
||||
# -------------------------------------------------------------------------
|
||||
|
||||
def get_or_compute(
|
||||
self,
|
||||
screenshot_path: str,
|
||||
compute_fn: Callable[[str], ScreenState],
|
||||
*,
|
||||
window_title: str = "",
|
||||
app_name: str = "",
|
||||
enable_ocr: bool = True,
|
||||
enable_ui_detection: bool = True,
|
||||
workflow_id: str = "",
|
||||
force_refresh: bool = False,
|
||||
) -> Tuple[ScreenState, bool, float]:
|
||||
"""
|
||||
Récupérer ou calculer le ScreenState pour un screenshot + contexte.
|
||||
|
||||
Clé de cache = composite(phash, window_title, app_name, enable_ocr,
|
||||
enable_ui_detection, workflow_id). Deux contextes différents partageant
|
||||
le même screenshot n'entrent PAS en collision.
|
||||
|
||||
Rétrocompatibilité : tous les kwargs de contexte ont une valeur par
|
||||
défaut. Un caller legacy qui n'a pas encore été adapté partagera la
|
||||
même entrée de cache qu'un autre caller legacy (comportement antérieur).
|
||||
|
||||
Args:
|
||||
screenshot_path: Chemin du screenshot
|
||||
compute_fn: Fonction qui construit un ScreenState si cache miss
|
||||
window_title: Titre de la fenêtre active (contexte visuel)
|
||||
app_name: Nom du process actif (contexte applicatif)
|
||||
enable_ocr: Flag runtime — différencie états avec/sans OCR
|
||||
enable_ui_detection: Flag runtime — différencie états avec/sans UI
|
||||
workflow_id: ID du workflow — isolation inter-workflows
|
||||
force_refresh: Ignorer le cache et recalculer
|
||||
|
||||
Returns:
|
||||
Tuple (state, cache_hit, elapsed_ms)
|
||||
"""
|
||||
t0 = time.time()
|
||||
phash = compute_perceptual_hash(screenshot_path)
|
||||
composite_key = _make_cache_key(
|
||||
phash=phash,
|
||||
window_title=window_title,
|
||||
app_name=app_name,
|
||||
enable_ocr=enable_ocr,
|
||||
enable_ui_detection=enable_ui_detection,
|
||||
workflow_id=workflow_id,
|
||||
)
|
||||
|
||||
if not force_refresh:
|
||||
cached = self._get(composite_key)
|
||||
if cached is not None:
|
||||
self.hits += 1
|
||||
elapsed_ms = (time.time() - t0) * 1000
|
||||
logger.debug(
|
||||
f"[ScreenStateCache] HIT key={composite_key[:24]}… "
|
||||
f"({elapsed_ms:.1f}ms)"
|
||||
)
|
||||
return cached, True, elapsed_ms
|
||||
|
||||
# Cache miss → calcul complet
|
||||
self.misses += 1
|
||||
state = compute_fn(screenshot_path)
|
||||
self._set(composite_key, phash, state)
|
||||
elapsed_ms = (time.time() - t0) * 1000
|
||||
logger.debug(
|
||||
f"[ScreenStateCache] MISS key={composite_key[:24]}… "
|
||||
f"({elapsed_ms:.1f}ms)"
|
||||
)
|
||||
return state, False, elapsed_ms
|
||||
|
||||
def stats(self) -> dict:
|
||||
"""Retourne les métriques du cache."""
|
||||
with self._lock:
|
||||
total = self.hits + self.misses
|
||||
return {
|
||||
"hits": self.hits,
|
||||
"misses": self.misses,
|
||||
"invalidations": self.invalidations,
|
||||
"hit_rate": self.hits / total if total > 0 else 0.0,
|
||||
"size": len(self._store),
|
||||
"max_entries": self.max_entries,
|
||||
"ttl_seconds": self.ttl_seconds,
|
||||
}
|
||||
|
||||
def __len__(self) -> int:
|
||||
with self._lock:
|
||||
return len(self._store)
|
||||
Reference in New Issue
Block a user