Compare commits
137 Commits
203dc00d53
...
backup/pre
| Author | SHA1 | Date | |
|---|---|---|---|
|
|
5543e25f9d | ||
|
|
2a07d8084b | ||
|
|
35b27ae492 | ||
|
|
b584bbabc3 | ||
|
|
8817f527e7 | ||
|
|
964856ab30 | ||
|
|
a67d896104 | ||
|
|
90c1d8036f | ||
|
|
6261002039 | ||
|
|
0e6e61f2b1 | ||
|
|
41c1250c99 | ||
|
|
2af3bc3b93 | ||
|
|
6154423a91 | ||
|
|
41eba898c0 | ||
|
|
9452e86fd1 | ||
|
|
5e31cdf666 | ||
|
|
487bcb8618 | ||
|
|
3d6868f029 | ||
|
|
f73a2a59a9 | ||
|
|
77faa03ec9 | ||
|
|
343d6fbe95 | ||
|
|
cc64439738 | ||
|
|
90007cc7c1 | ||
|
|
73cea2385e | ||
|
|
e2046837cf | ||
|
|
b30d4b6656 | ||
|
|
e4a48e78bf | ||
|
|
ea36bba5cc | ||
|
|
9da589c8c2 | ||
|
|
16ff396dbf | ||
|
|
e44fd7b328 | ||
|
|
66815b7a1a | ||
|
|
c6b695eca8 | ||
|
|
99d2083dea | ||
|
|
a718086140 | ||
|
|
c82979e72b | ||
|
|
2185c41cc1 | ||
|
|
26804eb123 | ||
|
|
d71d5df4a8 | ||
|
|
6829ad8e79 | ||
|
|
8903f35433 | ||
|
|
4ab2c15e5c | ||
|
|
eba6fea779 | ||
|
|
f04398d5a7 | ||
|
|
4ce9c47f45 | ||
|
|
9dfcdb5fb0 | ||
|
|
3efe15d2c7 | ||
|
|
9d87ed64c5 | ||
|
|
00134963e5 | ||
|
|
0ec5e2a25b | ||
|
|
0c5fffe951 | ||
|
|
5027ed9a23 | ||
|
|
6caab2c600 | ||
|
|
552e66dbf6 | ||
|
|
de1026ee2e | ||
|
|
7b50725bf8 | ||
|
|
7feef3b6a9 | ||
|
|
0b06db222d | ||
|
|
74ee0dadee | ||
|
|
0b452f975a | ||
|
|
6ab385d671 | ||
|
|
b3eab83a0f | ||
|
|
27490849a8 | ||
|
|
cebbf0809a | ||
|
|
3e227d28ad | ||
|
|
8ce63fcba2 | ||
|
|
4202431421 | ||
|
|
4923623dd4 | ||
|
|
84181cc982 | ||
|
|
7355d315a3 | ||
|
|
c50adab3a1 | ||
|
|
2fbb305f65 | ||
|
|
ff581be397 | ||
|
|
203e5cc6c1 | ||
|
|
d1b556b6cd | ||
|
|
729cd67743 | ||
|
|
73ddcdb29d | ||
|
|
14a9442343 | ||
|
|
5da4581e76 | ||
|
|
cbe8dc95d2 | ||
|
|
04a14a56b2 | ||
|
|
2290f1846b | ||
|
|
c57b40ae1d | ||
|
|
bc21b27da7 | ||
|
|
6a2248ddcd | ||
|
|
82d7b38cff | ||
|
|
6c7f88c05d | ||
|
|
447fbb2c6e | ||
|
|
623be15bfe | ||
|
|
55d5aebbd2 | ||
|
|
73b731fef8 | ||
|
|
ffd97ae9a5 | ||
|
|
d168833609 | ||
|
|
23a06a744c | ||
|
|
af4eae28b9 | ||
|
|
c198c930a1 | ||
|
|
e3efef2fe7 | ||
|
|
95fddeebb3 | ||
|
|
71523cebd3 | ||
|
|
3aa806a630 | ||
|
|
588c8f22c1 | ||
|
|
3d243d731d | ||
|
|
2431a6c9e9 | ||
|
|
969236da03 | ||
|
|
f30461b88c | ||
|
|
f34eca20f9 | ||
|
|
309dfd5287 | ||
|
|
f5a672d7b9 | ||
|
|
1acea85fa6 | ||
|
|
4f61741420 | ||
|
|
2fa864b5c7 | ||
|
|
10739c33fa | ||
|
|
39bea1b042 | ||
|
|
26b4e6d8ce | ||
|
|
4fb84b1090 | ||
|
|
7f2bc6fe97 | ||
|
|
eded968c70 | ||
|
|
53d29d9b24 | ||
|
|
690053bd57 | ||
|
|
c7b0649716 | ||
|
|
2bfcfa4535 | ||
|
|
b808e48b1f | ||
|
|
78ee962918 | ||
|
|
c8a3618e27 | ||
|
|
9ca277a63f | ||
|
|
8c7b6e5696 | ||
|
|
af4ffa189a | ||
|
|
42f571d496 | ||
|
|
36737cfe9d | ||
|
|
93ef93e563 | ||
|
|
376e4a88b3 | ||
|
|
bb4ed2a75d | ||
|
|
f7b8cddd2b | ||
|
|
a9a99953dd | ||
|
|
aee64f54b1 | ||
|
|
c77844fa9a | ||
|
|
013fe071a2 |
12
.env.example
12
.env.example
@@ -30,7 +30,9 @@ DASHBOARD_PORT=5001
|
||||
CLIP_MODEL=ViT-B-32
|
||||
CLIP_PRETRAINED=openai
|
||||
CLIP_DEVICE=cpu # cpu or cuda
|
||||
VLM_MODEL=qwen3-vl:8b
|
||||
RPA_VLM_MODEL=gemma4:latest # gemma4:latest (défaut), qwen3-vl:8b, ui-tars (fallback)
|
||||
VLM_MODEL=gemma4:latest # alias de compatibilité
|
||||
# VLM_ALLOW_CLOUD=false # true pour activer les APIs cloud en fallback (OpenAI, Gemini, Anthropic)
|
||||
VLM_ENDPOINT=http://localhost:11434
|
||||
OWL_MODEL=google/owlv2-base-patch16-ensemble
|
||||
OWL_CONFIDENCE_THRESHOLD=0.1
|
||||
@@ -44,6 +46,14 @@ LOGS_PATH=logs
|
||||
UPLOADS_PATH=data/training/uploads
|
||||
SESSIONS_PATH=data/training/sessions
|
||||
|
||||
# ============================================================================
|
||||
# Feedback Bus (Léa parle pendant exécution)
|
||||
# ============================================================================
|
||||
# Bus SocketIO unifié 'lea:*' (action_started, action_done, need_confirm, paused).
|
||||
# Désactivé par défaut. Mettre à 1 pour activer les bulles temps réel dans ChatWindow.
|
||||
# Si la connexion bus échoue, l'exécution continue normalement (fail-safe).
|
||||
LEA_FEEDBACK_BUS=0
|
||||
|
||||
# ============================================================================
|
||||
# FAISS
|
||||
# ============================================================================
|
||||
|
||||
207
.gitea/workflows/security-audit.yml
Normal file
207
.gitea/workflows/security-audit.yml
Normal file
@@ -0,0 +1,207 @@
|
||||
# ------------------------------------------------------------------
|
||||
# Audit sécurité — bandit + pip-audit + scan secrets
|
||||
# ------------------------------------------------------------------
|
||||
# Jamais bloquant : on reporte les warnings, on ne casse pas la CI.
|
||||
# Utile pour détecter les dérives progressives (nouveaux CVE, secrets
|
||||
# oubliés dans un commit, patterns risqués).
|
||||
#
|
||||
# Fréquence : à chaque push sur main + hebdo (cron).
|
||||
# ------------------------------------------------------------------
|
||||
name: security-audit
|
||||
|
||||
on:
|
||||
push:
|
||||
branches:
|
||||
- main
|
||||
schedule:
|
||||
# Tous les lundis à 6h UTC (8h Paris hiver, 7h Paris été).
|
||||
- cron: "0 6 * * 1"
|
||||
workflow_dispatch: {}
|
||||
|
||||
concurrency:
|
||||
group: ${{ github.workflow }}-${{ github.ref }}
|
||||
cancel-in-progress: true
|
||||
|
||||
jobs:
|
||||
# ----------------------------------------------------------------
|
||||
# Job 1 — bandit (bonnes pratiques sécu Python)
|
||||
# ----------------------------------------------------------------
|
||||
bandit:
|
||||
name: Bandit (scan statique)
|
||||
runs-on: ubuntu-latest
|
||||
timeout-minutes: 5
|
||||
continue-on-error: true
|
||||
|
||||
steps:
|
||||
- name: Checkout
|
||||
uses: actions/checkout@v4
|
||||
|
||||
- name: Setup Python 3.12
|
||||
uses: actions/setup-python@v5
|
||||
with:
|
||||
python-version: "3.12"
|
||||
cache: "pip"
|
||||
|
||||
- name: Installation bandit
|
||||
run: |
|
||||
python -m pip install --upgrade pip
|
||||
pip install "bandit[toml]==1.7.10"
|
||||
|
||||
- name: Scan bandit sur core/
|
||||
run: |
|
||||
# -ll : niveau LOW minimum (remonte tout)
|
||||
# -ii : confiance LOW minimum
|
||||
# --skip B101 : on ignore les asserts (usuels en tests/validation)
|
||||
bandit -r core/ \
|
||||
--skip B101,B404,B603 \
|
||||
--format txt \
|
||||
--exit-zero \
|
||||
--output bandit-report.txt
|
||||
echo "=== RAPPORT BANDIT ==="
|
||||
cat bandit-report.txt
|
||||
|
||||
- name: Upload rapport bandit
|
||||
if: always()
|
||||
uses: actions/upload-artifact@v3
|
||||
with:
|
||||
name: bandit-report
|
||||
path: bandit-report.txt
|
||||
retention-days: 30
|
||||
if-no-files-found: ignore
|
||||
|
||||
# ----------------------------------------------------------------
|
||||
# Job 2 — pip-audit (CVE sur requirements)
|
||||
# ----------------------------------------------------------------
|
||||
pip-audit:
|
||||
name: pip-audit (CVE dépendances)
|
||||
runs-on: ubuntu-latest
|
||||
timeout-minutes: 5
|
||||
continue-on-error: true
|
||||
|
||||
steps:
|
||||
- name: Checkout
|
||||
uses: actions/checkout@v4
|
||||
|
||||
- name: Setup Python 3.12
|
||||
uses: actions/setup-python@v5
|
||||
with:
|
||||
python-version: "3.12"
|
||||
cache: "pip"
|
||||
|
||||
- name: Installation pip-audit
|
||||
run: |
|
||||
python -m pip install --upgrade pip
|
||||
pip install "pip-audit==2.7.3"
|
||||
|
||||
- name: Audit CVE sur requirements-ci.txt
|
||||
run: |
|
||||
if [ -f requirements-ci.txt ]; then
|
||||
pip-audit -r requirements-ci.txt \
|
||||
--format json \
|
||||
--output pip-audit-ci.json \
|
||||
--progress-spinner off \
|
||||
--disable-pip || echo "::warning::CVE détectées dans requirements-ci.txt"
|
||||
echo "=== RAPPORT pip-audit (CI) ==="
|
||||
cat pip-audit-ci.json || true
|
||||
else
|
||||
echo "::notice::requirements-ci.txt absent — skip"
|
||||
fi
|
||||
|
||||
- name: Audit CVE sur requirements.txt (best-effort)
|
||||
run: |
|
||||
# Timeout généreux car requirements.txt est massif (torch, CUDA).
|
||||
timeout 120 pip-audit -r requirements.txt \
|
||||
--format json \
|
||||
--output pip-audit-full.json \
|
||||
--progress-spinner off \
|
||||
--disable-pip 2>&1 | head -200 || \
|
||||
echo "::warning::pip-audit sur requirements.txt a timeout ou échoué (non bloquant)"
|
||||
|
||||
- name: Upload rapports pip-audit
|
||||
if: always()
|
||||
uses: actions/upload-artifact@v3
|
||||
with:
|
||||
name: pip-audit-reports
|
||||
path: |
|
||||
pip-audit-ci.json
|
||||
pip-audit-full.json
|
||||
retention-days: 30
|
||||
if-no-files-found: ignore
|
||||
|
||||
# ----------------------------------------------------------------
|
||||
# Job 3 — Scan secrets en clair (grep simple)
|
||||
# ----------------------------------------------------------------
|
||||
# Patterns recherchés : clés API Anthropic (sk-ant-), OpenAI (sk-),
|
||||
# Google (AIzaSy), AWS (AKIA), tokens Hugging Face (hf_).
|
||||
# Ne cherche QUE dans les fichiers trackés (pas .env, pas .venv).
|
||||
# ----------------------------------------------------------------
|
||||
secrets-scan:
|
||||
name: Scan secrets (grep)
|
||||
runs-on: ubuntu-latest
|
||||
timeout-minutes: 3
|
||||
continue-on-error: true
|
||||
|
||||
steps:
|
||||
- name: Checkout (historique complet)
|
||||
uses: actions/checkout@v4
|
||||
with:
|
||||
fetch-depth: 0
|
||||
|
||||
- name: Scan patterns de secrets
|
||||
run: |
|
||||
# Chemins exclus : venvs, caches, data, htmlcov, models.
|
||||
EXCLUDES='--exclude-dir=.venv --exclude-dir=venv_v3 --exclude-dir=.git \
|
||||
--exclude-dir=node_modules --exclude-dir=htmlcov --exclude-dir=models \
|
||||
--exclude-dir=data --exclude-dir=__pycache__ --exclude-dir=.pytest_cache \
|
||||
--exclude=*.lock --exclude=*.log --exclude=*.md'
|
||||
|
||||
echo "=== Recherche de secrets potentiels ==="
|
||||
FOUND=0
|
||||
|
||||
# Anthropic
|
||||
if grep -rnI $EXCLUDES -E 'sk-ant-[a-zA-Z0-9_-]{20,}' . 2>/dev/null; then
|
||||
echo "::warning::Clé Anthropic potentielle détectée"
|
||||
FOUND=1
|
||||
fi
|
||||
|
||||
# OpenAI
|
||||
if grep -rnI $EXCLUDES -E 'sk-proj-[a-zA-Z0-9_-]{20,}|sk-[a-zA-Z0-9]{40,}' . 2>/dev/null; then
|
||||
echo "::warning::Clé OpenAI potentielle détectée"
|
||||
FOUND=1
|
||||
fi
|
||||
|
||||
# Google Cloud / API Keys
|
||||
if grep -rnI $EXCLUDES -E 'AIzaSy[a-zA-Z0-9_-]{33}' . 2>/dev/null; then
|
||||
echo "::warning::Clé Google API potentielle détectée"
|
||||
FOUND=1
|
||||
fi
|
||||
|
||||
# AWS
|
||||
if grep -rnI $EXCLUDES -E 'AKIA[0-9A-Z]{16}' . 2>/dev/null; then
|
||||
echo "::warning::Clé AWS potentielle détectée"
|
||||
FOUND=1
|
||||
fi
|
||||
|
||||
# Hugging Face
|
||||
if grep -rnI $EXCLUDES -E 'hf_[a-zA-Z0-9]{30,}' . 2>/dev/null; then
|
||||
echo "::warning::Token Hugging Face potentiel détecté"
|
||||
FOUND=1
|
||||
fi
|
||||
|
||||
# Mots-clés suspects à côté d'assignations
|
||||
if grep -rnI $EXCLUDES -E '(password|passwd|secret|api_key|apikey|token)\s*=\s*["\x27][a-zA-Z0-9_\-!@#\$%]{12,}["\x27]' . 2>/dev/null \
|
||||
| grep -viE '(example|dummy|placeholder|test|fake|xxx|changeme|\$\{)' 2>/dev/null; then
|
||||
echo "::warning::Assignation suspecte d'un secret détectée"
|
||||
FOUND=1
|
||||
fi
|
||||
|
||||
if [ "$FOUND" -eq 0 ]; then
|
||||
echo "Aucun secret détecté par les patterns de base."
|
||||
else
|
||||
echo ""
|
||||
echo "::notice::Vérifier manuellement les occurrences ci-dessus."
|
||||
echo "::notice::Si faux positif : ajouter le fichier aux exclusions ou reformater."
|
||||
fi
|
||||
|
||||
# Toujours succès (job non bloquant).
|
||||
exit 0
|
||||
214
.gitea/workflows/tests.yml
Normal file
214
.gitea/workflows/tests.yml
Normal file
@@ -0,0 +1,214 @@
|
||||
# ------------------------------------------------------------------
|
||||
# CI principale — Tests unitaires + lint léger
|
||||
# ------------------------------------------------------------------
|
||||
# Déclenchement : push / pull_request sur n'importe quelle branche.
|
||||
# Objectif : feedback rapide (< 3 min) sans GPU ni Ollama.
|
||||
# Runner : self-hosted (label "ubuntu-latest" ou équivalent).
|
||||
#
|
||||
# Les tests marqués `slow`, `gpu`, `integration`, `performance`,
|
||||
# `visual` et `smoke` sont exclus volontairement — ils nécessitent
|
||||
# CUDA, Ollama, ou des captures d'écran réelles.
|
||||
# ------------------------------------------------------------------
|
||||
name: tests
|
||||
|
||||
on:
|
||||
push:
|
||||
branches:
|
||||
- "**"
|
||||
pull_request:
|
||||
branches:
|
||||
- "**"
|
||||
|
||||
# Permet à une nouvelle exécution d'annuler les précédentes
|
||||
# sur la même branche (évite l'engorgement du runner local).
|
||||
concurrency:
|
||||
group: ${{ github.workflow }}-${{ github.ref }}
|
||||
cancel-in-progress: true
|
||||
|
||||
env:
|
||||
# Empêche l'import accidentel de torch/CUDA pendant la CI.
|
||||
PYTHONDONTWRITEBYTECODE: "1"
|
||||
PIP_DISABLE_PIP_VERSION_CHECK: "1"
|
||||
PIP_NO_PYTHON_VERSION_WARNING: "1"
|
||||
# Les modules d'exécution lisent parfois ces vars ; valeurs neutres en CI.
|
||||
RPA_VISION_CI: "1"
|
||||
RPA_AUTH_VAULT_PATH: "/tmp/ci_vault.enc"
|
||||
# api_stream.py a un fail-closed P0-C : si RPA_API_TOKEN absent, sys.exit(1)
|
||||
# au module load. On fournit un token bidon pour que les imports passent en CI.
|
||||
# (Le token n'est jamais utilisé réellement — les tests mockent les requêtes.)
|
||||
RPA_API_TOKEN: "ci_test_token_not_used_for_real_auth_just_to_pass_import_check_0123456789"
|
||||
|
||||
jobs:
|
||||
# ----------------------------------------------------------------
|
||||
# Job 1 — Lint (ruff + black --check)
|
||||
# ----------------------------------------------------------------
|
||||
# Non-bloquant : si ruff/black ne sont pas installables, on log
|
||||
# un warning et on continue. L'objectif ici est d'alerter, pas de
|
||||
# casser la CI pour des espaces en trop.
|
||||
# ----------------------------------------------------------------
|
||||
lint:
|
||||
name: Lint (ruff + black)
|
||||
runs-on: ubuntu-latest
|
||||
timeout-minutes: 5
|
||||
continue-on-error: true
|
||||
|
||||
steps:
|
||||
- name: Checkout du code
|
||||
uses: actions/checkout@v4
|
||||
|
||||
- name: Setup Python 3.12
|
||||
uses: actions/setup-python@v5
|
||||
with:
|
||||
python-version: "3.12"
|
||||
cache: "pip"
|
||||
|
||||
- name: Installation des linters
|
||||
run: |
|
||||
python -m pip install --upgrade pip
|
||||
pip install "ruff==0.6.9" "black==23.12.1" || {
|
||||
echo "::warning::Impossible d'installer ruff/black — job ignoré"
|
||||
exit 0
|
||||
}
|
||||
|
||||
- name: Ruff (lint rapide)
|
||||
run: |
|
||||
if command -v ruff >/dev/null 2>&1; then
|
||||
# Ruff : erreurs critiques uniquement (E9 syntax, F63 invalid print,
|
||||
# F7 syntax, F82 undefined in __all__).
|
||||
# F821 (undefined name) volontairement exclu le temps de nettoyer
|
||||
# la dette technique préexistante (voir docs/STATUS.md).
|
||||
# Dossiers legacy exclus :
|
||||
# - agent_v0/deploy/windows_client/ : clone obsolète (marqué OBSOLÈTE)
|
||||
# - tests/property/ : tests cassés connus (cf. MEMORY.md)
|
||||
ruff check --select=E9,F63,F7,F82 --output-format=github \
|
||||
--exclude "agent_v0/deploy/windows_client" \
|
||||
--exclude "tests/property" \
|
||||
--exclude "tests/integration/test_visual_rpa_checkpoint.py" \
|
||||
core/ agent_v0/ tests/ || {
|
||||
echo "::warning::Ruff a trouvé des erreurs critiques"
|
||||
exit 1
|
||||
}
|
||||
else
|
||||
echo "::warning::ruff indisponible — skip"
|
||||
fi
|
||||
|
||||
- name: Black (format check)
|
||||
run: |
|
||||
if command -v black >/dev/null 2>&1; then
|
||||
# --check : ne modifie pas, signale juste.
|
||||
# Dossiers legacy exclus (cohérent avec ruff).
|
||||
black --check --diff \
|
||||
--exclude "agent_v0/deploy/windows_client|tests/property" \
|
||||
core/ agent_v0/ tests/ || {
|
||||
echo "::warning::Black suggère un reformatage — non bloquant"
|
||||
exit 0
|
||||
}
|
||||
else
|
||||
echo "::warning::black indisponible — skip"
|
||||
fi
|
||||
|
||||
# ----------------------------------------------------------------
|
||||
# Job 2 — Tests unitaires
|
||||
# ----------------------------------------------------------------
|
||||
# Exclut tous les marqueurs lourds. Utilise requirements-ci.txt
|
||||
# pour éviter torch/CUDA (économie ~3 Go + ~2 min).
|
||||
# ----------------------------------------------------------------
|
||||
unit-tests:
|
||||
name: Tests unitaires (sans GPU)
|
||||
runs-on: ubuntu-latest
|
||||
timeout-minutes: 10
|
||||
|
||||
steps:
|
||||
- name: Checkout du code
|
||||
uses: actions/checkout@v4
|
||||
|
||||
- name: Setup Python 3.12
|
||||
uses: actions/setup-python@v5
|
||||
with:
|
||||
python-version: "3.12"
|
||||
cache: "pip"
|
||||
cache-dependency-path: |
|
||||
requirements-ci.txt
|
||||
requirements.txt
|
||||
|
||||
- name: Installation des dépendances CI
|
||||
run: |
|
||||
python -m pip install --upgrade pip
|
||||
if [ -f requirements-ci.txt ]; then
|
||||
echo "Utilisation de requirements-ci.txt (léger, sans torch)"
|
||||
pip install -r requirements-ci.txt
|
||||
else
|
||||
echo "::warning::requirements-ci.txt absent — fallback requirements.txt (lourd)"
|
||||
pip install -r requirements.txt
|
||||
fi
|
||||
|
||||
- name: Vérification imports critiques
|
||||
run: |
|
||||
python -c "import pytest; print(f'pytest {pytest.__version__}')"
|
||||
python -c "import sys; sys.path.insert(0, '.'); import core; print('core OK')" || {
|
||||
echo "::error::Impossible d'importer core.*"
|
||||
exit 1
|
||||
}
|
||||
|
||||
- name: Tests unitaires (hors slow/gpu/integration)
|
||||
run: |
|
||||
python -m pytest tests/unit/ \
|
||||
-m "not slow and not gpu and not integration and not performance and not visual" \
|
||||
--tb=short \
|
||||
--strict-markers \
|
||||
-q \
|
||||
--maxfail=10 \
|
||||
-o cache_dir=/tmp/.pytest_cache_ci
|
||||
|
||||
- name: Upload logs si échec
|
||||
if: failure()
|
||||
uses: actions/upload-artifact@v3
|
||||
with:
|
||||
name: pytest-logs
|
||||
path: |
|
||||
/tmp/.pytest_cache_ci
|
||||
logs/
|
||||
retention-days: 3
|
||||
if-no-files-found: ignore
|
||||
|
||||
# ----------------------------------------------------------------
|
||||
# Job 3 — Tests sécurité (bloquant)
|
||||
# ----------------------------------------------------------------
|
||||
# Les tests `test_security_*` valident des invariants critiques
|
||||
# (évaluation sûre, sérialisation signée). Aucune régression tolérée.
|
||||
# ----------------------------------------------------------------
|
||||
security-tests:
|
||||
name: Tests sécurité (critique)
|
||||
runs-on: ubuntu-latest
|
||||
timeout-minutes: 5
|
||||
needs: [unit-tests]
|
||||
|
||||
steps:
|
||||
- name: Checkout du code
|
||||
uses: actions/checkout@v4
|
||||
|
||||
- name: Setup Python 3.12
|
||||
uses: actions/setup-python@v5
|
||||
with:
|
||||
python-version: "3.12"
|
||||
cache: "pip"
|
||||
cache-dependency-path: |
|
||||
requirements-ci.txt
|
||||
requirements.txt
|
||||
|
||||
- name: Installation des dépendances CI
|
||||
run: |
|
||||
python -m pip install --upgrade pip
|
||||
if [ -f requirements-ci.txt ]; then
|
||||
pip install -r requirements-ci.txt
|
||||
else
|
||||
pip install -r requirements.txt
|
||||
fi
|
||||
|
||||
- name: Tests sécurité (test_security_*)
|
||||
run: |
|
||||
python -m pytest tests/unit/test_security_*.py \
|
||||
--tb=long \
|
||||
--strict-markers \
|
||||
-v \
|
||||
-o cache_dir=/tmp/.pytest_cache_ci_sec
|
||||
28
.gitignore
vendored
28
.gitignore
vendored
@@ -83,3 +83,31 @@ backups/
|
||||
# === Legacy / Triage ===
|
||||
_a_trier/
|
||||
archives/
|
||||
|
||||
# === Claude Code — worktrees et données locales ===
|
||||
# Worktrees générés par la CLI Claude Code lors d'exécutions d'agents
|
||||
# parallèles. Peuvent atteindre plusieurs centaines de Mo chacun.
|
||||
# Ne jamais committer — gérer via `git worktree list` / `git worktree remove`.
|
||||
.claude/
|
||||
.kiro/
|
||||
.mcp.json
|
||||
.snapshots/
|
||||
|
||||
# === Données runtime (sessions, learning, buffer, config local) ===
|
||||
data/
|
||||
**/capture_library.json
|
||||
.hypothesis/
|
||||
.deps_installed
|
||||
# Buffers SQLite locaux (streamer, cache)
|
||||
**/buffer/
|
||||
**/pending_events.db
|
||||
# Databases applicatives (instance Flask)
|
||||
**/instance/*.db
|
||||
**/instance/*.sqlite
|
||||
**/instance/*.sqlite3
|
||||
# Caches et index locaux
|
||||
*.sqlite
|
||||
*.sqlite3
|
||||
*.db-journal
|
||||
*.db-wal
|
||||
*.db-shm
|
||||
|
||||
@@ -21,7 +21,12 @@ ollama serve
|
||||
### 3. Télécharger le modèle VLM
|
||||
|
||||
```bash
|
||||
ollama pull qwen3-vl:8b
|
||||
# Modèle par défaut du projet (voir .env.example)
|
||||
ollama pull gemma4:latest
|
||||
|
||||
# Alternatives supportées
|
||||
# ollama pull qwen3-vl:8b
|
||||
# ollama pull 0000/ui-tars-1.5-7b-q8_0:7b # grounder visuel
|
||||
```
|
||||
|
||||
## Utilisation
|
||||
|
||||
331
README.md
331
README.md
@@ -1,207 +1,204 @@
|
||||
# RPA Vision V3 - 100% Vision-Based Workflow Automation
|
||||
# RPA Vision V3 — Automatisation basée sur la compréhension visuelle des interfaces
|
||||
|
||||
## 📊 Status
|
||||
> ⚠️ **Projet en phase POC** — voir [`docs/STATUS.md`](docs/STATUS.md) pour l'état
|
||||
> réel par module. Certaines briques sont opérationnelles bout en bout,
|
||||
> d'autres sont en cours de stabilisation. Ce dépôt n'est pas production-ready.
|
||||
|
||||
🚀 **PRODUCTION-READY** - Phase 12 Complete (77% System Completion) ✅
|
||||
*Dernière mise à jour : 14 avril 2026*
|
||||
|
||||
**Latest Update**: 14 Décembre 2024
|
||||
- ✅ **10/13 Phases Complétées** - Système mature et fonctionnel
|
||||
- ✅ **Performance Exceptionnelle** - 500-6250x plus rapide que requis
|
||||
- ✅ **Architecture Entreprise** - 148k+ lignes, 19 modules, 6 specs complètes
|
||||
- ✅ **Innovations Techniques** - Self-healing, Multi-modal, GPU management
|
||||
- 📊 **Audit Complet** - [Rapport détaillé](AUDIT_COMPLET_SYSTEME_RPA_VISION_V3.md)
|
||||
## Intention
|
||||
|
||||
**Quick Test**: `bash test_clip.sh`
|
||||
Automatiser des workflows métier par **compréhension sémantique de l'écran**
|
||||
plutôt que par coordonnées de clic fixes. Le système observe l'utilisateur,
|
||||
reconstruit un graphe d'états de l'interface, et cherche à rejouer la
|
||||
procédure en reconnaissant visuellement les éléments cibles — y compris
|
||||
quand l'UI change légèrement.
|
||||
|
||||
## 🎯 Vision
|
||||
Terrain cible principal : postes hospitaliers (Citrix, applications métier
|
||||
web et desktop). Contrainte forte : **100 % local**, pas d'appel à un LLM
|
||||
cloud dans le pipeline par défaut.
|
||||
|
||||
RPA basé sur la **compréhension sémantique** des interfaces, pas sur des coordonnées de clics.
|
||||
|
||||
Le système apprend des workflows en observant l'utilisateur et les automatise de manière robuste grâce à une architecture en 5 couches.
|
||||
|
||||
## 🏗️ Architecture en 5 Couches
|
||||
## Architecture en couches
|
||||
|
||||
```
|
||||
RawSession (Couche 0)
|
||||
RawSession (couche 0) — capture événements + screenshots
|
||||
↓
|
||||
ScreenState (Couche 1) - 4 niveaux d'abstraction
|
||||
ScreenState (couche 1) — états d'écran à plusieurs niveaux d'abstraction
|
||||
↓
|
||||
UIElement Detection (Couche 2) - Types + Rôles sémantiques
|
||||
UIElement (couche 2) — détection sémantique (cascade OCR + templates + VLM)
|
||||
↓
|
||||
State Embedding (Couche 3) - Fusion multi-modale
|
||||
State Embedding (couche 3) — fusion multi-modale + index FAISS
|
||||
↓
|
||||
Workflow Graph (Couche 4) - Nodes + Edges + Learning States
|
||||
Workflow Graph (couche 4) — nœuds, transitions, résolution de cibles
|
||||
```
|
||||
|
||||
## 📁 Structure
|
||||
## État des fonctionnalités (synthèse)
|
||||
|
||||
```
|
||||
rpa_vision_v3/
|
||||
├── core/
|
||||
│ ├── models/ # Couches 0-4 : Structures de données
|
||||
│ ├── capture/ # Couche 0 : Capture événements + screenshots
|
||||
│ ├── detection/ # Couche 2 : Détection UI sémantique
|
||||
│ ├── embedding/ # Couche 3 : Fusion multi-modale + FAISS
|
||||
│ ├── graph/ # Couche 4 : Construction + Matching + Exécution
|
||||
│ └── persistence/ # Sauvegarde/Chargement
|
||||
├── data/
|
||||
│ ├── sessions/ # RawSessions
|
||||
│ ├── screen_states/ # ScreenStates
|
||||
│ ├── embeddings/ # Vecteurs .npy
|
||||
│ ├── faiss_index/ # Index FAISS
|
||||
│ └── workflows/ # Workflow Graphs
|
||||
└── tests/ # Tests unitaires + intégration
|
||||
```
|
||||
Le détail par module est dans [`docs/STATUS.md`](docs/STATUS.md).
|
||||
|
||||
## 🚀 Démarrage Rapide
|
||||
**Opérationnel**
|
||||
- Capture Windows (Agent V1) + streaming vers serveur Linux
|
||||
- Stockage des sessions brutes (screenshots + événements)
|
||||
- Streaming server FastAPI, sessions en mémoire
|
||||
- Build du package Windows (`deploy/build_package.sh`)
|
||||
|
||||
**Alpha (fonctionnel sur un cas de référence, encore peu généralisé)**
|
||||
- Détection UI par cascade VLM + OCR + templates
|
||||
- Construction de workflow graph depuis une session
|
||||
- Replay E2E supervisé — premier succès sur Notepad le 13 avril 2026
|
||||
- Mode apprentissage : pause et demande d'aide humaine quand la résolution échoue
|
||||
- Embeddings CLIP + index FAISS
|
||||
- Module auth (Fernet + TOTP), federation (LearningPack)
|
||||
- Web Dashboard, Agent Chat
|
||||
|
||||
**En cours**
|
||||
- Visual Workflow Builder (VWB) — bugs DB runtime connus
|
||||
- Self-healing / recovery global
|
||||
- Analytics / reporting
|
||||
- Worker de compilation sessions → ExecutionPlan
|
||||
- Tests E2E multi-applications
|
||||
|
||||
## Limitations connues
|
||||
|
||||
- Le pipeline de replay est validé sur un nombre très restreint d'applications.
|
||||
- `TargetMemoryStore` (apprentissage Phase 1) est câblé mais sa base reste
|
||||
vide tant qu'un replay complet n'a pas été cristallisé.
|
||||
- Certaines asymétries entre chemins stricts et legacy dans le serveur de
|
||||
streaming peuvent provoquer des arrêts au lieu de pauses d'apprentissage.
|
||||
- VWB n'est pas encore stable en écriture ; un outil dédié plus simple est
|
||||
envisagé.
|
||||
|
||||
## Démarrage
|
||||
|
||||
### Prérequis
|
||||
|
||||
- Python 3.10 à 3.12
|
||||
- [Ollama](https://ollama.ai) installé et démarré localement
|
||||
- Recommandé : GPU NVIDIA pour l'inférence VLM
|
||||
- Windows 10/11 uniquement pour le client Agent V1
|
||||
|
||||
### Installation
|
||||
|
||||
```bash
|
||||
# 1. Installer Ollama
|
||||
curl -fsSL https://ollama.ai/install.sh | sh # Linux
|
||||
# ou
|
||||
brew install ollama # macOS
|
||||
|
||||
# 2. Démarrer Ollama
|
||||
ollama serve
|
||||
|
||||
# 3. Télécharger le modèle VLM
|
||||
ollama pull qwen3-vl:8b
|
||||
|
||||
# 4. Installer dépendances Python
|
||||
# 1) Cloner puis créer le venv
|
||||
python3 -m venv .venv
|
||||
source .venv/bin/activate
|
||||
pip install -r requirements.txt
|
||||
|
||||
# 2) Démarrer Ollama et récupérer le modèle VLM par défaut
|
||||
ollama serve &
|
||||
ollama pull gemma4:latest # défaut du projet
|
||||
# Alternatives supportées :
|
||||
# ollama pull qwen3-vl:8b
|
||||
# ollama pull 0000/ui-tars-1.5-7b-q8_0:7b # grounder visuel
|
||||
|
||||
# 3) Copier et ajuster la configuration
|
||||
cp .env.example .env
|
||||
# éditer .env pour vérifier RPA_VLM_MODEL, VLM_ENDPOINT, ports, etc.
|
||||
```
|
||||
|
||||
### Test Rapide
|
||||
### Lancer les services
|
||||
|
||||
Tous les services sont pilotés par `svc.sh` (source de vérité des ports :
|
||||
`services.conf`).
|
||||
|
||||
```bash
|
||||
# Diagnostic système
|
||||
python3 rpa_vision_v3/examples/diagnostic_vlm.py
|
||||
|
||||
# Test de détection
|
||||
./rpa_vision_v3/test_quick.sh
|
||||
./svc.sh status # État de tous les services
|
||||
./svc.sh start # Tout démarrer
|
||||
./svc.sh start streaming # Streaming server uniquement (port 5005)
|
||||
./svc.sh restart api # Redémarrer l'API (port 8000)
|
||||
./svc.sh stop # Tout arrêter
|
||||
```
|
||||
|
||||
### Utilisation - Détection UI
|
||||
| Port | Service |
|
||||
|---|---|
|
||||
| 8000 | API Server (upload / traitement core) |
|
||||
| 5001 | Web Dashboard |
|
||||
| 5002 | VWB Backend (Flask) |
|
||||
| 5003 | Monitoring |
|
||||
| 5004 | Agent Chat |
|
||||
| 5005 | Streaming Server (Agent V1 → pipeline core) |
|
||||
| 5006 | Session Cleaner |
|
||||
| 5099 | Worker de compilation (optionnel) |
|
||||
| 3002 | VWB Frontend (Vite/React) |
|
||||
|
||||
```python
|
||||
from rpa_vision_v3.core.detection import create_detector
|
||||
### Client Windows (Agent V1)
|
||||
|
||||
# Créer le détecteur
|
||||
detector = create_detector()
|
||||
|
||||
# Détecter les éléments UI
|
||||
elements = detector.detect("screenshot.png")
|
||||
|
||||
# Utiliser les résultats
|
||||
for elem in elements:
|
||||
print(f"{elem.type:15s} | {elem.role:20s} | {elem.label}")
|
||||
```
|
||||
|
||||
### Utilisation - Workflow (Phase 4 - À venir)
|
||||
|
||||
```python
|
||||
from rpa_vision_v3.core.models import RawSession, ScreenState, Workflow
|
||||
from rpa_vision_v3.core.graph import GraphBuilder, NodeMatcher
|
||||
|
||||
# 1. Capturer une session
|
||||
session = RawSession(...)
|
||||
# ... capturer événements et screenshots
|
||||
|
||||
# 2. Construire workflow automatiquement
|
||||
builder = GraphBuilder(...)
|
||||
workflow = builder.build_from_session(session)
|
||||
|
||||
# 3. Matcher état actuel
|
||||
matcher = NodeMatcher(...)
|
||||
current_state = ScreenState(...)
|
||||
match = matcher.match(current_state, workflow)
|
||||
|
||||
# 4. Exécuter action
|
||||
if match:
|
||||
edge = workflow.get_outgoing_edges(match.node.node_id)[0]
|
||||
executor.execute_edge(edge, current_state)
|
||||
```
|
||||
|
||||
## 📚 Documentation
|
||||
|
||||
### Guides Principaux
|
||||
- **Quick Start** : `QUICK_START.md` - Démarrage rapide
|
||||
- **Prochaines Étapes** : `NEXT_STEPS.md` - Roadmap et Phase 4
|
||||
- **Phase 3 Complète** : `PHASE3_COMPLETE.md` - Résumé Phase 3
|
||||
|
||||
### Documentation Technique
|
||||
- **Spec complète** : `.kiro/specs/workflow-graph-implementation/`
|
||||
- **Architecture** : `docs/reference/ARCHITECTURE_VISION_COMPLETE.md`
|
||||
- **Détection Hybride** : `HYBRID_DETECTION_SUMMARY.md`
|
||||
- **Intégration Ollama** : `docs/OLLAMA_INTEGRATION.md`
|
||||
|
||||
## 🎓 Concepts Clés
|
||||
|
||||
### RPA 100% Vision
|
||||
|
||||
- ❌ Pas de coordonnées (x, y) fixes
|
||||
- ✅ Rôles sémantiques (primary_action, form_input, etc.)
|
||||
- ✅ Matching par similarité visuelle et textuelle
|
||||
- ✅ Robuste aux changements d'UI
|
||||
|
||||
### Apprentissage Progressif
|
||||
|
||||
```
|
||||
OBSERVATION (5+ exécutions)
|
||||
↓
|
||||
COACHING (10+ assistances, succès >90%)
|
||||
↓
|
||||
AUTO_CANDIDATE (20+ exécutions, succès >95%)
|
||||
↓
|
||||
AUTO_CONFIRMÉ (validation utilisateur)
|
||||
```
|
||||
|
||||
### State Embedding
|
||||
|
||||
Fusion multi-modale :
|
||||
- 50% Image (screenshot complet)
|
||||
- 30% Texte (texte détecté)
|
||||
- 10% Titre (fenêtre)
|
||||
- 10% UI (éléments détectés)
|
||||
|
||||
## 🧪 Tests
|
||||
Le client capture souris, clavier et écran sur le poste Windows et envoie
|
||||
les données au streaming server Linux.
|
||||
|
||||
```bash
|
||||
# Tests unitaires
|
||||
pytest tests/unit/
|
||||
|
||||
# Tests d'intégration
|
||||
pytest tests/integration/
|
||||
|
||||
# Tests de performance
|
||||
pytest tests/performance/ --benchmark-only
|
||||
# Build du package Windows depuis le repo Linux
|
||||
./deploy/build_package.sh
|
||||
# produit deploy/Lea_v<version>.zip
|
||||
```
|
||||
|
||||
## 📈 Roadmap - 77% Complété (10/13 Phases)
|
||||
Voir [`docs/DEV_SETUP.md`](docs/DEV_SETUP.md) pour la maintenance du dépôt
|
||||
(worktrees, build, services).
|
||||
|
||||
### ✅ **Phases Complétées**
|
||||
- [x] **Phase 1-2** : Fondations + Embeddings FAISS ✅
|
||||
- [x] **Phase 4-6** : Détection UI + Workflow Graphs + Action Execution ✅
|
||||
- [x] **Phase 7-8** : Learning System + Training System ✅
|
||||
- [x] **Phase 10-12** : GPU Management + Performance + Monitoring ✅
|
||||
## Arborescence du dépôt
|
||||
|
||||
### 🎯 **Phases Restantes**
|
||||
- [ ] **Phase 3** : Checkpoint Final (tests storage)
|
||||
- [ ] **Phase 9** : Visual Workflow Builder (90% → 100%)
|
||||
- [ ] **Phase 13** : Tests End-to-End + Documentation finale
|
||||
```
|
||||
rpa_vision_v3/
|
||||
├── agent_v0/ # Agent V1 (client Windows) + serveur de streaming
|
||||
│ ├── agent_v1/ # Source de l'agent (capture, UI tray, exécution)
|
||||
│ └── server_v1/ # FastAPI streaming + processeurs
|
||||
├── core/ # Pipeline core
|
||||
│ ├── detection/ # Cascade VLM + OCR + templates
|
||||
│ ├── embedding/ # CLIP + FAISS
|
||||
│ ├── graph/ # Construction / matching de workflow graphs
|
||||
│ ├── execution/ # Résolution de cibles, actions LLM
|
||||
│ ├── learning/ # TargetMemoryStore (apprentissage)
|
||||
│ ├── auth/ # Vault Fernet + TOTP
|
||||
│ └── federation/ # Export/import de LearningPacks
|
||||
├── visual_workflow_builder/ # VWB (backend Flask + frontend React Vite)
|
||||
├── web_dashboard/ # Dashboard Flask + SocketIO
|
||||
├── agent_chat/ # Interface conversationnelle + planner
|
||||
├── deploy/ # Scripts de build et unités systemd
|
||||
├── data/ # Sessions, embeddings, index FAISS, apprentissage
|
||||
├── docs/ # Documentation technique
|
||||
├── tests/ # pytest (unit, integration, e2e)
|
||||
├── services.conf # Source de vérité des ports
|
||||
├── svc.sh # Orchestrateur des services
|
||||
└── run.sh # Démarrage tout-en-un (legacy, préférer svc.sh)
|
||||
```
|
||||
|
||||
### 🚀 **Composants Production-Ready**
|
||||
- **Agent V0** : Capture cross-platform + Encryption ✅
|
||||
- **Server API** : Processing pipeline + Web dashboard ✅
|
||||
- **Analytics System** : Monitoring + Insights + Reporting ✅
|
||||
- **Self-Healing** : Automatic adaptation + Recovery ✅
|
||||
## Tests
|
||||
|
||||
## 🤝 Contribution
|
||||
```bash
|
||||
source .venv/bin/activate
|
||||
|
||||
Voir `.kiro/specs/workflow-graph-implementation/tasks.md` pour les tâches en cours.
|
||||
# Tests rapides (hors marqueur slow)
|
||||
pytest -m "not slow" -q
|
||||
|
||||
## 📄 Licence
|
||||
# Tests d'intégration (streaming, pipeline)
|
||||
pytest tests/integration/ -q
|
||||
|
||||
Propriétaire - Tous droits réservés
|
||||
# Tests E2E
|
||||
pytest tests/test_pipeline_e2e.py -q
|
||||
```
|
||||
|
||||
Quelques tests legacy sont connus comme cassés — voir la mémoire projet et
|
||||
`docs/` pour la liste.
|
||||
|
||||
## Documentation
|
||||
|
||||
- [`docs/STATUS.md`](docs/STATUS.md) — état réel par module
|
||||
- [`docs/DEV_SETUP.md`](docs/DEV_SETUP.md) — tâches d'administration (worktrees, build)
|
||||
- [`docs/EXECUTION_LOOP_FLAGS.md`](docs/EXECUTION_LOOP_FLAGS.md) — flags C1 vision-aware (`enable_ui_detection`, `enable_ocr`, `analyze_timeout_ms`, `window_info_provider`)
|
||||
- [`docs/VISION_RPA_INTELLIGENT.md`](docs/VISION_RPA_INTELLIGENT.md) — cahier des charges
|
||||
- [`docs/PLAN_ACTEUR_V1.md`](docs/PLAN_ACTEUR_V1.md) — architecture 3 niveaux (Macro / Méso / Micro)
|
||||
- [`docs/CONFORMITE_AI_ACT.md`](docs/CONFORMITE_AI_ACT.md) — journalisation, floutage, rétention
|
||||
|
||||
## Concepts clés
|
||||
|
||||
- **RPA 100 % vision** : pas de coordonnées fixes ; l'agent localise un
|
||||
élément par ce qu'il voit (label + contexte visuel), pas par `x,y`.
|
||||
- **Apprentissage progressif** : mode shadow → assisté → autonome, validé
|
||||
par supervision humaine sur les échecs.
|
||||
- **LLM 100 % local** : Ollama sur la machine. Aucun appel cloud dans le
|
||||
pipeline par défaut (cf. feedback projet `feedback_local_only.md`).
|
||||
|
||||
## Licence
|
||||
|
||||
Propriétaire — tous droits réservés.
|
||||
|
||||
@@ -125,18 +125,19 @@ class WorkflowPipelineEnhanced:
|
||||
current_node_id = match_result["node_id"]
|
||||
logger.info(f"Matched current state to node: {current_node_id} (confidence: {match_result['confidence']:.3f})")
|
||||
|
||||
# 2. Obtenir la prochaine action
|
||||
# 2. Obtenir la prochaine action (contrat dict avec status explicite)
|
||||
action_info = self.get_next_action(workflow_id, current_node_id)
|
||||
action_status = action_info.get("status")
|
||||
|
||||
if not action_info:
|
||||
# Workflow terminé
|
||||
if action_status == "terminal":
|
||||
# Workflow terminé (aucun outgoing_edge = fin légitime)
|
||||
performance_metrics.total_execution_time_ms = (datetime.now() - start_time).total_seconds() * 1000
|
||||
|
||||
result = WorkflowExecutionResult.workflow_complete(
|
||||
execution_id=execution_id,
|
||||
workflow_id=workflow_id,
|
||||
current_node=current_node_id,
|
||||
performance_metrics=performance_metrics
|
||||
performance_metrics=performance_metrics,
|
||||
)
|
||||
result.correlation_id = correlation_id
|
||||
result.match_result = match_result
|
||||
@@ -144,6 +145,27 @@ class WorkflowPipelineEnhanced:
|
||||
logger.info(f"Workflow {workflow_id} completed at node {current_node_id}")
|
||||
return result
|
||||
|
||||
if action_status == "blocked":
|
||||
# Des edges existent mais aucun ne passe les filtres :
|
||||
# c'est un blocage, pas une fin de workflow.
|
||||
performance_metrics.total_execution_time_ms = (datetime.now() - start_time).total_seconds() * 1000
|
||||
|
||||
result = WorkflowExecutionResult.error(
|
||||
execution_id=execution_id,
|
||||
workflow_id=workflow_id,
|
||||
error_message=f"No valid edge: {action_info.get('reason', 'unknown')}",
|
||||
step_type="action_selection",
|
||||
current_node=current_node_id,
|
||||
performance_metrics=performance_metrics,
|
||||
)
|
||||
result.correlation_id = correlation_id
|
||||
|
||||
logger.warning(
|
||||
f"Workflow {workflow_id} blocked at node {current_node_id}: "
|
||||
f"{action_info.get('reason')}"
|
||||
)
|
||||
return result
|
||||
|
||||
logger.info(f"Next action: {action_info['action']['type']} -> {action_info['target_node']}")
|
||||
|
||||
# 3. Charger le workflow pour obtenir l'edge complet
|
||||
@@ -14,8 +14,9 @@ import asyncio
|
||||
import logging
|
||||
import json
|
||||
import base64
|
||||
import pickle
|
||||
import gzip
|
||||
import pickle # noqa: S403 - usage legacy restreint au fallback de migration
|
||||
import io
|
||||
from typing import Dict, List, Optional, Any, Tuple
|
||||
from dataclasses import dataclass, asdict
|
||||
from datetime import datetime
|
||||
@@ -24,6 +25,12 @@ import numpy as np
|
||||
|
||||
from core.visual.visual_target_manager import VisualTarget, VisualTargetManager
|
||||
from core.visual.screenshot_validation_manager import ScreenshotValidationManager, ValidationResult
|
||||
from core.security.signed_serializer import (
|
||||
SignatureVerificationError,
|
||||
UnsupportedFormatError,
|
||||
dumps_signed,
|
||||
loads_signed,
|
||||
)
|
||||
|
||||
logger = logging.getLogger(__name__)
|
||||
|
||||
@@ -435,7 +442,7 @@ class VisualPersistenceManager:
|
||||
return None
|
||||
|
||||
async def _serialize_workflow_data(self, workflow_data: VisualWorkflowData) -> bytes:
|
||||
"""Sérialise les données d'un workflow"""
|
||||
"""Sérialise les données d'un workflow en JSON signé HMAC."""
|
||||
# Convertir en dictionnaire
|
||||
data_dict = asdict(workflow_data)
|
||||
|
||||
@@ -456,13 +463,28 @@ class VisualPersistenceManager:
|
||||
]
|
||||
data_dict['validation_history'] = serialized_history
|
||||
|
||||
# Convertir en bytes
|
||||
return pickle.dumps(data_dict)
|
||||
# JSON signé HMAC (cf. core.security.signed_serializer)
|
||||
return dumps_signed(data_dict)
|
||||
|
||||
async def _deserialize_workflow_data(self, data: bytes) -> VisualWorkflowData:
|
||||
"""Désérialise les données d'un workflow"""
|
||||
# Désérialiser le dictionnaire
|
||||
data_dict = pickle.loads(data)
|
||||
"""Désérialise les données d'un workflow (JSON signé HMAC ;
|
||||
fallback pickle legacy avec WARNING pour migrer les anciens fichiers)."""
|
||||
try:
|
||||
data_dict = loads_signed(data)
|
||||
except SignatureVerificationError:
|
||||
# Fichier altéré ou clé différente : on refuse sans fallback.
|
||||
logger.error("Workflow visuel : signature HMAC invalide — refus.")
|
||||
raise
|
||||
except UnsupportedFormatError:
|
||||
# Ancien format pickle : fallback explicite et bruyant.
|
||||
import os
|
||||
if os.getenv("RPA_ALLOW_PICKLE_FALLBACK", "1") == "0":
|
||||
raise
|
||||
logger.warning(
|
||||
"Workflow visuel au format pickle legacy — lecture de compat, "
|
||||
"ré-écrire en JSON signé dès que possible."
|
||||
)
|
||||
data_dict = pickle.loads(data) # noqa: S301 - fallback legacy
|
||||
|
||||
# Reconstruire les objets
|
||||
workflow_data = VisualWorkflowData(
|
||||
@@ -133,6 +133,28 @@ def _streaming_headers() -> dict:
|
||||
headers["Authorization"] = f"Bearer {_STREAMING_API_TOKEN}"
|
||||
return headers
|
||||
|
||||
|
||||
# ============================================================
|
||||
# Feedback Bus — events 'lea:*' temps réel vers ChatWindow
|
||||
# ============================================================
|
||||
LEA_FEEDBACK_BUS = os.environ.get("LEA_FEEDBACK_BUS", "0").lower() in ("1", "true", "yes", "on")
|
||||
|
||||
|
||||
def _emit_lea(event: str, payload: Dict[str, Any]) -> None:
|
||||
"""Émet 'lea:{event}' sur le bus SocketIO. No-op silencieux si flag off ou erreur."""
|
||||
if not LEA_FEEDBACK_BUS:
|
||||
return
|
||||
try:
|
||||
socketio.emit(f"lea:{event}", payload)
|
||||
except Exception:
|
||||
logger.debug("_emit_lea silenced", exc_info=True)
|
||||
|
||||
|
||||
def _emit_dual(legacy_event: str, lea_event: str, payload: Dict[str, Any], **kwargs) -> None:
|
||||
"""Émet l'event legacy (compat dashboard) ET l'alias lea:* (ChatWindow tkinter)."""
|
||||
socketio.emit(legacy_event, payload, **kwargs)
|
||||
_emit_lea(lea_event, payload)
|
||||
|
||||
execution_status = {
|
||||
"running": False,
|
||||
"workflow": None,
|
||||
@@ -623,7 +645,7 @@ def api_execute():
|
||||
}
|
||||
|
||||
# Notifier via WebSocket
|
||||
socketio.emit('execution_started', {
|
||||
_emit_dual('execution_started', 'action_started', {
|
||||
"workflow": match.workflow_name,
|
||||
"params": all_params
|
||||
})
|
||||
@@ -1181,28 +1203,28 @@ def _execute_gesture(gesture):
|
||||
)
|
||||
|
||||
if resp.status_code == 200:
|
||||
socketio.emit('execution_completed', {
|
||||
_emit_dual('execution_completed', 'done', {
|
||||
"workflow": gesture.name,
|
||||
"success": True,
|
||||
"message": f"Geste '{gesture.name}' ({'+'.join(gesture.keys)}) envoyé",
|
||||
})
|
||||
else:
|
||||
error = resp.text[:200]
|
||||
socketio.emit('execution_completed', {
|
||||
_emit_dual('execution_completed', 'done', {
|
||||
"workflow": gesture.name,
|
||||
"success": False,
|
||||
"message": f"Erreur: {error}",
|
||||
})
|
||||
|
||||
except http_requests.ConnectionError:
|
||||
socketio.emit('execution_completed', {
|
||||
_emit_dual('execution_completed', 'done', {
|
||||
"workflow": gesture.name,
|
||||
"success": False,
|
||||
"message": "Serveur de streaming non disponible (port 5005).",
|
||||
})
|
||||
except Exception as e:
|
||||
logger.error(f"Gesture execution error: {e}")
|
||||
socketio.emit('execution_completed', {
|
||||
_emit_dual('execution_completed', 'done', {
|
||||
"workflow": gesture.name,
|
||||
"success": False,
|
||||
"message": f"Erreur: {str(e)}",
|
||||
@@ -1661,6 +1683,52 @@ def handle_copilot_abort():
|
||||
})
|
||||
|
||||
|
||||
# =============================================================================
|
||||
# Bulle paused_need_help — handlers SocketIO depuis ChatWindow (J3.5)
|
||||
# =============================================================================
|
||||
|
||||
@socketio.on('lea:replay_resume')
|
||||
def handle_lea_replay_resume(data):
|
||||
"""Bouton Continuer : relayer le resume vers le streaming server."""
|
||||
replay_id = (data or {}).get("replay_id")
|
||||
if not replay_id:
|
||||
_emit_lea("resume_acked", {"status": "error", "detail": "replay_id manquant"})
|
||||
return
|
||||
try:
|
||||
resp = http_requests.post(
|
||||
f"{STREAMING_SERVER_URL}/api/v1/traces/stream/replay/{replay_id}/resume",
|
||||
headers=_streaming_headers(),
|
||||
timeout=5,
|
||||
)
|
||||
if resp.ok:
|
||||
logger.info(f"Replay {replay_id} resume relayé OK")
|
||||
_emit_lea("resume_acked", {"replay_id": replay_id, "status": "ok"})
|
||||
else:
|
||||
detail = resp.text[:200]
|
||||
logger.warning(f"Resume échoué (HTTP {resp.status_code}): {detail}")
|
||||
_emit_lea("resume_acked", {
|
||||
"replay_id": replay_id, "status": "error",
|
||||
"http_status": resp.status_code, "detail": detail,
|
||||
})
|
||||
except Exception as e:
|
||||
logger.warning(f"Resume relay error: {e}")
|
||||
_emit_lea("resume_acked", {
|
||||
"replay_id": replay_id, "status": "error", "detail": str(e),
|
||||
})
|
||||
|
||||
|
||||
@socketio.on('lea:replay_abort')
|
||||
def handle_lea_replay_abort(data):
|
||||
"""Bouton Annuler : arrêter le polling local. Le replay côté streaming sera
|
||||
cleaned up naturellement au prochain replay (cf api_stream._replay_states stale)."""
|
||||
global execution_status
|
||||
replay_id = (data or {}).get("replay_id")
|
||||
execution_status["running"] = False
|
||||
execution_status["message"] = "Annulé par l'utilisateur"
|
||||
logger.info(f"Replay {replay_id or '?'} abort par l'utilisateur (paused bubble)")
|
||||
_emit_lea("abort_acked", {"replay_id": replay_id, "status": "ok"})
|
||||
|
||||
|
||||
# =============================================================================
|
||||
# Exécution de workflow
|
||||
# =============================================================================
|
||||
@@ -1730,14 +1798,20 @@ def _poll_replay_progress(replay_id: str, workflow_name: str, total_actions: int
|
||||
"""Suivre la progression d'un replay distant via polling."""
|
||||
import time
|
||||
|
||||
max_wait = 120 # 2 minutes max
|
||||
max_wait_running = 120 # 2 min en exécution active
|
||||
max_wait_paused = 600 # 10 min en pause supervisée (humain peut prendre son temps)
|
||||
poll_interval = 2.0
|
||||
elapsed = 0
|
||||
was_paused = False
|
||||
|
||||
while elapsed < max_wait and execution_status.get("running"):
|
||||
while execution_status.get("running"):
|
||||
time.sleep(poll_interval)
|
||||
elapsed += poll_interval
|
||||
|
||||
cap = max_wait_paused if was_paused else max_wait_running
|
||||
if elapsed >= cap:
|
||||
break
|
||||
|
||||
try:
|
||||
resp = http_requests.get(
|
||||
f"{STREAMING_SERVER_URL}/api/v1/traces/stream/replay/{replay_id}",
|
||||
@@ -1753,7 +1827,26 @@ def _poll_replay_progress(replay_id: str, workflow_name: str, total_actions: int
|
||||
failed = data.get("failed_actions", 0)
|
||||
progress = int(10 + (completed / max(total_actions, 1)) * 80)
|
||||
|
||||
socketio.emit('execution_progress', {
|
||||
if status == "paused_need_help" and not was_paused:
|
||||
_emit_lea("paused", {
|
||||
"workflow": workflow_name,
|
||||
"replay_id": replay_id,
|
||||
"completed": completed,
|
||||
"total": total_actions,
|
||||
"failed_action": data.get("failed_action"),
|
||||
"reason": data.get("error") or "Action incertaine",
|
||||
})
|
||||
was_paused = True
|
||||
elapsed = 0
|
||||
elif was_paused and status != "paused_need_help":
|
||||
_emit_lea("resumed", {
|
||||
"workflow": workflow_name,
|
||||
"replay_id": replay_id,
|
||||
"status_after": status,
|
||||
})
|
||||
was_paused = False
|
||||
|
||||
_emit_dual('execution_progress', 'action_progress', {
|
||||
"progress": progress,
|
||||
"step": f"Action {completed}/{total_actions} exécutée",
|
||||
"current": completed,
|
||||
@@ -1922,7 +2015,7 @@ def execute_workflow_copilot(match, params: Dict[str, Any]):
|
||||
|
||||
actions = _build_actions_from_workflow(match, params)
|
||||
if not actions:
|
||||
socketio.emit('copilot_complete', {
|
||||
_emit_dual('copilot_complete', 'done', {
|
||||
"workflow": workflow_name,
|
||||
"status": "error",
|
||||
"message": "Aucune action exécutable dans ce workflow.",
|
||||
@@ -1959,7 +2052,7 @@ def execute_workflow_copilot(match, params: Dict[str, Any]):
|
||||
break
|
||||
|
||||
copilot_state["status"] = "waiting_approval"
|
||||
socketio.emit('copilot_step', {
|
||||
_emit_dual('copilot_step', 'need_confirm', {
|
||||
"workflow": workflow_name,
|
||||
"step_index": idx,
|
||||
"total": total,
|
||||
@@ -1982,7 +2075,7 @@ def execute_workflow_copilot(match, params: Dict[str, Any]):
|
||||
|
||||
if waited >= max_wait:
|
||||
copilot_state["status"] = "aborted"
|
||||
socketio.emit('copilot_complete', {
|
||||
_emit_dual('copilot_complete', 'done', {
|
||||
"workflow": workflow_name,
|
||||
"status": "timeout",
|
||||
"message": f"Timeout : pas de réponse après {max_wait}s.",
|
||||
@@ -1999,7 +2092,7 @@ def execute_workflow_copilot(match, params: Dict[str, Any]):
|
||||
elif decision == "skipped":
|
||||
copilot_state["skipped"] += 1
|
||||
logger.info(f"Copilot skip étape {idx + 1}/{total}")
|
||||
socketio.emit('copilot_step_result', {
|
||||
_emit_dual('copilot_step_result', 'step_result', {
|
||||
"step_index": idx,
|
||||
"total": total,
|
||||
"status": "skipped",
|
||||
@@ -2034,7 +2127,7 @@ def execute_workflow_copilot(match, params: Dict[str, Any]):
|
||||
|
||||
if action_success:
|
||||
copilot_state["completed"] += 1
|
||||
socketio.emit('copilot_step_result', {
|
||||
_emit_dual('copilot_step_result', 'step_result', {
|
||||
"step_index": idx,
|
||||
"total": total,
|
||||
"status": "completed",
|
||||
@@ -2042,7 +2135,7 @@ def execute_workflow_copilot(match, params: Dict[str, Any]):
|
||||
})
|
||||
else:
|
||||
copilot_state["failed"] += 1
|
||||
socketio.emit('copilot_step_result', {
|
||||
_emit_dual('copilot_step_result', 'step_result', {
|
||||
"step_index": idx,
|
||||
"total": total,
|
||||
"status": "failed",
|
||||
@@ -2051,7 +2144,7 @@ def execute_workflow_copilot(match, params: Dict[str, Any]):
|
||||
else:
|
||||
error = resp.text[:200]
|
||||
copilot_state["failed"] += 1
|
||||
socketio.emit('copilot_step_result', {
|
||||
_emit_dual('copilot_step_result', 'step_result', {
|
||||
"step_index": idx,
|
||||
"total": total,
|
||||
"status": "failed",
|
||||
@@ -2060,7 +2153,7 @@ def execute_workflow_copilot(match, params: Dict[str, Any]):
|
||||
|
||||
except http_requests.ConnectionError:
|
||||
copilot_state["failed"] += 1
|
||||
socketio.emit('copilot_step_result', {
|
||||
_emit_dual('copilot_step_result', 'step_result', {
|
||||
"step_index": idx,
|
||||
"total": total,
|
||||
"status": "failed",
|
||||
@@ -2070,7 +2163,7 @@ def execute_workflow_copilot(match, params: Dict[str, Any]):
|
||||
except Exception as e:
|
||||
copilot_state["failed"] += 1
|
||||
logger.error(f"Copilot action error: {e}")
|
||||
socketio.emit('copilot_step_result', {
|
||||
_emit_dual('copilot_step_result', 'step_result', {
|
||||
"step_index": idx,
|
||||
"total": total,
|
||||
"status": "failed",
|
||||
@@ -2098,7 +2191,7 @@ def execute_workflow_copilot(match, params: Dict[str, Any]):
|
||||
f"Copilot terminé : {completed} réussies, "
|
||||
f"{skipped} passées, {failed} échouées sur {total} étapes."
|
||||
)
|
||||
socketio.emit('copilot_complete', {
|
||||
_emit_dual('copilot_complete', 'done', {
|
||||
"workflow": workflow_name,
|
||||
"status": "completed" if success else "partial",
|
||||
"message": message,
|
||||
@@ -2175,7 +2268,7 @@ def execute_workflow(match, params):
|
||||
execution_status["progress"] = 10
|
||||
execution_status["message"] = f"Envoyé à l'Agent V1 ({target_session})"
|
||||
|
||||
socketio.emit('execution_progress', {
|
||||
_emit_dual('execution_progress', 'action_progress', {
|
||||
"progress": 10,
|
||||
"step": f"Replay envoyé à l'Agent V1 — {total_actions} actions en attente",
|
||||
"current": 0,
|
||||
@@ -2523,7 +2616,7 @@ def update_progress(progress: int, message: str, current: int, total: int):
|
||||
execution_status["progress"] = progress
|
||||
execution_status["message"] = message
|
||||
|
||||
socketio.emit('execution_progress', {
|
||||
_emit_dual('execution_progress', 'action_progress', {
|
||||
"progress": progress,
|
||||
"step": message,
|
||||
"current": current,
|
||||
@@ -2543,7 +2636,7 @@ def finish_execution(workflow_name: str, success: bool, message: str):
|
||||
if command_history:
|
||||
command_history[-1]["status"] = "completed" if success else "failed"
|
||||
|
||||
socketio.emit('execution_completed', {
|
||||
_emit_dual('execution_completed', 'done', {
|
||||
"workflow": workflow_name,
|
||||
"success": success,
|
||||
"message": message
|
||||
|
||||
@@ -147,8 +147,10 @@ class AutonomousPlanner:
|
||||
"""Initialise le client VLM pour analyse intelligente."""
|
||||
if VLM_AVAILABLE and OllamaClient:
|
||||
try:
|
||||
self._vlm_client = OllamaClient(model="qwen2.5vl:7b")
|
||||
logger.info("VLM client initialized (qwen2.5vl:7b)")
|
||||
from core.detection.vlm_config import get_vlm_model
|
||||
_planner_vlm = get_vlm_model()
|
||||
self._vlm_client = OllamaClient(model=_planner_vlm)
|
||||
logger.info("VLM client initialized (%s)", _planner_vlm)
|
||||
except Exception as e:
|
||||
logger.warning(f"Could not initialize VLM client: {e}")
|
||||
self._vlm_client = None
|
||||
|
||||
@@ -40,10 +40,18 @@ MACHINE_ID = os.environ.get(
|
||||
BASE_DIR = Path(__file__).resolve().parent
|
||||
|
||||
# Endpoint du serveur Streaming (port 5005)
|
||||
# SERVER_URL contient TOUJOURS /api/v1 à la fin (convention unifiée).
|
||||
SERVER_URL = os.getenv("RPA_SERVER_URL", "http://localhost:5005/api/v1")
|
||||
# Base sans /api/v1 — pour les routes à la racine (/health)
|
||||
SERVER_BASE = SERVER_URL.rsplit("/api/v1", 1)[0]
|
||||
UPLOAD_ENDPOINT = f"{SERVER_URL}/traces/upload"
|
||||
STREAMING_ENDPOINT = f"{SERVER_URL}/traces/stream"
|
||||
|
||||
# Host Ollama — SÉPARÉ du serveur RPA.
|
||||
# Ollama tourne en local sur la machine serveur, jamais exposé via le reverse proxy.
|
||||
# Défaut : localhost (exécution locale ou accès LAN direct).
|
||||
OLLAMA_HOST = os.getenv("RPA_OLLAMA_HOST", "localhost")
|
||||
|
||||
# Token d'authentification API (doit correspondre au token du serveur)
|
||||
# Configurable via variable d'environnement RPA_API_TOKEN
|
||||
API_TOKEN = os.environ.get("RPA_API_TOKEN", "")
|
||||
|
||||
@@ -20,6 +20,7 @@ import os
|
||||
import threading
|
||||
import time
|
||||
import logging
|
||||
from typing import Any, Dict, Optional
|
||||
|
||||
# Forcer l'import de config AVANT pynput/mss pour garantir que le
|
||||
# DPI awareness est configure (SetProcessDpiAwareness(2) sur Windows).
|
||||
@@ -88,6 +89,11 @@ class ActionExecutorV1:
|
||||
self._api_token = os.environ.get("RPA_API_TOKEN", "")
|
||||
# Gestionnaire de notifications toast (pour les messages utilisateur)
|
||||
self._notification_manager = None
|
||||
# Drapeau sécurité : positionné quand on détecte un dialogue système
|
||||
# (UAC, CredUI, SmartScreen…). Lu par le caller pour signaler une
|
||||
# pause supervisée au serveur (`paused_need_help`).
|
||||
# Cf. core/system_dialog_guard.py
|
||||
self._system_dialog_pause: Optional[Dict[str, Any]] = None
|
||||
# Log de la resolution physique pour le diagnostic DPI
|
||||
self._log_screen_info()
|
||||
|
||||
@@ -471,9 +477,15 @@ class ActionExecutorV1:
|
||||
},
|
||||
headers=headers,
|
||||
timeout=10,
|
||||
allow_redirects=False,
|
||||
)
|
||||
|
||||
if resp.ok:
|
||||
if resp.status_code in (301, 302, 307, 308):
|
||||
logger.warning(
|
||||
f"Redirection {resp.status_code} sur POST {url} — "
|
||||
f"verifiez RPA_SERVER_URL (https:// si redirect)"
|
||||
)
|
||||
elif resp.ok:
|
||||
data = resp.json()
|
||||
state = data.get("screen_state", "ok")
|
||||
if state != "ok":
|
||||
@@ -537,6 +549,11 @@ class ActionExecutorV1:
|
||||
"visual_resolved": False,
|
||||
}
|
||||
|
||||
# Réinitialiser le drapeau dialogue système à chaque action
|
||||
# (sinon une détection lors d'une action précédente ferait bail-out
|
||||
# immédiat sur toutes les suivantes).
|
||||
self._system_dialog_pause = None
|
||||
|
||||
# ── Bloc conditionnel : skip si le dialogue n'est pas apparu ──
|
||||
# Les actions marquées conditional_on_window ne s'exécutent que
|
||||
# si la fenêtre attendue est effectivement présente. Sinon → skip.
|
||||
@@ -594,6 +611,23 @@ class ActionExecutorV1:
|
||||
f"{int(action.get('y_pct', 0) * height)})"
|
||||
)
|
||||
|
||||
# ── SÉCURITÉ : check proactif AVANT toute action ──
|
||||
# Si un UAC / CredUI / SmartScreen est déjà à l'écran (apparu
|
||||
# spontanément entre deux actions), on pause IMMÉDIATEMENT
|
||||
# sans rien tenter. Clic / type / key_combo : tous bloqués.
|
||||
# Cf. core/system_dialog_guard.py
|
||||
if action_type in ("click", "type", "key_combo", "double_click", "right_click"):
|
||||
if self._check_and_pause_on_system_dialog(context=f"pre_action_{action_type}"):
|
||||
pause_info = self._system_dialog_pause or {}
|
||||
result["success"] = False
|
||||
result["error"] = (
|
||||
f"system_dialog:{pause_info.get('category', 'unknown')}"
|
||||
)
|
||||
result["system_dialog"] = pause_info
|
||||
result["needs_human"] = True
|
||||
result["screenshot"] = self._capture_screenshot_b64()
|
||||
return result
|
||||
|
||||
# Resolution visuelle des coordonnees si demande
|
||||
x_pct = action.get("x_pct", 0.0)
|
||||
y_pct = action.get("y_pct", 0.0)
|
||||
@@ -675,7 +709,11 @@ class ActionExecutorV1:
|
||||
f"attendu '{expected_title}' → mode apprentissage"
|
||||
)
|
||||
try:
|
||||
self.notifier.replay_wrong_window(current_title, expected_title)
|
||||
self.notifier.replay_learning_mode(
|
||||
raison="wrong_window",
|
||||
target_description=expected_title,
|
||||
window_title=current_title,
|
||||
)
|
||||
except Exception:
|
||||
pass
|
||||
|
||||
@@ -737,6 +775,27 @@ class ActionExecutorV1:
|
||||
popup_coords = observation.get("popup_coords")
|
||||
print(f" [OBSERVER] Popup détectée : '{popup_label}' — fermeture")
|
||||
logger.info(f"Observer : popup '{popup_label}' détectée avant résolution")
|
||||
|
||||
# ── SÉCURITÉ : refuser de cliquer sur un dialogue système ──
|
||||
# Avant de suivre les coordonnées du serveur (VLM-based,
|
||||
# donc faillible) ou de rappeler le VLM local, on
|
||||
# vérifie que la popup n'est PAS un UAC/CredUI/SmartScreen.
|
||||
if self._check_and_pause_on_system_dialog(
|
||||
context="observer_popup"
|
||||
):
|
||||
# Dialogue système → on remonte la pause au replay.
|
||||
# On renvoie le résultat immédiatement pour que le
|
||||
# serveur passe en paused_need_help.
|
||||
pause_info = self._system_dialog_pause or {}
|
||||
result["success"] = False
|
||||
result["error"] = (
|
||||
f"system_dialog:{pause_info.get('category', 'unknown')}"
|
||||
)
|
||||
result["system_dialog"] = pause_info
|
||||
result["needs_human"] = True
|
||||
result["screenshot"] = self._capture_screenshot_b64()
|
||||
return result
|
||||
|
||||
if popup_coords:
|
||||
real_x = int(popup_coords["x_pct"] * width)
|
||||
real_y = int(popup_coords["y_pct"] * height)
|
||||
@@ -745,7 +804,20 @@ class ActionExecutorV1:
|
||||
print(f" [OBSERVER] Popup fermée — reprise du flow normal")
|
||||
else:
|
||||
# Pas de coordonnées → fallback sur handle_popup_vlm classique
|
||||
# (qui re-vérifie aussi system_dialog en interne)
|
||||
self._handle_popup_vlm()
|
||||
# Si _handle_popup_vlm a détecté un dialogue système,
|
||||
# on remonte la pause au replay.
|
||||
if self._system_dialog_pause:
|
||||
pause_info = self._system_dialog_pause
|
||||
result["success"] = False
|
||||
result["error"] = (
|
||||
f"system_dialog:{pause_info.get('category', 'unknown')}"
|
||||
)
|
||||
result["system_dialog"] = pause_info
|
||||
result["needs_human"] = True
|
||||
result["screenshot"] = self._capture_screenshot_b64()
|
||||
return result
|
||||
|
||||
elif obs_state == "unexpected":
|
||||
# État inattendu (pas la bonne page/écran)
|
||||
@@ -840,6 +912,24 @@ class ActionExecutorV1:
|
||||
f"({policy_decision.reason})"
|
||||
)
|
||||
|
||||
# ── SÉCURITÉ : si Policy a détecté un dialogue système
|
||||
# pendant son _try_close_popup, on remonte la pause au
|
||||
# serveur SANS tenter aucune action supplémentaire.
|
||||
if self._system_dialog_pause:
|
||||
pause_info = self._system_dialog_pause
|
||||
logger.critical(
|
||||
f"[POLICY] Dialogue système détecté par popup handler "
|
||||
f"({pause_info.get('category')}) — pause supervisée"
|
||||
)
|
||||
result["success"] = False
|
||||
result["error"] = (
|
||||
f"system_dialog:{pause_info.get('category', 'unknown')}"
|
||||
)
|
||||
result["system_dialog"] = pause_info
|
||||
result["needs_human"] = True
|
||||
result["screenshot"] = self._capture_screenshot_b64()
|
||||
return result
|
||||
|
||||
if policy_decision.decision == Decision.RETRY:
|
||||
resolved2 = self._resolve_target_visual(
|
||||
server_url, target_spec, x_pct, y_pct, width, height
|
||||
@@ -855,9 +945,10 @@ class ActionExecutorV1:
|
||||
# et ne trouve toujours pas. L'humain doit montrer.
|
||||
print(f" [POLICY] Retry échoué → mode apprentissage")
|
||||
try:
|
||||
self.notifier.replay_target_not_found(
|
||||
target_desc,
|
||||
target_spec.get("window_title", ""),
|
||||
self.notifier.replay_learning_mode(
|
||||
raison="retry_failed",
|
||||
target_description=target_desc,
|
||||
window_title=target_spec.get("window_title", ""),
|
||||
)
|
||||
except Exception:
|
||||
pass
|
||||
@@ -913,9 +1004,10 @@ class ActionExecutorV1:
|
||||
# passe en mode capture et enregistre ce que
|
||||
# l'humain fait (mini-workflow de correction).
|
||||
try:
|
||||
self.notifier.replay_target_not_found(
|
||||
target_desc,
|
||||
target_spec.get("window_title", ""),
|
||||
self.notifier.replay_learning_mode(
|
||||
raison="supervise",
|
||||
target_description=target_desc,
|
||||
window_title=target_spec.get("window_title", ""),
|
||||
)
|
||||
except Exception:
|
||||
pass
|
||||
@@ -1141,7 +1233,9 @@ class ActionExecutorV1:
|
||||
f"je demande de l'aide"
|
||||
)
|
||||
try:
|
||||
self.notifier.replay_no_screen_change(action_type)
|
||||
self.notifier.replay_learning_mode(
|
||||
raison="no_screen_change",
|
||||
)
|
||||
except Exception:
|
||||
pass
|
||||
|
||||
@@ -1297,7 +1391,13 @@ class ActionExecutorV1:
|
||||
|
||||
try:
|
||||
print(f" [SERVER-RESOLVE] Appel serveur {server_url}...")
|
||||
resp = _requests.post(url, json=payload, headers=headers, timeout=30)
|
||||
resp = _requests.post(url, json=payload, headers=headers, timeout=30, allow_redirects=False)
|
||||
if resp.status_code in (301, 302, 307, 308):
|
||||
logger.warning(
|
||||
f"Redirection {resp.status_code} sur POST {url} — "
|
||||
f"verifiez RPA_SERVER_URL (https:// si redirect)"
|
||||
)
|
||||
return None
|
||||
if not resp.ok:
|
||||
logger.warning(f"Server resolve HTTP {resp.status_code}")
|
||||
return None
|
||||
@@ -1441,7 +1541,7 @@ class ActionExecutorV1:
|
||||
if not vlm_description:
|
||||
return None
|
||||
|
||||
ollama_host = os.environ.get("RPA_SERVER_HOST", "localhost")
|
||||
ollama_host = os.environ.get("RPA_OLLAMA_HOST", "localhost")
|
||||
ollama_url = f"http://{ollama_host}:11434/api/chat"
|
||||
|
||||
prompt = (
|
||||
@@ -1577,7 +1677,7 @@ Example: x_pct=0.50, y_pct=0.30"""
|
||||
if anchor_b64:
|
||||
images.append(anchor_b64)
|
||||
|
||||
ollama_host = os.environ.get("RPA_SERVER_HOST", "localhost")
|
||||
ollama_host = os.environ.get("RPA_OLLAMA_HOST", "localhost")
|
||||
ollama_url = f"http://{ollama_host}:11434/api/chat"
|
||||
|
||||
# Prefill pour les modèles thinking (qwen3) — évite le mode réflexion >180s
|
||||
@@ -1771,6 +1871,9 @@ Example: x_pct=0.50, y_pct=0.30"""
|
||||
"target_spec": result.get("target_spec"),
|
||||
# Correction humaine (mode apprentissage supervisé)
|
||||
"correction": result.get("correction"),
|
||||
# Sécurité : dialogue système critique détecté (UAC, CredUI, SmartScreen)
|
||||
"system_dialog": result.get("system_dialog"),
|
||||
"needs_human": result.get("needs_human"),
|
||||
}
|
||||
try:
|
||||
resp2 = requests.post(
|
||||
@@ -1778,8 +1881,14 @@ Example: x_pct=0.50, y_pct=0.30"""
|
||||
json=report,
|
||||
headers=self._auth_headers(),
|
||||
timeout=10,
|
||||
allow_redirects=False,
|
||||
)
|
||||
if resp2.ok:
|
||||
if resp2.status_code in (301, 302, 307, 308):
|
||||
logger.warning(
|
||||
f"Redirection {resp2.status_code} sur POST {replay_result_url} — "
|
||||
f"verifiez RPA_SERVER_URL (https:// si redirect)"
|
||||
)
|
||||
elif resp2.ok:
|
||||
server_resp = resp2.json()
|
||||
msg = (
|
||||
f"Resultat rapporte : replay_status={server_resp.get('replay_status')}, "
|
||||
@@ -1796,6 +1905,129 @@ Example: x_pct=0.50, y_pct=0.30"""
|
||||
|
||||
return True
|
||||
|
||||
# =========================================================================
|
||||
# Garde-fou sécurité : dialogues système Windows (UAC, CredUI, SmartScreen)
|
||||
# =========================================================================
|
||||
|
||||
def _check_and_pause_on_system_dialog(self, context: str = "") -> bool:
|
||||
"""Détecter un dialogue système critique et positionner la pause.
|
||||
|
||||
Si un dialogue UAC, CredUI, SmartScreen (etc.) est actif, on :
|
||||
- N'appelle JAMAIS le VLM sur l'image (évite de lui faire suggérer "Oui")
|
||||
- Ne clique JAMAIS automatiquement
|
||||
- Positionne `self._system_dialog_pause` pour que le caller signale
|
||||
une pause supervisée au serveur
|
||||
- Notifie l'utilisateur via systray
|
||||
- Log l'événement pour audit
|
||||
|
||||
Args:
|
||||
context: Chaîne d'origine pour les logs (ex: "handle_popup_vlm",
|
||||
"observer_popup_click").
|
||||
|
||||
Returns:
|
||||
True si un dialogue système a été détecté (le caller doit
|
||||
stopper toute action automatique). False sinon.
|
||||
"""
|
||||
try:
|
||||
from .system_dialog_guard import detect_current_system_dialog
|
||||
detection = detect_current_system_dialog()
|
||||
except Exception as e:
|
||||
# Fix P0-D : fail-closed (principe "faux positif tolérable,
|
||||
# faux négatif catastrophique"). Si la détection échoue, on ne
|
||||
# peut PAS affirmer que l'écran est sûr — on pause par précaution
|
||||
# et on demande à l'humain. Un UAC non détecté à cause d'un bug
|
||||
# de détection = vecteur d'attaque ransomware.
|
||||
logger.critical(
|
||||
f"[SYS-DIALOG] Erreur détection dialogue système "
|
||||
f"(context={context}) : {e} — PAUSE SUPERVISÉE par précaution "
|
||||
f"(fail-closed : impossible de garantir l'absence de dialogue "
|
||||
f"système critique)"
|
||||
)
|
||||
print(
|
||||
f" [SÉCURITÉ] Vérification du garde-fou système a échoué "
|
||||
f"— pause supervisée par précaution ({type(e).__name__})"
|
||||
)
|
||||
# Positionner le flag de pause avec une catégorie dédiée pour que
|
||||
# le caller (execute_replay_action) remonte "paused_need_help".
|
||||
self._system_dialog_pause = {
|
||||
"category": "unknown_check_failed",
|
||||
"matched_signal": "exception",
|
||||
"matched_value": type(e).__name__,
|
||||
"reason": f"system_dialog_guard détection exception: {e}",
|
||||
"context": context,
|
||||
}
|
||||
# Notification utilisateur best-effort.
|
||||
try:
|
||||
notifier = self.notifier
|
||||
msg = (
|
||||
"Vérification du garde-fou système a échoué — "
|
||||
"pause supervisée par précaution. Léa ne clique pas."
|
||||
)
|
||||
if hasattr(notifier, "notify"):
|
||||
notifier.notify(
|
||||
title="Léa — sécurité",
|
||||
message=msg,
|
||||
timeout=10,
|
||||
)
|
||||
elif hasattr(notifier, "error"):
|
||||
notifier.error(msg)
|
||||
except Exception as notify_err:
|
||||
logger.debug(f"[SYS-DIALOG] Notification échouée : {notify_err}")
|
||||
return True
|
||||
|
||||
if not detection.is_system_dialog:
|
||||
return False
|
||||
|
||||
# Audit log : TOUJOURS tracer, même si la pause est redondante.
|
||||
logger.critical(
|
||||
f"[SYS-DIALOG] REFUS D'INTERACTION — {detection.category} "
|
||||
f"détecté via {detection.matched_signal}='{detection.matched_value}' "
|
||||
f"(context={context}). Pause supervisée demandée."
|
||||
)
|
||||
print(
|
||||
f" [SÉCURITÉ] Dialogue système détecté : {detection.category} "
|
||||
f"— Léa NE CLIQUE PAS, intervention humaine requise"
|
||||
)
|
||||
|
||||
# Positionner le flag pour le caller (execute_replay_action)
|
||||
self._system_dialog_pause = {
|
||||
"category": detection.category,
|
||||
"matched_signal": detection.matched_signal,
|
||||
"matched_value": detection.matched_value,
|
||||
"reason": detection.reason,
|
||||
"context": context,
|
||||
}
|
||||
|
||||
# Notification systray (best-effort, ne jamais planter dessus)
|
||||
try:
|
||||
cat_fr = {
|
||||
"uac_consent": "élévation de privilèges (UAC)",
|
||||
"windows_credential_prompt": "demande de mot de passe Windows",
|
||||
"smartscreen": "alerte SmartScreen",
|
||||
"windows_defender": "alerte Windows Defender",
|
||||
"driver_install": "installation de pilote",
|
||||
"security_toast": "notification de sécurité",
|
||||
"unknown_system_dialog": "dialogue système inconnu",
|
||||
}.get(detection.category, detection.category)
|
||||
msg = (
|
||||
f"Dialogue système détecté ({cat_fr}) — "
|
||||
f"intervention humaine requise. Léa ne clique pas."
|
||||
)
|
||||
# On essaie d'abord un formateur explicite ; sinon fallback error
|
||||
notifier = self.notifier
|
||||
if hasattr(notifier, "notify"):
|
||||
notifier.notify(
|
||||
title="Léa — sécurité",
|
||||
message=msg,
|
||||
timeout=10,
|
||||
)
|
||||
elif hasattr(notifier, "error"):
|
||||
notifier.error(msg)
|
||||
except Exception as e:
|
||||
logger.debug(f"[SYS-DIALOG] Notification échouée : {e}")
|
||||
|
||||
return True
|
||||
|
||||
# =========================================================================
|
||||
# Gestion intelligente des popups imprévues (VLM)
|
||||
# =========================================================================
|
||||
@@ -1817,9 +2049,22 @@ Example: x_pct=0.50, y_pct=0.30"""
|
||||
|
||||
Une seule tentative par action (pas de boucle infinie).
|
||||
|
||||
**SÉCURITÉ** : avant toute interaction, on détecte les dialogues
|
||||
système Windows critiques (UAC, CredUI, SmartScreen). Si un tel
|
||||
dialogue est actif → pause supervisée immédiate, pas de VLM, pas
|
||||
de clic automatique. Cf. system_dialog_guard.py.
|
||||
|
||||
Returns:
|
||||
True si une popup a été gérée (fermée), False sinon.
|
||||
False aussi en cas de dialogue système → le caller doit traiter
|
||||
`self._system_dialog_pause` pour signaler la pause au serveur.
|
||||
"""
|
||||
# ── SÉCURITÉ : refus absolu de cliquer sur un dialogue système ──
|
||||
# Un UAC / CredUI / SmartScreen ne doit JAMAIS recevoir de clic
|
||||
# automatique. On détecte AVANT le VLM (coût minimal ~20ms UIA).
|
||||
if self._check_and_pause_on_system_dialog(context="handle_popup_vlm"):
|
||||
return False
|
||||
|
||||
# Capturer le screenshot actuel (résolution native pour template matching)
|
||||
screenshot_b64 = self._capture_screenshot_b64(max_width=0, quality=75)
|
||||
if not screenshot_b64:
|
||||
@@ -1909,7 +2154,7 @@ Example: x_pct=0.50, y_pct=0.30"""
|
||||
"""
|
||||
import requests as _requests
|
||||
|
||||
ollama_host = os.environ.get("RPA_SERVER_HOST", "localhost")
|
||||
ollama_host = os.environ.get("RPA_OLLAMA_HOST", "localhost")
|
||||
ollama_url = f"http://{ollama_host}:11434/api/chat"
|
||||
|
||||
prompt = (
|
||||
@@ -1935,8 +2180,11 @@ Example: x_pct=0.50, y_pct=0.30"""
|
||||
},
|
||||
{"role": "user", "content": prompt, "images": [screenshot_b64]},
|
||||
]
|
||||
# Prefill pour les modèles "thinking" (qwen3-vl) : force la sortie à commencer
|
||||
# par cette chaîne, évite les longs blocs de raisonnement interne.
|
||||
prefill = "The button to click is: " if _is_thinking_popup else ""
|
||||
if _is_thinking_popup:
|
||||
messages_popup.append({"role": "assistant", "content": "The button to click is: "})
|
||||
messages_popup.append({"role": "assistant", "content": prefill})
|
||||
|
||||
payload = {
|
||||
"model": _vlm_model_popup,
|
||||
@@ -2353,8 +2601,8 @@ Example: x_pct=0.50, y_pct=0.30"""
|
||||
f"inactivité={INACTIVITY_TIMEOUT}s, hotkey=Ctrl+Shift+L)"
|
||||
)
|
||||
print(
|
||||
f" [APPRENTISSAGE] Montre-moi comment faire.\n"
|
||||
f" Quand tu as fini → Ctrl+Shift+L\n"
|
||||
f" [APPRENTISSAGE] Je n'y arrive pas, montrez-moi comment faire.\n"
|
||||
f" Quand vous avez fini → Ctrl+Shift+L\n"
|
||||
f" (ou j'attends {INACTIVITY_TIMEOUT}s sans action)"
|
||||
)
|
||||
|
||||
|
||||
@@ -85,6 +85,10 @@ class PolicyEngine:
|
||||
2. Si retry déjà fait → demander à l'acteur gemma4
|
||||
3. Selon gemma4 : SKIP, ABORT, ou SUPERVISE
|
||||
|
||||
**SÉCURITÉ** : si, pendant l'étape 1, le handler popup détecte un
|
||||
dialogue système Windows (UAC, CredUI, SmartScreen…), on bascule
|
||||
immédiatement en SUPERVISE. Cf. system_dialog_guard.py.
|
||||
|
||||
Args:
|
||||
action: L'action qui a échoué
|
||||
target_spec: La cible non trouvée
|
||||
@@ -96,6 +100,22 @@ class PolicyEngine:
|
||||
# ── Étape 1 : Tentative de fermeture popup (premier essai) ──
|
||||
if retry_count == 0:
|
||||
popup_handled = self._try_close_popup()
|
||||
|
||||
# Si le popup handler a détecté un dialogue système, on
|
||||
# bascule immédiatement en SUPERVISE — pas de retry, pas de
|
||||
# gemma4 : on rend la main à l'humain.
|
||||
if getattr(self._executor, "_system_dialog_pause", None):
|
||||
sd = self._executor._system_dialog_pause
|
||||
return PolicyDecision(
|
||||
decision=Decision.SUPERVISE,
|
||||
reason=(
|
||||
f"Dialogue système détecté ({sd.get('category', '?')}) — "
|
||||
f"refus d'interaction automatique"
|
||||
),
|
||||
action_taken="system_dialog_blocked",
|
||||
elapsed_ms=(time.time() - t_start) * 1000,
|
||||
)
|
||||
|
||||
if popup_handled:
|
||||
return PolicyDecision(
|
||||
decision=Decision.RETRY,
|
||||
|
||||
448
agent_v0/agent_v1/core/system_dialog_guard.py
Normal file
448
agent_v0/agent_v1/core/system_dialog_guard.py
Normal file
@@ -0,0 +1,448 @@
|
||||
# agent_v1/core/system_dialog_guard.py
|
||||
"""
|
||||
Garde-fou sécurité : détection des dialogues système Windows critiques.
|
||||
|
||||
==============================================================================
|
||||
POURQUOI ?
|
||||
==============================================================================
|
||||
|
||||
Pendant un replay, si un dialogue UAC, CredUI (mot de passe Windows),
|
||||
SmartScreen ou une notification de sécurité Windows apparaît, Léa pourrait
|
||||
demander au VLM "quel bouton cliquer" et recevoir "Oui" en réponse.
|
||||
|
||||
→ **Léa cliquerait OUI sur une élévation UAC** → vecteur d'attaque ransomware.
|
||||
|
||||
Ce module fournit la détection de ces dialogues pour que l'exécuteur
|
||||
**ne clique JAMAIS dessus automatiquement**. La décision est renvoyée à
|
||||
l'humain (pause supervisée).
|
||||
|
||||
==============================================================================
|
||||
PRINCIPE
|
||||
==============================================================================
|
||||
|
||||
- **Faux positif tolérable** : on préfère pauser pour rien plutôt que cliquer
|
||||
sur un UAC.
|
||||
- **Faux négatif catastrophique** : mieux vaut être trop prudent.
|
||||
- **Multi-signal** : titre, ClassName UIA, nom de processus, parent_path.
|
||||
Un seul signal suffit à bloquer.
|
||||
- **Compatible Citrix** : les dialogues UAC d'un client Citrix apparaissent
|
||||
aussi dans la VM distante — la détection par classe UIA fonctionne.
|
||||
|
||||
==============================================================================
|
||||
PATTERNS DE DÉTECTION (ordre de criticité décroissant)
|
||||
==============================================================================
|
||||
|
||||
1. UAC Consent (élévation de privilèges)
|
||||
- ClassName : `$$$Secure UAP Dummy Window Class$$$`
|
||||
- Process : `consent.exe`
|
||||
- Titre : "Contrôle de compte d'utilisateur", "User Account Control"
|
||||
|
||||
2. CredUI (prompt mot de passe Windows)
|
||||
- ClassName : `Credential Dialog Xaml Host`
|
||||
- Process : `credentialuibroker.exe`, `credui.exe`
|
||||
- Titre : "Sécurité Windows", "Windows Security"
|
||||
|
||||
3. SmartScreen (protection contre applications inconnues)
|
||||
- Process : `smartscreen.exe`
|
||||
- Titre : "Windows a protégé votre ordinateur", "Windows protected your PC"
|
||||
|
||||
4. Windows Defender / Security Center
|
||||
- Process : `securityhealthhost.exe`, `msmpeng.exe`
|
||||
- Titre : "Sécurité Windows", "Windows Defender"
|
||||
|
||||
5. Signatures pilotes / driver install
|
||||
- Titre : "Installer ce pilote", "Driver signature"
|
||||
"""
|
||||
|
||||
from __future__ import annotations
|
||||
|
||||
import logging
|
||||
import re
|
||||
from dataclasses import dataclass
|
||||
from typing import Any, Dict, Optional, Tuple
|
||||
|
||||
logger = logging.getLogger(__name__)
|
||||
|
||||
|
||||
# =============================================================================
|
||||
# Catégories de dialogues système (pour logging + messages)
|
||||
# =============================================================================
|
||||
|
||||
|
||||
class SystemDialogCategory:
|
||||
"""Catégories de dialogues système à bloquer absolument."""
|
||||
UAC = "uac_consent" # Élévation de privilèges
|
||||
CREDUI = "windows_credential_prompt" # Prompt de mot de passe
|
||||
SMARTSCREEN = "smartscreen" # Protection SmartScreen
|
||||
DEFENDER = "windows_defender" # Alerte Windows Defender
|
||||
DRIVER = "driver_install" # Installation pilote signé
|
||||
SECURITY_TOAST = "security_toast" # Toast de sécurité Windows
|
||||
UNKNOWN_DIALOG = "unknown_system_dialog" # Dialogue #32770 sans app connue
|
||||
|
||||
|
||||
@dataclass
|
||||
class SystemDialogDetection:
|
||||
"""Résultat d'une analyse de dialogue système."""
|
||||
is_system_dialog: bool
|
||||
category: str = "" # Valeur de SystemDialogCategory
|
||||
matched_signal: str = "" # Ex: "class_name=Consent.exe"
|
||||
matched_value: str = "" # La valeur qui a matché
|
||||
reason: str = "" # Explication lisible
|
||||
|
||||
def to_dict(self) -> Dict[str, Any]:
|
||||
return {
|
||||
"is_system_dialog": self.is_system_dialog,
|
||||
"category": self.category,
|
||||
"matched_signal": self.matched_signal,
|
||||
"matched_value": self.matched_value,
|
||||
"reason": self.reason,
|
||||
}
|
||||
|
||||
|
||||
# =============================================================================
|
||||
# Signatures de détection
|
||||
# =============================================================================
|
||||
|
||||
|
||||
# ClassName UIA (casse préservée — Windows exposées telle quelle par UIA).
|
||||
# Utilisées telles quelles puis en minuscules pour matcher avec souplesse.
|
||||
_CLASS_NAMES_SYSTEM = {
|
||||
# UAC Consent
|
||||
"$$$Secure UAP Dummy Window Class$$$": SystemDialogCategory.UAC,
|
||||
"Credential Dialog Xaml Host": SystemDialogCategory.CREDUI,
|
||||
# Windows Credential UI ancien nom
|
||||
"CredentialDialogXamlHost": SystemDialogCategory.CREDUI,
|
||||
}
|
||||
|
||||
# Nom de processus (comparaison insensible à la casse, .exe normalisé)
|
||||
_PROCESS_NAMES_SYSTEM = {
|
||||
"consent.exe": SystemDialogCategory.UAC,
|
||||
"credentialuibroker.exe": SystemDialogCategory.CREDUI,
|
||||
"credui.exe": SystemDialogCategory.CREDUI,
|
||||
"credwiz.exe": SystemDialogCategory.CREDUI,
|
||||
"smartscreen.exe": SystemDialogCategory.SMARTSCREEN,
|
||||
"securityhealthhost.exe": SystemDialogCategory.DEFENDER,
|
||||
"securityhealthui.exe": SystemDialogCategory.DEFENDER,
|
||||
"securityhealthsystray.exe": SystemDialogCategory.DEFENDER,
|
||||
"msmpeng.exe": SystemDialogCategory.DEFENDER,
|
||||
"windowsdefender.exe": SystemDialogCategory.DEFENDER,
|
||||
"msiexec.exe": SystemDialogCategory.DRIVER, # prompts pilotes signés
|
||||
"drvinst.exe": SystemDialogCategory.DRIVER,
|
||||
}
|
||||
|
||||
# Motifs titre (insensibles à la casse, regex avec word boundaries)
|
||||
# On ne matche pas les titres génériques trop larges pour limiter les faux
|
||||
# positifs sur OSIRIS/OBSIUS/MEDSPHERE.
|
||||
_TITLE_PATTERNS_SYSTEM: Tuple[Tuple[re.Pattern, str], ...] = (
|
||||
# UAC
|
||||
(re.compile(r"contr[oô]le\s+de\s+compte\s+d'?utilisateur", re.IGNORECASE),
|
||||
SystemDialogCategory.UAC),
|
||||
(re.compile(r"\buser\s+account\s+control\b", re.IGNORECASE),
|
||||
SystemDialogCategory.UAC),
|
||||
(re.compile(r"voulez-vous\s+autoriser\s+cette\s+application", re.IGNORECASE),
|
||||
SystemDialogCategory.UAC),
|
||||
(re.compile(r"do\s+you\s+want\s+to\s+allow\s+this\s+app", re.IGNORECASE),
|
||||
SystemDialogCategory.UAC),
|
||||
|
||||
# CredUI / Sécurité Windows
|
||||
(re.compile(r"\bs[eé]curit[eé]\s+windows\b", re.IGNORECASE),
|
||||
SystemDialogCategory.CREDUI),
|
||||
(re.compile(r"\bwindows\s+security\b", re.IGNORECASE),
|
||||
SystemDialogCategory.CREDUI),
|
||||
(re.compile(r"entrer\s+les\s+informations\s+d'?identification", re.IGNORECASE),
|
||||
SystemDialogCategory.CREDUI),
|
||||
(re.compile(r"enter\s+(?:your\s+)?credentials?", re.IGNORECASE),
|
||||
SystemDialogCategory.CREDUI),
|
||||
(re.compile(r"connectez-vous\s+[aà]\s+votre\s+compte", re.IGNORECASE),
|
||||
SystemDialogCategory.CREDUI),
|
||||
(re.compile(r"\bsign\s+in\s+to\s+your\s+account\b", re.IGNORECASE),
|
||||
SystemDialogCategory.CREDUI),
|
||||
|
||||
# SmartScreen
|
||||
(re.compile(r"windows\s+a\s+prot[eé]g[eé]", re.IGNORECASE),
|
||||
SystemDialogCategory.SMARTSCREEN),
|
||||
(re.compile(r"windows\s+protected\s+your\s+pc", re.IGNORECASE),
|
||||
SystemDialogCategory.SMARTSCREEN),
|
||||
(re.compile(r"\bsmartscreen\b", re.IGNORECASE),
|
||||
SystemDialogCategory.SMARTSCREEN),
|
||||
(re.compile(r"\b[eé]diteur\s+inconnu\b", re.IGNORECASE),
|
||||
SystemDialogCategory.SMARTSCREEN),
|
||||
(re.compile(r"\bunknown\s+publisher\b", re.IGNORECASE),
|
||||
SystemDialogCategory.SMARTSCREEN),
|
||||
|
||||
# Windows Defender
|
||||
(re.compile(r"windows\s+defender", re.IGNORECASE),
|
||||
SystemDialogCategory.DEFENDER),
|
||||
(re.compile(r"menace\s+d[eé]tect[eé]e", re.IGNORECASE),
|
||||
SystemDialogCategory.DEFENDER),
|
||||
(re.compile(r"threat\s+detected", re.IGNORECASE),
|
||||
SystemDialogCategory.DEFENDER),
|
||||
|
||||
# Driver
|
||||
(re.compile(r"installer\s+ce\s+pilote", re.IGNORECASE),
|
||||
SystemDialogCategory.DRIVER),
|
||||
(re.compile(r"install\s+this\s+driver", re.IGNORECASE),
|
||||
SystemDialogCategory.DRIVER),
|
||||
(re.compile(r"signature\s+num[eé]rique\s+du\s+pilote", re.IGNORECASE),
|
||||
SystemDialogCategory.DRIVER),
|
||||
)
|
||||
|
||||
|
||||
# =============================================================================
|
||||
# Fonctions de détection
|
||||
# =============================================================================
|
||||
|
||||
|
||||
def _normalize_process(name: str) -> str:
|
||||
"""Normaliser un nom de processus pour comparaison."""
|
||||
if not name:
|
||||
return ""
|
||||
name = name.strip().lower()
|
||||
# Enlever le chemin éventuel
|
||||
if "\\" in name or "/" in name:
|
||||
name = name.replace("\\", "/").split("/")[-1]
|
||||
# Assurer suffixe .exe pour matcher le dictionnaire
|
||||
if not name.endswith(".exe") and name:
|
||||
# Les process_name peuvent venir sans .exe (psutil) — on ajoute
|
||||
# pour avoir une clé uniforme
|
||||
name_with_exe = name + ".exe"
|
||||
if name_with_exe in _PROCESS_NAMES_SYSTEM:
|
||||
return name_with_exe
|
||||
return name
|
||||
|
||||
|
||||
def _check_class_name(class_name: str) -> Optional[Tuple[str, str, str]]:
|
||||
"""Vérifier si un ClassName UIA matche un dialogue système.
|
||||
|
||||
Returns:
|
||||
(category, matched_class, reason) si match, None sinon.
|
||||
"""
|
||||
if not class_name:
|
||||
return None
|
||||
|
||||
# Match exact
|
||||
if class_name in _CLASS_NAMES_SYSTEM:
|
||||
cat = _CLASS_NAMES_SYSTEM[class_name]
|
||||
return (cat, class_name, f"ClassName UIA '{class_name}' = dialogue système {cat}")
|
||||
|
||||
# Match insensible à la casse + normalisation espaces
|
||||
cn_norm = class_name.strip()
|
||||
for known, cat in _CLASS_NAMES_SYSTEM.items():
|
||||
if cn_norm.lower() == known.lower():
|
||||
return (cat, class_name, f"ClassName UIA ~= '{known}' ({cat})")
|
||||
|
||||
# Détection souple UAC (il existe quelques variantes de la classe secure)
|
||||
if "secure uap" in class_name.lower() or "uap dummy" in class_name.lower():
|
||||
return (SystemDialogCategory.UAC, class_name,
|
||||
f"ClassName '{class_name}' contient 'Secure UAP' → UAC")
|
||||
|
||||
# Credential XAML Host
|
||||
if "credential" in class_name.lower() and "xaml" in class_name.lower():
|
||||
return (SystemDialogCategory.CREDUI, class_name,
|
||||
f"ClassName '{class_name}' contient Credential+Xaml → CredUI")
|
||||
|
||||
return None
|
||||
|
||||
|
||||
def _check_process_name(process_name: str) -> Optional[Tuple[str, str, str]]:
|
||||
"""Vérifier si un nom de processus est un dialogue système.
|
||||
|
||||
Returns:
|
||||
(category, matched_process, reason) si match, None sinon.
|
||||
"""
|
||||
if not process_name:
|
||||
return None
|
||||
|
||||
norm = _normalize_process(process_name)
|
||||
if norm in _PROCESS_NAMES_SYSTEM:
|
||||
cat = _PROCESS_NAMES_SYSTEM[norm]
|
||||
return (cat, process_name, f"Processus '{norm}' = {cat}")
|
||||
return None
|
||||
|
||||
|
||||
def _check_title(title: str) -> Optional[Tuple[str, str, str]]:
|
||||
"""Vérifier si un titre de fenêtre matche un dialogue système.
|
||||
|
||||
Returns:
|
||||
(category, matched_pattern, reason) si match, None sinon.
|
||||
"""
|
||||
if not title:
|
||||
return None
|
||||
|
||||
for pattern, cat in _TITLE_PATTERNS_SYSTEM:
|
||||
m = pattern.search(title)
|
||||
if m:
|
||||
return (cat, m.group(0),
|
||||
f"Titre '{title[:60]}' matche '{pattern.pattern}' → {cat}")
|
||||
return None
|
||||
|
||||
|
||||
def is_system_dialog(
|
||||
uia_snapshot: Optional[Dict[str, Any]] = None,
|
||||
window_info: Optional[Dict[str, Any]] = None,
|
||||
) -> SystemDialogDetection:
|
||||
"""Déterminer si la fenêtre active est un dialogue système critique.
|
||||
|
||||
La détection combine plusieurs signaux — **un seul suffit à bloquer**.
|
||||
On préfère un faux positif (pause inutile) à un faux négatif (clic UAC).
|
||||
|
||||
Args:
|
||||
uia_snapshot: Dict avec champs `class_name`, `process_name`,
|
||||
`parent_path`, `name`. Peut être None si UIA indisponible.
|
||||
window_info: Dict avec champs `title`, `app_name`. Peut être None.
|
||||
|
||||
Returns:
|
||||
SystemDialogDetection avec is_system_dialog=True si un dialogue
|
||||
système est détecté.
|
||||
|
||||
Exemples::
|
||||
|
||||
det = is_system_dialog(window_info={"title": "User Account Control"})
|
||||
assert det.is_system_dialog # UAC détecté
|
||||
|
||||
det = is_system_dialog(uia_snapshot={"class_name": "$$$Secure UAP Dummy Window Class$$$"})
|
||||
assert det.is_system_dialog # UAC via ClassName
|
||||
|
||||
det = is_system_dialog(window_info={"title": "OSIRIS - Patient Dupont"})
|
||||
assert not det.is_system_dialog # Application métier → OK
|
||||
"""
|
||||
# ── Signal 1 : ClassName UIA ──
|
||||
if uia_snapshot:
|
||||
cn = uia_snapshot.get("class_name", "") or ""
|
||||
r = _check_class_name(cn)
|
||||
if r:
|
||||
cat, matched, reason = r
|
||||
return SystemDialogDetection(
|
||||
is_system_dialog=True,
|
||||
category=cat,
|
||||
matched_signal="class_name",
|
||||
matched_value=matched,
|
||||
reason=reason,
|
||||
)
|
||||
|
||||
# Explorer aussi les parents (le champ cliqué peut être un bouton
|
||||
# interne dont la ClassName est "Button", mais le root de la fenêtre
|
||||
# est le Consent.exe).
|
||||
for parent in uia_snapshot.get("parent_path", []) or []:
|
||||
p_cn = parent.get("class_name", "") or ""
|
||||
r = _check_class_name(p_cn)
|
||||
if r:
|
||||
cat, matched, reason = r
|
||||
return SystemDialogDetection(
|
||||
is_system_dialog=True,
|
||||
category=cat,
|
||||
matched_signal="parent_class_name",
|
||||
matched_value=matched,
|
||||
reason=f"Parent : {reason}",
|
||||
)
|
||||
|
||||
# ── Signal 2 : Process name ──
|
||||
if uia_snapshot:
|
||||
pn = uia_snapshot.get("process_name", "") or ""
|
||||
r = _check_process_name(pn)
|
||||
if r:
|
||||
cat, matched, reason = r
|
||||
return SystemDialogDetection(
|
||||
is_system_dialog=True,
|
||||
category=cat,
|
||||
matched_signal="process_name",
|
||||
matched_value=matched,
|
||||
reason=reason,
|
||||
)
|
||||
|
||||
if window_info:
|
||||
app = window_info.get("app_name", "") or ""
|
||||
r = _check_process_name(app)
|
||||
if r:
|
||||
cat, matched, reason = r
|
||||
return SystemDialogDetection(
|
||||
is_system_dialog=True,
|
||||
category=cat,
|
||||
matched_signal="app_name",
|
||||
matched_value=matched,
|
||||
reason=reason,
|
||||
)
|
||||
|
||||
# ── Signal 3 : Titre de fenêtre ──
|
||||
if window_info:
|
||||
title = window_info.get("title", "") or ""
|
||||
r = _check_title(title)
|
||||
if r:
|
||||
cat, matched, reason = r
|
||||
return SystemDialogDetection(
|
||||
is_system_dialog=True,
|
||||
category=cat,
|
||||
matched_signal="window_title",
|
||||
matched_value=matched,
|
||||
reason=reason,
|
||||
)
|
||||
|
||||
if uia_snapshot:
|
||||
# Certains dialogues système remontent leur titre dans uia.name
|
||||
uia_name = uia_snapshot.get("name", "") or ""
|
||||
r = _check_title(uia_name)
|
||||
if r:
|
||||
cat, matched, reason = r
|
||||
return SystemDialogDetection(
|
||||
is_system_dialog=True,
|
||||
category=cat,
|
||||
matched_signal="uia_name",
|
||||
matched_value=matched,
|
||||
reason=reason,
|
||||
)
|
||||
|
||||
return SystemDialogDetection(is_system_dialog=False)
|
||||
|
||||
|
||||
def detect_current_system_dialog() -> SystemDialogDetection:
|
||||
"""Analyser l'écran actuel et détecter un dialogue système.
|
||||
|
||||
Helper autonome qui interroge à la fois `get_active_window_info()` et
|
||||
le helper UIA (si dispo) pour obtenir la détection la plus fiable.
|
||||
|
||||
Returns:
|
||||
SystemDialogDetection. Si un signal matche, is_system_dialog=True.
|
||||
Si rien n'est disponible (Linux, UIA absent), is_system_dialog=False
|
||||
mais le caller peut encore fallback sur une analyse par titre.
|
||||
"""
|
||||
window_info: Optional[Dict[str, Any]] = None
|
||||
uia_snapshot: Optional[Dict[str, Any]] = None
|
||||
|
||||
# Fenêtre active (cross-platform)
|
||||
try:
|
||||
from ..window_info_crossplatform import get_active_window_info
|
||||
window_info = get_active_window_info()
|
||||
except Exception as e: # pragma: no cover — best-effort
|
||||
logger.debug(f"[SYS-DIALOG] window_info indisponible : {e}")
|
||||
|
||||
# UIA local (Windows uniquement, via lea_uia.exe)
|
||||
try:
|
||||
from .uia_helper import get_shared_helper
|
||||
helper = get_shared_helper()
|
||||
if helper.available:
|
||||
# On capture l'élément focalisé (root = fenêtre active)
|
||||
element = helper.capture_focused(max_depth=2)
|
||||
if element is not None:
|
||||
uia_snapshot = element.to_dict()
|
||||
except Exception as e: # pragma: no cover
|
||||
logger.debug(f"[SYS-DIALOG] UIA indisponible : {e}")
|
||||
|
||||
detection = is_system_dialog(
|
||||
uia_snapshot=uia_snapshot, window_info=window_info,
|
||||
)
|
||||
|
||||
if detection.is_system_dialog:
|
||||
logger.warning(
|
||||
f"[SYS-DIALOG] BLOCAGE — dialogue système détecté "
|
||||
f"[{detection.category}] via {detection.matched_signal}='{detection.matched_value}' "
|
||||
f"— {detection.reason}"
|
||||
)
|
||||
return detection
|
||||
|
||||
|
||||
__all__ = [
|
||||
"SystemDialogCategory",
|
||||
"SystemDialogDetection",
|
||||
"is_system_dialog",
|
||||
"detect_current_system_dialog",
|
||||
]
|
||||
@@ -17,6 +17,7 @@ import threading
|
||||
from .config import (
|
||||
SESSIONS_ROOT, AGENT_VERSION, SERVER_URL, MACHINE_ID, LOG_RETENTION_DAYS,
|
||||
SCREEN_RESOLUTION, DPI_SCALE, OS_THEME, API_TOKEN, MAX_SESSION_DURATION_S,
|
||||
STREAMING_ENDPOINT,
|
||||
)
|
||||
from .core.captor import EventCaptorV1
|
||||
from .core.executor import ActionExecutorV1
|
||||
@@ -86,22 +87,23 @@ class AgentV1:
|
||||
self._state.set_on_stop(self.stop_session)
|
||||
|
||||
# Client serveur pour le chat et les workflows
|
||||
# Plus de RPA_SERVER_HOST : le LeaServerClient derive tout de SERVER_URL
|
||||
self._server_client = None
|
||||
if LeaServerClient is not None:
|
||||
# Forcer le token API pour éviter les 401
|
||||
# (le token est set par start.bat dans l'environnement)
|
||||
from .config import API_TOKEN as _token
|
||||
server_host = os.getenv("RPA_SERVER_HOST", "localhost")
|
||||
self._server_client = LeaServerClient(server_host=server_host)
|
||||
self._server_client = LeaServerClient()
|
||||
if _token and not self._server_client._api_token:
|
||||
self._server_client._api_token = _token
|
||||
logger.info("Token API forcé dans LeaServerClient")
|
||||
|
||||
# Fenetre de chat Lea (tkinter natif)
|
||||
# Le host est derive de SERVER_URL (plus de RPA_SERVER_HOST)
|
||||
server_host = (
|
||||
self._server_client.server_host
|
||||
if self._server_client is not None
|
||||
else os.getenv("RPA_SERVER_HOST", "localhost")
|
||||
else "localhost"
|
||||
)
|
||||
self._chat_window = ChatWindow(
|
||||
server_client=self._server_client,
|
||||
@@ -363,11 +365,11 @@ class AgentV1:
|
||||
continue
|
||||
self._last_bg_hash = img_hash
|
||||
|
||||
# Envoyer au streaming server (avec token auth)
|
||||
# Envoyer au streaming server (via STREAMING_ENDPOINT unifié)
|
||||
headers = {"Authorization": f"Bearer {API_TOKEN}"} if API_TOKEN else {}
|
||||
with open(full_path, 'rb') as f:
|
||||
req.post(
|
||||
f"{SERVER_URL}/traces/stream/image",
|
||||
f"{STREAMING_ENDPOINT}/image",
|
||||
params={
|
||||
"session_id": bg_session,
|
||||
"shot_id": f"heartbeat_{int(time.time())}",
|
||||
@@ -376,6 +378,7 @@ class AgentV1:
|
||||
headers=headers,
|
||||
files={"file": ("screenshot.png", f, "image/png")},
|
||||
timeout=10,
|
||||
allow_redirects=False,
|
||||
)
|
||||
except Exception as e:
|
||||
logger.debug(f"[HEARTBEAT] Erreur: {e}")
|
||||
|
||||
149
agent_v0/agent_v1/network/feedback_bus.py
Normal file
149
agent_v0/agent_v1/network/feedback_bus.py
Normal file
@@ -0,0 +1,149 @@
|
||||
# agent_v1/network/feedback_bus.py
|
||||
"""Client SocketIO pour le bus feedback Léa.
|
||||
|
||||
Consomme les events 'lea:*' émis par agent_chat (port 5004) et les dispatche
|
||||
vers ChatWindow pour affichage en bulles temps réel.
|
||||
|
||||
Events écoutés :
|
||||
lea:action_started — début d'un workflow ou d'une action
|
||||
lea:action_progress — progression dans le workflow
|
||||
lea:done — fin d'un workflow ou d'un copilot
|
||||
lea:need_confirm — étape copilot en attente de validation
|
||||
lea:step_result — résultat d'une étape copilot
|
||||
lea:paused — basculement en paused_need_help (asset démo)
|
||||
lea:resumed — sortie de pause supervisée
|
||||
|
||||
Fail-safe : toute erreur de connexion ou de dispatch est silencieusement
|
||||
loggée. Le ChatWindow continue de fonctionner même si le bus est mort
|
||||
(comportement strictement identique au pré-J3).
|
||||
|
||||
Usage :
|
||||
bus = FeedbackBusClient(
|
||||
server_url="http://localhost:5004",
|
||||
token=os.environ.get("RPA_API_TOKEN", ""),
|
||||
on_event=lambda event, payload: print(event, payload),
|
||||
)
|
||||
bus.start() # connexion en arrière-plan, non-bloquant
|
||||
# ... ChatWindow tourne ...
|
||||
bus.stop()
|
||||
"""
|
||||
|
||||
import logging
|
||||
import threading
|
||||
from typing import Callable, Optional
|
||||
|
||||
import socketio
|
||||
|
||||
logger = logging.getLogger(__name__)
|
||||
|
||||
LEA_EVENTS = (
|
||||
'lea:action_started',
|
||||
'lea:action_progress',
|
||||
'lea:done',
|
||||
'lea:need_confirm',
|
||||
'lea:step_result',
|
||||
'lea:paused',
|
||||
'lea:resumed',
|
||||
)
|
||||
|
||||
EventCallback = Callable[[str, dict], None]
|
||||
|
||||
|
||||
class FeedbackBusClient:
|
||||
"""Client SocketIO non-bloquant pour le bus 'lea:*'."""
|
||||
|
||||
def __init__(
|
||||
self,
|
||||
server_url: str,
|
||||
token: Optional[str] = None,
|
||||
on_event: Optional[EventCallback] = None,
|
||||
):
|
||||
self._url = server_url.rstrip('/')
|
||||
self._token = token or None
|
||||
self._on_event: EventCallback = on_event or (lambda e, p: None)
|
||||
self._sio = socketio.Client(
|
||||
reconnection=True,
|
||||
reconnection_attempts=0, # 0 = illimité
|
||||
reconnection_delay=2,
|
||||
reconnection_delay_max=30,
|
||||
logger=False,
|
||||
engineio_logger=False,
|
||||
)
|
||||
self._thread: Optional[threading.Thread] = None
|
||||
self._register_handlers()
|
||||
|
||||
def _register_handlers(self) -> None:
|
||||
@self._sio.event
|
||||
def connect():
|
||||
logger.info("FeedbackBus connecté à %s", self._url)
|
||||
|
||||
@self._sio.event
|
||||
def disconnect():
|
||||
logger.info("FeedbackBus déconnecté")
|
||||
|
||||
for ev in LEA_EVENTS:
|
||||
self._sio.on(ev, lambda data, e=ev: self._dispatch(e, data))
|
||||
|
||||
def _dispatch(self, event: str, payload: Optional[dict]) -> None:
|
||||
try:
|
||||
self._on_event(event, payload or {})
|
||||
except Exception:
|
||||
logger.debug("FeedbackBus dispatch silenced", exc_info=True)
|
||||
|
||||
def start(self) -> None:
|
||||
"""Démarrer la connexion en arrière-plan (idempotent, non-bloquant)."""
|
||||
if self._thread is not None and self._thread.is_alive():
|
||||
return
|
||||
self._thread = threading.Thread(
|
||||
target=self._run, daemon=True, name="LeaFeedbackBus",
|
||||
)
|
||||
self._thread.start()
|
||||
|
||||
def _run(self) -> None:
|
||||
headers = {}
|
||||
if self._token:
|
||||
headers['Authorization'] = f'Bearer {self._token}'
|
||||
try:
|
||||
self._sio.connect(self._url, headers=headers, wait=True)
|
||||
self._sio.wait()
|
||||
except Exception as e:
|
||||
logger.warning(
|
||||
"FeedbackBus connect échoué (%s) — ChatWindow continue normalement", e,
|
||||
)
|
||||
|
||||
def stop(self) -> None:
|
||||
"""Arrêter proprement la connexion (idempotent, fail-safe)."""
|
||||
try:
|
||||
if self._sio.connected:
|
||||
self._sio.disconnect()
|
||||
except Exception:
|
||||
logger.debug("FeedbackBus stop silenced", exc_info=True)
|
||||
|
||||
@property
|
||||
def connected(self) -> bool:
|
||||
return bool(self._sio.connected)
|
||||
|
||||
# ------------------------------------------------------------------
|
||||
# Actions utilisateur depuis la bulle paused_need_help (J3.5)
|
||||
# ------------------------------------------------------------------
|
||||
|
||||
def resume_replay(self, replay_id: str) -> bool:
|
||||
"""Bouton Continuer : émet 'lea:replay_resume' vers agent_chat.
|
||||
|
||||
Retourne True si l'event a pu être émis, False sinon (déconnecté/erreur).
|
||||
"""
|
||||
return self._safe_emit("lea:replay_resume", {"replay_id": replay_id})
|
||||
|
||||
def abort_replay(self, replay_id: str) -> bool:
|
||||
"""Bouton Annuler : émet 'lea:replay_abort' vers agent_chat."""
|
||||
return self._safe_emit("lea:replay_abort", {"replay_id": replay_id})
|
||||
|
||||
def _safe_emit(self, event: str, payload: dict) -> bool:
|
||||
try:
|
||||
if not self._sio.connected:
|
||||
return False
|
||||
self._sio.emit(event, payload)
|
||||
return True
|
||||
except Exception:
|
||||
logger.debug("FeedbackBus _safe_emit silenced", exc_info=True)
|
||||
return False
|
||||
380
agent_v0/agent_v1/network/persistent_buffer.py
Normal file
380
agent_v0/agent_v1/network/persistent_buffer.py
Normal file
@@ -0,0 +1,380 @@
|
||||
# agent_v1/network/persistent_buffer.py
|
||||
"""
|
||||
Buffer persistant SQLite pour les événements/images qui n'ont pas pu être envoyés.
|
||||
|
||||
Résout le bloquant AI Act Article 12 : en cas de coupure serveur ou de queue pleine,
|
||||
les événements prioritaires (click, key, action, screenshot) sont persistés sur disque
|
||||
au lieu d'être silencieusement perdus. Ils sont rejoués à la reconnexion.
|
||||
|
||||
Caractéristiques :
|
||||
- SQLite fichier unique (agent_v1/buffer/pending_events.db), thread-safe
|
||||
- Async : les écritures se font depuis un thread daemon, jamais bloquant
|
||||
- Quota : compteur d'attempts par item, abandon après MAX_ATTEMPTS
|
||||
- Robustesse : un fichier corrompu est renommé et recréé vide
|
||||
"""
|
||||
|
||||
from __future__ import annotations
|
||||
|
||||
import json
|
||||
import logging
|
||||
import os
|
||||
import sqlite3
|
||||
import threading
|
||||
import time
|
||||
from pathlib import Path
|
||||
|
||||
logger = logging.getLogger(__name__)
|
||||
|
||||
# Nombre max de tentatives avant abandon définitif d'un item
|
||||
MAX_ATTEMPTS = 10
|
||||
|
||||
# Taille max du buffer en items pour éviter une explosion disque
|
||||
# (typiquement : 1000 events + 1000 images = quelques Mo de SQLite)
|
||||
MAX_BUFFER_ITEMS = 2000
|
||||
|
||||
|
||||
class PersistentBuffer:
|
||||
"""Buffer SQLite pour événements/images en attente d'envoi.
|
||||
|
||||
Deux tables :
|
||||
- pending_events (id, session_id, payload_json, attempts, created_at)
|
||||
- pending_images (id, session_id, shot_id, image_path, attempts, created_at)
|
||||
|
||||
Usage :
|
||||
buf = PersistentBuffer(base_dir / "buffer")
|
||||
buf.add_event(session_id, event_dict) # persiste un event
|
||||
buf.add_image(session_id, image_path, shot_id) # persiste une image
|
||||
for row in buf.drain_events(): # itère sur les events
|
||||
if envoyer(row): buf.delete_event(row["id"])
|
||||
else: buf.mark_attempt(row["id"], "event")
|
||||
"""
|
||||
|
||||
def __init__(self, buffer_dir: Path):
|
||||
self.buffer_dir = Path(buffer_dir)
|
||||
self.buffer_dir.mkdir(parents=True, exist_ok=True)
|
||||
self.db_path = self.buffer_dir / "pending_events.db"
|
||||
self._lock = threading.Lock()
|
||||
self._init_db()
|
||||
|
||||
# ---------------------------------------------------------------
|
||||
# Initialisation / gestion corruption
|
||||
# ---------------------------------------------------------------
|
||||
|
||||
def _init_db(self):
|
||||
"""Crée les tables si elles n'existent pas.
|
||||
|
||||
En cas de fichier corrompu, on le renomme en .corrupted et on recrée
|
||||
un buffer vide. On préfère perdre un buffer non lisible plutôt que
|
||||
de crasher l'agent au démarrage.
|
||||
"""
|
||||
try:
|
||||
with self._connect() as conn:
|
||||
conn.execute(
|
||||
"""
|
||||
CREATE TABLE IF NOT EXISTS pending_events (
|
||||
id INTEGER PRIMARY KEY AUTOINCREMENT,
|
||||
session_id TEXT NOT NULL,
|
||||
payload TEXT NOT NULL,
|
||||
attempts INTEGER NOT NULL DEFAULT 0,
|
||||
created_at REAL NOT NULL
|
||||
)
|
||||
"""
|
||||
)
|
||||
conn.execute(
|
||||
"""
|
||||
CREATE TABLE IF NOT EXISTS pending_images (
|
||||
id INTEGER PRIMARY KEY AUTOINCREMENT,
|
||||
session_id TEXT NOT NULL,
|
||||
shot_id TEXT NOT NULL,
|
||||
image_path TEXT NOT NULL,
|
||||
attempts INTEGER NOT NULL DEFAULT 0,
|
||||
created_at REAL NOT NULL
|
||||
)
|
||||
"""
|
||||
)
|
||||
conn.execute(
|
||||
"CREATE INDEX IF NOT EXISTS idx_events_created "
|
||||
"ON pending_events(created_at)"
|
||||
)
|
||||
conn.execute(
|
||||
"CREATE INDEX IF NOT EXISTS idx_images_created "
|
||||
"ON pending_images(created_at)"
|
||||
)
|
||||
conn.commit()
|
||||
except sqlite3.DatabaseError as e:
|
||||
logger.warning(
|
||||
f"Buffer SQLite corrompu ({e}) — renommage en .corrupted "
|
||||
f"et recréation d'un buffer vide"
|
||||
)
|
||||
try:
|
||||
corrupted = self.db_path.with_suffix(
|
||||
f".corrupted.{int(time.time())}"
|
||||
)
|
||||
os.rename(self.db_path, corrupted)
|
||||
except OSError:
|
||||
# Si le rename échoue, on tente la suppression directe
|
||||
try:
|
||||
os.remove(self.db_path)
|
||||
except OSError:
|
||||
pass
|
||||
# Nouvelle tentative (table vide)
|
||||
with self._connect() as conn:
|
||||
conn.execute(
|
||||
"CREATE TABLE IF NOT EXISTS pending_events ("
|
||||
"id INTEGER PRIMARY KEY AUTOINCREMENT, "
|
||||
"session_id TEXT NOT NULL, payload TEXT NOT NULL, "
|
||||
"attempts INTEGER NOT NULL DEFAULT 0, "
|
||||
"created_at REAL NOT NULL)"
|
||||
)
|
||||
conn.execute(
|
||||
"CREATE TABLE IF NOT EXISTS pending_images ("
|
||||
"id INTEGER PRIMARY KEY AUTOINCREMENT, "
|
||||
"session_id TEXT NOT NULL, shot_id TEXT NOT NULL, "
|
||||
"image_path TEXT NOT NULL, "
|
||||
"attempts INTEGER NOT NULL DEFAULT 0, "
|
||||
"created_at REAL NOT NULL)"
|
||||
)
|
||||
conn.commit()
|
||||
|
||||
def _connect(self) -> sqlite3.Connection:
|
||||
"""Connexion SQLite en mode WAL (meilleure concurrence)."""
|
||||
conn = sqlite3.connect(
|
||||
str(self.db_path),
|
||||
timeout=5.0,
|
||||
check_same_thread=False,
|
||||
isolation_level=None, # autocommit — on gère les transactions
|
||||
)
|
||||
try:
|
||||
conn.execute("PRAGMA journal_mode=WAL")
|
||||
conn.execute("PRAGMA synchronous=NORMAL")
|
||||
except sqlite3.DatabaseError:
|
||||
pass
|
||||
conn.row_factory = sqlite3.Row
|
||||
return conn
|
||||
|
||||
# ---------------------------------------------------------------
|
||||
# Écriture — persiste un item
|
||||
# ---------------------------------------------------------------
|
||||
|
||||
def add_event(self, session_id: str, event: dict) -> bool:
|
||||
"""Persiste un événement. Retourne True si écrit, False sinon.
|
||||
|
||||
Si le buffer dépasse MAX_BUFFER_ITEMS, on drop l'insertion (plutôt
|
||||
que saturer le disque). On log un warning au premier dépassement.
|
||||
"""
|
||||
with self._lock:
|
||||
try:
|
||||
with self._connect() as conn:
|
||||
count = conn.execute(
|
||||
"SELECT COUNT(*) FROM pending_events"
|
||||
).fetchone()[0]
|
||||
if count >= MAX_BUFFER_ITEMS:
|
||||
logger.warning(
|
||||
f"Buffer persistant saturé ({count} events) "
|
||||
f"— event droppé"
|
||||
)
|
||||
return False
|
||||
conn.execute(
|
||||
"INSERT INTO pending_events "
|
||||
"(session_id, payload, attempts, created_at) "
|
||||
"VALUES (?, ?, 0, ?)",
|
||||
(session_id, json.dumps(event), time.time()),
|
||||
)
|
||||
return True
|
||||
except (sqlite3.DatabaseError, TypeError, ValueError) as e:
|
||||
logger.error(f"Buffer add_event échoué : {e}")
|
||||
return False
|
||||
|
||||
def add_image(
|
||||
self, session_id: str, image_path: str, shot_id: str
|
||||
) -> bool:
|
||||
"""Persiste une référence image (chemin fichier + shot_id).
|
||||
|
||||
On ne stocke PAS les bytes de l'image (risque de faire gonfler la DB) :
|
||||
uniquement le chemin. Donc l'image doit rester présente sur disque
|
||||
tant qu'elle n'a pas été envoyée avec succès au serveur.
|
||||
"""
|
||||
with self._lock:
|
||||
try:
|
||||
with self._connect() as conn:
|
||||
count = conn.execute(
|
||||
"SELECT COUNT(*) FROM pending_images"
|
||||
).fetchone()[0]
|
||||
if count >= MAX_BUFFER_ITEMS:
|
||||
logger.warning(
|
||||
f"Buffer persistant saturé ({count} images) "
|
||||
f"— image droppée"
|
||||
)
|
||||
return False
|
||||
conn.execute(
|
||||
"INSERT INTO pending_images "
|
||||
"(session_id, shot_id, image_path, attempts, created_at) "
|
||||
"VALUES (?, ?, ?, 0, ?)",
|
||||
(session_id, shot_id, image_path, time.time()),
|
||||
)
|
||||
return True
|
||||
except sqlite3.DatabaseError as e:
|
||||
logger.error(f"Buffer add_image échoué : {e}")
|
||||
return False
|
||||
|
||||
# ---------------------------------------------------------------
|
||||
# Lecture — drain dans l'ordre chronologique
|
||||
# ---------------------------------------------------------------
|
||||
|
||||
def drain_events(self, limit: int = 100) -> list:
|
||||
"""Retourne les events en attente, triés par date de création."""
|
||||
with self._lock:
|
||||
try:
|
||||
with self._connect() as conn:
|
||||
rows = conn.execute(
|
||||
"SELECT id, session_id, payload, attempts "
|
||||
"FROM pending_events "
|
||||
"ORDER BY created_at ASC LIMIT ?",
|
||||
(limit,),
|
||||
).fetchall()
|
||||
return [dict(r) for r in rows]
|
||||
except sqlite3.DatabaseError as e:
|
||||
logger.error(f"Buffer drain_events échoué : {e}")
|
||||
return []
|
||||
|
||||
def drain_images(self, limit: int = 50) -> list:
|
||||
"""Retourne les images en attente, triées par date de création."""
|
||||
with self._lock:
|
||||
try:
|
||||
with self._connect() as conn:
|
||||
rows = conn.execute(
|
||||
"SELECT id, session_id, shot_id, image_path, attempts "
|
||||
"FROM pending_images "
|
||||
"ORDER BY created_at ASC LIMIT ?",
|
||||
(limit,),
|
||||
).fetchall()
|
||||
return [dict(r) for r in rows]
|
||||
except sqlite3.DatabaseError as e:
|
||||
logger.error(f"Buffer drain_images échoué : {e}")
|
||||
return []
|
||||
|
||||
# ---------------------------------------------------------------
|
||||
# Marquage — succès, échec, abandon
|
||||
# ---------------------------------------------------------------
|
||||
|
||||
def delete_event(self, row_id: int):
|
||||
"""Supprime un event après envoi réussi."""
|
||||
with self._lock:
|
||||
try:
|
||||
with self._connect() as conn:
|
||||
conn.execute(
|
||||
"DELETE FROM pending_events WHERE id = ?", (row_id,)
|
||||
)
|
||||
except sqlite3.DatabaseError as e:
|
||||
logger.error(f"Buffer delete_event échoué : {e}")
|
||||
|
||||
def delete_image(self, row_id: int):
|
||||
"""Supprime une image après envoi réussi."""
|
||||
with self._lock:
|
||||
try:
|
||||
with self._connect() as conn:
|
||||
conn.execute(
|
||||
"DELETE FROM pending_images WHERE id = ?", (row_id,)
|
||||
)
|
||||
except sqlite3.DatabaseError as e:
|
||||
logger.error(f"Buffer delete_image échoué : {e}")
|
||||
|
||||
def increment_attempts(self, row_id: int, kind: str) -> int:
|
||||
"""Incrémente le compteur d'attempts. Retourne la nouvelle valeur.
|
||||
|
||||
kind : "event" ou "image"
|
||||
"""
|
||||
table = "pending_events" if kind == "event" else "pending_images"
|
||||
with self._lock:
|
||||
try:
|
||||
with self._connect() as conn:
|
||||
conn.execute(
|
||||
f"UPDATE {table} SET attempts = attempts + 1 "
|
||||
"WHERE id = ?",
|
||||
(row_id,),
|
||||
)
|
||||
row = conn.execute(
|
||||
f"SELECT attempts FROM {table} WHERE id = ?", (row_id,)
|
||||
).fetchone()
|
||||
return int(row["attempts"]) if row else MAX_ATTEMPTS
|
||||
except sqlite3.DatabaseError as e:
|
||||
logger.error(f"Buffer increment_attempts échoué : {e}")
|
||||
return MAX_ATTEMPTS
|
||||
|
||||
def abandon_exceeded(self) -> int:
|
||||
"""Supprime les items ayant dépassé MAX_ATTEMPTS.
|
||||
|
||||
Un item abandonné est logué en erreur (trace AI Act) puis supprimé.
|
||||
Retourne le nombre d'items abandonnés.
|
||||
"""
|
||||
abandoned = 0
|
||||
with self._lock:
|
||||
try:
|
||||
with self._connect() as conn:
|
||||
# Events abandonnés
|
||||
rows = conn.execute(
|
||||
"SELECT id, session_id, payload FROM pending_events "
|
||||
"WHERE attempts >= ?",
|
||||
(MAX_ATTEMPTS,),
|
||||
).fetchall()
|
||||
for r in rows:
|
||||
try:
|
||||
event_type = json.loads(r["payload"]).get(
|
||||
"type", "?"
|
||||
)
|
||||
except (ValueError, TypeError):
|
||||
event_type = "?"
|
||||
logger.error(
|
||||
f"Buffer : event abandonné après {MAX_ATTEMPTS} "
|
||||
f"tentatives — session={r['session_id']} "
|
||||
f"type={event_type}"
|
||||
)
|
||||
abandoned += 1
|
||||
conn.execute(
|
||||
"DELETE FROM pending_events WHERE attempts >= ?",
|
||||
(MAX_ATTEMPTS,),
|
||||
)
|
||||
|
||||
# Images abandonnées
|
||||
rows = conn.execute(
|
||||
"SELECT id, session_id, shot_id FROM pending_images "
|
||||
"WHERE attempts >= ?",
|
||||
(MAX_ATTEMPTS,),
|
||||
).fetchall()
|
||||
for r in rows:
|
||||
logger.error(
|
||||
f"Buffer : image abandonnée après {MAX_ATTEMPTS} "
|
||||
f"tentatives — session={r['session_id']} "
|
||||
f"shot_id={r['shot_id']}"
|
||||
)
|
||||
abandoned += 1
|
||||
conn.execute(
|
||||
"DELETE FROM pending_images WHERE attempts >= ?",
|
||||
(MAX_ATTEMPTS,),
|
||||
)
|
||||
except sqlite3.DatabaseError as e:
|
||||
logger.error(f"Buffer abandon_exceeded échoué : {e}")
|
||||
return abandoned
|
||||
|
||||
# ---------------------------------------------------------------
|
||||
# Introspection
|
||||
# ---------------------------------------------------------------
|
||||
|
||||
def counts(self) -> dict:
|
||||
"""Retourne (events_count, images_count) pour diagnostic."""
|
||||
with self._lock:
|
||||
try:
|
||||
with self._connect() as conn:
|
||||
ev = conn.execute(
|
||||
"SELECT COUNT(*) FROM pending_events"
|
||||
).fetchone()[0]
|
||||
im = conn.execute(
|
||||
"SELECT COUNT(*) FROM pending_images"
|
||||
).fetchone()[0]
|
||||
return {"events": ev, "images": im}
|
||||
except sqlite3.DatabaseError:
|
||||
return {"events": 0, "images": 0}
|
||||
|
||||
def is_empty(self) -> bool:
|
||||
c = self.counts()
|
||||
return c["events"] == 0 and c["images"] == 0
|
||||
@@ -14,10 +14,19 @@ Robustesse (P0-2) :
|
||||
- Health-check périodique (30s) pour recovery du flag _server_available
|
||||
- Compression JPEG qualité 85 pour les images (réduction ~5-10x)
|
||||
- Backpressure : queue bornée (maxsize=100), drop des heartbeat si pleine
|
||||
|
||||
Conformité AI Act (Article 12 — journalisation automatique) :
|
||||
- Purge après ACK : les screenshots locaux sont supprimés après HTTP 200
|
||||
du serveur (par défaut). Le serveur devient la source de vérité.
|
||||
- Buffer persistant : les events/images prioritaires non envoyés sont
|
||||
persistés dans un SQLite local (agent_v1/buffer/pending_events.db)
|
||||
et rejoués au démarrage et à la reconnexion.
|
||||
"""
|
||||
|
||||
import enum
|
||||
import io
|
||||
import logging
|
||||
import os
|
||||
import queue
|
||||
import threading
|
||||
import time
|
||||
@@ -25,7 +34,18 @@ import time
|
||||
import requests
|
||||
from PIL import Image
|
||||
|
||||
from ..config import API_TOKEN, STREAMING_ENDPOINT
|
||||
from ..config import API_TOKEN, BASE_DIR, STREAMING_ENDPOINT
|
||||
from .persistent_buffer import MAX_ATTEMPTS, PersistentBuffer
|
||||
|
||||
|
||||
# Fix P0-E : résultat d'envoi d'image trivaleur (succès / échec réseau / fichier
|
||||
# disparu). On ne doit PAS considérer un FileNotFoundError comme un succès
|
||||
# HTTP 200 — sinon le buffer SQLite supprime l'entrée alors que le serveur n'a
|
||||
# jamais reçu l'image (perte silencieuse).
|
||||
class ImageSendResult(enum.Enum):
|
||||
OK = "ok" # HTTP 200, serveur a accusé réception
|
||||
FAILED = "failed" # Erreur réseau/serveur récupérable (retry OK)
|
||||
FILE_GONE = "file_gone" # Fichier local introuvable (abandon, pas retry)
|
||||
|
||||
logger = logging.getLogger(__name__)
|
||||
|
||||
@@ -45,6 +65,20 @@ QUEUE_MAX_SIZE = 100
|
||||
# Types d'événements à ne jamais dropper
|
||||
PRIORITY_EVENT_TYPES = {"click", "key", "scroll", "action", "screenshot"}
|
||||
|
||||
# Purge locale après ACK serveur (Partie A de l'audit)
|
||||
# Activé par défaut : le serveur conserve déjà les screenshots 180 jours
|
||||
# (conformité AI Act Article 12). Désactivable via RPA_PURGE_AFTER_ACK=0
|
||||
# pour debugging local.
|
||||
PURGE_AFTER_ACK = os.environ.get("RPA_PURGE_AFTER_ACK", "1").lower() in (
|
||||
"1", "true", "yes",
|
||||
)
|
||||
|
||||
# Chemin du buffer persistant (Partie B de l'audit)
|
||||
BUFFER_DIR = BASE_DIR / "buffer"
|
||||
|
||||
# Intervalle entre deux tentatives de drain du buffer (secondes)
|
||||
BUFFER_DRAIN_INTERVAL_S = 15
|
||||
|
||||
|
||||
class TraceStreamer:
|
||||
def __init__(self, session_id: str, machine_id: str = "default"):
|
||||
@@ -54,8 +88,20 @@ class TraceStreamer:
|
||||
self.running = False
|
||||
self._thread = None
|
||||
self._health_thread = None
|
||||
self._drain_thread = None
|
||||
self._server_available = True # Désactivé après trop d'échecs
|
||||
|
||||
# Buffer persistant — partagé entre sessions (survit au redémarrage)
|
||||
# Initialisé paresseusement pour ne pas payer le coût SQLite en dehors
|
||||
# d'un streaming actif.
|
||||
self._buffer: PersistentBuffer | None = None
|
||||
|
||||
def _get_buffer(self) -> PersistentBuffer:
|
||||
"""Retourne le buffer persistant, en l'initialisant au besoin."""
|
||||
if self._buffer is None:
|
||||
self._buffer = PersistentBuffer(BUFFER_DIR)
|
||||
return self._buffer
|
||||
|
||||
@staticmethod
|
||||
def _auth_headers() -> dict:
|
||||
"""Headers d'authentification Bearer pour les requêtes API."""
|
||||
@@ -75,6 +121,11 @@ class TraceStreamer:
|
||||
target=self._health_check_loop, daemon=True
|
||||
)
|
||||
self._health_thread.start()
|
||||
# Thread de drain du buffer persistant (rejoue les items en attente)
|
||||
self._drain_thread = threading.Thread(
|
||||
target=self._buffer_drain_loop, daemon=True
|
||||
)
|
||||
self._drain_thread.start()
|
||||
logger.info(f"Streamer pour {self.session_id} démarré")
|
||||
|
||||
def stop(self):
|
||||
@@ -99,6 +150,9 @@ class TraceStreamer:
|
||||
if self._health_thread:
|
||||
self._health_thread.join(timeout=2.0)
|
||||
|
||||
if self._drain_thread:
|
||||
self._drain_thread.join(timeout=2.0)
|
||||
|
||||
self._finalize_session()
|
||||
logger.info(f"Streamer pour {self.session_id} arrêté")
|
||||
|
||||
@@ -126,11 +180,21 @@ class TraceStreamer:
|
||||
|
||||
Quand la queue est pleine :
|
||||
- Les événements prioritaires (click, key, action, screenshot) sont
|
||||
ajoutés en bloquant brièvement (0.5s)
|
||||
- Les heartbeat sont silencieusement droppés
|
||||
ajoutés en bloquant brièvement (0.5s). Si toujours pleine → persistés
|
||||
dans le buffer SQLite pour rejeu ultérieur.
|
||||
- Les heartbeat sont silencieusement droppés.
|
||||
- Si le serveur est marqué indisponible, on persiste immédiatement les
|
||||
items prioritaires (évite de remplir la queue inutilement).
|
||||
"""
|
||||
is_priority = self._is_priority_item(item_type, data)
|
||||
|
||||
# Serveur indisponible + item prioritaire → on persiste directement
|
||||
# sans polluer la queue RAM (qui ne sera jamais vidée tant que le
|
||||
# serveur est down).
|
||||
if is_priority and not self._server_available:
|
||||
self._persist_to_buffer(item_type, data)
|
||||
return
|
||||
|
||||
try:
|
||||
self.queue.put_nowait((item_type, data))
|
||||
except queue.Full:
|
||||
@@ -139,9 +203,17 @@ class TraceStreamer:
|
||||
try:
|
||||
self.queue.put((item_type, data), timeout=0.5)
|
||||
except queue.Full:
|
||||
# Persistance disque (ne JAMAIS dropper un prioritaire)
|
||||
persisted = self._persist_to_buffer(item_type, data)
|
||||
if persisted:
|
||||
logger.warning(
|
||||
f"Queue pleine — événement prioritaire droppé "
|
||||
f"(type={item_type})"
|
||||
f"Queue pleine — événement prioritaire persisté "
|
||||
f"sur disque (type={item_type})"
|
||||
)
|
||||
else:
|
||||
logger.error(
|
||||
f"Queue pleine ET buffer saturé — événement "
|
||||
f"prioritaire perdu (type={item_type})"
|
||||
)
|
||||
else:
|
||||
# Heartbeat ou événement non-critique : on drop silencieusement
|
||||
@@ -163,6 +235,23 @@ class TraceStreamer:
|
||||
return event_type in PRIORITY_EVENT_TYPES
|
||||
return False
|
||||
|
||||
def _persist_to_buffer(self, item_type: str, data) -> bool:
|
||||
"""Persiste un item dans le buffer SQLite. Retourne True si OK.
|
||||
|
||||
Utilisé quand la queue est pleine ou le serveur indisponible.
|
||||
"""
|
||||
try:
|
||||
buf = self._get_buffer()
|
||||
if item_type == "event" and isinstance(data, dict):
|
||||
return buf.add_event(self.session_id, data)
|
||||
if item_type == "image":
|
||||
path, shot_id = data
|
||||
return buf.add_image(self.session_id, path, shot_id)
|
||||
except Exception as e:
|
||||
# On n'arrête jamais l'agent si le buffer échoue
|
||||
logger.error(f"Persistance buffer échouée : {e}")
|
||||
return False
|
||||
|
||||
# =========================================================================
|
||||
# Boucle d'envoi
|
||||
# =========================================================================
|
||||
@@ -174,16 +263,36 @@ class TraceStreamer:
|
||||
try:
|
||||
item_type, data = self.queue.get(timeout=0.5)
|
||||
success = False
|
||||
is_file_gone = False
|
||||
if item_type == "event":
|
||||
success = self._send_with_retry(self._send_event, data)
|
||||
elif item_type == "image":
|
||||
success = self._send_with_retry(self._send_image, *data)
|
||||
result = self._send_with_retry(self._send_image, *data)
|
||||
# Fix P0-E : distinguer FILE_GONE du vrai succès HTTP.
|
||||
if result is ImageSendResult.OK:
|
||||
success = True
|
||||
elif result is ImageSendResult.FILE_GONE:
|
||||
# Fichier disparu : pas de retry, pas de persistance
|
||||
# (on ne peut plus le renvoyer). On considère l'item
|
||||
# comme traité sans comptabiliser un succès réseau.
|
||||
is_file_gone = True
|
||||
success = False
|
||||
else:
|
||||
success = False
|
||||
self.queue.task_done()
|
||||
|
||||
if success:
|
||||
consecutive_failures = 0
|
||||
elif is_file_gone:
|
||||
# Fichier introuvable — déjà logué ERROR dans _send_image.
|
||||
# On ne persiste PAS dans le buffer (retry voué à échouer).
|
||||
consecutive_failures = 0
|
||||
else:
|
||||
consecutive_failures += 1
|
||||
# Après 3 retries infructueux, si l'item est prioritaire,
|
||||
# on le persiste pour ne pas le perdre définitivement.
|
||||
if self._is_priority_item(item_type, data):
|
||||
self._persist_to_buffer(item_type, data)
|
||||
if consecutive_failures >= 10:
|
||||
logger.warning(
|
||||
"10 échecs consécutifs — serveur marqué indisponible"
|
||||
@@ -200,15 +309,22 @@ class TraceStreamer:
|
||||
# Retry avec backoff exponentiel
|
||||
# =========================================================================
|
||||
|
||||
def _send_with_retry(self, send_fn, *args) -> bool:
|
||||
def _send_with_retry(self, send_fn, *args):
|
||||
"""Tente l'envoi avec retry et backoff exponentiel.
|
||||
|
||||
3 tentatives max avec délais de 1s, 2s, 4s entre chaque.
|
||||
Retourne True si l'envoi a réussi, False sinon.
|
||||
Retourne :
|
||||
- True / ImageSendResult.OK si l'envoi a réussi
|
||||
- ImageSendResult.FILE_GONE (images uniquement) — pas de retry
|
||||
- False / ImageSendResult.FAILED sinon
|
||||
"""
|
||||
# Première tentative (sans délai)
|
||||
if send_fn(*args):
|
||||
return True
|
||||
first = send_fn(*args)
|
||||
if first is ImageSendResult.OK or first is True:
|
||||
return first
|
||||
# Fix P0-E : FILE_GONE → pas de retry, l'erreur est permanente.
|
||||
if first is ImageSendResult.FILE_GONE:
|
||||
return first
|
||||
|
||||
# Retries avec backoff
|
||||
for attempt, delay in enumerate(RETRY_DELAYS, start=1):
|
||||
@@ -219,9 +335,13 @@ class TraceStreamer:
|
||||
f"Retry {attempt}/{MAX_RETRIES} dans {delay}s..."
|
||||
)
|
||||
time.sleep(delay)
|
||||
if send_fn(*args):
|
||||
result = send_fn(*args)
|
||||
if result is ImageSendResult.OK or result is True:
|
||||
logger.debug(f"Retry {attempt} réussi")
|
||||
return True
|
||||
return result
|
||||
# FILE_GONE pendant un retry — idem, on arrête
|
||||
if result is ImageSendResult.FILE_GONE:
|
||||
return result
|
||||
|
||||
logger.debug(f"Envoi échoué après {MAX_RETRIES} retries")
|
||||
return False
|
||||
@@ -260,6 +380,115 @@ class TraceStreamer:
|
||||
except Exception:
|
||||
logger.debug("Health-check échoué — serveur toujours indisponible")
|
||||
|
||||
# =========================================================================
|
||||
# Drain du buffer persistant (Partie B)
|
||||
# =========================================================================
|
||||
|
||||
def _buffer_drain_loop(self):
|
||||
"""Rejoue les items persistés en arrière-plan.
|
||||
|
||||
Tourne tant que self.running. Essaie de drainer le buffer toutes les
|
||||
BUFFER_DRAIN_INTERVAL_S secondes, mais seulement si :
|
||||
- le serveur est disponible,
|
||||
- il y a effectivement des items en attente.
|
||||
|
||||
Au premier passage (démarrage agent), on draine immédiatement pour
|
||||
rejouer tout ce qui a été persisté lors de la session précédente.
|
||||
"""
|
||||
# Au démarrage : drain immédiat (pas d'attente)
|
||||
first_pass = True
|
||||
while self.running:
|
||||
if not first_pass:
|
||||
time.sleep(BUFFER_DRAIN_INTERVAL_S)
|
||||
if not self.running:
|
||||
break
|
||||
first_pass = False
|
||||
|
||||
if not self._server_available:
|
||||
continue
|
||||
|
||||
try:
|
||||
buf = self._get_buffer()
|
||||
# Abandonner d'abord les items exceeded (évite de les retenter)
|
||||
abandoned = buf.abandon_exceeded()
|
||||
if abandoned:
|
||||
logger.warning(
|
||||
f"Buffer : {abandoned} items abandonnés "
|
||||
f"après {MAX_ATTEMPTS} tentatives"
|
||||
)
|
||||
|
||||
counts = buf.counts()
|
||||
if counts["events"] == 0 and counts["images"] == 0:
|
||||
continue
|
||||
|
||||
logger.info(
|
||||
f"Buffer drain : {counts['events']} events, "
|
||||
f"{counts['images']} images en attente — rejeu"
|
||||
)
|
||||
self._drain_buffer_once(buf)
|
||||
except Exception as e:
|
||||
logger.error(f"Buffer drain loop échoué : {e}")
|
||||
|
||||
def _drain_buffer_once(self, buf: PersistentBuffer):
|
||||
"""Une passe de drain : envoie ce qui peut l'être, incrémente le reste.
|
||||
|
||||
On arrête dès qu'un envoi échoue (serveur probablement down).
|
||||
"""
|
||||
# Events d'abord (plus légers, priorité métier AI Act)
|
||||
for row in buf.drain_events(limit=50):
|
||||
if not self._server_available:
|
||||
return
|
||||
try:
|
||||
import json as _json
|
||||
event = _json.loads(row["payload"])
|
||||
except (ValueError, TypeError):
|
||||
logger.error(
|
||||
f"Buffer : payload event #{row['id']} corrompu, suppression"
|
||||
)
|
||||
buf.delete_event(row["id"])
|
||||
continue
|
||||
if self._send_event(event):
|
||||
buf.delete_event(row["id"])
|
||||
else:
|
||||
buf.increment_attempts(row["id"], "event")
|
||||
# Serveur répond mal — on arrête la passe
|
||||
return
|
||||
|
||||
# Puis images
|
||||
for row in buf.drain_images(limit=20):
|
||||
if not self._server_available:
|
||||
return
|
||||
image_path = row["image_path"]
|
||||
shot_id = row["shot_id"]
|
||||
if not os.path.exists(image_path):
|
||||
# Fichier local disparu (purge, clean-up) — on abandonne.
|
||||
# Fix P0-E : log ERROR (pas warning) — c'est une perte de donnée.
|
||||
logger.error(
|
||||
f"Buffer : image #{row['id']} introuvable sur disque "
|
||||
f"({image_path}) — entrée abandonnée (le serveur n'a "
|
||||
f"jamais reçu cette image, session={row['session_id']}, "
|
||||
f"shot={shot_id})"
|
||||
)
|
||||
buf.delete_image(row["id"])
|
||||
continue
|
||||
result = self._send_image(image_path, shot_id)
|
||||
if result is ImageSendResult.OK or result is True:
|
||||
buf.delete_image(row["id"])
|
||||
elif result is ImageSendResult.FILE_GONE:
|
||||
# Fix P0-E : fichier disparu pendant l'envoi.
|
||||
# Ce n'est PAS un succès HTTP — ne pas considérer comme tel.
|
||||
# On supprime néanmoins l'entrée (retry voué à échouer)
|
||||
# mais avec un log ERROR explicite.
|
||||
logger.error(
|
||||
f"Buffer : image #{row['id']} disparue pendant l'envoi "
|
||||
f"({image_path}) — entrée abandonnée, pas de retry "
|
||||
f"(session={row['session_id']}, shot={shot_id})"
|
||||
)
|
||||
buf.delete_image(row["id"])
|
||||
else:
|
||||
buf.increment_attempts(row["id"], "image")
|
||||
return
|
||||
|
||||
# =========================================================================
|
||||
# Compression JPEG
|
||||
# =========================================================================
|
||||
@@ -287,6 +516,56 @@ class TraceStreamer:
|
||||
logger.warning(f"Compression JPEG échouée, envoi PNG brut: {e}")
|
||||
return None, None, None
|
||||
|
||||
# =========================================================================
|
||||
# Purge locale après ACK (Partie A)
|
||||
# =========================================================================
|
||||
|
||||
@staticmethod
|
||||
def _purge_local_image(path: str):
|
||||
"""Supprime un screenshot local après ACK 200 du serveur.
|
||||
|
||||
Ne crashe JAMAIS si le fichier est verrouillé (cas Windows) ou
|
||||
déjà supprimé : on log en debug et on continue. L'auto-cleanup
|
||||
de SessionStorage repassera plus tard.
|
||||
"""
|
||||
if not PURGE_AFTER_ACK:
|
||||
return
|
||||
try:
|
||||
os.remove(path)
|
||||
logger.debug(f"Screenshot local purgé après ACK : {path}")
|
||||
except FileNotFoundError:
|
||||
# Déjà supprimé ou chemin invalide — silencieux
|
||||
pass
|
||||
except PermissionError as e:
|
||||
# Windows verrouille parfois les fichiers (antivirus, indexation...)
|
||||
logger.debug(
|
||||
f"Purge différée (fichier verrouillé) : {path} — {e}"
|
||||
)
|
||||
except OSError as e:
|
||||
logger.debug(f"Purge échouée : {path} — {e}")
|
||||
|
||||
# =========================================================================
|
||||
# Protection redirect POST→GET (INC-7)
|
||||
# =========================================================================
|
||||
|
||||
@staticmethod
|
||||
def _check_redirect(resp, url: str):
|
||||
"""Detecter et logger une redirection sur un POST.
|
||||
|
||||
La lib requests transforme un POST en GET sur 301/302 (RFC 7231).
|
||||
Avec allow_redirects=False, on recoit le 301/302 directement.
|
||||
On log un WARNING explicite pour que l'admin corrige l'URL.
|
||||
"""
|
||||
if resp.status_code in (301, 302, 307, 308):
|
||||
location = resp.headers.get("Location", "?")
|
||||
logger.warning(
|
||||
f"Redirection {resp.status_code} detectee sur POST {url} "
|
||||
f"→ {location}. Verifiez que RPA_SERVER_URL utilise "
|
||||
f"https:// si le serveur redirige."
|
||||
)
|
||||
return True
|
||||
return False
|
||||
|
||||
# =========================================================================
|
||||
# Envois HTTP
|
||||
# =========================================================================
|
||||
@@ -294,15 +573,20 @@ class TraceStreamer:
|
||||
def _register_session(self):
|
||||
"""Enregistrer la session auprès du serveur (avec identifiant machine)."""
|
||||
try:
|
||||
url = f"{STREAMING_ENDPOINT}/register"
|
||||
resp = requests.post(
|
||||
f"{STREAMING_ENDPOINT}/register",
|
||||
url,
|
||||
params={
|
||||
"session_id": self.session_id,
|
||||
"machine_id": self.machine_id,
|
||||
},
|
||||
headers=self._auth_headers(),
|
||||
timeout=3,
|
||||
allow_redirects=False,
|
||||
)
|
||||
if self._check_redirect(resp, url):
|
||||
logger.warning("Enregistrement session échoué (redirect)")
|
||||
return
|
||||
if resp.ok:
|
||||
logger.info(
|
||||
f"Session {self.session_id} enregistrée sur le serveur "
|
||||
@@ -322,28 +606,32 @@ class TraceStreamer:
|
||||
C'est la dernière chance de sauver les données de la session.
|
||||
"""
|
||||
try:
|
||||
url = f"{STREAMING_ENDPOINT}/finalize"
|
||||
resp = requests.post(
|
||||
f"{STREAMING_ENDPOINT}/finalize",
|
||||
url,
|
||||
params={
|
||||
"session_id": self.session_id,
|
||||
"machine_id": self.machine_id,
|
||||
},
|
||||
headers=self._auth_headers(),
|
||||
timeout=30, # Le build workflow peut prendre du temps
|
||||
allow_redirects=False,
|
||||
)
|
||||
self._check_redirect(resp, url)
|
||||
if resp.ok:
|
||||
result = resp.json()
|
||||
logger.info(f"Session finalisée: {result}")
|
||||
else:
|
||||
logger.warning(f"Finalisation échouée: {resp.status_code}")
|
||||
except Exception as e:
|
||||
logger.debug(f"Finalisation échouée: {e}")
|
||||
logger.warning(f"Finalisation échouée: {e}")
|
||||
|
||||
def _send_event(self, event: dict) -> bool:
|
||||
"""Envoyer un événement au serveur (avec identifiant machine)."""
|
||||
if not self._server_available:
|
||||
return False
|
||||
try:
|
||||
url = f"{STREAMING_ENDPOINT}/event"
|
||||
payload = {
|
||||
"session_id": self.session_id,
|
||||
"timestamp": time.time(),
|
||||
@@ -351,24 +639,36 @@ class TraceStreamer:
|
||||
"machine_id": self.machine_id,
|
||||
}
|
||||
resp = requests.post(
|
||||
f"{STREAMING_ENDPOINT}/event",
|
||||
url,
|
||||
json=payload,
|
||||
headers=self._auth_headers(),
|
||||
timeout=2,
|
||||
allow_redirects=False,
|
||||
)
|
||||
if self._check_redirect(resp, url):
|
||||
return False
|
||||
return resp.ok
|
||||
except Exception as e:
|
||||
logger.debug(f"Streaming Event échoué: {e}")
|
||||
return False
|
||||
|
||||
def _send_image(self, path: str, shot_id: str) -> bool:
|
||||
def _send_image(self, path: str, shot_id: str):
|
||||
"""Envoyer un screenshot au serveur, compressé en JPEG.
|
||||
|
||||
Utilise un context manager pour le fallback PNG afin d'éviter
|
||||
les fuites de descripteurs de fichier.
|
||||
|
||||
Partie A (purge après ACK) : en cas de HTTP 200 confirmé, le fichier
|
||||
local est supprimé (le serveur devient la source de vérité).
|
||||
|
||||
Fix P0-E : retourne `ImageSendResult` (OK / FAILED / FILE_GONE).
|
||||
Les appelants historiques qui attendaient un bool continuent de
|
||||
fonctionner grâce à la truthiness du enum (OK → True, reste → False),
|
||||
MAIS le drain du buffer doit désormais discriminer FILE_GONE pour
|
||||
ne pas confondre "fichier disparu" avec "envoyé avec succès".
|
||||
"""
|
||||
if not self._server_available:
|
||||
return False
|
||||
return ImageSendResult.FAILED
|
||||
try:
|
||||
# Tenter la compression JPEG (réduction ~5-10x vs PNG)
|
||||
jpeg_buf, content_type, suffix = self._compress_image_to_jpeg(path)
|
||||
@@ -379,19 +679,26 @@ class TraceStreamer:
|
||||
"machine_id": self.machine_id,
|
||||
}
|
||||
|
||||
url = f"{STREAMING_ENDPOINT}/image"
|
||||
if jpeg_buf is not None:
|
||||
# Envoi du JPEG compressé (BytesIO, pas de fuite possible)
|
||||
files = {
|
||||
"file": (f"{shot_id}{suffix}", jpeg_buf, content_type)
|
||||
}
|
||||
resp = requests.post(
|
||||
f"{STREAMING_ENDPOINT}/image",
|
||||
url,
|
||||
files=files,
|
||||
params=params,
|
||||
headers=self._auth_headers(),
|
||||
timeout=5,
|
||||
allow_redirects=False,
|
||||
)
|
||||
return resp.ok
|
||||
if self._check_redirect(resp, url):
|
||||
return ImageSendResult.FAILED
|
||||
if resp.ok:
|
||||
self._purge_local_image(path)
|
||||
return ImageSendResult.OK
|
||||
return ImageSendResult.FAILED
|
||||
else:
|
||||
# Fallback : envoi PNG original avec context manager
|
||||
with open(path, "rb") as f:
|
||||
@@ -399,13 +706,29 @@ class TraceStreamer:
|
||||
"file": (f"{shot_id}.png", f, "image/png")
|
||||
}
|
||||
resp = requests.post(
|
||||
f"{STREAMING_ENDPOINT}/image",
|
||||
url,
|
||||
files=files,
|
||||
params=params,
|
||||
headers=self._auth_headers(),
|
||||
timeout=5,
|
||||
allow_redirects=False,
|
||||
)
|
||||
return resp.ok
|
||||
if self._check_redirect(resp, url):
|
||||
return ImageSendResult.FAILED
|
||||
if resp.ok:
|
||||
self._purge_local_image(path)
|
||||
return ImageSendResult.OK
|
||||
return ImageSendResult.FAILED
|
||||
except FileNotFoundError:
|
||||
# Fix P0-E : fichier local disparu. On NE doit PAS considérer ça
|
||||
# comme un succès HTTP 200. Le serveur n'a rien reçu. On signale
|
||||
# `FILE_GONE` pour que le drain du buffer supprime l'entrée
|
||||
# (pas de retry possible) tout en loguant ERROR (pas debug).
|
||||
logger.error(
|
||||
f"Image {shot_id} introuvable sur disque ({path}) — "
|
||||
f"abandon (serveur n'a rien reçu)"
|
||||
)
|
||||
return ImageSendResult.FILE_GONE
|
||||
except Exception as e:
|
||||
logger.debug(f"Streaming Image échoué: {e}")
|
||||
return False
|
||||
return ImageSendResult.FAILED
|
||||
|
||||
@@ -3,6 +3,7 @@ mss>=9.0.1 # Capture d'écran haute performance
|
||||
pynput>=1.7.7 # Clavier/Souris Cross-plateforme
|
||||
Pillow>=10.0.0 # Crops et processing image
|
||||
requests>=2.31.0 # Streaming réseau
|
||||
python-socketio[client]>=5.10,<6.0 # Bus feedback Léa 'lea:*' (compat Flask-SocketIO 5.3.x serveur)
|
||||
psutil>=5.9.0 # Monitoring CPU/RAM
|
||||
pystray>=0.19.5 # Icône Tray UI
|
||||
plyer>=2.1.0 # Notifications toast natives (remplace PyQt5)
|
||||
|
||||
@@ -3,15 +3,25 @@ Mini serveur HTTP sur l'agent Windows pour les captures d'ecran a la demande
|
||||
et les operations fichiers.
|
||||
|
||||
Ecoute sur le port 5006 (configurable via RPA_CAPTURE_PORT).
|
||||
Bind par defaut sur 127.0.0.1 (configurable via RPA_CAPTURE_BIND).
|
||||
Endpoints :
|
||||
GET /capture -> screenshot frais en base64 (JPEG)
|
||||
GET /health -> {"status": "ok"}
|
||||
GET /health -> {"status": "ok"} (pas d'auth — sonde liveness)
|
||||
POST /file-action -> operations fichiers (list, create, move, copy, sort)
|
||||
|
||||
Securite :
|
||||
- Authentification Bearer obligatoire (RPA_API_TOKEN) pour /capture et
|
||||
/file-action. Sans token configure, ces endpoints sont desactives.
|
||||
- Les tentatives non authentifiees sont loguees (WARNING) avec l'IP source.
|
||||
- Bind defaut localhost. Pour exposer sur le LAN (cas VWB backend qui
|
||||
appelle l'agent a distance), definir explicitement
|
||||
RPA_CAPTURE_BIND=0.0.0.0. L'auth reste alors la seule protection.
|
||||
"""
|
||||
import threading
|
||||
import logging
|
||||
import json
|
||||
import base64
|
||||
import hmac
|
||||
import io
|
||||
import os
|
||||
import time
|
||||
@@ -20,6 +30,17 @@ from http.server import HTTPServer, BaseHTTPRequestHandler
|
||||
logger = logging.getLogger(__name__)
|
||||
|
||||
CAPTURE_PORT = int(os.environ.get("RPA_CAPTURE_PORT", "5006"))
|
||||
# Bind par defaut sur localhost — defense en profondeur.
|
||||
# Pour le deploiement VWB (backend Linux -> agent Windows), definir
|
||||
# RPA_CAPTURE_BIND=0.0.0.0 explicitement. L'auth par token reste requise.
|
||||
CAPTURE_BIND = os.environ.get("RPA_CAPTURE_BIND", "127.0.0.1")
|
||||
|
||||
# Token d'authentification (partage avec le streaming). Doit etre defini pour
|
||||
# que /capture et /file-action soient accessibles.
|
||||
CAPTURE_TOKEN = os.environ.get("RPA_API_TOKEN", "")
|
||||
|
||||
# Endpoints ouverts (pas d'auth requise — sondes techniques uniquement)
|
||||
_PUBLIC_PATHS = {"/health"}
|
||||
|
||||
# Floutage des données sensibles (conformité AI Act)
|
||||
BLUR_SENSITIVE = os.environ.get("RPA_BLUR_SENSITIVE", "true").lower() in ("true", "1", "yes")
|
||||
@@ -33,6 +54,8 @@ class CaptureHandler(BaseHTTPRequestHandler):
|
||||
|
||||
def do_GET(self):
|
||||
if self.path == "/capture":
|
||||
if not self._check_auth():
|
||||
return
|
||||
self._handle_capture()
|
||||
elif self.path == "/health":
|
||||
self._send_json(200, {"status": "ok"})
|
||||
@@ -41,10 +64,56 @@ class CaptureHandler(BaseHTTPRequestHandler):
|
||||
|
||||
def do_POST(self):
|
||||
if self.path == "/file-action":
|
||||
if not self._check_auth():
|
||||
return
|
||||
self._handle_file_action()
|
||||
else:
|
||||
self._send_json(404, {"error": "not found"})
|
||||
|
||||
# ------------------------------------------------------------------
|
||||
|
||||
def _check_auth(self) -> bool:
|
||||
"""Valide le Bearer token. Renvoie 401/503 si invalide.
|
||||
|
||||
- Si aucun token n'est configure cote serveur (RPA_API_TOKEN vide),
|
||||
on refuse toutes les requetes sensibles (503) — fail-closed.
|
||||
- Sinon, on compare en temps constant via hmac.compare_digest.
|
||||
- Les tentatives echouees sont loguees avec l'IP source.
|
||||
"""
|
||||
# Autoriser les endpoints publics
|
||||
if self.path in _PUBLIC_PATHS:
|
||||
return True
|
||||
|
||||
peer = self.client_address[0] if self.client_address else "?"
|
||||
|
||||
if not CAPTURE_TOKEN:
|
||||
logger.error(
|
||||
"Refus %s depuis %s : RPA_API_TOKEN non configure "
|
||||
"(capture server en mode fail-closed)",
|
||||
self.path, peer,
|
||||
)
|
||||
self._send_json(503, {
|
||||
"error": "capture server non configure (token manquant)",
|
||||
})
|
||||
return False
|
||||
|
||||
auth_header = self.headers.get("Authorization", "")
|
||||
token = ""
|
||||
if auth_header.startswith("Bearer "):
|
||||
token = auth_header[len("Bearer "):].strip()
|
||||
|
||||
if not token or not hmac.compare_digest(token, CAPTURE_TOKEN):
|
||||
logger.warning(
|
||||
"Tentative d'acces non autorisee a %s depuis %s "
|
||||
"(token %s)",
|
||||
self.path, peer,
|
||||
"absent" if not token else "invalide",
|
||||
)
|
||||
self._send_json(401, {"error": "unauthorized"})
|
||||
return False
|
||||
|
||||
return True
|
||||
|
||||
def do_OPTIONS(self):
|
||||
"""Gestion CORS preflight."""
|
||||
self.send_response(200)
|
||||
@@ -351,21 +420,46 @@ class _FileActionHandlerLocal:
|
||||
class CaptureServer:
|
||||
"""Serveur de capture d'ecran en temps reel (thread daemon)."""
|
||||
|
||||
def __init__(self, port: int = CAPTURE_PORT):
|
||||
def __init__(self, port: int = CAPTURE_PORT, bind: str = CAPTURE_BIND):
|
||||
self._port = port
|
||||
self._bind = bind
|
||||
self._server: HTTPServer | None = None
|
||||
self._thread: threading.Thread | None = None
|
||||
|
||||
def start(self):
|
||||
"""Demarre le serveur dans un thread daemon."""
|
||||
"""Demarre le serveur dans un thread daemon.
|
||||
|
||||
Avertit si le serveur est expose sur le LAN sans token configure.
|
||||
"""
|
||||
# Defense en profondeur : refus de demarrer si expose LAN sans auth
|
||||
exposed_lan = self._bind not in ("127.0.0.1", "localhost", "::1")
|
||||
if exposed_lan and not CAPTURE_TOKEN:
|
||||
logger.error(
|
||||
"REFUS demarrage capture server : bind=%s (LAN) sans "
|
||||
"RPA_API_TOKEN. Definir le token ou RPA_CAPTURE_BIND=127.0.0.1.",
|
||||
self._bind,
|
||||
)
|
||||
print(
|
||||
f"[CAPTURE] REFUS demarrage : bind={self._bind} sans token. "
|
||||
f"Definir RPA_API_TOKEN ou RPA_CAPTURE_BIND=127.0.0.1."
|
||||
)
|
||||
return
|
||||
|
||||
try:
|
||||
self._server = HTTPServer(("0.0.0.0", self._port), CaptureHandler)
|
||||
self._server = HTTPServer((self._bind, self._port), CaptureHandler)
|
||||
self._thread = threading.Thread(
|
||||
target=self._server.serve_forever, daemon=True
|
||||
)
|
||||
self._thread.start()
|
||||
logger.info(f"Capture server demarre sur le port {self._port}")
|
||||
print(f"[CAPTURE] Serveur de capture demarre sur le port {self._port}")
|
||||
auth_mode = "token requis" if CAPTURE_TOKEN else "token absent (fail-closed)"
|
||||
logger.info(
|
||||
"Capture server demarre sur %s:%s (%s)",
|
||||
self._bind, self._port, auth_mode,
|
||||
)
|
||||
print(
|
||||
f"[CAPTURE] Serveur de capture demarre sur "
|
||||
f"{self._bind}:{self._port} ({auth_mode})"
|
||||
)
|
||||
except Exception as e:
|
||||
logger.error(f"Impossible de demarrer le capture server : {e}")
|
||||
print(f"[CAPTURE] ERREUR demarrage : {e}")
|
||||
|
||||
@@ -16,6 +16,15 @@ from typing import Any, Callable, Dict, Optional
|
||||
|
||||
logger = logging.getLogger(__name__)
|
||||
|
||||
# FeedbackBus : import fail-safe (le ChatWindow doit tourner même si python-socketio
|
||||
# n'est pas installé sur le poste client, par exemple ancienne installation Pauline)
|
||||
try:
|
||||
from ..network.feedback_bus import FeedbackBusClient
|
||||
_HAS_FEEDBACK_BUS = True
|
||||
except Exception:
|
||||
FeedbackBusClient = None # type: ignore
|
||||
_HAS_FEEDBACK_BUS = False
|
||||
|
||||
# ---------------------------------------------------------------------------
|
||||
# Theme — palette professionnelle claire
|
||||
# ---------------------------------------------------------------------------
|
||||
@@ -42,6 +51,25 @@ SCROLLBAR_BG = "#E5E7EB" # Fond scrollbar
|
||||
SCROLLBAR_FG = "#9CA3AF" # Curseur scrollbar
|
||||
MSG_BORDER_COLOR = "#D1D5DB" # Bordure subtile des bulles de messages
|
||||
|
||||
# Bulle paused_need_help (J3.5) — alerte non bloquante, asset démo majeur
|
||||
PAUSED_BG = "#FEF3C7" # Jaune pâle
|
||||
PAUSED_BORDER = "#F59E0B" # Orange ambré
|
||||
PAUSED_FG = "#92400E" # Brun foncé (lisible sur fond jaune)
|
||||
PAUSED_BTN_RESUME_BG = "#22C55E" # Vert
|
||||
PAUSED_BTN_RESUME_HOVER = "#16A34A"
|
||||
PAUSED_BTN_ABORT_BG = "#9CA3AF" # Gris neutre (pas dramatique)
|
||||
PAUSED_BTN_ABORT_HOVER = "#6B7280"
|
||||
|
||||
# Bulle "Léa exécute" (J3.4) — distincte des bulles chat normales
|
||||
ACTION_BG = "#F1F5F9" # Gris très clair (différencie d'une réponse chat)
|
||||
ACTION_BORDER = "#CBD5E1" # Gris pâle
|
||||
ACTION_FG = "#1E293B" # Gris foncé
|
||||
ACTION_META_FG = "#94A3B8" # Métadonnées en gris discret
|
||||
ACTION_ICON_RUN = "#3B82F6" # Bleu (en cours)
|
||||
ACTION_ICON_OK = "#22C55E" # Vert (succès)
|
||||
ACTION_ICON_ERR = "#EF4444" # Rouge (échec)
|
||||
ACTION_ICON_INFO = "#64748B" # Gris (neutre)
|
||||
|
||||
# Dimensions — confortables
|
||||
WIN_WIDTH = 600
|
||||
WIN_HEIGHT = 800
|
||||
@@ -62,6 +90,80 @@ FONT_SEND_BTN = ("Segoe UI", 13)
|
||||
FONT_RESIZE_GRIP = ("Segoe UI", 10)
|
||||
|
||||
|
||||
# ---------------------------------------------------------------------------
|
||||
# Templates de bulles "Léa exécute" (J3.4)
|
||||
# Chaque template prend un payload et retourne (icon, icon_color, title).
|
||||
# Les libellés sont volontairement neutres : le contexte métier vient du
|
||||
# payload (workflow, action, message), pas de hardcoding.
|
||||
# ---------------------------------------------------------------------------
|
||||
|
||||
def _tpl_action_started(payload: Dict[str, Any]) -> tuple:
|
||||
wf = payload.get("workflow") or "?"
|
||||
return ("▶", ACTION_ICON_RUN, f"Démarrage : {wf}")
|
||||
|
||||
|
||||
def _tpl_action_progress(payload: Dict[str, Any]) -> tuple:
|
||||
cur = payload.get("current", "?")
|
||||
tot = payload.get("total", "?")
|
||||
step = payload.get("step")
|
||||
title = step if step else f"Étape {cur}/{tot}"
|
||||
return ("⋯", ACTION_ICON_RUN, str(title))
|
||||
|
||||
|
||||
def _tpl_done(payload: Dict[str, Any]) -> tuple:
|
||||
success = bool(payload.get("success", True))
|
||||
msg = payload.get("message") or ("Terminé" if success else "Échec")
|
||||
if success:
|
||||
return ("✓", ACTION_ICON_OK, str(msg))
|
||||
return ("✗", ACTION_ICON_ERR, str(msg))
|
||||
|
||||
|
||||
def _tpl_need_confirm(payload: Dict[str, Any]) -> tuple:
|
||||
action = payload.get("action") or {}
|
||||
desc = action.get("description") if isinstance(action, dict) else None
|
||||
title = desc or "Validation requise"
|
||||
return ("?", ACTION_ICON_RUN, str(title))
|
||||
|
||||
|
||||
def _tpl_step_result(payload: Dict[str, Any]) -> tuple:
|
||||
status = (payload.get("status") or "").lower()
|
||||
msg = payload.get("message") or status or "Étape terminée"
|
||||
if status in ("ok", "success", "approved"):
|
||||
return ("✓", ACTION_ICON_OK, str(msg))
|
||||
if status in ("error", "failed"):
|
||||
return ("✗", ACTION_ICON_ERR, str(msg))
|
||||
return ("·", ACTION_ICON_INFO, str(msg))
|
||||
|
||||
|
||||
def _tpl_resumed(payload: Dict[str, Any]) -> tuple:
|
||||
return ("→", ACTION_ICON_OK, "Reprise")
|
||||
|
||||
|
||||
_ACTION_TEMPLATES = {
|
||||
"lea:action_started": _tpl_action_started,
|
||||
"lea:action_progress": _tpl_action_progress,
|
||||
"lea:done": _tpl_done,
|
||||
"lea:need_confirm": _tpl_need_confirm,
|
||||
"lea:step_result": _tpl_step_result,
|
||||
"lea:resumed": _tpl_resumed,
|
||||
}
|
||||
|
||||
|
||||
def _extract_meta(payload: Dict[str, Any]) -> str:
|
||||
"""Métadonnées techniques en pied de bulle (workflow, étape, replay_id court)."""
|
||||
parts = []
|
||||
wf = payload.get("workflow")
|
||||
if wf:
|
||||
parts.append(str(wf))
|
||||
cur, tot = payload.get("current"), payload.get("total")
|
||||
if cur is not None and tot is not None:
|
||||
parts.append(f"étape {cur}/{tot}")
|
||||
rid = payload.get("replay_id")
|
||||
if rid:
|
||||
parts.append(f"#{str(rid)[-6:]}")
|
||||
return " • ".join(parts)
|
||||
|
||||
|
||||
class ChatWindow:
|
||||
"""Fenetre de chat Lea en tkinter natif.
|
||||
|
||||
@@ -91,6 +193,8 @@ class ChatWindow:
|
||||
self._root = None
|
||||
self._ready = threading.Event()
|
||||
self._messages = [] # historique local
|
||||
self._bus: Optional[Any] = None # FeedbackBusClient (J3.3, peut rester None)
|
||||
self._active_paused_bubble: Optional[Dict[str, Any]] = None # bulle paused active (J3.5)
|
||||
|
||||
# S'abonner aux changements de l'etat partage
|
||||
if self._shared_state is not None:
|
||||
@@ -266,6 +370,9 @@ class ChatWindow:
|
||||
# Signaler que la fenetre est prete
|
||||
self._ready.set()
|
||||
|
||||
# Demarrer le bus feedback Lea (events 'lea:*' temps reel)
|
||||
self._start_feedback_bus()
|
||||
|
||||
# Boucle tkinter
|
||||
root.mainloop()
|
||||
|
||||
@@ -608,6 +715,12 @@ class ChatWindow:
|
||||
|
||||
def _do_destroy(self) -> None:
|
||||
"""Detruit la fenetre (appele dans le thread tkinter)."""
|
||||
if self._bus is not None:
|
||||
try:
|
||||
self._bus.stop()
|
||||
except Exception:
|
||||
pass
|
||||
self._bus = None
|
||||
if self._root is not None:
|
||||
try:
|
||||
self._root.quit()
|
||||
@@ -617,6 +730,232 @@ class ChatWindow:
|
||||
self._root = None
|
||||
self._visible = False
|
||||
|
||||
# ======================================================================
|
||||
# FeedbackBus — bulles temps reel pendant l'execution (J3.3)
|
||||
# ======================================================================
|
||||
|
||||
def _start_feedback_bus(self) -> None:
|
||||
"""Demarrer la connexion au bus 'lea:*' si flag actif et lib disponible."""
|
||||
if not _HAS_FEEDBACK_BUS:
|
||||
logger.debug("FeedbackBus non disponible (python-socketio manquant)")
|
||||
return
|
||||
flag = os.environ.get("LEA_FEEDBACK_BUS", "0").lower()
|
||||
if flag not in ("1", "true", "yes", "on"):
|
||||
return
|
||||
try:
|
||||
url = f"http://{self._server_host}:{self._chat_port}"
|
||||
token = os.environ.get("RPA_API_TOKEN", "") or None
|
||||
self._bus = FeedbackBusClient(url, token=token, on_event=self._on_lea_event)
|
||||
self._bus.start()
|
||||
logger.info("FeedbackBus demarre : %s", url)
|
||||
except Exception:
|
||||
logger.debug("FeedbackBus init silenced", exc_info=True)
|
||||
self._bus = None
|
||||
|
||||
def _on_lea_event(self, event: str, payload: Dict[str, Any]) -> None:
|
||||
"""Callback bus → bulle Lea. Thread-safe : helpers utilisent root.after."""
|
||||
payload = payload or {}
|
||||
|
||||
# J3.5 : la pause supervisée a sa propre bulle interactive
|
||||
if event == "lea:paused":
|
||||
self._add_paused_bubble(payload)
|
||||
return
|
||||
if event in ("lea:resumed", "lea:done"):
|
||||
self._close_active_paused_bubble(reason=event)
|
||||
# on continue pour afficher la bulle d'action (cf. dispatch ci-dessous)
|
||||
|
||||
# Acks bus (resume_acked, abort_acked) : silencieux côté UI
|
||||
if event in ("lea:resume_acked", "lea:abort_acked"):
|
||||
return
|
||||
|
||||
# J3.4 : bulle "Léa exécute" stylisée (séparée des bulles chat normales)
|
||||
rendered = _ACTION_TEMPLATES.get(event)
|
||||
if rendered is None:
|
||||
# Event inconnu : on affiche en bulle d'action neutre
|
||||
self._add_action_bubble(
|
||||
icon="·", icon_color=ACTION_ICON_INFO,
|
||||
title=event.removeprefix("lea:"),
|
||||
meta=_extract_meta(payload),
|
||||
)
|
||||
return
|
||||
icon, icon_color, title = rendered(payload)
|
||||
self._add_action_bubble(
|
||||
icon=icon, icon_color=icon_color, title=title,
|
||||
meta=_extract_meta(payload),
|
||||
)
|
||||
|
||||
# ------------------------------------------------------------------
|
||||
# Bulle "Léa exécute" stylisée (J3.4)
|
||||
# ------------------------------------------------------------------
|
||||
|
||||
def _add_action_bubble(
|
||||
self, icon: str, icon_color: str, title: str, meta: str = "",
|
||||
) -> None:
|
||||
if self._root is None:
|
||||
return
|
||||
self._root.after(0, lambda: self._render_action_bubble(icon, icon_color, title, meta))
|
||||
|
||||
def _render_action_bubble(
|
||||
self, icon: str, icon_color: str, title: str, meta: str,
|
||||
) -> None:
|
||||
tk = self._tk
|
||||
if getattr(self, "_msg_frame", None) is None:
|
||||
return
|
||||
now = datetime.now().strftime("%H:%M")
|
||||
|
||||
container = tk.Frame(self._msg_frame, bg=BG_COLOR)
|
||||
container.pack(fill=tk.X, padx=MARGIN, pady=3)
|
||||
|
||||
inner = tk.Frame(
|
||||
container, bg=ACTION_BG, padx=10, pady=6,
|
||||
highlightbackground=ACTION_BORDER, highlightthickness=1,
|
||||
)
|
||||
inner.pack(anchor=tk.W, padx=(0, 70), fill=tk.X)
|
||||
|
||||
row = tk.Frame(inner, bg=ACTION_BG)
|
||||
row.pack(fill=tk.X, anchor=tk.W)
|
||||
|
||||
tk.Label(
|
||||
row, text=icon, bg=ACTION_BG, fg=icon_color,
|
||||
font=("Segoe UI", 13, "bold"), padx=4,
|
||||
).pack(side=tk.LEFT)
|
||||
|
||||
tk.Label(
|
||||
row, text=title, bg=ACTION_BG, fg=ACTION_FG,
|
||||
font=FONT_MSG, anchor="w", justify=tk.LEFT,
|
||||
wraplength=MSG_WRAP_WIDTH - 60,
|
||||
).pack(side=tk.LEFT, fill=tk.X, expand=True, padx=(2, 0))
|
||||
|
||||
if meta:
|
||||
tk.Label(
|
||||
inner, text=f"{meta} • {now}",
|
||||
bg=ACTION_BG, fg=ACTION_META_FG,
|
||||
font=FONT_TIMESTAMP, anchor="w",
|
||||
).pack(fill=tk.X, anchor=tk.W, pady=(2, 0))
|
||||
|
||||
# ------------------------------------------------------------------
|
||||
# Bulle paused_need_help interactive (J3.5)
|
||||
# ------------------------------------------------------------------
|
||||
|
||||
def _add_paused_bubble(self, payload: Dict[str, Any]) -> None:
|
||||
"""Ajouter une bulle paused interactive (asset démo : Léa demande de l'aide)."""
|
||||
if self._root is None:
|
||||
return
|
||||
self._root.after(0, lambda: self._render_paused_bubble(payload))
|
||||
|
||||
def _render_paused_bubble(self, payload: Dict[str, Any]) -> None:
|
||||
tk = self._tk
|
||||
if getattr(self, "_msg_frame", None) is None:
|
||||
return
|
||||
|
||||
replay_id = str(payload.get("replay_id", "") or "")
|
||||
workflow = payload.get("workflow", "?")
|
||||
reason = payload.get("reason") or "Action incertaine — j'ai besoin de votre validation."
|
||||
completed = payload.get("completed", 0)
|
||||
total = payload.get("total", "?")
|
||||
now = datetime.now().strftime("%H:%M")
|
||||
|
||||
container = tk.Frame(self._msg_frame, bg=BG_COLOR)
|
||||
container.pack(fill=tk.X, padx=MARGIN, pady=6)
|
||||
|
||||
inner = tk.Frame(
|
||||
container, bg=PAUSED_BG, padx=14, pady=12,
|
||||
highlightbackground=PAUSED_BORDER, highlightthickness=2,
|
||||
)
|
||||
inner.pack(anchor=tk.W, padx=(0, 50), fill=tk.X)
|
||||
|
||||
tk.Label(
|
||||
inner, text=f"⏸ Pause supervisée • {now}",
|
||||
bg=PAUSED_BG, fg=PAUSED_FG,
|
||||
font=("Segoe UI", 12, "bold"), anchor="w",
|
||||
).pack(fill=tk.X, anchor=tk.W)
|
||||
|
||||
tk.Label(
|
||||
inner, text=reason, bg=PAUSED_BG, fg=PAUSED_FG,
|
||||
font=FONT_MSG, wraplength=MSG_WRAP_WIDTH - 30,
|
||||
anchor="w", justify=tk.LEFT,
|
||||
).pack(fill=tk.X, anchor=tk.W, pady=(6, 0))
|
||||
|
||||
tk.Label(
|
||||
inner, text=f"{workflow} — étape {completed}/{total}",
|
||||
bg=PAUSED_BG, fg=TIMESTAMP_FG, font=FONT_TIMESTAMP, anchor="w",
|
||||
).pack(fill=tk.X, anchor=tk.W, pady=(4, 8))
|
||||
|
||||
btn_frame = tk.Frame(inner, bg=PAUSED_BG)
|
||||
btn_frame.pack(fill=tk.X, anchor=tk.W)
|
||||
|
||||
btn_resume = tk.Button(
|
||||
btn_frame, text="Continuer",
|
||||
bg=PAUSED_BTN_RESUME_BG, fg="white", font=FONT_QUICK_BTN,
|
||||
padx=14, pady=4, bd=0, cursor="hand2",
|
||||
activebackground=PAUSED_BTN_RESUME_HOVER, activeforeground="white",
|
||||
command=lambda: self._on_paused_resume(replay_id),
|
||||
)
|
||||
btn_resume.pack(side=tk.LEFT, padx=(0, 8))
|
||||
|
||||
btn_abort = tk.Button(
|
||||
btn_frame, text="Annuler",
|
||||
bg=PAUSED_BTN_ABORT_BG, fg="white", font=FONT_QUICK_BTN,
|
||||
padx=14, pady=4, bd=0, cursor="hand2",
|
||||
activebackground=PAUSED_BTN_ABORT_HOVER, activeforeground="white",
|
||||
command=lambda: self._on_paused_abort(replay_id),
|
||||
)
|
||||
btn_abort.pack(side=tk.LEFT)
|
||||
|
||||
self._active_paused_bubble = {
|
||||
"container": container, "inner": inner,
|
||||
"btn_resume": btn_resume, "btn_abort": btn_abort,
|
||||
"replay_id": replay_id,
|
||||
}
|
||||
|
||||
def _close_active_paused_bubble(self, reason: str) -> None:
|
||||
if self._active_paused_bubble is None or self._root is None:
|
||||
return
|
||||
self._root.after(0, lambda: self._do_close_paused_bubble(reason))
|
||||
|
||||
def _do_close_paused_bubble(self, reason: str) -> None:
|
||||
bubble = self._active_paused_bubble
|
||||
if bubble is None:
|
||||
return
|
||||
try:
|
||||
bubble["btn_resume"].config(state="disabled")
|
||||
bubble["btn_abort"].config(state="disabled")
|
||||
label_text = {
|
||||
"lea:resumed": "→ Reprise",
|
||||
"lea:done": "→ Terminé",
|
||||
}.get(reason, f"→ {reason}")
|
||||
self._tk.Label(
|
||||
bubble["inner"], text=label_text,
|
||||
bg=PAUSED_BG, fg=PAUSED_FG, font=FONT_TIMESTAMP, anchor="w",
|
||||
).pack(fill="x", anchor="w", pady=(6, 0))
|
||||
except Exception:
|
||||
logger.debug("close paused bubble silenced", exc_info=True)
|
||||
self._active_paused_bubble = None
|
||||
|
||||
def _on_paused_resume(self, replay_id: str) -> None:
|
||||
if not replay_id or self._bus is None or not self._bus.connected:
|
||||
self._add_lea_message("⚠ Bus indisponible — impossible de relancer")
|
||||
return
|
||||
self._bus.resume_replay(replay_id)
|
||||
if self._active_paused_bubble:
|
||||
try:
|
||||
self._active_paused_bubble["btn_resume"].config(state="disabled")
|
||||
self._active_paused_bubble["btn_abort"].config(state="disabled")
|
||||
except Exception:
|
||||
pass
|
||||
|
||||
def _on_paused_abort(self, replay_id: str) -> None:
|
||||
if self._bus is None or not self._bus.connected:
|
||||
self._add_lea_message("⚠ Bus indisponible — impossible d'annuler")
|
||||
return
|
||||
self._bus.abort_replay(replay_id)
|
||||
if self._active_paused_bubble:
|
||||
try:
|
||||
self._active_paused_bubble["btn_resume"].config(state="disabled")
|
||||
self._active_paused_bubble["btn_abort"].config(state="disabled")
|
||||
except Exception:
|
||||
pass
|
||||
|
||||
# ======================================================================
|
||||
# Ajout de messages dans la zone de chat
|
||||
# ======================================================================
|
||||
|
||||
@@ -293,6 +293,49 @@ def formatter_ecran_inchange(action_type: str = "") -> MessageUtilisateur:
|
||||
)
|
||||
|
||||
|
||||
def formatter_mode_apprentissage(
|
||||
raison: str = "",
|
||||
description_cible: str = "",
|
||||
titre_fenetre: Optional[str] = None,
|
||||
) -> MessageUtilisateur:
|
||||
"""Message quand Léa passe en mode apprentissage (pause supervisée).
|
||||
|
||||
L'utilisateur doit comprendre :
|
||||
1. Léa est bloquée et a besoin d'aide
|
||||
2. L'utilisateur doit prendre la main et montrer comment faire
|
||||
3. Ctrl+Shift+L pour signaler qu'il a fini
|
||||
|
||||
Le ton est humble, clair, actionnable. Pas technique.
|
||||
|
||||
Exemple :
|
||||
Léa a besoin d'aide
|
||||
Je n'y arrive pas, montrez-moi comment faire.
|
||||
Quand vous avez fini, appuyez sur Ctrl+Shift+L.
|
||||
"""
|
||||
cible = _nettoyer_description_cible(description_cible) if description_cible else ""
|
||||
app = _extraire_nom_application(titre_fenetre or "") if titre_fenetre else ""
|
||||
|
||||
# Construire un contexte court si disponible
|
||||
contexte = ""
|
||||
if cible and app:
|
||||
contexte = f" (« {cible} » dans {app})"
|
||||
elif cible:
|
||||
contexte = f" (« {cible} »)"
|
||||
|
||||
corps = (
|
||||
f"Je n'y arrive pas{contexte}, montrez-moi comment faire. "
|
||||
f"Quand vous avez fini, appuyez sur Ctrl+Shift+L."
|
||||
)
|
||||
|
||||
return MessageUtilisateur(
|
||||
niveau=NiveauMessage.BLOCAGE,
|
||||
titre="Léa a besoin d'aide",
|
||||
corps=corps,
|
||||
duree_s=DUREE_PAR_NIVEAU[NiveauMessage.BLOCAGE],
|
||||
persistent=True,
|
||||
)
|
||||
|
||||
|
||||
def formatter_connexion_perdue(hote_serveur: str = "") -> MessageUtilisateur:
|
||||
"""Message quand la connexion avec le serveur est perdue.
|
||||
|
||||
|
||||
@@ -32,6 +32,7 @@ from .messages import (
|
||||
formatter_etape_workflow,
|
||||
formatter_fenetre_incorrecte,
|
||||
formatter_fin_workflow,
|
||||
formatter_mode_apprentissage,
|
||||
formatter_ralentissement,
|
||||
formatter_retry,
|
||||
)
|
||||
@@ -273,6 +274,20 @@ class NotificationManager:
|
||||
msg = formatter_ecran_inchange(action_type)
|
||||
return self.notify_message(msg)
|
||||
|
||||
def replay_learning_mode(
|
||||
self,
|
||||
raison: str = "",
|
||||
target_description: str = "",
|
||||
window_title: Optional[str] = None,
|
||||
) -> bool:
|
||||
"""Notification quand Léa passe en mode apprentissage.
|
||||
|
||||
Léa est bloquée et demande à l'utilisateur de montrer comment faire.
|
||||
Message humble et actionnable pour un utilisateur non technique.
|
||||
"""
|
||||
msg = formatter_mode_apprentissage(raison, target_description, window_title)
|
||||
return self.notify_message(msg)
|
||||
|
||||
def replay_retry(self, action_type: str = "", tentative: int = 2) -> bool:
|
||||
"""Notification quand Léa retente une action."""
|
||||
msg = formatter_retry(action_type, tentative)
|
||||
|
||||
@@ -2,12 +2,20 @@
|
||||
"""
|
||||
Gestionnaire de vision avancé pour Agent V1.
|
||||
Optimisé pour le streaming fibre avec détection de changement.
|
||||
|
||||
Captures disponibles :
|
||||
- Plein écran (full) : contexte global 1920x1080+
|
||||
- Crop ciblé (crop) : 80x80 autour du clic (apprentissage VLM)
|
||||
- Fenêtre active (window) : image isolée de la fenêtre + métadonnées
|
||||
(titre, rect, coordonnées clic relatives) — cross-platform
|
||||
"""
|
||||
|
||||
import os
|
||||
import time
|
||||
import logging
|
||||
import hashlib
|
||||
import platform
|
||||
from typing import Any, Dict, Optional
|
||||
from PIL import Image, ImageFilter, ImageStat
|
||||
import mss
|
||||
from ..config import TARGETED_CROP_SIZE, SCREENSHOT_QUALITY, BLUR_SENSITIVE
|
||||
@@ -15,6 +23,9 @@ from .blur_sensitive import blur_sensitive_regions
|
||||
|
||||
logger = logging.getLogger(__name__)
|
||||
|
||||
# OS courant (détecté une seule fois)
|
||||
_SYSTEM = platform.system()
|
||||
|
||||
class VisionCapturer:
|
||||
def __init__(self, session_dir: str):
|
||||
self.session_dir = session_dir
|
||||
@@ -27,6 +38,9 @@ class VisionCapturer:
|
||||
"""
|
||||
Capture l'écran complet.
|
||||
Si force=False, vérifie d'abord si l'écran a changé.
|
||||
|
||||
Enrichit les métadonnées avec le titre de la fenêtre active
|
||||
(utile pour le contextualisation des heartbeats côté serveur).
|
||||
"""
|
||||
try:
|
||||
with mss.mss() as sct:
|
||||
@@ -52,8 +66,24 @@ class VisionCapturer:
|
||||
logger.error(f"Erreur Context Capture: {e}")
|
||||
return ""
|
||||
|
||||
def get_active_window_title(self) -> str:
|
||||
"""Retourne le titre de la fenêtre active (pour enrichir les heartbeats).
|
||||
|
||||
Fallback gracieux : retourne une chaîne vide si indisponible.
|
||||
"""
|
||||
try:
|
||||
from ..window_info_crossplatform import get_active_window_info
|
||||
info = get_active_window_info()
|
||||
return info.get("title", "")
|
||||
except Exception:
|
||||
return ""
|
||||
|
||||
def capture_dual(self, x: int, y: int, screenshot_id: str, anonymize=False) -> dict:
|
||||
"""Capture duale (Full + Crop) systématique (forcée car liée à une action)."""
|
||||
"""Capture triple (Full + Crop + Fenêtre active) systématique.
|
||||
|
||||
La fenêtre active est un AJOUT — en cas d'échec, le full + crop
|
||||
sont toujours retournés (fallback gracieux).
|
||||
"""
|
||||
try:
|
||||
with mss.mss() as sct:
|
||||
full_path = os.path.join(self.shots_dir, f"{screenshot_id}_full.png")
|
||||
@@ -82,11 +112,130 @@ class VisionCapturer:
|
||||
# Mise à jour du hash pour le prochain heartbeat
|
||||
self.last_img_hash = self._compute_quick_hash(img)
|
||||
|
||||
return {"full": full_path, "crop": crop_path}
|
||||
result = {"full": full_path, "crop": crop_path}
|
||||
|
||||
# --- Capture de la fenêtre active ---
|
||||
# Ajout non-bloquant : enrichit le résultat avec l'image
|
||||
# de la fenêtre seule + métadonnées (titre, rect, clic relatif)
|
||||
window_info = self.capture_active_window(x, y, screenshot_id, full_img=img)
|
||||
if window_info:
|
||||
result["window_capture"] = window_info
|
||||
|
||||
return result
|
||||
except Exception as e:
|
||||
logger.error(f"Erreur Dual Capture: {e}")
|
||||
return {}
|
||||
|
||||
def capture_active_window(
|
||||
self,
|
||||
x: int,
|
||||
y: int,
|
||||
screenshot_id: str,
|
||||
full_img: Optional[Image.Image] = None,
|
||||
) -> Optional[Dict[str, Any]]:
|
||||
"""Capture l'image de la fenêtre active seule + métadonnées.
|
||||
|
||||
Stratégie :
|
||||
1. Obtenir le rectangle de la fenêtre via l'API OS (pywin32 / xdotool / Quartz)
|
||||
2. Cropper depuis le screenshot plein écran (plus fiable que PrintWindow)
|
||||
3. Calculer les coordonnées du clic relatives à la fenêtre
|
||||
|
||||
Args:
|
||||
x, y: coordonnées du clic en pixels écran
|
||||
screenshot_id: identifiant pour le nom de fichier
|
||||
full_img: screenshot plein écran déjà capturé (optionnel, évite une
|
||||
double capture si appelé depuis capture_dual)
|
||||
|
||||
Returns:
|
||||
Dict avec window_image, window_title, window_rect, click_in_window,
|
||||
window_size — ou None si la fenêtre est introuvable.
|
||||
"""
|
||||
try:
|
||||
from ..window_info_crossplatform import get_active_window_rect
|
||||
|
||||
rect_info = get_active_window_rect()
|
||||
if not rect_info:
|
||||
logger.debug("Fenêtre active introuvable — skip capture fenêtre")
|
||||
return None
|
||||
|
||||
win_rect = rect_info["rect"] # [left, top, right, bottom]
|
||||
win_left, win_top, win_right, win_bottom = win_rect
|
||||
win_w, win_h = rect_info["size"] # [width, height]
|
||||
title = rect_info.get("title", "unknown_window")
|
||||
app_name = rect_info.get("app_name", "unknown_app")
|
||||
|
||||
# Ignorer les fenêtres trop petites (barres de tâches, popups système)
|
||||
if win_w < 50 or win_h < 50:
|
||||
logger.debug(f"Fenêtre trop petite ({win_w}x{win_h}) — skip")
|
||||
return None
|
||||
|
||||
# Coordonnées du clic relatives à la fenêtre
|
||||
click_rel_x = x - win_left
|
||||
click_rel_y = y - win_top
|
||||
|
||||
# Si le clic est en dehors de la fenêtre, on le signale mais on continue
|
||||
click_inside = (0 <= click_rel_x <= win_w and 0 <= click_rel_y <= win_h)
|
||||
|
||||
# --- Crop de la fenêtre depuis le plein écran ---
|
||||
if full_img is None:
|
||||
# Pas de screenshot fourni — en capturer un (cas standalone)
|
||||
try:
|
||||
with mss.mss() as sct:
|
||||
monitor = sct.monitors[1]
|
||||
sct_img = sct.grab(monitor)
|
||||
full_img = Image.frombytes(
|
||||
"RGB", sct_img.size, sct_img.bgra, "raw", "BGRX"
|
||||
)
|
||||
except Exception as e:
|
||||
logger.error(f"Erreur capture plein écran pour fenêtre : {e}")
|
||||
return None
|
||||
|
||||
# Borner le crop aux limites de l'image plein écran
|
||||
img_w, img_h = full_img.size
|
||||
crop_left = max(0, win_left)
|
||||
crop_top = max(0, win_top)
|
||||
crop_right = min(img_w, win_right)
|
||||
crop_bottom = min(img_h, win_bottom)
|
||||
|
||||
if crop_right <= crop_left or crop_bottom <= crop_top:
|
||||
logger.debug("Fenêtre hors écran — skip capture fenêtre")
|
||||
return None
|
||||
|
||||
window_img = full_img.crop((crop_left, crop_top, crop_right, crop_bottom))
|
||||
|
||||
# Floutage conformité AI Act
|
||||
if BLUR_SENSITIVE:
|
||||
blur_sensitive_regions(window_img)
|
||||
|
||||
# Sauvegarde
|
||||
window_path = os.path.join(
|
||||
self.shots_dir, f"{screenshot_id}_window.png"
|
||||
)
|
||||
window_img.save(window_path, "PNG", quality=SCREENSHOT_QUALITY)
|
||||
|
||||
result = {
|
||||
"window_image": window_path,
|
||||
"window_title": title,
|
||||
"app_name": app_name,
|
||||
"window_rect": win_rect,
|
||||
"window_size": [win_w, win_h],
|
||||
"click_in_window": [click_rel_x, click_rel_y],
|
||||
"click_inside_window": click_inside,
|
||||
}
|
||||
|
||||
logger.debug(
|
||||
f"Fenêtre capturée : {title} ({win_w}x{win_h}) — "
|
||||
f"clic relatif ({click_rel_x}, {click_rel_y})"
|
||||
)
|
||||
return result
|
||||
|
||||
except ImportError as e:
|
||||
logger.debug(f"Module fenêtre indisponible : {e}")
|
||||
return None
|
||||
except Exception as e:
|
||||
logger.error(f"Erreur capture fenêtre active : {e}")
|
||||
return None
|
||||
|
||||
def _compute_quick_hash(self, img: Image) -> str:
|
||||
"""Calcule un hash rapide basé sur une vignette réduite pour détecter les changements."""
|
||||
# On réduit l'image à 64x64 pour comparer les masses de couleurs (très rapide)
|
||||
|
||||
@@ -17,7 +17,7 @@ from __future__ import annotations
|
||||
|
||||
import platform
|
||||
import subprocess
|
||||
from typing import Dict, Optional
|
||||
from typing import Any, Dict, Optional
|
||||
|
||||
|
||||
def _run_cmd(cmd: list[str]) -> Optional[str]:
|
||||
@@ -51,6 +51,32 @@ def get_active_window_info() -> Dict[str, str]:
|
||||
return {"title": "unknown_window", "app_name": "unknown_app"}
|
||||
|
||||
|
||||
def get_active_window_rect() -> Optional[Dict[str, Any]]:
|
||||
"""
|
||||
Renvoie le rectangle de la fenêtre active :
|
||||
{
|
||||
"title": "...",
|
||||
"app_name": "...",
|
||||
"rect": [left, top, right, bottom],
|
||||
"position": [left, top],
|
||||
"size": [width, height],
|
||||
"hwnd": int # Windows uniquement
|
||||
}
|
||||
|
||||
Retourne None si la fenêtre est introuvable ou minimisée.
|
||||
Détecte automatiquement l'OS et utilise la méthode appropriée.
|
||||
"""
|
||||
system = platform.system()
|
||||
|
||||
if system == "Windows":
|
||||
return _get_window_rect_windows()
|
||||
elif system == "Linux":
|
||||
return _get_window_rect_linux()
|
||||
elif system == "Darwin":
|
||||
return _get_window_rect_macos()
|
||||
return None
|
||||
|
||||
|
||||
def _get_window_info_linux() -> Dict[str, str]:
|
||||
"""
|
||||
Linux: utilise xdotool (X11)
|
||||
@@ -178,6 +204,163 @@ def _get_window_info_macos() -> Dict[str, str]:
|
||||
}
|
||||
|
||||
|
||||
def _get_window_rect_windows() -> Optional[Dict[str, Any]]:
|
||||
"""
|
||||
Windows : utilise pywin32 pour obtenir le rectangle de la fenêtre active.
|
||||
|
||||
Retourne None si la fenêtre est minimisée (icônifiée) ou si pywin32 manque.
|
||||
"""
|
||||
try:
|
||||
import win32gui
|
||||
import win32process
|
||||
import psutil
|
||||
|
||||
hwnd = win32gui.GetForegroundWindow()
|
||||
if not hwnd:
|
||||
return None
|
||||
|
||||
# Ignorer les fenêtres minimisées (pas de contenu visible)
|
||||
if win32gui.IsIconic(hwnd):
|
||||
return None
|
||||
|
||||
title = win32gui.GetWindowText(hwnd) or "unknown_window"
|
||||
|
||||
# Rectangle de la fenêtre (coordonnées écran absolues)
|
||||
left, top, right, bottom = win32gui.GetWindowRect(hwnd)
|
||||
width = right - left
|
||||
height = bottom - top
|
||||
|
||||
# Ignorer les fenêtres de taille nulle ou absurde
|
||||
if width <= 0 or height <= 0:
|
||||
return None
|
||||
|
||||
# Nom du processus
|
||||
_, pid = win32process.GetWindowThreadProcessId(hwnd)
|
||||
try:
|
||||
app_name = psutil.Process(pid).name()
|
||||
except Exception:
|
||||
app_name = "unknown_app"
|
||||
|
||||
return {
|
||||
"title": title,
|
||||
"app_name": app_name,
|
||||
"rect": [left, top, right, bottom],
|
||||
"position": [left, top],
|
||||
"size": [width, height],
|
||||
"hwnd": hwnd,
|
||||
}
|
||||
|
||||
except ImportError:
|
||||
return None
|
||||
except Exception:
|
||||
return None
|
||||
|
||||
|
||||
def _get_window_rect_linux() -> Optional[Dict[str, Any]]:
|
||||
"""
|
||||
Linux (X11) : utilise xdotool + xwininfo pour obtenir le rectangle.
|
||||
|
||||
Nécessite : sudo apt-get install xdotool x11-utils
|
||||
"""
|
||||
try:
|
||||
# Identifiant de la fenêtre active
|
||||
wid = _run_cmd(["xdotool", "getactivewindow"])
|
||||
if not wid:
|
||||
return None
|
||||
|
||||
title = _run_cmd(["xdotool", "getactivewindow", "getwindowname"]) or "unknown_window"
|
||||
pid_str = _run_cmd(["xdotool", "getactivewindow", "getwindowpid"])
|
||||
app_name = "unknown_app"
|
||||
if pid_str:
|
||||
app_name = _run_cmd(["ps", "-p", pid_str.strip(), "-o", "comm="]) or "unknown_app"
|
||||
|
||||
# Géométrie via xdotool --shell (position + taille)
|
||||
geom_raw = _run_cmd(["xdotool", "getwindowgeometry", "--shell", wid])
|
||||
if not geom_raw:
|
||||
return None
|
||||
|
||||
vals: Dict[str, int] = {}
|
||||
for line in geom_raw.strip().splitlines():
|
||||
if "=" in line:
|
||||
k, v = line.split("=", 1)
|
||||
try:
|
||||
vals[k.strip()] = int(v.strip())
|
||||
except ValueError:
|
||||
pass
|
||||
|
||||
if not {"X", "Y", "WIDTH", "HEIGHT"} <= vals.keys():
|
||||
return None
|
||||
|
||||
x, y = vals["X"], vals["Y"]
|
||||
w, h = vals["WIDTH"], vals["HEIGHT"]
|
||||
|
||||
return {
|
||||
"title": title,
|
||||
"app_name": app_name,
|
||||
"rect": [x, y, x + w, y + h],
|
||||
"position": [x, y],
|
||||
"size": [w, h],
|
||||
}
|
||||
|
||||
except Exception:
|
||||
return None
|
||||
|
||||
|
||||
def _get_window_rect_macos() -> Optional[Dict[str, Any]]:
|
||||
"""
|
||||
macOS : utilise Quartz (CGWindowListCopyWindowInfo) pour obtenir le rectangle.
|
||||
|
||||
Nécessite : pip install pyobjc-framework-Quartz
|
||||
"""
|
||||
try:
|
||||
from AppKit import NSWorkspace
|
||||
from Quartz import (
|
||||
CGWindowListCopyWindowInfo,
|
||||
kCGWindowListOptionOnScreenOnly,
|
||||
kCGNullWindowID,
|
||||
)
|
||||
|
||||
active_app = NSWorkspace.sharedWorkspace().activeApplication()
|
||||
app_name = active_app.get("NSApplicationName", "unknown_app")
|
||||
|
||||
window_list = CGWindowListCopyWindowInfo(
|
||||
kCGWindowListOptionOnScreenOnly, kCGNullWindowID
|
||||
)
|
||||
|
||||
for window in window_list:
|
||||
owner_name = window.get("kCGWindowOwnerName", "")
|
||||
if owner_name != app_name:
|
||||
continue
|
||||
|
||||
bounds = window.get("kCGWindowBounds")
|
||||
if not bounds:
|
||||
continue
|
||||
|
||||
x = int(bounds.get("X", 0))
|
||||
y = int(bounds.get("Y", 0))
|
||||
w = int(bounds.get("Width", 0))
|
||||
h = int(bounds.get("Height", 0))
|
||||
if w <= 0 or h <= 0:
|
||||
continue
|
||||
|
||||
title = window.get("kCGWindowName", "unknown_window") or "unknown_window"
|
||||
|
||||
return {
|
||||
"title": title,
|
||||
"app_name": app_name,
|
||||
"rect": [x, y, x + w, y + h],
|
||||
"position": [x, y],
|
||||
"size": [w, h],
|
||||
}
|
||||
|
||||
except ImportError:
|
||||
return None
|
||||
except Exception:
|
||||
return None
|
||||
|
||||
return None
|
||||
|
||||
|
||||
# Test rapide
|
||||
if __name__ == "__main__":
|
||||
import time
|
||||
@@ -188,5 +371,10 @@ if __name__ == "__main__":
|
||||
|
||||
for i in range(5):
|
||||
info = get_active_window_info()
|
||||
rect = get_active_window_rect()
|
||||
print(f"[{i+1}] App: {info['app_name']:20s} | Title: {info['title']}")
|
||||
if rect:
|
||||
print(f" Rect: {rect['rect']} | Size: {rect['size']}")
|
||||
else:
|
||||
print(" Rect: non disponible")
|
||||
time.sleep(1)
|
||||
|
||||
@@ -3,6 +3,7 @@ mss>=9.0.1 # Capture d'écran haute performance
|
||||
pynput>=1.7.7 # Clavier/Souris Cross-plateforme
|
||||
Pillow>=10.0.0 # Crops et processing image
|
||||
requests>=2.31.0 # Streaming réseau
|
||||
python-socketio[client]>=5.10,<6.0 # Bus feedback Léa 'lea:*' (compat Flask-SocketIO 5.3.x serveur)
|
||||
psutil>=5.9.0 # Monitoring CPU/RAM
|
||||
pystray>=0.19.5 # Icône Tray UI
|
||||
plyer>=2.1.0 # Notifications toast natives (remplace PyQt5)
|
||||
|
||||
@@ -2,6 +2,17 @@
|
||||
"""
|
||||
deploy_windows.py — Script de packaging du client Windows pour Agent V1.
|
||||
|
||||
⚠️ OBSOLÈTE (avril 2026)
|
||||
Le build officiel du package Windows passe par ``deploy/build_package.sh``
|
||||
(à la racine du repo) qui lit directement ``agent_v0/agent_v1/`` et évite
|
||||
les clones intermédiaires. Ce script est conservé pour référence mais son
|
||||
manifeste ``FILE_MANIFEST`` est incomplet : il n'inclut pas
|
||||
``system_dialog_guard.py``, ``persistent_buffer.py``, ``recovery.py``,
|
||||
``uia_helper.py``, ``grounding.py``, ``policy.py``,
|
||||
``vision/blur_sensitive.py``, ``vision/system_info.py``,
|
||||
``ui/chat_window.py``, ``ui/capture_server.py``, ``ui/shared_state.py``.
|
||||
Ne PAS l'utiliser pour un packaging réel.
|
||||
|
||||
Copie uniquement les fichiers nécessaires au fonctionnement de l'agent
|
||||
sur le PC cible (Windows), sans le serveur ni les dépendances lourdes.
|
||||
|
||||
|
||||
@@ -21,36 +21,33 @@ from typing import Any, Callable, Dict, List, Optional
|
||||
logger = logging.getLogger("lea_ui.server_client")
|
||||
|
||||
|
||||
def _get_server_host() -> str:
|
||||
"""Recuperer l'adresse du serveur Linux.
|
||||
def _get_server_url() -> str:
|
||||
"""Recuperer l'URL du serveur RPA (avec /api/v1).
|
||||
|
||||
Ordre de resolution :
|
||||
1. Variable d'environnement RPA_SERVER_HOST
|
||||
2. Fichier de config agent_config.json (cle "server_host")
|
||||
3. Fallback localhost
|
||||
1. Import depuis agent_v1.config (source de verite unique)
|
||||
2. Variable d'environnement RPA_SERVER_URL
|
||||
3. Fallback http://localhost:5005/api/v1
|
||||
"""
|
||||
# 1. Variable d'environnement
|
||||
host = os.environ.get("RPA_SERVER_HOST", "").strip()
|
||||
if host:
|
||||
return host
|
||||
|
||||
# 2. Fichier de config
|
||||
config_paths = [
|
||||
os.path.join(os.path.dirname(__file__), "..", "agent_config.json"),
|
||||
os.path.join(os.path.dirname(__file__), "..", "..", "agent_config.json"),
|
||||
]
|
||||
for config_path in config_paths:
|
||||
# 1. Import depuis config.py (source de verite)
|
||||
try:
|
||||
with open(config_path, "r", encoding="utf-8") as f:
|
||||
cfg = json.load(f)
|
||||
host = cfg.get("server_host", "").strip()
|
||||
if host:
|
||||
return host
|
||||
except (OSError, json.JSONDecodeError):
|
||||
continue
|
||||
from agent_v1.config import SERVER_URL
|
||||
return SERVER_URL
|
||||
except ImportError:
|
||||
pass
|
||||
|
||||
# 2. Variable d'environnement directe
|
||||
url = os.environ.get("RPA_SERVER_URL", "").strip().rstrip("/")
|
||||
if url:
|
||||
return url
|
||||
|
||||
# 3. Fallback
|
||||
return "localhost"
|
||||
return "http://localhost:5005/api/v1"
|
||||
|
||||
|
||||
def _get_server_base(server_url: str) -> str:
|
||||
"""Extraire la base URL (sans /api/v1) pour les routes racine (/health)."""
|
||||
return server_url.rsplit("/api/v1", 1)[0]
|
||||
|
||||
|
||||
class LeaServerClient:
|
||||
@@ -67,12 +64,23 @@ class LeaServerClient:
|
||||
chat_port: int = 5004,
|
||||
stream_port: int = 5005,
|
||||
) -> None:
|
||||
self._host = server_host or _get_server_host()
|
||||
# URL unifiée : SERVER_URL contient TOUJOURS /api/v1 (convention INC-1).
|
||||
# _stream_url = URL avec /api/v1 (pour les routes API)
|
||||
# _stream_base = URL sans /api/v1 (pour /health uniquement)
|
||||
self._stream_url = _get_server_url()
|
||||
self._stream_base = _get_server_base(self._stream_url)
|
||||
|
||||
# Extraire le host depuis l'URL pour le chat et pour l'affichage
|
||||
try:
|
||||
from urllib.parse import urlparse
|
||||
parsed = urlparse(self._stream_base)
|
||||
self._host = parsed.hostname or "localhost"
|
||||
except Exception:
|
||||
self._host = server_host or "localhost"
|
||||
|
||||
self._chat_port = chat_port
|
||||
self._stream_port = stream_port
|
||||
|
||||
self._chat_base = f"http://{self._host}:{self._chat_port}"
|
||||
self._stream_base = f"http://{self._host}:{self._stream_port}"
|
||||
|
||||
# Etat de connexion
|
||||
self._connected = False
|
||||
@@ -95,8 +103,8 @@ class LeaServerClient:
|
||||
self._api_token = os.environ.get("RPA_API_TOKEN", "")
|
||||
|
||||
logger.info(
|
||||
"LeaServerClient initialise : chat=%s, stream=%s",
|
||||
self._chat_base, self._stream_base,
|
||||
"LeaServerClient initialise : chat=%s, stream_url=%s, stream_base=%s",
|
||||
self._chat_base, self._stream_url, self._stream_base,
|
||||
)
|
||||
|
||||
# ---------------------------------------------------------------------------
|
||||
@@ -146,7 +154,11 @@ class LeaServerClient:
|
||||
# ---------------------------------------------------------------------------
|
||||
|
||||
def check_connection(self) -> bool:
|
||||
"""Tester la connexion au serveur streaming (port 5005)."""
|
||||
"""Tester la connexion au serveur streaming (port 5005).
|
||||
|
||||
Le health check utilise _stream_base (sans /api/v1) car la route
|
||||
/health est a la racine du serveur FastAPI, pas sous /api/v1.
|
||||
"""
|
||||
try:
|
||||
import requests
|
||||
resp = requests.get(
|
||||
@@ -219,7 +231,7 @@ class LeaServerClient:
|
||||
import requests
|
||||
headers = self._auth_headers()
|
||||
resp = requests.get(
|
||||
f"{self._stream_base}/api/v1/traces/stream/workflows",
|
||||
f"{self._stream_url}/traces/stream/workflows",
|
||||
headers=headers,
|
||||
timeout=10,
|
||||
)
|
||||
@@ -276,7 +288,7 @@ class LeaServerClient:
|
||||
while self._polling:
|
||||
try:
|
||||
resp = req_lib.get(
|
||||
f"{self._stream_base}/api/v1/traces/stream/replay/next",
|
||||
f"{self._stream_url}/traces/stream/replay/next",
|
||||
params={"session_id": self._poll_session_id},
|
||||
headers=self._auth_headers(),
|
||||
timeout=5,
|
||||
@@ -310,7 +322,7 @@ class LeaServerClient:
|
||||
try:
|
||||
import requests
|
||||
resp = requests.get(
|
||||
f"{self._stream_base}/api/v1/traces/stream/replays",
|
||||
f"{self._stream_url}/traces/stream/replays",
|
||||
headers=self._auth_headers(),
|
||||
timeout=5,
|
||||
)
|
||||
@@ -338,7 +350,7 @@ class LeaServerClient:
|
||||
try:
|
||||
import requests
|
||||
requests.post(
|
||||
f"{self._stream_base}/api/v1/traces/stream/replay/result",
|
||||
f"{self._stream_url}/traces/stream/replay/result",
|
||||
json={
|
||||
"session_id": session_id,
|
||||
"action_id": action_id,
|
||||
|
||||
296
agent_v0/server_v1/agent_registry.py
Normal file
296
agent_v0/server_v1/agent_registry.py
Normal file
@@ -0,0 +1,296 @@
|
||||
# agent_v0/server_v1/agent_registry.py
|
||||
"""
|
||||
Registre des agents Lea enrolles sur le parc.
|
||||
|
||||
Alimente par les endpoints /api/v1/agents/enroll et /api/v1/agents/uninstall
|
||||
que l'installeur Inno Setup (`deploy/installer/Lea.iss`) appelle a
|
||||
l'installation et a la desinstallation sur chaque poste collaborateur.
|
||||
|
||||
Stockage : SQLite simple, cohabite avec rpa_data.db dans data/databases/.
|
||||
Aucune dependance GPU/LLM — ce module doit rester leger (juste sqlite3 +
|
||||
stdlib) pour pouvoir etre importe par le serveur HTTP.
|
||||
|
||||
Schema de la table `enrolled_agents` :
|
||||
id INTEGER PK AUTOINCREMENT
|
||||
machine_id TEXT UNIQUE NOT NULL — identifiant genere par l'installeur
|
||||
user_name TEXT — nom affichage collaborateur
|
||||
user_email TEXT
|
||||
user_id TEXT — identifiant metier (ex: AIVA-001)
|
||||
hostname TEXT
|
||||
os_info TEXT
|
||||
version TEXT — version du client Lea
|
||||
status TEXT DEFAULT 'active' — 'active' | 'uninstalled'
|
||||
enrolled_at TEXT NOT NULL — ISO 8601 UTC
|
||||
last_seen_at TEXT — ISO 8601 UTC (heartbeat / stream)
|
||||
uninstalled_at TEXT
|
||||
uninstall_reason TEXT
|
||||
"""
|
||||
|
||||
from __future__ import annotations
|
||||
|
||||
import logging
|
||||
import sqlite3
|
||||
import threading
|
||||
from datetime import datetime, timezone
|
||||
from pathlib import Path
|
||||
from typing import Any, Dict, List, Optional
|
||||
|
||||
logger = logging.getLogger(__name__)
|
||||
|
||||
# Verrou global : SQLite tolere plusieurs threads mais on serialise
|
||||
# les ecritures pour eviter les races sur _init_db + upserts concurrents.
|
||||
_DB_LOCK = threading.Lock()
|
||||
|
||||
|
||||
def _utc_now_iso() -> str:
|
||||
"""Horodatage ISO 8601 UTC (compatible toutes les autres tables)."""
|
||||
return datetime.now(timezone.utc).isoformat()
|
||||
|
||||
|
||||
class AgentRegistry:
|
||||
"""Gestion CRUD des agents enrolles (SQLite)."""
|
||||
|
||||
def __init__(self, db_path: str | Path = "data/databases/rpa_data.db"):
|
||||
self.db_path = Path(db_path)
|
||||
self.db_path.parent.mkdir(parents=True, exist_ok=True)
|
||||
self._init_db()
|
||||
|
||||
# ------------------------------------------------------------------
|
||||
# Infra SQLite
|
||||
# ------------------------------------------------------------------
|
||||
|
||||
def _connect(self) -> sqlite3.Connection:
|
||||
# check_same_thread=False : on protege nous-memes via _DB_LOCK,
|
||||
# indispensable car FastAPI appelle les endpoints sur threads
|
||||
# differents (thread pool).
|
||||
conn = sqlite3.connect(str(self.db_path), check_same_thread=False)
|
||||
conn.row_factory = sqlite3.Row
|
||||
conn.execute("PRAGMA journal_mode=WAL")
|
||||
conn.execute("PRAGMA foreign_keys=ON")
|
||||
return conn
|
||||
|
||||
def _init_db(self) -> None:
|
||||
"""Cree la table et ses index si absents (idempotent)."""
|
||||
with _DB_LOCK, self._connect() as conn:
|
||||
conn.execute(
|
||||
"""
|
||||
CREATE TABLE IF NOT EXISTS enrolled_agents (
|
||||
id INTEGER PRIMARY KEY AUTOINCREMENT,
|
||||
machine_id TEXT NOT NULL UNIQUE,
|
||||
user_name TEXT,
|
||||
user_email TEXT,
|
||||
user_id TEXT,
|
||||
hostname TEXT,
|
||||
os_info TEXT,
|
||||
version TEXT,
|
||||
status TEXT NOT NULL DEFAULT 'active',
|
||||
enrolled_at TEXT NOT NULL,
|
||||
last_seen_at TEXT,
|
||||
uninstalled_at TEXT,
|
||||
uninstall_reason TEXT
|
||||
)
|
||||
"""
|
||||
)
|
||||
conn.execute(
|
||||
"CREATE INDEX IF NOT EXISTS idx_enrolled_agents_status "
|
||||
"ON enrolled_agents(status)"
|
||||
)
|
||||
conn.execute(
|
||||
"CREATE INDEX IF NOT EXISTS idx_enrolled_agents_machine "
|
||||
"ON enrolled_agents(machine_id)"
|
||||
)
|
||||
|
||||
# ------------------------------------------------------------------
|
||||
# Lecture
|
||||
# ------------------------------------------------------------------
|
||||
|
||||
def get(self, machine_id: str) -> Optional[Dict[str, Any]]:
|
||||
"""Recupere un agent par machine_id (ou None)."""
|
||||
with _DB_LOCK, self._connect() as conn:
|
||||
row = conn.execute(
|
||||
"SELECT * FROM enrolled_agents WHERE machine_id = ?",
|
||||
(machine_id,),
|
||||
).fetchone()
|
||||
return dict(row) if row else None
|
||||
|
||||
def list_by_status(self, status: str) -> List[Dict[str, Any]]:
|
||||
"""Liste les agents par statut ('active' | 'uninstalled')."""
|
||||
with _DB_LOCK, self._connect() as conn:
|
||||
rows = conn.execute(
|
||||
"SELECT * FROM enrolled_agents WHERE status = ? "
|
||||
"ORDER BY enrolled_at DESC",
|
||||
(status,),
|
||||
).fetchall()
|
||||
return [dict(r) for r in rows]
|
||||
|
||||
def count_by_status(self, status: str) -> int:
|
||||
with _DB_LOCK, self._connect() as conn:
|
||||
row = conn.execute(
|
||||
"SELECT COUNT(*) AS n FROM enrolled_agents WHERE status = ?",
|
||||
(status,),
|
||||
).fetchone()
|
||||
return int(row["n"]) if row else 0
|
||||
|
||||
# ------------------------------------------------------------------
|
||||
# Ecriture
|
||||
# ------------------------------------------------------------------
|
||||
|
||||
def enroll(
|
||||
self,
|
||||
*,
|
||||
machine_id: str,
|
||||
user_name: str | None = None,
|
||||
user_email: str | None = None,
|
||||
user_id: str | None = None,
|
||||
hostname: str | None = None,
|
||||
os_info: str | None = None,
|
||||
version: str | None = None,
|
||||
allow_reactivate: bool = True,
|
||||
) -> Dict[str, Any]:
|
||||
"""Enregistre un nouvel agent ou reactive un agent desinstalle.
|
||||
|
||||
Returns:
|
||||
dict avec clefs {"created": bool, "reactivated": bool, "agent": row}
|
||||
|
||||
Raises:
|
||||
ValueError: si machine_id est vide.
|
||||
AgentAlreadyEnrolledError: si deja actif (status=active).
|
||||
"""
|
||||
if not machine_id or not machine_id.strip():
|
||||
raise ValueError("machine_id est obligatoire")
|
||||
machine_id = machine_id.strip()
|
||||
|
||||
now = _utc_now_iso()
|
||||
|
||||
with _DB_LOCK, self._connect() as conn:
|
||||
existing = conn.execute(
|
||||
"SELECT * FROM enrolled_agents WHERE machine_id = ?",
|
||||
(machine_id,),
|
||||
).fetchone()
|
||||
|
||||
if existing is not None:
|
||||
if existing["status"] == "active":
|
||||
# Deja enrolle et actif -> conflit explicit
|
||||
raise AgentAlreadyEnrolledError(dict(existing))
|
||||
|
||||
# Agent desinstalle : reactivation si autorise (defaut)
|
||||
if not allow_reactivate:
|
||||
raise AgentAlreadyEnrolledError(dict(existing))
|
||||
|
||||
conn.execute(
|
||||
"""
|
||||
UPDATE enrolled_agents
|
||||
SET user_name = COALESCE(?, user_name),
|
||||
user_email = COALESCE(?, user_email),
|
||||
user_id = COALESCE(?, user_id),
|
||||
hostname = COALESCE(?, hostname),
|
||||
os_info = COALESCE(?, os_info),
|
||||
version = COALESCE(?, version),
|
||||
status = 'active',
|
||||
enrolled_at = ?,
|
||||
last_seen_at = ?,
|
||||
uninstalled_at = NULL,
|
||||
uninstall_reason = NULL
|
||||
WHERE machine_id = ?
|
||||
""",
|
||||
(
|
||||
user_name, user_email, user_id,
|
||||
hostname, os_info, version,
|
||||
now, now, machine_id,
|
||||
),
|
||||
)
|
||||
conn.commit()
|
||||
row = conn.execute(
|
||||
"SELECT * FROM enrolled_agents WHERE machine_id = ?",
|
||||
(machine_id,),
|
||||
).fetchone()
|
||||
return {"created": False, "reactivated": True, "agent": dict(row)}
|
||||
|
||||
# Nouvelle inscription
|
||||
conn.execute(
|
||||
"""
|
||||
INSERT INTO enrolled_agents (
|
||||
machine_id, user_name, user_email, user_id,
|
||||
hostname, os_info, version,
|
||||
status, enrolled_at, last_seen_at
|
||||
) VALUES (?, ?, ?, ?, ?, ?, ?, 'active', ?, ?)
|
||||
""",
|
||||
(
|
||||
machine_id, user_name, user_email, user_id,
|
||||
hostname, os_info, version,
|
||||
now, now,
|
||||
),
|
||||
)
|
||||
conn.commit()
|
||||
row = conn.execute(
|
||||
"SELECT * FROM enrolled_agents WHERE machine_id = ?",
|
||||
(machine_id,),
|
||||
).fetchone()
|
||||
return {"created": True, "reactivated": False, "agent": dict(row)}
|
||||
|
||||
def uninstall(
|
||||
self,
|
||||
*,
|
||||
machine_id: str,
|
||||
reason: str | None = None,
|
||||
) -> Optional[Dict[str, Any]]:
|
||||
"""Marque un agent comme desinstalle (soft delete).
|
||||
|
||||
Returns:
|
||||
Le row mis a jour, ou None si l'agent n'existe pas.
|
||||
"""
|
||||
if not machine_id or not machine_id.strip():
|
||||
raise ValueError("machine_id est obligatoire")
|
||||
machine_id = machine_id.strip()
|
||||
|
||||
now = _utc_now_iso()
|
||||
with _DB_LOCK, self._connect() as conn:
|
||||
existing = conn.execute(
|
||||
"SELECT * FROM enrolled_agents WHERE machine_id = ?",
|
||||
(machine_id,),
|
||||
).fetchone()
|
||||
if existing is None:
|
||||
return None
|
||||
|
||||
conn.execute(
|
||||
"""
|
||||
UPDATE enrolled_agents
|
||||
SET status = 'uninstalled',
|
||||
uninstalled_at = ?,
|
||||
uninstall_reason = ?
|
||||
WHERE machine_id = ?
|
||||
""",
|
||||
(now, reason, machine_id),
|
||||
)
|
||||
conn.commit()
|
||||
row = conn.execute(
|
||||
"SELECT * FROM enrolled_agents WHERE machine_id = ?",
|
||||
(machine_id,),
|
||||
).fetchone()
|
||||
return dict(row)
|
||||
|
||||
def touch_last_seen(self, machine_id: str) -> None:
|
||||
"""Met a jour last_seen_at (appel depuis le stream / heartbeat).
|
||||
|
||||
Silencieux si l'agent est inconnu (evite les erreurs sur vieux clients).
|
||||
"""
|
||||
if not machine_id:
|
||||
return
|
||||
now = _utc_now_iso()
|
||||
with _DB_LOCK, self._connect() as conn:
|
||||
conn.execute(
|
||||
"UPDATE enrolled_agents SET last_seen_at = ? WHERE machine_id = ?",
|
||||
(now, machine_id),
|
||||
)
|
||||
conn.commit()
|
||||
|
||||
|
||||
class AgentAlreadyEnrolledError(Exception):
|
||||
"""Levee si on tente d'enrouler une machine deja active."""
|
||||
|
||||
def __init__(self, existing_row: Dict[str, Any]):
|
||||
self.existing = existing_row
|
||||
super().__init__(
|
||||
f"machine_id={existing_row.get('machine_id')} deja enrole "
|
||||
f"(status={existing_row.get('status')})"
|
||||
)
|
||||
@@ -30,6 +30,7 @@ from .replay_failure_logger import log_replay_failure
|
||||
from .replay_verifier import ReplayVerifier, VerificationResult
|
||||
from .replay_learner import ReplayLearner
|
||||
from .audit_trail import AuditTrail, AuditEntry
|
||||
from .agent_registry import AgentRegistry, AgentAlreadyEnrolledError
|
||||
from .stream_processor import StreamProcessor, build_replay_from_raw_events, enrich_click_from_screenshot
|
||||
from .worker_stream import StreamWorker
|
||||
from .execution_plan_runner import (
|
||||
@@ -37,6 +38,13 @@ from .execution_plan_runner import (
|
||||
inject_plan_into_queue,
|
||||
)
|
||||
|
||||
# Pipeline d'anonymisation PII (OCR + NER côté serveur).
|
||||
# Import paresseux : on ne charge pas docTR tant qu'aucune image n'est reçue.
|
||||
try:
|
||||
from core.anonymisation import blur_pii_on_image as _blur_pii_on_image
|
||||
except ImportError:
|
||||
_blur_pii_on_image = None
|
||||
|
||||
# Instance globale du vérificateur de replay (comparaison screenshots avant/après)
|
||||
_replay_verifier = ReplayVerifier()
|
||||
_replay_learner = ReplayLearner()
|
||||
@@ -82,25 +90,77 @@ logger = logging.getLogger("api_stream")
|
||||
# =========================================================================
|
||||
# Authentification par token Bearer (sécurité HIGH)
|
||||
# =========================================================================
|
||||
# Le token est lu depuis l'environnement ou généré au démarrage.
|
||||
# Le token est lu depuis l'environnement obligatoirement.
|
||||
# Tous les endpoints requièrent le header Authorization: Bearer <token>,
|
||||
# sauf /health, /docs et /openapi.json (publics).
|
||||
API_TOKEN = os.environ.get("RPA_API_TOKEN", secrets.token_hex(32))
|
||||
#
|
||||
# Fail-closed P0-C :
|
||||
# - En production (défaut), RPA_API_TOKEN DOIT être défini.
|
||||
# - Pour désactiver l'auth en dev local : RPA_AUTH_DISABLED=true
|
||||
# Dans ce mode, aucun token n'est requis et l'API log un WARNING au boot.
|
||||
# - Sans token ET sans RPA_AUTH_DISABLED=true → arrêt immédiat du process
|
||||
# (sys.exit 1) avec message fatal clair. On NE génère PLUS de token
|
||||
# aléatoire en silence : cela cassait tous les agents clients sans bruit.
|
||||
_AUTH_DISABLED = os.environ.get("RPA_AUTH_DISABLED", "").lower() in (
|
||||
"1", "true", "yes",
|
||||
)
|
||||
_API_TOKEN_ENV = os.environ.get("RPA_API_TOKEN", "").strip()
|
||||
|
||||
if _AUTH_DISABLED:
|
||||
# Mode dev explicite : on tolère l'absence de token mais on log très fort.
|
||||
logger.warning(
|
||||
"[SÉCURITÉ] RPA_AUTH_DISABLED=true — authentification Bearer DÉSACTIVÉE. "
|
||||
"NE JAMAIS utiliser cette configuration en production. Tous les "
|
||||
"endpoints sont accessibles sans token."
|
||||
)
|
||||
API_TOKEN = _API_TOKEN_ENV or secrets.token_hex(32)
|
||||
elif not _API_TOKEN_ENV:
|
||||
# Fail-closed : pas de génération silencieuse. On arrête le serveur.
|
||||
_FATAL_MSG = (
|
||||
"[SÉCURITÉ] FATAL — RPA_API_TOKEN est absent ou vide. "
|
||||
"Refus de démarrer le serveur de streaming : générer un token "
|
||||
"aléatoire interne casserait tous les agents clients qui utilisent "
|
||||
"le token persistant (.env.local). "
|
||||
"Pour fixer : définir RPA_API_TOKEN=<32 hex chars> dans l'environnement. "
|
||||
"Pour désactiver l'auth en dev local : RPA_AUTH_DISABLED=true."
|
||||
)
|
||||
logger.critical(_FATAL_MSG)
|
||||
print(_FATAL_MSG, flush=True)
|
||||
# Utiliser sys.exit pour un arrêt propre (raise RuntimeError est accroché
|
||||
# par uvicorn sur Python 3.11, sys.exit remonte BaseException).
|
||||
import sys as _sys
|
||||
_sys.exit(1)
|
||||
else:
|
||||
API_TOKEN = _API_TOKEN_ENV
|
||||
# Log non-sensible : 8 premiers caractères seulement pour aider au diagnostic.
|
||||
logger.info(
|
||||
f"[SÉCURITÉ] Token API chargé (8 premiers caractères : "
|
||||
f"{API_TOKEN[:8]}…) — auth Bearer obligatoire"
|
||||
)
|
||||
|
||||
# Endpoints publics (pas besoin de token)
|
||||
# En production, /docs et /redoc sont désactivés (voir ci-dessous)
|
||||
# Paths publics : pas de token requis
|
||||
# /replay/next est public car l'agent Rust legacy n'envoie pas de token
|
||||
# et c'est un endpoint read-only (polling, pas d'écriture)
|
||||
#
|
||||
# Fix P0-B : /api/v1/traces/stream/image RETIRÉ de la liste publique.
|
||||
# L'upload d'image écrit sur disque + déclenche du travail VLM : exiger
|
||||
# un token Bearer. Tous les agents V1 déployés envoient déjà le token
|
||||
# (cf. agent_v0/agent_v1/network/streamer.py:_auth_headers).
|
||||
_PUBLIC_PATHS = {
|
||||
"/health", "/docs", "/openapi.json", "/redoc",
|
||||
"/api/v1/traces/stream/replay/next",
|
||||
"/api/v1/traces/stream/image",
|
||||
}
|
||||
|
||||
|
||||
async def _verify_token(request: Request):
|
||||
"""Middleware de vérification du token API Bearer."""
|
||||
"""Middleware de vérification du token API Bearer.
|
||||
|
||||
Bypass si RPA_AUTH_DISABLED=true (mode dev local uniquement).
|
||||
"""
|
||||
if _AUTH_DISABLED:
|
||||
return
|
||||
if request.url.path in _PUBLIC_PATHS:
|
||||
return
|
||||
auth = request.headers.get("Authorization", "")
|
||||
@@ -159,6 +219,10 @@ from .replay_engine import (
|
||||
_is_learned_workflow,
|
||||
_edge_to_normalized_actions,
|
||||
_substitute_variables,
|
||||
_resolve_runtime_vars,
|
||||
_SERVER_SIDE_ACTION_TYPES,
|
||||
_handle_extract_text_action,
|
||||
_handle_t2a_decision_action,
|
||||
_expand_compound_steps,
|
||||
_pre_check_screen_state as _pre_check_screen_state_impl,
|
||||
_detect_popup_hint as _detect_popup_hint_impl,
|
||||
@@ -232,6 +296,20 @@ app.add_middleware(
|
||||
)
|
||||
|
||||
|
||||
@app.middleware("http")
|
||||
async def url_compat_rewrite(request: Request, call_next):
|
||||
"""Rétrocompatibilité : réécriture des anciennes URLs sans préfixe /api/v1.
|
||||
|
||||
Certains agents clients (Léa V1 gelée) envoient sur /traces/stream/...
|
||||
au lieu de /api/v1/traces/stream/... Ce middleware redirige silencieusement.
|
||||
"""
|
||||
path = request.url.path
|
||||
if path.startswith("/traces/stream/") and not path.startswith("/api/v1/"):
|
||||
new_path = "/api/v1" + path
|
||||
request.scope["path"] = new_path
|
||||
return await call_next(request)
|
||||
|
||||
|
||||
@app.middleware("http")
|
||||
async def security_headers_middleware(request: Request, call_next):
|
||||
"""Ajouter les headers de sécurité sur toutes les réponses."""
|
||||
@@ -281,6 +359,14 @@ REPLAY_LOCK_FILE = _DATA_DIR / "_replay_active.lock"
|
||||
processor = StreamProcessor(data_dir=str(LIVE_SESSIONS_DIR))
|
||||
worker = StreamWorker(live_dir=str(LIVE_SESSIONS_DIR), processor=processor)
|
||||
|
||||
# Registre des postes Lea enroles (table enrolled_agents dans rpa_data.db)
|
||||
# Emplacement configurable via RPA_AGENTS_DB_PATH pour les tests.
|
||||
_AGENTS_DB_PATH = os.environ.get(
|
||||
"RPA_AGENTS_DB_PATH",
|
||||
str(ROOT_DIR / "data" / "databases" / "rpa_data.db"),
|
||||
)
|
||||
agent_registry = AgentRegistry(db_path=_AGENTS_DB_PATH)
|
||||
|
||||
|
||||
# =========================================================================
|
||||
# Flush garanti à l'arrêt — signal handler + atexit (ceinture et bretelles)
|
||||
@@ -490,6 +576,12 @@ class ReplayResultReport(BaseModel):
|
||||
target_spec: Optional[Dict[str, Any]] = None # Spec complete de la cible
|
||||
# Correction humaine (mode apprentissage supervisé)
|
||||
correction: Optional[Dict[str, Any]] = None # {x_pct, y_pct, uia_snapshot, crop_b64}
|
||||
# Sécurité : signalement d'un dialogue système critique détecté
|
||||
# (UAC, CredUI, SmartScreen...). Quand ce champ est présent, l'agent
|
||||
# refuse toute interaction et le serveur bascule en paused_need_help.
|
||||
# Cf. agent_v1/core/system_dialog_guard.py
|
||||
system_dialog: Optional[Dict[str, Any]] = None # {category, matched_signal, matched_value, reason, context}
|
||||
needs_human: Optional[bool] = None
|
||||
|
||||
|
||||
class ErrorCallbackConfig(BaseModel):
|
||||
@@ -498,6 +590,28 @@ class ErrorCallbackConfig(BaseModel):
|
||||
callback_url: str # URL à appeler en cas d'erreur non-récupérable
|
||||
|
||||
|
||||
# -------------------------------------------------------------------------
|
||||
# Agent Fleet — enrollment / desinstallation
|
||||
# Consommes par l'installeur Lea.iss (voir deploy/installer/)
|
||||
# -------------------------------------------------------------------------
|
||||
class AgentEnrollRequest(BaseModel):
|
||||
"""Enregistrement d'un nouveau poste lors de l'installation Lea."""
|
||||
machine_id: str
|
||||
user_name: Optional[str] = None
|
||||
user_email: Optional[str] = None
|
||||
user_id: Optional[str] = None
|
||||
hostname: Optional[str] = None
|
||||
os_info: Optional[str] = None
|
||||
version: Optional[str] = None
|
||||
|
||||
|
||||
class AgentUninstallRequest(BaseModel):
|
||||
"""Notification de desinstallation d'un poste."""
|
||||
machine_id: str
|
||||
# reason = user_uninstall | admin_revoke | machine_retired (libre)
|
||||
reason: Optional[str] = None
|
||||
|
||||
|
||||
# Thread de nettoyage périodique des replays terminés et sessions expirées
|
||||
_cleanup_thread: Optional[threading.Thread] = None
|
||||
_cleanup_running = False
|
||||
@@ -837,6 +951,40 @@ _som_enrichment_executor = ThreadPoolExecutor(
|
||||
max_workers=1, thread_name_prefix="som_enrich",
|
||||
)
|
||||
|
||||
# ThreadPool dédié à l'anonymisation PII (OCR + NER).
|
||||
# Activable via RPA_PII_BLUR_SERVER (default : true). 1 worker suffit, le
|
||||
# pipeline est rapide (<2 s par screenshot) et le blur peut prendre du retard
|
||||
# sur la capture sans bloquer ni le replay ni le grounding (ils utilisent le
|
||||
# fichier _full.png brut).
|
||||
_PII_BLUR_ENABLED = os.environ.get("RPA_PII_BLUR_SERVER", "true").lower() in ("true", "1", "yes")
|
||||
_pii_blur_executor = ThreadPoolExecutor(
|
||||
max_workers=1, thread_name_prefix="pii_blur",
|
||||
)
|
||||
|
||||
|
||||
def _produce_blurred_version(raw_path: str, shot_id: str) -> None:
|
||||
"""Exécute (en thread) le pipeline de blur PII sur un screenshot brut.
|
||||
|
||||
Écrit `<stem>_blurred.png` à côté du fichier brut pour l'affichage
|
||||
dashboard/cleaner. Le fichier brut `<stem>.png` reste intact pour le
|
||||
grounding, le replay et l'entraînement.
|
||||
"""
|
||||
if _blur_pii_on_image is None:
|
||||
return
|
||||
try:
|
||||
raw = Path(raw_path)
|
||||
out = raw.with_name(f"{raw.stem}_blurred{raw.suffix or '.png'}")
|
||||
# Évite de retraiter si déjà floutée (robustesse aux doubles réceptions)
|
||||
if out.exists() and out.stat().st_mtime >= raw.stat().st_mtime:
|
||||
return
|
||||
result = _blur_pii_on_image(raw, out)
|
||||
logger.debug(
|
||||
"pii_blur : %s → %d PII (%.0fms, ner=%s)",
|
||||
shot_id, result.count, result.elapsed_ms, result.ner_engine,
|
||||
)
|
||||
except Exception as e: # noqa: BLE001
|
||||
logger.warning("pii_blur : échec sur %s (%s)", shot_id, e)
|
||||
|
||||
# Clics en attente d'enrichissement (le screenshot n'est pas encore arrivé)
|
||||
# Clé : (session_id, screenshot_id) → dict avec les infos nécessaires
|
||||
_pending_click_enrichments: Dict[tuple, Dict[str, Any]] = {}
|
||||
@@ -1163,6 +1311,20 @@ async def stream_image(
|
||||
|
||||
file_path_str = str(file_path)
|
||||
|
||||
# Anonymisation PII côté serveur (OCR + NER + blur ciblé).
|
||||
# On ne floute QUE les screenshots affichés dans le dashboard / cleaner :
|
||||
# shot_XXXX_full (screenshots d'action) et heartbeats (vue live).
|
||||
# Les crops, focus, window sont utilisés pour le grounding/template — pas
|
||||
# d'affichage humain direct donc pas besoin de version floutée.
|
||||
# Le fichier brut (shot_XXXX_full.png) reste intact pour le replay,
|
||||
# le grounding VLM et l'entraînement. La version floutée est écrite en
|
||||
# parallèle sous shot_XXXX_full_blurred.png.
|
||||
if _PII_BLUR_ENABLED and _blur_pii_on_image is not None and (
|
||||
("_full" in shot_id and shot_id.startswith("shot_"))
|
||||
or shot_id.startswith("heartbeat_")
|
||||
):
|
||||
_pii_blur_executor.submit(_produce_blurred_version, file_path_str, shot_id)
|
||||
|
||||
# Crops : traitement léger (pas d'analyse ScreenAnalyzer)
|
||||
if "_crop" in shot_id:
|
||||
result = worker.process_crop_direct(session_id, shot_id, file_path_str)
|
||||
@@ -2600,8 +2762,29 @@ async def get_next_action(session_id: str, machine_id: str = "default"):
|
||||
|
||||
Si la session de l'agent n'a pas d'actions en attente, cherche dans les
|
||||
autres queues de la MÊME machine (pas cross-machine).
|
||||
|
||||
Acquire timeout : si une action serveur lente (extract_text OCR,
|
||||
t2a_decision LLM) tient le lock, on retourne immédiatement
|
||||
{action: None, server_busy: True} avant que le client ne timeout à 5s.
|
||||
Sans cela, des actions seraient popped serveur puis envoyées sur des
|
||||
sockets clients déjà fermées par timeout — perdues silencieusement.
|
||||
|
||||
L'acquire et les actions serveur lentes sont exécutés via
|
||||
run_in_executor : sinon l'appel synchrone bloque l'event loop FastAPI
|
||||
(single-threaded) et même les polls qui devraient recevoir server_busy
|
||||
sont bloqués jusqu'à libération — ce qui annule l'effet du timeout.
|
||||
"""
|
||||
with _replay_lock:
|
||||
import asyncio
|
||||
loop = asyncio.get_event_loop()
|
||||
acquired = await loop.run_in_executor(None, _replay_lock.acquire, True, 4.5)
|
||||
if not acquired:
|
||||
return {
|
||||
"action": None,
|
||||
"session_id": session_id,
|
||||
"machine_id": machine_id,
|
||||
"server_busy": True,
|
||||
}
|
||||
try:
|
||||
# Verifier si le replay est en pause supervisee (target_not_found).
|
||||
# Dans ce cas, NE PAS envoyer d'action — attendre l'intervention utilisateur.
|
||||
for state in _replay_states.values():
|
||||
@@ -2666,6 +2849,7 @@ async def get_next_action(session_id: str, machine_id: str = "default"):
|
||||
break
|
||||
if target_state:
|
||||
queue = target_queue
|
||||
owning_replay = target_state
|
||||
_replay_queues[session_id] = target_queue
|
||||
del _replay_queues[target_sid]
|
||||
target_state["session_id"] = session_id
|
||||
@@ -2682,6 +2866,7 @@ async def get_next_action(session_id: str, machine_id: str = "default"):
|
||||
other_queue = _replay_queues.get(other_sid, [])
|
||||
if other_queue:
|
||||
queue = other_queue
|
||||
owning_replay = state
|
||||
_replay_queues[session_id] = other_queue
|
||||
del _replay_queues[other_sid]
|
||||
state["session_id"] = session_id
|
||||
@@ -2692,9 +2877,81 @@ async def get_next_action(session_id: str, machine_id: str = "default"):
|
||||
if not queue:
|
||||
return {"action": None, "session_id": session_id, "machine_id": machine_id}
|
||||
|
||||
# Peek à la prochaine action SANS la retirer (pour le pre-check)
|
||||
# ── Boucle de traitement : actions serveur (extract_text, t2a_decision)
|
||||
# exécutées entièrement côté serveur jusqu'à trouver une action visuelle
|
||||
# à transmettre à l'Agent V1 ou un pause_for_human qui bloque le replay.
|
||||
action = None
|
||||
while queue:
|
||||
action = queue[0]
|
||||
|
||||
# Résoudre les variables runtime ({{var}} et {{var.field}})
|
||||
if owning_replay is not None:
|
||||
runtime_vars = owning_replay.get("variables") or {}
|
||||
if runtime_vars:
|
||||
action = _resolve_runtime_vars(action, runtime_vars)
|
||||
|
||||
type_ = action.get("type")
|
||||
|
||||
# pause_for_human : no-op en mode autonome — on saute et on continue
|
||||
if type_ == "pause_for_human":
|
||||
logger.info(
|
||||
"pause_for_human ignorée (mode autonome) — replay %s continue",
|
||||
owning_replay["replay_id"] if owning_replay else "?"
|
||||
)
|
||||
queue.pop(0)
|
||||
_replay_queues[session_id] = queue
|
||||
continue
|
||||
|
||||
# Actions serveur : exécuter HORS event loop pour ne pas bloquer
|
||||
# les autres polls (extract_text OCR ~5s, t2a_decision LLM ~8-13s).
|
||||
# Le lock reste tenu (queue cohérente) mais l'event loop est libre,
|
||||
# donc les polls concurrents peuvent recevoir {server_busy: True}.
|
||||
if type_ in _SERVER_SIDE_ACTION_TYPES and owning_replay is not None:
|
||||
try:
|
||||
if type_ == "extract_text":
|
||||
await loop.run_in_executor(
|
||||
None,
|
||||
_handle_extract_text_action,
|
||||
action, owning_replay, session_id, _last_heartbeat,
|
||||
)
|
||||
elif type_ == "t2a_decision":
|
||||
await loop.run_in_executor(
|
||||
None,
|
||||
_handle_t2a_decision_action,
|
||||
action, owning_replay,
|
||||
)
|
||||
except Exception as e:
|
||||
logger.warning(f"Action serveur {type_} a levé : {e}")
|
||||
queue.pop(0)
|
||||
_replay_queues[session_id] = queue
|
||||
continue # action suivante
|
||||
|
||||
# Clic conditionnel : si l'action a un paramètre "condition", évaluer la variable
|
||||
# Format : "dec.critere1_valide" → runtime_vars["dec"]["critere1_valide"]
|
||||
condition_key = (action.get("parameters") or {}).get("condition")
|
||||
if condition_key and owning_replay is not None:
|
||||
runtime_vars = owning_replay.get("variables") or {}
|
||||
parts = condition_key.split(".", 1)
|
||||
if len(parts) == 2:
|
||||
val = (runtime_vars.get(parts[0]) or {}).get(parts[1])
|
||||
else:
|
||||
val = runtime_vars.get(parts[0])
|
||||
if not val:
|
||||
logger.info("Clic conditionnel ignoré (%s=%s) — action %s",
|
||||
condition_key, val, action.get("action_id", "?"))
|
||||
queue.pop(0)
|
||||
_replay_queues[session_id] = queue
|
||||
continue
|
||||
|
||||
# Action visuelle : sortir de la boucle pour la transmettre à l'Agent V1
|
||||
break
|
||||
|
||||
# Si la queue s'est vidée après les exécutions serveur, rien à transmettre
|
||||
if not queue or action is None:
|
||||
return {"action": None, "session_id": session_id, "machine_id": machine_id}
|
||||
finally:
|
||||
_replay_lock.release()
|
||||
|
||||
# ---- Pre-check écran (optionnel, non bloquant) ----
|
||||
# Ne s'applique qu'aux actions qui ont un from_node (actions de workflow,
|
||||
# pas les wait/retry auto-injectés ni les actions Copilot/Agent Libre)
|
||||
@@ -3212,6 +3469,92 @@ async def report_action_result(report: ReplayResultReport):
|
||||
replay_state["completed_actions"] += 1
|
||||
replay_state["current_action_index"] += 1
|
||||
|
||||
elif not report.success and (report.system_dialog or (report.error or "").startswith("system_dialog:")):
|
||||
# ── SÉCURITÉ : dialogue système Windows détecté (UAC / CredUI / SmartScreen) ──
|
||||
# L'agent REFUSE de cliquer automatiquement sur ces dialogues.
|
||||
# On bascule immédiatement en paused_need_help — l'humain doit
|
||||
# valider manuellement (saisir mdp, autoriser l'élévation…).
|
||||
# Cf. agent_v1/core/system_dialog_guard.py
|
||||
_sys_info = report.system_dialog or {}
|
||||
_sys_category = (
|
||||
_sys_info.get("category")
|
||||
or (report.error or "system_dialog:unknown").split(":", 1)[-1]
|
||||
)
|
||||
_sys_reason = _sys_info.get("reason", "")
|
||||
_tspec_sys = (original_action or {}).get("target_spec") or report.target_spec or {}
|
||||
|
||||
# Message utilisateur adapté à la catégorie
|
||||
_cat_messages = {
|
||||
"uac_consent": (
|
||||
"Une demande d'élévation de privilèges (UAC) est apparue. "
|
||||
"Je ne clique jamais automatiquement dessus — merci de valider "
|
||||
"ou refuser toi-même, puis relance-moi."
|
||||
),
|
||||
"windows_credential_prompt": (
|
||||
"Windows me demande un mot de passe / identifiants. "
|
||||
"Merci de remplir toi-même, puis relance-moi."
|
||||
),
|
||||
"smartscreen": (
|
||||
"SmartScreen a bloqué l'application. "
|
||||
"Merci de vérifier et débloquer manuellement si légitime."
|
||||
),
|
||||
"windows_defender": (
|
||||
"Windows Defender signale une alerte. "
|
||||
"Merci de vérifier manuellement."
|
||||
),
|
||||
"driver_install": (
|
||||
"Une installation de pilote est demandée. "
|
||||
"Merci de valider manuellement."
|
||||
),
|
||||
}
|
||||
_pause_msg_sys = _cat_messages.get(
|
||||
_sys_category,
|
||||
"Un dialogue système Windows est apparu. "
|
||||
"Je ne clique pas automatiquement dessus — merci de gérer manuellement."
|
||||
)
|
||||
|
||||
replay_state["status"] = "paused_need_help"
|
||||
replay_state["failed_action"] = {
|
||||
"action_id": action_id,
|
||||
"type": (original_action or {}).get("type", "unknown"),
|
||||
"target_description": f"Dialogue système : {_sys_category}",
|
||||
"screenshot_b64": screenshot_after or report.screenshot,
|
||||
"target_spec": _tspec_sys,
|
||||
"reason": "system_dialog",
|
||||
"system_dialog": _sys_info,
|
||||
"error_detail": _sys_reason or (report.error or ""),
|
||||
}
|
||||
replay_state["pause_message"] = _pause_msg_sys
|
||||
error_entry = {
|
||||
"action_id": action_id,
|
||||
"error": f"system_dialog:{_sys_category}",
|
||||
"retry_count": retry_count,
|
||||
"timestamp": time.time(),
|
||||
}
|
||||
replay_state["error_log"].append(error_entry)
|
||||
logger.critical(
|
||||
f"[SECURITE] Replay PAUSE supervisee (dialogue systeme) : "
|
||||
f"{action_id} — categorie={_sys_category} — "
|
||||
f"signal={_sys_info.get('matched_signal', '?')}='{_sys_info.get('matched_value', '?')}' "
|
||||
f"— reason={_sys_reason}"
|
||||
)
|
||||
try:
|
||||
log_replay_failure(
|
||||
replay_id=replay_state["replay_id"],
|
||||
action_id=action_id,
|
||||
target_spec=_tspec_sys,
|
||||
screenshot_b64=screenshot_after or report.screenshot,
|
||||
error=f"system_dialog:{_sys_category}",
|
||||
extra={
|
||||
"system_dialog": _sys_info,
|
||||
"category": _sys_category,
|
||||
"matched_signal": _sys_info.get("matched_signal", ""),
|
||||
"matched_value": _sys_info.get("matched_value", ""),
|
||||
},
|
||||
)
|
||||
except Exception as _log_exc:
|
||||
logger.debug("log_replay_failure skip (system_dialog): %s", _log_exc)
|
||||
|
||||
elif not report.success and agent_warning == "wrong_window":
|
||||
# L'agent a détecté en pré-vérification que la fenêtre active
|
||||
# n'est pas celle attendue. Même philosophie que no_screen_change :
|
||||
@@ -3635,7 +3978,9 @@ async def resume_replay(replay_id: str):
|
||||
state["pause_message"] = None
|
||||
|
||||
# Reinjecter l'action echouee en tete de queue (sera re-tentee)
|
||||
if failed_action and failed_action.get("action_id"):
|
||||
# pause_for_human est une pause intentionnelle, pas une erreur — ne pas réinjecter
|
||||
if (failed_action and failed_action.get("action_id")
|
||||
and failed_action.get("reason") != "user_request"):
|
||||
# Reconstruire l'action a partir du retry_pending ou de l'original
|
||||
original_action_id = failed_action["action_id"]
|
||||
# Chercher l'action originale dans les retry_pending
|
||||
@@ -3676,6 +4021,26 @@ async def resume_replay(replay_id: str):
|
||||
}
|
||||
|
||||
|
||||
@app.post("/api/v1/traces/stream/replay/{replay_id}/cancel")
|
||||
async def cancel_replay(replay_id: str):
|
||||
"""Annuler un replay (quel que soit son statut) et vider sa queue."""
|
||||
with _replay_lock:
|
||||
state = _replay_states.get(replay_id)
|
||||
if not state:
|
||||
raise HTTPException(status_code=404, detail=f"Replay '{replay_id}' non trouvé")
|
||||
session_id = state["session_id"]
|
||||
state["status"] = "cancelled"
|
||||
state["failed_action"] = None
|
||||
state["pause_message"] = None
|
||||
_replay_queues[session_id] = []
|
||||
keys_to_del = [k for k, v in _retry_pending.items() if v.get("replay_id") == replay_id]
|
||||
for k in keys_to_del:
|
||||
_retry_pending.pop(k, None)
|
||||
|
||||
logger.info("Replay %s annulé manuellement", replay_id)
|
||||
return {"status": "cancelled", "replay_id": replay_id, "session_id": session_id}
|
||||
|
||||
|
||||
# =========================================================================
|
||||
# Visual Replay — Résolution visuelle des cibles (module resolve_engine)
|
||||
# =========================================================================
|
||||
@@ -4495,6 +4860,149 @@ async def list_chat_sessions():
|
||||
}
|
||||
|
||||
|
||||
# =========================================================================
|
||||
# Fleet management — enrollment des postes collaborateurs
|
||||
# Consommes par deploy/installer/Lea.iss et deploy/installer/uninstall_lea.ps1
|
||||
# =========================================================================
|
||||
|
||||
def _agent_row_public(row: Dict[str, Any]) -> Dict[str, Any]:
|
||||
"""Projette un row de la table enrolled_agents pour l'API publique.
|
||||
|
||||
On ne renvoie PAS l'id SQL interne : machine_id est l'identifiant public.
|
||||
"""
|
||||
return {
|
||||
"machine_id": row.get("machine_id"),
|
||||
"user_name": row.get("user_name"),
|
||||
"user_email": row.get("user_email"),
|
||||
"user_id": row.get("user_id"),
|
||||
"hostname": row.get("hostname"),
|
||||
"os_info": row.get("os_info"),
|
||||
"version": row.get("version"),
|
||||
"status": row.get("status"),
|
||||
"enrolled_at": row.get("enrolled_at"),
|
||||
"last_seen_at": row.get("last_seen_at"),
|
||||
"uninstalled_at": row.get("uninstalled_at"),
|
||||
"uninstall_reason": row.get("uninstall_reason"),
|
||||
}
|
||||
|
||||
|
||||
@app.post("/api/v1/agents/enroll", status_code=201)
|
||||
async def agents_enroll(request: AgentEnrollRequest):
|
||||
"""Enregistre un nouveau poste collaborateur (appele par l'installeur).
|
||||
|
||||
Comportement :
|
||||
- machine_id unique et obligatoire.
|
||||
- Si deja enrole et actif -> 409 Conflict (avec infos de l'enrollement existant).
|
||||
- Si deja enrole mais desinstalle -> reactive automatiquement (return 201 + reactivated=True).
|
||||
- Token Bearer global obligatoire (un seul token partage entre tous les postes).
|
||||
Une phase 2 pourra emettre un token par poste si besoin.
|
||||
"""
|
||||
machine_id = (request.machine_id or "").strip()
|
||||
if not machine_id:
|
||||
raise HTTPException(status_code=400, detail="machine_id est obligatoire")
|
||||
|
||||
try:
|
||||
result = agent_registry.enroll(
|
||||
machine_id=machine_id,
|
||||
user_name=request.user_name,
|
||||
user_email=request.user_email,
|
||||
user_id=request.user_id,
|
||||
hostname=request.hostname,
|
||||
os_info=request.os_info,
|
||||
version=request.version,
|
||||
)
|
||||
except AgentAlreadyEnrolledError as exc:
|
||||
existing = _agent_row_public(exc.existing)
|
||||
logger.warning(
|
||||
f"[FLEET] Tentative de reenrollement machine_id={machine_id} "
|
||||
f"(deja actif depuis {existing.get('enrolled_at')})"
|
||||
)
|
||||
raise HTTPException(
|
||||
status_code=409,
|
||||
detail={
|
||||
"error": "already_enrolled",
|
||||
"message": "machine_id deja enrole et actif",
|
||||
"existing": existing,
|
||||
},
|
||||
)
|
||||
except ValueError as exc:
|
||||
raise HTTPException(status_code=400, detail=str(exc))
|
||||
|
||||
agent = _agent_row_public(result["agent"])
|
||||
event_kind = "reactivated" if result["reactivated"] else "created"
|
||||
logger.info(
|
||||
f"[FLEET] Agent enrole ({event_kind}) : machine_id={machine_id} "
|
||||
f"user={request.user_name!r} hostname={request.hostname!r} "
|
||||
f"version={request.version!r}"
|
||||
)
|
||||
|
||||
return {
|
||||
"status": "enrolled",
|
||||
"created": result["created"],
|
||||
"reactivated": result["reactivated"],
|
||||
"machine_id": machine_id,
|
||||
# Phase 1 : on renvoie le token global pour que le client puisse
|
||||
# verifier qu'il est bien aligne avec le serveur. Phase 2 pourra
|
||||
# emettre un token par poste (issued_token != API_TOKEN global).
|
||||
"api_token": API_TOKEN,
|
||||
"agent": agent,
|
||||
}
|
||||
|
||||
|
||||
@app.post("/api/v1/agents/uninstall")
|
||||
async def agents_uninstall(request: AgentUninstallRequest):
|
||||
"""Marque un poste comme desinstalle (soft delete, garde l'historique).
|
||||
|
||||
Appele par deploy/installer/uninstall_lea.ps1 en best-effort. Si le
|
||||
machine_id est inconnu -> 404 (le client l'ignore silencieusement).
|
||||
"""
|
||||
machine_id = (request.machine_id or "").strip()
|
||||
if not machine_id:
|
||||
raise HTTPException(status_code=400, detail="machine_id est obligatoire")
|
||||
|
||||
reason = (request.reason or "").strip() or None
|
||||
|
||||
try:
|
||||
row = agent_registry.uninstall(machine_id=machine_id, reason=reason)
|
||||
except ValueError as exc:
|
||||
raise HTTPException(status_code=400, detail=str(exc))
|
||||
|
||||
if row is None:
|
||||
logger.warning(
|
||||
f"[FLEET] Desinstallation d'un machine_id inconnu : {machine_id}"
|
||||
)
|
||||
raise HTTPException(
|
||||
status_code=404,
|
||||
detail=f"machine_id={machine_id} introuvable dans le registre",
|
||||
)
|
||||
|
||||
logger.info(
|
||||
f"[FLEET] Agent desinstalle : machine_id={machine_id} reason={reason!r}"
|
||||
)
|
||||
return {
|
||||
"status": "uninstalled",
|
||||
"machine_id": machine_id,
|
||||
"agent": _agent_row_public(row),
|
||||
}
|
||||
|
||||
|
||||
@app.get("/api/v1/agents/fleet")
|
||||
async def agents_fleet():
|
||||
"""Liste les agents enroles, separes par statut (active / uninstalled).
|
||||
|
||||
Futur dashboard fleet : synthese des postes deployes + ceux disparus.
|
||||
"""
|
||||
active_rows = agent_registry.list_by_status("active")
|
||||
uninstalled_rows = agent_registry.list_by_status("uninstalled")
|
||||
|
||||
return {
|
||||
"active": [_agent_row_public(r) for r in active_rows],
|
||||
"uninstalled": [_agent_row_public(r) for r in uninstalled_rows],
|
||||
"total_active": len(active_rows),
|
||||
"total_uninstalled": len(uninstalled_rows),
|
||||
}
|
||||
|
||||
|
||||
if __name__ == "__main__":
|
||||
import uvicorn
|
||||
|
||||
|
||||
@@ -65,7 +65,8 @@ class LiveSessionState:
|
||||
class LiveSessionManager:
|
||||
"""Gère les sessions live en mémoire côté serveur avec persistance disque."""
|
||||
|
||||
def __init__(self, persist_dir: str = "data/streaming_sessions"):
|
||||
def __init__(self, persist_dir: str = "data/streaming_sessions",
|
||||
live_sessions_dir: Optional[str] = None):
|
||||
self._sessions: Dict[str, LiveSessionState] = {}
|
||||
self._lock = threading.Lock()
|
||||
self._persist_dir = Path(persist_dir)
|
||||
@@ -74,11 +75,16 @@ class LiveSessionManager:
|
||||
self._persist_counter = 0 # Compteur pour limiter la fréquence de persistance
|
||||
self._persist_interval = 10 # Persister toutes les N modifications
|
||||
|
||||
# Dossier des sessions live (JSONL + screenshots)
|
||||
self._live_sessions_dir = Path(live_sessions_dir) if live_sessions_dir else None
|
||||
|
||||
# Charger les sessions persistées au démarrage
|
||||
self._load_persisted_sessions()
|
||||
# Reconstruire les sessions depuis les live_events.jsonl sur disque
|
||||
self._discover_sessions_from_disk()
|
||||
|
||||
def _load_persisted_sessions(self):
|
||||
"""Charger les sessions sauvegardées au démarrage."""
|
||||
"""Charger les sessions sauvegardées au démarrage (JSON state files)."""
|
||||
count = 0
|
||||
for session_file in sorted(self._persist_dir.glob("sess_*.json")):
|
||||
try:
|
||||
@@ -92,6 +98,66 @@ class LiveSessionManager:
|
||||
if count:
|
||||
logger.info(f"{count} session(s) restaurée(s) depuis {self._persist_dir}")
|
||||
|
||||
def _discover_sessions_from_disk(self):
|
||||
"""Découvrir les sessions depuis les live_events.jsonl sur disque.
|
||||
|
||||
Reconstruit les sessions manquantes du session_manager en scannant :
|
||||
- live_sessions/sess_*/live_events.jsonl (sessions racine)
|
||||
- live_sessions/{machine_id}/sess_*/live_events.jsonl (multi-machine)
|
||||
|
||||
Ne touche pas aux sessions déjà chargées depuis le JSON persist.
|
||||
"""
|
||||
if self._live_sessions_dir is None:
|
||||
return
|
||||
live_dir = self._live_sessions_dir
|
||||
if not live_dir.exists():
|
||||
return
|
||||
|
||||
discovered = 0
|
||||
for jsonl_file in sorted(live_dir.glob("**/live_events.jsonl")):
|
||||
session_dir = jsonl_file.parent
|
||||
session_id = session_dir.name
|
||||
if not session_id.startswith("sess_"):
|
||||
continue
|
||||
if session_id in self._sessions:
|
||||
continue
|
||||
|
||||
# Déduire le machine_id depuis le chemin parent
|
||||
parent_name = session_dir.parent.name
|
||||
if parent_name == live_dir.name:
|
||||
machine_id = "default"
|
||||
else:
|
||||
machine_id = parent_name
|
||||
|
||||
# Compter events et screenshots
|
||||
events_count = 0
|
||||
try:
|
||||
with open(jsonl_file, 'r', encoding='utf-8') as f:
|
||||
for _ in f:
|
||||
events_count += 1
|
||||
except Exception:
|
||||
pass
|
||||
|
||||
shots_dir = session_dir / "shots"
|
||||
shots_count = len(list(shots_dir.glob("shot_*_full.png"))) if shots_dir.exists() else 0
|
||||
|
||||
# Créer la session en mémoire
|
||||
session = LiveSessionState(
|
||||
session_id=session_id,
|
||||
machine_id=machine_id,
|
||||
finalized=False,
|
||||
)
|
||||
# Stocker le nombre d'events/shots dans les métadonnées
|
||||
session.shot_paths = {f"shot_{i:04d}": "" for i in range(shots_count)}
|
||||
self._sessions[session_id] = session
|
||||
discovered += 1
|
||||
|
||||
if discovered:
|
||||
logger.info(
|
||||
f"{discovered} session(s) découverte(s) depuis {live_dir} "
|
||||
f"(total: {len(self._sessions)} sessions en mémoire)"
|
||||
)
|
||||
|
||||
def _persist_session(self, session_id: str):
|
||||
"""Sauvegarder une session sur disque (appelé périodiquement)."""
|
||||
session = self._sessions.get(session_id)
|
||||
@@ -102,7 +168,7 @@ class LiveSessionManager:
|
||||
with open(filepath, 'w', encoding='utf-8') as f:
|
||||
json.dump(session.to_dict(), f, ensure_ascii=False)
|
||||
except Exception as e:
|
||||
logger.debug(f"Erreur persistance session {session_id}: {e}")
|
||||
logger.warning(f"Erreur persistance session {session_id}: {e}")
|
||||
|
||||
def _maybe_persist(self, session_id: str):
|
||||
"""Persister si le compteur atteint l'intervalle."""
|
||||
@@ -180,6 +246,17 @@ class LiveSessionManager:
|
||||
if meta_val is not None:
|
||||
info[meta_key] = meta_val
|
||||
session.last_window_info = info
|
||||
# Exploiter window_capture (envoyé par l'agent avec la capture fenêtre)
|
||||
# pour enrichir last_window_info avec le titre précis de la fenêtre cliquée
|
||||
window_capture = event_data.get("window_capture")
|
||||
if window_capture and isinstance(window_capture, dict):
|
||||
wc_title = window_capture.get("title", "").strip()
|
||||
wc_app = window_capture.get("app_name", "").strip()
|
||||
if wc_title:
|
||||
session.last_window_info["title"] = wc_title
|
||||
if wc_app:
|
||||
session.last_window_info["app_name"] = wc_app
|
||||
|
||||
# Accumuler les titres/apps pour le nommage automatique
|
||||
title = session.last_window_info.get("title", "").strip()
|
||||
app_name = session.last_window_info.get("app_name", "").strip()
|
||||
@@ -221,18 +298,41 @@ class LiveSessionManager:
|
||||
import socket
|
||||
|
||||
# Construire les événements au format RawSession
|
||||
# Important : copier TOUTES les données de l'événement (pos, text, keys, button...)
|
||||
# car Event.from_dict() met tout sauf t/type/window/screenshot_id dans event.data,
|
||||
# et le GraphBuilder utilise event.data pour construire les actions.
|
||||
events = []
|
||||
for evt in session.events:
|
||||
# Extraire window info (plusieurs formats possibles)
|
||||
window_raw = evt.get("window")
|
||||
if isinstance(window_raw, dict):
|
||||
window_info = {
|
||||
"title": window_raw.get("title", session.last_window_info.get("title", "")),
|
||||
"app_name": window_raw.get("app_name", session.last_window_info.get("app_name", "unknown")),
|
||||
}
|
||||
else:
|
||||
window_info = {
|
||||
"title": evt.get("window_title", session.last_window_info.get("title", "")),
|
||||
"app_name": evt.get("app_name", session.last_window_info.get("app_name", "unknown")),
|
||||
}
|
||||
events.append({
|
||||
|
||||
raw_event = {
|
||||
"t": evt.get("timestamp", 0),
|
||||
"type": evt.get("type", "unknown"),
|
||||
"window": window_info,
|
||||
"screenshot_id": evt.get("screenshot_id"),
|
||||
})
|
||||
}
|
||||
|
||||
# Copier les données spécifiques au type d'événement
|
||||
# (pos, button, text, keys, etc.) — indispensable pour le replay
|
||||
_skip_keys = {"type", "timestamp", "window", "window_title",
|
||||
"app_name", "screenshot_id", "machine_id",
|
||||
"screen_metadata", "vision_info"}
|
||||
for key, value in evt.items():
|
||||
if key not in _skip_keys and key not in raw_event:
|
||||
raw_event[key] = value
|
||||
|
||||
events.append(raw_event)
|
||||
|
||||
# Construire les screenshots au format RawSession
|
||||
screenshots = []
|
||||
|
||||
@@ -33,7 +33,15 @@ _ALLOWED_ACTION_TYPES = {
|
||||
"file_open", "file_save", "file_close", "file_new", "file_dialog",
|
||||
"double_click", "right_click", "drag",
|
||||
"verify_screen", # Replay hybride : vérification visuelle entre groupes
|
||||
"pause_for_human", # Pause supervisée explicite (interceptée par /replay/next)
|
||||
"extract_text", # OCR serveur sur dernier heartbeat → variable workflow
|
||||
"t2a_decision", # Analyse LLM facturation T2A → variable workflow
|
||||
}
|
||||
|
||||
# Types d'actions exécutées CÔTÉ SERVEUR (jamais transmises à l'Agent V1).
|
||||
# Le pipeline /replay/next les traite en boucle interne et passe à l'action
|
||||
# suivante jusqu'à trouver une action visuelle (à transmettre au client).
|
||||
_SERVER_SIDE_ACTION_TYPES = {"extract_text", "t2a_decision"}
|
||||
_MAX_ACTION_TEXT_LENGTH = 10000
|
||||
_MAX_KEYS_PER_COMBO = 10
|
||||
# Touches autorisées dans les key_combo (modificateurs + touches spéciales + caractères simples)
|
||||
@@ -852,6 +860,30 @@ def _edge_to_normalized_actions(edge, params: Dict[str, Any]) -> List[Dict[str,
|
||||
keys = [action_params["key"]]
|
||||
normalized["keys"] = keys
|
||||
|
||||
elif action_type == "pause_for_human":
|
||||
normalized["type"] = "pause_for_human"
|
||||
normalized["parameters"] = {
|
||||
"message": action_params.get("message", "Validation requise"),
|
||||
}
|
||||
return [normalized] # pas de target/coords pour cette action logique
|
||||
|
||||
elif action_type == "extract_text":
|
||||
normalized["type"] = "extract_text"
|
||||
normalized["parameters"] = {
|
||||
"output_var": action_params.get("output_var", "extracted_text"),
|
||||
"paragraph": bool(action_params.get("paragraph", True)),
|
||||
}
|
||||
return [normalized]
|
||||
|
||||
elif action_type == "t2a_decision":
|
||||
normalized["type"] = "t2a_decision"
|
||||
normalized["parameters"] = {
|
||||
"input_template": action_params.get("input_template", ""),
|
||||
"output_var": action_params.get("output_var", "t2a_result"),
|
||||
"model": action_params.get("model"),
|
||||
}
|
||||
return [normalized]
|
||||
|
||||
else:
|
||||
logger.warning(f"Type d'action inconnu : {action_type}")
|
||||
return []
|
||||
@@ -886,6 +918,143 @@ def _substitute_variables(text: str, params: Dict[str, Any], defaults: Dict[str,
|
||||
return re.sub(r'\$\{(\w+)\}', replacer, text)
|
||||
|
||||
|
||||
# Regex pour le templating runtime : {{var}} ou {{var.champ}} ou {{var.champ.sous}}
|
||||
_RUNTIME_VAR_PATTERN = re.compile(r'\{\{\s*(\w+)(?:\.([\w.]+))?\s*\}\}')
|
||||
|
||||
|
||||
def _resolve_runtime_vars_in_str(text: str, variables: Dict[str, Any]) -> str:
|
||||
"""Remplace {{var}} et {{var.field}} par leur valeur depuis le dict variables.
|
||||
|
||||
Variables/champs absents : laissés tels quels (ne casse pas le pipeline).
|
||||
Pour les valeurs non-str (dict, list), str() est appelé.
|
||||
"""
|
||||
def replacer(match):
|
||||
var_name = match.group(1)
|
||||
path = match.group(2)
|
||||
if var_name not in variables:
|
||||
return match.group(0)
|
||||
value = variables[var_name]
|
||||
if path:
|
||||
for field in path.split('.'):
|
||||
if isinstance(value, dict) and field in value:
|
||||
value = value[field]
|
||||
else:
|
||||
return match.group(0)
|
||||
return str(value)
|
||||
|
||||
return _RUNTIME_VAR_PATTERN.sub(replacer, text)
|
||||
|
||||
|
||||
def _resolve_runtime_vars(value: Any, variables: Dict[str, Any]) -> Any:
|
||||
"""Résout récursivement les {{var}} et {{var.field}} dans une valeur.
|
||||
|
||||
Supporte str, dict, list. Les autres types sont retournés tels quels.
|
||||
Si variables est vide ou None, value est retournée inchangée.
|
||||
"""
|
||||
if not variables:
|
||||
return value
|
||||
if isinstance(value, str):
|
||||
return _resolve_runtime_vars_in_str(value, variables)
|
||||
if isinstance(value, dict):
|
||||
return {k: _resolve_runtime_vars(v, variables) for k, v in value.items()}
|
||||
if isinstance(value, list):
|
||||
return [_resolve_runtime_vars(item, variables) for item in value]
|
||||
return value
|
||||
|
||||
|
||||
# =========================================================================
|
||||
# Handlers pour les actions exécutées côté serveur (extract_text, t2a_decision)
|
||||
# =========================================================================
|
||||
|
||||
def _handle_extract_text_action(
|
||||
action: Dict[str, Any],
|
||||
replay_state: Dict[str, Any],
|
||||
session_id: str,
|
||||
last_heartbeat: Dict[str, Dict[str, Any]],
|
||||
) -> bool:
|
||||
"""Traite une action extract_text côté serveur. Stocke le texte OCRisé dans
|
||||
replay_state["variables"][output_var]. Retourne True si succès.
|
||||
|
||||
Robuste aux échecs : si pas de heartbeat ou OCR raté, stocke "" et retourne
|
||||
False (le pipeline continue, pas de blocage).
|
||||
"""
|
||||
params = action.get("parameters") or {}
|
||||
output_var = (params.get("output_var") or "extracted_text").strip()
|
||||
paragraph = bool(params.get("paragraph", True))
|
||||
|
||||
heartbeat = last_heartbeat.get(session_id) or {}
|
||||
path = heartbeat.get("path")
|
||||
text = ""
|
||||
|
||||
if path:
|
||||
try:
|
||||
from core.llm import extract_text_from_image
|
||||
text = extract_text_from_image(path, paragraph=paragraph)
|
||||
except Exception as e:
|
||||
logger.warning("extract_text OCR échoué (%s) — variable '%s' = ''", e, output_var)
|
||||
else:
|
||||
logger.warning(
|
||||
"extract_text : pas de heartbeat pour session %s — variable '%s' = ''",
|
||||
session_id, output_var,
|
||||
)
|
||||
|
||||
replay_state.setdefault("variables", {})[output_var] = text
|
||||
logger.info(
|
||||
"extract_text → variable '%s' (%d chars) replay %s",
|
||||
output_var, len(text), replay_state.get("replay_id", "?"),
|
||||
)
|
||||
return bool(text)
|
||||
|
||||
|
||||
def _handle_t2a_decision_action(
|
||||
action: Dict[str, Any],
|
||||
replay_state: Dict[str, Any],
|
||||
) -> bool:
|
||||
"""Traite une action t2a_decision côté serveur. Stocke le résultat JSON
|
||||
dans replay_state["variables"][output_var]. Retourne True si succès.
|
||||
|
||||
Le DPI à analyser vient de action.parameters.input_template (déjà résolu
|
||||
par _resolve_runtime_vars donc les {{var}} sont remplis).
|
||||
"""
|
||||
params = action.get("parameters") or {}
|
||||
output_var = (params.get("output_var") or "t2a_result").strip()
|
||||
dpi_text = (params.get("input_template") or params.get("dpi") or "").strip()
|
||||
model = params.get("model") or None # None → DEFAULT_MODEL
|
||||
|
||||
if not dpi_text:
|
||||
logger.warning(
|
||||
"t2a_decision : input vide — variable '%s' = {decision: 'INDETERMINE'}", output_var,
|
||||
)
|
||||
replay_state.setdefault("variables", {})[output_var] = {
|
||||
"decision": "INDETERMINE",
|
||||
"justification": "DPI vide ou non extrait",
|
||||
"confiance": "faible",
|
||||
"_error": "empty_input",
|
||||
}
|
||||
return False
|
||||
|
||||
try:
|
||||
from core.llm import analyze_dpi, DEFAULT_MODEL
|
||||
result = analyze_dpi(dpi_text, model=model or DEFAULT_MODEL)
|
||||
except Exception as e:
|
||||
logger.warning("t2a_decision : analyze_dpi exception %s", e)
|
||||
result = {
|
||||
"decision": "INDETERMINE",
|
||||
"justification": f"Erreur analyse : {e}",
|
||||
"confiance": "faible",
|
||||
"_error": str(e),
|
||||
}
|
||||
|
||||
replay_state.setdefault("variables", {})[output_var] = result
|
||||
decision = result.get("decision", "?")
|
||||
elapsed = result.get("_elapsed_s", "?")
|
||||
logger.info(
|
||||
"t2a_decision → variable '%s' decision=%s (%ss) replay %s",
|
||||
output_var, decision, elapsed, replay_state.get("replay_id", "?"),
|
||||
)
|
||||
return "_error" not in result
|
||||
|
||||
|
||||
def _expand_compound_steps(
|
||||
steps: List[Dict[str, Any]], base: Dict[str, Any], params: Dict[str, Any]
|
||||
) -> List[Dict[str, Any]]:
|
||||
@@ -1208,6 +1377,10 @@ def _create_replay_state(
|
||||
# Champs pour pause supervisée (target_not_found)
|
||||
"failed_action": None, # Contexte de l'action en echec (quand paused_need_help)
|
||||
"pause_message": None, # Message a afficher a l'utilisateur
|
||||
# Variables d'exécution produites en cours de workflow (extract_text,
|
||||
# t2a_decision, etc.). Résolues via templating {{var}} ou {{var.field}}
|
||||
# dans les paramètres des actions suivantes.
|
||||
"variables": {},
|
||||
}
|
||||
|
||||
|
||||
|
||||
@@ -248,7 +248,14 @@ def memory_record_success(
|
||||
try:
|
||||
from core.learning.target_memory_store import TargetFingerprint
|
||||
|
||||
# Stripper les préfixes "memory_" empilés pour ne garder que
|
||||
# la méthode de résolution originale (ex: template_matching).
|
||||
# Sans ça, le cycle lookup → record → lookup empile "memory_"
|
||||
# indéfiniment : memory_memory_memory_template_matching.
|
||||
method_clean = method or "v4_unknown"
|
||||
while method_clean.startswith("memory_"):
|
||||
method_clean = method_clean[len("memory_"):]
|
||||
method_clean = method_clean or "v4_unknown"
|
||||
fingerprint = TargetFingerprint(
|
||||
element_id=f"v4_{method_clean}",
|
||||
bbox=(x_pct, y_pct, 0.0, 0.0),
|
||||
|
||||
@@ -2193,22 +2193,33 @@ def _validate_resolution_quality(
|
||||
dx = abs(resolved_x - fallback_x_pct)
|
||||
dy = abs(resolved_y - fallback_y_pct)
|
||||
if dx > _RESOLUTION_MAX_DRIFT or dy > _RESOLUTION_MAX_DRIFT:
|
||||
logger.warning(
|
||||
"[REPLAY] Resolution REJETÉE (drift trop grand) : "
|
||||
"method=%s resolved=(%.3f, %.3f) expected=(%.3f, %.3f) "
|
||||
"drift=(%.3f, %.3f) max=%.2f",
|
||||
method, resolved_x, resolved_y,
|
||||
fallback_x_pct, fallback_y_pct,
|
||||
dx, dy, _RESOLUTION_MAX_DRIFT,
|
||||
# Exception : si le template matching trouve l'image avec une
|
||||
# similarité quasi parfaite, on fait confiance à la position
|
||||
# visuelle peu importe le drift. Une image retrouvée à >= 0.95
|
||||
# de score est SUR l'écran à l'endroit indiqué — le drift par
|
||||
# rapport à l'enregistrement ne reflète qu'un changement de
|
||||
# layout (scroll, redimensionnement, F11, devtools), pas une
|
||||
# erreur de résolution.
|
||||
_HIGH_CONFIDENCE = 0.95
|
||||
if score >= _HIGH_CONFIDENCE and method.startswith("template_matching"):
|
||||
logger.info(
|
||||
"[REPLAY] Drift (%.3f, %.3f) > %.2f IGNORÉ : score=%.3f >= %.2f "
|
||||
"sur %s — résultat visuel fiable, on l'utilise",
|
||||
dx, dy, _RESOLUTION_MAX_DRIFT, score, _HIGH_CONFIDENCE, method,
|
||||
)
|
||||
return result
|
||||
|
||||
logger.warning(
|
||||
"[REPLAY] Drift trop grand (%.3f, %.3f) > %.2f — fallback coords enregistrées (%.3f, %.3f)",
|
||||
dx, dy, _RESOLUTION_MAX_DRIFT, fallback_x_pct, fallback_y_pct,
|
||||
)
|
||||
# Fallback : coordonnées enregistrées lors de la capture (écran identique = safe)
|
||||
return {
|
||||
"resolved": False,
|
||||
"method": f"rejected_drift_{method}",
|
||||
"reason": f"drift_dx{dx:.3f}_dy{dy:.3f}_max{_RESOLUTION_MAX_DRIFT:.2f}",
|
||||
"resolved": True,
|
||||
"method": "fallback_recorded_coords",
|
||||
"reason": f"drift_dx{dx:.3f}_dy{dy:.3f}_using_recorded",
|
||||
"original_method": method,
|
||||
"original_score": score,
|
||||
"drift_dx": round(dx, 3),
|
||||
"drift_dy": round(dy, 3),
|
||||
"x_pct": fallback_x_pct,
|
||||
"y_pct": fallback_y_pct,
|
||||
}
|
||||
|
||||
@@ -1791,6 +1791,10 @@ class StreamProcessor:
|
||||
# Workflows construits (pour le matching)
|
||||
self._workflows: Dict[str, Any] = {}
|
||||
|
||||
# Shadow learning : dernier pattern UI détecté par session
|
||||
# Stocke {session_id: {"pattern": str, "ocr_text": str, "screen_state": obj, "shot_id": str}}
|
||||
self._pending_ui_patterns: Dict[str, Dict[str, Any]] = {}
|
||||
|
||||
# Charger les workflows existants depuis le disque
|
||||
self._load_persisted_workflows()
|
||||
|
||||
@@ -1975,6 +1979,9 @@ class StreamProcessor:
|
||||
- key_combo/key_press avec uniquement des modificateurs seuls (ctrl, alt, shift, etc.)
|
||||
- key_combo/key_press avec liste de touches vide
|
||||
- text_input avec texte vide
|
||||
|
||||
Shadow learning : quand un clic suit un pattern UI détecté,
|
||||
on apprend l'association dialogue→bouton.
|
||||
"""
|
||||
if _is_parasitic_event(event_data):
|
||||
logger.debug(
|
||||
@@ -1982,9 +1989,119 @@ class StreamProcessor:
|
||||
f"type={event_data.get('type')}, data={event_data.get('keys', event_data.get('text', ''))}"
|
||||
)
|
||||
return {"status": "event_filtered", "session_id": session_id, "reason": "parasitic"}
|
||||
|
||||
# Shadow learning : si un pattern UI est en attente et qu'on reçoit un clic
|
||||
if event_data.get("type") == "mouse_click":
|
||||
self._try_shadow_learn(session_id, event_data)
|
||||
|
||||
self.session_manager.add_event(session_id, event_data)
|
||||
return {"status": "event_recorded", "session_id": session_id}
|
||||
|
||||
def _try_shadow_learn(self, session_id: str, click_event: Dict[str, Any]):
|
||||
"""Tente d'apprendre un pattern UI depuis un clic observé en Shadow.
|
||||
|
||||
Quand un screenshot contenait un pattern UI détecté (dialogue) et que
|
||||
l'utilisateur clique ensuite, on extrait le texte OCR au point de clic
|
||||
pour apprendre l'association : "quand je vois ce texte → cliquer sur ce bouton".
|
||||
"""
|
||||
with self._data_lock:
|
||||
pending = self._pending_ui_patterns.pop(session_id, None)
|
||||
if not pending:
|
||||
return
|
||||
|
||||
screen_state = pending.get("screen_state")
|
||||
if screen_state is None:
|
||||
return
|
||||
|
||||
# Extraire la position du clic (pixels absolus)
|
||||
pos = click_event.get("pos", [])
|
||||
if not pos or len(pos) != 2:
|
||||
return
|
||||
|
||||
click_x, click_y = pos[0], pos[1]
|
||||
|
||||
# Trouver le texte OCR le plus proche du point de clic
|
||||
# via les ui_elements du ScreenState (ils ont bbox + label)
|
||||
clicked_label = self._find_label_at_position(screen_state, click_x, click_y)
|
||||
if not clicked_label:
|
||||
return
|
||||
|
||||
# Extraire le trigger principal du texte OCR du dialogue
|
||||
ocr_text = pending.get("ocr_text", "")
|
||||
# Utiliser un extrait court comme trigger (max 80 chars, premier segment pertinent)
|
||||
trigger_text = ocr_text[:80].strip().lower()
|
||||
if not trigger_text:
|
||||
return
|
||||
|
||||
logger.info(
|
||||
f"Shadow learning: pattern '{pending['pattern_name']}' "
|
||||
f"→ utilisateur a cliqué '{clicked_label}' | trigger='{trigger_text[:40]}...'"
|
||||
)
|
||||
|
||||
# Sauvegarder le pattern appris
|
||||
try:
|
||||
from core.knowledge.ui_patterns import UIPatternLibrary
|
||||
lib = UIPatternLibrary()
|
||||
lib.save_learned_pattern({
|
||||
"category": "dialog",
|
||||
"triggers": [trigger_text],
|
||||
"action": "click",
|
||||
"target": clicked_label,
|
||||
"os": "windows",
|
||||
"confidence": 0.8,
|
||||
})
|
||||
except Exception as e:
|
||||
logger.warning(f"Shadow learning: échec sauvegarde pattern: {e}")
|
||||
|
||||
@staticmethod
|
||||
def _find_label_at_position(screen_state, click_x: int, click_y: int) -> Optional[str]:
|
||||
"""Trouve le label de l'élément UI le plus proche du point de clic.
|
||||
|
||||
Parcourt les ui_elements du ScreenState et retourne le label de
|
||||
l'élément dont la bbox contient le point, ou le plus proche si aucun
|
||||
ne contient exactement le point.
|
||||
"""
|
||||
ui_elements = getattr(screen_state, "ui_elements", [])
|
||||
if not ui_elements:
|
||||
return None
|
||||
|
||||
best_label = None
|
||||
best_dist = float("inf")
|
||||
|
||||
for elem in ui_elements:
|
||||
bbox = getattr(elem, "bbox", None)
|
||||
label = getattr(elem, "label", "")
|
||||
if not bbox or not label:
|
||||
continue
|
||||
|
||||
# BBox = (x, y, width, height) — extraire les coordonnées
|
||||
try:
|
||||
bx, by = bbox.x, bbox.y
|
||||
bw, bh = bbox.width, bbox.height
|
||||
except AttributeError:
|
||||
# Fallback si bbox est une liste/tuple
|
||||
if hasattr(bbox, '__len__') and len(bbox) >= 4:
|
||||
bx, by, bw, bh = bbox[0], bbox[1], bbox[2], bbox[3]
|
||||
else:
|
||||
continue
|
||||
|
||||
# Vérifier si le clic est dans la bbox
|
||||
if bx <= click_x <= bx + bw and by <= click_y <= by + bh:
|
||||
return label.strip()
|
||||
|
||||
# Sinon calculer la distance au centre
|
||||
cx = bx + bw / 2
|
||||
cy = by + bh / 2
|
||||
dist = ((click_x - cx) ** 2 + (click_y - cy) ** 2) ** 0.5
|
||||
if dist < best_dist:
|
||||
best_dist = dist
|
||||
best_label = label.strip()
|
||||
|
||||
# Ne retourner le plus proche que s'il est raisonnablement proche (< 100px)
|
||||
if best_label and best_dist < 100:
|
||||
return best_label
|
||||
return None
|
||||
|
||||
# =========================================================================
|
||||
# Screenshots
|
||||
# =========================================================================
|
||||
@@ -2042,6 +2159,37 @@ class StreamProcessor:
|
||||
self._screen_states[session_id] = []
|
||||
self._screen_states[session_id].append(screen_state)
|
||||
|
||||
# Enrichir avec les patterns UI connus
|
||||
try:
|
||||
from core.knowledge.ui_patterns import UIPatternLibrary
|
||||
detected_text = getattr(screen_state.perception, "detected_text", [])
|
||||
if detected_text:
|
||||
ocr_text = " ".join(str(t) for t in detected_text) if isinstance(detected_text, list) else str(detected_text)
|
||||
lib = UIPatternLibrary()
|
||||
pattern = lib.find_pattern(ocr_text)
|
||||
if pattern:
|
||||
result["ui_pattern"] = pattern["pattern"]
|
||||
result["ui_pattern_action"] = pattern["action"]
|
||||
result["ui_pattern_target"] = pattern["target"]
|
||||
logger.info(f"Pattern UI détecté: {pattern['pattern']} → {pattern['target']}")
|
||||
|
||||
# Shadow learning : mémoriser le pattern en attente du clic utilisateur
|
||||
with self._data_lock:
|
||||
self._pending_ui_patterns[session_id] = {
|
||||
"pattern_name": pattern["pattern"],
|
||||
"ocr_text": ocr_text,
|
||||
"screen_state": screen_state,
|
||||
"shot_id": shot_id,
|
||||
}
|
||||
else:
|
||||
# Pas de pattern connu → effacer le pending (l'écran a changé)
|
||||
with self._data_lock:
|
||||
self._pending_ui_patterns.pop(session_id, None)
|
||||
except ImportError:
|
||||
pass
|
||||
except Exception as e:
|
||||
logger.debug(f"Pattern check: {e}")
|
||||
|
||||
logger.info(
|
||||
f"Screenshot analysé: {shot_id} | "
|
||||
f"{result['ui_elements_count']} UI elements, "
|
||||
|
||||
@@ -76,6 +76,15 @@ class StepMetrics:
|
||||
confidence_score: float
|
||||
retry_count: int = 0
|
||||
error_details: Optional[str] = None
|
||||
# C1 — Instrumentation vision-aware (ExecutionLoop)
|
||||
# Ces champs proviennent de `StepResult` (core/execution/execution_loop.py).
|
||||
# Tous optionnels avec valeurs par défaut pour rétrocompatibilité.
|
||||
ocr_ms: float = 0.0 # Temps OCR sur ce step
|
||||
ui_ms: float = 0.0 # Temps détection UI sur ce step
|
||||
analyze_ms: float = 0.0 # Temps analyse ScreenState (OCR + UI + reste)
|
||||
total_ms: float = 0.0 # Temps total du step (alias duration_ms)
|
||||
cache_hit: bool = False # True si ScreenState vient du cache perceptuel
|
||||
degraded: bool = False # True si mode dégradé (timeout analyse)
|
||||
|
||||
def to_dict(self) -> Dict[str, Any]:
|
||||
"""Convert to dictionary for storage."""
|
||||
@@ -92,7 +101,13 @@ class StepMetrics:
|
||||
'status': self.status,
|
||||
'confidence_score': self.confidence_score,
|
||||
'retry_count': self.retry_count,
|
||||
'error_details': self.error_details
|
||||
'error_details': self.error_details,
|
||||
'ocr_ms': self.ocr_ms,
|
||||
'ui_ms': self.ui_ms,
|
||||
'analyze_ms': self.analyze_ms,
|
||||
'total_ms': self.total_ms,
|
||||
'cache_hit': self.cache_hit,
|
||||
'degraded': self.degraded,
|
||||
}
|
||||
|
||||
@classmethod
|
||||
@@ -111,7 +126,13 @@ class StepMetrics:
|
||||
status=data['status'],
|
||||
confidence_score=data['confidence_score'],
|
||||
retry_count=data.get('retry_count', 0),
|
||||
error_details=data.get('error_details')
|
||||
error_details=data.get('error_details'),
|
||||
ocr_ms=float(data.get('ocr_ms') or 0.0),
|
||||
ui_ms=float(data.get('ui_ms') or 0.0),
|
||||
analyze_ms=float(data.get('analyze_ms') or 0.0),
|
||||
total_ms=float(data.get('total_ms') or 0.0),
|
||||
cache_hit=bool(data.get('cache_hit') or False),
|
||||
degraded=bool(data.get('degraded') or False),
|
||||
)
|
||||
|
||||
|
||||
|
||||
@@ -1,8 +1,8 @@
|
||||
"""Integration of analytics with ExecutionLoop."""
|
||||
|
||||
import logging
|
||||
from typing import Optional
|
||||
from datetime import datetime
|
||||
from typing import Any, Optional
|
||||
from datetime import datetime, timedelta
|
||||
import uuid
|
||||
|
||||
from ..analytics_system import get_analytics_system
|
||||
@@ -14,17 +14,35 @@ logger = logging.getLogger(__name__)
|
||||
class AnalyticsExecutionIntegration:
|
||||
"""Integrate analytics collection with workflow execution."""
|
||||
|
||||
def __init__(self, enabled: bool = True):
|
||||
def __init__(self, analytics_system: Any = True, enabled: Optional[bool] = None):
|
||||
"""
|
||||
Initialize analytics integration.
|
||||
|
||||
Args:
|
||||
enabled: Whether analytics collection is enabled
|
||||
"""
|
||||
self.enabled = enabled
|
||||
self.analytics = None
|
||||
Accepte deux formes d'appel pour la rétrocompatibilité :
|
||||
- ``AnalyticsExecutionIntegration(enabled=True)`` → auto-load du système
|
||||
- ``AnalyticsExecutionIntegration(analytics_system_instance)`` →
|
||||
utilise l'instance fournie (utilisé par ExecutionLoop)
|
||||
|
||||
if enabled:
|
||||
Args:
|
||||
analytics_system: Instance d'AnalyticsSystem pré-construite, ou
|
||||
True/False pour activer/désactiver (legacy).
|
||||
enabled: Legacy — si défini, prime sur analytics_system.
|
||||
"""
|
||||
# Détection de la forme d'appel
|
||||
if enabled is not None:
|
||||
# Appel legacy explicite: AnalyticsExecutionIntegration(enabled=...)
|
||||
self.enabled = bool(enabled)
|
||||
self.analytics = None
|
||||
elif isinstance(analytics_system, bool):
|
||||
# Appel legacy: AnalyticsExecutionIntegration(True/False)
|
||||
self.enabled = analytics_system
|
||||
self.analytics = None
|
||||
else:
|
||||
# Nouvelle forme: instance injectée
|
||||
self.enabled = analytics_system is not None
|
||||
self.analytics = analytics_system
|
||||
|
||||
if self.enabled and self.analytics is None:
|
||||
try:
|
||||
self.analytics = get_analytics_system()
|
||||
logger.info("Analytics integration enabled")
|
||||
@@ -36,18 +54,21 @@ class AnalyticsExecutionIntegration:
|
||||
self,
|
||||
workflow_id: str,
|
||||
execution_id: Optional[str] = None,
|
||||
total_steps: int = 0
|
||||
total_steps: int = 0,
|
||||
mode: Optional[str] = None,
|
||||
) -> str:
|
||||
"""
|
||||
Called when workflow execution starts.
|
||||
Appelé au démarrage d'une exécution de workflow.
|
||||
|
||||
Args:
|
||||
workflow_id: Workflow identifier
|
||||
execution_id: Execution identifier (generated if None)
|
||||
total_steps: Total number of steps
|
||||
workflow_id: Identifiant du workflow
|
||||
execution_id: Identifiant d'exécution (généré si None)
|
||||
total_steps: Nombre total d'étapes prévues
|
||||
mode: Mode d'exécution (OBSERVATION / COACHING / SUPERVISED /
|
||||
AUTOMATIC). Propagé en contexte pour MetricsCollector.
|
||||
|
||||
Returns:
|
||||
Execution ID
|
||||
Identifiant d'exécution (celui fourni ou nouvellement généré).
|
||||
"""
|
||||
if not self.enabled or not self.analytics:
|
||||
return execution_id or str(uuid.uuid4())
|
||||
@@ -56,11 +77,21 @@ class AnalyticsExecutionIntegration:
|
||||
execution_id = str(uuid.uuid4())
|
||||
|
||||
try:
|
||||
# Start real-time tracking
|
||||
# Démarrage du tracking temps réel
|
||||
self.analytics.realtime_analytics.track_execution(
|
||||
execution_id=execution_id,
|
||||
workflow_id=workflow_id,
|
||||
total_steps=total_steps
|
||||
total_steps=total_steps,
|
||||
)
|
||||
|
||||
# Ouverture de l'ExecutionMetrics côté collector (état "running").
|
||||
# Cela permet à `on_execution_complete` d'appeler
|
||||
# `record_execution_complete` qui clôture proprement.
|
||||
context = {"mode": mode} if mode else {}
|
||||
self.analytics.metrics_collector.record_execution_start(
|
||||
execution_id=execution_id,
|
||||
workflow_id=workflow_id,
|
||||
context=context,
|
||||
)
|
||||
|
||||
logger.debug(f"Started tracking execution: {execution_id}")
|
||||
@@ -101,108 +132,247 @@ class AnalyticsExecutionIntegration:
|
||||
execution_id: str,
|
||||
workflow_id: str,
|
||||
node_id: str,
|
||||
action_type: str,
|
||||
started_at: datetime,
|
||||
completed_at: datetime,
|
||||
duration: float,
|
||||
*,
|
||||
duration_ms: float,
|
||||
success: bool,
|
||||
error_message: Optional[str] = None
|
||||
action_type: str = "",
|
||||
started_at: Optional[datetime] = None,
|
||||
completed_at: Optional[datetime] = None,
|
||||
error_message: Optional[str] = None,
|
||||
confidence: float = 0.0,
|
||||
target_element: str = "",
|
||||
retry_count: int = 0,
|
||||
ocr_ms: float = 0.0,
|
||||
ui_ms: float = 0.0,
|
||||
analyze_ms: float = 0.0,
|
||||
total_ms: float = 0.0,
|
||||
cache_hit: bool = False,
|
||||
degraded: bool = False,
|
||||
step_id: Optional[str] = None,
|
||||
) -> None:
|
||||
"""
|
||||
Called when a step completes.
|
||||
Appelé à la fin d'un step.
|
||||
|
||||
Contrat normalisé (Lot A — avril 2026) : ``duration_ms`` est
|
||||
obligatoire et en millisecondes. Plus de rétrocompat silencieuse
|
||||
sur ``duration`` en secondes.
|
||||
|
||||
Args:
|
||||
execution_id: Execution identifier
|
||||
workflow_id: Workflow identifier
|
||||
node_id: Node identifier
|
||||
action_type: Type of action
|
||||
started_at: Start timestamp
|
||||
completed_at: Completion timestamp
|
||||
duration: Duration in seconds
|
||||
success: Whether step succeeded
|
||||
error_message: Error message if failed
|
||||
execution_id: Identifiant d'exécution
|
||||
workflow_id: Identifiant du workflow
|
||||
node_id: Identifiant du node
|
||||
duration_ms: Durée du step en millisecondes (obligatoire)
|
||||
success: Vrai si le step a réussi
|
||||
action_type: Type d'action (``click``, ``type``, …)
|
||||
started_at: Timestamp de début (déduit de duration_ms si None)
|
||||
completed_at: Timestamp de fin (``now()`` si None)
|
||||
error_message: Message d'erreur si ``success=False``
|
||||
confidence: Score de matching [0, 1]
|
||||
target_element: Élément ciblé (optionnel)
|
||||
retry_count: Nombre de retries
|
||||
ocr_ms: Temps OCR (C1)
|
||||
ui_ms: Temps détection UI (C1)
|
||||
analyze_ms: Temps analyse ScreenState (C1)
|
||||
total_ms: Temps total du step (C1, alias duration_ms)
|
||||
cache_hit: ScreenState depuis cache perceptuel (C1)
|
||||
degraded: Mode dégradé activé (C1)
|
||||
step_id: ID unique du step (généré si None)
|
||||
"""
|
||||
if not self.enabled or not self.analytics:
|
||||
return
|
||||
|
||||
try:
|
||||
# Record step metrics
|
||||
duration_ms_final = float(duration_ms)
|
||||
|
||||
# Normaliser les timestamps
|
||||
if completed_at is None:
|
||||
completed_at = datetime.now()
|
||||
if started_at is None:
|
||||
started_at = completed_at - timedelta(milliseconds=duration_ms_final)
|
||||
|
||||
step_metrics = StepMetrics(
|
||||
step_id=step_id or f"{execution_id}:{node_id}:{completed_at.isoformat()}",
|
||||
execution_id=execution_id,
|
||||
workflow_id=workflow_id,
|
||||
node_id=node_id,
|
||||
action_type=action_type,
|
||||
action_type=action_type or "unknown",
|
||||
target_element=target_element,
|
||||
started_at=started_at,
|
||||
completed_at=completed_at,
|
||||
duration=duration,
|
||||
success=success,
|
||||
error_message=error_message
|
||||
duration_ms=duration_ms_final,
|
||||
status="completed" if success else "failed",
|
||||
confidence_score=float(confidence),
|
||||
retry_count=retry_count,
|
||||
error_details=error_message,
|
||||
# C1 — vision-aware
|
||||
ocr_ms=float(ocr_ms or 0.0),
|
||||
ui_ms=float(ui_ms or 0.0),
|
||||
analyze_ms=float(analyze_ms or 0.0),
|
||||
total_ms=float(total_ms or duration_ms_final),
|
||||
cache_hit=bool(cache_hit),
|
||||
degraded=bool(degraded),
|
||||
)
|
||||
|
||||
self.analytics.metrics_collector.record_step(step_metrics)
|
||||
|
||||
# Update real-time tracking
|
||||
# Tracking temps réel
|
||||
try:
|
||||
self.analytics.realtime_analytics.record_step_complete(
|
||||
execution_id=execution_id,
|
||||
success=success
|
||||
success=success,
|
||||
)
|
||||
except Exception as rt_err:
|
||||
logger.debug(f"Realtime tracking skipped: {rt_err}")
|
||||
|
||||
logger.debug(f"Recorded step: {node_id} ({'success' if success else 'failed'})")
|
||||
logger.debug(
|
||||
f"Recorded step: {node_id} "
|
||||
f"({'success' if success else 'failed'}, "
|
||||
f"analyze_ms={analyze_ms:.0f}, cache_hit={cache_hit}, "
|
||||
f"degraded={degraded})"
|
||||
)
|
||||
except Exception as e:
|
||||
logger.error(f"Error recording step completion: {e}")
|
||||
|
||||
def on_step_result(
|
||||
self,
|
||||
execution_id: str,
|
||||
workflow_id: str,
|
||||
step_result: Any,
|
||||
) -> None:
|
||||
"""
|
||||
Raccourci C1 — enregistre un `StepResult` complet.
|
||||
|
||||
Évite aux appelants d'extraire manuellement les champs vision-aware.
|
||||
Utilisé par ExecutionLoop pour pousser StepResult au système analytics.
|
||||
|
||||
Args:
|
||||
execution_id: Identifiant d'exécution
|
||||
workflow_id: Identifiant de workflow
|
||||
step_result: Instance de `core.execution.execution_loop.StepResult`
|
||||
"""
|
||||
if not self.enabled or not self.analytics:
|
||||
return
|
||||
|
||||
action_type = "unknown"
|
||||
try:
|
||||
if getattr(step_result, "action_result", None) is not None:
|
||||
ar = step_result.action_result
|
||||
# ExecutionResult.action est optionnel selon la branche
|
||||
action_type = (
|
||||
getattr(ar, "action_type", None)
|
||||
or getattr(ar, "action", None)
|
||||
or "unknown"
|
||||
)
|
||||
except Exception:
|
||||
action_type = "unknown"
|
||||
|
||||
self.on_step_complete(
|
||||
execution_id=execution_id,
|
||||
workflow_id=workflow_id,
|
||||
node_id=getattr(step_result, "node_id", "unknown"),
|
||||
action_type=str(action_type),
|
||||
success=bool(getattr(step_result, "success", False)),
|
||||
error_message=None
|
||||
if getattr(step_result, "success", False)
|
||||
else getattr(step_result, "message", None),
|
||||
duration_ms=float(getattr(step_result, "duration_ms", 0.0) or 0.0),
|
||||
confidence=float(getattr(step_result, "match_confidence", 0.0) or 0.0),
|
||||
ocr_ms=float(getattr(step_result, "ocr_ms", 0.0) or 0.0),
|
||||
ui_ms=float(getattr(step_result, "ui_ms", 0.0) or 0.0),
|
||||
analyze_ms=float(getattr(step_result, "analyze_ms", 0.0) or 0.0),
|
||||
total_ms=float(getattr(step_result, "total_ms", 0.0) or 0.0),
|
||||
cache_hit=bool(getattr(step_result, "cache_hit", False)),
|
||||
degraded=bool(getattr(step_result, "degraded", False)),
|
||||
)
|
||||
|
||||
def on_execution_complete(
|
||||
self,
|
||||
execution_id: str,
|
||||
workflow_id: str,
|
||||
started_at: datetime,
|
||||
completed_at: datetime,
|
||||
duration: float,
|
||||
*,
|
||||
duration_ms: float,
|
||||
status: str,
|
||||
error_message: Optional[str] = None,
|
||||
steps_total: Optional[int] = None,
|
||||
steps_completed: int = 0,
|
||||
steps_failed: int = 0
|
||||
steps_failed: int = 0,
|
||||
error_message: Optional[str] = None,
|
||||
) -> None:
|
||||
"""
|
||||
Called when workflow execution completes.
|
||||
Appelé à la fin d'une exécution de workflow.
|
||||
|
||||
Contrat normalisé (Lot A — avril 2026) :
|
||||
- ``duration_ms`` en millisecondes, toujours. Plus de rétrocompat
|
||||
silencieuse sur ``duration`` en secondes.
|
||||
- ``status`` est une chaîne libre (``"completed"``, ``"failed"``,
|
||||
``"stopped"``, ``"timeout"``, …). L'appelant décide.
|
||||
- ``steps_total`` / ``steps_completed`` / ``steps_failed`` : noms
|
||||
alignés sur le dataclass ``ExecutionMetrics``. Si ``steps_total``
|
||||
n'est pas fourni, on le déduit par somme.
|
||||
|
||||
Args:
|
||||
execution_id: Execution identifier
|
||||
workflow_id: Workflow identifier
|
||||
started_at: Start timestamp
|
||||
completed_at: Completion timestamp
|
||||
duration: Duration in seconds
|
||||
status: Final status (success, failed, timeout)
|
||||
error_message: Error message if failed
|
||||
steps_completed: Number of steps completed
|
||||
steps_failed: Number of steps failed
|
||||
execution_id: Identifiant d'exécution
|
||||
workflow_id: Identifiant du workflow
|
||||
duration_ms: Durée totale en millisecondes
|
||||
status: Statut final (``"completed"`` / ``"failed"`` / ``"stopped"``)
|
||||
steps_total: Nombre total de steps exécutés (tous statuts confondus)
|
||||
steps_completed: Nombre de steps réussis
|
||||
steps_failed: Nombre de steps en échec
|
||||
error_message: Message d'erreur si ``status != "completed"``
|
||||
"""
|
||||
if not self.enabled or not self.analytics:
|
||||
return
|
||||
|
||||
# steps_total dérivé si non fourni explicitement
|
||||
if steps_total is None:
|
||||
steps_total = int(steps_completed) + int(steps_failed)
|
||||
|
||||
try:
|
||||
# Record execution metrics
|
||||
collector = self.analytics.metrics_collector
|
||||
|
||||
# record_execution_complete clôture proprement un ExecutionMetrics
|
||||
# ouvert par record_execution_start (chemin nominal via
|
||||
# on_execution_start). Si l'état n'est pas présent (tests, legacy),
|
||||
# on pousse un ExecutionMetrics synthétique directement.
|
||||
completed_at = datetime.now()
|
||||
started_at = completed_at - timedelta(milliseconds=float(duration_ms))
|
||||
|
||||
active = getattr(collector, "_active_executions", None)
|
||||
if active is not None and execution_id in active:
|
||||
collector.record_execution_complete(
|
||||
execution_id=execution_id,
|
||||
status=status,
|
||||
steps_total=int(steps_total),
|
||||
steps_completed=int(steps_completed),
|
||||
steps_failed=int(steps_failed),
|
||||
error_message=error_message,
|
||||
)
|
||||
else:
|
||||
# Fallback explicite : on construit directement un ExecutionMetrics
|
||||
# aligné sur le dataclass (duration_ms, status, steps_*).
|
||||
execution_metrics = ExecutionMetrics(
|
||||
execution_id=execution_id,
|
||||
workflow_id=workflow_id,
|
||||
started_at=started_at,
|
||||
completed_at=completed_at,
|
||||
duration=duration,
|
||||
duration_ms=float(duration_ms),
|
||||
status=status,
|
||||
steps_total=int(steps_total),
|
||||
steps_completed=int(steps_completed),
|
||||
steps_failed=int(steps_failed),
|
||||
error_message=error_message,
|
||||
steps_completed=steps_completed,
|
||||
steps_failed=steps_failed
|
||||
)
|
||||
# Le collector n'expose pas record_execution(...) : on pousse
|
||||
# dans le buffer protégé par lock pour rester cohérent.
|
||||
with collector._lock:
|
||||
collector._buffer.append(execution_metrics)
|
||||
|
||||
self.analytics.metrics_collector.record_execution(execution_metrics)
|
||||
# Flush pour garantir la persistance immédiate
|
||||
collector.flush()
|
||||
|
||||
# Flush to ensure persistence
|
||||
self.analytics.metrics_collector.flush()
|
||||
|
||||
# Complete real-time tracking
|
||||
# Clôture du tracking temps réel
|
||||
self.analytics.realtime_analytics.complete_execution(
|
||||
execution_id=execution_id,
|
||||
status=status
|
||||
status=status,
|
||||
)
|
||||
|
||||
logger.info(f"Recorded execution: {execution_id} ({status})")
|
||||
@@ -216,39 +386,54 @@ class AnalyticsExecutionIntegration:
|
||||
node_id: str,
|
||||
strategy: str,
|
||||
success: bool,
|
||||
duration: float
|
||||
duration_ms: float,
|
||||
) -> None:
|
||||
"""
|
||||
Called when self-healing attempts recovery.
|
||||
Appelé quand le self-healing tente une récupération.
|
||||
|
||||
Contrat normalisé (Lot A — avril 2026) : ``duration_ms`` en
|
||||
millisecondes, cohérent avec ``on_execution_complete`` et
|
||||
``on_step_complete``. Le StepMetrics construit respecte strictement
|
||||
le dataclass (``status``, ``duration_ms``, ``error_details``,
|
||||
``confidence_score``, ``target_element``, ``step_id``).
|
||||
|
||||
Args:
|
||||
execution_id: Execution identifier
|
||||
workflow_id: Workflow identifier
|
||||
node_id: Node identifier
|
||||
strategy: Recovery strategy used
|
||||
success: Whether recovery succeeded
|
||||
duration: Recovery duration
|
||||
execution_id: Identifiant d'exécution
|
||||
workflow_id: Identifiant du workflow
|
||||
node_id: Node où la récupération est tentée
|
||||
strategy: Stratégie de récupération employée
|
||||
success: Vrai si la récupération a réussi
|
||||
duration_ms: Durée de la tentative en millisecondes
|
||||
"""
|
||||
if not self.enabled or not self.analytics:
|
||||
return
|
||||
|
||||
try:
|
||||
# Record as a special step metric
|
||||
now = datetime.now()
|
||||
started_at = now - timedelta(milliseconds=float(duration_ms))
|
||||
|
||||
recovery_metrics = StepMetrics(
|
||||
step_id=f"{execution_id}:{node_id}:recovery:{now.isoformat()}",
|
||||
execution_id=execution_id,
|
||||
workflow_id=workflow_id,
|
||||
node_id=f"{node_id}_recovery",
|
||||
action_type=f"recovery_{strategy}",
|
||||
started_at=datetime.now(),
|
||||
completed_at=datetime.now(),
|
||||
duration=duration,
|
||||
success=success,
|
||||
error_message=None if success else f"Recovery failed: {strategy}"
|
||||
target_element="",
|
||||
started_at=started_at,
|
||||
completed_at=now,
|
||||
duration_ms=float(duration_ms),
|
||||
status="completed" if success else "failed",
|
||||
confidence_score=0.0,
|
||||
retry_count=0,
|
||||
error_details=None if success else f"Recovery failed: {strategy}",
|
||||
)
|
||||
|
||||
self.analytics.metrics_collector.record_step(recovery_metrics)
|
||||
|
||||
logger.debug(f"Recorded recovery: {strategy} ({'success' if success else 'failed'})")
|
||||
logger.debug(
|
||||
f"Recorded recovery: {strategy} "
|
||||
f"({'success' if success else 'failed'})"
|
||||
)
|
||||
except Exception as e:
|
||||
logger.error(f"Error recording recovery attempt: {e}")
|
||||
|
||||
|
||||
643
core/analytics/process_mining_bridge.py
Normal file
643
core/analytics/process_mining_bridge.py
Normal file
@@ -0,0 +1,643 @@
|
||||
"""
|
||||
Bridge entre les workflows Lea (core) et PM4Py pour le process mining.
|
||||
Genere des diagrammes BPMN et KPIs depuis les traces Shadow.
|
||||
|
||||
Usage:
|
||||
from core.analytics.process_mining_bridge import (
|
||||
sessions_to_event_log,
|
||||
workflow_to_event_log,
|
||||
discover_bpmn,
|
||||
compute_kpis,
|
||||
)
|
||||
|
||||
# Depuis des sessions JSONL brutes
|
||||
df = sessions_to_event_log(sessions_data)
|
||||
result = discover_bpmn(df, output_dir="data/analytics/bpmn")
|
||||
kpis = compute_kpis(df)
|
||||
|
||||
# Depuis un workflow core (dict JSON)
|
||||
df = workflow_to_event_log(workflow_dict)
|
||||
"""
|
||||
|
||||
import json
|
||||
import logging
|
||||
from datetime import datetime, timezone
|
||||
from pathlib import Path
|
||||
from typing import Any, Dict, List, Optional
|
||||
|
||||
import pandas as pd
|
||||
|
||||
logger = logging.getLogger(__name__)
|
||||
|
||||
# ---- Import conditionnel PM4Py -----------------------------------------
|
||||
|
||||
try:
|
||||
import pm4py
|
||||
PM4PY_AVAILABLE = True
|
||||
except ImportError:
|
||||
PM4PY_AVAILABLE = False
|
||||
logger.warning("pm4py non installe -- le process mining est desactive")
|
||||
|
||||
|
||||
def _sanitize_label(label: str) -> str:
|
||||
"""
|
||||
Supprime les caracteres de controle (0x00-0x1F sauf tab/newline)
|
||||
qui sont invalides en XML et font planter PM4Py.
|
||||
"""
|
||||
return "".join(
|
||||
c if c in ("\t", "\n", "\r") or ord(c) >= 0x20 else f"<0x{ord(c):02x}>"
|
||||
for c in label
|
||||
)
|
||||
|
||||
|
||||
# ---- Types d'evenements a ignorer (bruit) --------------------------------
|
||||
|
||||
_NOISE_EVENT_TYPES = frozenset({
|
||||
"heartbeat",
|
||||
"action_result",
|
||||
"screenshot",
|
||||
})
|
||||
|
||||
# Types d'evenements significatifs pour le process mining
|
||||
_RELEVANT_EVENT_TYPES = frozenset({
|
||||
"mouse_click",
|
||||
"text_input",
|
||||
"key_press",
|
||||
"key_combo",
|
||||
"window_focus_change",
|
||||
})
|
||||
|
||||
|
||||
# ===========================================================================
|
||||
# Conversion sessions JSONL -> event log PM4Py
|
||||
# ===========================================================================
|
||||
|
||||
|
||||
def _build_activity_label(event: dict) -> Optional[str]:
|
||||
"""
|
||||
Construit un label d'activite lisible depuis un event JSONL brut.
|
||||
|
||||
Regles :
|
||||
- mouse_click -> "Clic - <app_name> (<window_title tronque>)"
|
||||
- text_input -> "Saisie '<text>' - <app_name>"
|
||||
- key_press -> "Touche <key> - <app_name>"
|
||||
- key_combo -> "Raccourci <keys> - <app_name>"
|
||||
- window_focus_change -> "Fenetre <to.title> (<to.app_name>)"
|
||||
|
||||
Tous les labels sont sanitises pour supprimer les caracteres de controle
|
||||
(ex: \\x13 pour Ctrl+S) qui sont invalides en XML/BPMN.
|
||||
"""
|
||||
evt = event.get("event", event)
|
||||
etype = evt.get("type", "")
|
||||
|
||||
if etype in _NOISE_EVENT_TYPES:
|
||||
return None
|
||||
|
||||
# Extraction fenetre
|
||||
window = evt.get("window", {})
|
||||
app_name = window.get("app_name", "inconnu")
|
||||
win_title = window.get("title", "")
|
||||
# Tronquer le titre a 40 caracteres
|
||||
short_title = (win_title[:40] + "...") if len(win_title) > 40 else win_title
|
||||
|
||||
label: Optional[str] = None
|
||||
|
||||
if etype == "mouse_click":
|
||||
label = f"Clic - {app_name} ({short_title})"
|
||||
|
||||
elif etype == "text_input":
|
||||
text = evt.get("text", "")
|
||||
# Tronquer le texte a 20 caracteres pour rester lisible
|
||||
short_text = (text[:20] + "...") if len(text) > 20 else text
|
||||
label = f"Saisie '{short_text}' - {app_name}"
|
||||
|
||||
elif etype == "key_press":
|
||||
key = evt.get("key", "?")
|
||||
label = f"Touche {key} - {app_name}"
|
||||
|
||||
elif etype == "key_combo":
|
||||
keys = evt.get("keys", [])
|
||||
combo = "+".join(str(k) for k in keys)
|
||||
label = f"Raccourci {combo} - {app_name}"
|
||||
|
||||
elif etype == "window_focus_change":
|
||||
to_info = evt.get("to", {})
|
||||
if not to_info:
|
||||
return None
|
||||
to_title = to_info.get("title", "?")
|
||||
to_app = to_info.get("app_name", "?")
|
||||
label = f"Fenetre {to_title} ({to_app})"
|
||||
|
||||
else:
|
||||
# Types non reconnus : label generique
|
||||
label = f"{etype} - {app_name}"
|
||||
|
||||
return _sanitize_label(label) if label else None
|
||||
|
||||
|
||||
def _extract_timestamp(event: dict) -> Optional[float]:
|
||||
"""Extrait le timestamp unix depuis un event JSONL."""
|
||||
# Le timestamp peut etre au niveau racine ou dans event.timestamp
|
||||
evt = event.get("event", event)
|
||||
ts = evt.get("timestamp") or event.get("timestamp")
|
||||
if ts is not None:
|
||||
return float(ts)
|
||||
# Fallback sur le champ 't' (format simplifie)
|
||||
t = evt.get("t") or event.get("t")
|
||||
if t is not None:
|
||||
return float(t)
|
||||
return None
|
||||
|
||||
|
||||
def sessions_to_event_log(
|
||||
sessions_data: List[dict],
|
||||
deduplicate_windows: bool = True,
|
||||
) -> pd.DataFrame:
|
||||
"""
|
||||
Convertit des traces de sessions brutes (events JSONL) en event log PM4Py.
|
||||
|
||||
Chaque event pertinent devient une ligne :
|
||||
- case:concept:name = session_id
|
||||
- concept:name = label d'activite (ex: "Clic - Notepad.exe (Bloc-notes)")
|
||||
- time:timestamp = timestamp UTC
|
||||
|
||||
Args:
|
||||
sessions_data: liste de dicts, chaque dict est une ligne JSONL parsee.
|
||||
deduplicate_windows: si True, supprime les window_focus_change
|
||||
consecutifs vers la meme fenetre (bruit typique de Windows).
|
||||
|
||||
Returns:
|
||||
DataFrame pret pour PM4Py.
|
||||
"""
|
||||
rows: List[Dict[str, Any]] = []
|
||||
|
||||
# Regrouper par session_id pour le deduplication
|
||||
sessions: Dict[str, List[dict]] = {}
|
||||
for event in sessions_data:
|
||||
sid = event.get("session_id", "unknown")
|
||||
sessions.setdefault(sid, []).append(event)
|
||||
|
||||
for sid, events in sessions.items():
|
||||
# Trier par timestamp
|
||||
events.sort(key=lambda e: _extract_timestamp(e) or 0.0)
|
||||
last_window_label: Optional[str] = None
|
||||
|
||||
for event in events:
|
||||
label = _build_activity_label(event)
|
||||
if label is None:
|
||||
continue
|
||||
|
||||
ts = _extract_timestamp(event)
|
||||
if ts is None:
|
||||
continue
|
||||
|
||||
# Deduplication des changements de fenetre consecutifs
|
||||
evt = event.get("event", event)
|
||||
if deduplicate_windows and evt.get("type") == "window_focus_change":
|
||||
if label == last_window_label:
|
||||
continue
|
||||
last_window_label = label
|
||||
else:
|
||||
last_window_label = None
|
||||
|
||||
rows.append({
|
||||
"case:concept:name": sid,
|
||||
"concept:name": label,
|
||||
"time:timestamp": pd.Timestamp(
|
||||
datetime.fromtimestamp(ts, tz=timezone.utc)
|
||||
),
|
||||
"event_type": evt.get("type", ""),
|
||||
"app_name": evt.get("window", {}).get("app_name", ""),
|
||||
})
|
||||
|
||||
if not rows:
|
||||
logger.warning("Aucun evenement pertinent trouve dans les sessions")
|
||||
return pd.DataFrame(columns=[
|
||||
"case:concept:name",
|
||||
"concept:name",
|
||||
"time:timestamp",
|
||||
"event_type",
|
||||
"app_name",
|
||||
])
|
||||
|
||||
df = pd.DataFrame(rows)
|
||||
df = df.sort_values(["case:concept:name", "time:timestamp"]).reset_index(drop=True)
|
||||
logger.info(
|
||||
"Event log cree : %d evenements, %d sessions, %d activites distinctes",
|
||||
len(df),
|
||||
df["case:concept:name"].nunique(),
|
||||
df["concept:name"].nunique(),
|
||||
)
|
||||
return df
|
||||
|
||||
|
||||
# ===========================================================================
|
||||
# Conversion workflow core (dict JSON) -> event log PM4Py
|
||||
# ===========================================================================
|
||||
|
||||
|
||||
def workflow_to_event_log(workflow_dict: dict) -> pd.DataFrame:
|
||||
"""
|
||||
Convertit un workflow core (dict JSON) en DataFrame PM4Py.
|
||||
|
||||
Utilise les nodes et edges pour reconstituer une trace.
|
||||
Chaque chemin du entry_node vers un end_node = un case.
|
||||
|
||||
Mapping :
|
||||
- case:concept:name = workflow_id + suffixe de chemin
|
||||
- concept:name = node.name
|
||||
- time:timestamp = deduced from edge stats ou created_at
|
||||
"""
|
||||
wf_id = workflow_dict.get("workflow_id", "wf_unknown")
|
||||
nodes = {n["node_id"]: n for n in workflow_dict.get("nodes", [])}
|
||||
edges = workflow_dict.get("edges", [])
|
||||
entry_nodes = workflow_dict.get("entry_nodes", [])
|
||||
created_at = workflow_dict.get("created_at", datetime.now(timezone.utc).isoformat())
|
||||
|
||||
if not nodes or not edges:
|
||||
logger.warning("Workflow vide ou sans edges : %s", wf_id)
|
||||
return pd.DataFrame(columns=[
|
||||
"case:concept:name",
|
||||
"concept:name",
|
||||
"time:timestamp",
|
||||
])
|
||||
|
||||
# Construire un graphe d'adjacence
|
||||
adjacency: Dict[str, List[dict]] = {}
|
||||
for edge in edges:
|
||||
from_node = edge.get("from_node") or edge.get("source_node", "")
|
||||
adjacency.setdefault(from_node, []).append(edge)
|
||||
|
||||
# Parcours DFS pour trouver les chemins (limites a eviter l'explosion)
|
||||
MAX_PATHS = 100
|
||||
paths: List[List[str]] = []
|
||||
|
||||
def _dfs(current: str, path: List[str], visited: set) -> None:
|
||||
if len(paths) >= MAX_PATHS:
|
||||
return
|
||||
if current in visited:
|
||||
# Boucle detectee, sauvegarder le chemin tel quel
|
||||
paths.append(path[:])
|
||||
return
|
||||
visited.add(current)
|
||||
path.append(current)
|
||||
|
||||
outgoing = adjacency.get(current, [])
|
||||
if not outgoing:
|
||||
# End node
|
||||
paths.append(path[:])
|
||||
else:
|
||||
for edge in outgoing:
|
||||
to_node = edge.get("to_node") or edge.get("target_node", "")
|
||||
if to_node:
|
||||
_dfs(to_node, path, visited)
|
||||
path.pop()
|
||||
visited.discard(current)
|
||||
|
||||
for entry in entry_nodes:
|
||||
if entry in nodes:
|
||||
_dfs(entry, [], set())
|
||||
|
||||
# Si pas d'entry nodes, essayer tous les nodes sans edges entrants
|
||||
if not paths:
|
||||
target_nodes = set()
|
||||
for edge in edges:
|
||||
to_node = edge.get("to_node") or edge.get("target_node", "")
|
||||
target_nodes.add(to_node)
|
||||
root_nodes = [nid for nid in nodes if nid not in target_nodes]
|
||||
for root in root_nodes[:3]:
|
||||
_dfs(root, [], set())
|
||||
|
||||
# Construire le DataFrame
|
||||
rows: List[Dict[str, Any]] = []
|
||||
try:
|
||||
base_time = pd.Timestamp(datetime.fromisoformat(created_at))
|
||||
except (ValueError, TypeError):
|
||||
base_time = pd.Timestamp(datetime.now(timezone.utc))
|
||||
|
||||
for i, path in enumerate(paths):
|
||||
case_id = f"{wf_id}_path_{i}"
|
||||
for step_idx, node_id in enumerate(path):
|
||||
node = nodes.get(node_id, {})
|
||||
rows.append({
|
||||
"case:concept:name": case_id,
|
||||
"concept:name": node.get("name", node_id),
|
||||
"time:timestamp": base_time + pd.Timedelta(seconds=step_idx),
|
||||
})
|
||||
|
||||
df = pd.DataFrame(rows)
|
||||
if not df.empty:
|
||||
df = df.sort_values(["case:concept:name", "time:timestamp"]).reset_index(drop=True)
|
||||
logger.info(
|
||||
"Event log depuis workflow : %d evenements, %d chemins",
|
||||
len(df), len(paths),
|
||||
)
|
||||
return df
|
||||
|
||||
|
||||
# ===========================================================================
|
||||
# Decouverte BPMN
|
||||
# ===========================================================================
|
||||
|
||||
|
||||
def discover_bpmn(
|
||||
event_log_df: pd.DataFrame,
|
||||
output_dir: str = "data/analytics/bpmn",
|
||||
name: str = "process",
|
||||
) -> dict:
|
||||
"""
|
||||
Decouvre un modele BPMN depuis un event log via Inductive Miner.
|
||||
|
||||
Args:
|
||||
event_log_df: DataFrame au format PM4Py.
|
||||
output_dir: repertoire de sortie pour les fichiers generes.
|
||||
name: prefixe pour les noms de fichiers.
|
||||
|
||||
Returns:
|
||||
{
|
||||
'bpmn_xml_path': str,
|
||||
'bpmn_image_path': str,
|
||||
'petri_net_image_path': str,
|
||||
'dfg_image_path': str,
|
||||
'stats': {
|
||||
'activities': int,
|
||||
'variants': int,
|
||||
'cases': int,
|
||||
}
|
||||
}
|
||||
"""
|
||||
if not PM4PY_AVAILABLE:
|
||||
raise ImportError("pm4py n'est pas installe. Installez-le : pip install pm4py")
|
||||
|
||||
if event_log_df.empty:
|
||||
raise ValueError("Event log vide, impossible de decouvrir un BPMN")
|
||||
|
||||
out = Path(output_dir)
|
||||
out.mkdir(parents=True, exist_ok=True)
|
||||
|
||||
# Decouverte BPMN par Inductive Miner
|
||||
bpmn_model = pm4py.discover_bpmn_inductive(event_log_df)
|
||||
|
||||
# Export BPMN XML
|
||||
bpmn_xml_path = str(out / f"{name}.bpmn")
|
||||
try:
|
||||
pm4py.write_bpmn(bpmn_model, bpmn_xml_path)
|
||||
except Exception as e:
|
||||
# PM4Py layout peut echouer avec des labels contenant des caracteres
|
||||
# speciaux (accents, guillemets, etc.). Fallback : export via l'exporter
|
||||
# interne sans layout.
|
||||
logger.warning("Layout BPMN echoue (%s), export sans layout", e)
|
||||
from pm4py.objects.bpmn.exporter import exporter as bpmn_exporter
|
||||
bpmn_exporter.apply(bpmn_model, bpmn_xml_path)
|
||||
logger.info("BPMN XML exporte : %s", bpmn_xml_path)
|
||||
|
||||
# Export image BPMN (PNG) — grande taille pour lisibilité
|
||||
bpmn_image_path = str(out / f"{name}_bpmn.png")
|
||||
try:
|
||||
from pm4py.visualization.bpmn import visualizer as bpmn_vis
|
||||
gviz = bpmn_vis.apply(bpmn_model, parameters={
|
||||
"rankdir": "TB",
|
||||
"font_size": "12",
|
||||
})
|
||||
gviz.graph_attr["dpi"] = "150"
|
||||
gviz.graph_attr["size"] = "40,20!"
|
||||
gviz.graph_attr["rankdir"] = "TB"
|
||||
gviz.render(filename=bpmn_image_path.replace(".png", ""), format="png", cleanup=True)
|
||||
logger.info("BPMN PNG exporte : %s", bpmn_image_path)
|
||||
except Exception as e:
|
||||
logger.warning("BPMN image fallback : %s", e)
|
||||
try:
|
||||
pm4py.save_vis_bpmn(bpmn_model, bpmn_image_path)
|
||||
except Exception:
|
||||
bpmn_image_path = None
|
||||
|
||||
# DFG (Directly-Follows Graph) — grande taille
|
||||
dfg_image_path = str(out / f"{name}_dfg.png")
|
||||
try:
|
||||
from pm4py.visualization.dfg import visualizer as dfg_vis
|
||||
dfg, sa, ea = pm4py.discover_dfg(event_log_df)
|
||||
gviz = dfg_vis.apply(dfg, activities_count=sa, parameters={
|
||||
"start_activities": sa,
|
||||
"end_activities": ea,
|
||||
"rankdir": "TB",
|
||||
"font_size": "11",
|
||||
})
|
||||
gviz.graph_attr["dpi"] = "150"
|
||||
gviz.graph_attr["size"] = "40,20!"
|
||||
gviz.graph_attr["rankdir"] = "TB"
|
||||
gviz.render(filename=dfg_image_path.replace(".png", ""), format="png", cleanup=True)
|
||||
logger.info("DFG PNG exporte : %s", dfg_image_path)
|
||||
except Exception as e:
|
||||
logger.warning("DFG image fallback : %s", e)
|
||||
try:
|
||||
pm4py.save_vis_dfg(*pm4py.discover_dfg(event_log_df), file_path=dfg_image_path)
|
||||
except Exception:
|
||||
dfg_image_path = None
|
||||
|
||||
# Petri net via Inductive Miner (pour visualisation alternative)
|
||||
petri_image_path = str(out / f"{name}_petri.png")
|
||||
try:
|
||||
net, im, fm = pm4py.discover_petri_net_inductive(event_log_df)
|
||||
pm4py.save_vis_petri_net(net, im, fm, file_path=petri_image_path)
|
||||
logger.info("Petri net PNG exporte : %s", petri_image_path)
|
||||
except Exception as e:
|
||||
logger.warning("Impossible de generer le Petri net : %s", e)
|
||||
petri_image_path = None
|
||||
|
||||
# Stats de base
|
||||
variants = pm4py.get_variants(event_log_df)
|
||||
n_cases = event_log_df["case:concept:name"].nunique()
|
||||
n_activities = event_log_df["concept:name"].nunique()
|
||||
|
||||
result = {
|
||||
"bpmn_xml_path": bpmn_xml_path,
|
||||
"bpmn_image_path": bpmn_image_path,
|
||||
"petri_net_image_path": petri_image_path,
|
||||
"dfg_image_path": dfg_image_path,
|
||||
"stats": {
|
||||
"activities": n_activities,
|
||||
"variants": len(variants),
|
||||
"cases": n_cases,
|
||||
},
|
||||
}
|
||||
logger.info("Decouverte BPMN terminee : %s", result["stats"])
|
||||
return result
|
||||
|
||||
|
||||
# ===========================================================================
|
||||
# KPIs de process mining
|
||||
# ===========================================================================
|
||||
|
||||
|
||||
def compute_kpis(event_log_df: pd.DataFrame) -> dict:
|
||||
"""
|
||||
Calcule les KPIs de process mining.
|
||||
|
||||
Returns:
|
||||
{
|
||||
'total_cases': int,
|
||||
'total_events': int,
|
||||
'unique_activities': int,
|
||||
'variants_count': int,
|
||||
'variants_top5': list,
|
||||
'avg_case_duration_seconds': float,
|
||||
'median_case_duration_seconds': float,
|
||||
'avg_events_per_case': float,
|
||||
'activity_stats': {
|
||||
'<activity_name>': {
|
||||
'count': int,
|
||||
'avg_duration_seconds': float,
|
||||
'min_duration_seconds': float,
|
||||
'max_duration_seconds': float,
|
||||
}
|
||||
},
|
||||
'bottlenecks': [...], # top 3 activites les plus lentes
|
||||
'app_distribution': { '<app_name>': int },
|
||||
}
|
||||
"""
|
||||
if event_log_df.empty:
|
||||
return {
|
||||
"total_cases": 0,
|
||||
"total_events": 0,
|
||||
"unique_activities": 0,
|
||||
"variants_count": 0,
|
||||
"variants_top5": [],
|
||||
"avg_case_duration_seconds": 0.0,
|
||||
"median_case_duration_seconds": 0.0,
|
||||
"avg_events_per_case": 0.0,
|
||||
"activity_stats": {},
|
||||
"bottlenecks": [],
|
||||
"app_distribution": {},
|
||||
}
|
||||
|
||||
df = event_log_df.copy()
|
||||
|
||||
# ---- Metriques globales ----
|
||||
total_cases = df["case:concept:name"].nunique()
|
||||
total_events = len(df)
|
||||
unique_activities = df["concept:name"].nunique()
|
||||
|
||||
# ---- Variantes (PM4Py) ----
|
||||
if PM4PY_AVAILABLE:
|
||||
variants = pm4py.get_variants(df)
|
||||
variants_count = len(variants)
|
||||
# Top 5 variantes par frequence
|
||||
sorted_variants = sorted(variants.items(), key=lambda x: x[1], reverse=True)
|
||||
variants_top5 = [
|
||||
{"variant": " -> ".join(v), "count": c}
|
||||
for v, c in sorted_variants[:5]
|
||||
]
|
||||
else:
|
||||
variants_count = 0
|
||||
variants_top5 = []
|
||||
|
||||
# ---- Duree par case ----
|
||||
case_durations: List[float] = []
|
||||
for _case_id, group in df.groupby("case:concept:name"):
|
||||
ts = group["time:timestamp"]
|
||||
if len(ts) >= 2:
|
||||
duration = (ts.max() - ts.min()).total_seconds()
|
||||
case_durations.append(duration)
|
||||
|
||||
avg_case_dur = float(pd.Series(case_durations).mean()) if case_durations else 0.0
|
||||
median_case_dur = float(pd.Series(case_durations).median()) if case_durations else 0.0
|
||||
avg_events_per_case = total_events / total_cases if total_cases > 0 else 0.0
|
||||
|
||||
# ---- Stats par activite ----
|
||||
activity_stats: Dict[str, Dict[str, Any]] = {}
|
||||
# Calculer la duree entre chaque evenement et le suivant dans le meme case
|
||||
df_sorted = df.sort_values(["case:concept:name", "time:timestamp"])
|
||||
df_sorted["next_timestamp"] = df_sorted.groupby("case:concept:name")[
|
||||
"time:timestamp"
|
||||
].shift(-1)
|
||||
df_sorted["duration_to_next"] = (
|
||||
df_sorted["next_timestamp"] - df_sorted["time:timestamp"]
|
||||
).dt.total_seconds()
|
||||
|
||||
for activity, grp in df_sorted.groupby("concept:name"):
|
||||
durations = grp["duration_to_next"].dropna()
|
||||
# Filtrer les durees aberrantes (> 5 min = probablement une pause)
|
||||
durations = durations[durations <= 300]
|
||||
stats: Dict[str, Any] = {
|
||||
"count": len(grp),
|
||||
"avg_duration_seconds": round(float(durations.mean()), 2) if len(durations) > 0 else 0.0,
|
||||
"min_duration_seconds": round(float(durations.min()), 2) if len(durations) > 0 else 0.0,
|
||||
"max_duration_seconds": round(float(durations.max()), 2) if len(durations) > 0 else 0.0,
|
||||
}
|
||||
activity_stats[activity] = stats
|
||||
|
||||
# ---- Goulots d'etranglement (top 3 activites les plus lentes) ----
|
||||
bottlenecks = sorted(
|
||||
[
|
||||
{"activity": act, "avg_duration_seconds": s["avg_duration_seconds"]}
|
||||
for act, s in activity_stats.items()
|
||||
if s["avg_duration_seconds"] > 0
|
||||
],
|
||||
key=lambda x: x["avg_duration_seconds"],
|
||||
reverse=True,
|
||||
)[:3]
|
||||
|
||||
# ---- Distribution par application ----
|
||||
app_distribution: Dict[str, int] = {}
|
||||
if "app_name" in df.columns:
|
||||
app_distribution = df["app_name"].value_counts().to_dict()
|
||||
|
||||
return {
|
||||
"total_cases": total_cases,
|
||||
"total_events": total_events,
|
||||
"unique_activities": unique_activities,
|
||||
"variants_count": variants_count,
|
||||
"variants_top5": variants_top5,
|
||||
"avg_case_duration_seconds": round(avg_case_dur, 2),
|
||||
"median_case_duration_seconds": round(median_case_dur, 2),
|
||||
"avg_events_per_case": round(avg_events_per_case, 1),
|
||||
"activity_stats": activity_stats,
|
||||
"bottlenecks": bottlenecks,
|
||||
"app_distribution": app_distribution,
|
||||
}
|
||||
|
||||
|
||||
# ===========================================================================
|
||||
# Helpers : chargement sessions JSONL
|
||||
# ===========================================================================
|
||||
|
||||
|
||||
def load_jsonl_session(jsonl_path: str) -> List[dict]:
|
||||
"""
|
||||
Charge un fichier live_events.jsonl en liste de dicts.
|
||||
|
||||
Ignore les lignes vides ou invalides.
|
||||
"""
|
||||
events: List[dict] = []
|
||||
path = Path(jsonl_path)
|
||||
if not path.exists():
|
||||
raise FileNotFoundError(f"Fichier JSONL introuvable : {jsonl_path}")
|
||||
|
||||
with open(path, "r", encoding="utf-8") as f:
|
||||
for line_num, line in enumerate(f, 1):
|
||||
line = line.strip()
|
||||
if not line:
|
||||
continue
|
||||
try:
|
||||
events.append(json.loads(line))
|
||||
except json.JSONDecodeError as e:
|
||||
logger.warning("Ligne %d invalide dans %s : %s", line_num, jsonl_path, e)
|
||||
|
||||
logger.info("Charge %d evenements depuis %s", len(events), jsonl_path)
|
||||
return events
|
||||
|
||||
|
||||
def load_multiple_sessions(session_dirs: List[str]) -> List[dict]:
|
||||
"""
|
||||
Charge plusieurs sessions depuis leurs repertoires.
|
||||
|
||||
Cherche un fichier live_events.jsonl dans chaque repertoire.
|
||||
"""
|
||||
all_events: List[dict] = []
|
||||
for session_dir in session_dirs:
|
||||
jsonl_path = Path(session_dir) / "live_events.jsonl"
|
||||
if jsonl_path.exists():
|
||||
all_events.extend(load_jsonl_session(str(jsonl_path)))
|
||||
else:
|
||||
logger.warning("Pas de live_events.jsonl dans %s", session_dir)
|
||||
return all_events
|
||||
60
core/analytics/screen_change_detector.py
Normal file
60
core/analytics/screen_change_detector.py
Normal file
@@ -0,0 +1,60 @@
|
||||
"""
|
||||
Détection rapide de changement d'écran via perceptual hash (pHash).
|
||||
|
||||
Utilise imagehash pour calculer un hash perceptuel par screenshot.
|
||||
La distance de Hamming entre deux hashes indique le degré de changement :
|
||||
- < 5 : même écran (bruit, curseur déplacé)
|
||||
- 5-15 : changement mineur (scroll, popup, champ rempli)
|
||||
- > 15 : nouvel écran (nouvelle fenêtre, navigation)
|
||||
|
||||
Performance : ~15ms par hash sur CPU pour des screenshots 2560x1600.
|
||||
"""
|
||||
|
||||
from PIL import Image
|
||||
import imagehash
|
||||
from typing import Tuple, Optional
|
||||
from enum import Enum
|
||||
|
||||
|
||||
class ScreenChangeLevel(Enum):
|
||||
SAME = "same" # distance < 5
|
||||
MINOR = "minor" # 5 <= distance < 15
|
||||
MAJOR = "major" # distance >= 15
|
||||
|
||||
|
||||
def compute_phash(image: Image.Image, hash_size: int = 8) -> imagehash.ImageHash:
|
||||
"""Calcule le pHash d'une image PIL."""
|
||||
return imagehash.phash(image, hash_size=hash_size)
|
||||
|
||||
|
||||
def compare_screenshots(img1: Image.Image, img2: Image.Image, hash_size: int = 8) -> Tuple[int, ScreenChangeLevel]:
|
||||
"""
|
||||
Compare deux screenshots et retourne la distance + le niveau de changement.
|
||||
|
||||
Returns:
|
||||
(distance, level) — distance de Hamming et niveau de changement
|
||||
"""
|
||||
h1 = compute_phash(img1, hash_size)
|
||||
h2 = compute_phash(img2, hash_size)
|
||||
distance = h1 - h2
|
||||
|
||||
if distance < 5:
|
||||
level = ScreenChangeLevel.SAME
|
||||
elif distance < 15:
|
||||
level = ScreenChangeLevel.MINOR
|
||||
else:
|
||||
level = ScreenChangeLevel.MAJOR
|
||||
|
||||
return distance, level
|
||||
|
||||
|
||||
def compare_hashes(hash1: imagehash.ImageHash, hash2: imagehash.ImageHash) -> Tuple[int, ScreenChangeLevel]:
|
||||
"""Compare deux hashes pré-calculés."""
|
||||
distance = hash1 - hash2
|
||||
if distance < 5:
|
||||
level = ScreenChangeLevel.SAME
|
||||
elif distance < 15:
|
||||
level = ScreenChangeLevel.MINOR
|
||||
else:
|
||||
level = ScreenChangeLevel.MAJOR
|
||||
return distance, level
|
||||
@@ -42,6 +42,8 @@ class TimeSeriesStore:
|
||||
ON execution_metrics(started_at);
|
||||
|
||||
-- Step metrics table
|
||||
-- Les colonnes ocr_ms, ui_ms, analyze_ms, total_ms, cache_hit, degraded
|
||||
-- proviennent de l'instrumentation vision-aware (C1) de ExecutionLoop.
|
||||
CREATE TABLE IF NOT EXISTS step_metrics (
|
||||
step_id TEXT PRIMARY KEY,
|
||||
execution_id TEXT NOT NULL,
|
||||
@@ -56,6 +58,12 @@ class TimeSeriesStore:
|
||||
confidence_score REAL,
|
||||
retry_count INTEGER DEFAULT 0,
|
||||
error_details TEXT,
|
||||
ocr_ms REAL DEFAULT 0.0,
|
||||
ui_ms REAL DEFAULT 0.0,
|
||||
analyze_ms REAL DEFAULT 0.0,
|
||||
total_ms REAL DEFAULT 0.0,
|
||||
cache_hit INTEGER DEFAULT 0,
|
||||
degraded INTEGER DEFAULT 0,
|
||||
FOREIGN KEY (execution_id) REFERENCES execution_metrics(execution_id)
|
||||
);
|
||||
|
||||
@@ -101,12 +109,41 @@ class TimeSeriesStore:
|
||||
|
||||
logger.info(f"TimeSeriesStore initialized at {self.db_path}")
|
||||
|
||||
# Colonnes ajoutées ultérieurement — appliquées via ALTER TABLE si absentes.
|
||||
# (C1 — instrumentation vision-aware, avril 2026)
|
||||
_STEP_METRICS_MIGRATIONS = [
|
||||
("ocr_ms", "REAL DEFAULT 0.0"),
|
||||
("ui_ms", "REAL DEFAULT 0.0"),
|
||||
("analyze_ms", "REAL DEFAULT 0.0"),
|
||||
("total_ms", "REAL DEFAULT 0.0"),
|
||||
("cache_hit", "INTEGER DEFAULT 0"),
|
||||
("degraded", "INTEGER DEFAULT 0"),
|
||||
]
|
||||
|
||||
def _init_database(self) -> None:
|
||||
"""Initialize database schema."""
|
||||
"""Initialize database schema and apply lightweight migrations."""
|
||||
with self._get_connection() as conn:
|
||||
conn.executescript(self.SCHEMA)
|
||||
self._migrate_step_metrics(conn)
|
||||
conn.commit()
|
||||
|
||||
def _migrate_step_metrics(self, conn: sqlite3.Connection) -> None:
|
||||
"""Ajoute les colonnes C1 sur une base `step_metrics` pré-existante."""
|
||||
cursor = conn.execute("PRAGMA table_info(step_metrics)")
|
||||
existing = {row[1] for row in cursor.fetchall()}
|
||||
for column, ddl in self._STEP_METRICS_MIGRATIONS:
|
||||
if column not in existing:
|
||||
try:
|
||||
conn.execute(
|
||||
f"ALTER TABLE step_metrics ADD COLUMN {column} {ddl}"
|
||||
)
|
||||
logger.info(
|
||||
f"Migration step_metrics: ajout colonne {column}"
|
||||
)
|
||||
except sqlite3.OperationalError as e:
|
||||
# Collision bénigne (colonne déjà ajoutée par un autre process)
|
||||
logger.debug(f"Migration colonne {column} ignorée: {e}")
|
||||
|
||||
@contextmanager
|
||||
def _get_connection(self):
|
||||
"""Get database connection context manager."""
|
||||
@@ -164,13 +201,14 @@ class TimeSeriesStore:
|
||||
))
|
||||
|
||||
def _write_step_metric(self, conn: sqlite3.Connection, metric: StepMetrics) -> None:
|
||||
"""Write step metric."""
|
||||
"""Write step metric (inclut les champs vision-aware C1)."""
|
||||
conn.execute("""
|
||||
INSERT OR REPLACE INTO step_metrics
|
||||
(step_id, execution_id, workflow_id, node_id, action_type, target_element,
|
||||
started_at, completed_at, duration_ms, status, confidence_score,
|
||||
retry_count, error_details)
|
||||
VALUES (?, ?, ?, ?, ?, ?, ?, ?, ?, ?, ?, ?, ?)
|
||||
retry_count, error_details,
|
||||
ocr_ms, ui_ms, analyze_ms, total_ms, cache_hit, degraded)
|
||||
VALUES (?, ?, ?, ?, ?, ?, ?, ?, ?, ?, ?, ?, ?, ?, ?, ?, ?, ?, ?)
|
||||
""", (
|
||||
metric.step_id,
|
||||
metric.execution_id,
|
||||
@@ -184,7 +222,13 @@ class TimeSeriesStore:
|
||||
metric.status,
|
||||
metric.confidence_score,
|
||||
metric.retry_count,
|
||||
metric.error_details
|
||||
metric.error_details,
|
||||
getattr(metric, 'ocr_ms', 0.0),
|
||||
getattr(metric, 'ui_ms', 0.0),
|
||||
getattr(metric, 'analyze_ms', 0.0),
|
||||
getattr(metric, 'total_ms', 0.0),
|
||||
1 if getattr(metric, 'cache_hit', False) else 0,
|
||||
1 if getattr(metric, 'degraded', False) else 0,
|
||||
))
|
||||
|
||||
def _write_resource_metric(self, conn: sqlite3.Connection, metric: ResourceMetrics) -> None:
|
||||
|
||||
31
core/anonymisation/__init__.py
Normal file
31
core/anonymisation/__init__.py
Normal file
@@ -0,0 +1,31 @@
|
||||
# core/anonymisation/__init__.py
|
||||
"""Module de floutage ciblé des PII côté serveur.
|
||||
|
||||
Remplace l'ancien blur client-side (`agent_v0/agent_v1/vision/blur_sensitive.py`)
|
||||
qui floutait toutes les zones de texte claires, cassant les codes CIM, les
|
||||
montants PMSI et les boutons.
|
||||
|
||||
Stratégie :
|
||||
1. OCR (docTR) sur le screenshot → texte + bounding boxes
|
||||
2. NER (EDS-NLP si disponible, sinon regex) → détection des PII
|
||||
3. Filtrage : ne conserver que PERSON / LOCATION / PHONE / NIR / EMAIL
|
||||
4. Blur gaussien uniquement sur les bbox des PII filtrées
|
||||
|
||||
Usage :
|
||||
from core.anonymisation import blur_pii_on_image
|
||||
blurred_path = blur_pii_on_image("shot_0001_full.png")
|
||||
"""
|
||||
|
||||
from .pii_blur import (
|
||||
PIIBlurResult,
|
||||
PIIEntity,
|
||||
PIIBlurrer,
|
||||
blur_pii_on_image,
|
||||
)
|
||||
|
||||
__all__ = [
|
||||
"PIIBlurResult",
|
||||
"PIIEntity",
|
||||
"PIIBlurrer",
|
||||
"blur_pii_on_image",
|
||||
]
|
||||
650
core/anonymisation/pii_blur.py
Normal file
650
core/anonymisation/pii_blur.py
Normal file
@@ -0,0 +1,650 @@
|
||||
# core/anonymisation/pii_blur.py
|
||||
"""Floutage ciblé des PII côté serveur (Personal Identifiable Information).
|
||||
|
||||
Contexte
|
||||
--------
|
||||
L'ancien blur côté client (`agent_v0/agent_v1/vision/blur_sensitive.py`) était
|
||||
trop agressif : il floutait TOUTES les zones blanches avec texte, ce qui
|
||||
détruisait les codes CIM-10, les montants PMSI, les boutons et rendait les
|
||||
screenshots inutilisables pour le replay ou le grounding VLM. De plus,
|
||||
`opencv-python` n'était pas listé dans les dépendances de l'agent, donc le blur
|
||||
échouait silencieusement en production.
|
||||
|
||||
Stratégie retenue (avril 2026)
|
||||
------------------------------
|
||||
1. Agent = zéro blur → envoie les screenshots bruts via TLS.
|
||||
2. Serveur = OCR (docTR) + NER (EDS-NLP avec fallback regex).
|
||||
3. On floute UNIQUEMENT les entités :
|
||||
- PERSON → noms, prénoms
|
||||
- LOCATION → adresses, villes
|
||||
- PHONE → numéros de téléphone
|
||||
- NIR → numéro de sécurité sociale
|
||||
- EMAIL → adresses électroniques
|
||||
Et on préserve :
|
||||
- codes CIM-10 / CCAM
|
||||
- montants (1250€, 31,50 €)
|
||||
- dates (pas PII au sens RGPD santé)
|
||||
- identifiants techniques (shot_0001, session IDs…)
|
||||
4. Deux fichiers sont stockés :
|
||||
- `shot_XXXX_full.png` → version brute (accès restreint)
|
||||
- `shot_XXXX_full_blurred.png` → version pour affichage
|
||||
|
||||
Performance
|
||||
-----------
|
||||
Objectif : < 2 s par screenshot sur RTX 5070.
|
||||
docTR (db_mobilenet_v3_large + crnn_mobilenet_v3_large) : ~800 ms CPU, ~300 ms GPU.
|
||||
EDS-NLP pipeline minimal : ~100 ms pour un texte d'écran typique.
|
||||
Fallback regex : < 10 ms.
|
||||
"""
|
||||
|
||||
from __future__ import annotations
|
||||
|
||||
import logging
|
||||
import os
|
||||
import re
|
||||
import tempfile
|
||||
import time
|
||||
from dataclasses import dataclass, field
|
||||
from pathlib import Path
|
||||
from typing import Iterable, List, Optional, Sequence, Tuple, Union
|
||||
|
||||
logger = logging.getLogger(__name__)
|
||||
|
||||
|
||||
# =============================================================================
|
||||
# Types
|
||||
# =============================================================================
|
||||
|
||||
# Type d'entité PII reconnu. Aligné sur les labels EDS-NLP (`nlp.pipes.eds`)
|
||||
# et enrichi par nos propres patterns regex.
|
||||
PII_LABELS = frozenset({
|
||||
"PERSON", # noms de patient, médecin
|
||||
"LOCATION", # adresses, ville, code postal
|
||||
"ADDRESS", # alias de LOCATION (certains pipelines le produisent)
|
||||
"PHONE", # téléphone
|
||||
"NIR", # numéro sécu FR (15 chiffres)
|
||||
"SECURITY_NUMBER", # alias de NIR
|
||||
"EMAIL", # adresse email
|
||||
})
|
||||
|
||||
# Motifs qu'on NE DOIT PAS flouter même s'ils ressemblent à des PII :
|
||||
# - codes CIM-10 : 1 lettre + 2 chiffres + optionnellement .xx
|
||||
# - codes CCAM : 4 lettres + 3 chiffres
|
||||
# - montants (€, euros)
|
||||
# - dates format fr (dd/mm/yyyy, dd-mm-yy)
|
||||
# - identifiants techniques (ex: shot_0001, session_xxxxx)
|
||||
_RE_ICD10 = re.compile(r"\b[A-Z]\d{2}(\.\d{1,3})?\b")
|
||||
_RE_CCAM = re.compile(r"\b[A-Z]{4}\d{3}\b")
|
||||
_RE_MONEY = re.compile(r"\b\d{1,3}(?:[.,\s]\d{3})*(?:[.,]\d{1,2})?\s?€\b", re.IGNORECASE)
|
||||
_RE_DATE = re.compile(r"\b(0?[1-9]|[12]\d|3[01])[/.-](0?[1-9]|1[0-2])[/.-](\d{2}|\d{4})\b")
|
||||
_RE_TECH_ID = re.compile(r"\b(?:shot|session|sess|frame|trace|req|msg)_[\w-]+\b", re.IGNORECASE)
|
||||
|
||||
|
||||
# =============================================================================
|
||||
# Entités PII
|
||||
# =============================================================================
|
||||
|
||||
@dataclass(frozen=True)
|
||||
class PIIEntity:
|
||||
"""Une entité PII détectée dans un screenshot."""
|
||||
label: str # PERSON, LOCATION, PHONE, NIR, EMAIL
|
||||
text: str # Texte brut détecté
|
||||
bbox: Tuple[int, int, int, int] # (x1, y1, x2, y2) en pixels
|
||||
confidence: float = 1.0 # Score NER (1.0 si regex)
|
||||
source: str = "ner" # "ner" (EDS-NLP) ou "regex"
|
||||
|
||||
|
||||
@dataclass
|
||||
class PIIBlurResult:
|
||||
"""Résultat du pipeline de blur."""
|
||||
raw_path: Path
|
||||
blurred_path: Path
|
||||
entities: List[PIIEntity] = field(default_factory=list)
|
||||
elapsed_ms: float = 0.0
|
||||
ocr_ms: float = 0.0
|
||||
ner_ms: float = 0.0
|
||||
blur_ms: float = 0.0
|
||||
ocr_engine: str = "doctr"
|
||||
ner_engine: str = "regex" # ou "edsnlp"
|
||||
|
||||
@property
|
||||
def count(self) -> int:
|
||||
return len(self.entities)
|
||||
|
||||
|
||||
# =============================================================================
|
||||
# Fallback NER par regex (utilisé si EDS-NLP indisponible)
|
||||
# =============================================================================
|
||||
|
||||
# Précaution : on ne marque comme PHONE que des suites contiguës de 10 chiffres
|
||||
# (FR) ou un format international. Les codes à 3-4 chiffres sont ignorés.
|
||||
_RE_PHONE = re.compile(
|
||||
r"\b(?:(?:\+?33|0)\s?[1-9])(?:[\s.-]?\d{2}){4}\b"
|
||||
)
|
||||
_RE_NIR = re.compile(
|
||||
r"\b[12]\s?\d{2}\s?(?:0[1-9]|1[0-2]|20)\s?(?:\d{2}|2A|2B)\s?\d{3}\s?\d{3}(?:\s?\d{2})?\b"
|
||||
)
|
||||
_RE_EMAIL = re.compile(
|
||||
r"\b[A-Z0-9._%+-]+@[A-Z0-9.-]+\.[A-Z]{2,}\b", re.IGNORECASE
|
||||
)
|
||||
# Nom : Prénom Nom (au moins 2 majuscules initiales). Attrape aussi
|
||||
# "Mme Dupont", "M. Martin", "Dr. Bernard".
|
||||
# On utilise [^\S\n] (whitespace SANS newline) pour empêcher le match de sauter
|
||||
# de ligne — les lignes sont typiquement des champs distincts dans une UI métier.
|
||||
_RE_PERSON = re.compile(
|
||||
r"\b(?:M\.?|Mme|Mlle|Dr\.?|Pr\.?|Prof\.?)[^\S\n]+"
|
||||
r"[A-ZÉÈÀÂÎÔÛÇ][a-zéèàâîôûç\-]+"
|
||||
r"(?:[^\S\n]+[A-ZÉÈÀÂÎÔÛÇ][a-zéèàâîôûç\-]+)?"
|
||||
)
|
||||
# Adresse : "12 rue de la Paix", "3, avenue Victor Hugo"
|
||||
# Même principe : on empêche le matching de franchir les sauts de ligne.
|
||||
_RE_ADDRESS = re.compile(
|
||||
r"\b\d{1,4}(?:[^\S\n]?(?:bis|ter|quater))?[,\s]+(?:rue|avenue|av\.?|bd|boulevard|"
|
||||
r"allée|all\.?|place|impasse|chemin|route|rte\.?|quai|cours|voie|passage)"
|
||||
r"[^\S\n]+(?:de[^\S\n]+|du[^\S\n]+|des[^\S\n]+|la[^\S\n]+|le[^\S\n]+|les[^\S\n]+|l'|de[^\S\n]+la[^\S\n]+|d')?"
|
||||
r"[A-Za-zÀ-ÿ\-' ]{2,40}",
|
||||
re.IGNORECASE,
|
||||
)
|
||||
|
||||
|
||||
def _regex_find_pii(text: str) -> List[Tuple[str, int, int]]:
|
||||
"""Retourne une liste de (label, offset_debut, offset_fin) par regex.
|
||||
|
||||
Les motifs "techniques" (codes CIM, montants, dates) sont explicitement
|
||||
exclus même si un autre regex les attrape.
|
||||
"""
|
||||
# 1. Collecter toutes les plages à NE PAS flouter
|
||||
protected: List[Tuple[int, int]] = []
|
||||
for rx in (_RE_ICD10, _RE_CCAM, _RE_MONEY, _RE_DATE, _RE_TECH_ID):
|
||||
for m in rx.finditer(text):
|
||||
protected.append(m.span())
|
||||
|
||||
def _is_protected(start: int, end: int) -> bool:
|
||||
for p_start, p_end in protected:
|
||||
# recouvrement non nul
|
||||
if start < p_end and end > p_start:
|
||||
return True
|
||||
return False
|
||||
|
||||
hits: List[Tuple[str, int, int]] = []
|
||||
for label, rx in (
|
||||
("NIR", _RE_NIR),
|
||||
("EMAIL", _RE_EMAIL),
|
||||
("PHONE", _RE_PHONE),
|
||||
("PERSON", _RE_PERSON),
|
||||
("LOCATION", _RE_ADDRESS),
|
||||
):
|
||||
for m in rx.finditer(text):
|
||||
if _is_protected(m.start(), m.end()):
|
||||
continue
|
||||
hits.append((label, m.start(), m.end()))
|
||||
return hits
|
||||
|
||||
|
||||
# =============================================================================
|
||||
# NER via EDS-NLP (optionnel)
|
||||
# =============================================================================
|
||||
|
||||
_edsnlp_pipeline = None
|
||||
|
||||
|
||||
def _get_edsnlp_pipeline():
|
||||
"""Charge une pipeline EDS-NLP si le module est disponible.
|
||||
|
||||
Retourne None si EDS-NLP n'est pas installé — le pipeline retombera
|
||||
alors sur le NER regex.
|
||||
"""
|
||||
global _edsnlp_pipeline
|
||||
if _edsnlp_pipeline is not None:
|
||||
return _edsnlp_pipeline
|
||||
try:
|
||||
import edsnlp # type: ignore
|
||||
except ImportError:
|
||||
logger.info(
|
||||
"EDS-NLP non installé — fallback regex utilisé pour la détection PII. "
|
||||
"Pour activer EDS-NLP : pip install edsnlp"
|
||||
)
|
||||
return None
|
||||
|
||||
try:
|
||||
nlp = edsnlp.blank("eds")
|
||||
nlp.add_pipe("eds.sentences")
|
||||
nlp.add_pipe("eds.normalizer")
|
||||
# Les composants disponibles dépendent de la version installée.
|
||||
# On les ajoute en try/except pour rester résilient.
|
||||
for pipe_name in ("eds.names", "eds.dates", "eds.addresses"):
|
||||
try:
|
||||
nlp.add_pipe(pipe_name)
|
||||
except Exception as e: # noqa: BLE001
|
||||
logger.debug("EDS-NLP : composant %s indisponible (%s)", pipe_name, e)
|
||||
_edsnlp_pipeline = nlp
|
||||
logger.info("EDS-NLP : pipeline chargée")
|
||||
return _edsnlp_pipeline
|
||||
except Exception as e: # noqa: BLE001
|
||||
logger.warning("EDS-NLP non utilisable (%s) — fallback regex", e)
|
||||
return None
|
||||
|
||||
|
||||
def _edsnlp_find_pii(text: str, nlp) -> List[Tuple[str, int, int]]:
|
||||
"""Utilise EDS-NLP pour trouver des entités PII.
|
||||
|
||||
Les labels EDS-NLP sont mappés vers nos labels canoniques.
|
||||
"""
|
||||
try:
|
||||
doc = nlp(text)
|
||||
except Exception as e: # noqa: BLE001
|
||||
logger.debug("EDS-NLP : échec sur texte de %d chars (%s)", len(text), e)
|
||||
return []
|
||||
|
||||
mapping = {
|
||||
"person": "PERSON",
|
||||
"name": "PERSON",
|
||||
"patient": "PERSON",
|
||||
"doctor": "PERSON",
|
||||
"location": "LOCATION",
|
||||
"address": "LOCATION",
|
||||
"city": "LOCATION",
|
||||
}
|
||||
hits: List[Tuple[str, int, int]] = []
|
||||
for ent in getattr(doc, "ents", []):
|
||||
raw_label = str(getattr(ent, "label_", "")).lower()
|
||||
mapped = mapping.get(raw_label)
|
||||
if mapped is None:
|
||||
# On accepte aussi si le label EDS-NLP est déjà l'un de nos labels
|
||||
upper = raw_label.upper()
|
||||
if upper in PII_LABELS:
|
||||
mapped = upper
|
||||
if mapped:
|
||||
hits.append((mapped, ent.start_char, ent.end_char))
|
||||
return hits
|
||||
|
||||
|
||||
# =============================================================================
|
||||
# OCR avec bounding boxes par mot (docTR)
|
||||
# =============================================================================
|
||||
|
||||
_ocr_predictor = None
|
||||
|
||||
|
||||
def _get_ocr_predictor():
|
||||
"""Charge un prédicteur docTR léger (mobilenet) pour l'OCR rapide."""
|
||||
global _ocr_predictor
|
||||
if _ocr_predictor is not None:
|
||||
return _ocr_predictor
|
||||
from doctr.models import ocr_predictor # type: ignore
|
||||
_ocr_predictor = ocr_predictor(
|
||||
det_arch="db_mobilenet_v3_large",
|
||||
reco_arch="crnn_mobilenet_v3_large",
|
||||
pretrained=True,
|
||||
)
|
||||
# GPU si disponible
|
||||
try:
|
||||
import torch # type: ignore
|
||||
if torch.cuda.is_available():
|
||||
_ocr_predictor = _ocr_predictor.cuda()
|
||||
logger.info("pii_blur : docTR chargé sur CUDA")
|
||||
else:
|
||||
logger.info("pii_blur : docTR chargé sur CPU")
|
||||
except Exception: # noqa: BLE001
|
||||
logger.info("pii_blur : docTR chargé (device indéterminé)")
|
||||
return _ocr_predictor
|
||||
|
||||
|
||||
def _doctr_ocr(image_path: Path) -> Tuple[List[dict], int, int]:
|
||||
"""Exécute docTR et retourne une liste de mots avec leurs bbox pixel.
|
||||
|
||||
Retour : (words, width, height) où words = [{text, x1, y1, x2, y2}, ...]
|
||||
"""
|
||||
from doctr.io import DocumentFile # type: ignore
|
||||
from PIL import Image
|
||||
|
||||
predictor = _get_ocr_predictor()
|
||||
doc = DocumentFile.from_images([str(image_path)])
|
||||
result = predictor(doc)
|
||||
|
||||
# Les coords sont normalisées (0..1). On les remappe vers la taille réelle.
|
||||
with Image.open(image_path) as img:
|
||||
W, H = img.size
|
||||
|
||||
words: List[dict] = []
|
||||
line_counter = 0
|
||||
for page in result.pages:
|
||||
for block in page.blocks:
|
||||
for line in block.lines:
|
||||
for word in line.words:
|
||||
text = word.value
|
||||
if not text or not text.strip():
|
||||
continue
|
||||
(nx1, ny1), (nx2, ny2) = word.geometry
|
||||
x1 = max(0, int(nx1 * W))
|
||||
y1 = max(0, int(ny1 * H))
|
||||
x2 = min(W, int(nx2 * W))
|
||||
y2 = min(H, int(ny2 * H))
|
||||
words.append({
|
||||
"text": text,
|
||||
"x1": x1, "y1": y1, "x2": x2, "y2": y2,
|
||||
"line": line_counter,
|
||||
})
|
||||
line_counter += 1
|
||||
return words, W, H
|
||||
|
||||
|
||||
# =============================================================================
|
||||
# Pipeline principal
|
||||
# =============================================================================
|
||||
|
||||
class PIIBlurrer:
|
||||
"""Pipeline réutilisable (garde les modèles en mémoire entre appels).
|
||||
|
||||
Exemple :
|
||||
blurrer = PIIBlurrer()
|
||||
res = blurrer.blur_image("shot_0001_full.png")
|
||||
print(res.count, res.elapsed_ms)
|
||||
"""
|
||||
|
||||
def __init__(
|
||||
self,
|
||||
blur_kernel: Tuple[int, int] = (31, 31),
|
||||
blur_sigma: float = 15.0,
|
||||
bbox_padding: int = 2,
|
||||
use_edsnlp: bool = True,
|
||||
) -> None:
|
||||
self._blur_kernel = blur_kernel
|
||||
self._blur_sigma = blur_sigma
|
||||
self._bbox_padding = bbox_padding
|
||||
self._use_edsnlp = use_edsnlp
|
||||
|
||||
# ------------------------------------------------------------------
|
||||
# Point d'entrée publique
|
||||
# ------------------------------------------------------------------
|
||||
def blur_image(
|
||||
self,
|
||||
input_path: Union[str, Path],
|
||||
output_path: Optional[Union[str, Path]] = None,
|
||||
) -> PIIBlurResult:
|
||||
"""Floute les PII détectées et écrit la version floutée sur disque.
|
||||
|
||||
Args:
|
||||
input_path: Chemin vers le screenshot brut (PNG/JPG).
|
||||
output_path: Chemin de sortie. Défaut :
|
||||
`<stem>_blurred.png` à côté de l'input.
|
||||
|
||||
Returns:
|
||||
PIIBlurResult avec les timings et la liste des entités détectées.
|
||||
"""
|
||||
input_path = Path(input_path)
|
||||
if not input_path.is_file():
|
||||
raise FileNotFoundError(f"Screenshot introuvable : {input_path}")
|
||||
|
||||
if output_path is None:
|
||||
output_path = input_path.with_name(
|
||||
f"{input_path.stem}_blurred{input_path.suffix or '.png'}"
|
||||
)
|
||||
else:
|
||||
output_path = Path(output_path)
|
||||
|
||||
t_start = time.perf_counter()
|
||||
|
||||
# 1. OCR
|
||||
t_ocr = time.perf_counter()
|
||||
try:
|
||||
words, W, H = _doctr_ocr(input_path)
|
||||
except Exception as e: # noqa: BLE001
|
||||
logger.warning("pii_blur : OCR docTR échoué (%s) — pas de blur appliqué", e)
|
||||
# On copie simplement l'original vers la version "blurred"
|
||||
_copy_file(input_path, output_path)
|
||||
return PIIBlurResult(
|
||||
raw_path=input_path,
|
||||
blurred_path=output_path,
|
||||
entities=[],
|
||||
elapsed_ms=(time.perf_counter() - t_start) * 1000,
|
||||
)
|
||||
ocr_ms = (time.perf_counter() - t_ocr) * 1000
|
||||
|
||||
if not words:
|
||||
_copy_file(input_path, output_path)
|
||||
return PIIBlurResult(
|
||||
raw_path=input_path,
|
||||
blurred_path=output_path,
|
||||
entities=[],
|
||||
elapsed_ms=(time.perf_counter() - t_start) * 1000,
|
||||
ocr_ms=ocr_ms,
|
||||
)
|
||||
|
||||
# 2. Reconstituer le texte ligne par ligne en conservant la correspondance
|
||||
# (offset_char → mot) pour pouvoir repérer les bbox des entités.
|
||||
text, char_to_word = _build_text_with_map(words)
|
||||
|
||||
# 3. NER : EDS-NLP si dispo, sinon regex
|
||||
t_ner = time.perf_counter()
|
||||
ner_engine = "regex"
|
||||
entities_spans: List[Tuple[str, int, int]] = []
|
||||
if self._use_edsnlp:
|
||||
nlp = _get_edsnlp_pipeline()
|
||||
if nlp is not None:
|
||||
entities_spans = _edsnlp_find_pii(text, nlp)
|
||||
ner_engine = "edsnlp"
|
||||
# Toujours compléter avec le regex (EDS-NLP ne couvre pas tous les PII
|
||||
# fréquents : email, NIR, téléphone français).
|
||||
entities_spans.extend(_regex_find_pii(text))
|
||||
ner_ms = (time.perf_counter() - t_ner) * 1000
|
||||
|
||||
# Dédupliquer et normaliser
|
||||
entities_spans = _merge_spans(entities_spans)
|
||||
|
||||
# 4. Convertir (label, start, end) → PIIEntity(label, text, bbox pixel)
|
||||
pii_entities: List[PIIEntity] = []
|
||||
for label, start, end in entities_spans:
|
||||
if label not in PII_LABELS:
|
||||
continue
|
||||
bbox = _spans_to_bbox(start, end, char_to_word, words, self._bbox_padding, W, H)
|
||||
if bbox is None:
|
||||
continue
|
||||
pii_entities.append(PIIEntity(
|
||||
label=label,
|
||||
text=text[start:end],
|
||||
bbox=bbox,
|
||||
confidence=1.0,
|
||||
source=("ner" if ner_engine == "edsnlp" else "regex"),
|
||||
))
|
||||
|
||||
# 5. Appliquer le blur gaussien sur les bbox
|
||||
t_blur = time.perf_counter()
|
||||
_apply_blur(input_path, output_path, pii_entities,
|
||||
kernel=self._blur_kernel, sigma=self._blur_sigma)
|
||||
blur_ms = (time.perf_counter() - t_blur) * 1000
|
||||
|
||||
elapsed_ms = (time.perf_counter() - t_start) * 1000
|
||||
if pii_entities:
|
||||
logger.info(
|
||||
"pii_blur : %d PII floutés sur %s (%.0fms : ocr=%.0f ner=%.0f blur=%.0f, ner=%s)",
|
||||
len(pii_entities), input_path.name, elapsed_ms,
|
||||
ocr_ms, ner_ms, blur_ms, ner_engine,
|
||||
)
|
||||
else:
|
||||
logger.debug(
|
||||
"pii_blur : aucune PII détectée dans %s (%.0fms)",
|
||||
input_path.name, elapsed_ms,
|
||||
)
|
||||
|
||||
return PIIBlurResult(
|
||||
raw_path=input_path,
|
||||
blurred_path=output_path,
|
||||
entities=pii_entities,
|
||||
elapsed_ms=elapsed_ms,
|
||||
ocr_ms=ocr_ms,
|
||||
ner_ms=ner_ms,
|
||||
blur_ms=blur_ms,
|
||||
ner_engine=ner_engine,
|
||||
)
|
||||
|
||||
|
||||
# Instance singleton (lazy)
|
||||
_default_blurrer: Optional[PIIBlurrer] = None
|
||||
|
||||
|
||||
def blur_pii_on_image(
|
||||
input_path: Union[str, Path],
|
||||
output_path: Optional[Union[str, Path]] = None,
|
||||
) -> PIIBlurResult:
|
||||
"""Helper fonctionnel : instancie un PIIBlurrer singleton et l'applique."""
|
||||
global _default_blurrer
|
||||
if _default_blurrer is None:
|
||||
_default_blurrer = PIIBlurrer()
|
||||
return _default_blurrer.blur_image(input_path, output_path)
|
||||
|
||||
|
||||
# =============================================================================
|
||||
# Helpers internes
|
||||
# =============================================================================
|
||||
|
||||
def _copy_file(src: Path, dst: Path) -> None:
|
||||
"""Copie bytewise (utilisé quand aucun PII n'est détecté / OCR KO)."""
|
||||
dst.parent.mkdir(parents=True, exist_ok=True)
|
||||
with open(src, "rb") as f_in, open(dst, "wb") as f_out:
|
||||
f_out.write(f_in.read())
|
||||
|
||||
|
||||
def _build_text_with_map(words: Sequence[dict]) -> Tuple[str, List[int]]:
|
||||
"""Concatène les mots en texte + mappe chaque caractère vers son index de mot.
|
||||
|
||||
Quand deux mots consécutifs appartiennent à des lignes différentes (champ
|
||||
`line` dans le dict), on insère un `\n` au lieu d'un espace. Cela empêche
|
||||
les regex gloutons (PERSON, LOCATION…) de matcher à travers des lignes
|
||||
logiques, qui sont typiquement des champs distincts dans une UI métier.
|
||||
|
||||
Returns:
|
||||
text : str concaténé (mots séparés par un espace ou un \n)
|
||||
char_to_word : list[int] len == len(text), char_to_word[i] = index du mot
|
||||
(ou -1 pour les séparateurs).
|
||||
"""
|
||||
parts: List[str] = []
|
||||
char_to_word: List[int] = []
|
||||
prev_line: Optional[int] = None
|
||||
for i, w in enumerate(words):
|
||||
cur_line = w.get("line")
|
||||
if i > 0:
|
||||
if prev_line is not None and cur_line is not None and cur_line != prev_line:
|
||||
sep = "\n"
|
||||
else:
|
||||
sep = " "
|
||||
parts.append(sep)
|
||||
char_to_word.append(-1)
|
||||
txt = w["text"]
|
||||
parts.append(txt)
|
||||
char_to_word.extend([i] * len(txt))
|
||||
prev_line = cur_line
|
||||
return "".join(parts), char_to_word
|
||||
|
||||
|
||||
def _spans_to_bbox(
|
||||
start: int,
|
||||
end: int,
|
||||
char_to_word: Sequence[int],
|
||||
words: Sequence[dict],
|
||||
padding: int,
|
||||
image_w: int,
|
||||
image_h: int,
|
||||
) -> Optional[Tuple[int, int, int, int]]:
|
||||
"""Convertit une plage [start, end[ dans le texte en bbox englobant les mots."""
|
||||
if end <= start or start >= len(char_to_word):
|
||||
return None
|
||||
word_ids = set()
|
||||
for i in range(start, min(end, len(char_to_word))):
|
||||
wid = char_to_word[i]
|
||||
if wid >= 0:
|
||||
word_ids.add(wid)
|
||||
if not word_ids:
|
||||
return None
|
||||
xs1, ys1, xs2, ys2 = [], [], [], []
|
||||
for wid in word_ids:
|
||||
w = words[wid]
|
||||
xs1.append(w["x1"]); ys1.append(w["y1"])
|
||||
xs2.append(w["x2"]); ys2.append(w["y2"])
|
||||
x1 = max(0, min(xs1) - padding)
|
||||
y1 = max(0, min(ys1) - padding)
|
||||
x2 = min(image_w, max(xs2) + padding)
|
||||
y2 = min(image_h, max(ys2) + padding)
|
||||
if x2 <= x1 or y2 <= y1:
|
||||
return None
|
||||
return (x1, y1, x2, y2)
|
||||
|
||||
|
||||
def _merge_spans(
|
||||
spans: Sequence[Tuple[str, int, int]],
|
||||
) -> List[Tuple[str, int, int]]:
|
||||
"""Déduplique et fusionne les plages qui se chevauchent sur un même label.
|
||||
|
||||
En cas de conflit inter-labels, on garde celui qui couvre le plus large.
|
||||
"""
|
||||
if not spans:
|
||||
return []
|
||||
# Trier par start puis par -width (le plus long d'abord pour les ties)
|
||||
sorted_spans = sorted(spans, key=lambda s: (s[1], -(s[2] - s[1])))
|
||||
merged: List[Tuple[str, int, int]] = []
|
||||
for label, s, e in sorted_spans:
|
||||
if not merged:
|
||||
merged.append((label, s, e))
|
||||
continue
|
||||
last_label, ls, le = merged[-1]
|
||||
if s < le: # chevauchement
|
||||
# On garde l'étendue fusionnée avec le label du plus large
|
||||
new_start = min(ls, s)
|
||||
new_end = max(le, e)
|
||||
new_label = last_label if (le - ls) >= (e - s) else label
|
||||
merged[-1] = (new_label, new_start, new_end)
|
||||
else:
|
||||
merged.append((label, s, e))
|
||||
return merged
|
||||
|
||||
|
||||
def _apply_blur(
|
||||
src: Path,
|
||||
dst: Path,
|
||||
entities: Sequence[PIIEntity],
|
||||
kernel: Tuple[int, int],
|
||||
sigma: float,
|
||||
) -> None:
|
||||
"""Applique un flou gaussien sur les bbox des entités et écrit l'image."""
|
||||
from PIL import Image
|
||||
|
||||
with Image.open(src) as img:
|
||||
if img.mode != "RGB":
|
||||
img = img.convert("RGB")
|
||||
|
||||
if not entities:
|
||||
dst.parent.mkdir(parents=True, exist_ok=True)
|
||||
img.save(dst, format="PNG", optimize=True)
|
||||
return
|
||||
|
||||
# On privilégie OpenCV s'il est disponible (plus rapide),
|
||||
# sinon on utilise PIL ImageFilter.GaussianBlur.
|
||||
try:
|
||||
import cv2 # type: ignore
|
||||
import numpy as np # type: ignore
|
||||
arr = np.array(img)
|
||||
bgr = cv2.cvtColor(arr, cv2.COLOR_RGB2BGR)
|
||||
for ent in entities:
|
||||
x1, y1, x2, y2 = ent.bbox
|
||||
if x2 <= x1 or y2 <= y1:
|
||||
continue
|
||||
roi = bgr[y1:y2, x1:x2]
|
||||
if roi.size == 0:
|
||||
continue
|
||||
k = (max(3, kernel[0] | 1), max(3, kernel[1] | 1)) # impair
|
||||
bgr[y1:y2, x1:x2] = cv2.GaussianBlur(roi, k, sigma)
|
||||
out = cv2.cvtColor(bgr, cv2.COLOR_BGR2RGB)
|
||||
img = Image.fromarray(out)
|
||||
except ImportError:
|
||||
from PIL import ImageFilter
|
||||
radius = max(sigma / 2, 4.0)
|
||||
for ent in entities:
|
||||
x1, y1, x2, y2 = ent.bbox
|
||||
region = img.crop((x1, y1, x2, y2))
|
||||
if region.size[0] == 0 or region.size[1] == 0:
|
||||
continue
|
||||
blurred = region.filter(ImageFilter.GaussianBlur(radius=radius))
|
||||
img.paste(blurred, (x1, y1))
|
||||
|
||||
dst.parent.mkdir(parents=True, exist_ok=True)
|
||||
img.save(dst, format="PNG", optimize=True)
|
||||
0
core/cognition/__init__.py
Normal file
0
core/cognition/__init__.py
Normal file
191
core/cognition/vram_orchestrator.py
Normal file
191
core/cognition/vram_orchestrator.py
Normal file
@@ -0,0 +1,191 @@
|
||||
"""
|
||||
Orchestrateur VRAM — gère le chargement/déchargement des modèles selon le mode.
|
||||
|
||||
Deux modes :
|
||||
- SHADOW : streaming server + agent_chat actifs, VLM raisonnement déchargé
|
||||
- REPLAY : VLM raisonnement (qwen2.5vl:7b) chargé, services non-essentiels stoppés
|
||||
|
||||
Bascule automatique ou manuelle selon le contexte.
|
||||
"""
|
||||
|
||||
import logging
|
||||
import os
|
||||
import subprocess
|
||||
import time
|
||||
from enum import Enum
|
||||
from typing import Optional
|
||||
|
||||
logger = logging.getLogger(__name__)
|
||||
|
||||
OLLAMA_URL = os.environ.get("OLLAMA_URL", "http://localhost:11434")
|
||||
REASONING_MODEL = os.environ.get("RPA_REASONING_MODEL", "qwen2.5vl:7b")
|
||||
MIN_VRAM_FOR_REASONING = 5.0 # Go minimum pour charger le modèle de raisonnement
|
||||
|
||||
|
||||
class VRAMMode(Enum):
|
||||
SHADOW = "shadow"
|
||||
REPLAY = "replay"
|
||||
|
||||
|
||||
class VRAMOrchestrator:
|
||||
"""Gère la VRAM pour éviter les conflits entre modèles."""
|
||||
|
||||
def __init__(self):
|
||||
self._current_mode: Optional[VRAMMode] = None
|
||||
self._stopped_services: list = []
|
||||
|
||||
def get_free_vram_gb(self) -> float:
|
||||
"""Retourne la VRAM libre en Go."""
|
||||
try:
|
||||
result = subprocess.run(
|
||||
["nvidia-smi", "--query-gpu=memory.free", "--format=csv,noheader,nounits"],
|
||||
capture_output=True, text=True, timeout=5
|
||||
)
|
||||
return float(result.stdout.strip()) / 1024
|
||||
except Exception:
|
||||
return 0.0
|
||||
|
||||
def get_used_vram_gb(self) -> float:
|
||||
"""Retourne la VRAM utilisée en Go."""
|
||||
try:
|
||||
result = subprocess.run(
|
||||
["nvidia-smi", "--query-gpu=memory.used", "--format=csv,noheader,nounits"],
|
||||
capture_output=True, text=True, timeout=5
|
||||
)
|
||||
return float(result.stdout.strip()) / 1024
|
||||
except Exception:
|
||||
return 0.0
|
||||
|
||||
def switch_to_replay(self) -> bool:
|
||||
"""Bascule en mode replay : libère la VRAM pour le VLM de raisonnement.
|
||||
|
||||
1. Stoppe les services non-essentiels (agent_chat)
|
||||
2. Redémarre Ollama pour libérer les modèles chargés
|
||||
3. Précharge le modèle de raisonnement
|
||||
"""
|
||||
if self._current_mode == VRAMMode.REPLAY:
|
||||
logger.info("Déjà en mode REPLAY")
|
||||
return True
|
||||
|
||||
logger.info("Bascule en mode REPLAY...")
|
||||
|
||||
# Stopper agent_chat si il tourne
|
||||
try:
|
||||
result = subprocess.run(
|
||||
["pgrep", "-f", "agent_chat"],
|
||||
capture_output=True, text=True, timeout=5
|
||||
)
|
||||
pids = result.stdout.strip().split('\n')
|
||||
for pid in pids:
|
||||
if pid.strip():
|
||||
subprocess.run(["kill", pid.strip()], timeout=5)
|
||||
self._stopped_services.append(("agent_chat", pid.strip()))
|
||||
logger.info(f"agent_chat stoppé (PID {pid.strip()})")
|
||||
except Exception as e:
|
||||
logger.debug(f"Pas d'agent_chat à stopper: {e}")
|
||||
|
||||
# Redémarrer Ollama pour libérer la mémoire
|
||||
try:
|
||||
subprocess.run(["sudo", "systemctl", "restart", "ollama"],
|
||||
timeout=10, check=True)
|
||||
time.sleep(2)
|
||||
logger.info("Ollama redémarré")
|
||||
except Exception as e:
|
||||
logger.warning(f"Impossible de redémarrer Ollama: {e}")
|
||||
|
||||
# Vérifier la VRAM disponible
|
||||
free = self.get_free_vram_gb()
|
||||
logger.info(f"VRAM libre: {free:.1f} Go")
|
||||
|
||||
if free < MIN_VRAM_FOR_REASONING:
|
||||
logger.warning(f"VRAM insuffisante ({free:.1f} Go < {MIN_VRAM_FOR_REASONING} Go)")
|
||||
return False
|
||||
|
||||
# Précharger le modèle de raisonnement
|
||||
try:
|
||||
import requests
|
||||
logger.info(f"Préchargement {REASONING_MODEL}...")
|
||||
resp = requests.post(f"{OLLAMA_URL}/api/generate", json={
|
||||
"model": REASONING_MODEL,
|
||||
"prompt": "test",
|
||||
"stream": False,
|
||||
"options": {"num_predict": 1}
|
||||
}, timeout=60)
|
||||
if resp.status_code == 200:
|
||||
logger.info(f"{REASONING_MODEL} chargé en VRAM")
|
||||
free_after = self.get_free_vram_gb()
|
||||
logger.info(f"VRAM libre après chargement: {free_after:.1f} Go")
|
||||
except Exception as e:
|
||||
logger.warning(f"Préchargement échoué: {e}")
|
||||
|
||||
self._current_mode = VRAMMode.REPLAY
|
||||
return True
|
||||
|
||||
def switch_to_shadow(self) -> bool:
|
||||
"""Bascule en mode shadow : relance les services d'observation.
|
||||
|
||||
1. Redémarre Ollama (décharge le VLM de raisonnement)
|
||||
2. Relance les services stoppés
|
||||
"""
|
||||
if self._current_mode == VRAMMode.SHADOW:
|
||||
logger.info("Déjà en mode SHADOW")
|
||||
return True
|
||||
|
||||
logger.info("Bascule en mode SHADOW...")
|
||||
|
||||
# Redémarrer Ollama
|
||||
try:
|
||||
subprocess.run(["sudo", "systemctl", "restart", "ollama"],
|
||||
timeout=10, check=True)
|
||||
time.sleep(2)
|
||||
except Exception as e:
|
||||
logger.warning(f"Impossible de redémarrer Ollama: {e}")
|
||||
|
||||
# Relancer les services stoppés
|
||||
for service_name, _pid in self._stopped_services:
|
||||
try:
|
||||
if service_name == "agent_chat":
|
||||
subprocess.Popen(
|
||||
["python3", "-m", "agent_chat.app"],
|
||||
cwd="/home/dom/ai/rpa_vision_v3",
|
||||
stdout=subprocess.DEVNULL,
|
||||
stderr=subprocess.DEVNULL
|
||||
)
|
||||
logger.info(f"{service_name} relancé")
|
||||
except Exception as e:
|
||||
logger.warning(f"Impossible de relancer {service_name}: {e}")
|
||||
|
||||
self._stopped_services.clear()
|
||||
self._current_mode = VRAMMode.SHADOW
|
||||
return True
|
||||
|
||||
def ensure_reasoning_ready(self) -> bool:
|
||||
"""Vérifie que le VLM de raisonnement est prêt. Bascule si nécessaire."""
|
||||
free = self.get_free_vram_gb()
|
||||
if free >= MIN_VRAM_FOR_REASONING:
|
||||
return True
|
||||
return self.switch_to_replay()
|
||||
|
||||
@property
|
||||
def current_mode(self) -> Optional[str]:
|
||||
return self._current_mode.value if self._current_mode else None
|
||||
|
||||
def status(self) -> dict:
|
||||
return {
|
||||
"mode": self.current_mode,
|
||||
"vram_free_gb": round(self.get_free_vram_gb(), 1),
|
||||
"vram_used_gb": round(self.get_used_vram_gb(), 1),
|
||||
"reasoning_model": REASONING_MODEL,
|
||||
"stopped_services": [s[0] for s in self._stopped_services],
|
||||
}
|
||||
|
||||
|
||||
# Singleton
|
||||
_orchestrator: Optional[VRAMOrchestrator] = None
|
||||
|
||||
|
||||
def get_orchestrator() -> VRAMOrchestrator:
|
||||
global _orchestrator
|
||||
if _orchestrator is None:
|
||||
_orchestrator = VRAMOrchestrator()
|
||||
return _orchestrator
|
||||
260
core/cognition/working_memory.py
Normal file
260
core/cognition/working_memory.py
Normal file
@@ -0,0 +1,260 @@
|
||||
"""
|
||||
Mémoire de travail de Léa — contexte cognitif pendant l'exécution.
|
||||
|
||||
Donne à Léa la conscience de "où elle en est" :
|
||||
- Quel objectif elle poursuit
|
||||
- Quel écran elle voit
|
||||
- Ce qu'elle vient de faire
|
||||
- Ce qu'elle doit faire ensuite
|
||||
- Ce qu'elle a appris en cours de route
|
||||
|
||||
Sans ça, chaque étape est indépendante — Léa est amnésique entre
|
||||
deux actions. Avec ça, elle raisonne en contexte.
|
||||
"""
|
||||
|
||||
import logging
|
||||
from dataclasses import dataclass, field
|
||||
from datetime import datetime
|
||||
from typing import Any, Dict, List, Optional
|
||||
|
||||
logger = logging.getLogger(__name__)
|
||||
|
||||
|
||||
@dataclass
|
||||
class Observation:
|
||||
"""Ce que Léa observe sur l'écran à un instant donné."""
|
||||
timestamp: datetime
|
||||
window_title: str = ""
|
||||
application: str = ""
|
||||
ocr_text: str = ""
|
||||
ui_pattern: Optional[str] = None
|
||||
screen_description: str = ""
|
||||
confidence: float = 0.0
|
||||
|
||||
|
||||
@dataclass
|
||||
class ActionRecord:
|
||||
"""Une action que Léa a effectuée."""
|
||||
timestamp: datetime
|
||||
action_type: str
|
||||
target: str = ""
|
||||
result: str = ""
|
||||
success: bool = True
|
||||
duration_ms: float = 0
|
||||
|
||||
|
||||
@dataclass
|
||||
class CognitiveContext:
|
||||
"""Contexte cognitif complet — la "pensée" de Léa à un instant donné.
|
||||
|
||||
C'est le bloc-notes interne qui est réinjecté à chaque décision.
|
||||
Le VLM reçoit ce contexte pour raisonner en connaissance de cause.
|
||||
"""
|
||||
|
||||
# Objectif global (ce que Léa essaie d'accomplir)
|
||||
objective: str = ""
|
||||
|
||||
# Étape courante dans le plan
|
||||
current_step: int = 0
|
||||
total_steps: int = 0
|
||||
current_step_description: str = ""
|
||||
|
||||
# Ce que Léa voit maintenant
|
||||
current_observation: Optional[Observation] = None
|
||||
|
||||
# Historique des N dernières actions (mémoire court terme)
|
||||
action_history: List[ActionRecord] = field(default_factory=list)
|
||||
max_history: int = 10
|
||||
|
||||
# Ce que Léa a appris pendant cette session
|
||||
learned_facts: List[str] = field(default_factory=list)
|
||||
|
||||
# Plan : les étapes restantes
|
||||
remaining_steps: List[str] = field(default_factory=list)
|
||||
|
||||
# État émotionnel / confiance
|
||||
confidence: float = 1.0
|
||||
needs_help: bool = False
|
||||
help_reason: str = ""
|
||||
|
||||
# Timing
|
||||
session_id: str = ""
|
||||
machine_id: str = ""
|
||||
started_at: Optional[datetime] = None
|
||||
step_started_at: Optional[datetime] = None
|
||||
step_durations: Dict[str, List[float]] = field(default_factory=dict)
|
||||
|
||||
# Ce que Léa devrait voir à l'écran (comparaison attendu vs réel)
|
||||
expected_screen: str = ""
|
||||
|
||||
def record_action(self, action_type: str, target: str = "",
|
||||
result: str = "", success: bool = True,
|
||||
duration_ms: float = 0):
|
||||
"""Enregistre une action dans l'historique."""
|
||||
self.action_history.append(ActionRecord(
|
||||
timestamp=datetime.now(),
|
||||
action_type=action_type,
|
||||
target=target,
|
||||
result=result,
|
||||
success=success,
|
||||
duration_ms=duration_ms,
|
||||
))
|
||||
if len(self.action_history) > self.max_history:
|
||||
self.action_history = self.action_history[-self.max_history:]
|
||||
|
||||
if not success:
|
||||
self.confidence = max(0, self.confidence - 0.2)
|
||||
else:
|
||||
self.confidence = min(1.0, self.confidence + 0.05)
|
||||
|
||||
def observe(self, window_title: str = "", application: str = "",
|
||||
ocr_text: str = "", ui_pattern: Optional[str] = None,
|
||||
screen_description: str = ""):
|
||||
"""Met à jour l'observation courante."""
|
||||
self.current_observation = Observation(
|
||||
timestamp=datetime.now(),
|
||||
window_title=window_title,
|
||||
application=application,
|
||||
ocr_text=ocr_text,
|
||||
ui_pattern=ui_pattern,
|
||||
screen_description=screen_description,
|
||||
)
|
||||
|
||||
def advance_step(self):
|
||||
"""Passe à l'étape suivante du plan."""
|
||||
# Enregistrer la durée de l'étape précédente
|
||||
if self.step_started_at:
|
||||
duration = (datetime.now() - self.step_started_at).total_seconds()
|
||||
step_key = self.current_step_description or f"step_{self.current_step}"
|
||||
self.step_durations.setdefault(step_key, []).append(duration)
|
||||
|
||||
self.current_step += 1
|
||||
self.step_started_at = datetime.now()
|
||||
if self.remaining_steps:
|
||||
self.current_step_description = self.remaining_steps.pop(0)
|
||||
|
||||
def get_step_timing(self) -> Optional[Dict[str, Any]]:
|
||||
"""Retourne les infos de timing de l'étape en cours."""
|
||||
if not self.step_started_at:
|
||||
return None
|
||||
|
||||
elapsed = (datetime.now() - self.step_started_at).total_seconds()
|
||||
step_key = self.current_step_description or f"step_{self.current_step}"
|
||||
history = self.step_durations.get(step_key, [])
|
||||
avg = sum(history) / len(history) if history else None
|
||||
|
||||
result = {"elapsed_seconds": elapsed}
|
||||
if avg:
|
||||
result["avg_previous"] = avg
|
||||
result["is_slow"] = elapsed > avg * 2
|
||||
return result
|
||||
|
||||
def set_expected_screen(self, description: str):
|
||||
"""Définit ce que Léa devrait voir à l'écran pour cette étape."""
|
||||
self.expected_screen = description
|
||||
|
||||
def check_screen_matches_expected(self) -> Optional[bool]:
|
||||
"""Compare l'observation actuelle avec l'écran attendu."""
|
||||
if not self.expected_screen or not self.current_observation:
|
||||
return None
|
||||
obs_text = (self.current_observation.window_title + " " +
|
||||
self.current_observation.ocr_text).lower()
|
||||
expected_words = self.expected_screen.lower().split()
|
||||
matches = sum(1 for w in expected_words if w in obs_text)
|
||||
return matches / max(len(expected_words), 1) > 0.3
|
||||
|
||||
def learn(self, fact: str):
|
||||
"""Enregistre un fait appris pendant l'exécution."""
|
||||
if fact not in self.learned_facts:
|
||||
self.learned_facts.append(fact)
|
||||
logger.info(f"Fait appris: {fact}")
|
||||
|
||||
def ask_for_help(self, reason: str):
|
||||
"""Signale que Léa a besoin d'aide."""
|
||||
self.needs_help = True
|
||||
self.help_reason = reason
|
||||
self.confidence = max(0, self.confidence - 0.3)
|
||||
logger.warning(f"Léa demande de l'aide: {reason}")
|
||||
|
||||
def to_prompt_context(self) -> str:
|
||||
"""Génère le contexte à injecter dans le prompt VLM.
|
||||
|
||||
C'est ce texte qui donne au VLM la conscience de la situation.
|
||||
"""
|
||||
lines = []
|
||||
|
||||
if self.objective:
|
||||
lines.append(f"OBJECTIF : {self.objective}")
|
||||
|
||||
if self.current_step > 0:
|
||||
lines.append(f"PROGRESSION : étape {self.current_step}/{self.total_steps}")
|
||||
if self.current_step_description:
|
||||
lines.append(f"ÉTAPE EN COURS : {self.current_step_description}")
|
||||
|
||||
if self.current_observation:
|
||||
obs = self.current_observation
|
||||
if obs.window_title:
|
||||
lines.append(f"FENÊTRE ACTIVE : {obs.window_title}")
|
||||
if obs.application:
|
||||
lines.append(f"APPLICATION : {obs.application}")
|
||||
if obs.ui_pattern:
|
||||
lines.append(f"DIALOGUE DÉTECTÉ : {obs.ui_pattern}")
|
||||
|
||||
if self.action_history:
|
||||
last_actions = self.action_history[-3:]
|
||||
lines.append("DERNIÈRES ACTIONS :")
|
||||
for a in last_actions:
|
||||
status = "OK" if a.success else "ÉCHEC"
|
||||
lines.append(f" - {a.action_type} '{a.target}' → {status}")
|
||||
|
||||
if self.learned_facts:
|
||||
lines.append("FAITS APPRIS :")
|
||||
for fact in self.learned_facts[-5:]:
|
||||
lines.append(f" - {fact}")
|
||||
|
||||
if self.remaining_steps:
|
||||
lines.append("PROCHAINES ÉTAPES :")
|
||||
for step in self.remaining_steps[:3]:
|
||||
lines.append(f" - {step}")
|
||||
|
||||
timing = self.get_step_timing()
|
||||
if timing:
|
||||
lines.append(f"TEMPS ÉTAPE : {timing['elapsed_seconds']:.1f}s")
|
||||
if timing.get('avg_previous'):
|
||||
lines.append(f"MOYENNE PRÉCÉDENTE : {timing['avg_previous']:.1f}s")
|
||||
if timing.get('is_slow'):
|
||||
lines.append("⚠ ÉTAPE ANORMALEMENT LENTE")
|
||||
|
||||
if self.expected_screen:
|
||||
match = self.check_screen_matches_expected()
|
||||
if match is False:
|
||||
lines.append(f"⚠ ÉCRAN INATTENDU (attendu: {self.expected_screen})")
|
||||
elif match is True:
|
||||
lines.append(f"ÉCRAN CONFORME : {self.expected_screen}")
|
||||
|
||||
lines.append(f"CONFIANCE : {self.confidence:.0%}")
|
||||
|
||||
if self.needs_help:
|
||||
lines.append(f"BESOIN D'AIDE : {self.help_reason}")
|
||||
|
||||
return "\n".join(lines)
|
||||
|
||||
def to_dict(self) -> Dict[str, Any]:
|
||||
"""Sérialise le contexte pour le stockage/transport."""
|
||||
return {
|
||||
"objective": self.objective,
|
||||
"current_step": self.current_step,
|
||||
"total_steps": self.total_steps,
|
||||
"current_step_description": self.current_step_description,
|
||||
"confidence": self.confidence,
|
||||
"needs_help": self.needs_help,
|
||||
"help_reason": self.help_reason,
|
||||
"action_count": len(self.action_history),
|
||||
"learned_facts": self.learned_facts,
|
||||
"remaining_steps": self.remaining_steps,
|
||||
"last_observation": {
|
||||
"window_title": self.current_observation.window_title,
|
||||
"application": self.current_observation.application,
|
||||
"ui_pattern": self.current_observation.ui_pattern,
|
||||
} if self.current_observation else None,
|
||||
}
|
||||
@@ -68,7 +68,7 @@ class SystemConfig:
|
||||
clip_model: str = "ViT-B-32"
|
||||
clip_pretrained: str = "openai"
|
||||
clip_device: str = "cpu"
|
||||
vlm_model: str = "qwen3-vl:8b"
|
||||
vlm_model: str = "gemma4:latest"
|
||||
vlm_endpoint: str = "http://localhost:11434"
|
||||
owl_model: str = "google/owlv2-base-patch16-ensemble"
|
||||
owl_confidence_threshold: float = 0.1
|
||||
@@ -211,7 +211,7 @@ class ConfigurationManager:
|
||||
clip_model=os.getenv("CLIP_MODEL", "ViT-B-32"),
|
||||
clip_pretrained=os.getenv("CLIP_PRETRAINED", "openai"),
|
||||
clip_device=os.getenv("CLIP_DEVICE", "cpu"),
|
||||
vlm_model=os.getenv("VLM_MODEL", "qwen3-vl:8b"),
|
||||
vlm_model=os.getenv("RPA_VLM_MODEL", os.getenv("VLM_MODEL", "gemma4:latest")),
|
||||
vlm_endpoint=os.getenv("VLM_ENDPOINT", "http://localhost:11434"),
|
||||
owl_model=os.getenv("OWL_MODEL", "google/owlv2-base-patch16-ensemble"),
|
||||
owl_confidence_threshold=float(os.getenv("OWL_CONFIDENCE_THRESHOLD", "0.1")),
|
||||
@@ -435,7 +435,7 @@ class ModelConfig:
|
||||
clip_model: str = "ViT-B-32"
|
||||
clip_pretrained: str = "openai"
|
||||
clip_device: str = "cpu"
|
||||
vlm_model: str = "qwen3-vl:8b"
|
||||
vlm_model: str = "gemma4:latest"
|
||||
vlm_endpoint: str = "http://localhost:11434"
|
||||
owl_model: str = "google/owlv2-base-patch16-ensemble"
|
||||
owl_confidence_threshold: float = 0.1
|
||||
@@ -510,7 +510,7 @@ class FAISSConfig:
|
||||
class GPUResourceConfig:
|
||||
"""Configuration for GPU resource management - DEPRECATED: Use SystemConfig instead"""
|
||||
ollama_endpoint: str = "http://localhost:11434"
|
||||
vlm_model: str = "qwen3-vl:8b"
|
||||
vlm_model: str = "gemma4:latest"
|
||||
clip_model: str = "ViT-B-32"
|
||||
idle_timeout_seconds: int = 300
|
||||
vram_threshold_for_clip_gpu_mb: int = 1024
|
||||
@@ -599,7 +599,7 @@ UPLOADS_PATH=data/training/uploads
|
||||
CLIP_MODEL=ViT-B-32
|
||||
CLIP_PRETRAINED=openai
|
||||
CLIP_DEVICE=cpu
|
||||
VLM_MODEL=qwen3-vl:8b
|
||||
VLM_MODEL=gemma4:latest
|
||||
VLM_ENDPOINT=http://localhost:11434
|
||||
OWL_MODEL=google/owlv2-base-patch16-ensemble
|
||||
OWL_CONFIDENCE_THRESHOLD=0.1
|
||||
|
||||
@@ -25,7 +25,7 @@ class OllamaClient:
|
||||
|
||||
def __init__(self,
|
||||
endpoint: str = "http://localhost:11434",
|
||||
model: str = "qwen3-vl:8b",
|
||||
model: str = None,
|
||||
timeout: int = 180):
|
||||
"""
|
||||
Initialiser le client Ollama
|
||||
@@ -36,7 +36,12 @@ class OllamaClient:
|
||||
timeout: Timeout en secondes
|
||||
"""
|
||||
self.endpoint = endpoint.rstrip('/')
|
||||
# Résolution du modèle : paramètre explicite > config centralisée
|
||||
if model is not None:
|
||||
self.model = model
|
||||
else:
|
||||
from core.detection.vlm_config import get_vlm_model
|
||||
self.model = get_vlm_model(endpoint=self.endpoint)
|
||||
self.timeout = timeout
|
||||
self._check_connection()
|
||||
|
||||
@@ -126,7 +131,12 @@ class OllamaClient:
|
||||
messages.append(user_message)
|
||||
|
||||
# Déterminer si le modèle est un modèle thinking (qwen3)
|
||||
is_thinking_model = "qwen3" in self.model.lower()
|
||||
# Les modèles non-thinking (gemma4, qwen2.5vl) n'ont pas besoin
|
||||
# du workaround prefill et supportent le rôle system natif.
|
||||
from core.detection.vlm_config import is_thinking_model as _is_thinking
|
||||
from core.detection.vlm_config import needs_think_false as _needs_think_false
|
||||
is_thinking_model = _is_thinking(self.model)
|
||||
requires_think_false = _needs_think_false(self.model)
|
||||
|
||||
# WORKAROUND Ollama 0.18.x : think=false est ignoré par le
|
||||
# renderer qwen3-vl-thinking. On utilise un assistant prefill
|
||||
@@ -168,9 +178,9 @@ class OllamaClient:
|
||||
}
|
||||
}
|
||||
|
||||
# Garder think=false au cas où une future version d'Ollama le
|
||||
# corrige — le prefill reste le mécanisme principal
|
||||
if is_thinking_model:
|
||||
# think=false : requis pour qwen3 (prefill reste le mécanisme
|
||||
# principal) ET pour gemma4 (sinon tokens vides sur Ollama >=0.20)
|
||||
if is_thinking_model or requires_think_false:
|
||||
payload["think"] = False
|
||||
|
||||
if force_json:
|
||||
@@ -575,7 +585,7 @@ Your answer:"""
|
||||
# Fonctions utilitaires
|
||||
# ============================================================================
|
||||
|
||||
def create_ollama_client(model: str = "qwen3-vl:8b",
|
||||
def create_ollama_client(model: str = None,
|
||||
endpoint: str = "http://localhost:11434") -> OllamaClient:
|
||||
"""
|
||||
Créer un client Ollama
|
||||
|
||||
@@ -72,9 +72,9 @@ class BoundingBox:
|
||||
class DetectionConfig:
|
||||
"""Configuration de la détection UI hybride"""
|
||||
# VLM — modèle configurable via variable d'environnement RPA_VLM_MODEL
|
||||
# Production (local) : "qwen3-vl:8b" — GPU local, pas de réseau
|
||||
# Tests (cloud) : "qwen3-vl:235b-cloud" — pas de GPU, plus lent mais libère la VRAM
|
||||
vlm_model: str = os.environ.get("RPA_VLM_MODEL", "qwen3-vl:8b")
|
||||
# Par défaut : gemma4:e4b (meilleur grounding + contextualisation)
|
||||
# Fallback : qwen3-vl:8b si gemma4 non disponible
|
||||
vlm_model: str = os.environ.get("RPA_VLM_MODEL", os.environ.get("VLM_MODEL", "gemma4:e4b"))
|
||||
vlm_endpoint: str = "http://localhost:11434"
|
||||
use_vlm_classification: bool = True # Utiliser VLM pour classifier
|
||||
|
||||
@@ -865,7 +865,7 @@ JSON array: [{{"id":0,"type":"...","role":"...","text":"..."}}]"""
|
||||
# ============================================================================
|
||||
|
||||
def create_detector(
|
||||
vlm_model: str = "qwen3-vl:8b",
|
||||
vlm_model: str = None,
|
||||
confidence_threshold: float = 0.7,
|
||||
use_vlm: bool = True
|
||||
) -> UIDetector:
|
||||
@@ -873,13 +873,16 @@ def create_detector(
|
||||
Créer un détecteur avec configuration personnalisée
|
||||
|
||||
Args:
|
||||
vlm_model: Modèle VLM à utiliser
|
||||
vlm_model: Modèle VLM à utiliser (None = résolution automatique via vlm_config)
|
||||
confidence_threshold: Seuil de confiance
|
||||
use_vlm: Utiliser le VLM pour la classification
|
||||
|
||||
Returns:
|
||||
UIDetector configuré
|
||||
"""
|
||||
if vlm_model is None:
|
||||
from core.detection.vlm_config import get_vlm_model
|
||||
vlm_model = get_vlm_model()
|
||||
config = DetectionConfig(
|
||||
vlm_model=vlm_model,
|
||||
confidence_threshold=confidence_threshold,
|
||||
|
||||
194
core/detection/vlm_config.py
Normal file
194
core/detection/vlm_config.py
Normal file
@@ -0,0 +1,194 @@
|
||||
"""
|
||||
Configuration centralisée du modèle VLM (Vision-Language Model).
|
||||
|
||||
Point unique de configuration pour le modèle VLM utilisé dans tout le pipeline.
|
||||
Gère la variable d'environnement RPA_VLM_MODEL avec fallback automatique
|
||||
si le modèle configuré n'est pas disponible dans Ollama.
|
||||
|
||||
Ordre de résolution du modèle :
|
||||
1. Variable d'env RPA_VLM_MODEL (prioritaire)
|
||||
2. Variable d'env VLM_MODEL (compatibilité)
|
||||
3. Modèle par défaut : gemma4:latest
|
||||
|
||||
Fallback automatique :
|
||||
Si le modèle choisi n'est pas trouvé dans Ollama, on essaie les
|
||||
modèles de fallback dans l'ordre (FALLBACK_VLM_MODELS).
|
||||
"""
|
||||
|
||||
import logging
|
||||
import os
|
||||
from typing import List, Optional
|
||||
|
||||
import requests
|
||||
|
||||
logger = logging.getLogger(__name__)
|
||||
|
||||
# Modèle VLM par défaut — Gemma 4 latest (8B dense, Q4_K_M)
|
||||
# Nécessite think=false dans le payload (sinon tokens vides sur Ollama >=0.20)
|
||||
DEFAULT_VLM_MODEL = "gemma4:latest"
|
||||
|
||||
# Modèles de fallback, testés dans l'ordre si le modèle principal n'est pas dispo
|
||||
FALLBACK_VLM_MODELS = ["qwen3-vl:8b", "0000/ui-tars-1.5-7b-q8_0:7b"]
|
||||
|
||||
# Endpoint Ollama par défaut
|
||||
DEFAULT_OLLAMA_ENDPOINT = "http://localhost:11434"
|
||||
|
||||
# Cache du modèle résolu (évite de requêter Ollama à chaque appel)
|
||||
_resolved_model: Optional[str] = None
|
||||
_resolved_model_checked = False
|
||||
|
||||
|
||||
def get_vlm_model(
|
||||
endpoint: str = DEFAULT_OLLAMA_ENDPOINT,
|
||||
force_check: bool = False,
|
||||
) -> str:
|
||||
"""Retourne le nom du modèle VLM à utiliser, avec fallback automatique.
|
||||
|
||||
Vérifie la disponibilité du modèle dans Ollama au premier appel,
|
||||
puis cache le résultat pour les appels suivants.
|
||||
|
||||
Args:
|
||||
endpoint: URL de l'API Ollama
|
||||
force_check: Forcer une nouvelle vérification (ignorer le cache)
|
||||
|
||||
Returns:
|
||||
Nom du modèle VLM disponible (ex: "gemma4:latest")
|
||||
"""
|
||||
global _resolved_model, _resolved_model_checked
|
||||
|
||||
if _resolved_model_checked and not force_check:
|
||||
return _resolved_model
|
||||
|
||||
# Lire le modèle configuré depuis l'environnement
|
||||
configured = (
|
||||
os.environ.get("RPA_VLM_MODEL")
|
||||
or os.environ.get("VLM_MODEL")
|
||||
or DEFAULT_VLM_MODEL
|
||||
)
|
||||
|
||||
# Vérifier la disponibilité dans Ollama
|
||||
available = _list_ollama_models(endpoint)
|
||||
|
||||
if available is None:
|
||||
# Ollama non joignable — utiliser le modèle configuré sans vérification
|
||||
logger.warning(
|
||||
"Ollama non joignable (%s) — utilisation de '%s' sans vérification",
|
||||
endpoint, configured,
|
||||
)
|
||||
_resolved_model = configured
|
||||
_resolved_model_checked = True
|
||||
return _resolved_model
|
||||
|
||||
# Vérifier si le modèle configuré est disponible
|
||||
if _model_available(configured, available):
|
||||
logger.info("VLM model: %s (configuré, disponible)", configured)
|
||||
_resolved_model = configured
|
||||
_resolved_model_checked = True
|
||||
return _resolved_model
|
||||
|
||||
# Fallback : essayer les modèles alternatifs
|
||||
logger.warning(
|
||||
"Modèle VLM '%s' non trouvé dans Ollama. Recherche d'un fallback...",
|
||||
configured,
|
||||
)
|
||||
|
||||
# Construire la liste de fallback complète
|
||||
fallback_candidates = [DEFAULT_VLM_MODEL] + FALLBACK_VLM_MODELS
|
||||
for candidate in fallback_candidates:
|
||||
if candidate == configured:
|
||||
continue # Déjà testé
|
||||
if _model_available(candidate, available):
|
||||
logger.info(
|
||||
"VLM model: %s (fallback, '%s' non disponible)",
|
||||
candidate, configured,
|
||||
)
|
||||
_resolved_model = candidate
|
||||
_resolved_model_checked = True
|
||||
return _resolved_model
|
||||
|
||||
# Aucun fallback trouvé — utiliser le modèle configuré quand même
|
||||
# (Ollama le téléchargera peut-être au premier appel)
|
||||
logger.warning(
|
||||
"Aucun modèle VLM trouvé dans Ollama. "
|
||||
"Modèles disponibles : %s. Utilisation de '%s' par défaut.",
|
||||
[m for m in available if "vl" in m.lower() or "gemma" in m.lower()],
|
||||
configured,
|
||||
)
|
||||
_resolved_model = configured
|
||||
_resolved_model_checked = True
|
||||
return _resolved_model
|
||||
|
||||
|
||||
def reset_vlm_model_cache():
|
||||
"""Réinitialiser le cache du modèle résolu.
|
||||
|
||||
Utile après un changement de configuration ou un pull de modèle.
|
||||
"""
|
||||
global _resolved_model, _resolved_model_checked
|
||||
_resolved_model = None
|
||||
_resolved_model_checked = False
|
||||
|
||||
|
||||
def is_thinking_model(model_name: str) -> bool:
|
||||
"""Détermine si un modèle est un modèle 'thinking' (qwen3).
|
||||
|
||||
Les modèles thinking nécessitent un assistant prefill pour éviter
|
||||
le mode réflexion interne qui peut durer >180s avec des images.
|
||||
|
||||
Args:
|
||||
model_name: Nom du modèle (ex: "qwen3-vl:8b", "gemma4:e4b")
|
||||
|
||||
Returns:
|
||||
True si le modèle est de type thinking (nécessite prefill workaround)
|
||||
"""
|
||||
return "qwen3" in model_name.lower()
|
||||
|
||||
|
||||
def needs_think_false(model_name: str) -> bool:
|
||||
"""Détermine si un modèle nécessite think=false dans le payload.
|
||||
|
||||
Sur Ollama >=0.20, gemma4 produit des tokens vides si think n'est pas
|
||||
explicitement désactivé. Ce flag doit être envoyé dans le payload chat.
|
||||
|
||||
Args:
|
||||
model_name: Nom du modèle (ex: "gemma4:latest", "gemma4:e4b")
|
||||
|
||||
Returns:
|
||||
True si le modèle nécessite think=false
|
||||
"""
|
||||
return "gemma4" in model_name.lower()
|
||||
|
||||
|
||||
def _list_ollama_models(endpoint: str) -> Optional[List[str]]:
|
||||
"""Lister les modèles disponibles dans Ollama.
|
||||
|
||||
Returns:
|
||||
Liste des noms de modèles, ou None si Ollama n'est pas joignable.
|
||||
"""
|
||||
try:
|
||||
resp = requests.get(f"{endpoint}/api/tags", timeout=5)
|
||||
if resp.status_code == 200:
|
||||
models = resp.json().get("models", [])
|
||||
return [m["name"] for m in models]
|
||||
except Exception:
|
||||
pass
|
||||
return None
|
||||
|
||||
|
||||
def _model_available(model_name: str, available_models: List[str]) -> bool:
|
||||
"""Vérifie si un modèle est disponible dans la liste Ollama.
|
||||
|
||||
Supporte la correspondance exacte et le match sans tag de version
|
||||
(ex: "gemma4:e4b" match "gemma4:e4b" ou "gemma4:e4b-q4_0").
|
||||
"""
|
||||
# Match exact
|
||||
if model_name in available_models:
|
||||
return True
|
||||
|
||||
# Match par préfixe (sans tag) — "gemma4:e4b" match "gemma4:e4b"
|
||||
base_name = model_name.split(":")[0] if ":" in model_name else model_name
|
||||
for m in available_models:
|
||||
if m.startswith(base_name + ":"):
|
||||
return True
|
||||
|
||||
return False
|
||||
@@ -58,8 +58,18 @@ class CLIPEmbedder(EmbedderBase):
|
||||
"Install it with: pip install open-clip-torch"
|
||||
)
|
||||
|
||||
# Default to CPU to save GPU for vision models (Qwen3-VL, etc.)
|
||||
if device is None:
|
||||
try:
|
||||
import torch
|
||||
if torch.cuda.is_available():
|
||||
free_vram = torch.cuda.mem_get_info()[0] / 1024**3
|
||||
if free_vram > 1.5:
|
||||
device = "cuda"
|
||||
else:
|
||||
device = "cpu"
|
||||
else:
|
||||
device = "cpu"
|
||||
except Exception:
|
||||
device = "cpu"
|
||||
|
||||
self.model_name = model_name
|
||||
|
||||
@@ -11,7 +11,12 @@ from pathlib import Path
|
||||
from dataclasses import dataclass
|
||||
import numpy as np
|
||||
import json
|
||||
import pickle
|
||||
|
||||
from core.security.signed_serializer import (
|
||||
SignatureVerificationError,
|
||||
load_signed,
|
||||
save_signed,
|
||||
)
|
||||
|
||||
logger = logging.getLogger(__name__)
|
||||
|
||||
@@ -500,21 +505,23 @@ class FAISSManager:
|
||||
# Sauvegarder index FAISS
|
||||
faiss.write_index(index_to_save, str(index_path))
|
||||
|
||||
# Sauvegarder métadonnées
|
||||
# Sauvegarder métadonnées (JSON signé HMAC — cf. core.security.signed_serializer)
|
||||
metadata = {
|
||||
"dimensions": self.dimensions,
|
||||
"index_type": self.index_type,
|
||||
"metric": self.metric,
|
||||
"next_id": self.next_id,
|
||||
"metadata_store": self.metadata_store,
|
||||
# Les clés dict sont des int côté Python ; on les sérialise en str
|
||||
# puis on les reconvertit au chargement. JSON n'autorise pas de
|
||||
# clés non-string.
|
||||
"metadata_store": {str(k): v for k, v in self.metadata_store.items()},
|
||||
"nlist": self.nlist,
|
||||
"nprobe": self.nprobe,
|
||||
"is_trained": self.is_trained,
|
||||
"auto_optimize": self.auto_optimize
|
||||
"auto_optimize": self.auto_optimize,
|
||||
}
|
||||
|
||||
with open(metadata_path, 'wb') as f:
|
||||
pickle.dump(metadata, f)
|
||||
save_signed(metadata_path, metadata)
|
||||
|
||||
@classmethod
|
||||
def load(cls, index_path: Path, metadata_path: Path, use_gpu: bool = False) -> 'FAISSManager':
|
||||
@@ -529,11 +536,22 @@ class FAISSManager:
|
||||
Returns:
|
||||
FAISSManager chargé
|
||||
"""
|
||||
# Charger métadonnées
|
||||
with open(metadata_path, 'rb') as f:
|
||||
metadata = pickle.load(f)
|
||||
# Charger métadonnées (JSON signé ; fallback legacy pickle avec migration).
|
||||
try:
|
||||
metadata = load_signed(metadata_path)
|
||||
except SignatureVerificationError:
|
||||
logger.error(
|
||||
"Signature HMAC invalide pour %s — refus de chargement.",
|
||||
metadata_path,
|
||||
)
|
||||
raise
|
||||
|
||||
# Créer instance
|
||||
# Reconvertir les clés int du metadata_store (JSON force des clés str).
|
||||
if isinstance(metadata.get("metadata_store"), dict):
|
||||
metadata["metadata_store"] = {
|
||||
int(k) if isinstance(k, str) and k.lstrip("-").isdigit() else k: v
|
||||
for k, v in metadata["metadata_store"].items()
|
||||
}
|
||||
manager = cls(
|
||||
dimensions=metadata["dimensions"],
|
||||
index_type=metadata["index_type"],
|
||||
|
||||
@@ -10,6 +10,7 @@ from .error_handler import ErrorHandler, ErrorType, RecoveryStrategy
|
||||
from .workflow_runner import WorkflowRunner, RunResult, RunStatus, RunnerConfig
|
||||
from .dag_executor import DAGExecutor, WorkflowStep, StepType, StepStatus, DAGExecutionResult
|
||||
from .llm_actions import LLMActionHandler
|
||||
from .observe_reason_act import ORALoop, Observation, Decision, VerificationResult, LoopResult
|
||||
|
||||
# Import tardif pour éviter import circulaire avec pipeline
|
||||
def _get_execution_loop():
|
||||
@@ -34,5 +35,11 @@ __all__ = [
|
||||
'StepStatus',
|
||||
'DAGExecutionResult',
|
||||
'LLMActionHandler',
|
||||
# ORA — boucle Observe-Raisonne-Agit avec vérification
|
||||
'ORALoop',
|
||||
'Observation',
|
||||
'Decision',
|
||||
'VerificationResult',
|
||||
'LoopResult',
|
||||
# ExecutionLoop accessible via import direct du module
|
||||
]
|
||||
|
||||
@@ -654,7 +654,8 @@ class ActionExecutor:
|
||||
if PYAUTOGUI_AVAILABLE:
|
||||
pyautogui.click(click_x, click_y)
|
||||
time.sleep(0.2)
|
||||
pyautogui.write(text, interval=0.05)
|
||||
from .input_handler import safe_type_text
|
||||
safe_type_text(text)
|
||||
else:
|
||||
logger.info(f" (Simulated click at {click_x:.0f}, {click_y:.0f})")
|
||||
logger.info(f" (Simulated typing: {text[:50]}...)")
|
||||
|
||||
@@ -525,11 +525,25 @@ class DAGExecutor:
|
||||
True/False selon le résultat de la condition
|
||||
"""
|
||||
condition = action.get("condition", "True")
|
||||
# Contexte d'évaluation sécurisé : uniquement les résultats
|
||||
# Contexte d'évaluation sécurisé : uniquement les résultats.
|
||||
# NB : on utilise un évaluateur AST restreint (pas d'eval/exec),
|
||||
# seuls literals, comparaisons, booléens et indexations sont permis.
|
||||
eval_context = {"results": dict(self._results)}
|
||||
|
||||
# Import local pour éviter une dépendance circulaire au chargement.
|
||||
from core.execution.safe_condition_evaluator import (
|
||||
UnsafeExpressionError,
|
||||
safe_eval_condition,
|
||||
)
|
||||
|
||||
try:
|
||||
result = bool(eval(condition, {"__builtins__": {}}, eval_context))
|
||||
result = bool(safe_eval_condition(condition, eval_context))
|
||||
except UnsafeExpressionError as exc:
|
||||
logger.error(
|
||||
"Condition refusée pour '%s' (expression non sûre) : %s",
|
||||
step.step_id, exc,
|
||||
)
|
||||
result = False
|
||||
except Exception as exc:
|
||||
logger.warning(
|
||||
"Erreur d'évaluation de condition pour '%s' : %s",
|
||||
|
||||
@@ -151,6 +151,13 @@ class StepResult:
|
||||
duration_ms: float
|
||||
message: str
|
||||
screenshot_path: Optional[str] = None
|
||||
# C1 — Instrumentation vision-aware
|
||||
ocr_ms: float = 0.0 # Temps OCR du ScreenState de ce step
|
||||
ui_ms: float = 0.0 # Temps détection UI de ce step
|
||||
total_ms: float = 0.0 # Temps total (alias de duration_ms pour cohérence)
|
||||
analyze_ms: float = 0.0 # Temps total analyse ScreenState (OCR + UI + reste)
|
||||
cache_hit: bool = False # True si ScreenState vient du cache
|
||||
degraded: bool = False # True si mode dégradé activé (timeout analyse)
|
||||
|
||||
|
||||
class ExecutionLoop:
|
||||
@@ -175,7 +182,13 @@ class ExecutionLoop:
|
||||
capture_interval_ms: int = 500,
|
||||
max_no_match_retries: int = 5,
|
||||
confirmation_callback: Optional[Callable[[str, Dict], bool]] = None,
|
||||
coaching_callback: Optional[Callable[[str, Dict], "CoachingResponse"]] = None
|
||||
coaching_callback: Optional[Callable[[str, Dict], "CoachingResponse"]] = None,
|
||||
screen_analyzer: Optional[Any] = None,
|
||||
screen_state_cache: Optional[Any] = None,
|
||||
enable_ui_detection: bool = True,
|
||||
enable_ocr: bool = True,
|
||||
analyze_timeout_ms: int = 8000,
|
||||
window_info_provider: Optional[Callable[[], Optional[Dict[str, Any]]]] = None,
|
||||
):
|
||||
"""
|
||||
Initialiser la boucle d'exécution.
|
||||
@@ -188,6 +201,15 @@ class ExecutionLoop:
|
||||
max_no_match_retries: Nombre max de tentatives si pas de match
|
||||
confirmation_callback: Callback pour demander confirmation (SUPERVISED)
|
||||
coaching_callback: Callback pour décisions coaching (COACHING)
|
||||
screen_analyzer: ScreenAnalyzer pour construire un ScreenState enrichi
|
||||
(lazy init via singleton si None)
|
||||
screen_state_cache: Cache perceptuel (lazy init via singleton si None)
|
||||
enable_ui_detection: Active la détection UI (True par défaut, flag d'urgence)
|
||||
enable_ocr: Active l'OCR (True par défaut)
|
||||
analyze_timeout_ms: Timeout soft pour l'analyse d'un ScreenState.
|
||||
Au-delà, on active le mode dégradé pour les steps suivants.
|
||||
window_info_provider: Callable renvoyant un dict window_info. Si None,
|
||||
on tente `screen_capturer.get_active_window()`.
|
||||
"""
|
||||
self.pipeline = pipeline
|
||||
self.action_executor = action_executor or ActionExecutor()
|
||||
@@ -204,6 +226,27 @@ class ExecutionLoop:
|
||||
self.confirmation_callback = confirmation_callback
|
||||
self.coaching_callback = coaching_callback
|
||||
|
||||
# C1 — Vision-aware execution
|
||||
self._screen_analyzer = screen_analyzer # lazy init si None
|
||||
self._screen_state_cache = screen_state_cache # lazy init si None
|
||||
self.enable_ui_detection = enable_ui_detection
|
||||
self.enable_ocr = enable_ocr
|
||||
self.analyze_timeout_ms = analyze_timeout_ms
|
||||
self._window_info_provider = window_info_provider
|
||||
# Mode dégradé déclenché par un timeout analyse — persiste tant qu'un
|
||||
# probe n'a pas démontré la récupération (voir ci-dessous).
|
||||
self._degraded_mode = False
|
||||
# Auto-rétablissement : compteur de steps rapides consécutifs.
|
||||
# Si l'analyse tourne vite (< analyze_timeout_ms / 2) pendant
|
||||
# _fast_steps_recovery_threshold steps → on quitte le mode dégradé.
|
||||
self._successive_fast_steps = 0
|
||||
self._fast_steps_recovery_threshold = 3
|
||||
# En mode dégradé, on retente l'analyse tous les _probe_interval steps
|
||||
# pour détecter la récupération (les autres steps restent en stub pour
|
||||
# éviter de re-saturer le GPU). 10 par défaut = ~5s à 500ms/step.
|
||||
self._probe_interval = 10
|
||||
self._degraded_step_counter = 0
|
||||
|
||||
# État interne
|
||||
self.state = ExecutionState.IDLE
|
||||
self.context: Optional[ExecutionContext] = None
|
||||
@@ -464,15 +507,15 @@ class ExecutionLoop:
|
||||
})
|
||||
|
||||
# Notify Analytics about step completion
|
||||
# C1 — transmet tous les champs vision-aware (ocr_ms, ui_ms,
|
||||
# analyze_ms, cache_hit, degraded) au système analytics via
|
||||
# on_step_result qui accepte un StepResult complet.
|
||||
if self._analytics_integration and step_result:
|
||||
try:
|
||||
self._analytics_integration.on_step_complete(
|
||||
workflow_id=self.context.workflow_id,
|
||||
self._analytics_integration.on_step_result(
|
||||
execution_id=self.context.execution_id,
|
||||
step_id=step_result.node_id,
|
||||
success=step_result.success,
|
||||
duration_ms=step_result.duration_ms,
|
||||
confidence=step_result.match_confidence
|
||||
workflow_id=self.context.workflow_id,
|
||||
step_result=step_result,
|
||||
)
|
||||
except Exception as e:
|
||||
logger.warning(f"Analytics step notification failed: {e}")
|
||||
@@ -505,10 +548,32 @@ class ExecutionLoop:
|
||||
self._notify_state_change(ExecutionState.STOPPED)
|
||||
|
||||
# Notify Analytics about execution completion
|
||||
# Contrat normalisé (Lot A) : duration_ms + status explicite
|
||||
# au lieu du booléen success + duration ambigu.
|
||||
if self._analytics_integration and self.context:
|
||||
try:
|
||||
success = self.state == ExecutionState.COMPLETED
|
||||
duration_ms = (datetime.now() - self.context.started_at).total_seconds() * 1000
|
||||
duration_ms = (
|
||||
datetime.now() - self.context.started_at
|
||||
).total_seconds() * 1000
|
||||
|
||||
# Mapping ExecutionState → status analytics
|
||||
if self.state == ExecutionState.COMPLETED:
|
||||
status = "completed"
|
||||
elif self.state == ExecutionState.FAILED:
|
||||
status = "failed"
|
||||
elif self.state == ExecutionState.STOPPED:
|
||||
status = "stopped"
|
||||
elif self.state == ExecutionState.PAUSED:
|
||||
# Pause non résolue à la sortie = blocage non récupéré
|
||||
status = "blocked"
|
||||
else:
|
||||
status = self.state.value
|
||||
|
||||
error_message = (
|
||||
None
|
||||
if status == "completed"
|
||||
else f"Execution ended in state: {self.state.value}"
|
||||
)
|
||||
|
||||
# Stop resource monitoring
|
||||
self._analytics_integration.stop_resource_monitoring(
|
||||
@@ -518,12 +583,12 @@ class ExecutionLoop:
|
||||
self._analytics_integration.on_execution_complete(
|
||||
workflow_id=self.context.workflow_id,
|
||||
execution_id=self.context.execution_id,
|
||||
success=success,
|
||||
duration_ms=duration_ms,
|
||||
steps_executed=self.context.steps_executed,
|
||||
steps_succeeded=self.context.steps_succeeded,
|
||||
status=status,
|
||||
steps_total=self.context.steps_executed,
|
||||
steps_completed=self.context.steps_succeeded,
|
||||
steps_failed=self.context.steps_failed,
|
||||
error_message=None if success else f"Execution ended in state: {self.state.value}"
|
||||
error_message=error_message,
|
||||
)
|
||||
except Exception as e:
|
||||
logger.warning(f"Analytics completion notification failed: {e}")
|
||||
@@ -547,10 +612,23 @@ class ExecutionLoop:
|
||||
|
||||
self.context.last_screenshot_path = screenshot_path
|
||||
|
||||
# 1bis. Construire un ScreenState enrichi (C1) — avec cache perceptuel
|
||||
screen_state, timings = self._build_screen_state(screenshot_path)
|
||||
logger.debug(
|
||||
f"[Step] ScreenState analyze={timings['analyze_ms']:.0f}ms "
|
||||
f"ocr={timings['ocr_ms']:.0f}ms ui={timings['ui_ms']:.0f}ms "
|
||||
f"cache_hit={timings['cache_hit']} degraded={timings['degraded']}"
|
||||
)
|
||||
|
||||
# 2. Identifier l'état actuel (matching)
|
||||
match = self.pipeline.match_current_state(
|
||||
screenshot_path,
|
||||
workflow_id=self.context.workflow_id
|
||||
#
|
||||
# Lot E — on consomme le ScreenState enrichi déjà construit en 1bis
|
||||
# (avec ui_elements, detected_text, window_title réels) au lieu de
|
||||
# laisser le pipeline reconstruire un stub avec window_title="Unknown".
|
||||
# Premier vrai matching context-aware.
|
||||
match = self.pipeline.match_current_state_from_state(
|
||||
screen_state,
|
||||
workflow_id=self.context.workflow_id,
|
||||
)
|
||||
|
||||
if not match:
|
||||
@@ -564,25 +642,98 @@ class ExecutionLoop:
|
||||
|
||||
logger.info(f"Matched node: {current_node_id} (confidence: {confidence:.3f})")
|
||||
|
||||
# 3. Obtenir la prochaine action
|
||||
# 3. Obtenir la prochaine action (C3 : sélection d'edge robuste)
|
||||
#
|
||||
# Lot A — contrat dict avec status explicite :
|
||||
# "terminal" → fin légitime du workflow (success=True)
|
||||
# "blocked" → pause supervisée (plus JAMAIS traité comme un succès
|
||||
# pour ne pas déclencher un faux _is_workflow_complete)
|
||||
# "selected" → action à exécuter
|
||||
#
|
||||
# Lot B — on propage la confidence du match courant (source_similarity)
|
||||
# pour que l'EdgeScorer puisse vérifier la précondition
|
||||
# `min_source_similarity` de chaque edge. Sans cette propagation, la
|
||||
# contrainte était silencieusement désactivée (hardcodé à 1.0).
|
||||
next_action = self.pipeline.get_next_action(
|
||||
self.context.workflow_id,
|
||||
current_node_id
|
||||
current_node_id,
|
||||
screen_state=screen_state,
|
||||
source_similarity=confidence,
|
||||
)
|
||||
|
||||
if not next_action:
|
||||
# Pas d'action suivante = fin du workflow ou node terminal
|
||||
# Rétrocompat défensive : si un pipeline legacy renvoie None ou un dict
|
||||
# sans status, on considère ça comme un blocage (safe default).
|
||||
if not isinstance(next_action, dict) or "status" not in next_action:
|
||||
logger.error(
|
||||
"get_next_action a renvoyé un résultat sans status "
|
||||
f"(legacy?). Valeur reçue: {next_action!r}"
|
||||
)
|
||||
next_action = {"status": "blocked", "reason": "legacy_none_return"}
|
||||
|
||||
action_status = next_action.get("status")
|
||||
|
||||
if action_status == "terminal":
|
||||
# Fin légitime : aucun outgoing_edge sur le node courant
|
||||
total_ms = (time.time() - start_time) * 1000
|
||||
return StepResult(
|
||||
success=True,
|
||||
node_id=current_node_id,
|
||||
edge_id=None,
|
||||
action_result=None,
|
||||
match_confidence=confidence,
|
||||
duration_ms=(time.time() - start_time) * 1000,
|
||||
message="No next action (terminal node)",
|
||||
screenshot_path=screenshot_path
|
||||
duration_ms=total_ms,
|
||||
message="Workflow terminated (terminal node)",
|
||||
screenshot_path=screenshot_path,
|
||||
ocr_ms=timings["ocr_ms"],
|
||||
ui_ms=timings["ui_ms"],
|
||||
analyze_ms=timings["analyze_ms"],
|
||||
total_ms=total_ms,
|
||||
cache_hit=timings["cache_hit"],
|
||||
degraded=timings["degraded"],
|
||||
)
|
||||
|
||||
if action_status == "blocked":
|
||||
# Blocage : des edges existent mais aucun n'est valide.
|
||||
# On déclenche une pause supervisée (paused_need_help) et on
|
||||
# remonte l'erreur. On ne retourne PAS success=True.
|
||||
reason = next_action.get("reason", "unknown")
|
||||
logger.warning(
|
||||
f"ExecutionLoop bloqué sur {current_node_id}: {reason} "
|
||||
f"→ pause supervisée demandée"
|
||||
)
|
||||
# On bascule en PAUSED et on arme _pause_requested pour que la
|
||||
# boucle principale attende un resume() humain.
|
||||
self.state = ExecutionState.PAUSED
|
||||
self._pause_requested = True
|
||||
self._notify_state_change(ExecutionState.PAUSED)
|
||||
if self._on_error:
|
||||
try:
|
||||
self._on_error(
|
||||
"blocked",
|
||||
Exception(f"No valid edge from {current_node_id}: {reason}"),
|
||||
)
|
||||
except Exception as cb_err:
|
||||
logger.debug(f"on_error callback failed: {cb_err}")
|
||||
|
||||
total_ms = (time.time() - start_time) * 1000
|
||||
return StepResult(
|
||||
success=False,
|
||||
node_id=current_node_id,
|
||||
edge_id=None,
|
||||
action_result=None,
|
||||
match_confidence=confidence,
|
||||
duration_ms=total_ms,
|
||||
message=f"Blocked: {reason}",
|
||||
screenshot_path=screenshot_path,
|
||||
ocr_ms=timings["ocr_ms"],
|
||||
ui_ms=timings["ui_ms"],
|
||||
analyze_ms=timings["analyze_ms"],
|
||||
total_ms=total_ms,
|
||||
cache_hit=timings["cache_hit"],
|
||||
degraded=timings["degraded"],
|
||||
)
|
||||
|
||||
# À partir d'ici, on est forcément en status="selected"
|
||||
edge_id = next_action["edge_id"]
|
||||
self.context.current_edge_id = edge_id
|
||||
|
||||
@@ -604,7 +755,7 @@ class ExecutionLoop:
|
||||
if coaching_response.decision == CoachingDecision.ACCEPT:
|
||||
# Utilisateur accepte : exécuter l'action suggérée
|
||||
self._coaching_stats['accepted'] += 1
|
||||
action_result = self._execute_action(next_action)
|
||||
action_result = self._execute_action(next_action, screen_state=screen_state)
|
||||
self._record_coaching_feedback(
|
||||
next_action, coaching_response, action_result, success=True
|
||||
)
|
||||
@@ -615,15 +766,22 @@ class ExecutionLoop:
|
||||
self._record_coaching_feedback(
|
||||
next_action, coaching_response, None, success=False
|
||||
)
|
||||
total_ms = (time.time() - start_time) * 1000
|
||||
return StepResult(
|
||||
success=False,
|
||||
node_id=current_node_id,
|
||||
edge_id=edge_id,
|
||||
action_result=None,
|
||||
match_confidence=confidence,
|
||||
duration_ms=(time.time() - start_time) * 1000,
|
||||
duration_ms=total_ms,
|
||||
message="Action rejected by user in COACHING mode",
|
||||
screenshot_path=screenshot_path
|
||||
screenshot_path=screenshot_path,
|
||||
ocr_ms=timings["ocr_ms"],
|
||||
ui_ms=timings["ui_ms"],
|
||||
analyze_ms=timings["analyze_ms"],
|
||||
total_ms=total_ms,
|
||||
cache_hit=timings["cache_hit"],
|
||||
degraded=timings["degraded"],
|
||||
)
|
||||
|
||||
elif coaching_response.decision == CoachingDecision.CORRECT:
|
||||
@@ -632,7 +790,7 @@ class ExecutionLoop:
|
||||
corrected_action = self._apply_coaching_correction(
|
||||
next_action, coaching_response.correction
|
||||
)
|
||||
action_result = self._execute_action(corrected_action)
|
||||
action_result = self._execute_action(corrected_action, screen_state=screen_state)
|
||||
self._record_coaching_feedback(
|
||||
next_action, coaching_response, action_result,
|
||||
success=action_result.status == ExecutionStatus.SUCCESS if action_result else False
|
||||
@@ -658,23 +816,30 @@ class ExecutionLoop:
|
||||
# Mode supervisé : demander confirmation
|
||||
if not self._request_confirmation(next_action):
|
||||
logger.info("Action rejected by user")
|
||||
total_ms = (time.time() - start_time) * 1000
|
||||
return StepResult(
|
||||
success=False,
|
||||
node_id=current_node_id,
|
||||
edge_id=edge_id,
|
||||
action_result=None,
|
||||
match_confidence=confidence,
|
||||
duration_ms=(time.time() - start_time) * 1000,
|
||||
duration_ms=total_ms,
|
||||
message="Action rejected by user",
|
||||
screenshot_path=screenshot_path
|
||||
screenshot_path=screenshot_path,
|
||||
ocr_ms=timings["ocr_ms"],
|
||||
ui_ms=timings["ui_ms"],
|
||||
analyze_ms=timings["analyze_ms"],
|
||||
total_ms=total_ms,
|
||||
cache_hit=timings["cache_hit"],
|
||||
degraded=timings["degraded"],
|
||||
)
|
||||
|
||||
# Exécuter l'action
|
||||
action_result = self._execute_action(next_action)
|
||||
action_result = self._execute_action(next_action, screen_state=screen_state)
|
||||
|
||||
elif self.context.mode == ExecutionMode.AUTOMATIC:
|
||||
# Mode automatique : exécuter directement
|
||||
action_result = self._execute_action(next_action)
|
||||
action_result = self._execute_action(next_action, screen_state=screen_state)
|
||||
|
||||
# 5. Mettre à jour les compteurs
|
||||
self.context.steps_executed += 1
|
||||
@@ -693,7 +858,13 @@ class ExecutionLoop:
|
||||
match_confidence=confidence,
|
||||
duration_ms=duration_ms,
|
||||
message=action_result.message if action_result else "Observed",
|
||||
screenshot_path=screenshot_path
|
||||
screenshot_path=screenshot_path,
|
||||
ocr_ms=timings["ocr_ms"],
|
||||
ui_ms=timings["ui_ms"],
|
||||
analyze_ms=timings["analyze_ms"],
|
||||
total_ms=duration_ms,
|
||||
cache_hit=timings["cache_hit"],
|
||||
degraded=timings["degraded"],
|
||||
)
|
||||
|
||||
# =========================================================================
|
||||
@@ -718,8 +889,18 @@ class ExecutionLoop:
|
||||
logger.error(f"Screen capture failed: {e}")
|
||||
return None
|
||||
|
||||
def _execute_action(self, action_info: Dict[str, Any]) -> ExecutionResult:
|
||||
"""Exécuter une action via l'ActionExecutor."""
|
||||
def _execute_action(
|
||||
self,
|
||||
action_info: Dict[str, Any],
|
||||
screen_state: Optional[Any] = None,
|
||||
) -> ExecutionResult:
|
||||
"""
|
||||
Exécuter une action via l'ActionExecutor.
|
||||
|
||||
Args:
|
||||
action_info: dict action {edge_id, action, target_node, ...}
|
||||
screen_state: ScreenState enrichi (si None, fallback stub minimal)
|
||||
"""
|
||||
try:
|
||||
# Charger le workflow et l'edge
|
||||
workflow = self.pipeline.load_workflow(self.context.workflow_id)
|
||||
@@ -732,36 +913,10 @@ class ExecutionLoop:
|
||||
duration_ms=0
|
||||
)
|
||||
|
||||
# Créer un ScreenState minimal pour l'exécution
|
||||
from core.models.screen_state import (
|
||||
ScreenState, WindowContext, RawLevel, PerceptionLevel,
|
||||
ContextLevel, EmbeddingRef
|
||||
)
|
||||
|
||||
screen_state = ScreenState(
|
||||
screen_state_id=f"exec_{datetime.now().strftime('%Y%m%d_%H%M%S')}",
|
||||
timestamp=datetime.now(),
|
||||
session_id=self.context.execution_id,
|
||||
window=WindowContext(
|
||||
app_name="unknown",
|
||||
window_title="Unknown",
|
||||
screen_resolution=[1920, 1080],
|
||||
workspace="main"
|
||||
),
|
||||
raw=RawLevel(
|
||||
screenshot_path=self.context.last_screenshot_path or "",
|
||||
capture_method="execution",
|
||||
file_size_bytes=0
|
||||
),
|
||||
perception=PerceptionLevel(
|
||||
embedding=EmbeddingRef(provider="", vector_id="", dimensions=512),
|
||||
detected_text=[],
|
||||
text_detection_method="none",
|
||||
confidence_avg=0.0
|
||||
),
|
||||
context=ContextLevel(),
|
||||
ui_elements=[]
|
||||
)
|
||||
# Utiliser le ScreenState enrichi fourni par le loop ; fallback minimal
|
||||
# uniquement si on n'en a pas (legacy, tests).
|
||||
if screen_state is None:
|
||||
screen_state = self._build_stub_screen_state()
|
||||
|
||||
# Exécuter l'action
|
||||
result = self.action_executor.execute_edge(
|
||||
@@ -782,6 +937,286 @@ class ExecutionLoop:
|
||||
error=e
|
||||
)
|
||||
|
||||
# =========================================================================
|
||||
# C1 — Construction du ScreenState (vision-aware)
|
||||
# =========================================================================
|
||||
|
||||
def _get_screen_analyzer(self):
|
||||
"""
|
||||
Récupérer le ScreenAnalyzer (singleton partagé, lazy).
|
||||
|
||||
Retourne None si indisponible (import error, etc.) — le loop
|
||||
bascule alors en fallback stub.
|
||||
|
||||
Note Lot C : on ne passe plus `session_id` au singleton. Le session_id
|
||||
est désormais un paramètre d'appel de `analyze()`, pour éviter que deux
|
||||
ExecutionLoop partageant le même analyzer se marchent dessus.
|
||||
"""
|
||||
if self._screen_analyzer is not None:
|
||||
return self._screen_analyzer
|
||||
try:
|
||||
from core.pipeline import get_screen_analyzer
|
||||
self._screen_analyzer = get_screen_analyzer()
|
||||
return self._screen_analyzer
|
||||
except Exception as e:
|
||||
logger.warning(f"ScreenAnalyzer indisponible: {e}")
|
||||
return None
|
||||
|
||||
def _get_screen_state_cache(self):
|
||||
"""Récupérer le cache de ScreenState (singleton partagé, lazy)."""
|
||||
if self._screen_state_cache is not None:
|
||||
return self._screen_state_cache
|
||||
try:
|
||||
from core.pipeline import get_screen_state_cache
|
||||
self._screen_state_cache = get_screen_state_cache()
|
||||
return self._screen_state_cache
|
||||
except Exception as e:
|
||||
logger.warning(f"ScreenStateCache indisponible: {e}")
|
||||
return None
|
||||
|
||||
def _resolve_window_info(self) -> Optional[Dict[str, Any]]:
|
||||
"""
|
||||
Récupérer les infos de la fenêtre active.
|
||||
|
||||
Ordre de préférence :
|
||||
1. `window_info_provider` fourni au constructeur
|
||||
2. `screen_capturer.get_active_window()`
|
||||
3. None → ScreenAnalyzer utilisera les valeurs par défaut
|
||||
"""
|
||||
if self._window_info_provider is not None:
|
||||
try:
|
||||
return self._window_info_provider()
|
||||
except Exception as e:
|
||||
logger.debug(f"window_info_provider failed: {e}")
|
||||
|
||||
try:
|
||||
raw = self.screen_capturer.get_active_window()
|
||||
if raw:
|
||||
# Normaliser vers le format attendu par ScreenAnalyzer
|
||||
return {
|
||||
"title": raw.get("title", "Unknown"),
|
||||
"app_name": raw.get("app", "unknown"),
|
||||
"window_bounds": [
|
||||
raw.get("x", 0),
|
||||
raw.get("y", 0),
|
||||
raw.get("width", 0),
|
||||
raw.get("height", 0),
|
||||
],
|
||||
}
|
||||
except Exception as e:
|
||||
logger.debug(f"get_active_window failed: {e}")
|
||||
return None
|
||||
|
||||
def _build_screen_state(
|
||||
self,
|
||||
screenshot_path: str,
|
||||
) -> tuple:
|
||||
"""
|
||||
Construire un ScreenState enrichi depuis un screenshot.
|
||||
|
||||
Logique :
|
||||
- Si enable_ui_detection=False ET enable_ocr=False → stub
|
||||
- Si analyseur indisponible → stub
|
||||
- Sinon : cache.get_or_compute(analyzer.analyze)
|
||||
- Timeout soft : si l'analyse dépasse `analyze_timeout_ms`, on log
|
||||
un warning et on active le mode dégradé pour les prochains steps.
|
||||
|
||||
Returns:
|
||||
(screen_state, timings_dict)
|
||||
timings_dict: {
|
||||
"analyze_ms", "ocr_ms", "ui_ms", "cache_hit", "degraded"
|
||||
}
|
||||
"""
|
||||
timings = {
|
||||
"analyze_ms": 0.0,
|
||||
"ocr_ms": 0.0,
|
||||
"ui_ms": 0.0,
|
||||
"cache_hit": False,
|
||||
"degraded": False,
|
||||
}
|
||||
|
||||
# Mode "tout désactivé" (flag d'urgence) → stub
|
||||
if not self.enable_ui_detection and not self.enable_ocr:
|
||||
timings["degraded"] = True
|
||||
return self._build_stub_screen_state(screenshot_path), timings
|
||||
|
||||
analyzer = self._get_screen_analyzer()
|
||||
if analyzer is None:
|
||||
timings["degraded"] = True
|
||||
return self._build_stub_screen_state(screenshot_path), timings
|
||||
|
||||
# Mode dégradé : on reste sur stub, sauf "probe" périodique qui teste
|
||||
# si le GPU est redevenu performant. Si oui, on accumule les steps
|
||||
# rapides ; après _fast_steps_recovery_threshold probes rapides
|
||||
# consécutifs on retourne en mode complet.
|
||||
if self._degraded_mode:
|
||||
self._degraded_step_counter += 1
|
||||
if self._degraded_step_counter < self._probe_interval:
|
||||
timings["degraded"] = True
|
||||
return self._build_stub_screen_state(screenshot_path), timings
|
||||
# Sinon on tente un probe réel ci-dessous
|
||||
self._degraded_step_counter = 0
|
||||
|
||||
cache = self._get_screen_state_cache()
|
||||
|
||||
# Invalidation proactive : si l'écran a massivement changé depuis
|
||||
# la dernière entrée du cache, on purge. Le TTL seul (2s) laisserait
|
||||
# passer des entrées obsolètes sur des changements rapides (popup, nav).
|
||||
if cache is not None:
|
||||
try:
|
||||
cache.invalidate_if_changed(screenshot_path, threshold=0.3)
|
||||
except Exception as e:
|
||||
logger.debug(f"invalidate_if_changed a échoué: {e}")
|
||||
|
||||
window_info = self._resolve_window_info()
|
||||
|
||||
# Fonction de calcul (cache miss)
|
||||
# Les flags runtime (enable_ocr, enable_ui_detection) et le session_id
|
||||
# sont passés en kwargs-only à analyze() : AUCUNE mutation de l'analyseur
|
||||
# singleton (Lot C — thread-safety, deux ExecutionLoop peuvent partager
|
||||
# le même analyzer sans se contaminer).
|
||||
execution_id = self.context.execution_id if self.context else ""
|
||||
|
||||
def compute(path: str):
|
||||
t_start = time.time()
|
||||
state = analyzer.analyze(
|
||||
path,
|
||||
window_info=window_info,
|
||||
enable_ocr=self.enable_ocr,
|
||||
enable_ui_detection=self.enable_ui_detection,
|
||||
session_id=execution_id,
|
||||
)
|
||||
elapsed = (time.time() - t_start) * 1000
|
||||
# Annoter le temps dans les métadonnées
|
||||
if hasattr(state, "metadata"):
|
||||
state.metadata["analyze_ms"] = elapsed
|
||||
return state
|
||||
|
||||
t0 = time.time()
|
||||
try:
|
||||
if cache is not None:
|
||||
# Lot D — clé composite context-aware : deux contextes
|
||||
# différents partageant le même screenshot n'entrent plus
|
||||
# en collision. Le workflow_id isole les replays par workflow,
|
||||
# les flags différencient les modes d'analyse (OCR on/off,
|
||||
# UI on/off), et le (window_title, app_name) distingue deux
|
||||
# applications qui présenteraient un rendu visuel similaire.
|
||||
ctx_window_title = (window_info or {}).get("title", "") or ""
|
||||
ctx_app_name = (window_info or {}).get("app_name", "") or ""
|
||||
ctx_workflow_id = (
|
||||
self.context.workflow_id if self.context else ""
|
||||
)
|
||||
state, cache_hit, _ = cache.get_or_compute(
|
||||
screenshot_path,
|
||||
compute,
|
||||
window_title=ctx_window_title,
|
||||
app_name=ctx_app_name,
|
||||
enable_ocr=self.enable_ocr,
|
||||
enable_ui_detection=self.enable_ui_detection,
|
||||
workflow_id=ctx_workflow_id,
|
||||
)
|
||||
else:
|
||||
state = compute(screenshot_path)
|
||||
cache_hit = False
|
||||
except Exception as e:
|
||||
logger.warning(f"ScreenState build failed: {e} — fallback stub")
|
||||
timings["degraded"] = True
|
||||
return self._build_stub_screen_state(screenshot_path), timings
|
||||
|
||||
analyze_ms = (time.time() - t0) * 1000
|
||||
timings["analyze_ms"] = analyze_ms
|
||||
timings["cache_hit"] = cache_hit
|
||||
|
||||
# Décomposer OCR vs UI si possible (métadonnées)
|
||||
meta = getattr(state, "metadata", {}) or {}
|
||||
timings["ocr_ms"] = float(meta.get("ocr_ms", 0.0))
|
||||
timings["ui_ms"] = float(meta.get("ui_ms", 0.0))
|
||||
|
||||
# Timeout soft : activer le mode dégradé si > seuil
|
||||
# (cache_hit ignoré : un hit ne prouve rien sur la santé du GPU)
|
||||
if analyze_ms > self.analyze_timeout_ms and not cache_hit:
|
||||
logger.warning(
|
||||
f"ScreenState analysis slow: {analyze_ms:.0f}ms > "
|
||||
f"{self.analyze_timeout_ms}ms → activation mode dégradé"
|
||||
)
|
||||
self._degraded_mode = True
|
||||
self._successive_fast_steps = 0
|
||||
timings["degraded"] = True
|
||||
else:
|
||||
# Step "rapide" : incrémenter le compteur si < timeout / 2.
|
||||
# On ignore les cache hits (pas représentatifs de la perf GPU).
|
||||
fast_threshold_ms = self.analyze_timeout_ms / 2
|
||||
if not cache_hit and analyze_ms < fast_threshold_ms:
|
||||
self._successive_fast_steps += 1
|
||||
|
||||
# Auto-rétablissement : si on était en dégradé et qu'on a
|
||||
# enchaîné assez de steps rapides → retour en mode complet.
|
||||
if (
|
||||
self._degraded_mode
|
||||
and self._successive_fast_steps
|
||||
>= self._fast_steps_recovery_threshold
|
||||
):
|
||||
logger.info(
|
||||
"Mode complet restauré après %d steps rapides "
|
||||
"(dernier analyze_ms=%.0fms < seuil=%.0fms)",
|
||||
self._successive_fast_steps,
|
||||
analyze_ms,
|
||||
fast_threshold_ms,
|
||||
)
|
||||
self._degraded_mode = False
|
||||
self._successive_fast_steps = 0
|
||||
elif not cache_hit:
|
||||
# Step ni lent ni rapide (entre timeout/2 et timeout) : reset
|
||||
self._successive_fast_steps = 0
|
||||
|
||||
# On propage l'état dégradé courant dans les timings (utile pour le
|
||||
# StepResult : tant qu'on n'a pas récupéré assez de steps rapides,
|
||||
# on continue à signaler "degraded=True").
|
||||
timings["degraded"] = self._degraded_mode
|
||||
|
||||
return state, timings
|
||||
|
||||
def _build_stub_screen_state(self, screenshot_path: Optional[str] = None):
|
||||
"""
|
||||
Construire un ScreenState minimal (fallback legacy).
|
||||
|
||||
Utilisé quand l'analyseur est indisponible ou que tous les flags
|
||||
de détection sont désactivés (flag d'urgence).
|
||||
"""
|
||||
from core.models.screen_state import (
|
||||
ScreenState, WindowContext, RawLevel, PerceptionLevel,
|
||||
ContextLevel, EmbeddingRef
|
||||
)
|
||||
|
||||
path = screenshot_path or (
|
||||
self.context.last_screenshot_path if self.context else ""
|
||||
) or ""
|
||||
|
||||
return ScreenState(
|
||||
screen_state_id=f"exec_{datetime.now().strftime('%Y%m%d_%H%M%S_%f')}",
|
||||
timestamp=datetime.now(),
|
||||
session_id=self.context.execution_id if self.context else "stub",
|
||||
window=WindowContext(
|
||||
app_name="unknown",
|
||||
window_title="Unknown",
|
||||
screen_resolution=[1920, 1080],
|
||||
workspace="main",
|
||||
),
|
||||
raw=RawLevel(
|
||||
screenshot_path=path,
|
||||
capture_method="execution",
|
||||
file_size_bytes=0,
|
||||
),
|
||||
perception=PerceptionLevel(
|
||||
embedding=EmbeddingRef(provider="", vector_id="", dimensions=512),
|
||||
detected_text=[],
|
||||
text_detection_method="none",
|
||||
confidence_avg=0.0,
|
||||
),
|
||||
context=ContextLevel(),
|
||||
ui_elements=[],
|
||||
)
|
||||
|
||||
def _request_confirmation(self, action_info: Dict[str, Any]) -> bool:
|
||||
"""Demander confirmation à l'utilisateur."""
|
||||
if self.confirmation_callback:
|
||||
|
||||
708
core/execution/input_handler.py
Normal file
708
core/execution/input_handler.py
Normal file
@@ -0,0 +1,708 @@
|
||||
"""
|
||||
Module partagé de saisie texte et gestion des dialogues.
|
||||
|
||||
Utilisé par les deux executors :
|
||||
- VWB executor (visual_workflow_builder/backend/api_v3/execute.py)
|
||||
- Core executor (core/execution/action_executor.py)
|
||||
|
||||
Garantit le même comportement AZERTY/VM/Citrix partout.
|
||||
"""
|
||||
|
||||
import logging
|
||||
import subprocess
|
||||
import shutil
|
||||
import time
|
||||
from typing import Any, Dict, List, Optional
|
||||
|
||||
logger = logging.getLogger(__name__)
|
||||
|
||||
try:
|
||||
import pyautogui
|
||||
PYAUTOGUI_AVAILABLE = True
|
||||
except ImportError:
|
||||
PYAUTOGUI_AVAILABLE = False
|
||||
|
||||
|
||||
def safe_type_text(text: str):
|
||||
"""Saisie de texte compatible VM/Citrix et claviers AZERTY/QWERTY.
|
||||
|
||||
Priorité :
|
||||
1. xdotool type avec refresh layout → traverse les VM spice/QEMU
|
||||
2. Presse-papier (xclip) + Ctrl+V → fallback
|
||||
3. pyautogui.write() → dernier recours
|
||||
"""
|
||||
if not text:
|
||||
return
|
||||
|
||||
# Méthode 1 : xdotool type avec refresh du layout clavier
|
||||
if shutil.which('xdotool') and shutil.which('setxkbmap'):
|
||||
try:
|
||||
subprocess.run(['setxkbmap', 'fr'], timeout=2)
|
||||
subprocess.run(
|
||||
['xdotool', 'type', '--delay', '0', '--clearmodifiers', '--', text],
|
||||
timeout=max(30, len(text) * 0.05),
|
||||
check=True
|
||||
)
|
||||
logger.debug(f"Saisie via xdotool type ({len(text)} car.)")
|
||||
return
|
||||
except Exception as e:
|
||||
logger.debug(f"xdotool type échoué: {e}")
|
||||
|
||||
# Méthode 2 : Presse-papier
|
||||
xclip = shutil.which('xclip')
|
||||
if xclip and PYAUTOGUI_AVAILABLE:
|
||||
try:
|
||||
p = subprocess.Popen(
|
||||
['xclip', '-selection', 'clipboard'],
|
||||
stdin=subprocess.PIPE,
|
||||
stdout=subprocess.DEVNULL,
|
||||
stderr=subprocess.DEVNULL
|
||||
)
|
||||
p.stdin.write(text.encode('utf-8'))
|
||||
p.stdin.close()
|
||||
time.sleep(0.2)
|
||||
pyautogui.hotkey('ctrl', 'v')
|
||||
time.sleep(0.3)
|
||||
logger.debug(f"Saisie via presse-papier ({len(text)} car.)")
|
||||
return
|
||||
except Exception as e:
|
||||
logger.debug(f"xclip échoué: {e}")
|
||||
|
||||
# Méthode 3 : pyautogui
|
||||
if PYAUTOGUI_AVAILABLE:
|
||||
logger.warning("Saisie via pyautogui.write() (AZERTY non garanti)")
|
||||
pyautogui.write(text, interval=0.02)
|
||||
else:
|
||||
logger.warning(f"Aucune méthode de saisie disponible pour: {text[:50]}")
|
||||
|
||||
|
||||
def check_screen_for_patterns() -> Optional[Dict[str, Any]]:
|
||||
"""Vérifie si l'écran contient un pattern UI connu (dialogue, popup).
|
||||
|
||||
Capture l'écran, extrait le texte via OCR, et cherche un pattern
|
||||
dans la UIPatternLibrary.
|
||||
|
||||
Returns:
|
||||
Dict avec le pattern trouvé, ou None.
|
||||
"""
|
||||
try:
|
||||
from core.knowledge.ui_patterns import UIPatternLibrary
|
||||
import mss
|
||||
from PIL import Image
|
||||
|
||||
lib = UIPatternLibrary()
|
||||
|
||||
with mss.mss() as sct:
|
||||
monitor = sct.monitors[0]
|
||||
screenshot = sct.grab(monitor)
|
||||
screen = Image.frombytes('RGB', screenshot.size, screenshot.bgra, 'raw', 'BGRX')
|
||||
|
||||
try:
|
||||
# Essayer docTR d'abord (peut être importé depuis différents chemins)
|
||||
try:
|
||||
from services.ocr_service import ocr_extract_text
|
||||
except ImportError:
|
||||
from core.extraction.field_extractor import FieldExtractor
|
||||
extractor = FieldExtractor()
|
||||
ocr_extract_text = lambda img: extractor.extract_text_from_image(img)
|
||||
|
||||
ocr_text = ocr_extract_text(screen)
|
||||
except ImportError:
|
||||
logger.debug("OCR non disponible pour pattern check")
|
||||
return None
|
||||
|
||||
if not ocr_text or len(ocr_text) < 5:
|
||||
return None
|
||||
|
||||
pattern = lib.find_pattern(ocr_text)
|
||||
if pattern and pattern['category'] in ('dialog', 'popup'):
|
||||
print(f"🧠 [PatternCheck] Détecté: '{pattern['pattern']}' → {pattern['action']} '{pattern['target']}'")
|
||||
return pattern
|
||||
|
||||
return None
|
||||
|
||||
except Exception as e:
|
||||
print(f"⚠️ [PatternCheck] Erreur: {e}")
|
||||
return None
|
||||
|
||||
|
||||
def handle_detected_pattern(pattern: Dict[str, Any]) -> bool:
|
||||
"""Gère automatiquement un pattern UI détecté.
|
||||
|
||||
Cherche le bouton cible via OCR (position réelle sur l'écran).
|
||||
100% vision — zéro coordonnée hardcodée.
|
||||
|
||||
Returns:
|
||||
True si le pattern a été géré avec succès.
|
||||
"""
|
||||
if not PYAUTOGUI_AVAILABLE:
|
||||
logger.warning("pyautogui non disponible — impossible de gérer le pattern")
|
||||
return False
|
||||
|
||||
action = pattern.get('action')
|
||||
target = pattern.get('target', '')
|
||||
alternatives = pattern.get('alternatives', [])
|
||||
|
||||
if action == 'click':
|
||||
candidates_labels = [target] + alternatives
|
||||
print(f"🔧 [Réflexe/handle] Recherche bouton parmi: {candidates_labels}")
|
||||
|
||||
try:
|
||||
import mss
|
||||
import numpy as np
|
||||
from PIL import Image
|
||||
|
||||
with mss.mss() as sct:
|
||||
monitor = sct.monitors[0]
|
||||
screenshot = sct.grab(monitor)
|
||||
screen = Image.frombytes('RGB', screenshot.size, screenshot.bgra, 'raw', 'BGRX')
|
||||
|
||||
# EasyOCR (rapide, bonne qualité GUI) avec fallback docTR
|
||||
words = []
|
||||
try:
|
||||
import easyocr
|
||||
_reader = easyocr.Reader(['fr', 'en'], gpu=False, verbose=False)
|
||||
results = _reader.readtext(np.array(screen))
|
||||
for (bbox_pts, text, conf) in results:
|
||||
if not text or len(text.strip()) < 1:
|
||||
continue
|
||||
x1 = int(min(p[0] for p in bbox_pts))
|
||||
y1 = int(min(p[1] for p in bbox_pts))
|
||||
x2 = int(max(p[0] for p in bbox_pts))
|
||||
y2 = int(max(p[1] for p in bbox_pts))
|
||||
words.append({'text': text.strip(), 'bbox': [x1, y1, x2, y2]})
|
||||
except ImportError:
|
||||
try:
|
||||
from services.ocr_service import ocr_extract_words
|
||||
words = ocr_extract_words(screen) or []
|
||||
except ImportError:
|
||||
pass
|
||||
|
||||
print(f"🔧 [Réflexe/handle] {len(words)} mots OCR détectés")
|
||||
|
||||
# Collecter tous les matchs, prendre le plus bas (bouton = bas du dialogue)
|
||||
all_matches = []
|
||||
|
||||
for candidate in candidates_labels:
|
||||
candidate_lower = candidate.lower()
|
||||
for word in words:
|
||||
word_text = word['text'].lower()
|
||||
if len(word_text) < 2 or len(candidate_lower) < 2:
|
||||
continue
|
||||
# Match exact ou inclusion
|
||||
if word_text == candidate_lower or candidate_lower in word_text or word_text in candidate_lower:
|
||||
x1, y1, x2, y2 = word['bbox']
|
||||
all_matches.append({
|
||||
'text': word['text'],
|
||||
'x': int((x1 + x2) / 2),
|
||||
'y': int((y1 + y2) / 2),
|
||||
'candidate': candidate,
|
||||
})
|
||||
|
||||
if all_matches:
|
||||
best = max(all_matches, key=lambda m: m['y'])
|
||||
print(f"✅ [Réflexe/handle] Clic sur '{best['text']}' à ({best['x']}, {best['y']})")
|
||||
pyautogui.click(best['x'], best['y'])
|
||||
time.sleep(1.0)
|
||||
return True
|
||||
|
||||
print(f"⚠️ [Réflexe/handle] Bouton '{target}' introuvable parmi {[w['text'] for w in words[:15]]}")
|
||||
return False
|
||||
|
||||
except Exception as e:
|
||||
print(f"⚠️ [Réflexe/handle] Erreur: {e}")
|
||||
return False
|
||||
|
||||
elif action == 'hotkey':
|
||||
keys = target.split('+')
|
||||
logger.info(f"Raccourci automatique: {target}")
|
||||
pyautogui.hotkey(*keys)
|
||||
time.sleep(0.5)
|
||||
return True
|
||||
|
||||
return False
|
||||
|
||||
|
||||
def vlm_reason_about_screen(objective: str = "", context: str = "") -> Optional[Dict[str, Any]]:
|
||||
"""Demande au VLM de raisonner sur l'écran actuel et proposer une action.
|
||||
|
||||
Utilisé quand les réflexes (patterns) ne suffisent pas.
|
||||
Le VLM voit l'écran et décide quoi faire.
|
||||
|
||||
Args:
|
||||
objective: Ce que Léa essaie de faire (ex: "cliquer sur Enregistrer")
|
||||
context: Contexte additionnel (ex: "un dialogue est apparu")
|
||||
|
||||
Returns:
|
||||
Dict avec 'action', 'target', 'reasoning' ou None si le VLM ne peut pas aider.
|
||||
"""
|
||||
try:
|
||||
import mss
|
||||
import requests
|
||||
import json
|
||||
import base64
|
||||
import io
|
||||
import os
|
||||
from PIL import Image
|
||||
|
||||
with mss.mss() as sct:
|
||||
monitor = sct.monitors[0]
|
||||
screenshot = sct.grab(monitor)
|
||||
screen = Image.frombytes('RGB', screenshot.size, screenshot.bgra, 'raw', 'BGRX')
|
||||
|
||||
buffer = io.BytesIO()
|
||||
screen.save(buffer, format='JPEG', quality=70)
|
||||
image_b64 = base64.b64encode(buffer.getvalue()).decode('utf-8')
|
||||
|
||||
prompt = f"""Analyse cet écran et dis-moi quoi faire.
|
||||
|
||||
Objectif : {objective or "Interagir avec l'interface visible"}
|
||||
Contexte : {context or "Aucun contexte supplémentaire"}
|
||||
|
||||
Réponds en JSON strict :
|
||||
{{
|
||||
"action": "click" ou "type" ou "wait" ou "nothing",
|
||||
"target": "texte exact du bouton ou champ à cliquer",
|
||||
"reasoning": "explication courte de ton choix"
|
||||
}}
|
||||
|
||||
Si tu vois un dialogue ou une popup, indique quel bouton cliquer.
|
||||
Si l'écran est normal sans action nécessaire, réponds action="nothing".
|
||||
Réponds UNIQUEMENT le JSON, pas d'explication."""
|
||||
|
||||
ollama_url = os.environ.get("OLLAMA_URL", "http://localhost:11434")
|
||||
model = os.environ.get("RPA_REASONING_MODEL", "qwen2.5vl:7b")
|
||||
|
||||
response = requests.post(
|
||||
f"{ollama_url}/api/generate",
|
||||
json={
|
||||
"model": model,
|
||||
"prompt": prompt,
|
||||
"images": [image_b64],
|
||||
"stream": False,
|
||||
"options": {"temperature": 0.1, "num_predict": 200}
|
||||
},
|
||||
timeout=30
|
||||
)
|
||||
|
||||
if response.status_code != 200:
|
||||
logger.warning(f"VLM reasoning failed: HTTP {response.status_code}")
|
||||
return None
|
||||
|
||||
result = response.json()
|
||||
text = result.get('response', '').strip()
|
||||
|
||||
import re
|
||||
match = re.search(r'\{[\s\S]*\}', text)
|
||||
if match:
|
||||
parsed = json.loads(match.group())
|
||||
logger.info(f"VLM reasoning: {parsed.get('action')} '{parsed.get('target')}' — {parsed.get('reasoning', '')[:80]}")
|
||||
return parsed
|
||||
|
||||
logger.debug(f"VLM response not parseable: {text[:100]}")
|
||||
return None
|
||||
|
||||
except Exception as e:
|
||||
logger.debug(f"VLM reasoning failed: {e}")
|
||||
return None
|
||||
|
||||
|
||||
def find_element_on_screen(
|
||||
target_text: str,
|
||||
target_description: str = "",
|
||||
anchor_image_base64: Optional[str] = None,
|
||||
anchor_bbox: Optional[Dict] = None,
|
||||
) -> Optional[Dict[str, Any]]:
|
||||
"""
|
||||
Cherche un élément sur l'écran en utilisant 3 méthodes en cascade.
|
||||
|
||||
Niveau 1 — OCR (rapide, ~1s) : docTR pour trouver le texte exact
|
||||
Niveau 2 — UI-TARS grounding (~3s) : modèle GUI spécialisé
|
||||
Niveau 3 — VLM reasoning (~10s) : raisonnement + OCR de confirmation
|
||||
|
||||
Args:
|
||||
target_text: Texte de l'élément à trouver (ex: "Demo", "Enregistrer")
|
||||
target_description: Description plus longue (ex: "le dossier Demo sur le bureau")
|
||||
anchor_image_base64: Image de référence de l'ancre (pour CLIP matching, réservé futur)
|
||||
anchor_bbox: Position originale de l'ancre (pour désambiguïser les matchs multiples)
|
||||
|
||||
Returns:
|
||||
{'x': int, 'y': int, 'method': str, 'confidence': float} ou None
|
||||
"""
|
||||
# Si le target_text est vide ou c'est juste le type d'action,
|
||||
# utiliser le VLM pour décrire l'image de l'ancre
|
||||
action_types = {'click_anchor', 'double_click_anchor', 'right_click_anchor',
|
||||
'hover_anchor', 'focus_anchor', 'scroll_to_anchor'}
|
||||
has_useful_text = target_text and target_text not in action_types
|
||||
|
||||
if not has_useful_text and anchor_image_base64:
|
||||
desc = _describe_anchor_image(anchor_image_base64)
|
||||
if desc:
|
||||
logger.info(f"[Grounding] Ancre décrite par VLM: '{desc}'")
|
||||
target_description = desc
|
||||
if not has_useful_text:
|
||||
target_text = desc
|
||||
|
||||
if not target_text and not target_description:
|
||||
logger.debug("find_element_on_screen: ni target_text ni target_description fournis")
|
||||
return None
|
||||
|
||||
search_label = target_description or target_text
|
||||
logger.info(f"[Grounding] Recherche élément: '{search_label}' (cascade 3 niveaux)")
|
||||
|
||||
# ─── Niveau 1 — OCR (rapide, ~1s) ───
|
||||
result = _grounding_ocr(target_text, anchor_bbox=anchor_bbox)
|
||||
if result:
|
||||
return result
|
||||
|
||||
# ─── Niveau 2 — UI-TARS grounding (~3s) ───
|
||||
result = _grounding_ui_tars(target_text, target_description)
|
||||
if result:
|
||||
return result
|
||||
|
||||
# ─── Niveau 3 — VLM reasoning (~10s) ───
|
||||
result = _grounding_vlm(target_text, target_description)
|
||||
if result:
|
||||
return result
|
||||
|
||||
logger.warning(f"[Grounding] ÉCHEC total pour '{search_label}' — aucune méthode n'a trouvé l'élément")
|
||||
return None
|
||||
|
||||
|
||||
def _describe_anchor_image(anchor_image_base64: str) -> Optional[str]:
|
||||
"""Demande au VLM de décrire l'image de l'ancre en quelques mots.
|
||||
|
||||
Utilisé quand le label est vide — le VLM regarde le crop de l'ancre
|
||||
et décrit ce qu'il voit ("folder icon named Demo", "Save button", etc.)
|
||||
pour que UI-TARS puisse chercher cet élément sur l'écran complet.
|
||||
"""
|
||||
try:
|
||||
import requests
|
||||
import os
|
||||
|
||||
if ',' in anchor_image_base64:
|
||||
anchor_image_base64 = anchor_image_base64.split(',', 1)[1]
|
||||
|
||||
ollama_url = os.environ.get("OLLAMA_URL", "http://localhost:11434")
|
||||
model = "qwen2.5vl:3b"
|
||||
|
||||
logger.info(f"[Grounding] Description ancre via {model}...")
|
||||
response = requests.post(
|
||||
f"{ollama_url}/api/generate",
|
||||
json={
|
||||
"model": model,
|
||||
"prompt": "Describe this UI element in 5 words maximum. Just the element name, nothing else. Example: 'folder icon named Demo' or 'Save button' or 'Chrome browser icon'",
|
||||
"images": [anchor_image_base64],
|
||||
"stream": False,
|
||||
"options": {"temperature": 0.1, "num_predict": 20}
|
||||
},
|
||||
timeout=30
|
||||
)
|
||||
|
||||
if response.status_code == 200:
|
||||
desc = response.json().get('response', '').strip().strip('"').strip("'")
|
||||
if desc and len(desc) > 2:
|
||||
return desc
|
||||
|
||||
return None
|
||||
|
||||
except Exception as e:
|
||||
logger.warning(f"[Grounding] Description ancre échouée: {e}")
|
||||
return None
|
||||
|
||||
|
||||
def _capture_screen():
|
||||
"""Capture l'écran principal et retourne (PIL.Image, width, height)."""
|
||||
try:
|
||||
import mss
|
||||
from PIL import Image as PILImage
|
||||
|
||||
with mss.mss() as sct:
|
||||
monitor = sct.monitors[0]
|
||||
screenshot = sct.grab(monitor)
|
||||
screen = PILImage.frombytes('RGB', screenshot.size, screenshot.bgra, 'raw', 'BGRX')
|
||||
return screen, monitor['width'], monitor['height']
|
||||
except Exception as e:
|
||||
logger.debug(f"Capture écran échouée: {e}")
|
||||
return None, 0, 0
|
||||
|
||||
|
||||
def _grounding_ocr(target_text: str, anchor_bbox: Optional[Dict] = None) -> Optional[Dict[str, Any]]:
|
||||
"""Niveau 1 — Cherche le texte par OCR (docTR). ~1s.
|
||||
|
||||
Collecte TOUS les matchs et choisit le plus pertinent :
|
||||
- Si anchor_bbox fourni → le plus proche de la position originale
|
||||
- Sinon → le plus proche du centre de l'écran (zone contenu)
|
||||
"""
|
||||
logger.debug(f"[Grounding/OCR] target='{target_text}' bbox={anchor_bbox}")
|
||||
if not target_text:
|
||||
return None
|
||||
|
||||
try:
|
||||
screen, screen_w, screen_h = _capture_screen()
|
||||
if screen is None:
|
||||
return None
|
||||
|
||||
try:
|
||||
from services.ocr_service import ocr_extract_words
|
||||
except ImportError:
|
||||
from core.extraction.field_extractor import FieldExtractor
|
||||
extractor = FieldExtractor()
|
||||
def ocr_extract_words(img):
|
||||
return extractor.extract_words_from_image(img)
|
||||
|
||||
words = ocr_extract_words(screen)
|
||||
if not words:
|
||||
logger.debug("[Grounding/OCR] Aucun mot détecté")
|
||||
return None
|
||||
|
||||
target_lower = target_text.lower()
|
||||
all_matches = []
|
||||
|
||||
# Collecter tous les matchs
|
||||
for word in words:
|
||||
word_lower = word['text'].lower()
|
||||
x1, y1, x2, y2 = word['bbox']
|
||||
cx, cy = int((x1 + x2) / 2), int((y1 + y2) / 2)
|
||||
|
||||
if word_lower == target_lower:
|
||||
all_matches.append({'text': word['text'], 'x': cx, 'y': cy, 'type': 'exact', 'conf': 0.95})
|
||||
elif len(word_lower) >= 3 and len(target_lower) >= 3:
|
||||
if target_lower in word_lower or word_lower in target_lower:
|
||||
# Pénaliser les matchs partiels trop courts par rapport au target
|
||||
ratio = len(word_lower) / max(len(target_lower), 1)
|
||||
conf = 0.80 if ratio > 0.5 else 0.50
|
||||
all_matches.append({'text': word['text'], 'x': cx, 'y': cy, 'type': 'partial', 'conf': conf})
|
||||
|
||||
# Matching lettre initiale manquante
|
||||
if not all_matches and len(target_lower) > 3:
|
||||
partial = target_lower[1:]
|
||||
for word in words:
|
||||
if partial in word['text'].lower():
|
||||
x1, y1, x2, y2 = word['bbox']
|
||||
all_matches.append({'text': word['text'], 'x': int((x1+x2)/2), 'y': int((y1+y2)/2), 'type': 'partial_cut', 'conf': 0.70})
|
||||
|
||||
if not all_matches:
|
||||
logger.debug(f"[Grounding/OCR] '{target_text}' non trouvé parmi {len(words)} mots")
|
||||
return None
|
||||
|
||||
# Choisir le meilleur match
|
||||
if len(all_matches) == 1:
|
||||
best = all_matches[0]
|
||||
elif anchor_bbox:
|
||||
# Prendre le plus proche de la position originale de l'ancre
|
||||
orig_x = anchor_bbox.get('x', 0) + anchor_bbox.get('width', 0) / 2
|
||||
orig_y = anchor_bbox.get('y', 0) + anchor_bbox.get('height', 0) / 2
|
||||
best = min(all_matches, key=lambda m: ((m['x'] - orig_x)**2 + (m['y'] - orig_y)**2))
|
||||
else:
|
||||
# Prendre le plus central (zone contenu, pas les barres de titre)
|
||||
center_x, center_y = screen_w / 2, screen_h / 2
|
||||
best = min(all_matches, key=lambda m: ((m['x'] - center_x)**2 + (m['y'] - center_y)**2))
|
||||
|
||||
for m in all_matches:
|
||||
sel = " ← CHOISI" if m is best else ""
|
||||
logger.info(f" [OCR] Candidat: '{m['text']}' à ({m['x']}, {m['y']}) [{m['type']}]{sel}")
|
||||
|
||||
return {'x': best['x'], 'y': best['y'], 'method': 'ocr', 'confidence': best['conf']}
|
||||
|
||||
except Exception as e:
|
||||
logger.debug(f"[Grounding/OCR] Erreur: {e}")
|
||||
return None
|
||||
|
||||
|
||||
def _grounding_ui_tars(target_text: str, target_description: str = "") -> Optional[Dict[str, Any]]:
|
||||
"""Niveau 2 — UI-TARS grounding visuel (~3s)."""
|
||||
try:
|
||||
import requests
|
||||
import base64
|
||||
import io
|
||||
import re
|
||||
import os
|
||||
|
||||
screen, screen_w, screen_h = _capture_screen()
|
||||
if screen is None:
|
||||
return None
|
||||
|
||||
# Encoder le screenshot en base64
|
||||
buffer = io.BytesIO()
|
||||
screen.save(buffer, format='JPEG', quality=70)
|
||||
image_b64 = base64.b64encode(buffer.getvalue()).decode('utf-8')
|
||||
|
||||
# Construire le prompt pour UI-TARS
|
||||
click_target = target_description or target_text
|
||||
prompt = f"click on {click_target}"
|
||||
|
||||
ollama_url = os.environ.get("OLLAMA_URL", "http://localhost:11434")
|
||||
model = "0000/ui-tars-1.5-7b-q8_0:7b"
|
||||
|
||||
logger.info(f"[Grounding/UI-TARS] Envoi à {model}: '{prompt}'")
|
||||
|
||||
response = requests.post(
|
||||
f"{ollama_url}/api/generate",
|
||||
json={
|
||||
"model": model,
|
||||
"prompt": prompt,
|
||||
"images": [image_b64],
|
||||
"stream": False,
|
||||
"options": {"temperature": 0.1, "num_predict": 50}
|
||||
},
|
||||
timeout=30
|
||||
)
|
||||
|
||||
if response.status_code != 200:
|
||||
logger.warning(f"[Grounding/UI-TARS] HTTP {response.status_code}")
|
||||
return None
|
||||
|
||||
result = response.json()
|
||||
text = result.get('response', '').strip()
|
||||
logger.debug(f"[Grounding/UI-TARS] Réponse brute: {text[:200]}")
|
||||
|
||||
# Parser les coordonnées de UI-TARS
|
||||
coords = _parse_ui_tars_coordinates(text, screen_w, screen_h)
|
||||
if coords:
|
||||
x, y = coords
|
||||
# Valider que les coordonnées sont dans l'écran
|
||||
if 0 <= x <= screen_w and 0 <= y <= screen_h:
|
||||
logger.info(f"[Grounding/UI-TARS] Grounding → ({x}, {y})")
|
||||
return {'x': x, 'y': y, 'method': 'ui_tars', 'confidence': 0.85}
|
||||
else:
|
||||
logger.warning(f"[Grounding/UI-TARS] Coordonnées hors écran: ({x}, {y}) pour {screen_w}x{screen_h}")
|
||||
return None
|
||||
|
||||
logger.debug(f"[Grounding/UI-TARS] Pas de coordonnées parsées dans: {text[:100]}")
|
||||
return None
|
||||
|
||||
except Exception as e:
|
||||
logger.debug(f"[Grounding/UI-TARS] Erreur: {e}")
|
||||
return None
|
||||
|
||||
|
||||
def _parse_ui_tars_coordinates(text: str, screen_w: int, screen_h: int) -> Optional[tuple]:
|
||||
"""Parse les coordonnées retournées par UI-TARS.
|
||||
|
||||
UI-TARS peut retourner :
|
||||
- Coordonnées normalisées (0-1000) : "click at (500, 300)"
|
||||
- Coordonnées en pixels : "click at (960, 540)"
|
||||
- Format (x, y) ou [x, y] ou x,y
|
||||
- Format "Action: click\nCoordinate: (500, 300)" ou "[500, 300]"
|
||||
|
||||
Returns:
|
||||
(x_pixel, y_pixel) ou None
|
||||
"""
|
||||
import re
|
||||
|
||||
# Chercher des patterns de coordonnées
|
||||
patterns = [
|
||||
r'Coordinate:\s*\[?\(?\s*(\d+(?:\.\d+)?)\s*,\s*(\d+(?:\.\d+)?)\s*\)?\]?',
|
||||
r'click\s+(?:at\s+)?\[?\(?\s*(\d+(?:\.\d+)?)\s*,\s*(\d+(?:\.\d+)?)\s*\)?\]?',
|
||||
r'\(\s*(\d+(?:\.\d+)?)\s*,\s*(\d+(?:\.\d+)?)\s*\)',
|
||||
r'\[\s*(\d+(?:\.\d+)?)\s*,\s*(\d+(?:\.\d+)?)\s*\]',
|
||||
]
|
||||
|
||||
for pattern in patterns:
|
||||
match = re.search(pattern, text, re.IGNORECASE)
|
||||
if match:
|
||||
raw_x = float(match.group(1))
|
||||
raw_y = float(match.group(2))
|
||||
|
||||
# UI-TARS utilise souvent des coordonnées normalisées 0-1000
|
||||
if raw_x <= 1000 and raw_y <= 1000 and (raw_x > 1 or raw_y > 1):
|
||||
# Probablement normalisées sur 1000
|
||||
x = int(raw_x * screen_w / 1000)
|
||||
y = int(raw_y * screen_h / 1000)
|
||||
elif raw_x <= 1.0 and raw_y <= 1.0:
|
||||
# Normalisées 0-1
|
||||
x = int(raw_x * screen_w)
|
||||
y = int(raw_y * screen_h)
|
||||
else:
|
||||
# Pixels directs
|
||||
x = int(raw_x)
|
||||
y = int(raw_y)
|
||||
|
||||
return (x, y)
|
||||
|
||||
return None
|
||||
|
||||
|
||||
def _grounding_vlm(target_text: str, target_description: str = "") -> Optional[Dict[str, Any]]:
|
||||
"""Niveau 3 — VLM reasoning + confirmation OCR (~10s)."""
|
||||
try:
|
||||
search_label = target_description or target_text
|
||||
|
||||
vlm_result = vlm_reason_about_screen(
|
||||
objective=f"Cliquer sur {search_label}",
|
||||
context=f"Je cherche l'élément '{target_text}' sur l'écran pour cliquer dessus"
|
||||
)
|
||||
|
||||
if not vlm_result:
|
||||
logger.debug("[Grounding/VLM] VLM n'a pas retourné de résultat")
|
||||
return None
|
||||
|
||||
if vlm_result.get('action') != 'click' or not vlm_result.get('target'):
|
||||
logger.debug(f"[Grounding/VLM] VLM action={vlm_result.get('action')}, pas un clic")
|
||||
return None
|
||||
|
||||
vlm_target = vlm_result['target']
|
||||
logger.info(f"[Grounding/VLM] VLM suggère de cliquer sur: '{vlm_target}'")
|
||||
|
||||
# Confirmation par OCR : chercher le target VLM sur l'écran
|
||||
screen, screen_w, screen_h = _capture_screen()
|
||||
if screen is None:
|
||||
return None
|
||||
|
||||
try:
|
||||
try:
|
||||
from services.ocr_service import ocr_extract_words
|
||||
except ImportError:
|
||||
from core.extraction.field_extractor import FieldExtractor
|
||||
extractor = FieldExtractor()
|
||||
def ocr_extract_words(img):
|
||||
return extractor.extract_words_from_image(img)
|
||||
|
||||
words = ocr_extract_words(screen)
|
||||
|
||||
vlm_target_lower = vlm_target.lower()
|
||||
for word in words:
|
||||
if vlm_target_lower in word['text'].lower() or word['text'].lower() in vlm_target_lower:
|
||||
x1, y1, x2, y2 = word['bbox']
|
||||
x = int((x1 + x2) / 2)
|
||||
y = int((y1 + y2) / 2)
|
||||
logger.info(f"[Grounding/VLM] Confirmé par OCR: '{word['text']}' à ({x}, {y})")
|
||||
return {'x': x, 'y': y, 'method': 'vlm', 'confidence': 0.75}
|
||||
|
||||
logger.debug(f"[Grounding/VLM] Target VLM '{vlm_target}' non trouvé par OCR")
|
||||
return None
|
||||
|
||||
except Exception as e:
|
||||
logger.debug(f"[Grounding/VLM] OCR de confirmation échoué: {e}")
|
||||
return None
|
||||
|
||||
except Exception as e:
|
||||
logger.debug(f"[Grounding/VLM] Erreur: {e}")
|
||||
return None
|
||||
|
||||
|
||||
def post_execution_cleanup(execution_mode: str = 'debug'):
|
||||
"""Vérifie l'écran après exécution et gère les dialogues restants.
|
||||
|
||||
Appelé après la dernière étape d'un workflow pour laisser l'écran propre.
|
||||
"""
|
||||
if execution_mode not in ('intelligent', 'debug'):
|
||||
return
|
||||
|
||||
logger.info("Vérification écran final...")
|
||||
time.sleep(1.0)
|
||||
for _ in range(3):
|
||||
detected = check_screen_for_patterns()
|
||||
if detected:
|
||||
logger.info(f"Dialogue résiduel détecté: {detected.get('pattern')}")
|
||||
handle_detected_pattern(detected)
|
||||
time.sleep(1.0)
|
||||
else:
|
||||
vlm_result = vlm_reason_about_screen(
|
||||
objective="Vérifier que l'écran est propre après l'exécution",
|
||||
context="Le workflow vient de se terminer"
|
||||
)
|
||||
if vlm_result and vlm_result.get('action') in ('click', 'type'):
|
||||
logger.info(f"VLM post-workflow: {vlm_result.get('action')} '{vlm_result.get('target')}'")
|
||||
break
|
||||
@@ -40,12 +40,16 @@ class LLMActionHandler:
|
||||
def __init__(
|
||||
self,
|
||||
ollama_endpoint: str = "http://localhost:11434",
|
||||
model: str = "qwen3-vl:8b",
|
||||
model: str = None,
|
||||
temperature: float = 0.1,
|
||||
timeout: int = 120,
|
||||
):
|
||||
self.endpoint = ollama_endpoint.rstrip("/")
|
||||
if model is not None:
|
||||
self.model = model
|
||||
else:
|
||||
from core.detection.vlm_config import get_vlm_model
|
||||
self.model = get_vlm_model()
|
||||
self.temperature = temperature
|
||||
self.timeout = timeout
|
||||
|
||||
|
||||
2006
core/execution/observe_reason_act.py
Normal file
2006
core/execution/observe_reason_act.py
Normal file
File diff suppressed because it is too large
Load Diff
228
core/execution/safe_condition_evaluator.py
Normal file
228
core/execution/safe_condition_evaluator.py
Normal file
@@ -0,0 +1,228 @@
|
||||
"""
|
||||
Évaluateur de conditions sécurisé pour le DAGExecutor.
|
||||
|
||||
Remplace `eval()` (vulnérable à l'exécution de code arbitraire) par un
|
||||
parseur AST restreint :
|
||||
|
||||
- Seuls les noeuds AST nécessaires sont autorisés (literals, comparaisons,
|
||||
booléens, indexations, accès attribut limité, arithmétique simple).
|
||||
- Les appels de fonction sont interdits.
|
||||
- Les accès à des attributs « dunder » (`__class__`, `__import__`, etc.)
|
||||
sont systématiquement refusés pour éviter les évasions classiques.
|
||||
- Le contexte d'évaluation est fourni explicitement par l'appelant ;
|
||||
aucun builtins n'est exposé.
|
||||
|
||||
Usage typique :
|
||||
>>> evaluator = SafeConditionEvaluator()
|
||||
>>> evaluator.evaluate("results['step_1']['score'] >= 0.8",
|
||||
... {"results": {"step_1": {"score": 0.92}}})
|
||||
True
|
||||
"""
|
||||
|
||||
from __future__ import annotations
|
||||
|
||||
import ast
|
||||
import operator
|
||||
from typing import Any, Callable, Dict, Mapping
|
||||
|
||||
|
||||
class UnsafeExpressionError(ValueError):
|
||||
"""Levée lorsqu'une expression contient un noeud AST interdit."""
|
||||
|
||||
|
||||
# Opérateurs arithmétiques & de comparaison autorisés.
|
||||
_BIN_OPS: Dict[type, Callable[[Any, Any], Any]] = {
|
||||
ast.Add: operator.add,
|
||||
ast.Sub: operator.sub,
|
||||
ast.Mult: operator.mul,
|
||||
ast.Div: operator.truediv,
|
||||
ast.FloorDiv: operator.floordiv,
|
||||
ast.Mod: operator.mod,
|
||||
ast.Pow: operator.pow,
|
||||
}
|
||||
|
||||
_BOOL_OPS: Dict[type, Callable[[Any, Any], Any]] = {
|
||||
ast.And: lambda a, b: a and b,
|
||||
ast.Or: lambda a, b: a or b,
|
||||
}
|
||||
|
||||
_UNARY_OPS: Dict[type, Callable[[Any], Any]] = {
|
||||
ast.Not: operator.not_,
|
||||
ast.USub: operator.neg,
|
||||
ast.UAdd: operator.pos,
|
||||
}
|
||||
|
||||
_CMP_OPS: Dict[type, Callable[[Any, Any], bool]] = {
|
||||
ast.Eq: operator.eq,
|
||||
ast.NotEq: operator.ne,
|
||||
ast.Lt: operator.lt,
|
||||
ast.LtE: operator.le,
|
||||
ast.Gt: operator.gt,
|
||||
ast.GtE: operator.ge,
|
||||
ast.In: lambda a, b: a in b,
|
||||
ast.NotIn: lambda a, b: a not in b,
|
||||
ast.Is: operator.is_,
|
||||
ast.IsNot: operator.is_not,
|
||||
}
|
||||
|
||||
|
||||
class SafeConditionEvaluator:
|
||||
"""Évalue une expression de condition via un parseur AST restreint."""
|
||||
|
||||
# Longueur max — stoppe les expressions pathologiques très tôt.
|
||||
MAX_EXPRESSION_LENGTH = 1024
|
||||
|
||||
def evaluate(
|
||||
self,
|
||||
expression: str,
|
||||
context: Mapping[str, Any],
|
||||
) -> Any:
|
||||
if not isinstance(expression, str):
|
||||
raise UnsafeExpressionError(
|
||||
"L'expression doit être une chaîne de caractères."
|
||||
)
|
||||
if len(expression) > self.MAX_EXPRESSION_LENGTH:
|
||||
raise UnsafeExpressionError(
|
||||
"Expression trop longue (> 1024 caractères)."
|
||||
)
|
||||
|
||||
try:
|
||||
tree = ast.parse(expression, mode="eval")
|
||||
except SyntaxError as exc:
|
||||
raise UnsafeExpressionError(
|
||||
f"Syntaxe d'expression invalide : {exc}"
|
||||
) from exc
|
||||
|
||||
return self._eval_node(tree.body, context)
|
||||
|
||||
# ------------------------------------------------------------------
|
||||
# Dispatch AST
|
||||
# ------------------------------------------------------------------
|
||||
|
||||
def _eval_node(self, node: ast.AST, context: Mapping[str, Any]) -> Any:
|
||||
# Littéraux (Constant remplace Num/Str/Bytes/NameConstant depuis 3.8)
|
||||
if isinstance(node, ast.Constant):
|
||||
return node.value
|
||||
|
||||
# Variables : uniquement celles présentes dans `context`.
|
||||
if isinstance(node, ast.Name):
|
||||
if node.id not in context:
|
||||
raise UnsafeExpressionError(
|
||||
f"Variable '{node.id}' non autorisée."
|
||||
)
|
||||
return context[node.id]
|
||||
|
||||
# Accès attribut — interdit tout attribut dunder.
|
||||
if isinstance(node, ast.Attribute):
|
||||
if node.attr.startswith("_"):
|
||||
raise UnsafeExpressionError(
|
||||
f"Accès à l'attribut privé '{node.attr}' interdit."
|
||||
)
|
||||
value = self._eval_node(node.value, context)
|
||||
return getattr(value, node.attr)
|
||||
|
||||
# Indexation (results['step_1']).
|
||||
if isinstance(node, ast.Subscript):
|
||||
value = self._eval_node(node.value, context)
|
||||
# Python < 3.9 utilise ast.Index, >= 3.9 utilise directement un
|
||||
# noeud. On gère les deux cas.
|
||||
slice_node = node.slice
|
||||
if isinstance(slice_node, ast.Index): # type: ignore[attr-defined]
|
||||
slice_value = self._eval_node(
|
||||
slice_node.value, context # type: ignore[attr-defined]
|
||||
)
|
||||
else:
|
||||
slice_value = self._eval_node(slice_node, context)
|
||||
return value[slice_value]
|
||||
|
||||
# Comparaisons chaînées (a < b <= c).
|
||||
if isinstance(node, ast.Compare):
|
||||
left = self._eval_node(node.left, context)
|
||||
for op_node, comparator in zip(node.ops, node.comparators):
|
||||
op_cls = type(op_node)
|
||||
if op_cls not in _CMP_OPS:
|
||||
raise UnsafeExpressionError(
|
||||
f"Opérateur de comparaison '{op_cls.__name__}' interdit."
|
||||
)
|
||||
right = self._eval_node(comparator, context)
|
||||
if not _CMP_OPS[op_cls](left, right):
|
||||
return False
|
||||
left = right
|
||||
return True
|
||||
|
||||
# Booléen (and / or) — short-circuit manuel.
|
||||
if isinstance(node, ast.BoolOp):
|
||||
op_cls = type(node.op)
|
||||
if op_cls not in _BOOL_OPS:
|
||||
raise UnsafeExpressionError(
|
||||
f"Opérateur booléen '{op_cls.__name__}' interdit."
|
||||
)
|
||||
if isinstance(node.op, ast.And):
|
||||
result: Any = True
|
||||
for sub in node.values:
|
||||
result = self._eval_node(sub, context)
|
||||
if not result:
|
||||
return result
|
||||
return result
|
||||
# Or
|
||||
result = False
|
||||
for sub in node.values:
|
||||
result = self._eval_node(sub, context)
|
||||
if result:
|
||||
return result
|
||||
return result
|
||||
|
||||
# Unaires (-x, not x)
|
||||
if isinstance(node, ast.UnaryOp):
|
||||
op_cls = type(node.op)
|
||||
if op_cls not in _UNARY_OPS:
|
||||
raise UnsafeExpressionError(
|
||||
f"Opérateur unaire '{op_cls.__name__}' interdit."
|
||||
)
|
||||
return _UNARY_OPS[op_cls](self._eval_node(node.operand, context))
|
||||
|
||||
# Binaires (+, -, *, /, %, **, //)
|
||||
if isinstance(node, ast.BinOp):
|
||||
op_cls = type(node.op)
|
||||
if op_cls not in _BIN_OPS:
|
||||
raise UnsafeExpressionError(
|
||||
f"Opérateur binaire '{op_cls.__name__}' interdit."
|
||||
)
|
||||
left = self._eval_node(node.left, context)
|
||||
right = self._eval_node(node.right, context)
|
||||
return _BIN_OPS[op_cls](left, right)
|
||||
|
||||
# Literals composites
|
||||
if isinstance(node, ast.Tuple):
|
||||
return tuple(self._eval_node(e, context) for e in node.elts)
|
||||
if isinstance(node, ast.List):
|
||||
return [self._eval_node(e, context) for e in node.elts]
|
||||
if isinstance(node, ast.Set):
|
||||
return {self._eval_node(e, context) for e in node.elts}
|
||||
if isinstance(node, ast.Dict):
|
||||
return {
|
||||
self._eval_node(k, context) if k is not None else None:
|
||||
self._eval_node(v, context)
|
||||
for k, v in zip(node.keys, node.values)
|
||||
}
|
||||
|
||||
# Tout le reste (Call, Lambda, Comprehensions, Import, etc.) est
|
||||
# refusé explicitement.
|
||||
raise UnsafeExpressionError(
|
||||
f"Noeud AST '{type(node).__name__}' interdit dans les conditions."
|
||||
)
|
||||
|
||||
|
||||
def safe_eval_condition(
|
||||
expression: str,
|
||||
context: Mapping[str, Any],
|
||||
) -> Any:
|
||||
"""Helper fonctionnel : évalue `expression` avec le contexte donné."""
|
||||
return SafeConditionEvaluator().evaluate(expression, context)
|
||||
|
||||
|
||||
__all__ = [
|
||||
"SafeConditionEvaluator",
|
||||
"UnsafeExpressionError",
|
||||
"safe_eval_condition",
|
||||
]
|
||||
@@ -1697,12 +1697,6 @@ class TargetResolver:
|
||||
|
||||
return best_elem, tie_break_criterion
|
||||
|
||||
# Spatial analyzer (lazy load) - Exigence 5.3
|
||||
self._spatial_analyzer: Optional[SpatialAnalyzer] = None
|
||||
self._spatial_relations_cache: Dict[str, List[SpatialRelation]] = {}
|
||||
|
||||
logger.info(f"TargetResolver initialized (threshold={similarity_threshold}, spatial={use_spatial_fallback})")
|
||||
|
||||
# =========================================================================
|
||||
# Résolution principale
|
||||
# =========================================================================
|
||||
|
||||
@@ -22,7 +22,7 @@ logger = logging.getLogger(__name__)
|
||||
|
||||
# Configuration Ollama (coherente avec le reste du projet)
|
||||
OLLAMA_DEFAULT_URL = os.environ.get("OLLAMA_URL", "http://localhost:11434")
|
||||
OLLAMA_DEFAULT_MODEL = os.environ.get("VLM_MODEL", "qwen3-vl:8b")
|
||||
OLLAMA_DEFAULT_MODEL = os.environ.get("RPA_VLM_MODEL", os.environ.get("VLM_MODEL", "gemma4:e4b"))
|
||||
|
||||
|
||||
class FieldExtractor:
|
||||
|
||||
@@ -2,7 +2,7 @@
|
||||
GPU Resource Management Module for RPA Vision V3
|
||||
|
||||
This module provides dynamic GPU resource allocation between ML models:
|
||||
- Ollama VLM (qwen3-vl:8b) for UI classification
|
||||
- Ollama VLM (gemma4:e4b par défaut, configurable via RPA_VLM_MODEL) for UI classification
|
||||
- CLIP (ViT-B-32) for embedding matching
|
||||
|
||||
The GPUResourceManager optimizes VRAM usage by:
|
||||
|
||||
@@ -2,7 +2,7 @@
|
||||
GPU Resource Manager - Central orchestrator for GPU resource allocation
|
||||
|
||||
Manages dynamic allocation of GPU resources between:
|
||||
- Ollama VLM (qwen3-vl:8b) - ~10.5 GB VRAM for UI classification
|
||||
- Ollama VLM (gemma4:e4b par défaut) - ~10 GB VRAM for UI classification
|
||||
- CLIP (ViT-B-32) - ~500 MB VRAM for embedding matching
|
||||
|
||||
Optimizes VRAM usage based on execution mode:
|
||||
@@ -12,13 +12,14 @@ Optimizes VRAM usage based on execution mode:
|
||||
"""
|
||||
|
||||
import asyncio
|
||||
import contextlib
|
||||
import logging
|
||||
import threading
|
||||
import time
|
||||
from dataclasses import dataclass, field
|
||||
from datetime import datetime
|
||||
from enum import Enum
|
||||
from typing import Any, Callable, Dict, List, Optional
|
||||
from typing import Any, Callable, Dict, Iterator, List, Optional
|
||||
|
||||
logger = logging.getLogger(__name__)
|
||||
|
||||
@@ -53,7 +54,7 @@ class VRAMInfo:
|
||||
class GPUResourceConfig:
|
||||
"""Configuration for GPU resource management."""
|
||||
ollama_endpoint: str = "http://localhost:11434"
|
||||
vlm_model: str = "qwen3-vl:8b"
|
||||
vlm_model: str = "gemma4:e4b"
|
||||
clip_model: str = "ViT-B-32"
|
||||
idle_timeout_seconds: int = 300 # 5 minutes
|
||||
vram_threshold_for_clip_gpu_mb: int = 1024 # 1 GB
|
||||
@@ -127,6 +128,12 @@ class GPUResourceManager:
|
||||
self._operation_queue: asyncio.Queue = asyncio.Queue()
|
||||
self._operation_lock = asyncio.Lock()
|
||||
|
||||
# Lock d'inférence synchrone : sérialise les appels GPU concurrents
|
||||
# (ScreenAnalyzer.analyze, UIDetector, CLIP.encode) entre
|
||||
# ExecutionLoop et stream_processor pour éviter la saturation VRAM
|
||||
# sur RTX 5070 (12 Go). Un seul analyze à la fois sur le GPU.
|
||||
self._inference_lock = threading.Lock()
|
||||
|
||||
# Event callbacks
|
||||
self._on_resource_changed: List[Callable[[ResourceChangedEvent], None]] = []
|
||||
self._on_mode_changed: List[Callable[[ExecutionMode], None]] = []
|
||||
@@ -208,6 +215,44 @@ class GPUResourceManager:
|
||||
"""Get the current execution mode."""
|
||||
return self._execution_mode
|
||||
|
||||
# =========================================================================
|
||||
# Inference serialization (sync)
|
||||
# =========================================================================
|
||||
|
||||
@contextlib.contextmanager
|
||||
def acquire_inference(self, timeout: Optional[float] = None) -> Iterator[bool]:
|
||||
"""
|
||||
Context manager synchrone pour sérialiser les inférences GPU.
|
||||
|
||||
Garantit qu'un seul appel d'inférence (ScreenAnalyzer.analyze,
|
||||
UIDetector.detect, CLIP.encode…) tourne à la fois sur le GPU.
|
||||
Évite la saturation VRAM quand ExecutionLoop et stream_processor
|
||||
appellent analyze() simultanément sur une RTX 5070 (12 Go).
|
||||
|
||||
Args:
|
||||
timeout: Délai max d'attente (secondes). None = bloquant.
|
||||
|
||||
Yields:
|
||||
True si le lock est acquis, False en cas de timeout.
|
||||
|
||||
Example:
|
||||
>>> with gpu_manager.acquire_inference(timeout=30.0) as acquired:
|
||||
... if not acquired:
|
||||
... logger.warning("GPU lock timeout")
|
||||
... state = analyzer.analyze(path)
|
||||
"""
|
||||
if timeout is None:
|
||||
self._inference_lock.acquire()
|
||||
acquired = True
|
||||
else:
|
||||
acquired = self._inference_lock.acquire(timeout=timeout)
|
||||
|
||||
try:
|
||||
yield acquired
|
||||
finally:
|
||||
if acquired:
|
||||
self._inference_lock.release()
|
||||
|
||||
# =========================================================================
|
||||
# VLM Management
|
||||
# =========================================================================
|
||||
|
||||
@@ -32,7 +32,7 @@ class OllamaManager:
|
||||
def __init__(
|
||||
self,
|
||||
endpoint: str = "http://localhost:11434",
|
||||
model: str = "qwen3-vl:8b",
|
||||
model: str = "gemma4:e4b",
|
||||
default_keep_alive: str = "5m"
|
||||
):
|
||||
"""
|
||||
|
||||
@@ -173,6 +173,10 @@ class GraphBuilder:
|
||||
clustering_eps: float = 0.08,
|
||||
clustering_min_samples: int = 2,
|
||||
enable_quality_validation: bool = True,
|
||||
ui_detector: Optional[Any] = None,
|
||||
screen_analyzer: Optional[Any] = None,
|
||||
enable_ui_enrichment: bool = True,
|
||||
element_proximity_max_px: float = 50.0,
|
||||
):
|
||||
"""
|
||||
Initialiser le GraphBuilder.
|
||||
@@ -185,6 +189,17 @@ class GraphBuilder:
|
||||
clustering_eps: Epsilon pour DBSCAN (distance max entre points)
|
||||
clustering_min_samples: Nombre minimum d'échantillons pour un cluster
|
||||
enable_quality_validation: Activer la validation de qualité
|
||||
ui_detector: UIDetector optionnel. Si fourni, sera utilisé par
|
||||
l'analyzer lazy-initialisé. Sinon, fallback sur le singleton
|
||||
partagé (`get_screen_analyzer()`).
|
||||
screen_analyzer: Instance ScreenAnalyzer à utiliser directement.
|
||||
Si None, lazy init via le singleton partagé C1.
|
||||
enable_ui_enrichment: Active l'enrichissement visuel des
|
||||
ScreenStates lors de `_create_screen_states` (OCR + UIDetector).
|
||||
False = comportement historique (ui_elements=[], detected_text=[]).
|
||||
element_proximity_max_px: Distance maximale (en pixels) entre un
|
||||
clic et le bbox le plus proche pour qu'un UIElement soit
|
||||
considéré comme cible. Au-delà, le clic reste sans ancre.
|
||||
"""
|
||||
self.embedding_builder = embedding_builder or StateEmbeddingBuilder()
|
||||
self.faiss_manager = faiss_manager
|
||||
@@ -193,22 +208,73 @@ class GraphBuilder:
|
||||
self.clustering_eps = clustering_eps
|
||||
self.clustering_min_samples = clustering_min_samples
|
||||
self.enable_quality_validation = enable_quality_validation
|
||||
self._screen_analyzer = None # ScreenAnalyzer (lazy import)
|
||||
self.enable_ui_enrichment = enable_ui_enrichment
|
||||
self.element_proximity_max_px = element_proximity_max_px
|
||||
# UIDetector explicite (optionnel) — injecté dans l'analyzer lazy.
|
||||
self._ui_detector = ui_detector
|
||||
# Instance ScreenAnalyzer. Si fournie, on l'utilise telle quelle ;
|
||||
# sinon, on bascule sur le singleton partagé (lazy init).
|
||||
self._screen_analyzer = screen_analyzer
|
||||
|
||||
logger.info(
|
||||
f"GraphBuilder initialized: "
|
||||
f"min_repetitions={min_pattern_repetitions}, "
|
||||
f"eps={clustering_eps}, "
|
||||
f"min_samples={clustering_min_samples}, "
|
||||
f"quality_validation={enable_quality_validation}"
|
||||
f"quality_validation={enable_quality_validation}, "
|
||||
f"ui_enrichment={enable_ui_enrichment}"
|
||||
)
|
||||
|
||||
# ------------------------------------------------------------------
|
||||
# Résolution paresseuse du ScreenAnalyzer (singleton C1 par défaut)
|
||||
# ------------------------------------------------------------------
|
||||
|
||||
def _get_screen_analyzer(self):
|
||||
"""
|
||||
Retourner l'instance ScreenAnalyzer à utiliser.
|
||||
|
||||
Priorité :
|
||||
1. Instance injectée via le constructeur (`screen_analyzer=…`).
|
||||
2. Singleton partagé `get_screen_analyzer()` (C1) — évite le double
|
||||
chargement GPU quand ExecutionLoop et stream_processor tournent.
|
||||
3. En dernier recours (import circulaire, tests), création locale.
|
||||
"""
|
||||
if self._screen_analyzer is not None:
|
||||
return self._screen_analyzer
|
||||
|
||||
try:
|
||||
from core.pipeline import get_screen_analyzer
|
||||
|
||||
self._screen_analyzer = get_screen_analyzer(
|
||||
ui_detector=self._ui_detector,
|
||||
)
|
||||
return self._screen_analyzer
|
||||
except Exception as e:
|
||||
logger.warning(
|
||||
f"Impossible d'obtenir le ScreenAnalyzer singleton "
|
||||
f"({e}); fallback sur une instance locale."
|
||||
)
|
||||
try:
|
||||
from core.pipeline.screen_analyzer import ScreenAnalyzer
|
||||
|
||||
self._screen_analyzer = ScreenAnalyzer(
|
||||
ui_detector=self._ui_detector,
|
||||
)
|
||||
return self._screen_analyzer
|
||||
except Exception as e2:
|
||||
logger.error(
|
||||
f"Impossible d'instancier ScreenAnalyzer: {e2}. "
|
||||
"Enrichissement UI désactivé."
|
||||
)
|
||||
return None
|
||||
|
||||
def build_from_session(
|
||||
self,
|
||||
session: RawSession,
|
||||
workflow_name: Optional[str] = None,
|
||||
precomputed_states: Optional[List["ScreenState"]] = None,
|
||||
precomputed_embeddings: Optional[List] = None,
|
||||
sequential: bool = False,
|
||||
) -> Workflow:
|
||||
"""
|
||||
Construire un Workflow complet depuis une RawSession.
|
||||
@@ -216,7 +282,7 @@ class GraphBuilder:
|
||||
Processus:
|
||||
1. Créer ScreenStates depuis screenshots (ou utiliser precomputed_states)
|
||||
2. Calculer embeddings pour chaque état (ou réutiliser precomputed_embeddings)
|
||||
3. Détecter patterns via clustering
|
||||
3. Détecter patterns via clustering (ou mode séquentiel)
|
||||
4. Construire nodes depuis clusters
|
||||
5. Construire edges depuis transitions
|
||||
|
||||
@@ -228,6 +294,10 @@ class GraphBuilder:
|
||||
precomputed_embeddings: Embeddings déjà calculés (streaming).
|
||||
Si fourni et de la bonne longueur (= len(screen_states)),
|
||||
saute l'étape 2 (pas de recalcul CLIP).
|
||||
sequential: Si True, crée un node par état d'écran (pas de
|
||||
clustering DBSCAN). Approprié pour les enregistrements
|
||||
single-pass d'un workflow — chaque screenshot est une étape
|
||||
distincte avec ses actions associées.
|
||||
|
||||
Returns:
|
||||
Workflow construit avec nodes et edges
|
||||
@@ -242,6 +312,7 @@ class GraphBuilder:
|
||||
f"Building workflow from session {session.session_id} "
|
||||
f"with {len(precomputed_states or session.screenshots)} "
|
||||
f"{'precomputed states' if precomputed_states else 'screenshots'}"
|
||||
f"{' (mode séquentiel)' if sequential else ''}"
|
||||
)
|
||||
|
||||
# Étape 1: Créer ScreenStates (ou réutiliser ceux pré-calculés)
|
||||
@@ -266,7 +337,16 @@ class GraphBuilder:
|
||||
embeddings = self._compute_embeddings(screen_states)
|
||||
logger.debug(f"Computed {len(embeddings)} embeddings")
|
||||
|
||||
# Étape 3: Détecter patterns
|
||||
# Étape 3: Détecter patterns ou mode séquentiel
|
||||
if sequential:
|
||||
# Mode séquentiel : chaque état d'écran est un node distinct.
|
||||
# Pas de clustering — essentiel pour les enregistrements single-pass
|
||||
# où l'on veut reproduire fidèlement la séquence des actions.
|
||||
clusters = {i: [i] for i in range(len(screen_states))}
|
||||
logger.info(
|
||||
f"Mode séquentiel: {len(clusters)} nodes (1 par état)"
|
||||
)
|
||||
else:
|
||||
clusters = self._detect_patterns(embeddings, screen_states)
|
||||
logger.info(f"Detected {len(clusters)} patterns")
|
||||
|
||||
@@ -275,7 +355,10 @@ class GraphBuilder:
|
||||
logger.info(f"Built {len(nodes)} workflow nodes")
|
||||
|
||||
# Étape 5: Construire edges (passer les embeddings pour éviter recalcul)
|
||||
edges = self._build_edges(nodes, screen_states, session, embeddings=embeddings)
|
||||
edges = self._build_edges(
|
||||
nodes, screen_states, session, embeddings=embeddings,
|
||||
sequential=sequential,
|
||||
)
|
||||
logger.info(f"Built {len(edges)} workflow edges")
|
||||
|
||||
# Créer Workflow
|
||||
@@ -395,11 +478,28 @@ class GraphBuilder:
|
||||
if event.screenshot_id:
|
||||
screenshot_to_event[event.screenshot_id] = event
|
||||
|
||||
# Récupérer (une seule fois) l'analyzer partagé si l'enrichissement est actif.
|
||||
# Le singleton C1 garantit qu'on ne recharge pas UIDetector/CLIP inutilement.
|
||||
analyzer = None
|
||||
if self.enable_ui_enrichment:
|
||||
analyzer = self._get_screen_analyzer()
|
||||
|
||||
# Cache partagé (C1) : réutiliser les analyses si même screenshot est
|
||||
# repassé plusieurs fois (peu fréquent en construction, utile en tests).
|
||||
try:
|
||||
from core.pipeline import get_screen_state_cache
|
||||
|
||||
state_cache = get_screen_state_cache()
|
||||
except Exception as e:
|
||||
logger.debug(f"ScreenStateCache indisponible ({e}); aucun cache utilisé.")
|
||||
state_cache = None
|
||||
|
||||
enriched_count = 0
|
||||
for i, screenshot in enumerate(session.screenshots):
|
||||
# Trouver l'événement associé
|
||||
event = screenshot_to_event.get(screenshot.screenshot_id)
|
||||
|
||||
# Créer WindowContext depuis l'événement
|
||||
# Construire WindowContext depuis l'événement (si dispo)
|
||||
screen_env = session.environment.get("screen", {})
|
||||
screen_res = screen_env.get("primary_resolution", [1920, 1080])
|
||||
if event and event.window:
|
||||
@@ -427,59 +527,127 @@ class GraphBuilder:
|
||||
os_language=session.environment.get("os_language", "unknown"),
|
||||
)
|
||||
|
||||
# Créer RawLevel
|
||||
# Construire chemin absolu : data/training/sessions/{session_id}/{session_id}/{relative_path}
|
||||
screenshot_absolute_path = f"data/training/sessions/{session.session_id}/{session.session_id}/{screenshot.relative_path}"
|
||||
# Chemin absolu du screenshot
|
||||
screenshot_absolute_path = (
|
||||
f"data/training/sessions/{session.session_id}/"
|
||||
f"{session.session_id}/{screenshot.relative_path}"
|
||||
)
|
||||
screenshot_path = Path(screenshot_absolute_path)
|
||||
|
||||
# Timestamp
|
||||
if isinstance(screenshot.captured_at, str):
|
||||
timestamp = datetime.fromisoformat(
|
||||
screenshot.captured_at.replace('Z', '+00:00')
|
||||
)
|
||||
else:
|
||||
timestamp = screenshot.captured_at
|
||||
|
||||
# ------------------------------------------------------------
|
||||
# Enrichissement visuel : déléguer au ScreenAnalyzer partagé
|
||||
# ------------------------------------------------------------
|
||||
# L'analyzer renvoie un ScreenState complet avec :
|
||||
# - raw (image + file_size)
|
||||
# - perception (OCR + embedding ref)
|
||||
# - ui_elements (détection UIDetector)
|
||||
# On récupère ces niveaux et on rebâtit un état final avec le
|
||||
# WindowContext et les metadata issus de la session brute (les
|
||||
# données "metier" que l'analyzer ignore).
|
||||
# ------------------------------------------------------------
|
||||
detected_text: List[str] = []
|
||||
text_method = "none"
|
||||
ui_elements: List = []
|
||||
raw = RawLevel(
|
||||
screenshot_path=str(screenshot_path),
|
||||
capture_method="mss",
|
||||
file_size_bytes=screenshot_path.stat().st_size if screenshot_path.exists() else 0
|
||||
file_size_bytes=(
|
||||
screenshot_path.stat().st_size
|
||||
if screenshot_path.exists()
|
||||
else 0
|
||||
),
|
||||
)
|
||||
|
||||
# Créer PerceptionLevel — enrichir avec OCR si le screenshot existe
|
||||
detected_text = []
|
||||
text_method = "none"
|
||||
|
||||
if screenshot_path.exists():
|
||||
if analyzer is not None and screenshot_path.exists():
|
||||
try:
|
||||
if self._screen_analyzer is None:
|
||||
from core.pipeline.screen_analyzer import ScreenAnalyzer
|
||||
self._screen_analyzer = ScreenAnalyzer(session_id=session.session_id)
|
||||
extracted = self._screen_analyzer._extract_text(str(screenshot_path))
|
||||
if extracted:
|
||||
detected_text = extracted
|
||||
text_method = self._screen_analyzer._get_ocr_method_name()
|
||||
except Exception as e:
|
||||
logger.debug(f"OCR échoué pour {screenshot_path}: {e}")
|
||||
# Construire l'info fenêtre pour donner le contexte à
|
||||
# l'UIDetector (certains détecteurs s'en servent pour
|
||||
# filtrer hors-fenêtre).
|
||||
window_info = {
|
||||
"app_name": window.app_name,
|
||||
"title": window.window_title,
|
||||
"screen_resolution": list(window.screen_resolution or []),
|
||||
}
|
||||
|
||||
analyzed = analyzer.analyze(
|
||||
str(screenshot_path),
|
||||
window_info=window_info,
|
||||
enable_ocr=True,
|
||||
enable_ui_detection=True,
|
||||
session_id=session.session_id,
|
||||
)
|
||||
detected_text = list(analyzed.perception.detected_text or [])
|
||||
text_method = (
|
||||
analyzed.perception.text_detection_method or "none"
|
||||
)
|
||||
ui_elements = list(analyzed.ui_elements or [])
|
||||
# Garder les métriques OCR/UI si présentes (debug)
|
||||
analyzer_metadata = dict(analyzed.metadata or {})
|
||||
raw = analyzed.raw # conserver file_size réel mesuré
|
||||
if ui_elements:
|
||||
enriched_count += 1
|
||||
except Exception as e:
|
||||
logger.warning(
|
||||
f"Enrichissement visuel échoué pour {screenshot_path}: {e}. "
|
||||
"Fallback sur ScreenState minimal."
|
||||
)
|
||||
analyzer_metadata = {"analyzer_error": str(e)}
|
||||
else:
|
||||
analyzer_metadata = {}
|
||||
if self.enable_ui_enrichment and not screenshot_path.exists():
|
||||
logger.debug(
|
||||
f"Screenshot introuvable: {screenshot_path} "
|
||||
"— ui_elements restera vide"
|
||||
)
|
||||
|
||||
# PerceptionLevel : vector_id calculé de façon déterministe.
|
||||
perception = PerceptionLevel(
|
||||
embedding=EmbeddingRef(
|
||||
provider="openclip_ViT-B-32",
|
||||
vector_id=f"data/embeddings/screens/{session.session_id}_state_{i:04d}.npy",
|
||||
dimensions=512
|
||||
vector_id=(
|
||||
f"data/embeddings/screens/"
|
||||
f"{session.session_id}_state_{i:04d}.npy"
|
||||
),
|
||||
dimensions=512,
|
||||
),
|
||||
detected_text=detected_text,
|
||||
text_detection_method=text_method,
|
||||
confidence_avg=0.85 if detected_text else 0.0
|
||||
confidence_avg=0.85 if detected_text else 0.0,
|
||||
)
|
||||
|
||||
# Créer ContextLevel
|
||||
# ContextLevel (métier)
|
||||
context = ContextLevel(
|
||||
current_workflow_candidate=None,
|
||||
workflow_step=i,
|
||||
user_id=session.user.get("id", "unknown"),
|
||||
tags=list(session.context.get("tags", [])) if isinstance(session.context.get("tags"), list) else [],
|
||||
business_variables={}
|
||||
tags=(
|
||||
list(session.context.get("tags", []))
|
||||
if isinstance(session.context.get("tags"), list)
|
||||
else []
|
||||
),
|
||||
business_variables={},
|
||||
)
|
||||
|
||||
# Parser timestamp
|
||||
if isinstance(screenshot.captured_at, str):
|
||||
timestamp = datetime.fromisoformat(screenshot.captured_at.replace('Z', '+00:00'))
|
||||
else:
|
||||
timestamp = screenshot.captured_at
|
||||
# Metadata : on garde le lien événement/session + éventuels
|
||||
# compteurs remontés par l'analyzer.
|
||||
metadata = {
|
||||
"screenshot_id": screenshot.screenshot_id,
|
||||
"event_type": event.type if event else None,
|
||||
"event_time": event.t if event else None,
|
||||
}
|
||||
# Propager les indicateurs utiles de l'analyzer sans écraser la base.
|
||||
for key in ("ocr_ms", "ui_ms", "analyzer_error"):
|
||||
if key in analyzer_metadata:
|
||||
metadata[key] = analyzer_metadata[key]
|
||||
|
||||
# Créer ScreenState complet
|
||||
state = ScreenState(
|
||||
screen_state_id=f"{session.session_id}_state_{i:04d}",
|
||||
timestamp=timestamp,
|
||||
@@ -488,17 +656,17 @@ class GraphBuilder:
|
||||
raw=raw,
|
||||
perception=perception,
|
||||
context=context,
|
||||
metadata={
|
||||
"screenshot_id": screenshot.screenshot_id,
|
||||
"event_type": event.type if event else None,
|
||||
"event_time": event.t if event else None
|
||||
},
|
||||
ui_elements=[] # Sera rempli par UIDetector si disponible
|
||||
metadata=metadata,
|
||||
ui_elements=ui_elements,
|
||||
)
|
||||
|
||||
screen_states.append(state)
|
||||
|
||||
logger.info(f"Created {len(screen_states)} enriched screen states")
|
||||
logger.info(
|
||||
f"Created {len(screen_states)} enriched screen states "
|
||||
f"({enriched_count} avec UI détectée, "
|
||||
f"ui_enrichment={self.enable_ui_enrichment})"
|
||||
)
|
||||
return screen_states
|
||||
|
||||
def _compute_embeddings(
|
||||
@@ -924,6 +1092,99 @@ class GraphBuilder:
|
||||
constraints.sort(key=lambda c: role_counts.get(c.get("role", ""), 0), reverse=True)
|
||||
return constraints[:8]
|
||||
|
||||
# ------------------------------------------------------------------
|
||||
# Association spatiale clic → UIElement
|
||||
# ------------------------------------------------------------------
|
||||
|
||||
def _find_clicked_element(
|
||||
self,
|
||||
event: Event,
|
||||
ui_elements: List[Any],
|
||||
) -> Optional[Any]:
|
||||
"""
|
||||
Identifier l'UIElement cible d'un clic par proximité spatiale.
|
||||
|
||||
Règle :
|
||||
1. Si un bbox contient strictement la position du clic → match.
|
||||
2. Sinon, on prend le bbox le plus proche (distance euclidienne
|
||||
au bord) sous réserve qu'il soit à <= `element_proximity_max_px`.
|
||||
3. Sinon, aucun ancrage possible → None.
|
||||
|
||||
Cette association transforme un clic "aveugle" (coordonnées brutes)
|
||||
en un clic "intelligent" (rôle + label), permettant au matcher de
|
||||
retrouver l'élément même si la résolution ou la position change.
|
||||
|
||||
Args:
|
||||
event: Événement `mouse_click` (avec `data["pos"] = [x, y]`).
|
||||
ui_elements: Liste des UIElement détectés sur l'écran source.
|
||||
|
||||
Returns:
|
||||
UIElement le plus pertinent, ou None si rien ne correspond.
|
||||
"""
|
||||
if not ui_elements:
|
||||
return None
|
||||
if not event or event.type != "mouse_click":
|
||||
return None
|
||||
|
||||
pos = event.data.get("pos") if event.data else None
|
||||
if not pos or len(pos) < 2:
|
||||
return None
|
||||
|
||||
try:
|
||||
click_x = float(pos[0])
|
||||
click_y = float(pos[1])
|
||||
except (TypeError, ValueError):
|
||||
return None
|
||||
|
||||
best_contained = None
|
||||
best_contained_area = None
|
||||
best_near = None
|
||||
best_near_distance = None
|
||||
|
||||
for element in ui_elements:
|
||||
bbox = getattr(element, "bbox", None)
|
||||
if bbox is None:
|
||||
continue
|
||||
|
||||
# Extraction défensive des coordonnées (BBox Pydantic ou tuple)
|
||||
try:
|
||||
bx = int(getattr(bbox, "x", bbox[0]))
|
||||
by = int(getattr(bbox, "y", bbox[1]))
|
||||
bw = int(getattr(bbox, "width", bbox[2]))
|
||||
bh = int(getattr(bbox, "height", bbox[3]))
|
||||
except (AttributeError, IndexError, TypeError):
|
||||
continue
|
||||
|
||||
# Cas 1 : la position est strictement dans le bbox.
|
||||
if bx <= click_x <= bx + bw and by <= click_y <= by + bh:
|
||||
# Sélectionner le plus petit bbox qui contient (élément le plus spécifique)
|
||||
area = max(1, bw * bh)
|
||||
if best_contained is None or area < best_contained_area:
|
||||
best_contained = element
|
||||
best_contained_area = area
|
||||
continue
|
||||
|
||||
# Cas 2 : calculer la distance au bord le plus proche.
|
||||
dx = max(bx - click_x, 0, click_x - (bx + bw))
|
||||
dy = max(by - click_y, 0, click_y - (by + bh))
|
||||
distance = (dx * dx + dy * dy) ** 0.5
|
||||
|
||||
if best_near is None or distance < best_near_distance:
|
||||
best_near = element
|
||||
best_near_distance = distance
|
||||
|
||||
if best_contained is not None:
|
||||
return best_contained
|
||||
|
||||
if (
|
||||
best_near is not None
|
||||
and best_near_distance is not None
|
||||
and best_near_distance <= self.element_proximity_max_px
|
||||
):
|
||||
return best_near
|
||||
|
||||
return None
|
||||
|
||||
# Patterns d'erreur courants pour la détection fail_fast
|
||||
_ERROR_PATTERNS = [
|
||||
"erreur", "error", "échec", "failed", "impossible",
|
||||
@@ -937,12 +1198,14 @@ class GraphBuilder:
|
||||
screen_states: List[ScreenState],
|
||||
session: RawSession,
|
||||
embeddings: Optional[List[np.ndarray]] = None,
|
||||
sequential: bool = False,
|
||||
) -> List[WorkflowEdge]:
|
||||
"""
|
||||
Construire WorkflowEdges depuis les transitions observées.
|
||||
|
||||
Algorithme:
|
||||
1. Mapper chaque ScreenState vers son node (via embedding similarity)
|
||||
En mode séquentiel, le mapping est direct (state i → node i).
|
||||
2. Identifier les transitions (state_i -> state_j où node change)
|
||||
3. Extraire l'action depuis l'événement entre les deux états
|
||||
4. Créer WorkflowEdge avec action, pré-conditions et post-conditions
|
||||
@@ -960,6 +1223,7 @@ class GraphBuilder:
|
||||
screen_states: ScreenStates
|
||||
session: Session brute (pour événements)
|
||||
embeddings: Embeddings pré-calculés (évite un recalcul dans _map_states_to_nodes)
|
||||
sequential: Mode séquentiel — chaque paire consécutive = transition
|
||||
|
||||
Returns:
|
||||
Liste de WorkflowEdges
|
||||
@@ -975,7 +1239,19 @@ class GraphBuilder:
|
||||
node_by_id = {node.node_id: node for node in nodes}
|
||||
|
||||
# Étape 1: Mapper chaque état vers son node
|
||||
state_to_node = self._map_states_to_nodes(screen_states, nodes, embeddings=embeddings)
|
||||
if sequential:
|
||||
# Mode séquentiel : mapping direct state[i] → node[i]
|
||||
state_to_node = {}
|
||||
for i, state in enumerate(screen_states):
|
||||
if i < len(nodes):
|
||||
state_to_node[state.screen_state_id] = nodes[i].node_id
|
||||
logger.debug(
|
||||
f"Mode séquentiel: {len(state_to_node)} states mappés directement"
|
||||
)
|
||||
else:
|
||||
state_to_node = self._map_states_to_nodes(
|
||||
screen_states, nodes, embeddings=embeddings
|
||||
)
|
||||
|
||||
# Étape 2: Récupérer la résolution d'écran pour normaliser les coordonnées
|
||||
screen_env = session.environment.get("screen", {})
|
||||
@@ -989,8 +1265,11 @@ class GraphBuilder:
|
||||
current_node_id = state_to_node.get(current_state.screen_state_id)
|
||||
next_node_id = state_to_node.get(next_state.screen_state_id)
|
||||
|
||||
# Si les deux états sont dans des nodes différents, c'est une transition
|
||||
if current_node_id and next_node_id and current_node_id != next_node_id:
|
||||
# En mode séquentiel, chaque paire consécutive est une transition
|
||||
# En mode clustering, uniquement si les nodes sont différents
|
||||
if current_node_id and next_node_id and (
|
||||
sequential or current_node_id != next_node_id
|
||||
):
|
||||
# Trouver TOUS les événements entre les deux états
|
||||
transition_events = self._find_transition_events(
|
||||
current_state, next_state, session.events
|
||||
@@ -1012,6 +1291,7 @@ class GraphBuilder:
|
||||
target_node=target_node,
|
||||
all_events=transition_events,
|
||||
screen_resolution=screen_resolution,
|
||||
source_state=current_state,
|
||||
)
|
||||
edges.append(edge)
|
||||
|
||||
@@ -1094,6 +1374,32 @@ class GraphBuilder:
|
||||
|
||||
return state_to_node
|
||||
|
||||
def _get_state_time(self, state: ScreenState, fallback: float = 0) -> float:
|
||||
"""Extraire le timestamp d'un ScreenState.
|
||||
|
||||
Priorité :
|
||||
1. metadata['event_time'] (set par _create_screen_states)
|
||||
2. metadata['shot_timestamp'] (set par le reprocessing)
|
||||
3. state.timestamp converti en epoch si c'est un datetime
|
||||
4. fallback
|
||||
|
||||
Note : event_time peut être 0.0 (timestamps relatifs), donc on
|
||||
vérifie `is not None` et non `> 0`.
|
||||
"""
|
||||
if state.metadata:
|
||||
et = state.metadata.get("event_time")
|
||||
if et is not None:
|
||||
return float(et)
|
||||
st = state.metadata.get("shot_timestamp")
|
||||
if st is not None:
|
||||
return float(st)
|
||||
if state.timestamp:
|
||||
try:
|
||||
return state.timestamp.timestamp()
|
||||
except (AttributeError, OSError):
|
||||
pass
|
||||
return fallback
|
||||
|
||||
def _find_transition_events(
|
||||
self,
|
||||
current_state: ScreenState,
|
||||
@@ -1108,6 +1414,9 @@ class GraphBuilder:
|
||||
C'est essentiel pour le replay : une transition peut nécessiter
|
||||
plusieurs actions (ex: Win+R → taper "notepad" → Entrée).
|
||||
|
||||
Timestamps : utilise _get_state_time() qui supporte plusieurs
|
||||
sources (event_time, shot_timestamp, datetime).
|
||||
|
||||
Args:
|
||||
current_state: État source
|
||||
next_state: État cible
|
||||
@@ -1117,8 +1426,8 @@ class GraphBuilder:
|
||||
Liste ordonnée (par timestamp) de tous les événements d'action
|
||||
entre les deux états. Peut être vide.
|
||||
"""
|
||||
current_time = current_state.metadata.get("event_time", 0)
|
||||
next_time = next_state.metadata.get("event_time", float('inf'))
|
||||
current_time = self._get_state_time(current_state, fallback=0)
|
||||
next_time = self._get_state_time(next_state, fallback=float('inf'))
|
||||
|
||||
action_events = []
|
||||
for event in events:
|
||||
@@ -1155,6 +1464,7 @@ class GraphBuilder:
|
||||
target_node: Optional[WorkflowNode] = None,
|
||||
all_events: Optional[List[Event]] = None,
|
||||
screen_resolution: Tuple[int, int] = (1920, 1080),
|
||||
source_state: Optional[ScreenState] = None,
|
||||
) -> WorkflowEdge:
|
||||
"""
|
||||
Créer un WorkflowEdge depuis une transition observée.
|
||||
@@ -1180,12 +1490,24 @@ class GraphBuilder:
|
||||
# Si on a plusieurs événements, créer une action compound
|
||||
events_to_use = all_events or ([event] if event else [])
|
||||
|
||||
# UIElements de l'écran source — sert à ancrer les clics sur un vrai
|
||||
# élément UI (rôle, texte, bbox) plutôt que sur une coordonnée brute.
|
||||
source_ui_elements = (
|
||||
list(source_state.ui_elements)
|
||||
if source_state and source_state.ui_elements
|
||||
else []
|
||||
)
|
||||
|
||||
if len(events_to_use) > 1:
|
||||
action = self._build_compound_action(
|
||||
events_to_use, screen_resolution
|
||||
events_to_use, screen_resolution,
|
||||
source_ui_elements=source_ui_elements,
|
||||
)
|
||||
elif len(events_to_use) == 1:
|
||||
action = self._build_single_action(events_to_use[0])
|
||||
action = self._build_single_action(
|
||||
events_to_use[0],
|
||||
source_ui_elements=source_ui_elements,
|
||||
)
|
||||
else:
|
||||
action = Action(
|
||||
type="unknown",
|
||||
@@ -1235,15 +1557,29 @@ class GraphBuilder:
|
||||
metadata=edge_metadata,
|
||||
)
|
||||
|
||||
def _build_single_action(self, event: Event) -> Action:
|
||||
def _build_single_action(
|
||||
self,
|
||||
event: Event,
|
||||
source_ui_elements: Optional[List[Any]] = None,
|
||||
) -> Action:
|
||||
"""
|
||||
Construire une Action simple depuis un seul événement.
|
||||
|
||||
Rétrocompatible avec l'ancien format : un type d'action direct
|
||||
(mouse_click, key_press, text_input) avec ses paramètres.
|
||||
Pour un clic, si `source_ui_elements` est fourni, on tente d'ancrer
|
||||
l'action sur l'UIElement le plus proche (par proximité spatiale).
|
||||
Le TargetSpec devient alors discriminant :
|
||||
- `by_role` = rôle sémantique de l'élément (ex: "primary_action")
|
||||
- `by_text` = label détecté (ex: "Valider")
|
||||
- `selection_policy` = "by_similarity" (laisse le matcher scorer)
|
||||
- `context_hints["anchor_element_id"]` = traçabilité
|
||||
- `context_hints["anchor_bbox"]` = invariant spatial debug
|
||||
|
||||
À défaut d'ancrage (pas d'UIElement ou clic hors de toute bbox
|
||||
proche), on retombe sur `by_role="unknown_element"` (legacy).
|
||||
"""
|
||||
action_type = event.type
|
||||
action_params = {}
|
||||
action_params: Dict[str, Any] = {}
|
||||
target_spec: Optional[TargetSpec] = None
|
||||
|
||||
if action_type == "mouse_click":
|
||||
action_params = {
|
||||
@@ -1251,39 +1587,111 @@ class GraphBuilder:
|
||||
"position": event.data.get("pos", [0, 0]),
|
||||
"wait_after_ms": 500,
|
||||
}
|
||||
target_role = "unknown_element"
|
||||
target_spec = self._build_click_target_spec(
|
||||
event, source_ui_elements or []
|
||||
)
|
||||
|
||||
elif action_type == "key_press":
|
||||
action_params = {
|
||||
"keys": event.data.get("keys", []),
|
||||
"wait_after_ms": 200,
|
||||
}
|
||||
target_role = "keyboard_input"
|
||||
target_spec = TargetSpec(
|
||||
by_role="keyboard_input",
|
||||
selection_policy="first",
|
||||
fallback_strategy="visual_similarity",
|
||||
)
|
||||
|
||||
elif action_type == "text_input":
|
||||
action_params = {
|
||||
"text": event.data.get("text", ""),
|
||||
"wait_after_ms": 300,
|
||||
}
|
||||
target_role = "text_field"
|
||||
target_spec = TargetSpec(
|
||||
by_role="text_field",
|
||||
selection_policy="first",
|
||||
fallback_strategy="visual_similarity",
|
||||
)
|
||||
else:
|
||||
action_params = {}
|
||||
target_role = "unknown"
|
||||
target_spec = TargetSpec(
|
||||
by_role="unknown",
|
||||
selection_policy="first",
|
||||
fallback_strategy="visual_similarity",
|
||||
)
|
||||
|
||||
return Action(
|
||||
type=action_type,
|
||||
target=TargetSpec(
|
||||
by_role=target_role,
|
||||
target=target_spec,
|
||||
parameters=action_params,
|
||||
)
|
||||
|
||||
def _build_click_target_spec(
|
||||
self,
|
||||
event: Event,
|
||||
source_ui_elements: List[Any],
|
||||
) -> TargetSpec:
|
||||
"""
|
||||
Construire un TargetSpec pour un clic, en essayant de l'ancrer à
|
||||
un UIElement détecté sur l'écran source.
|
||||
|
||||
Retourne toujours un TargetSpec valide :
|
||||
- ancré (role + text + context_hints) si un élément proche existe ;
|
||||
- fallback `unknown_element` sinon (comportement historique).
|
||||
"""
|
||||
clicked = self._find_clicked_element(event, source_ui_elements)
|
||||
|
||||
if clicked is None:
|
||||
return TargetSpec(
|
||||
by_role="unknown_element",
|
||||
selection_policy="first",
|
||||
fallback_strategy="visual_similarity",
|
||||
),
|
||||
parameters=action_params,
|
||||
)
|
||||
|
||||
# Extraction défensive des attributs de l'élément.
|
||||
role = getattr(clicked, "role", None) or "unknown_element"
|
||||
label = getattr(clicked, "label", None) or None
|
||||
element_id = getattr(clicked, "element_id", None)
|
||||
|
||||
# Contexte de traçabilité — `context_hints` est le seul dict libre
|
||||
# disponible dans TargetSpec (pas de champ `metadata` dédié).
|
||||
context_hints: Dict[str, Any] = {}
|
||||
if element_id:
|
||||
context_hints["anchor_element_id"] = str(element_id)
|
||||
|
||||
bbox = getattr(clicked, "bbox", None)
|
||||
if bbox is not None:
|
||||
try:
|
||||
context_hints["anchor_bbox"] = {
|
||||
"x": int(getattr(bbox, "x", bbox[0])),
|
||||
"y": int(getattr(bbox, "y", bbox[1])),
|
||||
"width": int(getattr(bbox, "width", bbox[2])),
|
||||
"height": int(getattr(bbox, "height", bbox[3])),
|
||||
}
|
||||
except (AttributeError, IndexError, TypeError):
|
||||
pass
|
||||
|
||||
# Center (utile comme ancre de fallback quand le matcher échoue)
|
||||
center = getattr(clicked, "center", None)
|
||||
if center is not None:
|
||||
try:
|
||||
context_hints["anchor_center"] = [int(center[0]), int(center[1])]
|
||||
except (IndexError, TypeError):
|
||||
pass
|
||||
|
||||
return TargetSpec(
|
||||
by_role=role,
|
||||
by_text=label,
|
||||
selection_policy="by_similarity",
|
||||
fallback_strategy="visual_similarity",
|
||||
context_hints=context_hints,
|
||||
)
|
||||
|
||||
def _build_compound_action(
|
||||
self,
|
||||
events: List[Event],
|
||||
screen_resolution: Tuple[int, int] = (1920, 1080),
|
||||
source_ui_elements: Optional[List[Any]] = None,
|
||||
) -> Action:
|
||||
"""
|
||||
Construire une Action compound (multi-étapes) depuis plusieurs événements.
|
||||
@@ -1360,21 +1768,33 @@ class GraphBuilder:
|
||||
# La cible du compound = cible de la dernière action (le clic final, etc.)
|
||||
last_event = events[-1]
|
||||
if last_event.type == "mouse_click":
|
||||
target_role = "unknown_element"
|
||||
# On tente d'ancrer le clic final aux UIElements détectés,
|
||||
# comme dans _build_single_action.
|
||||
target_spec = self._build_click_target_spec(
|
||||
last_event, source_ui_elements or []
|
||||
)
|
||||
elif last_event.type == "text_input":
|
||||
target_role = "text_field"
|
||||
target_spec = TargetSpec(
|
||||
by_role="text_field",
|
||||
selection_policy="first",
|
||||
fallback_strategy="visual_similarity",
|
||||
)
|
||||
elif last_event.type == "key_press":
|
||||
target_role = "keyboard_input"
|
||||
target_spec = TargetSpec(
|
||||
by_role="keyboard_input",
|
||||
selection_policy="first",
|
||||
fallback_strategy="visual_similarity",
|
||||
)
|
||||
else:
|
||||
target_role = "unknown"
|
||||
target_spec = TargetSpec(
|
||||
by_role="unknown",
|
||||
selection_policy="first",
|
||||
fallback_strategy="visual_similarity",
|
||||
)
|
||||
|
||||
return Action(
|
||||
type="compound",
|
||||
target=TargetSpec(
|
||||
by_role=target_role,
|
||||
selection_policy="first",
|
||||
fallback_strategy="visual_similarity",
|
||||
),
|
||||
target=target_spec,
|
||||
parameters={
|
||||
"steps": steps,
|
||||
"step_count": len(steps),
|
||||
|
||||
20
core/grounding/__init__.py
Normal file
20
core/grounding/__init__.py
Normal file
@@ -0,0 +1,20 @@
|
||||
# core/grounding — Module de localisation d'éléments UI
|
||||
#
|
||||
# Centralise les méthodes de grounding visuel : template matching,
|
||||
# OCR, VLM, etc. Chaque méthode produit un GroundingResult uniforme.
|
||||
#
|
||||
# Le serveur de grounding (server.py) tourne dans un process séparé
|
||||
# sur le port 8200. Le client HTTP (UITarsGrounder) l'appelle via HTTP.
|
||||
# Le pipeline (GroundingPipeline) orchestre template → OCR → UI-TARS → static.
|
||||
|
||||
from core.grounding.template_matcher import TemplateMatcher, MatchResult
|
||||
from core.grounding.target import GroundingTarget, GroundingResult
|
||||
from core.grounding.ui_tars_grounder import UITarsGrounder
|
||||
from core.grounding.pipeline import GroundingPipeline
|
||||
|
||||
__all__ = [
|
||||
'TemplateMatcher', 'MatchResult',
|
||||
'GroundingTarget', 'GroundingResult',
|
||||
'UITarsGrounder',
|
||||
'GroundingPipeline',
|
||||
]
|
||||
256
core/grounding/dialog_handler.py
Normal file
256
core/grounding/dialog_handler.py
Normal file
@@ -0,0 +1,256 @@
|
||||
"""
|
||||
core/grounding/dialog_handler.py — Gestion intelligente des dialogues
|
||||
|
||||
Quand un dialogue inattendu apparaît (pHash change après une action) :
|
||||
1. Lire le titre de la fenêtre (EasyOCR crop 45px, ~130ms)
|
||||
2. Si titre connu (Enregistrer sous, Confirmer, etc.) → action connue
|
||||
3. Demander à InfiGUI de cliquer sur le bon bouton (~3s)
|
||||
4. Vérifier que le dialogue a disparu (pHash)
|
||||
|
||||
Pas de patterns prédéfinis pour les boutons. InfiGUI comprend
|
||||
visuellement le dialogue et clique au bon endroit.
|
||||
|
||||
Utilisation :
|
||||
from core.grounding.dialog_handler import DialogHandler
|
||||
|
||||
handler = DialogHandler()
|
||||
result = handler.handle_if_dialog(screenshot_pil)
|
||||
if result['handled']:
|
||||
print(f"Dialogue '{result['title']}' géré → {result['action']}")
|
||||
"""
|
||||
|
||||
from __future__ import annotations
|
||||
|
||||
import time
|
||||
from typing import Any, Dict, Optional
|
||||
|
||||
|
||||
# Titres connus → quelle action demander à InfiGUI.
|
||||
#
|
||||
# IMPORTANT — ordre du dict = priorité de matching.
|
||||
# L'OCR est full-screen et capte souvent le texte du dialog parent ET du popup
|
||||
# modal qui apparaît par-dessus (ex: "Enregistrer sous" reste visible derrière
|
||||
# "Confirmer l'enregistrement"). Les popups modaux DOIVENT matcher avant les
|
||||
# fenêtres principales, sinon Léa clique sur le bouton du parent qui n'a pas
|
||||
# le focus.
|
||||
KNOWN_DIALOGS = {
|
||||
# ── Popups modaux de confirmation (priorité HAUTE) ──────────────────
|
||||
"voulez-vous le remplacer": {"target": "Oui", "description": "Clique sur Oui pour confirmer le remplacement du fichier"},
|
||||
"do you want to replace": {"target": "Yes", "description": "Click Yes to confirm file replacement"},
|
||||
"existe déjà": {"target": "Oui", "description": "Clique sur Oui, le fichier existe déjà et doit être remplacé"},
|
||||
"already exists": {"target": "Yes", "description": "Click Yes, the file already exists"},
|
||||
"remplacer": {"target": "Oui", "description": "Clique sur le bouton Oui pour confirmer le remplacement du fichier"},
|
||||
"replace": {"target": "Yes", "description": "Click Yes to confirm file replacement"},
|
||||
"écraser": {"target": "Oui", "description": "Clique sur Oui pour écraser le fichier"},
|
||||
"overwrite": {"target": "Yes", "description": "Click Yes to overwrite"},
|
||||
"confirmer l'enregistrement": {"target": "Oui", "description": "Clique sur Oui dans le popup de confirmation d'enregistrement"},
|
||||
"confirmer": {"target": "Oui", "description": "Clique sur le bouton Oui dans le dialogue de confirmation"},
|
||||
# ── Avertissements/erreurs (priorité haute, 1 seul bouton OK) ───────
|
||||
"erreur": {"target": "OK", "description": "Clique sur OK pour fermer le message d'erreur"},
|
||||
"error": {"target": "OK", "description": "Click OK to close the error message"},
|
||||
"avertissement": {"target": "OK", "description": "Clique sur OK pour fermer l'avertissement"},
|
||||
"warning": {"target": "OK", "description": "Click OK to close the warning"},
|
||||
# ── Dialogs principaux de sauvegarde (priorité BASSE — fenêtres parents) ─
|
||||
"voulez-vous enregistrer": {"target": "Enregistrer", "description": "Clique sur Enregistrer pour sauvegarder les modifications"},
|
||||
"do you want to save": {"target": "Save", "description": "Click Save to save changes"},
|
||||
"enregistrer sous": {"target": "Enregistrer", "description": "Clique sur le bouton Enregistrer dans le dialogue Enregistrer sous"},
|
||||
"save as": {"target": "Save", "description": "Click the Save button in the Save As dialog"},
|
||||
}
|
||||
|
||||
|
||||
class DialogHandler:
|
||||
"""Gestion intelligente des dialogues via titre + InfiGUI."""
|
||||
|
||||
def __init__(self):
|
||||
self._easyocr_reader = None
|
||||
|
||||
def handle_if_dialog(
|
||||
self,
|
||||
screenshot_pil,
|
||||
previous_title: str = "",
|
||||
) -> Dict[str, Any]:
|
||||
"""Vérifie si l'écran montre un dialogue et le gère.
|
||||
|
||||
Args:
|
||||
screenshot_pil: Screenshot PIL actuel.
|
||||
previous_title: Titre de la fenêtre avant l'action (pour comparaison).
|
||||
|
||||
Returns:
|
||||
Dict avec 'handled' (bool), 'title', 'action', 'position'.
|
||||
"""
|
||||
t0 = time.time()
|
||||
|
||||
# 1. Lire le titre de la fenêtre
|
||||
title = self._read_title(screenshot_pil)
|
||||
if not title or len(title) < 3:
|
||||
return {'handled': False, 'title': '', 'reason': 'Titre illisible'}
|
||||
|
||||
print(f"🔍 [Dialog] Titre lu: '{title}'")
|
||||
|
||||
# 2. Chercher si c'est un dialogue connu
|
||||
matched_dialog = None
|
||||
for key, action_info in KNOWN_DIALOGS.items():
|
||||
if key in title.lower():
|
||||
matched_dialog = (key, action_info)
|
||||
break
|
||||
|
||||
if not matched_dialog:
|
||||
# Pas un dialogue connu — le workflow continue normalement
|
||||
return {'handled': False, 'title': title, 'reason': 'Pas un dialogue connu'}
|
||||
|
||||
dialog_key, action_info = matched_dialog
|
||||
target = action_info['target']
|
||||
description = action_info['description']
|
||||
|
||||
print(f"🧠 [Dialog] Dialogue détecté: '{dialog_key}' → clic '{target}'")
|
||||
|
||||
# 3. Demander à InfiGUI de cliquer sur le bouton
|
||||
click_result = self._click_via_infigui(
|
||||
target, description, screenshot_pil
|
||||
)
|
||||
|
||||
dt = (time.time() - t0) * 1000
|
||||
|
||||
if click_result:
|
||||
print(f"✅ [Dialog] Clic '{target}' à ({click_result['x']}, {click_result['y']}) ({dt:.0f}ms)")
|
||||
return {
|
||||
'handled': True,
|
||||
'title': title,
|
||||
'dialog_type': dialog_key,
|
||||
'action': f"click '{target}'",
|
||||
'position': (click_result['x'], click_result['y']),
|
||||
'time_ms': dt,
|
||||
}
|
||||
else:
|
||||
# InfiGUI n'a pas trouvé le bouton — essayer le clic direct via OCR
|
||||
print(f"⚠️ [Dialog] InfiGUI n'a pas trouvé '{target}', essai OCR direct")
|
||||
ocr_result = self._click_via_ocr(target, screenshot_pil)
|
||||
dt = (time.time() - t0) * 1000
|
||||
|
||||
if ocr_result:
|
||||
print(f"✅ [Dialog] OCR clic '{target}' à ({ocr_result[0]}, {ocr_result[1]}) ({dt:.0f}ms)")
|
||||
return {
|
||||
'handled': True,
|
||||
'title': title,
|
||||
'dialog_type': dialog_key,
|
||||
'action': f"click '{target}' (OCR)",
|
||||
'position': ocr_result,
|
||||
'time_ms': dt,
|
||||
}
|
||||
|
||||
print(f"❌ [Dialog] Impossible de cliquer '{target}' ({dt:.0f}ms)")
|
||||
return {
|
||||
'handled': False,
|
||||
'title': title,
|
||||
'dialog_type': dialog_key,
|
||||
'reason': f"Bouton '{target}' introuvable",
|
||||
'time_ms': dt,
|
||||
}
|
||||
|
||||
# ------------------------------------------------------------------
|
||||
# Lecture titre
|
||||
# ------------------------------------------------------------------
|
||||
|
||||
def _read_title(self, screenshot_pil) -> str:
|
||||
"""Lit TOUT le texte visible via EasyOCR full-screen (~500ms).
|
||||
|
||||
En VM QEMU, la barre de titre Windows est à l'intérieur du framebuffer,
|
||||
pas en haut absolu de l'écran. On fait l'OCR full-screen et on cherche
|
||||
les mots-clés des dialogues connus dans le texte complet.
|
||||
"""
|
||||
try:
|
||||
import numpy as np
|
||||
|
||||
reader = self._get_easyocr()
|
||||
if reader is None:
|
||||
return ""
|
||||
|
||||
results = reader.readtext(np.array(screenshot_pil))
|
||||
full_text = ' '.join(r[1] for r in results if r[1].strip())
|
||||
return full_text
|
||||
|
||||
except Exception as e:
|
||||
print(f"⚠️ [Dialog] Erreur lecture écran: {e}")
|
||||
return ""
|
||||
|
||||
# ------------------------------------------------------------------
|
||||
# Clic via InfiGUI (serveur grounding)
|
||||
# ------------------------------------------------------------------
|
||||
|
||||
def _click_via_infigui(
|
||||
self, target: str, description: str, screenshot_pil
|
||||
) -> Optional[Dict]:
|
||||
"""Demande à InfiGUI (subprocess one-shot) de localiser et cliquer sur le bouton."""
|
||||
try:
|
||||
from core.grounding.ui_tars_grounder import UITarsGrounder
|
||||
|
||||
grounder = UITarsGrounder.get_instance()
|
||||
result = grounder.ground(
|
||||
target_text=target,
|
||||
target_description=description,
|
||||
screen_pil=screenshot_pil,
|
||||
)
|
||||
|
||||
if result and result.x is not None:
|
||||
import pyautogui
|
||||
pyautogui.click(result.x, result.y)
|
||||
return {'x': result.x, 'y': result.y}
|
||||
|
||||
return None
|
||||
|
||||
except Exception as e:
|
||||
print(f"⚠️ [Dialog/InfiGUI] Erreur: {e}")
|
||||
return None
|
||||
|
||||
# ------------------------------------------------------------------
|
||||
# Clic via OCR (fallback rapide)
|
||||
# ------------------------------------------------------------------
|
||||
|
||||
def _click_via_ocr(self, target: str, screenshot_pil) -> Optional[tuple]:
|
||||
"""Cherche le bouton par OCR et clique dessus."""
|
||||
try:
|
||||
import numpy as np
|
||||
|
||||
reader = self._get_easyocr()
|
||||
if reader is None:
|
||||
return None
|
||||
|
||||
results = reader.readtext(np.array(screenshot_pil))
|
||||
|
||||
target_lower = target.lower()
|
||||
matches = []
|
||||
for (bbox_pts, text, conf) in results:
|
||||
if target_lower in text.lower() or text.lower() in target_lower:
|
||||
x = int(sum(p[0] for p in bbox_pts) / 4)
|
||||
y = int(sum(p[1] for p in bbox_pts) / 4)
|
||||
matches.append((x, y, text))
|
||||
|
||||
if matches:
|
||||
# Prendre le match le plus bas (boutons = bas du dialogue)
|
||||
best = max(matches, key=lambda m: m[1])
|
||||
import pyautogui
|
||||
pyautogui.click(best[0], best[1])
|
||||
return (best[0], best[1])
|
||||
|
||||
return None
|
||||
|
||||
except Exception as e:
|
||||
print(f"⚠️ [Dialog/OCR] Erreur: {e}")
|
||||
return None
|
||||
|
||||
# ------------------------------------------------------------------
|
||||
# EasyOCR singleton
|
||||
# ------------------------------------------------------------------
|
||||
|
||||
def _get_easyocr(self):
|
||||
if self._easyocr_reader is not None:
|
||||
return self._easyocr_reader
|
||||
|
||||
try:
|
||||
import easyocr
|
||||
self._easyocr_reader = easyocr.Reader(
|
||||
['fr', 'en'], gpu=True, verbose=False
|
||||
)
|
||||
return self._easyocr_reader
|
||||
except ImportError:
|
||||
return None
|
||||
239
core/grounding/element_signature.py
Normal file
239
core/grounding/element_signature.py
Normal file
@@ -0,0 +1,239 @@
|
||||
"""
|
||||
core/grounding/element_signature.py — Signatures d'éléments UI apprises
|
||||
|
||||
Chaque élément cliqué avec succès enrichit sa signature :
|
||||
- texte OCR, type, position relative, voisins contextuels
|
||||
- nombre de succès/échecs, confiance moyenne
|
||||
- variantes observées (résolutions, positions)
|
||||
|
||||
Les signatures sont stockées en SQLite pour un lookup rapide.
|
||||
Pattern identique à TargetMemoryStore (validé en prod).
|
||||
|
||||
Utilisation :
|
||||
from core.grounding.element_signature import SignatureStore
|
||||
|
||||
store = SignatureStore()
|
||||
|
||||
# Après un clic réussi
|
||||
store.record_success("btn_valider", "notepad_1920x1080", element, confidence=0.92)
|
||||
|
||||
# Au replay
|
||||
sig = store.lookup("btn_valider", "notepad_1920x1080")
|
||||
if sig:
|
||||
print(f"Signature connue : {sig['text']} position={sig['relative_position']}")
|
||||
"""
|
||||
|
||||
from __future__ import annotations
|
||||
|
||||
import hashlib
|
||||
import json
|
||||
import os
|
||||
import sqlite3
|
||||
import threading
|
||||
import time
|
||||
from typing import Any, Dict, List, Optional
|
||||
|
||||
from core.grounding.fast_types import DetectedUIElement
|
||||
|
||||
# Chemin par défaut de la DB
|
||||
_DEFAULT_DB = os.path.join(
|
||||
os.path.dirname(os.path.dirname(os.path.dirname(__file__))),
|
||||
"data", "learning", "element_signatures.db",
|
||||
)
|
||||
|
||||
|
||||
class SignatureStore:
|
||||
"""Stockage SQLite des signatures d'éléments UI appris."""
|
||||
|
||||
def __init__(self, db_path: str = _DEFAULT_DB):
|
||||
self.db_path = db_path
|
||||
self._lock = threading.Lock()
|
||||
self._ensure_db()
|
||||
|
||||
def _ensure_db(self):
|
||||
"""Crée la DB et la table si nécessaire."""
|
||||
os.makedirs(os.path.dirname(self.db_path), exist_ok=True)
|
||||
with sqlite3.connect(self.db_path) as conn:
|
||||
conn.execute("""
|
||||
CREATE TABLE IF NOT EXISTS signatures (
|
||||
target_key TEXT NOT NULL,
|
||||
screen_context TEXT NOT NULL,
|
||||
text TEXT DEFAULT '',
|
||||
element_type TEXT DEFAULT 'element',
|
||||
relative_position TEXT DEFAULT '',
|
||||
neighbors TEXT DEFAULT '[]',
|
||||
success_count INTEGER DEFAULT 0,
|
||||
fail_count INTEGER DEFAULT 0,
|
||||
avg_confidence REAL DEFAULT 0.0,
|
||||
last_seen TEXT DEFAULT '',
|
||||
variants TEXT DEFAULT '[]',
|
||||
PRIMARY KEY (target_key, screen_context)
|
||||
)
|
||||
""")
|
||||
conn.execute("""
|
||||
CREATE INDEX IF NOT EXISTS idx_target_key
|
||||
ON signatures(target_key)
|
||||
""")
|
||||
|
||||
# ------------------------------------------------------------------
|
||||
# Lookup
|
||||
# ------------------------------------------------------------------
|
||||
|
||||
def lookup(self, target_key: str, screen_context: str = "") -> Optional[Dict[str, Any]]:
|
||||
"""Cherche une signature connue.
|
||||
|
||||
Args:
|
||||
target_key: Clé unique de la cible (hash du texte + description).
|
||||
screen_context: Contexte d'écran (hash titre fenêtre + résolution).
|
||||
|
||||
Returns:
|
||||
Dict avec les champs de la signature, ou None.
|
||||
"""
|
||||
with sqlite3.connect(self.db_path) as conn:
|
||||
conn.row_factory = sqlite3.Row
|
||||
# Chercher avec le contexte exact d'abord
|
||||
row = conn.execute(
|
||||
"SELECT * FROM signatures WHERE target_key = ? AND screen_context = ?",
|
||||
(target_key, screen_context),
|
||||
).fetchone()
|
||||
|
||||
# Fallback : chercher sans contexte (toutes les variantes)
|
||||
if row is None and screen_context:
|
||||
row = conn.execute(
|
||||
"SELECT * FROM signatures WHERE target_key = ? ORDER BY success_count DESC LIMIT 1",
|
||||
(target_key,),
|
||||
).fetchone()
|
||||
|
||||
if row is None:
|
||||
return None
|
||||
|
||||
return {
|
||||
"target_key": row["target_key"],
|
||||
"screen_context": row["screen_context"],
|
||||
"text": row["text"],
|
||||
"element_type": row["element_type"],
|
||||
"relative_position": row["relative_position"],
|
||||
"neighbors": json.loads(row["neighbors"]),
|
||||
"success_count": row["success_count"],
|
||||
"fail_count": row["fail_count"],
|
||||
"avg_confidence": row["avg_confidence"],
|
||||
"last_seen": row["last_seen"],
|
||||
"variants": json.loads(row["variants"]),
|
||||
}
|
||||
|
||||
# ------------------------------------------------------------------
|
||||
# Enregistrement
|
||||
# ------------------------------------------------------------------
|
||||
|
||||
def record_success(
|
||||
self,
|
||||
target_key: str,
|
||||
screen_context: str,
|
||||
element: DetectedUIElement,
|
||||
confidence: float,
|
||||
):
|
||||
"""Enregistre un succès — crée ou enrichit la signature."""
|
||||
with self._lock:
|
||||
existing = self.lookup(target_key, screen_context)
|
||||
now = time.strftime("%Y-%m-%dT%H:%M:%S")
|
||||
|
||||
if existing:
|
||||
# Enrichir la signature existante
|
||||
n = existing["success_count"]
|
||||
new_avg = (existing["avg_confidence"] * n + confidence) / (n + 1)
|
||||
|
||||
# Ajouter la variante si position différente
|
||||
variants = existing["variants"]
|
||||
variant = {
|
||||
"position": element.relative_position,
|
||||
"center": list(element.center),
|
||||
"confidence": confidence,
|
||||
"timestamp": now,
|
||||
}
|
||||
variants.append(variant)
|
||||
# Garder les 20 dernières variantes max
|
||||
variants = variants[-20:]
|
||||
|
||||
# Mettre à jour les voisins (union)
|
||||
neighbors = list(set(existing["neighbors"] + element.neighbors))[:10]
|
||||
|
||||
with sqlite3.connect(self.db_path) as conn:
|
||||
conn.execute("""
|
||||
UPDATE signatures SET
|
||||
success_count = success_count + 1,
|
||||
avg_confidence = ?,
|
||||
last_seen = ?,
|
||||
neighbors = ?,
|
||||
variants = ?,
|
||||
relative_position = ?
|
||||
WHERE target_key = ? AND screen_context = ?
|
||||
""", (
|
||||
new_avg, now,
|
||||
json.dumps(neighbors),
|
||||
json.dumps(variants),
|
||||
element.relative_position,
|
||||
target_key, screen_context,
|
||||
))
|
||||
else:
|
||||
# Créer une nouvelle signature
|
||||
with sqlite3.connect(self.db_path) as conn:
|
||||
conn.execute("""
|
||||
INSERT INTO signatures
|
||||
(target_key, screen_context, text, element_type, relative_position,
|
||||
neighbors, success_count, fail_count, avg_confidence, last_seen, variants)
|
||||
VALUES (?, ?, ?, ?, ?, ?, 1, 0, ?, ?, ?)
|
||||
""", (
|
||||
target_key, screen_context,
|
||||
element.ocr_text,
|
||||
element.element_type,
|
||||
element.relative_position,
|
||||
json.dumps(element.neighbors[:10]),
|
||||
confidence, now,
|
||||
json.dumps([{
|
||||
"position": element.relative_position,
|
||||
"center": list(element.center),
|
||||
"confidence": confidence,
|
||||
"timestamp": now,
|
||||
}]),
|
||||
))
|
||||
|
||||
print(f"📝 [Signature] '{target_key}' {'enrichie' if existing else 'créée'} "
|
||||
f"(conf={confidence:.2f}, ctx='{screen_context[:30]}')")
|
||||
|
||||
def record_failure(self, target_key: str, screen_context: str):
|
||||
"""Enregistre un échec pour une signature."""
|
||||
with self._lock:
|
||||
with sqlite3.connect(self.db_path) as conn:
|
||||
conn.execute("""
|
||||
UPDATE signatures SET fail_count = fail_count + 1, last_seen = ?
|
||||
WHERE target_key = ? AND screen_context = ?
|
||||
""", (time.strftime("%Y-%m-%dT%H:%M:%S"), target_key, screen_context))
|
||||
|
||||
# ------------------------------------------------------------------
|
||||
# Utilitaires
|
||||
# ------------------------------------------------------------------
|
||||
|
||||
@staticmethod
|
||||
def make_target_key(text: str, description: str = "") -> str:
|
||||
"""Génère une clé unique pour une cible."""
|
||||
raw = f"{text.lower().strip()}|{description.lower().strip()}"
|
||||
return hashlib.md5(raw.encode()).hexdigest()[:16]
|
||||
|
||||
@staticmethod
|
||||
def make_screen_context(window_title: str, resolution: tuple = (0, 0)) -> str:
|
||||
"""Génère un contexte d'écran."""
|
||||
raw = f"{window_title.lower().strip()}|{resolution[0]}x{resolution[1]}"
|
||||
return hashlib.md5(raw.encode()).hexdigest()[:12]
|
||||
|
||||
def get_stats(self) -> Dict[str, Any]:
|
||||
"""Statistiques de la base de signatures."""
|
||||
with sqlite3.connect(self.db_path) as conn:
|
||||
total = conn.execute("SELECT COUNT(*) FROM signatures").fetchone()[0]
|
||||
reliable = conn.execute(
|
||||
"SELECT COUNT(*) FROM signatures WHERE success_count >= 3 AND fail_count = 0"
|
||||
).fetchone()[0]
|
||||
return {
|
||||
"total_signatures": total,
|
||||
"reliable": reliable,
|
||||
"db_path": self.db_path,
|
||||
}
|
||||
326
core/grounding/fast_detector.py
Normal file
326
core/grounding/fast_detector.py
Normal file
@@ -0,0 +1,326 @@
|
||||
"""
|
||||
core/grounding/fast_detector.py — Layer FAST : détection rapide des éléments UI
|
||||
|
||||
Capture l'écran, détecte tous les éléments UI via RF-DETR (~120ms),
|
||||
enrichit chaque élément avec le texte OCR et le contexte spatial.
|
||||
|
||||
Produit un ScreenSnapshot utilisable par le SmartMatcher.
|
||||
|
||||
Utilisation :
|
||||
from core.grounding.fast_detector import FastDetector
|
||||
|
||||
detector = FastDetector()
|
||||
snapshot = detector.detect()
|
||||
print(f"{len(snapshot.elements)} éléments en {snapshot.total_time_ms:.0f}ms")
|
||||
"""
|
||||
|
||||
from __future__ import annotations
|
||||
|
||||
import math
|
||||
import time
|
||||
from typing import Any, Dict, List, Optional, Tuple
|
||||
|
||||
from core.grounding.fast_types import DetectedUIElement, ScreenSnapshot
|
||||
|
||||
|
||||
class FastDetector:
|
||||
"""Détection rapide de tous les éléments UI visibles sur l'écran.
|
||||
|
||||
Combine RF-DETR (détection bbox) + docTR (OCR) pour produire
|
||||
un ScreenSnapshot enrichi.
|
||||
|
||||
Le modèle RF-DETR est un singleton chargé au premier appel (~1s),
|
||||
puis les appels suivants sont rapides (~120ms).
|
||||
"""
|
||||
|
||||
def __init__(self, detection_threshold: float = 0.30):
|
||||
self.detection_threshold = detection_threshold
|
||||
self._last_snapshot: Optional[ScreenSnapshot] = None
|
||||
self._last_phash: str = ""
|
||||
|
||||
def detect(
|
||||
self,
|
||||
screenshot_pil: Optional[Any] = None,
|
||||
phash: str = "",
|
||||
window_title: str = "",
|
||||
) -> ScreenSnapshot:
|
||||
"""Détecte et enrichit tous les éléments UI de l'écran.
|
||||
|
||||
Args:
|
||||
screenshot_pil: Image PIL. Si None, capture via mss.
|
||||
phash: Hash perceptuel pour le cache. Si identique au dernier, réutilise le cache.
|
||||
window_title: Titre de la fenêtre active.
|
||||
|
||||
Returns:
|
||||
ScreenSnapshot avec tous les éléments enrichis.
|
||||
"""
|
||||
t0 = time.time()
|
||||
|
||||
# Cache : même écran → même résultat
|
||||
if phash and phash == self._last_phash and self._last_snapshot is not None:
|
||||
print(f"⚡ [FAST] Cache hit (pHash identique)")
|
||||
return self._last_snapshot
|
||||
|
||||
# Capture si pas fourni
|
||||
if screenshot_pil is None:
|
||||
screenshot_pil = self._capture_screen()
|
||||
if screenshot_pil is None:
|
||||
return ScreenSnapshot(elements=[], ocr_words=[], resolution=(0, 0))
|
||||
|
||||
w, h = screenshot_pil.size
|
||||
|
||||
# --- Détection RF-DETR (~120ms) ---
|
||||
t_det = time.time()
|
||||
raw_elements = self._detect_rfdetr(screenshot_pil)
|
||||
detection_ms = (time.time() - t_det) * 1000
|
||||
|
||||
# --- OCR sur les crops des éléments détectés (pas full screen) ---
|
||||
t_ocr = time.time()
|
||||
ocr_words = self._ocr_extract(screenshot_pil)
|
||||
ocr_ms = (time.time() - t_ocr) * 1000
|
||||
|
||||
# --- Enrichissement : attribuer texte + voisins + position ---
|
||||
enriched = self._enrich_elements(raw_elements, ocr_words, w, h)
|
||||
|
||||
total_ms = (time.time() - t0) * 1000
|
||||
|
||||
snapshot = ScreenSnapshot(
|
||||
elements=enriched,
|
||||
ocr_words=ocr_words,
|
||||
resolution=(w, h),
|
||||
window_title=window_title,
|
||||
phash=phash,
|
||||
detection_time_ms=detection_ms,
|
||||
ocr_time_ms=ocr_ms,
|
||||
total_time_ms=total_ms,
|
||||
)
|
||||
|
||||
# Mettre en cache
|
||||
if phash:
|
||||
self._last_phash = phash
|
||||
self._last_snapshot = snapshot
|
||||
|
||||
print(f"⚡ [FAST] {len(enriched)} éléments détectés en {total_ms:.0f}ms "
|
||||
f"(det={detection_ms:.0f}ms, ocr={ocr_ms:.0f}ms)")
|
||||
|
||||
return snapshot
|
||||
|
||||
# ------------------------------------------------------------------
|
||||
# Détection RF-DETR
|
||||
# ------------------------------------------------------------------
|
||||
|
||||
def _detect_rfdetr(self, image) -> List[DetectedUIElement]:
|
||||
"""Détecte les éléments via RF-DETR (réutilise le singleton existant)."""
|
||||
try:
|
||||
import sys
|
||||
sys.path.insert(0, 'visual_workflow_builder/backend')
|
||||
from services.ui_detection_service import detect_ui_elements
|
||||
|
||||
result = detect_ui_elements(image, threshold=self.detection_threshold)
|
||||
|
||||
elements = []
|
||||
for e in result.elements:
|
||||
x1 = e.bbox["x1"]
|
||||
y1 = e.bbox["y1"]
|
||||
x2 = e.bbox["x2"]
|
||||
y2 = e.bbox["y2"]
|
||||
elements.append(DetectedUIElement(
|
||||
id=e.id,
|
||||
bbox=(x1, y1, x2, y2),
|
||||
center=(e.center["x"], e.center["y"]),
|
||||
confidence=e.confidence,
|
||||
))
|
||||
|
||||
return elements
|
||||
|
||||
except Exception as ex:
|
||||
print(f"⚠️ [FAST/detect] RF-DETR erreur: {ex}")
|
||||
return []
|
||||
|
||||
# ------------------------------------------------------------------
|
||||
# OCR
|
||||
# ------------------------------------------------------------------
|
||||
|
||||
_easyocr_reader = None # Singleton EasyOCR (chargé une fois)
|
||||
|
||||
def _ocr_extract(self, image) -> List[Dict[str, Any]]:
|
||||
"""Extrait les mots visibles via EasyOCR (GPU, ~500ms).
|
||||
|
||||
Fallback sur docTR si EasyOCR non disponible.
|
||||
"""
|
||||
try:
|
||||
import numpy as np
|
||||
import easyocr
|
||||
|
||||
# Singleton : charger le reader une seule fois
|
||||
if FastDetector._easyocr_reader is None:
|
||||
print(f"🔍 [FAST/ocr] Chargement EasyOCR (GPU)...")
|
||||
FastDetector._easyocr_reader = easyocr.Reader(
|
||||
['fr', 'en'], gpu=True, verbose=False
|
||||
)
|
||||
|
||||
results = FastDetector._easyocr_reader.readtext(np.array(image))
|
||||
|
||||
words = []
|
||||
for (bbox_pts, text, conf) in results:
|
||||
if not text or len(text.strip()) < 1:
|
||||
continue
|
||||
# bbox_pts = [[x1,y1],[x2,y1],[x2,y2],[x1,y2]]
|
||||
x1 = int(min(p[0] for p in bbox_pts))
|
||||
y1 = int(min(p[1] for p in bbox_pts))
|
||||
x2 = int(max(p[0] for p in bbox_pts))
|
||||
y2 = int(max(p[1] for p in bbox_pts))
|
||||
words.append({
|
||||
'text': text.strip(),
|
||||
'bbox': [x1, y1, x2, y2],
|
||||
'confidence': float(conf),
|
||||
})
|
||||
|
||||
return words
|
||||
|
||||
except ImportError:
|
||||
# Fallback docTR
|
||||
try:
|
||||
import sys
|
||||
sys.path.insert(0, 'visual_workflow_builder/backend')
|
||||
from services.ocr_service import ocr_extract_words
|
||||
return ocr_extract_words(image) or []
|
||||
except Exception:
|
||||
return []
|
||||
except Exception as ex:
|
||||
print(f"⚠️ [FAST/ocr] EasyOCR erreur: {ex}")
|
||||
return []
|
||||
|
||||
# ------------------------------------------------------------------
|
||||
# Enrichissement
|
||||
# ------------------------------------------------------------------
|
||||
|
||||
def _enrich_elements(
|
||||
self,
|
||||
elements: List[DetectedUIElement],
|
||||
ocr_words: List[Dict[str, Any]],
|
||||
screen_w: int,
|
||||
screen_h: int,
|
||||
) -> List[DetectedUIElement]:
|
||||
"""Enrichit chaque élément avec texte OCR, voisins et position relative."""
|
||||
|
||||
for elem in elements:
|
||||
# 1. Attribuer le texte OCR par intersection bbox
|
||||
elem.ocr_text = self._assign_ocr_text(elem, ocr_words)
|
||||
|
||||
# 2. Position relative dans l'écran (grille 3x3)
|
||||
elem.relative_position = self._compute_relative_position(
|
||||
elem.center, screen_w, screen_h
|
||||
)
|
||||
|
||||
# 3. Classifier le type d'élément (heuristique taille + ratio)
|
||||
elem.element_type = self._classify_element_type(elem)
|
||||
|
||||
# 4. Calculer les voisins (texte des éléments proches)
|
||||
for elem in elements:
|
||||
elem.neighbors = self._find_neighbors(elem, elements)
|
||||
|
||||
return elements
|
||||
|
||||
def _assign_ocr_text(
|
||||
self,
|
||||
elem: DetectedUIElement,
|
||||
ocr_words: List[Dict[str, Any]],
|
||||
) -> str:
|
||||
"""Attribue le texte OCR à un élément par intersection géométrique."""
|
||||
x1, y1, x2, y2 = elem.bbox
|
||||
# Élargir la bbox de 20% pour capturer le texte autour
|
||||
margin_x = int((x2 - x1) * 0.2)
|
||||
margin_y = int((y2 - y1) * 0.2)
|
||||
ex1, ey1 = x1 - margin_x, y1 - margin_y
|
||||
ex2, ey2 = x2 + margin_x, y2 + margin_y
|
||||
|
||||
texts = []
|
||||
for word in ocr_words:
|
||||
wb = word.get('bbox', [0, 0, 0, 0])
|
||||
if len(wb) < 4:
|
||||
continue
|
||||
wx1, wy1, wx2, wy2 = wb[0], wb[1], wb[2], wb[3]
|
||||
# Intersection ?
|
||||
if wx1 < ex2 and wx2 > ex1 and wy1 < ey2 and wy2 > ey1:
|
||||
text = word.get('text', '').strip()
|
||||
if text and len(text) > 1:
|
||||
texts.append(text)
|
||||
|
||||
return ' '.join(texts)
|
||||
|
||||
@staticmethod
|
||||
def _compute_relative_position(
|
||||
center: Tuple[int, int],
|
||||
screen_w: int,
|
||||
screen_h: int,
|
||||
) -> str:
|
||||
"""Calcule la position relative dans une grille 3x3."""
|
||||
cx, cy = center
|
||||
col = "left" if cx < screen_w / 3 else ("right" if cx > 2 * screen_w / 3 else "center")
|
||||
row = "top" if cy < screen_h / 3 else ("bottom" if cy > 2 * screen_h / 3 else "middle")
|
||||
return f"{row}_{col}"
|
||||
|
||||
@staticmethod
|
||||
def _classify_element_type(elem: DetectedUIElement) -> str:
|
||||
"""Classifie le type d'élément par heuristique taille/ratio."""
|
||||
w, h = elem.width, elem.height
|
||||
if w == 0 or h == 0:
|
||||
return "element"
|
||||
ratio = w / h
|
||||
area = w * h
|
||||
|
||||
# Petit carré → icône
|
||||
if area < 5000 and 0.5 < ratio < 2.0:
|
||||
return "icon"
|
||||
# Large et fin → bouton ou champ
|
||||
if ratio > 3.0 and h < 60:
|
||||
return "input"
|
||||
if ratio > 2.0 and h < 50:
|
||||
return "button"
|
||||
# Grand bloc → zone de contenu
|
||||
if area > 50000:
|
||||
return "container"
|
||||
|
||||
return "element"
|
||||
|
||||
@staticmethod
|
||||
def _find_neighbors(
|
||||
elem: DetectedUIElement,
|
||||
all_elements: List[DetectedUIElement],
|
||||
max_neighbors: int = 5,
|
||||
) -> List[str]:
|
||||
"""Trouve les textes OCR des éléments proches (rayon 1.5x diagonale)."""
|
||||
diag = math.sqrt(elem.width**2 + elem.height**2)
|
||||
radius = max(diag * 1.5, 100) # minimum 100px
|
||||
|
||||
neighbors = []
|
||||
for other in all_elements:
|
||||
if other.id == elem.id or not other.ocr_text:
|
||||
continue
|
||||
dx = other.center[0] - elem.center[0]
|
||||
dy = other.center[1] - elem.center[1]
|
||||
dist = math.sqrt(dx**2 + dy**2)
|
||||
if dist < radius:
|
||||
neighbors.append(other.ocr_text)
|
||||
|
||||
return neighbors[:max_neighbors]
|
||||
|
||||
# ------------------------------------------------------------------
|
||||
# Capture écran
|
||||
# ------------------------------------------------------------------
|
||||
|
||||
@staticmethod
|
||||
def _capture_screen():
|
||||
"""Capture l'écran via mss."""
|
||||
try:
|
||||
import mss
|
||||
from PIL import Image
|
||||
|
||||
with mss.mss() as sct:
|
||||
mon = sct.monitors[0]
|
||||
grab = sct.grab(mon)
|
||||
return Image.frombytes('RGB', grab.size, grab.bgra, 'raw', 'BGRX')
|
||||
except Exception as ex:
|
||||
print(f"⚠️ [FAST/capture] Erreur: {ex}")
|
||||
return None
|
||||
216
core/grounding/fast_pipeline.py
Normal file
216
core/grounding/fast_pipeline.py
Normal file
@@ -0,0 +1,216 @@
|
||||
"""
|
||||
core/grounding/fast_pipeline.py — Pipeline FAST → SMART → THINK
|
||||
|
||||
Orchestrateur central : détecte les éléments (FAST), matche avec la cible (SMART),
|
||||
et demande au VLM de trancher si le score est trop bas (THINK).
|
||||
|
||||
Seuils de confiance :
|
||||
≥ 0.90 → action directe (FAST/SMART)
|
||||
0.60-0.90 → VLM confirme (THINK)
|
||||
< 0.60 → VLM cherche seul (THINK)
|
||||
|
||||
L'ancien GroundingPipeline est utilisé en fallback si tout échoue.
|
||||
|
||||
Utilisation :
|
||||
from core.grounding.fast_pipeline import FastSmartThinkPipeline
|
||||
from core.grounding.target import GroundingTarget
|
||||
|
||||
pipeline = FastSmartThinkPipeline()
|
||||
result = pipeline.locate(GroundingTarget(text="Valider"))
|
||||
if result:
|
||||
print(f"({result.x}, {result.y}) via {result.method} en {result.time_ms:.0f}ms")
|
||||
"""
|
||||
|
||||
from __future__ import annotations
|
||||
|
||||
import time
|
||||
import threading
|
||||
from typing import Optional
|
||||
|
||||
from core.grounding.target import GroundingTarget, GroundingResult
|
||||
from core.grounding.fast_types import LocateResult
|
||||
from core.grounding.fast_detector import FastDetector
|
||||
from core.grounding.smart_matcher import SmartMatcher
|
||||
from core.grounding.think_arbiter import ThinkArbiter
|
||||
from core.grounding.element_signature import SignatureStore
|
||||
|
||||
|
||||
# Singleton
|
||||
_instance: Optional[FastSmartThinkPipeline] = None
|
||||
_instance_lock = threading.Lock()
|
||||
|
||||
|
||||
class FastSmartThinkPipeline:
|
||||
"""Pipeline FAST → SMART → THINK pour la localisation d'éléments UI.
|
||||
|
||||
Chaque appel à locate() suit la cascade :
|
||||
1. FAST : détection RF-DETR + OCR enrichissement (~120ms+1s)
|
||||
2. SMART : matching texte/type/position/voisins (< 1ms)
|
||||
3. THINK : VLM arbitre si score insuffisant (~3-5s)
|
||||
4. Fallback : ancien pipeline si tout échoue
|
||||
"""
|
||||
|
||||
def __init__(
|
||||
self,
|
||||
confidence_direct: float = 0.90,
|
||||
confidence_think: float = 0.60,
|
||||
enable_think: bool = True,
|
||||
enable_learning: bool = True,
|
||||
):
|
||||
self.confidence_direct = confidence_direct
|
||||
self.confidence_think = confidence_think
|
||||
self.enable_think = enable_think
|
||||
self.enable_learning = enable_learning
|
||||
|
||||
self._detector = FastDetector()
|
||||
self._matcher = SmartMatcher()
|
||||
self._arbiter = ThinkArbiter()
|
||||
self._signatures = SignatureStore()
|
||||
self._fallback_pipeline = None
|
||||
|
||||
@classmethod
|
||||
def get_instance(cls) -> FastSmartThinkPipeline:
|
||||
"""Retourne l'instance singleton."""
|
||||
global _instance
|
||||
if _instance is None:
|
||||
with _instance_lock:
|
||||
if _instance is None:
|
||||
_instance = cls()
|
||||
return _instance
|
||||
|
||||
def set_fallback_pipeline(self, pipeline) -> None:
|
||||
"""Configure l'ancien pipeline comme safety net."""
|
||||
self._fallback_pipeline = pipeline
|
||||
|
||||
# ------------------------------------------------------------------
|
||||
# API principale
|
||||
# ------------------------------------------------------------------
|
||||
|
||||
def locate(
|
||||
self,
|
||||
target: GroundingTarget,
|
||||
screenshot_pil=None,
|
||||
phash: str = "",
|
||||
window_title: str = "",
|
||||
) -> Optional[GroundingResult]:
|
||||
"""Localise un élément UI via la cascade FAST → SMART → THINK.
|
||||
|
||||
Args:
|
||||
target: Ce qu'on cherche (texte, description, bbox d'origine).
|
||||
screenshot_pil: Image PIL. Si None, capture via mss.
|
||||
phash: Hash perceptuel pour le cache.
|
||||
window_title: Titre de la fenêtre active.
|
||||
|
||||
Returns:
|
||||
GroundingResult compatible avec le pipeline existant, ou None.
|
||||
"""
|
||||
t0 = time.time()
|
||||
|
||||
# --- FAST : détecter tous les éléments ---
|
||||
snapshot = self._detector.detect(
|
||||
screenshot_pil=screenshot_pil,
|
||||
phash=phash,
|
||||
window_title=window_title,
|
||||
)
|
||||
|
||||
if not snapshot.elements:
|
||||
print(f"⚡ [Pipeline] FAST : aucun élément détecté")
|
||||
return self._try_fallback(target)
|
||||
|
||||
# --- Lookup signature apprise ---
|
||||
target_key = SignatureStore.make_target_key(
|
||||
target.text or "", target.description or ""
|
||||
)
|
||||
screen_ctx = SignatureStore.make_screen_context(
|
||||
window_title, snapshot.resolution
|
||||
)
|
||||
signature = self._signatures.lookup(target_key, screen_ctx)
|
||||
|
||||
# --- SMART : matcher avec la cible ---
|
||||
candidate = self._matcher.match(snapshot, target, signature)
|
||||
|
||||
if candidate:
|
||||
dt = (time.time() - t0) * 1000
|
||||
|
||||
# Score suffisant → action directe
|
||||
if candidate.score >= self.confidence_direct:
|
||||
print(f"✅ [Pipeline] FAST→SMART direct : '{candidate.element.ocr_text}' "
|
||||
f"score={candidate.score:.3f} ({candidate.method}) "
|
||||
f"→ ({candidate.element.center[0]}, {candidate.element.center[1]}) "
|
||||
f"en {dt:.0f}ms")
|
||||
|
||||
# Apprentissage
|
||||
if self.enable_learning:
|
||||
self._signatures.record_success(
|
||||
target_key, screen_ctx,
|
||||
candidate.element, candidate.score,
|
||||
)
|
||||
|
||||
return GroundingResult(
|
||||
x=candidate.element.center[0],
|
||||
y=candidate.element.center[1],
|
||||
method=f"fast_{candidate.method}",
|
||||
confidence=candidate.score,
|
||||
time_ms=dt,
|
||||
)
|
||||
|
||||
# Score moyen → demander au VLM de confirmer
|
||||
if candidate.score >= self.confidence_think and self.enable_think:
|
||||
print(f"🤔 [Pipeline] SMART score={candidate.score:.3f} — THINK pour confirmer")
|
||||
think_result = self._arbiter.arbitrate(
|
||||
target,
|
||||
candidates=[candidate],
|
||||
screenshot_pil=screenshot_pil or snapshot.elements[0] if False else screenshot_pil,
|
||||
)
|
||||
dt = (time.time() - t0) * 1000
|
||||
|
||||
if think_result:
|
||||
# VLM a confirmé
|
||||
if self.enable_learning:
|
||||
self._signatures.record_success(
|
||||
target_key, screen_ctx,
|
||||
candidate.element, think_result.confidence,
|
||||
)
|
||||
return GroundingResult(
|
||||
x=think_result.x, y=think_result.y,
|
||||
method="smart_think_confirmed",
|
||||
confidence=think_result.confidence,
|
||||
time_ms=dt,
|
||||
)
|
||||
|
||||
# --- THINK : score trop bas ou pas de candidat → VLM cherche seul ---
|
||||
if self.enable_think:
|
||||
score_info = f"score={candidate.score:.3f}" if candidate else "aucun candidat"
|
||||
print(f"🤔 [Pipeline] {score_info} — THINK recherche complète")
|
||||
think_result = self._arbiter.arbitrate(
|
||||
target, candidates=[], screenshot_pil=screenshot_pil,
|
||||
)
|
||||
dt = (time.time() - t0) * 1000
|
||||
|
||||
if think_result:
|
||||
return GroundingResult(
|
||||
x=think_result.x, y=think_result.y,
|
||||
method="think_vlm",
|
||||
confidence=think_result.confidence,
|
||||
time_ms=dt,
|
||||
)
|
||||
|
||||
# --- Fallback : ancien pipeline ---
|
||||
return self._try_fallback(target)
|
||||
|
||||
# ------------------------------------------------------------------
|
||||
# Fallback
|
||||
# ------------------------------------------------------------------
|
||||
|
||||
def _try_fallback(self, target: GroundingTarget) -> Optional[GroundingResult]:
|
||||
"""Tente l'ancien pipeline en dernier recours."""
|
||||
if self._fallback_pipeline is None:
|
||||
print(f"❌ [Pipeline] Aucune méthode n'a trouvé '{target.text}'")
|
||||
return None
|
||||
|
||||
print(f"⚠️ [Pipeline] Fallback ancien pipeline pour '{target.text}'")
|
||||
try:
|
||||
return self._fallback_pipeline.locate(target)
|
||||
except Exception as ex:
|
||||
print(f"⚠️ [Pipeline] Fallback échoué: {ex}")
|
||||
return None
|
||||
81
core/grounding/fast_types.py
Normal file
81
core/grounding/fast_types.py
Normal file
@@ -0,0 +1,81 @@
|
||||
"""
|
||||
core/grounding/fast_types.py — Structures de données pour le pipeline FAST→SMART→THINK
|
||||
|
||||
Utilisées exclusivement par le pipeline de localisation rapide.
|
||||
Compatibles avec GroundingTarget/GroundingResult existants via conversion.
|
||||
"""
|
||||
|
||||
from __future__ import annotations
|
||||
|
||||
from dataclasses import dataclass, field
|
||||
from typing import Any, Dict, List, Optional, Tuple
|
||||
|
||||
|
||||
@dataclass
|
||||
class DetectedUIElement:
|
||||
"""Élément UI détecté par le layer FAST (RF-DETR) puis enrichi par OCR."""
|
||||
id: int
|
||||
bbox: Tuple[int, int, int, int] # (x1, y1, x2, y2) pixels absolus
|
||||
center: Tuple[int, int] # (cx, cy)
|
||||
confidence: float # confidence détecteur (0-1)
|
||||
element_type: str = "element" # "button", "input", "icon", "text", "element"
|
||||
ocr_text: str = "" # texte OCR extrait de la région
|
||||
neighbors: List[str] = field(default_factory=list) # textes des éléments proches
|
||||
relative_position: str = "" # "top_left", "center", "bottom_right", etc.
|
||||
|
||||
@property
|
||||
def width(self) -> int:
|
||||
return self.bbox[2] - self.bbox[0]
|
||||
|
||||
@property
|
||||
def height(self) -> int:
|
||||
return self.bbox[3] - self.bbox[1]
|
||||
|
||||
@property
|
||||
def area(self) -> int:
|
||||
return self.width * self.height
|
||||
|
||||
|
||||
@dataclass
|
||||
class ScreenSnapshot:
|
||||
"""État complet de l'écran à un instant t — sortie du layer FAST."""
|
||||
elements: List[DetectedUIElement]
|
||||
ocr_words: List[Dict[str, Any]] # mots OCR bruts [{text, bbox}]
|
||||
resolution: Tuple[int, int] # (width, height)
|
||||
window_title: str = ""
|
||||
phash: str = ""
|
||||
detection_time_ms: float = 0.0
|
||||
ocr_time_ms: float = 0.0
|
||||
total_time_ms: float = 0.0
|
||||
|
||||
|
||||
@dataclass
|
||||
class MatchCandidate:
|
||||
"""Résultat du matching SMART pour un élément candidat."""
|
||||
element: DetectedUIElement
|
||||
score: float # score combiné (0-1)
|
||||
score_detail: Dict[str, float] = field(default_factory=dict)
|
||||
method: str = "" # "exact_text", "fuzzy_text", "position", etc.
|
||||
|
||||
|
||||
@dataclass
|
||||
class LocateResult:
|
||||
"""Résultat final du pipeline FAST→SMART→THINK."""
|
||||
x: int
|
||||
y: int
|
||||
confidence: float
|
||||
method: str # "fast_exact", "fast_fuzzy", "smart_vote", "think_vlm"
|
||||
time_ms: float
|
||||
tier: str = "fast" # "fast", "smart", "think"
|
||||
element: Optional[DetectedUIElement] = None
|
||||
candidates_count: int = 0
|
||||
|
||||
def to_grounding_result(self):
|
||||
"""Conversion vers GroundingResult pour compatibilité."""
|
||||
from core.grounding.target import GroundingResult
|
||||
return GroundingResult(
|
||||
x=self.x, y=self.y,
|
||||
method=self.method,
|
||||
confidence=self.confidence,
|
||||
time_ms=self.time_ms,
|
||||
)
|
||||
210
core/grounding/infigui_worker.py
Normal file
210
core/grounding/infigui_worker.py
Normal file
@@ -0,0 +1,210 @@
|
||||
#!/usr/bin/env python3
|
||||
"""
|
||||
Worker InfiGUI — process indépendant, communication par fichiers.
|
||||
|
||||
Charge le modèle, surveille /tmp/infigui_request.json, infère, écrit /tmp/infigui_response.json.
|
||||
|
||||
Lancement :
|
||||
cd ~/ai/rpa_vision_v3
|
||||
.venv/bin/python3 -m core.grounding.infigui_worker
|
||||
"""
|
||||
|
||||
import json
|
||||
import math
|
||||
import os
|
||||
import re
|
||||
import sys
|
||||
import time
|
||||
import gc
|
||||
import warnings
|
||||
|
||||
warnings.filterwarnings("ignore")
|
||||
|
||||
import torch
|
||||
|
||||
REQUEST_FILE = "/tmp/infigui_request.json"
|
||||
RESPONSE_FILE = "/tmp/infigui_response.json"
|
||||
READY_FILE = "/tmp/infigui_ready"
|
||||
|
||||
|
||||
def load_model():
|
||||
"""Charge InfiGUI-G1-3B en 4-bit NF4."""
|
||||
torch.cuda.empty_cache()
|
||||
gc.collect()
|
||||
|
||||
from transformers import Qwen2_5_VLForConditionalGeneration, AutoProcessor, BitsAndBytesConfig
|
||||
|
||||
model_id = "InfiX-ai/InfiGUI-G1-3B"
|
||||
print(f"[infigui-worker] Chargement {model_id}...")
|
||||
|
||||
bnb = BitsAndBytesConfig(
|
||||
load_in_4bit=True, bnb_4bit_quant_type="nf4",
|
||||
bnb_4bit_compute_dtype=torch.bfloat16, bnb_4bit_use_double_quant=True,
|
||||
)
|
||||
model = Qwen2_5_VLForConditionalGeneration.from_pretrained(
|
||||
model_id, quantization_config=bnb, device_map={"": "cuda:0"},
|
||||
)
|
||||
model.eval()
|
||||
processor = AutoProcessor.from_pretrained(
|
||||
model_id, padding_side="left",
|
||||
min_pixels=100 * 28 * 28, max_pixels=5600 * 28 * 28,
|
||||
)
|
||||
|
||||
vram = torch.cuda.memory_allocated() / 1e9
|
||||
print(f"[infigui-worker] Prêt — VRAM: {vram:.2f}GB")
|
||||
|
||||
# Signal "prêt"
|
||||
with open(READY_FILE, "w") as f:
|
||||
f.write(f"ready {vram:.2f}GB")
|
||||
|
||||
return model, processor
|
||||
|
||||
|
||||
def infer(model, processor, req):
|
||||
"""Fait une inférence.
|
||||
|
||||
Modes :
|
||||
- texte seul (target/description) : grounding classique
|
||||
- fusionné (anchor_image_path présent) : on passe en plus le crop d'ancre
|
||||
comme image de référence et le modèle doit retrouver cet élément sur
|
||||
le screenshot. Évite la double passe describe→ground.
|
||||
"""
|
||||
from PIL import Image
|
||||
from qwen_vl_utils import process_vision_info
|
||||
|
||||
target = req.get("target", "")
|
||||
description = req.get("description", "")
|
||||
label = f"{target} — {description}" if description else target
|
||||
|
||||
# Image principale (screenshot complet)
|
||||
image_path = req.get("image_path", "")
|
||||
if image_path and os.path.exists(image_path):
|
||||
img = Image.open(image_path).convert("RGB")
|
||||
else:
|
||||
import mss
|
||||
with mss.mss() as sct:
|
||||
grab = sct.grab(sct.monitors[0])
|
||||
img = Image.frombytes("RGB", grab.size, grab.bgra, "raw", "BGRX")
|
||||
|
||||
# Image d'ancre (optionnelle) — mode fusionné describe+ground
|
||||
anchor_image_path = req.get("anchor_image_path", "")
|
||||
anchor_img = None
|
||||
if anchor_image_path and os.path.exists(anchor_image_path):
|
||||
anchor_img = Image.open(anchor_image_path).convert("RGB")
|
||||
|
||||
if not label.strip() and anchor_img is None:
|
||||
return {"x": None, "y": None, "error": "target ou anchor_image requis"}
|
||||
|
||||
W, H = img.size
|
||||
factor = 28
|
||||
rH = max(factor, round(H / factor) * factor)
|
||||
rW = max(factor, round(W / factor) * factor)
|
||||
|
||||
system = (
|
||||
"You FIRST think about the reasoning process as an internal monologue "
|
||||
"and then provide the final answer.\n"
|
||||
"The reasoning process MUST BE enclosed within <think> </think> tags."
|
||||
)
|
||||
|
||||
# Construction du prompt selon le mode
|
||||
if anchor_img is not None:
|
||||
# Mode fusionné : Image1 = crop d'ancre, Image2 = screenshot
|
||||
hint = f' Hint: this element looks like "{label}".' if label.strip() else ""
|
||||
user_text = (
|
||||
f"The first image is a small crop of a UI element captured previously. "
|
||||
f"The second image is the current screen ({rW}x{rH}).{hint}\n"
|
||||
f"Locate on the second image the UI element that visually matches the first image. "
|
||||
f"Output the coordinates using JSON format: "
|
||||
f'[{{"point_2d": [x, y]}}, ...]'
|
||||
)
|
||||
messages = [
|
||||
{"role": "system", "content": system},
|
||||
{"role": "user", "content": [
|
||||
{"type": "image", "image": anchor_img},
|
||||
{"type": "image", "image": img},
|
||||
{"type": "text", "text": user_text},
|
||||
]},
|
||||
]
|
||||
else:
|
||||
# Mode classique : texte seul
|
||||
user_text = (
|
||||
f'The screen\'s resolution is {rW}x{rH}.\n'
|
||||
f'Locate the UI element(s) for "{label}", '
|
||||
f'output the coordinates using JSON format: '
|
||||
f'[{{"point_2d": [x, y]}}, ...]'
|
||||
)
|
||||
messages = [
|
||||
{"role": "system", "content": system},
|
||||
{"role": "user", "content": [
|
||||
{"type": "image", "image": img},
|
||||
{"type": "text", "text": user_text},
|
||||
]},
|
||||
]
|
||||
|
||||
text = processor.apply_chat_template(messages, tokenize=False, add_generation_prompt=True)
|
||||
image_inputs, video_inputs = process_vision_info(messages)
|
||||
inputs = processor(
|
||||
text=[text], images=image_inputs, videos=video_inputs,
|
||||
padding=True, return_tensors="pt",
|
||||
).to(model.device)
|
||||
|
||||
t0 = time.time()
|
||||
with torch.no_grad():
|
||||
gen = model.generate(**inputs, max_new_tokens=512)
|
||||
infer_ms = (time.time() - t0) * 1000
|
||||
|
||||
trimmed = [o[len(i):] for i, o in zip(inputs.input_ids, gen)]
|
||||
raw = processor.batch_decode(
|
||||
trimmed, skip_special_tokens=True, clean_up_tokenization_spaces=False,
|
||||
)[0].strip()
|
||||
|
||||
mode_str = "fused" if anchor_img is not None else "text"
|
||||
print(f"[infigui-worker] [{mode_str}] '{label[:40]}' ({infer_ms:.0f}ms)")
|
||||
|
||||
# Parser JSON point_2d
|
||||
json_part = raw.split("</think>")[-1] if "</think>" in raw else raw
|
||||
json_part = json_part.replace("```json", "").replace("```", "").strip()
|
||||
|
||||
px, py = None, None
|
||||
try:
|
||||
parsed = json.loads(json_part)
|
||||
if isinstance(parsed, list) and len(parsed) > 0:
|
||||
pt = parsed[0].get("point_2d", [])
|
||||
if len(pt) >= 2:
|
||||
px = int(pt[0] * W / rW)
|
||||
py = int(pt[1] * H / rH)
|
||||
except json.JSONDecodeError:
|
||||
m = re.search(r'"point_2d"\s*:\s*\[(\d+),\s*(\d+)\]', raw)
|
||||
if m:
|
||||
px = int(int(m.group(1)) * W / rW)
|
||||
py = int(int(m.group(2)) * H / rH)
|
||||
|
||||
return {
|
||||
"x": px, "y": py,
|
||||
"method": "infigui",
|
||||
"confidence": 0.90 if px else 0.0,
|
||||
"time_ms": round(infer_ms, 1),
|
||||
}
|
||||
|
||||
|
||||
def main():
|
||||
"""Mode one-shot : lit une requête sur stdin, infère, écrit le résultat sur stdout."""
|
||||
# Lire la requête
|
||||
input_data = sys.stdin.read().strip()
|
||||
if not input_data:
|
||||
print(json.dumps({"x": None, "y": None, "error": "pas de requête"}))
|
||||
return
|
||||
|
||||
try:
|
||||
req = json.loads(input_data)
|
||||
except json.JSONDecodeError:
|
||||
print(json.dumps({"x": None, "y": None, "error": "JSON invalide"}))
|
||||
return
|
||||
|
||||
model, processor = load_model()
|
||||
result = infer(model, processor, req)
|
||||
print(json.dumps(result))
|
||||
|
||||
|
||||
if __name__ == "__main__":
|
||||
main()
|
||||
190
core/grounding/pipeline.py
Normal file
190
core/grounding/pipeline.py
Normal file
@@ -0,0 +1,190 @@
|
||||
"""
|
||||
core/grounding/pipeline.py — Pipeline de grounding en cascade
|
||||
|
||||
Orchestre les methodes de localisation dans l'ordre :
|
||||
1. Template matching (TemplateMatcher, local, ~80ms)
|
||||
2. OCR (docTR via input_handler, local, ~1s)
|
||||
3. UI-TARS (HTTP vers serveur grounding, ~3s)
|
||||
4. Static fallback (coordonnees d'origine du workflow)
|
||||
|
||||
Chaque methode est essayee dans l'ordre. Des qu'une reussit, on retourne
|
||||
le resultat. Cela permet un equilibre entre vitesse (template) et robustesse
|
||||
(UI-TARS pour les elements qui ont change de position/apparence).
|
||||
|
||||
Utilisation :
|
||||
from core.grounding.pipeline import GroundingPipeline
|
||||
from core.grounding.target import GroundingTarget
|
||||
|
||||
pipeline = GroundingPipeline()
|
||||
result = pipeline.locate(GroundingTarget(
|
||||
text="Valider",
|
||||
description="bouton vert en bas",
|
||||
template_b64=screenshot_b64,
|
||||
original_bbox={"x": 100, "y": 200, "width": 80, "height": 30},
|
||||
))
|
||||
if result:
|
||||
print(f"Trouve a ({result.x}, {result.y}) via {result.method}")
|
||||
"""
|
||||
|
||||
from __future__ import annotations
|
||||
|
||||
import time
|
||||
from typing import Optional
|
||||
|
||||
from core.grounding.target import GroundingTarget, GroundingResult
|
||||
|
||||
|
||||
class GroundingPipeline:
|
||||
"""Pipeline de localisation en cascade : template -> OCR -> UI-TARS -> static."""
|
||||
|
||||
def __init__(self, template_threshold: float = 0.75, enable_uitars: bool = True):
|
||||
self.template_threshold = template_threshold
|
||||
self.enable_uitars = enable_uitars
|
||||
|
||||
def locate(self, target: GroundingTarget) -> Optional[GroundingResult]:
|
||||
"""Localise un element UI en essayant les methodes en cascade.
|
||||
|
||||
Args:
|
||||
target: description de l'element a localiser
|
||||
|
||||
Returns:
|
||||
GroundingResult ou None si aucune methode ne trouve l'element
|
||||
"""
|
||||
t0 = time.time()
|
||||
|
||||
# --- Methode 1 : Template matching (~80ms) ---
|
||||
result = self._try_template(target)
|
||||
if result:
|
||||
print(f"[GroundingPipeline] Localise via {result.method} en "
|
||||
f"{(time.time() - t0) * 1000:.0f}ms")
|
||||
return result
|
||||
|
||||
# --- Methode 2 : OCR texte (~1s) ---
|
||||
result = self._try_ocr(target)
|
||||
if result:
|
||||
print(f"[GroundingPipeline] Localise via {result.method} en "
|
||||
f"{(time.time() - t0) * 1000:.0f}ms")
|
||||
return result
|
||||
|
||||
# --- Methode 3 : UI-TARS via serveur HTTP (~3s) ---
|
||||
if self.enable_uitars:
|
||||
result = self._try_uitars(target)
|
||||
if result:
|
||||
print(f"[GroundingPipeline] Localise via {result.method} en "
|
||||
f"{(time.time() - t0) * 1000:.0f}ms")
|
||||
return result
|
||||
|
||||
# --- Methode 4 : Fallback statique ---
|
||||
result = self._try_static(target)
|
||||
if result:
|
||||
print(f"[GroundingPipeline] Localise via {result.method} en "
|
||||
f"{(time.time() - t0) * 1000:.0f}ms")
|
||||
return result
|
||||
|
||||
print(f"[GroundingPipeline] ECHEC: '{target.text}' introuvable "
|
||||
f"(toutes methodes epuisees, {(time.time() - t0) * 1000:.0f}ms)")
|
||||
return None
|
||||
|
||||
# ------------------------------------------------------------------
|
||||
# Methodes individuelles
|
||||
# ------------------------------------------------------------------
|
||||
|
||||
def _try_template(self, target: GroundingTarget) -> Optional[GroundingResult]:
|
||||
"""Template matching — rapide, exact, mais sensible aux changements visuels."""
|
||||
if not target.template_b64:
|
||||
return None
|
||||
|
||||
try:
|
||||
from core.grounding.template_matcher import TemplateMatcher
|
||||
matcher = TemplateMatcher(threshold=self.template_threshold)
|
||||
match = matcher.match_screen(anchor_b64=target.template_b64)
|
||||
if match:
|
||||
print(f"[GroundingPipeline/template] score={match.score:.3f} "
|
||||
f"pos=({match.x},{match.y}) ({match.time_ms:.0f}ms)")
|
||||
return GroundingResult(
|
||||
x=match.x,
|
||||
y=match.y,
|
||||
method='template',
|
||||
confidence=match.score,
|
||||
time_ms=match.time_ms,
|
||||
)
|
||||
else:
|
||||
diag = matcher.match_screen_diagnostic(anchor_b64=target.template_b64)
|
||||
print(f"[GroundingPipeline/template] pas de match — best={diag}")
|
||||
except Exception as e:
|
||||
print(f"[GroundingPipeline/template] ERREUR: {e}")
|
||||
|
||||
return None
|
||||
|
||||
def _try_ocr(self, target: GroundingTarget) -> Optional[GroundingResult]:
|
||||
"""OCR : cherche le texte cible sur l'ecran via docTR."""
|
||||
if not target.text:
|
||||
return None
|
||||
|
||||
try:
|
||||
from core.execution.input_handler import _grounding_ocr
|
||||
bbox = target.original_bbox if target.original_bbox else None
|
||||
result = _grounding_ocr(target.text, anchor_bbox=bbox)
|
||||
if result:
|
||||
print(f"[GroundingPipeline/OCR] '{target.text}' -> ({result['x']}, {result['y']})")
|
||||
return GroundingResult(
|
||||
x=result['x'],
|
||||
y=result['y'],
|
||||
method='ocr',
|
||||
confidence=result.get('confidence', 0.80),
|
||||
time_ms=result.get('time_ms', 0),
|
||||
)
|
||||
else:
|
||||
print(f"[GroundingPipeline/OCR] '{target.text}' non trouve")
|
||||
except Exception as e:
|
||||
print(f"[GroundingPipeline/OCR] ERREUR: {e}")
|
||||
|
||||
return None
|
||||
|
||||
def _try_uitars(self, target: GroundingTarget) -> Optional[GroundingResult]:
|
||||
"""UI-TARS via serveur HTTP — robust, gere les changements de layout."""
|
||||
if not target.text and not target.description:
|
||||
return None
|
||||
|
||||
try:
|
||||
from core.grounding.ui_tars_grounder import UITarsGrounder
|
||||
grounder = UITarsGrounder.get_instance()
|
||||
result = grounder.ground(
|
||||
target_text=target.text,
|
||||
target_description=target.description,
|
||||
)
|
||||
if result:
|
||||
print(f"[GroundingPipeline/UI-TARS] ({result.x}, {result.y}) "
|
||||
f"conf={result.confidence:.2f} ({result.time_ms:.0f}ms)")
|
||||
return result
|
||||
else:
|
||||
print(f"[GroundingPipeline/UI-TARS] pas de resultat")
|
||||
except Exception as e:
|
||||
print(f"[GroundingPipeline/UI-TARS] ERREUR: {e}")
|
||||
|
||||
return None
|
||||
|
||||
def _try_static(self, target: GroundingTarget) -> Optional[GroundingResult]:
|
||||
"""Fallback : coordonnees d'origine du workflow (centre du bounding box)."""
|
||||
bbox = target.original_bbox
|
||||
if not bbox:
|
||||
return None
|
||||
|
||||
w = bbox.get('width', 0)
|
||||
h = bbox.get('height', 0)
|
||||
if not w or not h:
|
||||
return None
|
||||
|
||||
x = int(bbox.get('x', 0) + w / 2)
|
||||
y = int(bbox.get('y', 0) + h / 2)
|
||||
|
||||
print(f"[GroundingPipeline/static] fallback ({x}, {y}) "
|
||||
f"depuis bbox {bbox}")
|
||||
|
||||
return GroundingResult(
|
||||
x=x,
|
||||
y=y,
|
||||
method='static_fallback',
|
||||
confidence=0.30,
|
||||
time_ms=0.0,
|
||||
)
|
||||
113
core/grounding/server.py
Normal file
113
core/grounding/server.py
Normal file
@@ -0,0 +1,113 @@
|
||||
"""Serveur grounding minimaliste — Flask single-thread, même contexte CUDA."""
|
||||
import base64, io, json, math, os, re, time, gc
|
||||
import torch
|
||||
from flask import Flask, request, jsonify
|
||||
from PIL import Image
|
||||
|
||||
app = Flask(__name__)
|
||||
|
||||
MODEL_ID = os.environ.get("GROUNDING_MODEL", "InfiX-ai/InfiGUI-G1-3B")
|
||||
MIN_PIXELS = 100 * 28 * 28
|
||||
MAX_PIXELS = 5600 * 28 * 28
|
||||
_model = None
|
||||
_processor = None
|
||||
|
||||
def _smart_resize(h, w, factor=28):
|
||||
h_bar = max(factor, round(h/factor)*factor)
|
||||
w_bar = max(factor, round(w/factor)*factor)
|
||||
if h_bar*w_bar > MAX_PIXELS:
|
||||
beta = math.sqrt((h*w)/MAX_PIXELS)
|
||||
h_bar = math.floor(h/beta/factor)*factor
|
||||
w_bar = math.floor(w/beta/factor)*factor
|
||||
elif h_bar*w_bar < MIN_PIXELS:
|
||||
beta = math.sqrt(MIN_PIXELS/(h*w))
|
||||
h_bar = math.ceil(h*beta/factor)*factor
|
||||
w_bar = math.ceil(w*beta/factor)*factor
|
||||
return h_bar, w_bar
|
||||
|
||||
def load_model():
|
||||
global _model, _processor
|
||||
if _model is not None:
|
||||
return
|
||||
from transformers import Qwen2_5_VLForConditionalGeneration, AutoProcessor, BitsAndBytesConfig
|
||||
torch.cuda.empty_cache(); gc.collect()
|
||||
print(f"[grounding] Chargement {MODEL_ID}...")
|
||||
bnb = BitsAndBytesConfig(load_in_4bit=True, bnb_4bit_quant_type="nf4",
|
||||
bnb_4bit_compute_dtype=torch.bfloat16, bnb_4bit_use_double_quant=True)
|
||||
_model = Qwen2_5_VLForConditionalGeneration.from_pretrained(
|
||||
MODEL_ID, quantization_config=bnb, device_map="auto")
|
||||
_model.eval()
|
||||
_processor = AutoProcessor.from_pretrained(MODEL_ID, min_pixels=MIN_PIXELS, max_pixels=MAX_PIXELS, padding_side="left")
|
||||
print(f"[grounding] Prêt — VRAM: {torch.cuda.memory_allocated()/1e9:.2f}GB")
|
||||
|
||||
@app.route('/health')
|
||||
def health():
|
||||
return jsonify({"status": "ok", "model": MODEL_ID, "model_loaded": _model is not None,
|
||||
"cuda_available": torch.cuda.is_available(),
|
||||
"vram_allocated_gb": round(torch.cuda.memory_allocated()/1e9, 2)})
|
||||
|
||||
@app.route('/ground', methods=['POST'])
|
||||
def ground():
|
||||
if _model is None:
|
||||
return jsonify({"error": "Modèle pas chargé"}), 503
|
||||
from qwen_vl_utils import process_vision_info
|
||||
data = request.json
|
||||
target = data.get('target_text', '')
|
||||
desc = data.get('target_description', '')
|
||||
label = f"{target} — {desc}" if desc else target
|
||||
if not label.strip():
|
||||
return jsonify({"error": "target_text requis"}), 400
|
||||
|
||||
# Image
|
||||
if data.get('image_b64'):
|
||||
raw = data['image_b64'].split(',')[1] if ',' in data['image_b64'] else data['image_b64']
|
||||
img = Image.open(io.BytesIO(base64.b64decode(raw))).convert('RGB')
|
||||
else:
|
||||
import mss
|
||||
with mss.mss() as sct:
|
||||
grab = sct.grab(sct.monitors[0])
|
||||
img = Image.frombytes('RGB', grab.size, grab.bgra, 'raw', 'BGRX')
|
||||
|
||||
W, H = img.size
|
||||
rH, rW = _smart_resize(H, W)
|
||||
|
||||
user_text = f'The screen\'s resolution is {rW}x{rH}.\nLocate the UI element(s) for "{label}", output the coordinates using JSON format: [{{"point_2d": [x, y]}}, ...]'
|
||||
system = "You FIRST think about the reasoning process as an internal monologue and then provide the final answer.\nThe reasoning process MUST BE enclosed within <think> </think> tags."
|
||||
|
||||
messages = [{"role": "system", "content": system},
|
||||
{"role": "user", "content": [{"type": "image", "image": img}, {"type": "text", "text": user_text}]}]
|
||||
|
||||
text = _processor.apply_chat_template(messages, tokenize=False, add_generation_prompt=True)
|
||||
image_inputs, video_inputs = process_vision_info(messages)
|
||||
inputs = _processor(text=[text], images=image_inputs, videos=video_inputs, padding=True, return_tensors="pt").to(_model.device)
|
||||
|
||||
t0 = time.time()
|
||||
with torch.no_grad():
|
||||
gen = _model.generate(**inputs, max_new_tokens=512)
|
||||
infer_ms = (time.time()-t0)*1000
|
||||
|
||||
trimmed = [o[len(i):] for i,o in zip(inputs.input_ids, gen)]
|
||||
raw = _processor.batch_decode(trimmed, skip_special_tokens=True, clean_up_tokenization_spaces=False)[0].strip()
|
||||
print(f"[grounding] '{label[:40]}' → {raw[:100]} ({infer_ms:.0f}ms)")
|
||||
|
||||
# Parser JSON point_2d
|
||||
json_part = raw.split("</think>")[-1] if "</think>" in raw else raw
|
||||
json_part = json_part.replace("```json","").replace("```","").strip()
|
||||
px, py = None, None
|
||||
try:
|
||||
parsed = json.loads(json_part)
|
||||
if isinstance(parsed, list) and len(parsed) > 0:
|
||||
pt = parsed[0].get("point_2d", [])
|
||||
if len(pt) >= 2:
|
||||
px, py = int(pt[0]*W/rW), int(pt[1]*H/rH)
|
||||
except json.JSONDecodeError:
|
||||
m = re.search(r'"point_2d"\s*:\s*\[(\d+),\s*(\d+)\]', raw)
|
||||
if m:
|
||||
px, py = int(int(m.group(1))*W/rW), int(int(m.group(2))*H/rH)
|
||||
|
||||
return jsonify({"x": px, "y": py, "method": "infigui", "confidence": 0.90 if px else 0.0,
|
||||
"time_ms": round(infer_ms, 1), "raw_output": raw[:300]})
|
||||
|
||||
if __name__ == '__main__':
|
||||
load_model()
|
||||
app.run(host='0.0.0.0', port=8200, threaded=False)
|
||||
156
core/grounding/shadow_learning_hook.py
Normal file
156
core/grounding/shadow_learning_hook.py
Normal file
@@ -0,0 +1,156 @@
|
||||
"""
|
||||
core/grounding/shadow_learning_hook.py — Hook d'apprentissage Shadow
|
||||
|
||||
Connecte le ShadowObserver au SignatureStore : chaque clic observé pendant
|
||||
une session Shadow enrichit la base de signatures d'éléments.
|
||||
|
||||
L'humain clique quelque part → on détecte quel élément UI est sous le clic →
|
||||
on stocke sa signature (texte, type, position, voisins) pour le replay.
|
||||
|
||||
Ce module est un HOOK optionnel — il ne modifie pas le ShadowObserver,
|
||||
il s'y branche via callback.
|
||||
|
||||
Utilisation :
|
||||
from core.grounding.shadow_learning_hook import ShadowLearningHook
|
||||
|
||||
hook = ShadowLearningHook()
|
||||
|
||||
# Dans le ShadowObserver ou l'API de capture :
|
||||
hook.on_click_observed(
|
||||
click_x=542, click_y=318,
|
||||
screenshot_pil=screen,
|
||||
window_title="Bloc-notes",
|
||||
target_label="Bouton Valider",
|
||||
)
|
||||
"""
|
||||
|
||||
from __future__ import annotations
|
||||
|
||||
import threading
|
||||
import time
|
||||
from typing import Any, Dict, Optional
|
||||
|
||||
from core.grounding.element_signature import SignatureStore
|
||||
from core.grounding.fast_types import DetectedUIElement
|
||||
|
||||
|
||||
class ShadowLearningHook:
|
||||
"""Hook d'apprentissage pour le mode Shadow.
|
||||
|
||||
À chaque clic humain observé, détecte l'élément sous le clic
|
||||
et enrichit le SignatureStore.
|
||||
"""
|
||||
|
||||
def __init__(self, signature_store: Optional[SignatureStore] = None):
|
||||
self._store = signature_store or SignatureStore()
|
||||
self._detector = None # Lazy load pour ne pas charger RF-DETR au startup
|
||||
self._lock = threading.Lock()
|
||||
|
||||
def on_click_observed(
|
||||
self,
|
||||
click_x: int,
|
||||
click_y: int,
|
||||
screenshot_pil: Optional[Any] = None,
|
||||
window_title: str = "",
|
||||
target_label: str = "",
|
||||
target_description: str = "",
|
||||
) -> Optional[Dict[str, Any]]:
|
||||
"""Appelé quand un clic humain est observé pendant le Shadow.
|
||||
|
||||
Args:
|
||||
click_x, click_y: Position du clic (pixels écran).
|
||||
screenshot_pil: Capture d'écran PIL au moment du clic.
|
||||
window_title: Titre de la fenêtre active.
|
||||
target_label: Label de l'étape (si connu).
|
||||
target_description: Description de l'élément (si connue).
|
||||
|
||||
Returns:
|
||||
Dict avec la signature créée/enrichie, ou None si échec.
|
||||
"""
|
||||
t0 = time.time()
|
||||
|
||||
try:
|
||||
# Lazy load du détecteur
|
||||
if self._detector is None:
|
||||
from core.grounding.fast_detector import FastDetector
|
||||
self._detector = FastDetector()
|
||||
|
||||
# Détecter les éléments sur l'écran
|
||||
snapshot = self._detector.detect(screenshot_pil=screenshot_pil)
|
||||
|
||||
if not snapshot.elements:
|
||||
print(f"📝 [Shadow/learn] Aucun élément détecté à ({click_x}, {click_y})")
|
||||
return None
|
||||
|
||||
# Trouver l'élément sous le clic
|
||||
clicked_element = self._find_element_at(click_x, click_y, snapshot.elements)
|
||||
|
||||
if clicked_element is None:
|
||||
print(f"📝 [Shadow/learn] Aucun élément sous ({click_x}, {click_y})")
|
||||
return None
|
||||
|
||||
# Construire la clé de la cible
|
||||
target_key = SignatureStore.make_target_key(
|
||||
target_label or clicked_element.ocr_text,
|
||||
target_description,
|
||||
)
|
||||
screen_ctx = SignatureStore.make_screen_context(
|
||||
window_title, snapshot.resolution,
|
||||
)
|
||||
|
||||
# Enregistrer la signature
|
||||
self._store.record_success(
|
||||
target_key=target_key,
|
||||
screen_context=screen_ctx,
|
||||
element=clicked_element,
|
||||
confidence=1.0, # L'humain a cliqué → confiance maximale
|
||||
)
|
||||
|
||||
dt = (time.time() - t0) * 1000
|
||||
print(f"📝 [Shadow/learn] Signature '{clicked_element.ocr_text}' "
|
||||
f"type={clicked_element.element_type} "
|
||||
f"pos={clicked_element.relative_position} "
|
||||
f"voisins={clicked_element.neighbors[:3]} ({dt:.0f}ms)")
|
||||
|
||||
return {
|
||||
"target_key": target_key,
|
||||
"text": clicked_element.ocr_text,
|
||||
"element_type": clicked_element.element_type,
|
||||
"relative_position": clicked_element.relative_position,
|
||||
"neighbors": clicked_element.neighbors,
|
||||
"center": clicked_element.center,
|
||||
}
|
||||
|
||||
except Exception as e:
|
||||
print(f"⚠️ [Shadow/learn] Erreur: {e}")
|
||||
return None
|
||||
|
||||
@staticmethod
|
||||
def _find_element_at(
|
||||
x: int, y: int,
|
||||
elements: list,
|
||||
margin: int = 20,
|
||||
) -> Optional[DetectedUIElement]:
|
||||
"""Trouve l'élément dont la bbox contient le point (x, y).
|
||||
|
||||
Si aucun match exact, prend le plus proche dans un rayon de `margin` pixels.
|
||||
"""
|
||||
# Match exact : le clic est dans la bbox
|
||||
for elem in elements:
|
||||
x1, y1, x2, y2 = elem.bbox
|
||||
if x1 <= x <= x2 and y1 <= y <= y2:
|
||||
return elem
|
||||
|
||||
# Match par proximité : le clic est proche du centre
|
||||
best_elem = None
|
||||
best_dist = float('inf')
|
||||
|
||||
for elem in elements:
|
||||
dx = abs(elem.center[0] - x)
|
||||
dy = abs(elem.center[1] - y)
|
||||
dist = (dx**2 + dy**2) ** 0.5
|
||||
if dist < margin and dist < best_dist:
|
||||
best_dist = dist
|
||||
best_elem = elem
|
||||
|
||||
return best_elem
|
||||
263
core/grounding/smart_matcher.py
Normal file
263
core/grounding/smart_matcher.py
Normal file
@@ -0,0 +1,263 @@
|
||||
"""
|
||||
core/grounding/smart_matcher.py — Layer SMART : matching déterministe/probabiliste
|
||||
|
||||
Étant donné un ScreenSnapshot (tous les éléments détectés) et un GroundingTarget
|
||||
(ce qu'on cherche), trouve l'élément correspondant avec un score de confiance.
|
||||
|
||||
Pipeline de matching (court-circuit au premier match haute confiance) :
|
||||
1. Texte exact (2ms) → score 0.95
|
||||
2. Texte fuzzy ratio (5ms) → score 0.70-0.90
|
||||
3. Type + position (2ms) → bonus/malus
|
||||
4. Voisins contextuels (5ms) → bonus
|
||||
5. Score combiné → MatchCandidate
|
||||
|
||||
Utilisation :
|
||||
from core.grounding.smart_matcher import SmartMatcher
|
||||
from core.grounding.fast_types import ScreenSnapshot
|
||||
from core.grounding.target import GroundingTarget
|
||||
|
||||
matcher = SmartMatcher()
|
||||
candidate = matcher.match(snapshot, GroundingTarget(text="Valider"))
|
||||
if candidate and candidate.score >= 0.90:
|
||||
print(f"Match direct : ({candidate.element.center}) score={candidate.score}")
|
||||
"""
|
||||
|
||||
from __future__ import annotations
|
||||
|
||||
import re
|
||||
from difflib import SequenceMatcher
|
||||
from typing import Dict, List, Optional
|
||||
|
||||
from core.grounding.fast_types import DetectedUIElement, MatchCandidate, ScreenSnapshot
|
||||
from core.grounding.target import GroundingTarget
|
||||
|
||||
|
||||
class SmartMatcher:
|
||||
"""Matching intelligent entre une cible et les éléments détectés.
|
||||
|
||||
Combine plusieurs signaux (texte, type, position, voisins) en un score
|
||||
de confiance unique pour chaque candidat.
|
||||
"""
|
||||
|
||||
def __init__(
|
||||
self,
|
||||
weight_text: float = 0.50,
|
||||
weight_type: float = 0.10,
|
||||
weight_position: float = 0.15,
|
||||
weight_neighbors: float = 0.25,
|
||||
):
|
||||
self.w_text = weight_text
|
||||
self.w_type = weight_type
|
||||
self.w_position = weight_position
|
||||
self.w_neighbors = weight_neighbors
|
||||
|
||||
def match(
|
||||
self,
|
||||
snapshot: ScreenSnapshot,
|
||||
target: GroundingTarget,
|
||||
signature: Optional[Dict] = None,
|
||||
) -> Optional[MatchCandidate]:
|
||||
"""Trouve le MEILLEUR élément correspondant à la cible.
|
||||
|
||||
Returns:
|
||||
Le MatchCandidate avec le score le plus élevé, ou None si aucun match.
|
||||
"""
|
||||
candidates = self.match_all(snapshot, target, signature)
|
||||
if not candidates:
|
||||
return None
|
||||
return candidates[0]
|
||||
|
||||
def match_all(
|
||||
self,
|
||||
snapshot: ScreenSnapshot,
|
||||
target: GroundingTarget,
|
||||
signature: Optional[Dict] = None,
|
||||
) -> List[MatchCandidate]:
|
||||
"""Trouve TOUS les candidats triés par score décroissant.
|
||||
|
||||
Args:
|
||||
snapshot: État de l'écran (éléments détectés + OCR).
|
||||
target: Ce qu'on cherche (texte, description, bbox d'origine).
|
||||
signature: Signature apprise (optionnel, enrichit le matching).
|
||||
|
||||
Returns:
|
||||
Liste de MatchCandidate triée par score décroissant.
|
||||
"""
|
||||
if not snapshot.elements:
|
||||
return []
|
||||
|
||||
target_text = (target.text or "").strip()
|
||||
target_desc = (target.description or "").strip()
|
||||
search_text = target_text or target_desc
|
||||
|
||||
if not search_text:
|
||||
return []
|
||||
|
||||
candidates = []
|
||||
search_lower = self._normalize(search_text)
|
||||
|
||||
for elem in snapshot.elements:
|
||||
score_detail: Dict[str, float] = {}
|
||||
method = ""
|
||||
|
||||
# --- 1. Score texte ---
|
||||
text_score = self._score_text(search_lower, elem.ocr_text)
|
||||
score_detail["text"] = text_score
|
||||
|
||||
if text_score >= 0.95:
|
||||
method = "exact_text"
|
||||
elif text_score >= 0.70:
|
||||
method = "fuzzy_text"
|
||||
|
||||
# --- 2. Score type (si signature connue) ---
|
||||
type_score = 0.5 # neutre par défaut
|
||||
if signature and signature.get("element_type"):
|
||||
if elem.element_type == signature["element_type"]:
|
||||
type_score = 1.0
|
||||
elif elem.element_type == "element":
|
||||
type_score = 0.5 # non classifié, neutre
|
||||
else:
|
||||
type_score = 0.2
|
||||
score_detail["type"] = type_score
|
||||
|
||||
# --- 3. Score position (si bbox d'origine connue) ---
|
||||
position_score = 0.5 # neutre
|
||||
if target.original_bbox:
|
||||
position_score = self._score_position(
|
||||
elem.center, target.original_bbox,
|
||||
snapshot.resolution[0], snapshot.resolution[1],
|
||||
)
|
||||
elif signature and signature.get("relative_position"):
|
||||
if elem.relative_position == signature["relative_position"]:
|
||||
position_score = 0.9
|
||||
else:
|
||||
position_score = 0.3
|
||||
score_detail["position"] = position_score
|
||||
|
||||
# --- 4. Score voisins (si signature connue) ---
|
||||
neighbor_score = 0.5 # neutre
|
||||
if signature and signature.get("neighbors"):
|
||||
neighbor_score = self._score_neighbors(
|
||||
elem.neighbors, signature["neighbors"]
|
||||
)
|
||||
score_detail["neighbors"] = neighbor_score
|
||||
|
||||
# --- Score combiné ---
|
||||
combined = (
|
||||
self.w_text * text_score
|
||||
+ self.w_type * type_score
|
||||
+ self.w_position * position_score
|
||||
+ self.w_neighbors * neighbor_score
|
||||
)
|
||||
|
||||
# Seuil minimum : pas de candidat si le texte ne matche pas du tout
|
||||
if text_score < 0.30:
|
||||
continue
|
||||
|
||||
if not method:
|
||||
method = "combined"
|
||||
|
||||
candidates.append(MatchCandidate(
|
||||
element=elem,
|
||||
score=combined,
|
||||
score_detail=score_detail,
|
||||
method=method,
|
||||
))
|
||||
|
||||
# Trier par score décroissant
|
||||
candidates.sort(key=lambda c: c.score, reverse=True)
|
||||
|
||||
return candidates
|
||||
|
||||
# ------------------------------------------------------------------
|
||||
# Scoring texte
|
||||
# ------------------------------------------------------------------
|
||||
|
||||
def _score_text(self, search: str, ocr_text: str) -> float:
|
||||
"""Score de similarité textuelle (0-1)."""
|
||||
if not ocr_text:
|
||||
return 0.0
|
||||
|
||||
ocr_lower = self._normalize(ocr_text)
|
||||
|
||||
# Match exact
|
||||
if search == ocr_lower:
|
||||
return 1.0
|
||||
|
||||
# Inclusion (l'un contient l'autre)
|
||||
if search in ocr_lower or ocr_lower in search:
|
||||
overlap = min(len(search), len(ocr_lower))
|
||||
total = max(len(search), len(ocr_lower))
|
||||
if total > 0:
|
||||
return 0.70 + 0.25 * (overlap / total)
|
||||
|
||||
# Fuzzy matching (SequenceMatcher, standard library)
|
||||
ratio = SequenceMatcher(None, search, ocr_lower).ratio()
|
||||
if ratio >= 0.60:
|
||||
return 0.50 + 0.40 * ratio
|
||||
|
||||
return ratio * 0.3
|
||||
|
||||
# ------------------------------------------------------------------
|
||||
# Scoring position
|
||||
# ------------------------------------------------------------------
|
||||
|
||||
@staticmethod
|
||||
def _score_position(
|
||||
center: tuple,
|
||||
original_bbox: dict,
|
||||
screen_w: int,
|
||||
screen_h: int,
|
||||
) -> float:
|
||||
"""Score de proximité par rapport à la position d'origine (0-1)."""
|
||||
if not original_bbox:
|
||||
return 0.5
|
||||
|
||||
orig_x = original_bbox.get("x", 0) + original_bbox.get("width", 0) / 2
|
||||
orig_y = original_bbox.get("y", 0) + original_bbox.get("height", 0) / 2
|
||||
|
||||
dx = abs(center[0] - orig_x) / max(screen_w, 1)
|
||||
dy = abs(center[1] - orig_y) / max(screen_h, 1)
|
||||
distance_norm = (dx**2 + dy**2) ** 0.5
|
||||
|
||||
# distance 0 = score 1.0, distance 0.5 (demi-écran) = score ~0.2
|
||||
return max(0.0, 1.0 - distance_norm * 2.0)
|
||||
|
||||
# ------------------------------------------------------------------
|
||||
# Scoring voisins
|
||||
# ------------------------------------------------------------------
|
||||
|
||||
@staticmethod
|
||||
def _score_neighbors(
|
||||
current_neighbors: List[str],
|
||||
expected_neighbors: List[str],
|
||||
) -> float:
|
||||
"""Score Jaccard sur les ensembles de mots voisins (0-1)."""
|
||||
if not expected_neighbors:
|
||||
return 0.5
|
||||
|
||||
current_set = {n.lower().strip() for n in current_neighbors if n}
|
||||
expected_set = {n.lower().strip() for n in expected_neighbors if n}
|
||||
|
||||
if not current_set and not expected_set:
|
||||
return 0.5
|
||||
|
||||
intersection = current_set & expected_set
|
||||
union = current_set | expected_set
|
||||
|
||||
if not union:
|
||||
return 0.5
|
||||
|
||||
return len(intersection) / len(union)
|
||||
|
||||
# ------------------------------------------------------------------
|
||||
# Utilitaires
|
||||
# ------------------------------------------------------------------
|
||||
|
||||
@staticmethod
|
||||
def _normalize(text: str) -> str:
|
||||
"""Normalise un texte pour la comparaison."""
|
||||
text = text.lower().strip()
|
||||
text = re.sub(r'[_\-\./\\]', ' ', text)
|
||||
text = re.sub(r'\s+', ' ', text)
|
||||
return text
|
||||
48
core/grounding/target.py
Normal file
48
core/grounding/target.py
Normal file
@@ -0,0 +1,48 @@
|
||||
"""
|
||||
core/grounding/target.py — Types partagés pour le grounding visuel
|
||||
|
||||
Dataclasses décrivant une cible à localiser (GroundingTarget) et
|
||||
le résultat d'une localisation (GroundingResult).
|
||||
|
||||
Ces types sont la brique commune pour tous les modules de grounding :
|
||||
template matching, OCR, VLM, CLIP, etc.
|
||||
"""
|
||||
|
||||
from __future__ import annotations
|
||||
|
||||
from dataclasses import dataclass, field
|
||||
from typing import Dict, Optional
|
||||
|
||||
|
||||
@dataclass
|
||||
class GroundingTarget:
|
||||
"""Description d'un élément UI à localiser sur l'écran.
|
||||
|
||||
Attributs :
|
||||
text : texte visible de l'élément (bouton, label, etc.)
|
||||
description : description sémantique libre (ex: "le bouton Valider en bas à droite")
|
||||
template_b64 : capture visuelle de l'élément, encodée en base64 PNG/JPEG
|
||||
original_bbox : position d'origine lors de la capture {x, y, width, height}
|
||||
"""
|
||||
text: str = ""
|
||||
description: str = ""
|
||||
template_b64: str = ""
|
||||
original_bbox: Optional[Dict[str, int]] = field(default=None)
|
||||
|
||||
|
||||
@dataclass
|
||||
class GroundingResult:
|
||||
"""Résultat d'une localisation d'élément UI.
|
||||
|
||||
Attributs :
|
||||
x : coordonnée X du centre de l'élément trouvé (pixels écran)
|
||||
y : coordonnée Y du centre de l'élément trouvé (pixels écran)
|
||||
method : méthode ayant produit le résultat ('template', 'ocr', 'vlm', 'clip', etc.)
|
||||
confidence : score de confiance [0.0 – 1.0]
|
||||
time_ms : temps de recherche en millisecondes
|
||||
"""
|
||||
x: int
|
||||
y: int
|
||||
method: str
|
||||
confidence: float
|
||||
time_ms: float
|
||||
350
core/grounding/template_matcher.py
Normal file
350
core/grounding/template_matcher.py
Normal file
@@ -0,0 +1,350 @@
|
||||
"""
|
||||
core/grounding/template_matcher.py — Template matching centralisé
|
||||
|
||||
Fournit une classe TemplateMatcher qui localise une ancre visuelle (image template)
|
||||
dans un screenshot via cv2.matchTemplate. Supporte single-scale et multi-scale.
|
||||
|
||||
Remplace les implémentations dupliquées dans :
|
||||
- core/execution/observe_reason_act.py (~1348-1375)
|
||||
- visual_workflow_builder/backend/api_v3/execute.py (~930-963)
|
||||
- visual_workflow_builder/backend/catalog_routes_v2_vlm.py (~339-381)
|
||||
- visual_workflow_builder/backend/services/intelligent_executor.py (~131-210)
|
||||
- core/detection/omniparser_adapter.py (~330)
|
||||
|
||||
Utilisation :
|
||||
from core.grounding import TemplateMatcher, MatchResult
|
||||
|
||||
matcher = TemplateMatcher(threshold=0.75)
|
||||
result = matcher.match_screen(anchor_b64="...")
|
||||
if result:
|
||||
print(f"Trouvé à ({result.x}, {result.y}) score={result.score:.3f}")
|
||||
"""
|
||||
|
||||
from __future__ import annotations
|
||||
|
||||
import base64
|
||||
import io
|
||||
import logging
|
||||
import time
|
||||
from dataclasses import dataclass
|
||||
from typing import List, Optional, Tuple
|
||||
|
||||
logger = logging.getLogger(__name__)
|
||||
|
||||
# Imports optionnels — le module se charge même sans cv2/PIL/mss
|
||||
try:
|
||||
import cv2
|
||||
_CV2 = True
|
||||
except ImportError:
|
||||
_CV2 = False
|
||||
|
||||
try:
|
||||
import numpy as np
|
||||
_NP = True
|
||||
except ImportError:
|
||||
_NP = False
|
||||
|
||||
try:
|
||||
from PIL import Image
|
||||
_PIL = True
|
||||
except ImportError:
|
||||
_PIL = False
|
||||
|
||||
try:
|
||||
import mss as mss_lib
|
||||
_MSS = True
|
||||
except ImportError:
|
||||
_MSS = False
|
||||
|
||||
|
||||
# ---------------------------------------------------------------------------
|
||||
# Résultat d'un match
|
||||
# ---------------------------------------------------------------------------
|
||||
|
||||
@dataclass
|
||||
class MatchResult:
|
||||
"""Résultat d'un template matching."""
|
||||
x: int
|
||||
y: int
|
||||
score: float
|
||||
method: str # 'template' | 'template_multiscale'
|
||||
time_ms: float
|
||||
scale: float = 1.0 # Échelle à laquelle le meilleur match a été trouvé
|
||||
|
||||
|
||||
# ---------------------------------------------------------------------------
|
||||
# TemplateMatcher
|
||||
# ---------------------------------------------------------------------------
|
||||
|
||||
class TemplateMatcher:
|
||||
"""Localise une ancre visuelle dans un screenshot via template matching.
|
||||
|
||||
Paramètres :
|
||||
threshold : score minimum pour accepter un match (défaut 0.75)
|
||||
multiscale : active le matching multi-échelle (défaut False)
|
||||
scales : liste d'échelles à tester en mode multi-scale
|
||||
method : méthode cv2 (défaut cv2.TM_CCOEFF_NORMED)
|
||||
grayscale : convertir en niveaux de gris avant matching (défaut False)
|
||||
"""
|
||||
|
||||
# Échelles par défaut pour le mode multi-scale, ordonnées par
|
||||
# probabilité décroissante (1.0 en premier = rapide si ça matche)
|
||||
DEFAULT_SCALES: List[float] = [1.0, 0.95, 1.05, 0.9, 1.1, 0.85, 1.15, 0.8, 1.2]
|
||||
|
||||
def __init__(
|
||||
self,
|
||||
threshold: float = 0.75,
|
||||
multiscale: bool = False,
|
||||
scales: Optional[List[float]] = None,
|
||||
grayscale: bool = False,
|
||||
):
|
||||
self.threshold = threshold
|
||||
self.multiscale = multiscale
|
||||
self.scales = scales or self.DEFAULT_SCALES
|
||||
self.grayscale = grayscale
|
||||
# cv2.TM_CCOEFF_NORMED est la méthode utilisée partout dans le projet
|
||||
self._cv2_method = cv2.TM_CCOEFF_NORMED if _CV2 else None
|
||||
|
||||
# ------------------------------------------------------------------
|
||||
# API publique
|
||||
# ------------------------------------------------------------------
|
||||
|
||||
def match_screen(
|
||||
self,
|
||||
anchor_b64: Optional[str] = None,
|
||||
anchor_pil: Optional["Image.Image"] = None,
|
||||
screen_pil: Optional["Image.Image"] = None,
|
||||
) -> Optional[MatchResult]:
|
||||
"""Cherche l'ancre dans le screenshot courant (ou fourni).
|
||||
|
||||
L'ancre peut être passée en base64 ou en PIL Image.
|
||||
Le screenshot est capturé via mss si non fourni.
|
||||
|
||||
Retourne un MatchResult ou None si aucun match >= seuil.
|
||||
"""
|
||||
if not (_CV2 and _NP and _PIL):
|
||||
logger.debug("[TemplateMatcher] cv2/numpy/PIL non disponible")
|
||||
return None
|
||||
|
||||
# --- Préparer l'ancre ---
|
||||
anchor_img = self._decode_anchor(anchor_b64, anchor_pil)
|
||||
if anchor_img is None:
|
||||
return None
|
||||
|
||||
# --- Préparer le screenshot ---
|
||||
if screen_pil is None:
|
||||
screen_pil = self._capture_screen()
|
||||
if screen_pil is None:
|
||||
return None
|
||||
|
||||
# --- Convertir en arrays cv2 ---
|
||||
screen_cv = cv2.cvtColor(np.array(screen_pil), cv2.COLOR_RGB2BGR)
|
||||
anchor_cv = cv2.cvtColor(np.array(anchor_img), cv2.COLOR_RGB2BGR)
|
||||
|
||||
# --- Matching ---
|
||||
if self.multiscale:
|
||||
return self._match_multiscale(screen_cv, anchor_cv)
|
||||
else:
|
||||
return self._match_single(screen_cv, anchor_cv)
|
||||
|
||||
def match_in_region(
|
||||
self,
|
||||
region_cv: "np.ndarray",
|
||||
anchor_cv: "np.ndarray",
|
||||
threshold: Optional[float] = None,
|
||||
) -> Optional[MatchResult]:
|
||||
"""Match dans une région déjà découpée (arrays BGR).
|
||||
|
||||
Utilisé par les pipelines qui font leur propre capture/découpe.
|
||||
"""
|
||||
if not (_CV2 and _NP):
|
||||
return None
|
||||
|
||||
thr = threshold if threshold is not None else self.threshold
|
||||
|
||||
if self.multiscale:
|
||||
return self._match_multiscale(region_cv, anchor_cv, threshold_override=thr)
|
||||
else:
|
||||
return self._match_single(region_cv, anchor_cv, threshold_override=thr)
|
||||
|
||||
def match_screen_diagnostic(
|
||||
self,
|
||||
anchor_b64: Optional[str] = None,
|
||||
anchor_pil: Optional["Image.Image"] = None,
|
||||
screen_pil: Optional["Image.Image"] = None,
|
||||
) -> str:
|
||||
"""Retourne un diagnostic textuel (score + position) même sans match."""
|
||||
if not (_CV2 and _NP and _PIL):
|
||||
return "cv2/numpy/PIL non dispo"
|
||||
|
||||
anchor_img = self._decode_anchor(anchor_b64, anchor_pil)
|
||||
if anchor_img is None:
|
||||
return "ancre non décodable"
|
||||
|
||||
if screen_pil is None:
|
||||
screen_pil = self._capture_screen()
|
||||
if screen_pil is None:
|
||||
return "capture écran échouée"
|
||||
|
||||
screen_cv = cv2.cvtColor(np.array(screen_pil), cv2.COLOR_RGB2BGR)
|
||||
anchor_cv = cv2.cvtColor(np.array(anchor_img), cv2.COLOR_RGB2BGR)
|
||||
|
||||
if anchor_cv.shape[0] >= screen_cv.shape[0] or anchor_cv.shape[1] >= screen_cv.shape[1]:
|
||||
return f"ancre {anchor_cv.shape[:2]} >= écran {screen_cv.shape[:2]}"
|
||||
|
||||
s_img, a_img = self._maybe_grayscale(screen_cv, anchor_cv)
|
||||
result_tm = cv2.matchTemplate(s_img, a_img, self._cv2_method)
|
||||
_, max_val, _, max_loc = cv2.minMaxLoc(result_tm)
|
||||
return f"{max_val:.3f} pos={max_loc}"
|
||||
|
||||
# ------------------------------------------------------------------
|
||||
# Méthodes internes
|
||||
# ------------------------------------------------------------------
|
||||
|
||||
def _match_single(
|
||||
self,
|
||||
screen_cv: "np.ndarray",
|
||||
anchor_cv: "np.ndarray",
|
||||
threshold_override: Optional[float] = None,
|
||||
) -> Optional[MatchResult]:
|
||||
"""Template matching single-scale."""
|
||||
threshold = threshold_override if threshold_override is not None else self.threshold
|
||||
|
||||
if anchor_cv.shape[0] >= screen_cv.shape[0] or anchor_cv.shape[1] >= screen_cv.shape[1]:
|
||||
logger.debug("[TemplateMatcher] Ancre plus grande que le screen")
|
||||
return None
|
||||
|
||||
s_img, a_img = self._maybe_grayscale(screen_cv, anchor_cv)
|
||||
|
||||
t0 = time.time()
|
||||
result_tm = cv2.matchTemplate(s_img, a_img, self._cv2_method)
|
||||
_, max_val, _, max_loc = cv2.minMaxLoc(result_tm)
|
||||
elapsed_ms = (time.time() - t0) * 1000
|
||||
|
||||
logger.debug(
|
||||
"[TemplateMatcher] score=%.3f pos=%s (%.0fms)",
|
||||
max_val, max_loc, elapsed_ms,
|
||||
)
|
||||
|
||||
if max_val >= threshold:
|
||||
cx = max_loc[0] + anchor_cv.shape[1] // 2
|
||||
cy = max_loc[1] + anchor_cv.shape[0] // 2
|
||||
return MatchResult(
|
||||
x=cx,
|
||||
y=cy,
|
||||
score=float(max_val),
|
||||
method='template',
|
||||
time_ms=elapsed_ms,
|
||||
scale=1.0,
|
||||
)
|
||||
return None
|
||||
|
||||
def _match_multiscale(
|
||||
self,
|
||||
screen_cv: "np.ndarray",
|
||||
anchor_cv: "np.ndarray",
|
||||
threshold_override: Optional[float] = None,
|
||||
) -> Optional[MatchResult]:
|
||||
"""Template matching multi-scale."""
|
||||
threshold = threshold_override if threshold_override is not None else self.threshold
|
||||
|
||||
best_score = -1.0
|
||||
best_loc = None
|
||||
best_scale = 1.0
|
||||
best_anchor_shape = anchor_cv.shape
|
||||
|
||||
t0 = time.time()
|
||||
|
||||
for scale in self.scales:
|
||||
if scale == 1.0:
|
||||
scaled = anchor_cv
|
||||
else:
|
||||
new_w = int(anchor_cv.shape[1] * scale)
|
||||
new_h = int(anchor_cv.shape[0] * scale)
|
||||
if new_w < 8 or new_h < 8:
|
||||
continue
|
||||
if new_h >= screen_cv.shape[0] or new_w >= screen_cv.shape[1]:
|
||||
continue
|
||||
scaled = cv2.resize(anchor_cv, (new_w, new_h), interpolation=cv2.INTER_AREA)
|
||||
|
||||
if scaled.shape[0] >= screen_cv.shape[0] or scaled.shape[1] >= screen_cv.shape[1]:
|
||||
continue
|
||||
|
||||
s_img, a_img = self._maybe_grayscale(screen_cv, scaled)
|
||||
result_tm = cv2.matchTemplate(s_img, a_img, self._cv2_method)
|
||||
_, max_val, _, max_loc = cv2.minMaxLoc(result_tm)
|
||||
|
||||
if max_val > best_score:
|
||||
best_score = max_val
|
||||
best_loc = max_loc
|
||||
best_scale = scale
|
||||
best_anchor_shape = scaled.shape
|
||||
|
||||
elapsed_ms = (time.time() - t0) * 1000
|
||||
|
||||
logger.debug(
|
||||
"[TemplateMatcher/multiscale] best_score=%.3f scale=%.2f (%.0fms)",
|
||||
best_score, best_scale, elapsed_ms,
|
||||
)
|
||||
|
||||
if best_score >= threshold and best_loc is not None:
|
||||
cx = best_loc[0] + best_anchor_shape[1] // 2
|
||||
cy = best_loc[1] + best_anchor_shape[0] // 2
|
||||
return MatchResult(
|
||||
x=cx,
|
||||
y=cy,
|
||||
score=float(best_score),
|
||||
method='template_multiscale',
|
||||
time_ms=elapsed_ms,
|
||||
scale=best_scale,
|
||||
)
|
||||
return None
|
||||
|
||||
def _maybe_grayscale(
|
||||
self,
|
||||
screen: "np.ndarray",
|
||||
anchor: "np.ndarray",
|
||||
) -> Tuple["np.ndarray", "np.ndarray"]:
|
||||
"""Convertit en niveaux de gris si self.grayscale est True."""
|
||||
if not self.grayscale:
|
||||
return screen, anchor
|
||||
s = cv2.cvtColor(screen, cv2.COLOR_BGR2GRAY) if len(screen.shape) == 3 else screen
|
||||
a = cv2.cvtColor(anchor, cv2.COLOR_BGR2GRAY) if len(anchor.shape) == 3 else anchor
|
||||
return s, a
|
||||
|
||||
@staticmethod
|
||||
def _decode_anchor(
|
||||
anchor_b64: Optional[str],
|
||||
anchor_pil: Optional["Image.Image"],
|
||||
) -> Optional["Image.Image"]:
|
||||
"""Décode l'ancre depuis base64 ou retourne le PIL directement."""
|
||||
if anchor_pil is not None:
|
||||
return anchor_pil
|
||||
|
||||
if anchor_b64 is None:
|
||||
logger.debug("[TemplateMatcher] Ni anchor_b64 ni anchor_pil fourni")
|
||||
return None
|
||||
|
||||
try:
|
||||
raw = anchor_b64.split(',')[1] if ',' in anchor_b64 else anchor_b64
|
||||
data = base64.b64decode(raw)
|
||||
return Image.open(io.BytesIO(data))
|
||||
except Exception as e:
|
||||
logger.debug("[TemplateMatcher] Erreur décodage ancre: %s", e)
|
||||
return None
|
||||
|
||||
@staticmethod
|
||||
def _capture_screen() -> Optional["Image.Image"]:
|
||||
"""Capture l'écran complet via mss (moniteur 0 = tous les écrans)."""
|
||||
if not _MSS:
|
||||
logger.debug("[TemplateMatcher] mss non disponible")
|
||||
return None
|
||||
|
||||
try:
|
||||
with mss_lib.mss() as sct:
|
||||
mon = sct.monitors[0]
|
||||
grab = sct.grab(mon)
|
||||
return Image.frombytes('RGB', grab.size, grab.bgra, 'raw', 'BGRX')
|
||||
except Exception as e:
|
||||
logger.debug("[TemplateMatcher] Erreur capture écran: %s", e)
|
||||
return None
|
||||
103
core/grounding/think_arbiter.py
Normal file
103
core/grounding/think_arbiter.py
Normal file
@@ -0,0 +1,103 @@
|
||||
"""
|
||||
core/grounding/think_arbiter.py — Layer THINK : VLM arbitre (InfiGUI via subprocess)
|
||||
|
||||
Appelé UNIQUEMENT quand le SmartMatcher n'a pas assez confiance.
|
||||
Utilise le subprocess worker InfiGUI (pas de serveur HTTP).
|
||||
|
||||
Utilisation :
|
||||
from core.grounding.think_arbiter import ThinkArbiter
|
||||
|
||||
arbiter = ThinkArbiter()
|
||||
result = arbiter.arbitrate(target, candidates, screenshot)
|
||||
"""
|
||||
|
||||
from __future__ import annotations
|
||||
|
||||
import time
|
||||
from typing import Any, Dict, List, Optional
|
||||
|
||||
from core.grounding.fast_types import LocateResult, MatchCandidate
|
||||
from core.grounding.target import GroundingTarget
|
||||
|
||||
|
||||
class ThinkArbiter:
|
||||
"""Arbitre VLM — appelle InfiGUI via subprocess worker."""
|
||||
|
||||
def __init__(self):
|
||||
self._grounder = None
|
||||
|
||||
def _get_grounder(self):
|
||||
if self._grounder is None:
|
||||
from core.grounding.ui_tars_grounder import UITarsGrounder
|
||||
self._grounder = UITarsGrounder.get_instance()
|
||||
return self._grounder
|
||||
|
||||
@property
|
||||
def available(self) -> bool:
|
||||
"""Toujours disponible — le worker se lance à la demande."""
|
||||
return True
|
||||
|
||||
def arbitrate(
|
||||
self,
|
||||
target: GroundingTarget,
|
||||
candidates: List[MatchCandidate],
|
||||
screenshot_pil: Optional[Any] = None,
|
||||
) -> Optional[LocateResult]:
|
||||
"""Demande au VLM de trancher.
|
||||
|
||||
Si target.template_b64 est fourni, on bascule en mode fusionné :
|
||||
le crop est passé comme image de référence à InfiGUI, ce qui évite
|
||||
une description Ollama qwen2.5vl coûteuse en VRAM.
|
||||
"""
|
||||
t0 = time.time()
|
||||
|
||||
# Décodage du crop d'ancre si disponible (mode fusionné)
|
||||
anchor_pil = None
|
||||
if target.template_b64:
|
||||
try:
|
||||
import base64
|
||||
import io
|
||||
from PIL import Image
|
||||
|
||||
raw_b64 = target.template_b64
|
||||
if ',' in raw_b64:
|
||||
raw_b64 = raw_b64.split(',', 1)[1]
|
||||
anchor_pil = Image.open(io.BytesIO(base64.b64decode(raw_b64))).convert("RGB")
|
||||
except Exception as ex:
|
||||
print(f"⚠️ [THINK] Décodage anchor échoué: {ex}")
|
||||
anchor_pil = None
|
||||
|
||||
try:
|
||||
grounder = self._get_grounder()
|
||||
result = grounder.ground(
|
||||
target_text=target.text or "",
|
||||
target_description=target.description or "",
|
||||
screen_pil=screenshot_pil,
|
||||
anchor_pil=anchor_pil,
|
||||
)
|
||||
|
||||
dt = (time.time() - t0) * 1000
|
||||
|
||||
if result is None:
|
||||
label = target.text or "<crop>"
|
||||
print(f"🤔 [THINK] VLM n'a pas trouvé '{label}' ({dt:.0f}ms)")
|
||||
return None
|
||||
|
||||
method = "think_vlm_fused" if anchor_pil is not None else "think_vlm"
|
||||
locate = LocateResult(
|
||||
x=result.x,
|
||||
y=result.y,
|
||||
confidence=result.confidence,
|
||||
method=method,
|
||||
time_ms=dt,
|
||||
tier="think",
|
||||
candidates_count=len(candidates),
|
||||
)
|
||||
|
||||
print(f"🤔 [THINK/{method}] ({result.x}, {result.y}) conf={result.confidence:.2f} ({dt:.0f}ms)")
|
||||
return locate
|
||||
|
||||
except Exception as ex:
|
||||
dt = (time.time() - t0) * 1000
|
||||
print(f"⚠️ [THINK] Erreur: {ex} ({dt:.0f}ms)")
|
||||
return None
|
||||
174
core/grounding/title_verifier.py
Normal file
174
core/grounding/title_verifier.py
Normal file
@@ -0,0 +1,174 @@
|
||||
"""
|
||||
core/grounding/title_verifier.py — Vérification post-action par titre de fenêtre
|
||||
|
||||
Après chaque action (clic, double-clic), vérifie que la fenêtre active
|
||||
a changé de manière attendue en lisant le titre via OCR sur un crop
|
||||
de 45px en haut de l'écran.
|
||||
|
||||
Léger (~120ms), non-bloquant (échec = warning + retry, pas stop).
|
||||
|
||||
Utilisation :
|
||||
from core.grounding.title_verifier import TitleVerifier
|
||||
|
||||
verifier = TitleVerifier()
|
||||
title = verifier.read_title(screenshot_pil)
|
||||
changed = verifier.has_title_changed(title_before, title_after)
|
||||
"""
|
||||
|
||||
from __future__ import annotations
|
||||
|
||||
import time
|
||||
from difflib import SequenceMatcher
|
||||
from typing import Optional
|
||||
|
||||
|
||||
class TitleVerifier:
|
||||
"""Vérifie le titre de la fenêtre active via OCR sur crop."""
|
||||
|
||||
# Hauteur du crop pour la barre de titre Windows
|
||||
TITLE_BAR_HEIGHT = 45
|
||||
|
||||
def __init__(self):
|
||||
self._ocr_fn = None # Lazy load
|
||||
|
||||
def read_title(self, screenshot_pil) -> str:
|
||||
"""Lit le titre de la fenêtre active via OCR sur le crop supérieur.
|
||||
|
||||
Args:
|
||||
screenshot_pil: Image PIL du screenshot complet.
|
||||
|
||||
Returns:
|
||||
Texte du titre (peut être vide si OCR échoue).
|
||||
"""
|
||||
t0 = time.time()
|
||||
|
||||
try:
|
||||
w, h = screenshot_pil.size
|
||||
# Crop la barre de titre (45px du haut)
|
||||
title_crop = screenshot_pil.crop((0, 0, w, min(self.TITLE_BAR_HEIGHT, h)))
|
||||
|
||||
# OCR sur le petit crop
|
||||
ocr_fn = self._get_ocr()
|
||||
if ocr_fn is None:
|
||||
return ""
|
||||
|
||||
text = ocr_fn(title_crop)
|
||||
dt = (time.time() - t0) * 1000
|
||||
|
||||
# Nettoyer le texte
|
||||
title = text.strip() if text else ""
|
||||
if title:
|
||||
print(f"📋 [TitleVerify] Titre lu: '{title[:60]}' ({dt:.0f}ms)")
|
||||
|
||||
return title
|
||||
|
||||
except Exception as e:
|
||||
print(f"⚠️ [TitleVerify] Erreur lecture titre: {e}")
|
||||
return ""
|
||||
|
||||
def has_title_changed(self, title_before: str, title_after: str) -> bool:
|
||||
"""Vérifie si le titre a changé de manière significative."""
|
||||
if not title_before and not title_after:
|
||||
return False
|
||||
if not title_before or not title_after:
|
||||
return True # Un des deux est vide = changement
|
||||
|
||||
# Comparaison fuzzy — les titres peuvent avoir des variations mineures
|
||||
ratio = SequenceMatcher(None, title_before.lower(), title_after.lower()).ratio()
|
||||
return ratio < 0.85 # Changement si < 85% similaire
|
||||
|
||||
def verify_action(
|
||||
self,
|
||||
screenshot_before,
|
||||
screenshot_after,
|
||||
action_type: str,
|
||||
) -> dict:
|
||||
"""Vérifie qu'une action a produit l'effet attendu sur le titre.
|
||||
|
||||
Args:
|
||||
screenshot_before: Screenshot PIL avant l'action.
|
||||
screenshot_after: Screenshot PIL après l'action.
|
||||
action_type: Type d'action ("double_click", "click", "type", "hotkey").
|
||||
|
||||
Returns:
|
||||
Dict avec success, title_before, title_after, changed.
|
||||
"""
|
||||
# Les actions qui ne changent pas le titre
|
||||
if action_type in ('type_text', 'keyboard_shortcut', 'wait_for_anchor', 'hover'):
|
||||
return {
|
||||
'success': True,
|
||||
'title_before': '',
|
||||
'title_after': '',
|
||||
'changed': False,
|
||||
'reason': f"Action '{action_type}' — vérification titre non requise",
|
||||
}
|
||||
|
||||
title_before = self.read_title(screenshot_before)
|
||||
title_after = self.read_title(screenshot_after)
|
||||
changed = self.has_title_changed(title_before, title_after)
|
||||
|
||||
# Pour un double-clic (ouverture fichier/dossier), le titre DOIT changer
|
||||
# Mais seulement si les titres lus sont significatifs (> 3 chars)
|
||||
# docTR sur un crop 45px dans une VM peut donner du bruit ('o', 'a', etc.)
|
||||
if action_type in ('double_click_anchor',) and not changed:
|
||||
if len(title_before) > 3 and len(title_after) > 3:
|
||||
return {
|
||||
'success': False,
|
||||
'title_before': title_before,
|
||||
'title_after': title_after,
|
||||
'changed': False,
|
||||
'reason': f"Double-clic sans changement de titre ('{title_after[:40]}')",
|
||||
}
|
||||
# Titres trop courts = bruit OCR, on ne peut pas conclure
|
||||
return {
|
||||
'success': True,
|
||||
'title_before': title_before,
|
||||
'title_after': title_after,
|
||||
'changed': False,
|
||||
'reason': f"Titre trop court pour vérifier ('{title_after}')",
|
||||
}
|
||||
|
||||
# Pour un clic simple, le changement est optionnel
|
||||
return {
|
||||
'success': True,
|
||||
'title_before': title_before,
|
||||
'title_after': title_after,
|
||||
'changed': changed,
|
||||
'reason': 'Titre changé' if changed else 'Titre identique (acceptable)',
|
||||
}
|
||||
|
||||
_easyocr_reader = None # Singleton partagé
|
||||
|
||||
def _get_ocr(self):
|
||||
"""Lazy load de la fonction OCR (EasyOCR prioritaire, fallback docTR)."""
|
||||
if self._ocr_fn is not None:
|
||||
return self._ocr_fn
|
||||
|
||||
# EasyOCR (rapide, bonne qualité GUI)
|
||||
try:
|
||||
import easyocr
|
||||
import numpy as np
|
||||
|
||||
if TitleVerifier._easyocr_reader is None:
|
||||
TitleVerifier._easyocr_reader = easyocr.Reader(
|
||||
['fr', 'en'], gpu=True, verbose=False
|
||||
)
|
||||
|
||||
def _easyocr_extract_text(img):
|
||||
results = TitleVerifier._easyocr_reader.readtext(np.array(img))
|
||||
return ' '.join(r[1] for r in results if r[1].strip())
|
||||
|
||||
self._ocr_fn = _easyocr_extract_text
|
||||
return self._ocr_fn
|
||||
except ImportError:
|
||||
pass
|
||||
|
||||
# Fallback docTR
|
||||
try:
|
||||
import sys
|
||||
sys.path.insert(0, 'visual_workflow_builder/backend')
|
||||
from services.ocr_service import ocr_extract_text
|
||||
self._ocr_fn = ocr_extract_text
|
||||
return self._ocr_fn
|
||||
except ImportError:
|
||||
return None
|
||||
161
core/grounding/ui_tars_grounder.py
Normal file
161
core/grounding/ui_tars_grounder.py
Normal file
@@ -0,0 +1,161 @@
|
||||
"""
|
||||
core/grounding/ui_tars_grounder.py — Grounding via script one-shot InfiGUI
|
||||
|
||||
Chaque appel lance un subprocess Python qui charge le modèle, infère, et quitte.
|
||||
Lent (~15s) mais fiable — pas de crash CUDA en process persistant.
|
||||
"""
|
||||
|
||||
from __future__ import annotations
|
||||
|
||||
import json
|
||||
import os
|
||||
import subprocess
|
||||
import sys
|
||||
import threading
|
||||
import time
|
||||
from typing import Optional
|
||||
|
||||
from core.grounding.target import GroundingResult
|
||||
|
||||
_instance: Optional[UITarsGrounder] = None
|
||||
_instance_lock = threading.Lock()
|
||||
|
||||
|
||||
class UITarsGrounder:
|
||||
"""Grounding via script one-shot InfiGUI."""
|
||||
|
||||
def __init__(self):
|
||||
self._lock = threading.Lock()
|
||||
self._project_root = os.path.abspath(
|
||||
os.path.join(os.path.dirname(__file__), "..", "..")
|
||||
)
|
||||
|
||||
@classmethod
|
||||
def get_instance(cls) -> UITarsGrounder:
|
||||
global _instance
|
||||
if _instance is None:
|
||||
with _instance_lock:
|
||||
if _instance is None:
|
||||
_instance = cls()
|
||||
return _instance
|
||||
|
||||
@property
|
||||
def available(self) -> bool:
|
||||
return True # Toujours disponible — le script se lance à la demande
|
||||
|
||||
def ground(
|
||||
self,
|
||||
target_text: str = "",
|
||||
target_description: str = "",
|
||||
screen_pil=None,
|
||||
anchor_pil=None,
|
||||
) -> Optional[GroundingResult]:
|
||||
"""Localise un élément UI via un script one-shot InfiGUI.
|
||||
|
||||
Args:
|
||||
target_text: nom textuel de la cible (peut être vide si anchor_pil fourni).
|
||||
target_description: description sémantique libre.
|
||||
screen_pil: screenshot complet (PIL.Image).
|
||||
anchor_pil: crop visuel de l'ancre capturée précédemment (PIL.Image).
|
||||
Si fourni, le worker passe en mode fusionné : Image1=crop, Image2=screen,
|
||||
"trouve sur l'image 2 l'élément visuel de l'image 1".
|
||||
"""
|
||||
t0 = time.time()
|
||||
|
||||
try:
|
||||
with self._lock:
|
||||
# Sauver l'image principale
|
||||
image_path = "/tmp/infigui_screen.png"
|
||||
if screen_pil is not None:
|
||||
screen_pil.save(image_path)
|
||||
|
||||
# Sauver l'image d'ancre (mode fusionné)
|
||||
anchor_image_path = ""
|
||||
if anchor_pil is not None:
|
||||
anchor_image_path = "/tmp/infigui_anchor.png"
|
||||
anchor_pil.save(anchor_image_path)
|
||||
|
||||
# Construire la requête JSON
|
||||
req = json.dumps({
|
||||
"target": target_text,
|
||||
"description": target_description,
|
||||
"image_path": image_path,
|
||||
"anchor_image_path": anchor_image_path,
|
||||
})
|
||||
|
||||
mode_str = "fused" if anchor_pil is not None else "text"
|
||||
label_short = target_text[:30] if target_text else "<crop only>"
|
||||
print(f"🎯 [InfiGUI] Lancement one-shot [{mode_str}]: '{label_short}'")
|
||||
|
||||
# Lancer le script one-shot
|
||||
# IMPORTANT: depuis un service systemd où le parent a déjà chargé CUDA,
|
||||
# le subprocess hérite d'un état GPU cassé (No CUDA GPUs available).
|
||||
# Solutions : start_new_session=True (nouveau cgroup) + forcer
|
||||
# CUDA_VISIBLE_DEVICES=0 explicitement pour bypass l'héritage parent.
|
||||
_child_env = {**os.environ}
|
||||
_child_env["PYTHONDONTWRITEBYTECODE"] = "1"
|
||||
_child_env["CUDA_VISIBLE_DEVICES"] = "0"
|
||||
_child_env["NVIDIA_VISIBLE_DEVICES"] = "all"
|
||||
# Supprimer les variables Python qui pourraient pointer sur l'état parent
|
||||
_child_env.pop("PYTORCH_NVML_BASED_CUDA_CHECK", None)
|
||||
|
||||
result = subprocess.run(
|
||||
[sys.executable, "-m", "core.grounding.infigui_worker"],
|
||||
input=req + "\n",
|
||||
capture_output=True,
|
||||
text=True,
|
||||
timeout=60,
|
||||
cwd=self._project_root,
|
||||
env=_child_env,
|
||||
start_new_session=True, # nouveau session group, isole du parent
|
||||
close_fds=True,
|
||||
)
|
||||
|
||||
if result.returncode != 0:
|
||||
stderr_lines = (result.stderr or '').strip().split('\n')
|
||||
# Afficher les dernières lignes significatives du stderr
|
||||
last_err = [l for l in stderr_lines[-5:] if l.strip()]
|
||||
print(f"⚠️ [InfiGUI] Script échoué (code {result.returncode})")
|
||||
for l in last_err:
|
||||
print(f" ❌ {l}")
|
||||
return None
|
||||
|
||||
# Parser la sortie — chercher la ligne JSON de résultat
|
||||
data = None
|
||||
for line in result.stdout.strip().split("\n"):
|
||||
line = line.strip()
|
||||
if not line:
|
||||
continue
|
||||
try:
|
||||
parsed = json.loads(line)
|
||||
if "x" in parsed:
|
||||
data = parsed
|
||||
except json.JSONDecodeError:
|
||||
continue
|
||||
|
||||
if data is None:
|
||||
print(f"⚠️ [InfiGUI] Pas de réponse JSON dans la sortie")
|
||||
return None
|
||||
|
||||
dt = (time.time() - t0) * 1000
|
||||
|
||||
if data.get("x") is not None:
|
||||
method_name = "infigui_fused" if anchor_pil is not None else "infigui"
|
||||
print(f"🎯 [InfiGUI/{method_name}] ({data['x']}, {data['y']}) "
|
||||
f"conf={data.get('confidence', 0):.2f} ({dt:.0f}ms)")
|
||||
return GroundingResult(
|
||||
x=data["x"], y=data["y"],
|
||||
method=method_name,
|
||||
confidence=data.get("confidence", 0.90),
|
||||
time_ms=dt,
|
||||
)
|
||||
else:
|
||||
print(f"⚠️ [InfiGUI] Pas trouvé ({dt:.0f}ms)")
|
||||
return None
|
||||
|
||||
except subprocess.TimeoutExpired:
|
||||
print(f"⚠️ [InfiGUI] Timeout 60s")
|
||||
return None
|
||||
except Exception as e:
|
||||
print(f"⚠️ [InfiGUI] Erreur: {e}")
|
||||
return None
|
||||
0
core/knowledge/__init__.py
Normal file
0
core/knowledge/__init__.py
Normal file
523
core/knowledge/ui_patterns.py
Normal file
523
core/knowledge/ui_patterns.py
Normal file
@@ -0,0 +1,523 @@
|
||||
"""
|
||||
Base de connaissances des patterns d'interface utilisateur.
|
||||
|
||||
Donne à Léa des "réflexes natifs" : quand elle reconnaît un pattern UI
|
||||
connu (dialogue OK/Annuler, menu, barre d'outils), elle sait immédiatement
|
||||
quoi faire sans avoir besoin de l'apprendre par observation.
|
||||
|
||||
Sources :
|
||||
- GUI-R1 dataset (3K exemples annotés, ritzzai/GUI-R1)
|
||||
- Patterns Windows/Linux courants
|
||||
- Conventions UI universelles
|
||||
|
||||
Utilisation :
|
||||
from core.knowledge.ui_patterns import UIPatternLibrary
|
||||
lib = UIPatternLibrary()
|
||||
match = lib.find_pattern("Voulez-vous enregistrer ?")
|
||||
# → {'action': 'click', 'target': 'Enregistrer', 'zone': 'dialog_center', ...}
|
||||
"""
|
||||
|
||||
import json
|
||||
import logging
|
||||
from dataclasses import dataclass, field
|
||||
from pathlib import Path
|
||||
from typing import Any, Dict, List, Optional, Tuple
|
||||
|
||||
logger = logging.getLogger(__name__)
|
||||
|
||||
|
||||
@dataclass
|
||||
class UIPattern:
|
||||
"""Un pattern d'interface connu."""
|
||||
name: str
|
||||
category: str
|
||||
triggers: List[str]
|
||||
action: str
|
||||
target: str
|
||||
typical_zone: str
|
||||
typical_bbox: Optional[List[float]] = None
|
||||
os: str = "any"
|
||||
confidence: float = 0.9
|
||||
metadata: Dict[str, Any] = field(default_factory=dict)
|
||||
|
||||
|
||||
# Patterns Windows natifs — réflexes de base
|
||||
BUILTIN_PATTERNS: List[Dict[str, Any]] = [
|
||||
# === DIALOGUES DE CONFIRMATION ===
|
||||
{
|
||||
"name": "dialog_save",
|
||||
"category": "dialog",
|
||||
"triggers": [
|
||||
"voulez-vous enregistrer", "do you want to save",
|
||||
"save changes", "enregistrer les modifications",
|
||||
"enregistrer sous", "save as",
|
||||
"sauvegarder", "unsaved changes",
|
||||
],
|
||||
"action": "click",
|
||||
"target": "Enregistrer",
|
||||
"alternatives": ["Save", "Oui", "Yes"],
|
||||
"typical_zone": "dialog_center",
|
||||
"typical_bbox": [0.35, 0.55, 0.50, 0.65],
|
||||
"os": "any",
|
||||
},
|
||||
{
|
||||
"name": "dialog_cancel",
|
||||
"category": "dialog",
|
||||
"triggers": [
|
||||
"annuler", "cancel", "abandonner", "discard",
|
||||
],
|
||||
"action": "click",
|
||||
"target": "Annuler",
|
||||
"alternatives": ["Cancel", "Non", "No"],
|
||||
"typical_zone": "dialog_center",
|
||||
"typical_bbox": [0.50, 0.55, 0.65, 0.65],
|
||||
"os": "any",
|
||||
},
|
||||
{
|
||||
"name": "dialog_ok",
|
||||
"category": "dialog",
|
||||
"triggers": [
|
||||
"ok", "d'accord", "compris", "information",
|
||||
"erreur", "error", "warning", "avertissement",
|
||||
],
|
||||
"action": "click",
|
||||
"target": "OK",
|
||||
"alternatives": ["Fermer", "Close", "Compris"],
|
||||
"typical_zone": "dialog_center",
|
||||
"typical_bbox": [0.45, 0.60, 0.55, 0.70],
|
||||
"os": "any",
|
||||
},
|
||||
{
|
||||
"name": "dialog_yes_no",
|
||||
"category": "dialog",
|
||||
"triggers": [
|
||||
"êtes-vous sûr", "are you sure", "confirmer",
|
||||
"confirm", "supprimer", "delete",
|
||||
],
|
||||
"action": "click",
|
||||
"target": "Oui",
|
||||
"alternatives": ["Yes", "Confirmer", "Confirm"],
|
||||
"typical_zone": "dialog_center",
|
||||
"typical_bbox": [0.35, 0.60, 0.45, 0.68],
|
||||
"os": "any",
|
||||
},
|
||||
{
|
||||
"name": "dialog_overwrite",
|
||||
"category": "dialog",
|
||||
"triggers": [
|
||||
"voulez-vous remplacer", "voulez-vous écraser",
|
||||
"remplacer le fichier", "replace existing",
|
||||
"fichier existe déjà", "already exists",
|
||||
"overwrite", "écraser",
|
||||
],
|
||||
"action": "click",
|
||||
"target": "Oui",
|
||||
"alternatives": ["Yes", "Remplacer", "Replace", "Confirmer"],
|
||||
"typical_zone": "dialog_center",
|
||||
"os": "any",
|
||||
},
|
||||
{
|
||||
"name": "dialog_dont_save",
|
||||
"category": "dialog",
|
||||
"triggers": [
|
||||
"ne pas enregistrer", "don't save",
|
||||
"ne pas sauvegarder", "quitter sans enregistrer",
|
||||
"discard changes",
|
||||
],
|
||||
"action": "click",
|
||||
"target": "Ne pas enregistrer",
|
||||
"alternatives": ["Don't Save", "Ne pas sauvegarder", "Non"],
|
||||
"typical_zone": "dialog_center",
|
||||
"os": "any",
|
||||
},
|
||||
|
||||
# === NAVIGATION FENÊTRE ===
|
||||
{
|
||||
"name": "window_close",
|
||||
"category": "window",
|
||||
"triggers": ["fermer la fenêtre", "close window"],
|
||||
"action": "click",
|
||||
"target": "X",
|
||||
"typical_zone": "titlebar",
|
||||
"typical_bbox": [0.96, 0.0, 1.0, 0.04],
|
||||
"os": "windows",
|
||||
},
|
||||
{
|
||||
"name": "window_minimize",
|
||||
"category": "window",
|
||||
"triggers": ["minimiser", "minimize"],
|
||||
"action": "click",
|
||||
"target": "_",
|
||||
"typical_zone": "titlebar",
|
||||
"typical_bbox": [0.90, 0.0, 0.94, 0.04],
|
||||
"os": "windows",
|
||||
},
|
||||
{
|
||||
"name": "window_maximize",
|
||||
"category": "window",
|
||||
"triggers": ["maximiser", "maximize", "agrandir"],
|
||||
"action": "click",
|
||||
"target": "□",
|
||||
"typical_zone": "titlebar",
|
||||
"typical_bbox": [0.94, 0.0, 0.96, 0.04],
|
||||
"os": "windows",
|
||||
},
|
||||
|
||||
# === MENUS ===
|
||||
{
|
||||
"name": "menu_file",
|
||||
"category": "menu",
|
||||
"triggers": ["menu fichier", "menu file", "ouvrir fichier", "open file"],
|
||||
"action": "click",
|
||||
"target": "Fichier",
|
||||
"alternatives": ["File"],
|
||||
"typical_zone": "menu_toolbar",
|
||||
"typical_bbox": [0.0, 0.03, 0.06, 0.06],
|
||||
"os": "any",
|
||||
},
|
||||
{
|
||||
"name": "menu_edit",
|
||||
"category": "menu",
|
||||
"triggers": ["édition", "edit", "modifier"],
|
||||
"action": "click",
|
||||
"target": "Édition",
|
||||
"alternatives": ["Edit"],
|
||||
"typical_zone": "menu_toolbar",
|
||||
"typical_bbox": [0.06, 0.03, 0.12, 0.06],
|
||||
"os": "any",
|
||||
},
|
||||
|
||||
# === FORMULAIRES ===
|
||||
{
|
||||
"name": "form_submit",
|
||||
"category": "form",
|
||||
"triggers": [
|
||||
"valider", "submit", "envoyer", "send",
|
||||
"connexion", "login", "se connecter", "sign in",
|
||||
],
|
||||
"action": "click",
|
||||
"target": "Valider",
|
||||
"alternatives": ["Submit", "Envoyer", "Connexion", "Login", "OK"],
|
||||
"typical_zone": "content",
|
||||
"typical_bbox": [0.35, 0.70, 0.65, 0.80],
|
||||
"os": "any",
|
||||
},
|
||||
{
|
||||
"name": "form_search",
|
||||
"category": "form",
|
||||
"triggers": ["rechercher", "search", "chercher", "find"],
|
||||
"action": "click",
|
||||
"target": "Rechercher",
|
||||
"alternatives": ["Search", "🔍", "Go"],
|
||||
"typical_zone": "menu_toolbar",
|
||||
"typical_bbox": [0.30, 0.03, 0.70, 0.06],
|
||||
"os": "any",
|
||||
},
|
||||
|
||||
# === NAVIGATION WEB ===
|
||||
{
|
||||
"name": "cookie_accept",
|
||||
"category": "popup",
|
||||
"triggers": [
|
||||
"accepter les cookies", "accept cookies",
|
||||
"utilise des cookies", "uses cookies",
|
||||
"j'accepte", "accept all", "tout accepter",
|
||||
"consent", "consentement",
|
||||
],
|
||||
"action": "click",
|
||||
"target": "Accepter",
|
||||
"alternatives": ["Accept", "Accept All", "Tout accepter", "J'accepte"],
|
||||
"typical_zone": "content",
|
||||
"typical_bbox": [0.30, 0.80, 0.70, 0.90],
|
||||
"os": "any",
|
||||
},
|
||||
|
||||
# === RACCOURCIS UNIVERSELS ===
|
||||
{
|
||||
"name": "shortcut_save",
|
||||
"category": "shortcut",
|
||||
"triggers": ["sauvegarder", "enregistrer", "save"],
|
||||
"action": "hotkey",
|
||||
"target": "ctrl+s",
|
||||
"typical_zone": "keyboard",
|
||||
"os": "any",
|
||||
},
|
||||
{
|
||||
"name": "shortcut_undo",
|
||||
"category": "shortcut",
|
||||
"triggers": ["annuler action", "undo", "défaire"],
|
||||
"action": "hotkey",
|
||||
"target": "ctrl+z",
|
||||
"typical_zone": "keyboard",
|
||||
"os": "any",
|
||||
},
|
||||
{
|
||||
"name": "shortcut_copy",
|
||||
"category": "shortcut",
|
||||
"triggers": ["copier", "copy"],
|
||||
"action": "hotkey",
|
||||
"target": "ctrl+c",
|
||||
"typical_zone": "keyboard",
|
||||
"os": "any",
|
||||
},
|
||||
{
|
||||
"name": "shortcut_paste",
|
||||
"category": "shortcut",
|
||||
"triggers": ["coller", "paste"],
|
||||
"action": "hotkey",
|
||||
"target": "ctrl+v",
|
||||
"typical_zone": "keyboard",
|
||||
"os": "any",
|
||||
},
|
||||
]
|
||||
|
||||
|
||||
class UIPatternLibrary:
|
||||
"""Bibliothèque de patterns UI connus.
|
||||
|
||||
Fournit des "réflexes natifs" à Léa : quand un pattern
|
||||
est reconnu dans le texte OCR ou le contexte visuel,
|
||||
elle sait immédiatement quoi faire.
|
||||
"""
|
||||
|
||||
# Chemins par défaut des fichiers de patterns additionnels
|
||||
_PROJECT_ROOT = Path(__file__).resolve().parent.parent.parent
|
||||
_GUI_R1_PATTERNS_PATH = _PROJECT_ROOT / "data" / "gui_r1_ui_patterns.json"
|
||||
_LEARNED_PATTERNS_PATH = _PROJECT_ROOT / "data" / "learned_patterns.json"
|
||||
|
||||
def __init__(self, extra_patterns_path: Optional[str] = None):
|
||||
self._patterns: List[UIPattern] = []
|
||||
self._load_builtin()
|
||||
|
||||
# Charger les patterns extraits de GUI-R1 (statiques, générés une fois)
|
||||
self._load_from_file(str(self._GUI_R1_PATTERNS_PATH))
|
||||
|
||||
# Charger les patterns appris par observation Shadow (dynamiques)
|
||||
self._load_from_file(str(self._LEARNED_PATTERNS_PATH))
|
||||
|
||||
# Fichier custom fourni explicitement
|
||||
if extra_patterns_path:
|
||||
self._load_from_file(extra_patterns_path)
|
||||
|
||||
logger.info(f"UIPatternLibrary: {len(self._patterns)} patterns chargés")
|
||||
|
||||
def _load_builtin(self):
|
||||
for p in BUILTIN_PATTERNS:
|
||||
self._patterns.append(UIPattern(
|
||||
name=p["name"],
|
||||
category=p["category"],
|
||||
triggers=p["triggers"],
|
||||
action=p["action"],
|
||||
target=p["target"],
|
||||
typical_zone=p.get("typical_zone", "content"),
|
||||
typical_bbox=p.get("typical_bbox"),
|
||||
os=p.get("os", "any"),
|
||||
metadata={
|
||||
"alternatives": p.get("alternatives", []),
|
||||
"source": "builtin",
|
||||
},
|
||||
))
|
||||
|
||||
def _load_from_file(self, path: str):
|
||||
filepath = Path(path)
|
||||
if not filepath.exists():
|
||||
logger.debug(f"Fichier patterns non trouvé (OK si premier lancement): {path}")
|
||||
return
|
||||
try:
|
||||
with open(filepath) as f:
|
||||
data = json.load(f)
|
||||
for p in data.get("patterns", []):
|
||||
# Construire metadata en incluant source/learned_at/gui_r1_id si présents
|
||||
meta = dict(p.get("metadata", {}))
|
||||
if "source" in p:
|
||||
meta["source"] = p["source"]
|
||||
if "learned_at" in p:
|
||||
meta["learned_at"] = p["learned_at"]
|
||||
if "gui_r1_id" in p:
|
||||
meta["gui_r1_id"] = p["gui_r1_id"]
|
||||
self._patterns.append(UIPattern(
|
||||
name=p["name"],
|
||||
category=p.get("category", "custom"),
|
||||
triggers=p.get("triggers", []),
|
||||
action=p.get("action", "click"),
|
||||
target=p.get("target", ""),
|
||||
typical_zone=p.get("typical_zone", "content"),
|
||||
typical_bbox=p.get("typical_bbox"),
|
||||
os=p.get("os", "any"),
|
||||
confidence=p.get("confidence", 0.9),
|
||||
metadata=meta,
|
||||
))
|
||||
logger.info(f"Chargé {len(data.get('patterns', []))} patterns depuis {path}")
|
||||
except Exception as e:
|
||||
logger.error(f"Erreur chargement patterns: {e}")
|
||||
|
||||
def find_pattern(
|
||||
self,
|
||||
text: str,
|
||||
os_filter: Optional[str] = None,
|
||||
) -> Optional[Dict[str, Any]]:
|
||||
"""Cherche un pattern UI dans du texte (OCR, titre fenêtre, etc.).
|
||||
|
||||
Args:
|
||||
text: Texte à analyser (peut contenir du bruit OCR)
|
||||
os_filter: Filtrer par OS ("windows", "linux", None=tous)
|
||||
|
||||
Returns:
|
||||
Dict avec action, target, confidence, etc. ou None
|
||||
"""
|
||||
text_lower = text.lower()
|
||||
best_match = None
|
||||
best_score = 0
|
||||
|
||||
for pattern in self._patterns:
|
||||
if os_filter and pattern.os not in ("any", os_filter):
|
||||
continue
|
||||
|
||||
score = 0
|
||||
matched_trigger = None
|
||||
for trigger in pattern.triggers:
|
||||
if len(trigger) <= 3:
|
||||
import re
|
||||
if re.search(r'\b' + re.escape(trigger) + r'\b', text_lower):
|
||||
trigger_score = len(trigger) / max(len(text_lower), 1)
|
||||
if trigger_score > score:
|
||||
score = trigger_score
|
||||
matched_trigger = trigger
|
||||
elif trigger in text_lower:
|
||||
trigger_score = len(trigger) / max(len(text_lower), 1)
|
||||
if trigger_score > score:
|
||||
score = trigger_score
|
||||
matched_trigger = trigger
|
||||
|
||||
if score > best_score and matched_trigger is not None:
|
||||
best_score = score
|
||||
best_match = {
|
||||
"pattern": pattern.name,
|
||||
"category": pattern.category,
|
||||
"action": pattern.action,
|
||||
"target": pattern.target,
|
||||
"alternatives": pattern.metadata.get("alternatives", []),
|
||||
"typical_zone": pattern.typical_zone,
|
||||
"typical_bbox": pattern.typical_bbox,
|
||||
"confidence": min(pattern.confidence * (1 + score), 1.0),
|
||||
"matched_trigger": matched_trigger,
|
||||
"os": pattern.os,
|
||||
}
|
||||
|
||||
return best_match
|
||||
|
||||
def find_by_category(self, category: str) -> List[Dict[str, Any]]:
|
||||
"""Retourne tous les patterns d'une catégorie."""
|
||||
return [
|
||||
{
|
||||
"name": p.name,
|
||||
"action": p.action,
|
||||
"target": p.target,
|
||||
"triggers": p.triggers,
|
||||
"typical_zone": p.typical_zone,
|
||||
}
|
||||
for p in self._patterns
|
||||
if p.category == category
|
||||
]
|
||||
|
||||
def get_dialog_handler(self, dialog_text: str) -> Optional[Dict[str, Any]]:
|
||||
"""Raccourci : cherche un pattern de dialogue."""
|
||||
match = self.find_pattern(dialog_text)
|
||||
if match and match["category"] == "dialog":
|
||||
return match
|
||||
return self.find_pattern(dialog_text)
|
||||
|
||||
def add_pattern(self, pattern_dict: Dict[str, Any]):
|
||||
"""Ajoute un pattern dynamiquement (ex: appris par observation)."""
|
||||
self._patterns.append(UIPattern(
|
||||
name=pattern_dict["name"],
|
||||
category=pattern_dict.get("category", "learned"),
|
||||
triggers=pattern_dict.get("triggers", []),
|
||||
action=pattern_dict.get("action", "click"),
|
||||
target=pattern_dict.get("target", ""),
|
||||
typical_zone=pattern_dict.get("typical_zone", "content"),
|
||||
typical_bbox=pattern_dict.get("typical_bbox"),
|
||||
os=pattern_dict.get("os", "any"),
|
||||
confidence=pattern_dict.get("confidence", 0.7),
|
||||
metadata={"source": "learned"},
|
||||
))
|
||||
|
||||
def save_to_file(self, path: str):
|
||||
"""Sauvegarde tous les patterns (builtin + appris) dans un fichier."""
|
||||
data = {
|
||||
"patterns": [
|
||||
{
|
||||
"name": p.name,
|
||||
"category": p.category,
|
||||
"triggers": p.triggers,
|
||||
"action": p.action,
|
||||
"target": p.target,
|
||||
"typical_zone": p.typical_zone,
|
||||
"typical_bbox": p.typical_bbox,
|
||||
"os": p.os,
|
||||
"confidence": p.confidence,
|
||||
"metadata": p.metadata,
|
||||
}
|
||||
for p in self._patterns
|
||||
]
|
||||
}
|
||||
with open(path, "w", encoding="utf-8") as f:
|
||||
json.dump(data, f, indent=2, ensure_ascii=False)
|
||||
logger.info(f"Sauvegardé {len(self._patterns)} patterns dans {path}")
|
||||
|
||||
def save_learned_pattern(self, pattern_dict: Dict[str, Any]):
|
||||
"""Persiste un pattern appris par observation Shadow dans learned_patterns.json.
|
||||
|
||||
Le pattern est ajouté en mémoire ET sauvegardé sur disque.
|
||||
Le fichier est créé s'il n'existe pas, ou les patterns existants sont préservés.
|
||||
"""
|
||||
from datetime import datetime as dt
|
||||
|
||||
# Charger le fichier existant ou créer la structure
|
||||
filepath = self._LEARNED_PATTERNS_PATH
|
||||
filepath.parent.mkdir(parents=True, exist_ok=True)
|
||||
|
||||
existing: Dict[str, Any] = {"patterns": []}
|
||||
if filepath.exists():
|
||||
try:
|
||||
with open(filepath, encoding="utf-8") as f:
|
||||
existing = json.load(f)
|
||||
except (json.JSONDecodeError, OSError):
|
||||
logger.warning(f"Fichier {filepath} corrompu, recréation")
|
||||
|
||||
# Vérifier qu'on ne duplique pas (même trigger + même target)
|
||||
new_triggers = set(t.lower() for t in pattern_dict.get("triggers", []))
|
||||
new_target = pattern_dict.get("target", "").lower()
|
||||
for existing_p in existing.get("patterns", []):
|
||||
existing_triggers = set(t.lower() for t in existing_p.get("triggers", []))
|
||||
if existing_triggers == new_triggers and existing_p.get("target", "").lower() == new_target:
|
||||
logger.debug(f"Pattern déjà connu, skip: triggers={new_triggers}, target={new_target}")
|
||||
return
|
||||
|
||||
# Numéroter automatiquement et construire l'entrée complète
|
||||
count = len(existing.get("patterns", []))
|
||||
entry = {
|
||||
"name": pattern_dict.get("name", f"learned_dialog_{count + 1:03d}"),
|
||||
"category": pattern_dict.get("category", "dialog"),
|
||||
"triggers": pattern_dict.get("triggers", []),
|
||||
"action": pattern_dict.get("action", "click"),
|
||||
"target": pattern_dict.get("target", ""),
|
||||
"os": pattern_dict.get("os", "windows"),
|
||||
"source": "shadow_learning",
|
||||
"learned_at": dt.now().isoformat(timespec="seconds"),
|
||||
"confidence": pattern_dict.get("confidence", 0.8),
|
||||
}
|
||||
|
||||
# Ajouter en mémoire (avec le nom auto-généré)
|
||||
self.add_pattern(entry)
|
||||
existing.setdefault("patterns", []).append(entry)
|
||||
|
||||
with open(filepath, "w", encoding="utf-8") as f:
|
||||
json.dump(existing, f, indent=2, ensure_ascii=False)
|
||||
logger.info(f"Pattern appris sauvegardé: {entry['name']} → {entry['target']}")
|
||||
|
||||
@property
|
||||
def stats(self) -> Dict[str, int]:
|
||||
from collections import Counter
|
||||
cats = Counter(p.category for p in self._patterns)
|
||||
return {"total": len(self._patterns), "by_category": dict(cats)}
|
||||
15
core/llm/__init__.py
Normal file
15
core/llm/__init__.py
Normal file
@@ -0,0 +1,15 @@
|
||||
"""Modules LLM (clients Ollama et décisionnels métier) + extracteur OCR."""
|
||||
|
||||
from .t2a_decision import (
|
||||
PROMPT_TEMPLATE,
|
||||
DEFAULT_MODEL,
|
||||
analyze_dpi,
|
||||
)
|
||||
from .ocr_extractor import extract_text_from_image
|
||||
|
||||
__all__ = [
|
||||
"PROMPT_TEMPLATE",
|
||||
"DEFAULT_MODEL",
|
||||
"analyze_dpi",
|
||||
"extract_text_from_image",
|
||||
]
|
||||
71
core/llm/ocr_extractor.py
Normal file
71
core/llm/ocr_extractor.py
Normal file
@@ -0,0 +1,71 @@
|
||||
"""Extracteur OCR — texte depuis une image (screenshot d'écran).
|
||||
|
||||
Utilise EasyOCR fr+en. Singleton (chargement modèle ~3s au premier appel).
|
||||
|
||||
Conçu pour le pipeline streaming serveur (action `extract_text`) : récupère
|
||||
un screenshot fresh (dernier heartbeat ou capture forcée), applique l'OCR,
|
||||
retourne le texte concaténé pour analyse downstream (ex: t2a_decision).
|
||||
"""
|
||||
|
||||
from __future__ import annotations
|
||||
|
||||
import logging
|
||||
from pathlib import Path
|
||||
from typing import Optional, Tuple
|
||||
|
||||
logger = logging.getLogger(__name__)
|
||||
|
||||
_easyocr_reader = None
|
||||
|
||||
|
||||
def _get_reader():
|
||||
"""Initialise EasyOCR fr+en au premier appel (singleton)."""
|
||||
global _easyocr_reader
|
||||
if _easyocr_reader is None:
|
||||
import easyocr
|
||||
try:
|
||||
_easyocr_reader = easyocr.Reader(['fr', 'en'], gpu=True, verbose=False)
|
||||
logger.info("EasyOCR initialisé (fr+en, GPU)")
|
||||
except Exception as e:
|
||||
logger.warning("EasyOCR GPU indisponible (%s), fallback CPU", e)
|
||||
_easyocr_reader = easyocr.Reader(['fr', 'en'], gpu=False, verbose=False)
|
||||
return _easyocr_reader
|
||||
|
||||
|
||||
def extract_text_from_image(
|
||||
image_path: str,
|
||||
region: Optional[Tuple[int, int, int, int]] = None,
|
||||
paragraph: bool = True,
|
||||
) -> str:
|
||||
"""Extrait le texte d'une image via EasyOCR.
|
||||
|
||||
Args:
|
||||
image_path: chemin du PNG sur disque.
|
||||
region: (x, y, w, h) pour cropper avant OCR. None = image entière.
|
||||
paragraph: True pour regrouper les lignes en paragraphes (lisible),
|
||||
False pour blocs séparés (granulaire).
|
||||
|
||||
Returns:
|
||||
Texte concaténé. Chaque ligne / paragraphe est séparé par un saut de ligne.
|
||||
En cas d'erreur, retourne une chaîne vide et log un warning.
|
||||
"""
|
||||
path = Path(image_path)
|
||||
if not path.exists():
|
||||
logger.warning("extract_text: fichier introuvable %s", image_path)
|
||||
return ""
|
||||
|
||||
try:
|
||||
from PIL import Image
|
||||
import numpy as np
|
||||
|
||||
img = Image.open(path)
|
||||
if region:
|
||||
x, y, w, h = region
|
||||
img = img.crop((x, y, x + w, y + h))
|
||||
|
||||
reader = _get_reader()
|
||||
results = reader.readtext(np.array(img), detail=0, paragraph=paragraph)
|
||||
return "\n".join(str(r).strip() for r in results if r)
|
||||
except Exception as e:
|
||||
logger.warning("extract_text échoué sur %s : %s", image_path, e)
|
||||
return ""
|
||||
168
core/llm/t2a_decision.py
Normal file
168
core/llm/t2a_decision.py
Normal file
@@ -0,0 +1,168 @@
|
||||
"""Aide à la décision de facturation urgences T2A/PMSI via LLM local.
|
||||
|
||||
Décide si un passage aux urgences relève :
|
||||
- du FORFAIT_URGENCE (passage simple, retour à domicile)
|
||||
- de la REQUALIFICATION_HOSPITALISATION (séjour MCO, valorisation 1k-5k€+)
|
||||
|
||||
Le prompt impose une extraction littérale des faits du DPI (pas d'invention)
|
||||
et une modulation honnête de la confiance. Validé sur 15 DPI synthétiques :
|
||||
qwen2.5:7b atteint 100 % d'accuracy en ~5 s/cas avec 4,7 Go VRAM.
|
||||
|
||||
Voir docs/clients/ght_sud_95/ et demo/facturation_urgences/RESULTATS.md pour le
|
||||
bench comparatif des 11 LLMs évalués.
|
||||
"""
|
||||
|
||||
from __future__ import annotations
|
||||
|
||||
import json
|
||||
import logging
|
||||
import os
|
||||
import time
|
||||
import urllib.error
|
||||
import urllib.request
|
||||
from typing import Any, Dict
|
||||
|
||||
logger = logging.getLogger(__name__)
|
||||
|
||||
OLLAMA_URL = os.environ.get("OLLAMA_URL", "http://localhost:11434/api/generate")
|
||||
DEFAULT_MODEL = os.environ.get("T2A_MODEL", "qwen2.5:7b")
|
||||
DEFAULT_TIMEOUT = 60 # secondes
|
||||
|
||||
PROMPT_TEMPLATE = """Tu es médecin DIM (Département d'Information Médicale), expert en facturation T2A/PMSI aux urgences hospitalières en France.
|
||||
|
||||
Analyse le dossier patient ci-dessous pour déterminer si le passage relève :
|
||||
- FORFAIT_URGENCE : passage simple, retour à domicile, sans surveillance prolongée ni soins continus
|
||||
- REQUALIFICATION_HOSPITALISATION : séjour MCO requis selon les 3 critères PMSI/ATIH
|
||||
|
||||
LES 3 CRITÈRES UHCD (au moins 2 sur 3 validés ⇒ REQUALIFICATION) :
|
||||
1. Pathologie potentiellement évolutive (instabilité hémodynamique, terrain à risque, traitement nécessitant adaptation)
|
||||
2. Surveillance médicale et paramédicale prolongée (constantes itératives, observations IDE/médecin, durée > 6 h)
|
||||
3. Examens complémentaires ou actes thérapeutiques (biologie, imagerie, sutures, gestes techniques)
|
||||
|
||||
INSTRUCTIONS STRICTES :
|
||||
1. N'utilise QUE des éléments littéralement présents dans le dossier patient. N'invente AUCUN critère.
|
||||
2. Pour CHAQUE critère (1, 2, 3), tu DOIS produire un texte de preuve qui contient AU MOINS UNE CITATION LITTÉRALE du dossier entre guillemets français « ... ». Exemple : « FC à 110 bpm, TA 92/60 ».
|
||||
3. Si le critère est NON validé, ne renvoie JAMAIS un fallback creux : explique factuellement ce qui manque, en citant le dossier (ex: « Sortie à H+2 », « Aucun acte technique au compte-rendu »).
|
||||
4. Le texte de chaque preuve fait 2-3 phrases : (i) la citation littérale, (ii) l'analyse PMSI, (iii) la conclusion validé/non validé.
|
||||
5. Calcule la durée totale du passage en heures (admission → sortie/transfert) à partir des horaires du dossier.
|
||||
6. Module ta confiance honnêtement :
|
||||
- "elevee" uniquement si tous les indices convergent
|
||||
- "moyenne" si éléments ambivalents
|
||||
- "faible" si information manquante ou très atypique
|
||||
|
||||
Réponds STRICTEMENT en JSON valide, sans texte avant ni après :
|
||||
{{
|
||||
"duree_passage_heures": <nombre>,
|
||||
"elements_pour_hospitalisation": [<phrases littéralement extraites du dossier>],
|
||||
"elements_pour_forfait": [<phrases littéralement extraites du dossier>],
|
||||
"decision": "FORFAIT_URGENCE" | "REQUALIFICATION_HOSPITALISATION",
|
||||
"decision_court": "UHCD" | "Forfait Urgences",
|
||||
"preuve_critere1": "<2-3 phrases incluant AU MOINS UNE citation littérale entre « » (motif, symptôme, terrain à risque, traitement). Si non validé : factualise ce qui manque en citant le dossier.>",
|
||||
"critere1_valide": true | false,
|
||||
"preuve_critere2": "<2-3 phrases incluant AU MOINS UNE citation littérale entre « » (constantes, observations IDE, durée surveillance). Si non validé : factualise.>",
|
||||
"critere2_valide": true | false,
|
||||
"preuve_critere3": "<2-3 phrases incluant AU MOINS UNE citation littérale entre « » (actes/examens : biologie, imagerie, suture, etc.). Si non validé : factualise.>",
|
||||
"critere3_valide": true | false,
|
||||
"justification": "<2-3 phrases synthétiques s'appuyant explicitement sur les preuves ci-dessus, avec au moins une citation>",
|
||||
"confiance": "elevee" | "moyenne" | "faible"
|
||||
}}
|
||||
|
||||
DOSSIER PATIENT :
|
||||
{dpi}
|
||||
"""
|
||||
|
||||
|
||||
def analyze_dpi(
|
||||
dpi_text: str,
|
||||
model: str = DEFAULT_MODEL,
|
||||
timeout: int = DEFAULT_TIMEOUT,
|
||||
ollama_url: str = OLLAMA_URL,
|
||||
) -> Dict[str, Any]:
|
||||
"""Soumet un DPI urgences à un LLM Ollama et retourne la décision JSON.
|
||||
|
||||
Args:
|
||||
dpi_text: Texte du dossier patient (concaténation des onglets ou DPI brut).
|
||||
model: Modèle Ollama à utiliser (default qwen2.5:7b — 100% accuracy bench).
|
||||
timeout: Timeout HTTP en secondes.
|
||||
ollama_url: Endpoint Ollama (default localhost:11434/api/generate).
|
||||
|
||||
Returns:
|
||||
Dict avec :
|
||||
decision: "FORFAIT_URGENCE" | "REQUALIFICATION_HOSPITALISATION"
|
||||
elements_pour_hospitalisation: List[str]
|
||||
elements_pour_forfait: List[str]
|
||||
duree_passage_heures: float
|
||||
justification: str
|
||||
confiance: "elevee" | "moyenne" | "faible"
|
||||
_elapsed_s: float (latence)
|
||||
_model: str
|
||||
En cas d'erreur :
|
||||
{"_error": str, "_elapsed_s": float} (réseau / Ollama indisponible)
|
||||
{"_parse_error": True, "_raw": str, "_elapsed_s": float} (JSON invalide)
|
||||
"""
|
||||
payload = {
|
||||
"model": model,
|
||||
"prompt": PROMPT_TEMPLATE.format(dpi=dpi_text),
|
||||
"stream": False,
|
||||
"format": "json",
|
||||
"keep_alive": "5m",
|
||||
"options": {
|
||||
"temperature": 0.1,
|
||||
"num_predict": 1500,
|
||||
"num_ctx": 16384,
|
||||
},
|
||||
}
|
||||
data = json.dumps(payload).encode("utf-8")
|
||||
req = urllib.request.Request(
|
||||
ollama_url,
|
||||
data=data,
|
||||
headers={"Content-Type": "application/json"},
|
||||
method="POST",
|
||||
)
|
||||
t0 = time.time()
|
||||
try:
|
||||
with urllib.request.urlopen(req, timeout=timeout) as resp:
|
||||
body = json.loads(resp.read().decode("utf-8"))
|
||||
except (urllib.error.URLError, TimeoutError, ConnectionError) as e:
|
||||
elapsed = round(time.time() - t0, 1)
|
||||
logger.warning("analyze_dpi: Ollama indisponible (%s) après %.1fs", e, elapsed)
|
||||
return {"_error": str(e), "_elapsed_s": elapsed, "_model": model}
|
||||
|
||||
elapsed = time.time() - t0
|
||||
|
||||
raw_response = body.get("response", "").strip()
|
||||
raw_thinking = body.get("thinking", "").strip()
|
||||
|
||||
candidates = [raw_response]
|
||||
if not raw_response and raw_thinking:
|
||||
last_close = raw_thinking.rfind("}")
|
||||
last_open = raw_thinking.rfind("{", 0, last_close)
|
||||
if last_open != -1 and last_close != -1:
|
||||
candidates.append(raw_thinking[last_open:last_close + 1])
|
||||
|
||||
parsed = None
|
||||
for cand in candidates:
|
||||
cleaned = cand
|
||||
if cleaned.startswith("```"):
|
||||
cleaned = cleaned.split("\n", 1)[-1]
|
||||
if cleaned.endswith("```"):
|
||||
cleaned = cleaned.rsplit("```", 1)[0]
|
||||
cleaned = cleaned.strip()
|
||||
try:
|
||||
parsed = json.loads(cleaned)
|
||||
break
|
||||
except json.JSONDecodeError:
|
||||
continue
|
||||
|
||||
if parsed is None:
|
||||
return {
|
||||
"_parse_error": True,
|
||||
"_raw": (raw_response or raw_thinking)[:500],
|
||||
"_elapsed_s": round(elapsed, 1),
|
||||
"_model": model,
|
||||
}
|
||||
|
||||
parsed["_elapsed_s"] = round(elapsed, 1)
|
||||
parsed["_model"] = model
|
||||
parsed["_eval_count"] = body.get("eval_count")
|
||||
return parsed
|
||||
@@ -2,7 +2,140 @@
|
||||
Pipeline module - Orchestration du flux RPA Vision V3
|
||||
"""
|
||||
|
||||
from __future__ import annotations
|
||||
|
||||
import threading
|
||||
from typing import Optional
|
||||
|
||||
from .workflow_pipeline import WorkflowPipeline, create_pipeline
|
||||
from .screen_analyzer import ScreenAnalyzer
|
||||
from .screen_state_cache import ScreenStateCache, compute_perceptual_hash
|
||||
from .edge_scorer import EdgeScorer, EdgeScore
|
||||
|
||||
__all__ = ["WorkflowPipeline", "create_pipeline", "ScreenAnalyzer"]
|
||||
__all__ = [
|
||||
"WorkflowPipeline",
|
||||
"create_pipeline",
|
||||
"ScreenAnalyzer",
|
||||
"ScreenStateCache",
|
||||
"compute_perceptual_hash",
|
||||
"EdgeScorer",
|
||||
"EdgeScore",
|
||||
"get_screen_analyzer",
|
||||
"reset_screen_analyzer",
|
||||
"get_screen_state_cache",
|
||||
"reset_screen_state_cache",
|
||||
]
|
||||
|
||||
|
||||
# =============================================================================
|
||||
# Singleton ScreenAnalyzer
|
||||
# =============================================================================
|
||||
#
|
||||
# Une seule instance est partagée entre ExecutionLoop, GraphBuilder et
|
||||
# stream_processor pour éviter le double chargement GPU (UIDetector + CLIP
|
||||
# = 6-10 Go VRAM, plafond 12 Go sur RTX 5070).
|
||||
#
|
||||
# Thread-safe : protégé par un lock.
|
||||
#
|
||||
# IMPORTANT (Lot C — avril 2026) :
|
||||
# Ce singleton ne porte plus AUCUN contexte d'exécution. Il détient
|
||||
# uniquement les ressources lourdes (modèles OCR, UIDetector, CLIP).
|
||||
# • Les flags runtime (`enable_ocr`, `enable_ui_detection`) et l'identité
|
||||
# de session (`session_id`) se passent en kwargs-only à `analyze()`,
|
||||
# jamais en mutant l'instance. Voir `ScreenAnalyzer.analyze()`.
|
||||
# • L'argument `session_id` de `get_screen_analyzer()` ne sert QUE de
|
||||
# valeur par défaut historique, ignorée après la première création.
|
||||
# À terme, prévoir sa suppression.
|
||||
# =============================================================================
|
||||
|
||||
|
||||
_SCREEN_ANALYZER_SINGLETON: Optional[ScreenAnalyzer] = None
|
||||
_SCREEN_ANALYZER_LOCK = threading.Lock()
|
||||
|
||||
|
||||
def get_screen_analyzer(
|
||||
ui_detector=None,
|
||||
ocr_engine: Optional[str] = None,
|
||||
session_id: str = "",
|
||||
force_new: bool = False,
|
||||
) -> ScreenAnalyzer:
|
||||
"""
|
||||
Récupérer l'instance partagée de ScreenAnalyzer.
|
||||
|
||||
Création à la première demande (lazy). Les appels ultérieurs retournent
|
||||
la même instance, quels que soient les arguments (sauf `force_new=True`).
|
||||
|
||||
Args:
|
||||
ui_detector: UIDetector optionnel (utilisé seulement à la 1ère création)
|
||||
ocr_engine: Moteur OCR ("doctr", "tesseract", None=auto)
|
||||
session_id: ID de session pour la 1ère création
|
||||
force_new: Forcer la création d'une nouvelle instance (tests)
|
||||
|
||||
Returns:
|
||||
Instance partagée de ScreenAnalyzer
|
||||
"""
|
||||
global _SCREEN_ANALYZER_SINGLETON
|
||||
|
||||
if force_new:
|
||||
with _SCREEN_ANALYZER_LOCK:
|
||||
_SCREEN_ANALYZER_SINGLETON = ScreenAnalyzer(
|
||||
ui_detector=ui_detector,
|
||||
ocr_engine=ocr_engine,
|
||||
session_id=session_id,
|
||||
)
|
||||
return _SCREEN_ANALYZER_SINGLETON
|
||||
|
||||
if _SCREEN_ANALYZER_SINGLETON is not None:
|
||||
return _SCREEN_ANALYZER_SINGLETON
|
||||
|
||||
with _SCREEN_ANALYZER_LOCK:
|
||||
# Double-check locking
|
||||
if _SCREEN_ANALYZER_SINGLETON is None:
|
||||
_SCREEN_ANALYZER_SINGLETON = ScreenAnalyzer(
|
||||
ui_detector=ui_detector,
|
||||
ocr_engine=ocr_engine,
|
||||
session_id=session_id,
|
||||
)
|
||||
return _SCREEN_ANALYZER_SINGLETON
|
||||
|
||||
|
||||
def reset_screen_analyzer() -> None:
|
||||
"""Réinitialiser le singleton (tests uniquement)."""
|
||||
global _SCREEN_ANALYZER_SINGLETON
|
||||
with _SCREEN_ANALYZER_LOCK:
|
||||
_SCREEN_ANALYZER_SINGLETON = None
|
||||
|
||||
|
||||
# =============================================================================
|
||||
# Singleton ScreenStateCache (partagé)
|
||||
# =============================================================================
|
||||
|
||||
|
||||
_SCREEN_STATE_CACHE_SINGLETON: Optional[ScreenStateCache] = None
|
||||
_SCREEN_STATE_CACHE_LOCK = threading.Lock()
|
||||
|
||||
|
||||
def get_screen_state_cache(
|
||||
ttl_seconds: float = 2.0,
|
||||
max_entries: int = 16,
|
||||
) -> ScreenStateCache:
|
||||
"""
|
||||
Retourne le cache de ScreenState partagé (créé à la 1ère demande).
|
||||
"""
|
||||
global _SCREEN_STATE_CACHE_SINGLETON
|
||||
if _SCREEN_STATE_CACHE_SINGLETON is not None:
|
||||
return _SCREEN_STATE_CACHE_SINGLETON
|
||||
with _SCREEN_STATE_CACHE_LOCK:
|
||||
if _SCREEN_STATE_CACHE_SINGLETON is None:
|
||||
_SCREEN_STATE_CACHE_SINGLETON = ScreenStateCache(
|
||||
ttl_seconds=ttl_seconds,
|
||||
max_entries=max_entries,
|
||||
)
|
||||
return _SCREEN_STATE_CACHE_SINGLETON
|
||||
|
||||
|
||||
def reset_screen_state_cache() -> None:
|
||||
"""Réinitialiser le cache partagé (tests uniquement)."""
|
||||
global _SCREEN_STATE_CACHE_SINGLETON
|
||||
with _SCREEN_STATE_CACHE_LOCK:
|
||||
_SCREEN_STATE_CACHE_SINGLETON = None
|
||||
|
||||
380
core/pipeline/edge_scorer.py
Normal file
380
core/pipeline/edge_scorer.py
Normal file
@@ -0,0 +1,380 @@
|
||||
"""
|
||||
EdgeScorer — Sélection robuste d'un edge parmi plusieurs candidats.
|
||||
|
||||
Au lieu de prendre "le premier edge sortant" (comportement legacy),
|
||||
ce module :
|
||||
|
||||
1. Applique un **filtre dur** : rejette les edges dont les `pre_conditions`
|
||||
(EdgeConstraints) échouent étant donné le ScreenState courant.
|
||||
2. Applique un **ranking léger** : score composite
|
||||
- `stats.success_rate` (pondéré fort)
|
||||
- match du `target_spec` (présence d'un UI element compatible)
|
||||
- récence (dernière exécution réussie)
|
||||
3. Retourne le meilleur edge, ou `None` si aucun ne passe le filtre.
|
||||
|
||||
API principale :
|
||||
>>> scorer = EdgeScorer()
|
||||
>>> edge = scorer.select_best(edges, screen_state=state)
|
||||
|
||||
Les scores individuels sont exposés via `score_edge()` pour les tests
|
||||
et la télémétrie.
|
||||
"""
|
||||
|
||||
from __future__ import annotations
|
||||
|
||||
import logging
|
||||
from dataclasses import dataclass
|
||||
from datetime import datetime
|
||||
from typing import List, Optional, Sequence
|
||||
|
||||
from core.models.screen_state import ScreenState
|
||||
from core.models.workflow_graph import WorkflowEdge
|
||||
|
||||
logger = logging.getLogger(__name__)
|
||||
|
||||
|
||||
# =============================================================================
|
||||
# Résultat de scoring (utile pour la télémétrie / debug)
|
||||
# =============================================================================
|
||||
|
||||
|
||||
@dataclass
|
||||
class EdgeScore:
|
||||
"""Résultat détaillé du scoring d'un edge."""
|
||||
|
||||
edge: WorkflowEdge
|
||||
total: float
|
||||
success_rate: float
|
||||
target_match: float
|
||||
recency: float
|
||||
passed_preconditions: bool
|
||||
precondition_reason: str = "OK"
|
||||
|
||||
def __lt__(self, other: "EdgeScore") -> bool:
|
||||
# Utilisé par sorted() : plus grand score = meilleur
|
||||
return self.total < other.total
|
||||
|
||||
|
||||
# =============================================================================
|
||||
# Scorer
|
||||
# =============================================================================
|
||||
|
||||
|
||||
class EdgeScorer:
|
||||
"""
|
||||
Sélectionne le meilleur edge sortant étant donné un ScreenState.
|
||||
|
||||
Les poids par défaut peuvent être ajustés à la construction.
|
||||
"""
|
||||
|
||||
def __init__(
|
||||
self,
|
||||
weight_success_rate: float = 0.55,
|
||||
weight_target_match: float = 0.35,
|
||||
weight_recency: float = 0.10,
|
||||
default_success_rate: float = 0.5,
|
||||
):
|
||||
"""
|
||||
Args:
|
||||
weight_success_rate: poids du `edge.stats.success_rate`
|
||||
weight_target_match: poids du match `target_spec` / `ui_elements`
|
||||
weight_recency: poids de la récence de la dernière exécution
|
||||
default_success_rate: valeur quand l'edge n'a jamais été exécuté
|
||||
"""
|
||||
total = weight_success_rate + weight_target_match + weight_recency
|
||||
if total <= 0:
|
||||
raise ValueError("La somme des poids doit être > 0")
|
||||
# Normalisation silencieuse
|
||||
self.w_success = weight_success_rate / total
|
||||
self.w_target = weight_target_match / total
|
||||
self.w_recency = weight_recency / total
|
||||
self.default_success_rate = default_success_rate
|
||||
|
||||
# -------------------------------------------------------------------------
|
||||
# API publique
|
||||
# -------------------------------------------------------------------------
|
||||
|
||||
def select_best(
|
||||
self,
|
||||
edges: Sequence[WorkflowEdge],
|
||||
screen_state: Optional[ScreenState] = None,
|
||||
strategy: str = "best",
|
||||
source_similarity: float = 1.0,
|
||||
) -> Optional[WorkflowEdge]:
|
||||
"""
|
||||
Sélectionne le meilleur edge.
|
||||
|
||||
Args:
|
||||
edges: Liste des edges candidats (généralement les sortants d'un node)
|
||||
screen_state: État courant pour évaluer pre_conditions et target_spec
|
||||
strategy: "best" (défaut, score complet) ou "first" (legacy, premier edge)
|
||||
source_similarity: confiance du matching qui a identifié le node
|
||||
source courant (valeur propagée depuis `match_current_state`).
|
||||
Utilisée pour évaluer la précondition ``min_source_similarity``
|
||||
de chaque edge. Défaut à ``1.0`` pour compat avec les appelants
|
||||
qui ne la fournissent pas encore.
|
||||
|
||||
Returns:
|
||||
Meilleur edge ou None si aucun ne passe les pre_conditions
|
||||
"""
|
||||
if not edges:
|
||||
return None
|
||||
|
||||
if strategy == "first":
|
||||
# Comportement legacy — retourne le premier edge quoi qu'il arrive
|
||||
return edges[0]
|
||||
|
||||
scores = self.rank(
|
||||
edges, screen_state=screen_state, source_similarity=source_similarity
|
||||
)
|
||||
|
||||
# Filtrer ceux qui ont passé les pre_conditions
|
||||
valid = [s for s in scores if s.passed_preconditions]
|
||||
if not valid:
|
||||
# Aucun edge valide → log pour debug, retourner None
|
||||
reasons = "; ".join(
|
||||
f"{s.edge.edge_id}: {s.precondition_reason}" for s in scores[:5]
|
||||
)
|
||||
logger.warning(
|
||||
f"[EdgeScorer] Aucun edge valide parmi {len(edges)} candidats. "
|
||||
f"Raisons: {reasons}"
|
||||
)
|
||||
return None
|
||||
|
||||
best = valid[0].edge # déjà trié par score décroissant
|
||||
logger.debug(
|
||||
f"[EdgeScorer] Sélection {best.edge_id} "
|
||||
f"(score={valid[0].total:.3f}, parmi {len(valid)} valides)"
|
||||
)
|
||||
return best
|
||||
|
||||
def rank(
|
||||
self,
|
||||
edges: Sequence[WorkflowEdge],
|
||||
screen_state: Optional[ScreenState] = None,
|
||||
source_similarity: float = 1.0,
|
||||
) -> List[EdgeScore]:
|
||||
"""
|
||||
Retourne la liste des edges triés par score décroissant,
|
||||
avec le détail pour chaque edge.
|
||||
|
||||
Tiebreak : `success_rate` le plus haut.
|
||||
|
||||
Args:
|
||||
edges: edges candidats
|
||||
screen_state: état courant (pour pre_conditions + target_match)
|
||||
source_similarity: confiance du match courant, propagée aux
|
||||
pre_conditions pour vérifier ``min_source_similarity``
|
||||
"""
|
||||
scored = [
|
||||
self.score_edge(edge, screen_state, source_similarity=source_similarity)
|
||||
for edge in edges
|
||||
]
|
||||
# Tri : score total décroissant, puis success_rate décroissant
|
||||
scored.sort(key=lambda s: (s.total, s.success_rate), reverse=True)
|
||||
return scored
|
||||
|
||||
# -------------------------------------------------------------------------
|
||||
# Scoring par edge
|
||||
# -------------------------------------------------------------------------
|
||||
|
||||
def score_edge(
|
||||
self,
|
||||
edge: WorkflowEdge,
|
||||
screen_state: Optional[ScreenState] = None,
|
||||
source_similarity: float = 1.0,
|
||||
) -> EdgeScore:
|
||||
"""
|
||||
Calcule le score d'un edge.
|
||||
|
||||
Les pre_conditions sont évaluées ici mais servent uniquement de filtre
|
||||
dur (le score total reste calculé, mais `passed_preconditions` est à False).
|
||||
|
||||
Args:
|
||||
edge: edge à scorer
|
||||
screen_state: état courant (fenêtre, textes, ui_elements)
|
||||
source_similarity: confiance du matching courant, injectée dans
|
||||
``EdgeConstraints.check_preconditions`` pour évaluer
|
||||
``min_source_similarity``.
|
||||
"""
|
||||
# 1. Pre-conditions : filtre dur
|
||||
passed, reason = self._check_preconditions(
|
||||
edge, screen_state, source_similarity=source_similarity
|
||||
)
|
||||
|
||||
# 2. Success rate (dépend des stats existantes)
|
||||
success_rate = self._score_success_rate(edge)
|
||||
|
||||
# 3. Target match (UI element présent ?)
|
||||
target_match = self._score_target_match(edge, screen_state)
|
||||
|
||||
# 4. Récence
|
||||
recency = self._score_recency(edge)
|
||||
|
||||
total = (
|
||||
self.w_success * success_rate
|
||||
+ self.w_target * target_match
|
||||
+ self.w_recency * recency
|
||||
)
|
||||
|
||||
return EdgeScore(
|
||||
edge=edge,
|
||||
total=total,
|
||||
success_rate=success_rate,
|
||||
target_match=target_match,
|
||||
recency=recency,
|
||||
passed_preconditions=passed,
|
||||
precondition_reason=reason,
|
||||
)
|
||||
|
||||
# -------------------------------------------------------------------------
|
||||
# Composantes du score
|
||||
# -------------------------------------------------------------------------
|
||||
|
||||
def _check_preconditions(
|
||||
self,
|
||||
edge: WorkflowEdge,
|
||||
screen_state: Optional[ScreenState],
|
||||
source_similarity: float = 1.0,
|
||||
) -> tuple[bool, str]:
|
||||
"""
|
||||
Vérifier les pre_conditions de l'edge.
|
||||
|
||||
Si pas de ScreenState, on ne peut rien vérifier → on laisse passer
|
||||
(mais on loggue).
|
||||
|
||||
Args:
|
||||
edge: edge à évaluer
|
||||
screen_state: état courant (None si non dispo)
|
||||
source_similarity: confiance du matching courant propagée par
|
||||
l'appelant (EdgeScorer.score_edge/rank/select_best). Elle
|
||||
alimente ``EdgeConstraints.check_preconditions`` pour rendre
|
||||
effective la contrainte ``min_source_similarity``.
|
||||
"""
|
||||
constraints = edge.constraints
|
||||
if constraints is None:
|
||||
return True, "OK (pas de contraintes)"
|
||||
|
||||
if screen_state is None:
|
||||
# Pas de ScreenState → on ne peut évaluer ni fenêtre, ni textes,
|
||||
# mais la similarité source reste vérifiable.
|
||||
try:
|
||||
ok, reason = constraints.check_preconditions(
|
||||
window_title="",
|
||||
app_name="",
|
||||
detected_texts=[],
|
||||
source_similarity=source_similarity,
|
||||
)
|
||||
if not ok:
|
||||
return ok, reason
|
||||
except Exception as e:
|
||||
logger.warning(f"[EdgeScorer] Erreur check_preconditions: {e}")
|
||||
return True, f"Erreur ignorée: {e}"
|
||||
return True, "OK (pas de ScreenState pour évaluer)"
|
||||
|
||||
window_title = screen_state.window.window_title if screen_state.window else ""
|
||||
app_name = screen_state.window.app_name if screen_state.window else ""
|
||||
detected_texts = (
|
||||
screen_state.perception.detected_text
|
||||
if screen_state.perception
|
||||
else []
|
||||
)
|
||||
|
||||
try:
|
||||
ok, reason = constraints.check_preconditions(
|
||||
window_title=window_title,
|
||||
app_name=app_name,
|
||||
detected_texts=detected_texts,
|
||||
source_similarity=source_similarity,
|
||||
)
|
||||
return ok, reason
|
||||
except Exception as e:
|
||||
logger.warning(f"[EdgeScorer] Erreur check_preconditions: {e}")
|
||||
# En cas d'erreur, on ne bloque pas l'edge
|
||||
return True, f"Erreur ignorée: {e}"
|
||||
|
||||
def _score_success_rate(self, edge: WorkflowEdge) -> float:
|
||||
"""Score basé sur `edge.stats.success_rate`."""
|
||||
if edge.stats is None or edge.stats.execution_count == 0:
|
||||
return self.default_success_rate
|
||||
return max(0.0, min(1.0, edge.stats.success_rate))
|
||||
|
||||
def _score_target_match(
|
||||
self,
|
||||
edge: WorkflowEdge,
|
||||
screen_state: Optional[ScreenState],
|
||||
) -> float:
|
||||
"""
|
||||
Score de correspondance entre le `target_spec` de l'action et
|
||||
les `ui_elements` de l'écran courant.
|
||||
|
||||
Retourne :
|
||||
- 1.0 si un élément matche strictement (texte ou rôle)
|
||||
- 0.5 si aucun screen_state fourni (neutre, pas pénalisant)
|
||||
- 0.0 si aucun élément compatible
|
||||
"""
|
||||
if screen_state is None:
|
||||
return 0.5
|
||||
|
||||
target = edge.action.target if edge.action else None
|
||||
if target is None:
|
||||
return 0.5
|
||||
|
||||
ui_elements = screen_state.ui_elements or []
|
||||
if not ui_elements:
|
||||
# Pas d'UI détectée → on ne peut pas trancher, neutre
|
||||
return 0.5
|
||||
|
||||
target_text = (target.by_text or "").lower().strip()
|
||||
target_role = (target.by_role or "").lower().strip()
|
||||
|
||||
best = 0.0
|
||||
for el in ui_elements:
|
||||
score = 0.0
|
||||
el_label = getattr(el, "label", "") or ""
|
||||
el_role = getattr(el, "role", "") or ""
|
||||
el_type = getattr(el, "type", "") or ""
|
||||
|
||||
if target_text:
|
||||
if target_text == el_label.lower().strip():
|
||||
score = max(score, 1.0)
|
||||
elif target_text in el_label.lower():
|
||||
score = max(score, 0.8)
|
||||
|
||||
if target_role:
|
||||
if target_role == el_role.lower() or target_role == el_type.lower():
|
||||
score = max(score, 0.9)
|
||||
|
||||
if not target_text and not target_role and target.by_position:
|
||||
# Si seule la position est fournie, on considère toujours match possible
|
||||
score = 0.6
|
||||
|
||||
if score > best:
|
||||
best = score
|
||||
|
||||
# Si on n'a rien trouvé mais qu'un target est demandé → 0.0 (fort négatif)
|
||||
if best == 0.0 and (target_text or target_role):
|
||||
return 0.0
|
||||
|
||||
return best if best > 0 else 0.5
|
||||
|
||||
def _score_recency(self, edge: WorkflowEdge) -> float:
|
||||
"""
|
||||
Score de récence basé sur `edge.stats.last_executed`.
|
||||
|
||||
Échelle :
|
||||
- exécuté dans les dernières 24h : 1.0
|
||||
- exécuté dans les 7 derniers jours : 0.7
|
||||
- exécuté il y a plus longtemps : 0.3
|
||||
- jamais exécuté : 0.5 (neutre)
|
||||
"""
|
||||
if edge.stats is None or edge.stats.last_executed is None:
|
||||
return 0.5
|
||||
|
||||
delta = datetime.now() - edge.stats.last_executed
|
||||
seconds = delta.total_seconds()
|
||||
if seconds < 24 * 3600:
|
||||
return 1.0
|
||||
if seconds < 7 * 24 * 3600:
|
||||
return 0.7
|
||||
return 0.3
|
||||
@@ -9,13 +9,33 @@ Orchestre les 4 niveaux du ScreenState :
|
||||
|
||||
Ce module comble le chaînon manquant entre la capture brute (Couche 0)
|
||||
et la construction d'embeddings (Couche 3).
|
||||
|
||||
=============================================================================
|
||||
Thread-safety & partage multi-loops (Lot C — avril 2026)
|
||||
=============================================================================
|
||||
Cet analyseur peut être partagé entre plusieurs `ExecutionLoop` (singleton
|
||||
`get_screen_analyzer()`). Pour éviter la contamination croisée :
|
||||
|
||||
• `analyze()` NE MUTE JAMAIS `self._ocr`, `self._ui_detector`,
|
||||
`self._ocr_initialized`, `self._ui_detector_initialized` pour gérer les
|
||||
flags runtime (enable_ocr / enable_ui_detection). Ces flags sont par
|
||||
appel, résolus en variables locales.
|
||||
• `session_id` circule en paramètre d'appel et renseigne la metadata du
|
||||
ScreenState ; l'attribut `self.session_id` n'est qu'un défaut historique
|
||||
(rétrocompat) et n'est plus la source de vérité.
|
||||
• L'init lazy des composants lourds (OCR, UIDetector) est protégée par un
|
||||
`_init_lock` par instance pour empêcher une double initialisation
|
||||
concurrente.
|
||||
"""
|
||||
|
||||
import contextlib
|
||||
import logging
|
||||
import os
|
||||
import threading
|
||||
import time
|
||||
from datetime import datetime
|
||||
from pathlib import Path
|
||||
from typing import Optional, Dict, Any, List
|
||||
from typing import Optional, Dict, Any, List, Tuple
|
||||
|
||||
from PIL import Image
|
||||
|
||||
@@ -32,6 +52,44 @@ from core.models.ui_element import UIElement
|
||||
logger = logging.getLogger(__name__)
|
||||
|
||||
|
||||
# Lock d'inférence local au module : sert de fallback si le GPUResourceManager
|
||||
# n'est pas disponible (import error, tests). Partagé entre toutes les instances
|
||||
# ScreenAnalyzer du process, cohérent avec le singleton get_screen_analyzer().
|
||||
_ANALYZE_FALLBACK_LOCK = threading.Lock()
|
||||
|
||||
|
||||
def _acquire_gpu_context(timeout: Optional[float] = None):
|
||||
"""
|
||||
Retourne un context manager pour sérialiser les appels GPU.
|
||||
|
||||
Préfère `GPUResourceManager.acquire_inference()` si disponible (coordination
|
||||
globale), sinon bascule sur un lock threading local au module.
|
||||
"""
|
||||
try:
|
||||
from core.gpu import get_gpu_resource_manager
|
||||
|
||||
manager = get_gpu_resource_manager()
|
||||
return manager.acquire_inference(timeout=timeout)
|
||||
except Exception as e: # pragma: no cover - fallback defensif
|
||||
logger.debug(f"GPUResourceManager indisponible, fallback lock local: {e}")
|
||||
|
||||
@contextlib.contextmanager
|
||||
def _fallback():
|
||||
if timeout is None:
|
||||
_ANALYZE_FALLBACK_LOCK.acquire()
|
||||
yield True
|
||||
_ANALYZE_FALLBACK_LOCK.release()
|
||||
else:
|
||||
got = _ANALYZE_FALLBACK_LOCK.acquire(timeout=timeout)
|
||||
try:
|
||||
yield got
|
||||
finally:
|
||||
if got:
|
||||
_ANALYZE_FALLBACK_LOCK.release()
|
||||
|
||||
return _fallback()
|
||||
|
||||
|
||||
class ScreenAnalyzer:
|
||||
"""
|
||||
Construit un ScreenState complet (4 niveaux) depuis un screenshot.
|
||||
@@ -44,6 +102,14 @@ class ScreenAnalyzer:
|
||||
>>> state = analyzer.analyze("/path/to/screenshot.png")
|
||||
>>> print(state.perception.detected_text)
|
||||
>>> print(len(state.ui_elements))
|
||||
|
||||
Runtime overrides (kwargs-only) sur analyze() :
|
||||
>>> state = analyzer.analyze(
|
||||
... path,
|
||||
... enable_ocr=False, # bypass OCR pour cet appel
|
||||
... enable_ui_detection=False, # bypass UIDetector
|
||||
... session_id="session_42", # session par appel
|
||||
... )
|
||||
"""
|
||||
|
||||
def __init__(
|
||||
@@ -56,18 +122,27 @@ class ScreenAnalyzer:
|
||||
Args:
|
||||
ui_detector: Instance de UIDetector (créé si None)
|
||||
ocr_engine: Moteur OCR à utiliser ("doctr", "tesseract", None=auto)
|
||||
session_id: ID de la session en cours
|
||||
session_id: ID de session par défaut (rétrocompat ; préférer passer
|
||||
`session_id` en kwarg de `analyze()` pour chaque appel).
|
||||
"""
|
||||
self._ui_detector = ui_detector
|
||||
self._ocr_engine_name = ocr_engine
|
||||
self._ocr = None
|
||||
# Session par défaut (rétrocompat). La source de vérité est désormais
|
||||
# le paramètre `session_id` de `analyze()`.
|
||||
self.session_id = session_id
|
||||
# Compteur d'états — protégé par _state_lock pour être safe en parallèle.
|
||||
self._state_counter = 0
|
||||
self._state_lock = threading.Lock()
|
||||
|
||||
# Initialisation lazy pour éviter les imports lourds au démarrage
|
||||
# Initialisation lazy pour éviter les imports lourds au démarrage.
|
||||
self._ui_detector_initialized = ui_detector is not None
|
||||
self._ocr_initialized = False
|
||||
|
||||
# Lock dédié à l'init lazy : empêche deux threads d'initialiser
|
||||
# simultanément OCR ou UIDetector (double chargement GPU).
|
||||
self._init_lock = threading.Lock()
|
||||
|
||||
# =========================================================================
|
||||
# API publique
|
||||
# =========================================================================
|
||||
@@ -77,28 +152,85 @@ class ScreenAnalyzer:
|
||||
screenshot_path: str,
|
||||
window_info: Optional[Dict[str, Any]] = None,
|
||||
context: Optional[Dict[str, Any]] = None,
|
||||
*,
|
||||
enable_ocr: bool = True,
|
||||
enable_ui_detection: bool = True,
|
||||
session_id: str = "",
|
||||
) -> ScreenState:
|
||||
"""
|
||||
Analyser un screenshot et construire un ScreenState complet.
|
||||
|
||||
Les flags `enable_ocr`, `enable_ui_detection` et `session_id` sont
|
||||
**par appel, kwargs-only**, pour ne pas polluer l'état partagé du
|
||||
singleton quand plusieurs `ExecutionLoop` se partagent l'analyseur.
|
||||
|
||||
Args:
|
||||
screenshot_path: Chemin vers le fichier image
|
||||
window_info: Infos fenêtre active {"title": ..., "app_name": ...}
|
||||
context: Contexte métier optionnel
|
||||
enable_ocr: Active l'OCR pour cet appel (True par défaut).
|
||||
False → `detected_text=[]`, aucune init d'OCR déclenchée.
|
||||
enable_ui_detection: Active la détection UI pour cet appel
|
||||
(True par défaut). False → `ui_elements=[]`.
|
||||
session_id: ID de session pour cet appel. Si vide, on retombe sur
|
||||
`self.session_id` (rétrocompat). Cette valeur est propagée
|
||||
dans `ScreenState.session_id` et `metadata["session_id"]`.
|
||||
|
||||
Returns:
|
||||
ScreenState avec les 4 niveaux remplis
|
||||
ScreenState avec les 4 niveaux remplis.
|
||||
"""
|
||||
screenshot_path = str(screenshot_path)
|
||||
|
||||
# Résolution de la session : priorité au kwarg, fallback sur l'état
|
||||
# interne (legacy). Variable locale uniquement — pas de mutation.
|
||||
effective_session_id = session_id or self.session_id
|
||||
|
||||
# Compteur incrémenté sous lock pour identifiants uniques même en
|
||||
# parallèle. C'est la seule mutation tolérée : elle n'impacte pas le
|
||||
# comportement OCR/UI.
|
||||
with self._state_lock:
|
||||
self._state_counter += 1
|
||||
state_counter = self._state_counter
|
||||
|
||||
state_id = f"{self.session_id}_state_{self._state_counter:04d}" if self.session_id else f"state_{self._state_counter:04d}"
|
||||
state_id = (
|
||||
f"{effective_session_id}_state_{state_counter:04d}"
|
||||
if effective_session_id
|
||||
else f"state_{state_counter:04d}"
|
||||
)
|
||||
|
||||
# Niveau 1 : Raw
|
||||
# Niveau 1 : Raw (léger, hors lock GPU)
|
||||
raw = self._build_raw_level(screenshot_path)
|
||||
|
||||
# Niveau 2 : Perception (OCR)
|
||||
detected_text = self._extract_text(screenshot_path)
|
||||
# Résolution locale des instances OCR / UIDetector selon les flags.
|
||||
# Aucune mutation de self ici : on décide simplement ce qu'on utilise.
|
||||
ocr_instance = self._resolve_ocr_instance(enable_ocr=enable_ocr)
|
||||
ui_detector_instance = self._resolve_ui_detector_instance(
|
||||
enable_ui_detection=enable_ui_detection
|
||||
)
|
||||
|
||||
# Niveaux 2 et 3 : OCR + détection UI sont les étapes lourdes en GPU.
|
||||
# On sérialise via GPUResourceManager.acquire_inference() pour éviter
|
||||
# que ExecutionLoop et stream_processor saturent simultanément la VRAM
|
||||
# sur RTX 5070 (12 Go). Timeout généreux : un appel peut prendre 15-20s.
|
||||
with _acquire_gpu_context(timeout=60.0) as acquired:
|
||||
if not acquired:
|
||||
logger.warning(
|
||||
"Timeout en attendant le lock GPU pour ScreenAnalyzer.analyze() "
|
||||
"→ exécution sans sérialisation (risque saturation VRAM)"
|
||||
)
|
||||
|
||||
# Niveau 2 : Perception (OCR) — mesure du temps OCR
|
||||
ocr_t0 = time.time()
|
||||
detected_text = self._extract_text_with(ocr_instance, screenshot_path)
|
||||
ocr_ms = (time.time() - ocr_t0) * 1000
|
||||
|
||||
# Niveau 3 : UI Elements — mesure du temps détection
|
||||
ui_t0 = time.time()
|
||||
ui_elements = self._detect_ui_elements_with(
|
||||
ui_detector_instance, screenshot_path, window_info
|
||||
)
|
||||
ui_ms = (time.time() - ui_t0) * 1000
|
||||
|
||||
perception = PerceptionLevel(
|
||||
embedding=EmbeddingRef(
|
||||
provider="openclip_ViT-B-32",
|
||||
@@ -106,13 +238,10 @@ class ScreenAnalyzer:
|
||||
dimensions=512,
|
||||
),
|
||||
detected_text=detected_text,
|
||||
text_detection_method=self._get_ocr_method_name(),
|
||||
text_detection_method=self._get_ocr_method_name(ocr_instance),
|
||||
confidence_avg=0.85 if detected_text else 0.0,
|
||||
)
|
||||
|
||||
# Niveau 3 : UI Elements
|
||||
ui_elements = self._detect_ui_elements(screenshot_path, window_info)
|
||||
|
||||
# Niveau 4 : Contexte
|
||||
window_ctx = self._build_window_context(window_info)
|
||||
context_level = self._build_context_level(context)
|
||||
@@ -120,22 +249,28 @@ class ScreenAnalyzer:
|
||||
state = ScreenState(
|
||||
screen_state_id=state_id,
|
||||
timestamp=datetime.now(),
|
||||
session_id=self.session_id,
|
||||
session_id=effective_session_id,
|
||||
window=window_ctx,
|
||||
raw=raw,
|
||||
perception=perception,
|
||||
context=context_level,
|
||||
metadata={
|
||||
"analyzer_version": "1.0",
|
||||
"analyzer_version": "1.1",
|
||||
"session_id": effective_session_id,
|
||||
"ui_elements_count": len(ui_elements),
|
||||
"text_regions_count": len(detected_text),
|
||||
"ocr_ms": ocr_ms,
|
||||
"ui_ms": ui_ms,
|
||||
"ocr_enabled": enable_ocr,
|
||||
"ui_detection_enabled": enable_ui_detection,
|
||||
},
|
||||
ui_elements=ui_elements,
|
||||
)
|
||||
|
||||
logger.info(
|
||||
f"ScreenState {state_id} construit: "
|
||||
f"{len(ui_elements)} éléments UI, {len(detected_text)} textes détectés"
|
||||
f"{len(ui_elements)} éléments UI, {len(detected_text)} textes détectés "
|
||||
f"(ocr={enable_ocr}, ui={enable_ui_detection})"
|
||||
)
|
||||
return state
|
||||
|
||||
@@ -145,11 +280,16 @@ class ScreenAnalyzer:
|
||||
save_dir: str = "data/screens",
|
||||
window_info: Optional[Dict[str, Any]] = None,
|
||||
context: Optional[Dict[str, Any]] = None,
|
||||
*,
|
||||
enable_ocr: bool = True,
|
||||
enable_ui_detection: bool = True,
|
||||
session_id: str = "",
|
||||
) -> ScreenState:
|
||||
"""
|
||||
Analyser une PIL Image (utile quand on a déjà l'image en mémoire).
|
||||
|
||||
Sauvegarde l'image sur disque puis appelle analyze().
|
||||
Sauvegarde l'image sur disque puis appelle analyze(). Les flags
|
||||
runtime sont propagés à `analyze()` en kwargs-only.
|
||||
"""
|
||||
save_path = Path(save_dir)
|
||||
save_path.mkdir(parents=True, exist_ok=True)
|
||||
@@ -159,7 +299,49 @@ class ScreenAnalyzer:
|
||||
filepath = save_path / filename
|
||||
|
||||
image.save(str(filepath))
|
||||
return self.analyze(str(filepath), window_info=window_info, context=context)
|
||||
return self.analyze(
|
||||
str(filepath),
|
||||
window_info=window_info,
|
||||
context=context,
|
||||
enable_ocr=enable_ocr,
|
||||
enable_ui_detection=enable_ui_detection,
|
||||
session_id=session_id,
|
||||
)
|
||||
|
||||
# =========================================================================
|
||||
# Résolution des instances OCR / UI selon les flags d'appel
|
||||
# =========================================================================
|
||||
|
||||
def _resolve_ocr_instance(self, *, enable_ocr: bool):
|
||||
"""
|
||||
Retourner l'instance OCR à utiliser pour cet appel.
|
||||
|
||||
- `enable_ocr=False` → None (pas d'init, pas d'appel OCR)
|
||||
- sinon → init lazy sous lock si nécessaire, puis retour de `self._ocr`
|
||||
|
||||
Ne mute `self._ocr` / `self._ocr_initialized` QUE pendant l'init lazy
|
||||
réelle, jamais pour bypasser l'OCR d'un appel.
|
||||
"""
|
||||
if not enable_ocr:
|
||||
return None
|
||||
if not self._ocr_initialized:
|
||||
with self._init_lock:
|
||||
# Double-check : un autre thread a pu initialiser entretemps.
|
||||
if not self._ocr_initialized:
|
||||
self._ensure_ocr_locked()
|
||||
return self._ocr
|
||||
|
||||
def _resolve_ui_detector_instance(self, *, enable_ui_detection: bool):
|
||||
"""
|
||||
Retourner l'instance UIDetector pour cet appel (idem _resolve_ocr_instance).
|
||||
"""
|
||||
if not enable_ui_detection:
|
||||
return None
|
||||
if not self._ui_detector_initialized:
|
||||
with self._init_lock:
|
||||
if not self._ui_detector_initialized:
|
||||
self._ensure_ui_detector_locked()
|
||||
return self._ui_detector
|
||||
|
||||
# =========================================================================
|
||||
# Niveau 1 : Raw
|
||||
@@ -182,23 +364,24 @@ class ScreenAnalyzer:
|
||||
# Niveau 2 : Perception — OCR
|
||||
# =========================================================================
|
||||
|
||||
def _extract_text(self, screenshot_path: str) -> List[str]:
|
||||
"""Extraire le texte d'un screenshot via OCR."""
|
||||
self._ensure_ocr()
|
||||
|
||||
if self._ocr is None:
|
||||
def _extract_text_with(self, ocr_callable, screenshot_path: str) -> List[str]:
|
||||
"""Extraire le texte via un callable OCR donné (peut être None)."""
|
||||
if ocr_callable is None:
|
||||
return []
|
||||
|
||||
try:
|
||||
return self._ocr(screenshot_path)
|
||||
return ocr_callable(screenshot_path)
|
||||
except Exception as e:
|
||||
logger.warning(f"OCR échoué: {e}")
|
||||
return []
|
||||
|
||||
def _ensure_ocr(self) -> None:
|
||||
"""Initialiser le moteur OCR (lazy)."""
|
||||
if self._ocr_initialized:
|
||||
return
|
||||
def _ensure_ocr_locked(self) -> None:
|
||||
"""
|
||||
Initialiser le moteur OCR (appelé sous `self._init_lock`).
|
||||
|
||||
Ne doit PAS être appelé hors de `_resolve_ocr_instance()`.
|
||||
"""
|
||||
# Mutation intentionnelle : on installe l'instance OCR réelle.
|
||||
# Protégée par le lock d'init (pas le lock GPU).
|
||||
self._ocr_initialized = True
|
||||
|
||||
engine = self._ocr_engine_name
|
||||
@@ -257,8 +440,9 @@ class ScreenAnalyzer:
|
||||
|
||||
return ocr_func
|
||||
|
||||
def _get_ocr_method_name(self) -> str:
|
||||
if self._ocr is None:
|
||||
def _get_ocr_method_name(self, ocr_instance=None) -> str:
|
||||
"""Nom du moteur OCR effectivement utilisé pour cet appel."""
|
||||
if ocr_instance is None:
|
||||
return "none"
|
||||
if self._ocr_engine_name:
|
||||
return self._ocr_engine_name
|
||||
@@ -268,19 +452,18 @@ class ScreenAnalyzer:
|
||||
# Niveau 3 : UI Elements
|
||||
# =========================================================================
|
||||
|
||||
def _detect_ui_elements(
|
||||
def _detect_ui_elements_with(
|
||||
self,
|
||||
ui_detector,
|
||||
screenshot_path: str,
|
||||
window_info: Optional[Dict[str, Any]] = None,
|
||||
) -> List[UIElement]:
|
||||
"""Détecter les éléments UI dans le screenshot."""
|
||||
self._ensure_ui_detector()
|
||||
|
||||
if self._ui_detector is None:
|
||||
"""Détecter les éléments UI via un détecteur donné (peut être None)."""
|
||||
if ui_detector is None:
|
||||
return []
|
||||
|
||||
try:
|
||||
elements = self._ui_detector.detect(
|
||||
elements = ui_detector.detect(
|
||||
screenshot_path, window_context=window_info
|
||||
)
|
||||
return elements
|
||||
@@ -288,10 +471,10 @@ class ScreenAnalyzer:
|
||||
logger.warning(f"Détection UI échouée: {e}")
|
||||
return []
|
||||
|
||||
def _ensure_ui_detector(self) -> None:
|
||||
"""Initialiser le UIDetector (lazy)."""
|
||||
if self._ui_detector_initialized:
|
||||
return
|
||||
def _ensure_ui_detector_locked(self) -> None:
|
||||
"""
|
||||
Initialiser le UIDetector (appelé sous `self._init_lock`).
|
||||
"""
|
||||
self._ui_detector_initialized = True
|
||||
|
||||
try:
|
||||
|
||||
409
core/pipeline/screen_state_cache.py
Normal file
409
core/pipeline/screen_state_cache.py
Normal file
@@ -0,0 +1,409 @@
|
||||
"""
|
||||
ScreenStateCache — Cache perceptuel de ScreenState (context-aware).
|
||||
|
||||
Objectif : éviter de réanalyser un screenshot identique (5-15s VLM/OCR)
|
||||
à chaque step de la boucle d'exécution.
|
||||
|
||||
Principe (Lot D — avril 2026) :
|
||||
- Clé = composite de 6 éléments pour éviter les collisions silencieuses
|
||||
entre contextes différents partageant un même screenshot :
|
||||
1. phash (dhash 8x8 du screenshot) — calculé en ~2-5ms
|
||||
2. window_title (titre fenêtre active)
|
||||
3. app_name (nom process actif)
|
||||
4. enable_ocr (flag runtime)
|
||||
5. enable_ui_detection (flag runtime)
|
||||
6. workflow_id (isolation inter-workflows)
|
||||
- TTL par défaut : 2 secondes (configurable)
|
||||
- Invalidation explicite possible (par clé composite ou globale)
|
||||
- invalidate_if_changed reste piloté par le phash seul (détection de
|
||||
changement visuel majeur, indépendant du contexte)
|
||||
- Thread-safe (lock interne)
|
||||
|
||||
API principale :
|
||||
>>> cache = ScreenStateCache(ttl_seconds=2.0)
|
||||
>>> state, hit, ms = cache.get_or_compute(
|
||||
... screenshot_path, compute_fn,
|
||||
... window_title="App", app_name="app.exe",
|
||||
... enable_ocr=True, enable_ui_detection=True,
|
||||
... workflow_id="wf_123",
|
||||
... )
|
||||
|
||||
La fonction `compute_fn` prend le chemin du screenshot et doit retourner
|
||||
un `ScreenState`. Elle n'est appelée qu'en cache miss.
|
||||
"""
|
||||
|
||||
from __future__ import annotations
|
||||
|
||||
import hashlib
|
||||
import logging
|
||||
import threading
|
||||
import time
|
||||
from dataclasses import dataclass
|
||||
from pathlib import Path
|
||||
from typing import Callable, Optional, Tuple
|
||||
|
||||
from PIL import Image
|
||||
|
||||
from core.models.screen_state import ScreenState
|
||||
|
||||
logger = logging.getLogger(__name__)
|
||||
|
||||
|
||||
# =============================================================================
|
||||
# Hash perceptuel (dhash simple, sans dépendance imagehash)
|
||||
# =============================================================================
|
||||
|
||||
|
||||
def _hamming_distance_hex(a: str, b: str) -> int:
|
||||
"""
|
||||
Distance de Hamming entre deux chaînes hexadécimales de même longueur.
|
||||
|
||||
Retourne le nombre de bits qui diffèrent entre les deux hashes.
|
||||
Si les longueurs diffèrent, on pad à droite par des zéros.
|
||||
"""
|
||||
if len(a) != len(b):
|
||||
max_len = max(len(a), len(b))
|
||||
a = a.ljust(max_len, "0")
|
||||
b = b.ljust(max_len, "0")
|
||||
try:
|
||||
xor = int(a, 16) ^ int(b, 16)
|
||||
return bin(xor).count("1")
|
||||
except ValueError:
|
||||
# Fallback : comparaison caractère à caractère
|
||||
return sum(1 for ca, cb in zip(a, b) if ca != cb) * 4
|
||||
|
||||
|
||||
def compute_perceptual_hash(screenshot_path: str, size: int = 8) -> str:
|
||||
"""
|
||||
Calculer un dhash (difference hash) pour un screenshot.
|
||||
|
||||
Algorithme :
|
||||
1. Convertir en niveaux de gris
|
||||
2. Redimensionner à (size+1) x size
|
||||
3. Comparer chaque pixel avec son voisin de droite (dhash)
|
||||
4. Retourner un hash hexadécimal de size*size bits
|
||||
|
||||
Robuste aux petites variations (curseur, blink, compression).
|
||||
Coût typique : 2-5 ms sur un 1920x1080.
|
||||
|
||||
Args:
|
||||
screenshot_path: Chemin vers le fichier image
|
||||
size: Taille du hash (8 = 64 bits, défaut)
|
||||
|
||||
Returns:
|
||||
Chaîne hexadécimale (size*size/4 caractères)
|
||||
"""
|
||||
try:
|
||||
img = Image.open(screenshot_path)
|
||||
img = img.convert("L").resize((size + 1, size), Image.LANCZOS)
|
||||
pixels = list(img.getdata())
|
||||
|
||||
# dhash : comparer chaque pixel avec celui de droite
|
||||
bits = []
|
||||
for row in range(size):
|
||||
for col in range(size):
|
||||
left = pixels[row * (size + 1) + col]
|
||||
right = pixels[row * (size + 1) + col + 1]
|
||||
bits.append(1 if left > right else 0)
|
||||
|
||||
# Convertir en hex
|
||||
value = 0
|
||||
for bit in bits:
|
||||
value = (value << 1) | bit
|
||||
return format(value, f"0{size * size // 4}x")
|
||||
except Exception as e:
|
||||
logger.warning(f"Hash perceptuel échoué pour {screenshot_path}: {e}")
|
||||
# Fallback : hash du contenu brut
|
||||
try:
|
||||
data = Path(screenshot_path).read_bytes()
|
||||
return hashlib.md5(data).hexdigest()[:16]
|
||||
except Exception:
|
||||
return f"unhashable_{int(time.time() * 1000)}"
|
||||
|
||||
|
||||
# =============================================================================
|
||||
# Clé composite (Lot D)
|
||||
# =============================================================================
|
||||
|
||||
|
||||
def _make_cache_key(
|
||||
phash: str,
|
||||
window_title: str,
|
||||
app_name: str,
|
||||
enable_ocr: bool,
|
||||
enable_ui_detection: bool,
|
||||
workflow_id: str,
|
||||
) -> str:
|
||||
"""
|
||||
Construire une clé composite stable pour le cache.
|
||||
|
||||
Combine les 6 dimensions du contexte d'exécution dans une chaîne
|
||||
hexadécimale (md5 tronqué à 16 caractères), préfixée par le phash pour
|
||||
conserver une lisibilité minimale en debug (log : `aabb…|ctx=1234…`).
|
||||
|
||||
NB : On hash plutôt que concaténer brut pour :
|
||||
- Borner la taille de la clé même si window_title est long
|
||||
- Éviter les collisions triviales (séparateur présent dans un titre)
|
||||
- Rendre la clé opaque (pas de PII en clair dans les logs de cache)
|
||||
|
||||
Args:
|
||||
phash: Hash perceptuel du screenshot (dhash 8x8)
|
||||
window_title: Titre de la fenêtre active (str)
|
||||
app_name: Nom du process actif (str)
|
||||
enable_ocr: Flag runtime OCR (bool)
|
||||
enable_ui_detection: Flag runtime détection UI (bool)
|
||||
workflow_id: ID du workflow en cours (str, "" pour legacy)
|
||||
|
||||
Returns:
|
||||
Clé composite `{phash}|{ctx_hash}` où ctx_hash = md5(16)
|
||||
"""
|
||||
# Sérialisation déterministe ; `|` comme séparateur interne puisque hashé.
|
||||
ctx_repr = (
|
||||
f"{window_title or ''}\x1f"
|
||||
f"{app_name or ''}\x1f"
|
||||
f"{int(bool(enable_ocr))}\x1f"
|
||||
f"{int(bool(enable_ui_detection))}\x1f"
|
||||
f"{workflow_id or ''}"
|
||||
)
|
||||
ctx_hash = hashlib.md5(ctx_repr.encode("utf-8")).hexdigest()[:16]
|
||||
return f"{phash}|{ctx_hash}"
|
||||
|
||||
|
||||
# =============================================================================
|
||||
# Entry
|
||||
# =============================================================================
|
||||
|
||||
|
||||
@dataclass
|
||||
class _CacheEntry:
|
||||
state: ScreenState
|
||||
created_at: float
|
||||
phash: str # phash seul (utilisé par invalidate_if_changed)
|
||||
|
||||
|
||||
# =============================================================================
|
||||
# Cache
|
||||
# =============================================================================
|
||||
|
||||
|
||||
class ScreenStateCache:
|
||||
"""
|
||||
Cache de ScreenState avec TTL et clé composite context-aware.
|
||||
|
||||
Thread-safe. Utilise un lock interne pour les opérations get/set.
|
||||
"""
|
||||
|
||||
def __init__(self, ttl_seconds: float = 2.0, max_entries: int = 16):
|
||||
"""
|
||||
Args:
|
||||
ttl_seconds: Durée de vie d'une entrée (en secondes)
|
||||
max_entries: Nombre max d'entrées avant éviction LRU simple
|
||||
"""
|
||||
self.ttl_seconds = ttl_seconds
|
||||
self.max_entries = max_entries
|
||||
# Clé = composite (_make_cache_key), valeur = _CacheEntry
|
||||
self._store: dict[str, _CacheEntry] = {}
|
||||
self._lock = threading.Lock()
|
||||
|
||||
# Métriques simples (utile pour le debug / logs)
|
||||
self.hits = 0
|
||||
self.misses = 0
|
||||
self.invalidations = 0
|
||||
|
||||
# -------------------------------------------------------------------------
|
||||
# API bas niveau (par clé composite)
|
||||
# -------------------------------------------------------------------------
|
||||
|
||||
def _get(self, composite_key: str) -> Optional[ScreenState]:
|
||||
"""Retourne l'entrée pour cette clé composite si encore valide."""
|
||||
with self._lock:
|
||||
entry = self._store.get(composite_key)
|
||||
if entry is None:
|
||||
return None
|
||||
if time.time() - entry.created_at > self.ttl_seconds:
|
||||
# Expiré
|
||||
self._store.pop(composite_key, None)
|
||||
return None
|
||||
return entry.state
|
||||
|
||||
def _set(self, composite_key: str, phash: str, state: ScreenState) -> None:
|
||||
"""Enregistre un état pour cette clé composite."""
|
||||
with self._lock:
|
||||
# Éviction simple : si plein, virer l'entrée la plus ancienne
|
||||
if (
|
||||
len(self._store) >= self.max_entries
|
||||
and composite_key not in self._store
|
||||
):
|
||||
oldest_key = min(
|
||||
self._store, key=lambda k: self._store[k].created_at
|
||||
)
|
||||
self._store.pop(oldest_key, None)
|
||||
|
||||
self._store[composite_key] = _CacheEntry(
|
||||
state=state,
|
||||
created_at=time.time(),
|
||||
phash=phash,
|
||||
)
|
||||
|
||||
def invalidate(self, composite_key: Optional[str] = None) -> None:
|
||||
"""
|
||||
Invalider une entrée ou tout le cache.
|
||||
|
||||
Args:
|
||||
composite_key: Clé à invalider. Si None, vide tout le cache.
|
||||
"""
|
||||
with self._lock:
|
||||
if composite_key is None:
|
||||
self._store.clear()
|
||||
else:
|
||||
self._store.pop(composite_key, None)
|
||||
self.invalidations += 1
|
||||
|
||||
def invalidate_if_changed(
|
||||
self,
|
||||
screenshot_path: str,
|
||||
threshold: float = 0.3,
|
||||
) -> bool:
|
||||
"""
|
||||
Invalider le cache si l'écran a suffisamment changé.
|
||||
|
||||
Compare le dhash du screenshot courant avec le phash (seul) de chaque
|
||||
entrée du cache. La décision est volontairement indépendante du reste
|
||||
de la clé composite : un changement visuel majeur rend toutes les
|
||||
entrées obsolètes, quel que soit le contexte.
|
||||
|
||||
Args:
|
||||
screenshot_path: Chemin du screenshot courant
|
||||
threshold: Proportion de bits qui doivent différer (0.0-1.0).
|
||||
0.3 = 30% (~19 bits sur 64) = changement significatif.
|
||||
|
||||
Returns:
|
||||
True si le cache a été invalidé, False sinon.
|
||||
"""
|
||||
if not self._store:
|
||||
return False
|
||||
|
||||
current_phash = compute_perceptual_hash(screenshot_path)
|
||||
|
||||
# Bits totaux : 64 pour un dhash 8x8 standard. On déduit via la
|
||||
# longueur hexa du hash courant pour rester générique.
|
||||
total_bits = len(current_phash) * 4
|
||||
if total_bits == 0:
|
||||
return False
|
||||
|
||||
threshold_bits = threshold * total_bits
|
||||
|
||||
with self._lock:
|
||||
if not self._store:
|
||||
return False
|
||||
|
||||
# Distance de Hamming minimale avec les phashes des entrées
|
||||
# (on regarde entry.phash, pas la clé composite).
|
||||
min_distance = None
|
||||
for entry in self._store.values():
|
||||
distance = _hamming_distance_hex(current_phash, entry.phash)
|
||||
if min_distance is None or distance < min_distance:
|
||||
min_distance = distance
|
||||
|
||||
if min_distance is not None and min_distance > threshold_bits:
|
||||
size_before = len(self._store)
|
||||
self._store.clear()
|
||||
self.invalidations += 1
|
||||
logger.debug(
|
||||
f"[ScreenStateCache] invalidate_if_changed: "
|
||||
f"distance={min_distance}/{total_bits} > "
|
||||
f"threshold={threshold_bits:.1f} → {size_before} entrées purgées"
|
||||
)
|
||||
return True
|
||||
return False
|
||||
|
||||
# -------------------------------------------------------------------------
|
||||
# API haut niveau (context-aware)
|
||||
# -------------------------------------------------------------------------
|
||||
|
||||
def get_or_compute(
|
||||
self,
|
||||
screenshot_path: str,
|
||||
compute_fn: Callable[[str], ScreenState],
|
||||
*,
|
||||
window_title: str = "",
|
||||
app_name: str = "",
|
||||
enable_ocr: bool = True,
|
||||
enable_ui_detection: bool = True,
|
||||
workflow_id: str = "",
|
||||
force_refresh: bool = False,
|
||||
) -> Tuple[ScreenState, bool, float]:
|
||||
"""
|
||||
Récupérer ou calculer le ScreenState pour un screenshot + contexte.
|
||||
|
||||
Clé de cache = composite(phash, window_title, app_name, enable_ocr,
|
||||
enable_ui_detection, workflow_id). Deux contextes différents partageant
|
||||
le même screenshot n'entrent PAS en collision.
|
||||
|
||||
Rétrocompatibilité : tous les kwargs de contexte ont une valeur par
|
||||
défaut. Un caller legacy qui n'a pas encore été adapté partagera la
|
||||
même entrée de cache qu'un autre caller legacy (comportement antérieur).
|
||||
|
||||
Args:
|
||||
screenshot_path: Chemin du screenshot
|
||||
compute_fn: Fonction qui construit un ScreenState si cache miss
|
||||
window_title: Titre de la fenêtre active (contexte visuel)
|
||||
app_name: Nom du process actif (contexte applicatif)
|
||||
enable_ocr: Flag runtime — différencie états avec/sans OCR
|
||||
enable_ui_detection: Flag runtime — différencie états avec/sans UI
|
||||
workflow_id: ID du workflow — isolation inter-workflows
|
||||
force_refresh: Ignorer le cache et recalculer
|
||||
|
||||
Returns:
|
||||
Tuple (state, cache_hit, elapsed_ms)
|
||||
"""
|
||||
t0 = time.time()
|
||||
phash = compute_perceptual_hash(screenshot_path)
|
||||
composite_key = _make_cache_key(
|
||||
phash=phash,
|
||||
window_title=window_title,
|
||||
app_name=app_name,
|
||||
enable_ocr=enable_ocr,
|
||||
enable_ui_detection=enable_ui_detection,
|
||||
workflow_id=workflow_id,
|
||||
)
|
||||
|
||||
if not force_refresh:
|
||||
cached = self._get(composite_key)
|
||||
if cached is not None:
|
||||
self.hits += 1
|
||||
elapsed_ms = (time.time() - t0) * 1000
|
||||
logger.debug(
|
||||
f"[ScreenStateCache] HIT key={composite_key[:24]}… "
|
||||
f"({elapsed_ms:.1f}ms)"
|
||||
)
|
||||
return cached, True, elapsed_ms
|
||||
|
||||
# Cache miss → calcul complet
|
||||
self.misses += 1
|
||||
state = compute_fn(screenshot_path)
|
||||
self._set(composite_key, phash, state)
|
||||
elapsed_ms = (time.time() - t0) * 1000
|
||||
logger.debug(
|
||||
f"[ScreenStateCache] MISS key={composite_key[:24]}… "
|
||||
f"({elapsed_ms:.1f}ms)"
|
||||
)
|
||||
return state, False, elapsed_ms
|
||||
|
||||
def stats(self) -> dict:
|
||||
"""Retourne les métriques du cache."""
|
||||
with self._lock:
|
||||
total = self.hits + self.misses
|
||||
return {
|
||||
"hits": self.hits,
|
||||
"misses": self.misses,
|
||||
"invalidations": self.invalidations,
|
||||
"hit_rate": self.hits / total if total > 0 else 0.0,
|
||||
"size": len(self._store),
|
||||
"max_entries": self.max_entries,
|
||||
"ttl_seconds": self.ttl_seconds,
|
||||
}
|
||||
|
||||
def __len__(self) -> int:
|
||||
with self._lock:
|
||||
return len(self._store)
|
||||
@@ -137,10 +137,14 @@ class WorkflowPipeline:
|
||||
else:
|
||||
logger.warning(f"UI Detector not available: {e}")
|
||||
|
||||
# 6. Graph Builder
|
||||
# 6. Graph Builder — reçoit l'UIDetector pour enrichir les
|
||||
# ScreenStates avec ui_elements + OCR pendant _create_screen_states.
|
||||
# Sans ça, les TargetSpec ne peuvent pas être ancrés (by_role=unknown).
|
||||
self.graph_builder = GraphBuilder(
|
||||
embedding_builder=self.embedding_builder,
|
||||
faiss_manager=self.faiss_manager
|
||||
faiss_manager=self.faiss_manager,
|
||||
ui_detector=self.ui_detector,
|
||||
enable_ui_enrichment=enable_ui_detection,
|
||||
)
|
||||
logger.info("✓ Graph Builder initialized")
|
||||
|
||||
@@ -355,87 +359,177 @@ class WorkflowPipeline:
|
||||
# Mode MATCHING : Reconnaissance de l'état actuel
|
||||
# =========================================================================
|
||||
|
||||
def match_current_state(
|
||||
def match_current_state_from_state(
|
||||
self,
|
||||
screenshot_path: str,
|
||||
screen_state: ScreenState,
|
||||
workflow_id: Optional[str] = None,
|
||||
window_title: Optional[str] = None
|
||||
*,
|
||||
min_similarity: float = 0.5,
|
||||
) -> Optional[Dict[str, Any]]:
|
||||
"""
|
||||
Identifier dans quel node se trouve l'écran actuel.
|
||||
Matcher un ``ScreenState`` enrichi contre les nodes d'un workflow.
|
||||
|
||||
Lot E — premier vrai matching context-aware. Cette méthode consomme
|
||||
directement le ``ScreenState`` déjà construit par ``ExecutionLoop``
|
||||
(avec ``window_title``, ``detected_text`` et ``ui_elements``
|
||||
renseignés par le ``ScreenAnalyzer``) au lieu de reconstruire un
|
||||
stub vide avec ``window_title="Unknown"``.
|
||||
|
||||
Stratégie :
|
||||
1. Si le ``HierarchicalMatcher`` est disponible ET que le workflow
|
||||
cible est chargeable, on privilégie le matching multi-niveau
|
||||
(fenêtre → région → élément) qui exploite pleinement les
|
||||
``ui_elements`` et le ``window_title``.
|
||||
2. Sinon on retombe sur le matching par embedding via FAISS
|
||||
(même logique que l'ancien ``match_current_state``, mais avec
|
||||
le ``ScreenState`` fourni, pas un stub).
|
||||
|
||||
Args:
|
||||
screenshot_path: Chemin vers le screenshot actuel
|
||||
workflow_id: ID du workflow à matcher (tous si None)
|
||||
window_title: Titre de fenêtre pour contexte
|
||||
screen_state: ``ScreenState`` complet (ui_elements + detected_text
|
||||
+ window_info) construit en amont par l'``ExecutionLoop``.
|
||||
workflow_id: ID du workflow cible (tous si None).
|
||||
min_similarity: seuil minimum de confidence pour considérer un
|
||||
match valide. Conserve la sémantique historique (0.5 pour
|
||||
le hiérarchique, 0.85 pour le FAISS fallback).
|
||||
|
||||
Returns:
|
||||
Dict avec node_id, workflow_id, confidence, ou None si pas de match
|
||||
Dict avec ``node_id``, ``workflow_id``, ``confidence`` (+ détails
|
||||
du matching hiérarchique si applicable), ou ``None`` si aucun
|
||||
match ne dépasse le seuil.
|
||||
"""
|
||||
logger.debug(f"Matching screenshot: {screenshot_path}")
|
||||
|
||||
# Créer un ScreenState temporaire
|
||||
from core.models.screen_state import (
|
||||
WindowContext, RawLevel, PerceptionLevel, ContextLevel, EmbeddingRef
|
||||
logger.debug(
|
||||
"Matching ScreenState (app=%s, title=%s, ui_elements=%d, "
|
||||
"detected_text=%d)",
|
||||
screen_state.window.app_name,
|
||||
screen_state.window.window_title,
|
||||
len(screen_state.ui_elements),
|
||||
len(screen_state.perception.detected_text),
|
||||
)
|
||||
|
||||
screenshot_path = Path(screenshot_path)
|
||||
|
||||
window = WindowContext(
|
||||
app_name="unknown",
|
||||
window_title=window_title or "Unknown",
|
||||
screen_resolution=[1920, 1080],
|
||||
workspace="main"
|
||||
# --- Stratégie 1 : matching hiérarchique si workflow disponible ---
|
||||
if workflow_id:
|
||||
workflow = self.load_workflow(workflow_id)
|
||||
if workflow is not None and getattr(workflow, "nodes", None):
|
||||
try:
|
||||
hier_result = self._match_hierarchical_from_state(
|
||||
screen_state=screen_state,
|
||||
workflow=workflow,
|
||||
workflow_id=workflow_id,
|
||||
min_similarity=min_similarity,
|
||||
)
|
||||
if hier_result is not None:
|
||||
return hier_result
|
||||
except Exception as exc:
|
||||
# Ne jamais casser le matching sur une erreur du
|
||||
# matcher hiérarchique : on retombe sur FAISS.
|
||||
logger.debug(
|
||||
f"Hierarchical matching failed, fallback FAISS: {exc}"
|
||||
)
|
||||
|
||||
raw = RawLevel(
|
||||
screenshot_path=str(screenshot_path),
|
||||
capture_method="manual",
|
||||
file_size_bytes=screenshot_path.stat().st_size if screenshot_path.exists() else 0
|
||||
# --- Stratégie 2 : fallback embedding + FAISS ---
|
||||
return self._match_via_faiss(
|
||||
screen_state=screen_state,
|
||||
workflow_id=workflow_id,
|
||||
min_similarity=min_similarity,
|
||||
)
|
||||
|
||||
perception = PerceptionLevel(
|
||||
embedding=EmbeddingRef(
|
||||
provider="openclip_ViT-B-32",
|
||||
vector_id="temp",
|
||||
dimensions=512
|
||||
),
|
||||
detected_text=[],
|
||||
text_detection_method="pending",
|
||||
confidence_avg=0.0
|
||||
def _match_hierarchical_from_state(
|
||||
self,
|
||||
screen_state: ScreenState,
|
||||
workflow: Workflow,
|
||||
workflow_id: str,
|
||||
min_similarity: float,
|
||||
) -> Optional[Dict[str, Any]]:
|
||||
"""
|
||||
Déléguer le matching au ``HierarchicalMatcher`` en extrayant
|
||||
``window_info``, ``detected_elements`` et le screenshot à partir du
|
||||
``ScreenState`` fourni. Factorise la logique de ``match_hierarchical``
|
||||
sans re-ouvrir l'image si ce n'est pas nécessaire.
|
||||
"""
|
||||
# Reconstruire window_info à partir du ScreenState (pas "Unknown")
|
||||
window_info = {
|
||||
"title": screen_state.window.window_title,
|
||||
"app_name": screen_state.window.app_name,
|
||||
"window_title": screen_state.window.window_title,
|
||||
}
|
||||
detected_elements = list(screen_state.ui_elements)
|
||||
|
||||
# Ouvrir le screenshot si nécessaire (le matcher peut en avoir besoin
|
||||
# pour du matching au niveau région). Si le chemin n'existe pas, on
|
||||
# passe None et laisse le matcher travailler avec window + elements.
|
||||
screenshot = None
|
||||
path = screen_state.raw.screenshot_path
|
||||
if path:
|
||||
try:
|
||||
from PIL import Image
|
||||
screenshot = Image.open(path)
|
||||
except Exception as exc:
|
||||
logger.debug(f"Screenshot unavailable for hierarchical match: {exc}")
|
||||
|
||||
# Contexte temporel par workflow
|
||||
if workflow_id not in self._temporal_context:
|
||||
self._temporal_context[workflow_id] = TemporalContext()
|
||||
temporal_context = self._temporal_context[workflow_id]
|
||||
|
||||
result: MatchResult = self.hierarchical_matcher.match(
|
||||
screenshot=screenshot,
|
||||
workflow=workflow,
|
||||
window_info=window_info,
|
||||
detected_elements=detected_elements,
|
||||
temporal_context=temporal_context,
|
||||
)
|
||||
|
||||
context = ContextLevel(
|
||||
current_workflow_candidate=workflow_id,
|
||||
workflow_step=None,
|
||||
user_id="matcher",
|
||||
tags=[],
|
||||
business_variables={}
|
||||
if result.confidence < min_similarity:
|
||||
logger.debug(
|
||||
f"Hierarchical match below threshold: {result.confidence:.3f} "
|
||||
f"(min={min_similarity})"
|
||||
)
|
||||
return None
|
||||
|
||||
current_state = ScreenState(
|
||||
screen_state_id=f"match_{datetime.now().strftime('%Y%m%d_%H%M%S')}",
|
||||
timestamp=datetime.now(),
|
||||
session_id="matching",
|
||||
window=window,
|
||||
raw=raw,
|
||||
perception=perception,
|
||||
context=context,
|
||||
ui_elements=[]
|
||||
)
|
||||
# Mémoriser le match pour le boost temporel suivant
|
||||
temporal_context.add_match(result.node_id, result.confidence)
|
||||
|
||||
# Calculer embedding
|
||||
state_embedding = self.embedding_builder.build(current_state)
|
||||
return {
|
||||
"node_id": result.node_id,
|
||||
"workflow_id": workflow_id,
|
||||
"confidence": result.confidence,
|
||||
"window_confidence": result.window_confidence,
|
||||
"region_confidence": result.region_confidence,
|
||||
"element_confidence": result.element_confidence,
|
||||
"temporal_boost": result.temporal_boost,
|
||||
"matched_variant": result.matched_variant,
|
||||
"alternatives": [
|
||||
{"node_id": alt.node_id, "confidence": alt.confidence}
|
||||
for alt in result.alternatives
|
||||
],
|
||||
"match_time_ms": result.match_time_ms,
|
||||
"match_type": "hierarchical",
|
||||
}
|
||||
|
||||
def _match_via_faiss(
|
||||
self,
|
||||
screen_state: ScreenState,
|
||||
workflow_id: Optional[str],
|
||||
min_similarity: float,
|
||||
) -> Optional[Dict[str, Any]]:
|
||||
"""
|
||||
Fallback embedding + recherche FAISS. On réutilise le ``ScreenState``
|
||||
fourni (donc ses ``ui_elements`` et son ``window_title`` réels)
|
||||
au lieu d'en recréer un stub.
|
||||
"""
|
||||
# Le seuil FAISS historique était 0.85. On l'honore comme plancher
|
||||
# par défaut mais on respecte un ``min_similarity`` plus permissif
|
||||
# si l'appelant en fournit un (hiérarchique pouvant déjà avoir échoué).
|
||||
threshold = max(min_similarity, 0.85)
|
||||
|
||||
state_embedding = self.embedding_builder.build(screen_state)
|
||||
query_vector = state_embedding.get_vector()
|
||||
|
||||
# Rechercher dans FAISS
|
||||
results = self.faiss_manager.search(query_vector, k=5)
|
||||
|
||||
if not results:
|
||||
logger.debug("No match found in FAISS")
|
||||
return None
|
||||
|
||||
# Filtrer par workflow si spécifié
|
||||
for result in results:
|
||||
metadata = result.get("metadata", {})
|
||||
result_workflow_id = metadata.get("workflow_id")
|
||||
@@ -444,17 +538,136 @@ class WorkflowPipeline:
|
||||
continue
|
||||
|
||||
similarity = result.get("similarity", 0)
|
||||
if similarity >= 0.85: # Seuil de matching
|
||||
if similarity >= threshold:
|
||||
return {
|
||||
"node_id": metadata.get("node_id"),
|
||||
"workflow_id": result_workflow_id,
|
||||
"confidence": similarity,
|
||||
"state_embedding_id": state_embedding.embedding_id
|
||||
"state_embedding_id": state_embedding.embedding_id,
|
||||
"match_type": "faiss",
|
||||
}
|
||||
|
||||
logger.debug(f"Best match below threshold: {results[0].get('similarity', 0):.3f}")
|
||||
logger.debug(
|
||||
f"Best FAISS match below threshold: "
|
||||
f"{results[0].get('similarity', 0):.3f} (min={threshold})"
|
||||
)
|
||||
return None
|
||||
|
||||
def match_current_state(
|
||||
self,
|
||||
screenshot_path: str,
|
||||
workflow_id: Optional[str] = None,
|
||||
window_title: Optional[str] = None,
|
||||
) -> Optional[Dict[str, Any]]:
|
||||
"""
|
||||
Identifier dans quel node se trouve l'écran actuel (API legacy).
|
||||
|
||||
Lot E — cette méthode est désormais un **wrapper** de rétrocompat :
|
||||
elle construit un ``ScreenState`` enrichi via ``ScreenAnalyzer``
|
||||
(au lieu d'un stub avec ``window_title="Unknown"``) puis délègue
|
||||
à ``match_current_state_from_state``. Garantit la compat pour les
|
||||
callers externes qui ne manipulent que le chemin du screenshot.
|
||||
|
||||
Args:
|
||||
screenshot_path: Chemin vers le screenshot actuel.
|
||||
workflow_id: ID du workflow à matcher (tous si None).
|
||||
window_title: Titre de fenêtre pour contexte (utilisé comme
|
||||
hint si le ScreenAnalyzer n'est pas disponible).
|
||||
|
||||
Returns:
|
||||
Dict avec ``node_id``, ``workflow_id``, ``confidence``, ou
|
||||
``None`` si pas de match.
|
||||
"""
|
||||
logger.debug(f"Matching screenshot: {screenshot_path}")
|
||||
|
||||
# Construire un ScreenState enrichi via le ScreenAnalyzer partagé.
|
||||
screen_state = self._build_screen_state_for_matching(
|
||||
screenshot_path=screenshot_path,
|
||||
workflow_id=workflow_id,
|
||||
window_title=window_title,
|
||||
)
|
||||
|
||||
return self.match_current_state_from_state(
|
||||
screen_state=screen_state,
|
||||
workflow_id=workflow_id,
|
||||
)
|
||||
|
||||
def _build_screen_state_for_matching(
|
||||
self,
|
||||
screenshot_path: str,
|
||||
workflow_id: Optional[str],
|
||||
window_title: Optional[str],
|
||||
) -> ScreenState:
|
||||
"""
|
||||
Construire un ``ScreenState`` pour l'API legacy ``match_current_state``.
|
||||
|
||||
Tente d'utiliser le ``ScreenAnalyzer`` partagé ; en cas d'échec,
|
||||
retombe sur un stub minimaliste (équivalent fonctionnel de l'ancien
|
||||
comportement, mais clairement isolé ici).
|
||||
"""
|
||||
from core.models.screen_state import (
|
||||
WindowContext, RawLevel, PerceptionLevel, ContextLevel, EmbeddingRef
|
||||
)
|
||||
|
||||
path = Path(screenshot_path)
|
||||
|
||||
# Tentative 1 : ScreenAnalyzer partagé (résultat enrichi)
|
||||
try:
|
||||
from core.pipeline import get_screen_analyzer
|
||||
analyzer = get_screen_analyzer()
|
||||
if analyzer is not None:
|
||||
window_info = None
|
||||
if window_title:
|
||||
window_info = {"title": window_title, "app_name": "unknown"}
|
||||
return analyzer.analyze(
|
||||
str(path),
|
||||
window_info=window_info,
|
||||
)
|
||||
except Exception as exc:
|
||||
logger.debug(
|
||||
f"ScreenAnalyzer unavailable in match_current_state wrapper: {exc}"
|
||||
)
|
||||
|
||||
# Tentative 2 : stub minimal (comportement legacy d'urgence)
|
||||
window = WindowContext(
|
||||
app_name="unknown",
|
||||
window_title=window_title or "Unknown",
|
||||
screen_resolution=[1920, 1080],
|
||||
workspace="main",
|
||||
)
|
||||
raw = RawLevel(
|
||||
screenshot_path=str(path),
|
||||
capture_method="manual",
|
||||
file_size_bytes=path.stat().st_size if path.exists() else 0,
|
||||
)
|
||||
perception = PerceptionLevel(
|
||||
embedding=EmbeddingRef(
|
||||
provider="openclip_ViT-B-32",
|
||||
vector_id="temp",
|
||||
dimensions=512,
|
||||
),
|
||||
detected_text=[],
|
||||
text_detection_method="pending",
|
||||
confidence_avg=0.0,
|
||||
)
|
||||
context = ContextLevel(
|
||||
current_workflow_candidate=workflow_id,
|
||||
workflow_step=None,
|
||||
user_id="matcher",
|
||||
tags=[],
|
||||
business_variables={},
|
||||
)
|
||||
return ScreenState(
|
||||
screen_state_id=f"match_{datetime.now().strftime('%Y%m%d_%H%M%S')}",
|
||||
timestamp=datetime.now(),
|
||||
session_id="matching",
|
||||
window=window,
|
||||
raw=raw,
|
||||
perception=perception,
|
||||
context=context,
|
||||
ui_elements=[],
|
||||
)
|
||||
|
||||
def match_hierarchical(
|
||||
self,
|
||||
screenshot_path: str,
|
||||
@@ -548,17 +761,56 @@ class WorkflowPipeline:
|
||||
def get_next_action(
|
||||
self,
|
||||
workflow_id: str,
|
||||
current_node_id: str
|
||||
) -> Optional[Dict[str, Any]]:
|
||||
current_node_id: str,
|
||||
screen_state: Optional[ScreenState] = None,
|
||||
strategy: str = "best",
|
||||
source_similarity: float = 1.0,
|
||||
) -> Dict[str, Any]:
|
||||
"""
|
||||
Obtenir la prochaine action à exécuter.
|
||||
|
||||
Contrat normalisé (Lot A — avril 2026) : retourne **toujours** un
|
||||
dict avec une clé ``status`` non-ambiguë. Le ``None`` ambigu qui
|
||||
confondait "workflow terminé" et "aucun edge valide" a été
|
||||
supprimé : l'appelant (ExecutionLoop) peut désormais distinguer
|
||||
ces cas pour déclencher une pause supervisée plutôt qu'une fin
|
||||
de workflow faux-positive.
|
||||
|
||||
Sélection d'edge (C3) :
|
||||
- Filtre dur sur ``pre_conditions`` (EdgeConstraints)
|
||||
- Ranking par score composite (success_rate, target_match, recency)
|
||||
- Tiebreak : success_rate le plus haut
|
||||
|
||||
Args:
|
||||
workflow_id: ID du workflow
|
||||
current_node_id: ID du node actuel
|
||||
screen_state: État courant, requis pour évaluer les
|
||||
``pre_conditions`` et le match ``target_spec``. Si None,
|
||||
fallback sur la logique sans filtre de contraintes.
|
||||
strategy: ``"best"`` (défaut, scoring complet) ou ``"first"``
|
||||
(mode legacy, premier edge sans tri)
|
||||
source_similarity: confiance du matching (``match_current_state``)
|
||||
qui a identifié ``current_node_id``. Propagée à l'EdgeScorer
|
||||
pour activer la précondition ``min_source_similarity`` des
|
||||
edges. Défaut ``1.0`` pour compat avec les appelants qui
|
||||
ne la fournissent pas encore (Lot B — avril 2026).
|
||||
|
||||
Returns:
|
||||
Dict avec action, target_node, confidence, ou None
|
||||
Dict avec l'une des formes suivantes :
|
||||
|
||||
- ``{"status": "selected", "edge_id": str, "action": dict,
|
||||
"target_node": str, "confidence": float, "score": float}``
|
||||
→ edge sélectionné, l'ExecutionLoop doit l'exécuter.
|
||||
|
||||
- ``{"status": "terminal"}`` → le node courant n'a pas
|
||||
d'outgoing_edge (fin légitime de workflow).
|
||||
|
||||
- ``{"status": "blocked", "reason": str}`` → il existe des
|
||||
outgoing_edges mais aucun ne satisfait les conditions
|
||||
(``reason="no_valid_edge"``), ou le workflow est introuvable
|
||||
(``reason="workflow_not_found"``). L'ExecutionLoop doit
|
||||
déclencher une pause supervisée et ne **jamais** traiter
|
||||
ce cas comme un succès.
|
||||
"""
|
||||
workflow = self._workflows.get(workflow_id)
|
||||
if not workflow:
|
||||
@@ -569,23 +821,44 @@ class WorkflowPipeline:
|
||||
self._workflows[workflow_id] = workflow
|
||||
else:
|
||||
logger.error(f"Workflow not found: {workflow_id}")
|
||||
return None
|
||||
return {"status": "blocked", "reason": "workflow_not_found"}
|
||||
|
||||
# Trouver les edges sortants du node actuel
|
||||
outgoing_edges = workflow.get_outgoing_edges(current_node_id)
|
||||
|
||||
if not outgoing_edges:
|
||||
# Aucun outgoing_edge = fin légitime du workflow
|
||||
logger.info(f"No outgoing edges from node {current_node_id}")
|
||||
return None
|
||||
return {"status": "terminal"}
|
||||
|
||||
# Pour l'instant, prendre le premier edge (TODO: logique de sélection)
|
||||
edge = outgoing_edges[0]
|
||||
# Sélection robuste via EdgeScorer (C3)
|
||||
from core.pipeline.edge_scorer import EdgeScorer
|
||||
|
||||
scorer = EdgeScorer()
|
||||
edge = scorer.select_best(
|
||||
outgoing_edges,
|
||||
screen_state=screen_state,
|
||||
strategy=strategy,
|
||||
source_similarity=source_similarity,
|
||||
)
|
||||
|
||||
if edge is None:
|
||||
# Il y avait des candidats mais aucun n'a passé les filtres.
|
||||
# On NE retourne PAS "terminal" : l'ExecutionLoop doit traiter
|
||||
# ce cas comme un blocage et demander de l'aide.
|
||||
logger.warning(
|
||||
f"No valid edge from {current_node_id} "
|
||||
f"({len(outgoing_edges)} candidates rejected)"
|
||||
)
|
||||
return {"status": "blocked", "reason": "no_valid_edge"}
|
||||
|
||||
return {
|
||||
"status": "selected",
|
||||
"edge_id": edge.edge_id,
|
||||
"action": edge.action.to_dict(),
|
||||
"target_node": edge.to_node,
|
||||
"confidence": edge.stats.success_rate if edge.stats else 1.0
|
||||
"confidence": edge.stats.success_rate if edge.stats else 1.0,
|
||||
"score": edge.stats.success_rate if edge.stats else 1.0,
|
||||
}
|
||||
|
||||
def should_execute_automatically(self, workflow_id: str) -> bool:
|
||||
@@ -759,10 +1032,11 @@ class WorkflowPipeline:
|
||||
current_node_id = match_result["node_id"]
|
||||
logger.info(f"Matched current state to node: {current_node_id} (confidence: {match_result['confidence']:.3f})")
|
||||
|
||||
# 2. Obtenir la prochaine action
|
||||
# 2. Obtenir la prochaine action (contrat dict avec status explicite)
|
||||
action_info = self.get_next_action(workflow_id, current_node_id)
|
||||
action_status = action_info.get("status")
|
||||
|
||||
if not action_info:
|
||||
if action_status == "terminal":
|
||||
return {
|
||||
"execution_id": execution_id,
|
||||
"workflow_id": workflow_id,
|
||||
@@ -771,7 +1045,19 @@ class WorkflowPipeline:
|
||||
"message": "Workflow completed - no more actions",
|
||||
"current_node": current_node_id,
|
||||
"execution_time_ms": (datetime.now() - start_time).total_seconds() * 1000,
|
||||
"correlation_id": execution_id
|
||||
"correlation_id": execution_id,
|
||||
}
|
||||
|
||||
if action_status == "blocked":
|
||||
return {
|
||||
"execution_id": execution_id,
|
||||
"workflow_id": workflow_id,
|
||||
"success": False,
|
||||
"step_type": "action_selection",
|
||||
"error": f"No valid edge: {action_info.get('reason', 'unknown')}",
|
||||
"current_node": current_node_id,
|
||||
"execution_time_ms": (datetime.now() - start_time).total_seconds() * 1000,
|
||||
"correlation_id": execution_id,
|
||||
}
|
||||
|
||||
logger.info(f"Next action: {action_info['action']['type']} -> {action_info['target_node']}")
|
||||
|
||||
@@ -1,327 +0,0 @@
|
||||
e)a, field_namg(datin_loggsanitize_fordator.valieturn r()
|
||||
or_validatet_inputalidator = g""
|
||||
v
|
||||
"iséesnées sanit Don
|
||||
Returns:
|
||||
amp
|
||||
chNom du ame: field_ntiser
|
||||
s à saniata: Donnée d
|
||||
|
||||
Args:ging.
|
||||
le loges pours donnéSanitise de """
|
||||
-> str:
|
||||
"data") me: str = nay, field_ta: An(da_loggingize_for sanita
|
||||
|
||||
|
||||
defarsed_dat return p
|
||||
")
|
||||
errors)}t.uljoin(res {'; '.ed:ion failalidator(f"JSON vlidationErrise InputVa ralid:
|
||||
is_vat.not resul if
|
||||
")
|
||||
"json_datafield_name=e, th=max_sizr, max_lengring(json_stalidate_stvalidator.vt =
|
||||
resuldata)s(parsed_on.dump = js json_strtor()
|
||||
put_validaet_in gidator =s
|
||||
vales injectionur lontenu poider le c
|
||||
# Valt")
|
||||
dicng orbe strimust N data "JSOionError(putValidat raise In se:
|
||||
|
||||
elson_data_data = jparsed")
|
||||
size}max_ze of { maximum siexceedsN data rror(f"JSOValidationEaise Input r_size:
|
||||
lized) > maxlen(seria if a)
|
||||
s(json_dat json.dumpalized =eri sialisée
|
||||
ére sla taillrifier # Véct):
|
||||
ata, di_de(jsonncsinsta elif i
|
||||
t: {e}") JSON formaidror(f"InvalErdationalise InputV raie:
|
||||
ror as JSONDecodeErt json. excep n_data)
|
||||
loads(jsojson.= d_data parse
|
||||
try:
|
||||
size}")
|
||||
{max_mum size of axiceeds m data exONor(f"JSrrtionEputValidaise In ra
|
||||
max_size:a) >(json_datf len i
|
||||
data, str):json_isinstance( if ""
|
||||
" invalides
|
||||
sont ess donnéSi letionError: InputValida s:
|
||||
Raise
|
||||
|
||||
ON validéess JS Donnéeurns:
|
||||
|
||||
Ret s
|
||||
n caractèremale exille maax_size: Tai mou dict)
|
||||
string nnées JSON (: Do_data json
|
||||
|
||||
Args: .
|
||||
nnées JSONdo Valide des "
|
||||
|
||||
"") -> dict:= 10000x_size: int t], man[str, dicnion_data: Uput(jsoe_json_inalidat
|
||||
|
||||
|
||||
def ved_pathurn normaliz ret
|
||||
|
||||
")ath}malized_pories: {norwed directllon apath not ior(f"File ionErratlide InputVa rais ):
|
||||
rslowed_di_dir in al for allowedr)d_diallowe.startswith(_obj)str(pathot any( if n)
|
||||
alized_pathPath(normpath_obj = :
|
||||
_dirsif allowed
|
||||
i spécifiésautorisés soires répertrifier lesVé
|
||||
# ")
|
||||
xt}n: {file_extensio engerous filer(f"DaolationErroyVi Securit raisensions:
|
||||
xtegerous_ext in danf file_e()
|
||||
ix.lowerath).suffied_pnormalizxt = Path( file_e p', '.sh'}
|
||||
.ph', ' '.jscr', '.vbs', '.s, '.cmd',xe', '.bat'{'.ensions = ngerous_exte dauses
|
||||
angereons densies exter l Vérifi
|
||||
#_path}")
|
||||
{file detected:attemptl raversa t"Pathrror(fationEyViol Securitise ra"/"):
|
||||
ith(path.startswd_or normalizelized_path in norma ".." ifl
|
||||
rsaraveh tives de patntat les teVérifier # )
|
||||
|
||||
_pathle.normpath(fih = os.pathpatrmalized_ noin
|
||||
ser le chem# Normali
|
||||
ng")
|
||||
t be a strile path mus"Fir(dationErroalise InputV raitr):
|
||||
th, se_pailsinstance(ft i if no
|
||||
"""
|
||||
ngereux dae chemin estError: Si lionnputValidat I
|
||||
aises:
|
||||
R
|
||||
sénormalit min validé e Che
|
||||
Returns:
|
||||
|
||||
orisésutres ars: Répertoilowed_di al valider
|
||||
n àhemie_path: C filgs:
|
||||
Ar
|
||||
chier.
|
||||
hemin de fialide un c V"
|
||||
" ":
|
||||
trne) -> s No] =str]List[ional[rs: Optwed_di: str, allole_pathath_input(fifile_plidate_vae
|
||||
|
||||
|
||||
def ized_valuresult.sanitreturn
|
||||
|
||||
.errors)}").join(resulte}: {'; 'field_named for {dation failf"ValinError(idatio InputValserai is_valid:
|
||||
t.ul not res
|
||||
if_name)
|
||||
_html, fieldength, allow, max_lring(valuealidate_stidator.vval = resultor()
|
||||
idatt_input_valator = ge"
|
||||
valid""ue
|
||||
échotionlidai la vaor: SdationErrnputVali Is:
|
||||
se
|
||||
Rai
|
||||
nitisée sa Valeureturns:
|
||||
R
|
||||
p
|
||||
du chamm d_name: No fiel HTML
|
||||
oriser leow_html: Aut all ximale
|
||||
Longueur mamax_length: r
|
||||
r à valideue: Valeu val Args:
|
||||
|
||||
|
||||
ée string.e une entranitisalide et s
|
||||
V"""r:
|
||||
t") -> st= "inpue: str e, field_namalsool = Fw_html: b allo
|
||||
1000, ength: int =max_lvalue: str, ut(ing_inpvalidate_str
|
||||
|
||||
|
||||
def r_instancern _validato)
|
||||
retudator(alie = InputVancinstalidator_ _v one:
|
||||
tance is Nor_insf _validat
|
||||
itancer_insal _validatolob"
|
||||
g""r
|
||||
alidateuu vstance d Inturns:
|
||||
Re
|
||||
r.
|
||||
teuida du valobaleinstance glourne l' Ret""
|
||||
"or:
|
||||
lidatputVa-> Inr() dato_valit_inputef geNone
|
||||
|
||||
|
||||
d= ] putValidatoronal[Inance: Optilidator_instidateur
|
||||
_va du val globalencesta
|
||||
# In )
|
||||
|
||||
}"
|
||||
_valuezedue: {saniti f"Val . "
|
||||
field_name}ype} in {ation_tvioltected: {iolation dey vf"Securit rning(
|
||||
ger.wa logame)
|
||||
e, field_ng(valuor_logginf.sanitize_f selalue =tized_v sani""
|
||||
té."ride sécuion violatg une Lo """:
|
||||
ny) -> Nonevalue: A_name: str, ldier, fn_type: stolatioon(self, viati_violitylog_secur _
|
||||
def _}]"
|
||||
e_(data).__namntable:{typeme}[unpri{field_nareturn f"
|
||||
ion:cept Except ex
|
||||
ata_str
|
||||
turn d re
|
||||
tr)
|
||||
scape(data_s html.e data_str =
|
||||
dangereuxres es caractèhapper l # Éc
|
||||
."
|
||||
"..r[:200] + ata_stata_str = d d
|
||||
0:r) > 20ata_st if len(d s
|
||||
our les log taille pr la # Limite
|
||||
|
||||
ta)r(dastr = st data_ else:
|
||||
|
||||
, ':')),'s=('eparatore, s_ascii=Trunsurea, e(dat.dumps json = data_str
|
||||
ct, list)): (dia,nstance(datsi if i
|
||||
try:le
|
||||
aila tter lg et limi en strinonvertir # C
|
||||
]"
|
||||
{len(data)}_}:size=a).__name_(dattypeme}[{{field_naturn f" re :
|
||||
))istta, (dict, ltance(daisinsif el )}]"
|
||||
lue(datave_vasensitish:{hash_e}[haield_namf"{f return
|
||||
> 20:d len(data)str) ane(data, sinstanc if is
|
||||
ensiblenées ss donhasher lerisé, En mode sécu # itive:
|
||||
ensself.log_s not if ""
|
||||
|
||||
"r logging pouestisénées saniDon
|
||||
Returns:
|
||||
|
||||
pom du chameld_name: N fi er
|
||||
itis sanes àata: Donné d gs:
|
||||
Ar
|
||||
sécurisé.
|
||||
le logging pouronnéess dnitise de Sa ""
|
||||
" ) -> str:
|
||||
ata"tr = "dd_name: sy, fiel: Anlf, dataging(seogze_for_lef saniti
|
||||
dngs)
|
||||
ors, warninitized, err sa_valid,ult(isationReslid return Va
|
||||
s) == 0error= len(valid is_
|
||||
itized)
|
||||
, san7F]', ''\x1F\x0C\x0E-\x0B8\x0-\x0r'[\x0e.sub(= r sanitized ôle
|
||||
ntrctères de cocaraoyer les # Nett
|
||||
|
||||
anitized).escape(s = html sanitized :
|
||||
allow_html if not ire
|
||||
si nécessatizer HTML# Sani
|
||||
)
|
||||
"SQL patternspicious Noains suntld_name} cofiepend(f"{ngs.ap warni else:
|
||||
|
||||
value)e,nam", field_ attemptionjectQL inlation("NoSecurity_vioog_s._l self ")
|
||||
ernection pattl NoSQL injs potentiae} containd_nam{fiel(f"penderrors.ap
|
||||
_mode:lf.strictse if lue):
|
||||
(vaern.searchif patt ns:
|
||||
atterf._nosql_prn in selte for patSQL
|
||||
njections Nofier les i # Véri
|
||||
")
|
||||
QL pattern Suspiciousontains seld_name} c{fiappend(f"arnings. w:
|
||||
else e)
|
||||
, valu_nameeld, fipt"ection attem"SQL injiolation(security_vg_loself._ )
|
||||
on pattern"L injectiotential SQontains p_name} c"{fieldppend(f.aors err e:
|
||||
.strict_modself if alue):
|
||||
rn.search(vatteif p patterns:
|
||||
sql_f._eln spattern i for ons SQL
|
||||
tir les injecVérifie #
|
||||
|
||||
x_length] value[:matized = sani ers")
|
||||
th} charact{max_lengcated to _name} trunf"{fieldend(s.app warning else:
|
||||
|
||||
}")ax_length{mf length oimum eeds maxe} exc"{field_nam(fpend errors.ap ct_mode:
|
||||
f self.stri ih:
|
||||
lengtalue) > max_ if len(vueur
|
||||
longVérifier la
|
||||
# s)
|
||||
ors, warningne, errt(False, NoonResulidati return Val tring")
|
||||
t be a smusd_name} f"{fielrs.append( erro
|
||||
, str):ce(valueisinstan if not
|
||||
ue
|
||||
d = valanitize sgs = []
|
||||
nin war
|
||||
errors = []"
|
||||
"" alidation
|
||||
vt de Résulta eturns:
|
||||
R
|
||||
s
|
||||
our les logdu champ pNom : ld_name fie HTML
|
||||
toriser le w_html: Au allo e
|
||||
aximalgueur mh: Lonengt max_lder
|
||||
valiue: Valeur à val:
|
||||
Args
|
||||
.
|
||||
tèresde carac chaîne Valide une"
|
||||
"" lt:
|
||||
esuValidationRput") -> : str = "infield_name= False, tml: bool allow_h ,
|
||||
000h: int = 1 max_lengtstr,f, value: (selring validate_st def
|
||||
ERNS]
|
||||
TTN_PAJECTIOlf.NOSQL_INttern in seor paE) fCASe.IGNOREttern, re(pa.compil= [rerns patteself._nosql_ RNS]
|
||||
TE_PATL_INJECTION in self.SQfor patternNORECASE) re.IGtern,compile(pate. = [rerns_sql_pattf. selformance
|
||||
pour pers patterns lepiler # Com
|
||||
ata
|
||||
ive_d.log_sensitive = configsit_sen self.log
|
||||
ationinput_valid.strict_se configels not None _mode istrictct_mode if striict_mode = self.str nfig()
|
||||
security_coig = get_ conf""
|
||||
"g)
|
||||
selon confi auto (None =strictde: Mode strict_mo
|
||||
Args:
|
||||
|
||||
ur.datese le vali Initiali """
|
||||
:
|
||||
one)l] = N[boo: Optionalt_mode stric_(self,it_def __in
|
||||
]
|
||||
)"
|
||||
\.|db\.is r"(th
|
||||
\})",\s*\$.* r"(\{
|
||||
meout\b)",etTil\b|\bs\(|\bevaction\s*"(funr nin)",
|
||||
in|\$gt|\$lt|\$\$e|\$regex|\$n"(\$where| r [
|
||||
TTERNS =CTION_PAL_INJEOSQ N n NoSQL
|
||||
ctiour injengereux poatterns da # P]
|
||||
|
||||
"
|
||||
b)\qlbsp_executes"(\
|
||||
r",dshell\b)bxp_cm r"(\
|
||||
)",[\'\";]r"( )\b)",
|
||||
ONERRORAD|T|ONLOBSCRIP|VIPTAVASCRSCRIPT|J(\b( r" */)",
|
||||
--|#|/\*|\ r"( ",
|
||||
+)s*=\s*\d\AND)\s+\d+(UNION|OR|\b r"(
|
||||
b)",\UTE)EXEC|EXECE|ALTER|OP|CREATDRELETE|ERT|UPDATE|Db(SELECT|INS r"(\
|
||||
RNS = [N_PATTE_INJECTIOSQL
|
||||
SQLnjection ereux pour irns dangtte# Pa
|
||||
|
||||
""teur."s utilisaeur d'entréeidatVal"" "ator:
|
||||
Valids Inputclas
|
||||
|
||||
pass
|
||||
""
|
||||
ée."tectécurité déolation de s"Vi"" Error):
|
||||
tValidationnError(InpuyViolatioSecurit
|
||||
|
||||
class pass
|
||||
"
|
||||
rée.""nton d'ealidatieur de v""Err "
|
||||
ion):r(ExceptidationErroputValass In= []
|
||||
|
||||
|
||||
clf.warnings sel:
|
||||
None isarnings self.w ifors = []
|
||||
elf.err sne:
|
||||
is Nororser if self.
|
||||
lf):init__(seost_def __p
|
||||
r]
|
||||
[sts: Listningwar[str]
|
||||
istrs: L erroue: Any
|
||||
ed_val sanitiz: bool
|
||||
lid
|
||||
is_va"""
|
||||
une entrée.dation d' de valitat"Résul""lt:
|
||||
ationResuclass Validaclass
|
||||
dat
|
||||
|
||||
@_)
|
||||
ame_etLogger(__ngging.g
|
||||
logger = lolue
|
||||
ive_vaash_sensitonfig, h_cecurityimport get_srity_config .secu
|
||||
|
||||
from dataclassrtpoimdataclasses
|
||||
from Union, SetOptional,, List, Any, Dict import ng
|
||||
from typirt Pathimpoib thlfrom pajson
|
||||
|
||||
import l htmortlogging
|
||||
impe
|
||||
import port r
|
||||
imrt ospo"
|
||||
|
||||
im"ggées
|
||||
"données loization des 7.4: Sanit
|
||||
Exigence s chiers de fin des chemintioalida3: VExigence 7.
|
||||
SQL/NoSQLonsti injeccontre lesion ectotence 7.2: PrExigé.
|
||||
a sécuritur lteur polisatrées utiion des envalidat
|
||||
Système de m
|
||||
stedation Syut Vali"""
|
||||
Inp
|
||||
308
core/security/signed_serializer.py
Normal file
308
core/security/signed_serializer.py
Normal file
@@ -0,0 +1,308 @@
|
||||
"""
|
||||
Sérialiseur signé — RPA Vision V3
|
||||
|
||||
Remplace les usages de `pickle.load` (vulnérables à la désérialisation arbitraire
|
||||
de code) par une sérialisation JSON signée via HMAC-SHA256.
|
||||
|
||||
Principes :
|
||||
- Les données sont sérialisées en JSON (avec support des types numpy / datetime
|
||||
via un encodeur custom).
|
||||
- Une signature HMAC-SHA256 est calculée sur le JSON avec une clé secrète
|
||||
dérivée de `RPA_SIGNING_KEY` (ou, à défaut, de `TOKEN_SECRET_KEY`).
|
||||
- À la lecture, la signature est vérifiée AVANT tout parsing applicatif.
|
||||
- Rétrocompatibilité : un fallback `pickle.load` est disponible pour migrer
|
||||
les anciens fichiers. Il logue un WARNING et doit être suivi d'une
|
||||
ré-écriture en JSON signé.
|
||||
|
||||
ATTENTION : n'utiliser le fallback pickle que sur des fichiers dont la source
|
||||
est réputée sûre (locale + protégée). Le fallback est désactivable via la
|
||||
variable d'environnement `RPA_ALLOW_PICKLE_FALLBACK=0`.
|
||||
"""
|
||||
|
||||
from __future__ import annotations
|
||||
|
||||
import base64
|
||||
import hashlib
|
||||
import hmac
|
||||
import io
|
||||
import json
|
||||
import logging
|
||||
import os
|
||||
import pickle
|
||||
from datetime import datetime, timedelta
|
||||
from pathlib import Path
|
||||
from typing import Any, Callable, Optional, Union
|
||||
|
||||
import numpy as np
|
||||
|
||||
logger = logging.getLogger(__name__)
|
||||
|
||||
# -----------------------------------------------------------------------------
|
||||
# Clé de signature
|
||||
# -----------------------------------------------------------------------------
|
||||
|
||||
_SIGNATURE_ALGO = "sha256"
|
||||
_SIGNATURE_HEADER = b"RPA_SIGNED_V1\n" # Marqueur de format signé
|
||||
|
||||
|
||||
def _resolve_signing_key() -> bytes:
|
||||
"""Récupère la clé de signature HMAC.
|
||||
|
||||
Ordre de priorité :
|
||||
1. RPA_SIGNING_KEY (dédiée à la signature de fichiers)
|
||||
2. TOKEN_SECRET_KEY (clé déjà utilisée pour signer les tokens API)
|
||||
3. Clé dérivée en dev (avec WARNING)
|
||||
|
||||
La clé dev est stable pour une même machine (dérivée du hostname + path)
|
||||
afin que les lectures/écritures locales restent cohérentes en l'absence
|
||||
de configuration, tout en refusant de valider des fichiers produits
|
||||
ailleurs.
|
||||
"""
|
||||
explicit = os.getenv("RPA_SIGNING_KEY", "").strip()
|
||||
if explicit:
|
||||
return explicit.encode("utf-8")
|
||||
|
||||
fallback = os.getenv("TOKEN_SECRET_KEY", "").strip()
|
||||
if fallback:
|
||||
return fallback.encode("utf-8")
|
||||
|
||||
# Clé dev dérivée : non cryptographiquement sûre, juste pour éviter des
|
||||
# erreurs en dev local. On loggue explicitement.
|
||||
logger.warning(
|
||||
"RPA_SIGNING_KEY et TOKEN_SECRET_KEY non définis — "
|
||||
"utilisation d'une clé dérivée locale. "
|
||||
"Définir RPA_SIGNING_KEY en production."
|
||||
)
|
||||
seed = f"rpa-vision-v3::{os.uname().nodename}::dev-signing" # type: ignore[attr-defined]
|
||||
return hashlib.sha256(seed.encode("utf-8")).digest()
|
||||
|
||||
|
||||
# -----------------------------------------------------------------------------
|
||||
# Encodage JSON étendu (numpy, datetime, Path, bytes)
|
||||
# -----------------------------------------------------------------------------
|
||||
|
||||
class _RPAJSONEncoder(json.JSONEncoder):
|
||||
"""Encodeur JSON supportant numpy / datetime / Path / bytes."""
|
||||
|
||||
def default(self, obj: Any) -> Any: # noqa: D401 - API json standard
|
||||
if isinstance(obj, np.ndarray):
|
||||
return {
|
||||
"__type__": "ndarray",
|
||||
"dtype": str(obj.dtype),
|
||||
"shape": list(obj.shape),
|
||||
"data": base64.b64encode(obj.tobytes()).decode("ascii"),
|
||||
}
|
||||
if isinstance(obj, (np.integer,)):
|
||||
return int(obj)
|
||||
if isinstance(obj, (np.floating,)):
|
||||
return float(obj)
|
||||
if isinstance(obj, (np.bool_,)):
|
||||
return bool(obj)
|
||||
if isinstance(obj, datetime):
|
||||
return {"__type__": "datetime", "iso": obj.isoformat()}
|
||||
if isinstance(obj, timedelta):
|
||||
return {"__type__": "timedelta", "seconds": obj.total_seconds()}
|
||||
if isinstance(obj, Path):
|
||||
return {"__type__": "path", "value": str(obj)}
|
||||
if isinstance(obj, bytes):
|
||||
return {
|
||||
"__type__": "bytes",
|
||||
"data": base64.b64encode(obj).decode("ascii"),
|
||||
}
|
||||
if isinstance(obj, set):
|
||||
return {"__type__": "set", "items": list(obj)}
|
||||
return super().default(obj)
|
||||
|
||||
|
||||
def _json_object_hook(obj: Any) -> Any:
|
||||
"""Reconstruit les types étendus depuis le JSON."""
|
||||
if not isinstance(obj, dict):
|
||||
return obj
|
||||
tag = obj.get("__type__")
|
||||
if tag is None:
|
||||
return obj
|
||||
if tag == "ndarray":
|
||||
raw = base64.b64decode(obj["data"])
|
||||
arr = np.frombuffer(raw, dtype=np.dtype(obj["dtype"]))
|
||||
return arr.reshape(obj["shape"]).copy()
|
||||
if tag == "datetime":
|
||||
return datetime.fromisoformat(obj["iso"])
|
||||
if tag == "timedelta":
|
||||
return timedelta(seconds=float(obj["seconds"]))
|
||||
if tag == "path":
|
||||
return Path(obj["value"])
|
||||
if tag == "bytes":
|
||||
return base64.b64decode(obj["data"])
|
||||
if tag == "set":
|
||||
return set(obj.get("items", []))
|
||||
return obj
|
||||
|
||||
|
||||
# -----------------------------------------------------------------------------
|
||||
# Erreurs dédiées
|
||||
# -----------------------------------------------------------------------------
|
||||
|
||||
class SignedSerializerError(Exception):
|
||||
"""Erreur de base du module."""
|
||||
|
||||
|
||||
class SignatureVerificationError(SignedSerializerError):
|
||||
"""Signature HMAC invalide : le fichier a été altéré ou forgé."""
|
||||
|
||||
|
||||
class UnsupportedFormatError(SignedSerializerError):
|
||||
"""Le fichier n'est ni au format signé, ni reconnu comme pickle legacy."""
|
||||
|
||||
|
||||
# -----------------------------------------------------------------------------
|
||||
# API publique
|
||||
# -----------------------------------------------------------------------------
|
||||
|
||||
def _compute_hmac(payload: bytes, key: bytes) -> str:
|
||||
return hmac.new(key, payload, hashlib.sha256).hexdigest()
|
||||
|
||||
|
||||
def dumps_signed(data: Any, key: Optional[bytes] = None) -> bytes:
|
||||
"""Sérialise `data` en JSON signé HMAC-SHA256.
|
||||
|
||||
Format binaire retourné :
|
||||
b"RPA_SIGNED_V1\n" + utf8(json({"hmac": "<hex>", "payload": <data>}))
|
||||
|
||||
Le HMAC couvre le JSON canonique de `payload` (keys triées,
|
||||
séparateurs compacts) pour qu'un même objet produise toujours la
|
||||
même signature.
|
||||
"""
|
||||
signing_key = key if key is not None else _resolve_signing_key()
|
||||
payload_json = json.dumps(
|
||||
data,
|
||||
cls=_RPAJSONEncoder,
|
||||
sort_keys=True,
|
||||
separators=(",", ":"),
|
||||
ensure_ascii=False,
|
||||
).encode("utf-8")
|
||||
signature = _compute_hmac(payload_json, signing_key)
|
||||
envelope = {"hmac": signature, "payload_b64": base64.b64encode(payload_json).decode("ascii")}
|
||||
body = json.dumps(envelope, separators=(",", ":"), ensure_ascii=False).encode("utf-8")
|
||||
return _SIGNATURE_HEADER + body
|
||||
|
||||
|
||||
def loads_signed(raw: bytes, key: Optional[bytes] = None) -> Any:
|
||||
"""Désérialise un blob produit par `dumps_signed` après vérification HMAC."""
|
||||
if not raw.startswith(_SIGNATURE_HEADER):
|
||||
raise UnsupportedFormatError("Marqueur RPA_SIGNED_V1 absent.")
|
||||
signing_key = key if key is not None else _resolve_signing_key()
|
||||
body = raw[len(_SIGNATURE_HEADER):]
|
||||
try:
|
||||
envelope = json.loads(body.decode("utf-8"))
|
||||
except (UnicodeDecodeError, json.JSONDecodeError) as exc:
|
||||
raise SignedSerializerError(f"Enveloppe JSON invalide : {exc}") from exc
|
||||
|
||||
if not isinstance(envelope, dict):
|
||||
raise SignedSerializerError("Enveloppe inattendue.")
|
||||
signature = envelope.get("hmac")
|
||||
payload_b64 = envelope.get("payload_b64")
|
||||
if not isinstance(signature, str) or not isinstance(payload_b64, str):
|
||||
raise SignedSerializerError("Enveloppe mal formée (hmac / payload_b64).")
|
||||
|
||||
try:
|
||||
payload_bytes = base64.b64decode(payload_b64.encode("ascii"), validate=True)
|
||||
except Exception as exc: # noqa: BLE001 - base64 peut lever plusieurs erreurs
|
||||
raise SignedSerializerError(f"Payload base64 invalide : {exc}") from exc
|
||||
|
||||
expected = _compute_hmac(payload_bytes, signing_key)
|
||||
if not hmac.compare_digest(expected, signature):
|
||||
raise SignatureVerificationError(
|
||||
"Signature HMAC invalide — fichier altéré ou clé différente."
|
||||
)
|
||||
|
||||
return json.loads(payload_bytes.decode("utf-8"), object_hook=_json_object_hook)
|
||||
|
||||
|
||||
def _pickle_fallback_allowed() -> bool:
|
||||
return os.getenv("RPA_ALLOW_PICKLE_FALLBACK", "1") != "0"
|
||||
|
||||
|
||||
def save_signed(path: Union[str, Path], data: Any, key: Optional[bytes] = None) -> None:
|
||||
"""Écrit `data` sur disque dans le format JSON signé."""
|
||||
path = Path(path)
|
||||
path.parent.mkdir(parents=True, exist_ok=True)
|
||||
blob = dumps_signed(data, key=key)
|
||||
tmp = path.with_suffix(path.suffix + ".tmp")
|
||||
with open(tmp, "wb") as fp:
|
||||
fp.write(blob)
|
||||
os.replace(tmp, path)
|
||||
|
||||
|
||||
def load_signed(
|
||||
path: Union[str, Path],
|
||||
*,
|
||||
allow_pickle_fallback: bool = True,
|
||||
migrate_on_fallback: bool = True,
|
||||
pickle_loader: Optional[Callable[[io.BufferedReader], Any]] = None,
|
||||
key: Optional[bytes] = None,
|
||||
) -> Any:
|
||||
"""Charge un fichier sauvegardé par `save_signed`.
|
||||
|
||||
Si le fichier n'est pas au format signé, et si `allow_pickle_fallback`
|
||||
est vrai (ET `RPA_ALLOW_PICKLE_FALLBACK != "0"`), tente un
|
||||
`pickle.load()` pour migrer les anciens fichiers. Dans ce cas, un
|
||||
WARNING est émis et le fichier est ré-écrit en JSON signé si
|
||||
`migrate_on_fallback` vaut True.
|
||||
|
||||
Args:
|
||||
path: Chemin du fichier
|
||||
allow_pickle_fallback: Activer la compat legacy
|
||||
migrate_on_fallback: Ré-écrire en JSON signé après fallback
|
||||
pickle_loader: Callable alternatif (pour tests / restricted unpickler)
|
||||
key: Clé HMAC explicite (sinon dérivée de l'environnement)
|
||||
|
||||
Raises:
|
||||
SignatureVerificationError: HMAC invalide (fichier altéré)
|
||||
UnsupportedFormatError: format inconnu et fallback désactivé
|
||||
"""
|
||||
path = Path(path)
|
||||
with open(path, "rb") as fp:
|
||||
raw = fp.read()
|
||||
|
||||
if raw.startswith(_SIGNATURE_HEADER):
|
||||
return loads_signed(raw, key=key)
|
||||
|
||||
if not allow_pickle_fallback or not _pickle_fallback_allowed():
|
||||
raise UnsupportedFormatError(
|
||||
f"{path} n'est pas au format signé et le fallback pickle est désactivé."
|
||||
)
|
||||
|
||||
logger.warning(
|
||||
"Chargement legacy pickle pour %s — ce format est obsolète et "
|
||||
"sera ré-écrit en JSON signé. Voir docs/SECURITY.md.",
|
||||
path,
|
||||
)
|
||||
|
||||
# Par défaut on refuse tout type non documenté dans ce fichier à risque :
|
||||
# utilisateur peut fournir un `pickle_loader` custom (ex: Unpickler
|
||||
# restreint). On log l'ouverture pour la traçabilité.
|
||||
loader = pickle_loader or (lambda f: pickle.load(f)) # noqa: S301 - usage legacy
|
||||
with open(path, "rb") as fp:
|
||||
data = loader(fp)
|
||||
|
||||
if migrate_on_fallback:
|
||||
try:
|
||||
save_signed(path, data, key=key)
|
||||
logger.info("Fichier %s migré en JSON signé.", path)
|
||||
except Exception as exc: # noqa: BLE001
|
||||
logger.error(
|
||||
"Migration JSON signé échouée pour %s : %s", path, exc
|
||||
)
|
||||
|
||||
return data
|
||||
|
||||
|
||||
__all__ = [
|
||||
"SignedSerializerError",
|
||||
"SignatureVerificationError",
|
||||
"UnsupportedFormatError",
|
||||
"dumps_signed",
|
||||
"loads_signed",
|
||||
"save_signed",
|
||||
"load_signed",
|
||||
]
|
||||
Some files were not shown because too many files have changed in this diff Show More
Reference in New Issue
Block a user