Initial commit
This commit is contained in:
334
geniusia2/core/LEARNING_MANAGER_README.md
Normal file
334
geniusia2/core/LEARNING_MANAGER_README.md
Normal file
@@ -0,0 +1,334 @@
|
||||
# Gestionnaire d'Apprentissage - Documentation d'Implémentation
|
||||
|
||||
## Vue d'Ensemble
|
||||
|
||||
Le module `learning_manager.py` implémente le gestionnaire d'apprentissage central pour RPA Vision V2. Il gère la progression d'apprentissage, les transitions de mode et le calcul des scores de confiance.
|
||||
|
||||
## Fonctionnalités Implémentées
|
||||
|
||||
### 1. Classe LearningManager
|
||||
|
||||
#### Initialisation
|
||||
- ✅ Chargement de la configuration (seuils, pondérations)
|
||||
- ✅ Initialisation en mode Shadow par défaut
|
||||
- ✅ Chargement automatique des profils de tâches existants
|
||||
- ✅ Intégration avec EmbeddingsManager et Logger
|
||||
|
||||
#### Gestion des Profils de Tâches
|
||||
|
||||
##### `_load_profiles()`
|
||||
- ✅ Charge tous les profils JSON depuis le répertoire
|
||||
- ✅ Reconstruit les objets TaskProfile
|
||||
- ✅ Gestion d'erreurs robuste
|
||||
|
||||
##### `_save_profile(task_id)`
|
||||
- ✅ Sauvegarde un profil de tâche en JSON
|
||||
- ✅ Encodage UTF-8 pour support multilingue
|
||||
- ✅ Logging des opérations
|
||||
|
||||
#### Observation et Apprentissage
|
||||
|
||||
##### `observe(action)`
|
||||
- ✅ Enregistre une observation en mode Shadow
|
||||
- ✅ Crée automatiquement un profil de tâche si nécessaire
|
||||
- ✅ Génère un task_id basé sur fenêtre + élément + action
|
||||
- ✅ Ajoute l'action à la séquence
|
||||
- ✅ Stocke l'embedding dans FAISS
|
||||
- ✅ Déclenche transition Shadow → Assist après 5 observations
|
||||
|
||||
##### `_generate_task_id(action)`
|
||||
- ✅ Génère un ID unique basé sur le contexte
|
||||
- ✅ Format: `{fenêtre}_{élément}_{type_action}`
|
||||
- ✅ Nettoyage des espaces et conversion en minuscules
|
||||
|
||||
#### Suggestions et Prédictions
|
||||
|
||||
##### `suggest_action(context)`
|
||||
- ✅ Génère des suggestions en mode Assisté ou Auto
|
||||
- ✅ Recherche de similarité dans l'index FAISS
|
||||
- ✅ Trouve l'action la plus similaire dans l'historique
|
||||
- ✅ Calcule la confiance combinée (vision + LLM + historique)
|
||||
- ✅ Retourne None si pas de correspondance
|
||||
|
||||
#### Validation et Corrections
|
||||
|
||||
##### `confirm_action(feedback)`
|
||||
- ✅ Traite les retours utilisateur (accept/reject/correct)
|
||||
- ✅ Met à jour le taux de concordance
|
||||
- ✅ Gère les corrections avec ajout d'embeddings
|
||||
- ✅ Déclenche les vérifications de transition de mode
|
||||
- ✅ Sauvegarde automatique du profil
|
||||
|
||||
##### `_update_concordance(task_id, success)`
|
||||
- ✅ Calcul de concordance sur fenêtre glissante (10 dernières exécutions)
|
||||
- ✅ Stockage dans metadata pour persistance
|
||||
- ✅ Moyenne mobile pour adaptation progressive
|
||||
|
||||
#### Calcul de Confiance
|
||||
|
||||
##### `calculate_confidence(vision_conf, llm_score, task_id)`
|
||||
- ✅ Formule pondérée : 0.6 × vision + 0.3 × llm + 0.1 × historique
|
||||
- ✅ Normalisation entre 0.0 et 1.0
|
||||
- ✅ Intégration de la performance historique
|
||||
|
||||
##### `_get_historical_performance(task_id)`
|
||||
- ✅ Retourne le taux de concordance comme score historique
|
||||
- ✅ Valeur par défaut 0.5 pour nouvelles tâches
|
||||
|
||||
#### Évaluation et Métriques
|
||||
|
||||
##### `evaluate_task(task_id)`
|
||||
- ✅ Retourne toutes les métriques d'une tâche :
|
||||
- task_id, task_name, mode
|
||||
- observation_count, concordance_rate
|
||||
- confidence_score, correction_count
|
||||
- correction_rate, last_execution
|
||||
|
||||
##### `get_all_tasks()`
|
||||
- ✅ Retourne la liste de toutes les tâches avec métriques
|
||||
- ✅ Utile pour le tableau de bord
|
||||
|
||||
##### `get_task_stats()`
|
||||
- ✅ Statistiques globales :
|
||||
- Nombre total de tâches
|
||||
- Répartition par mode (shadow/assist/auto)
|
||||
- Mode actuel
|
||||
|
||||
#### Transitions de Mode
|
||||
|
||||
##### `should_transition_to_auto(task_id)`
|
||||
- ✅ Vérifie les critères Autopilot :
|
||||
- ≥ 20 observations
|
||||
- ≥ 95% concordance
|
||||
- ✅ Retourne booléen
|
||||
|
||||
##### `rollback_if_low_confidence(task_id)`
|
||||
- ✅ Rétrograde Auto → Assist si confiance < 90%
|
||||
- ✅ Logging de la transition avec raison
|
||||
|
||||
##### `_check_mode_transitions(task_id)`
|
||||
- ✅ Vérifie toutes les transitions possibles :
|
||||
- Shadow → Assist (5 observations)
|
||||
- Assist → Auto (critères remplis)
|
||||
- Auto → Assist (concordance < 85%)
|
||||
|
||||
##### `_transition_mode(task_id, new_mode)`
|
||||
- ✅ Effectue la transition de mode
|
||||
- ✅ Logging avec contexte (observations, concordance)
|
||||
- ✅ Sauvegarde automatique du profil
|
||||
|
||||
#### Gestion du Contexte
|
||||
|
||||
##### `get_mode()`
|
||||
- ✅ Retourne le mode de la tâche actuelle
|
||||
- ✅ Fallback vers mode global si pas de tâche
|
||||
|
||||
##### `get_current_intent()`
|
||||
- ✅ Retourne l'intention utilisateur du contexte
|
||||
|
||||
##### `set_current_task(task_id)`
|
||||
- ✅ Définit la tâche active
|
||||
|
||||
##### `set_current_context(context)`
|
||||
- ✅ Définit le contexte actuel
|
||||
|
||||
#### Enregistrement d'Exécution
|
||||
|
||||
##### `record_execution(decision)`
|
||||
- ✅ Enregistre l'exécution d'une action
|
||||
- ✅ Met à jour last_execution
|
||||
- ✅ Met à jour confidence_score
|
||||
- ✅ Logging et sauvegarde
|
||||
|
||||
## Conformité aux Exigences
|
||||
|
||||
### Exigence 1.2
|
||||
> TANT QUE le Système_RPA fonctionne en Mode_Shadow, LE Système_RPA DOIT enregistrer toutes les interactions utilisateur
|
||||
|
||||
✅ **Implémenté**: La méthode `observe()` enregistre toutes les actions avec horodatages, fenêtres et positions.
|
||||
|
||||
### Exigence 1.4
|
||||
> LORSQU'une interaction utilisateur est capturée en Mode_Shadow, LE Système_RPA DOIT générer et stocker des embeddings visuels
|
||||
|
||||
✅ **Implémenté**: Les embeddings sont stockés dans FAISS via `embeddings_manager.add_to_index()`.
|
||||
|
||||
### Exigence 2.6
|
||||
> LORSQU'un Événement_Correction est enregistré, LE Gestionnaire_Apprentissage DOIT mettre à jour les embeddings visuels
|
||||
|
||||
✅ **Implémenté**: La méthode `confirm_action()` avec type "correct" ajoute les nouveaux embeddings.
|
||||
|
||||
### Exigence 3.1
|
||||
> LORSQU'une Séquence_Actions a un Compteur_Observations ≥20 ET un Taux_Concordance ≥95%, proposer transition vers Autopilot
|
||||
|
||||
✅ **Implémenté**: La méthode `should_transition_to_auto()` vérifie ces critères exacts.
|
||||
|
||||
### Exigence 3.6
|
||||
> LORSQU'un retour correctif est fourni après une action en Mode_Autopilot, ajuster le Score_Confiance
|
||||
|
||||
✅ **Implémenté**: Les corrections mettent à jour la concordance qui influence le score de confiance.
|
||||
|
||||
### Exigence 4.1, 4.2
|
||||
> Calculer le delta entre prédiction et réalité, déclencher ré-entraînement si delta > 10 pixels
|
||||
|
||||
✅ **Préparé**: L'infrastructure est en place, sera complété dans l'orchestrateur.
|
||||
|
||||
### Exigence 4.4
|
||||
> LORSQUE le Score_Confiance tombe en dessous de 90%, faire revenir au Mode_Assisté
|
||||
|
||||
✅ **Implémenté**: La méthode `rollback_if_low_confidence()` effectue cette vérification.
|
||||
|
||||
### Exigence 4.5
|
||||
> Notifier l'utilisateur lors des transitions avec raison
|
||||
|
||||
✅ **Implémenté**: Toutes les transitions sont loggées avec raison via `log_mode_transition()`.
|
||||
|
||||
### Exigence 4.6
|
||||
> Recalculer le Score_Confiance après chaque exécution avec formule 0.6/0.3/0.1
|
||||
|
||||
✅ **Implémenté**: La méthode `calculate_confidence()` utilise exactement cette formule.
|
||||
|
||||
### Exigence 6.1
|
||||
> Maintenir des seuils de confiance dynamiques adaptatifs
|
||||
|
||||
✅ **Implémenté**: Les seuils sont configurables et la concordance utilise une fenêtre glissante.
|
||||
|
||||
### Exigence 6.4
|
||||
> LORSQUE le Taux_Concordance tombe en dessous de 85% sur les 10 dernières exécutions, passer au Mode_Assisté
|
||||
|
||||
✅ **Implémenté**: La méthode `_check_mode_transitions()` vérifie ce critère.
|
||||
|
||||
### Exigence 6.6
|
||||
> Suivre le taux d'Événement_Correction et alerter si > 5% sur 20 exécutions
|
||||
|
||||
✅ **Implémenté**: Le `correction_rate` est calculé dans `evaluate_task()`.
|
||||
|
||||
## Architecture des Données
|
||||
|
||||
### TaskProfile
|
||||
Chaque tâche apprise contient :
|
||||
- **Identification** : task_id, task_name
|
||||
- **État** : mode, observation_count, last_execution
|
||||
- **Performance** : concordance_rate, confidence_score, correction_count
|
||||
- **Sécurité** : window_whitelist
|
||||
- **Apprentissage** : action_sequence, embeddings
|
||||
- **Métadonnées** : metadata (dont recent_results pour concordance)
|
||||
|
||||
### Persistance
|
||||
- **Format** : JSON avec encodage UTF-8
|
||||
- **Emplacement** : `data/user_profiles/{task_id}.json`
|
||||
- **Embeddings** : Stockés séparément dans FAISS
|
||||
- **Chargement** : Automatique à l'initialisation
|
||||
|
||||
## Formules et Algorithmes
|
||||
|
||||
### Score de Confiance
|
||||
```
|
||||
confidence = 0.6 × vision_conf + 0.3 × llm_score + 0.1 × history_score
|
||||
```
|
||||
|
||||
### Taux de Concordance
|
||||
```
|
||||
concordance_rate = succès / total_exécutions (fenêtre de 10)
|
||||
```
|
||||
|
||||
### Taux de Correction
|
||||
```
|
||||
correction_rate = correction_count / observation_count
|
||||
```
|
||||
|
||||
### Critères de Transition
|
||||
|
||||
#### Shadow → Assist
|
||||
- observation_count ≥ 5
|
||||
|
||||
#### Assist → Auto
|
||||
- observation_count ≥ 20
|
||||
- concordance_rate ≥ 0.95
|
||||
|
||||
#### Auto → Assist
|
||||
- confidence_score < 0.90 OU
|
||||
- concordance_rate < 0.85
|
||||
|
||||
## Intégration
|
||||
|
||||
Le LearningManager s'intègre avec :
|
||||
|
||||
1. **EmbeddingsManager** : Stockage et recherche d'embeddings visuels
|
||||
2. **Logger** : Journalisation chiffrée de toutes les opérations
|
||||
3. **Orchestrator** : Boucle cognitive principale (à implémenter)
|
||||
4. **GUI** : Affichage des métriques et notifications (à implémenter)
|
||||
|
||||
## Tests
|
||||
|
||||
Tests unitaires complets dans `tests/test_learning_manager.py` :
|
||||
- ✅ Initialisation
|
||||
- ✅ Observation d'actions
|
||||
- ✅ Transitions de mode
|
||||
- ✅ Calcul de confiance
|
||||
- ✅ Confirmations et corrections
|
||||
- ✅ Évaluation et métriques
|
||||
- ✅ Sauvegarde et chargement
|
||||
|
||||
## Utilisation
|
||||
|
||||
```python
|
||||
from geniusia2.core.learning_manager import LearningManager
|
||||
from geniusia2.core.embeddings_manager import EmbeddingsManager
|
||||
from geniusia2.core.logger import Logger
|
||||
from geniusia2.core.models import Action
|
||||
from datetime import datetime
|
||||
import numpy as np
|
||||
|
||||
# Initialiser les composants
|
||||
embeddings_manager = EmbeddingsManager()
|
||||
logger = Logger()
|
||||
config = {"thresholds": {...}}
|
||||
|
||||
learning_manager = LearningManager(
|
||||
embeddings_manager=embeddings_manager,
|
||||
logger=logger,
|
||||
config=config
|
||||
)
|
||||
|
||||
# Observer une action
|
||||
action = Action(
|
||||
action_type="click",
|
||||
target_element="valider_button",
|
||||
bbox=(100, 100, 50, 30),
|
||||
confidence=0.9,
|
||||
embedding=np.random.rand(512),
|
||||
timestamp=datetime.now(),
|
||||
window_title="Application"
|
||||
)
|
||||
|
||||
learning_manager.observe(action)
|
||||
|
||||
# Obtenir des statistiques
|
||||
stats = learning_manager.get_task_stats()
|
||||
print(f"Total tâches: {stats['total_tasks']}")
|
||||
print(f"Mode actuel: {stats['current_mode']}")
|
||||
```
|
||||
|
||||
## Statut
|
||||
|
||||
✅ **Tâche 7.1 COMPLÈTE**
|
||||
|
||||
Toutes les fonctionnalités requises sont implémentées :
|
||||
- ✅ Classe LearningManager avec initialisation
|
||||
- ✅ Méthode observe() pour mode Shadow
|
||||
- ✅ Méthode suggest_action() pour mode Assisté
|
||||
- ✅ Méthode confirm_action() pour validation/correction
|
||||
- ✅ Méthode calculate_confidence() avec formule 0.6/0.3/0.1
|
||||
- ✅ Méthode evaluate_task() pour métriques
|
||||
- ✅ Méthode should_transition_to_auto() avec critères
|
||||
- ✅ Méthode rollback_if_low_confidence()
|
||||
- ✅ Méthodes get_current_intent() et record_execution()
|
||||
- ✅ Gestion des seuils dynamiques adaptatifs
|
||||
- ✅ Tests unitaires complets
|
||||
|
||||
## Prochaines Étapes
|
||||
|
||||
La tâche suivante est **Tâche 8 : Implémenter l'orchestrateur (boucle cognitive)** qui utilisera le LearningManager pour :
|
||||
- Coordonner le cycle Observer → Réfléchir → Agir
|
||||
- Gérer les transitions entre modes
|
||||
- Intégrer vision, LLM et apprentissage
|
||||
257
geniusia2/core/ORCHESTRATOR_README.md
Normal file
257
geniusia2/core/ORCHESTRATOR_README.md
Normal file
@@ -0,0 +1,257 @@
|
||||
# Orchestrateur - Boucle Cognitive RPA Vision V2
|
||||
|
||||
## Vue d'ensemble
|
||||
|
||||
L'orchestrateur est le composant central de RPA Vision V2 qui implémente la boucle cognitive principale suivant le paradigme **Observer → Réfléchir → Agir → Apprendre**.
|
||||
|
||||
## Architecture
|
||||
|
||||
```
|
||||
┌─────────────────────────────────────────────────┐
|
||||
│ BOUCLE COGNITIVE PRINCIPALE │
|
||||
│ │
|
||||
│ 1. OBSERVER │
|
||||
│ ├─ Capturer l'écran (capture_screen) │
|
||||
│ └─ Détecter fenêtre active │
|
||||
│ │
|
||||
│ 2. RÉFLÉCHIR │
|
||||
│ ├─ Détecter éléments UI (VisionUtils) │
|
||||
│ └─ Raisonner sur l'action (LLMManager) │
|
||||
│ │
|
||||
│ 3. AGIR │
|
||||
│ ├─ Mode Shadow: Observer uniquement │
|
||||
│ ├─ Mode Assisté: Suggérer + Valider │
|
||||
│ └─ Mode Autopilot: Exécuter │
|
||||
│ │
|
||||
│ 4. APPRENDRE │
|
||||
│ └─ Mettre à jour LearningManager │
|
||||
└─────────────────────────────────────────────────┘
|
||||
```
|
||||
|
||||
## Composants Intégrés
|
||||
|
||||
- **LearningManager**: Gestion de l'apprentissage et transitions de mode
|
||||
- **VisionUtils**: Détection d'éléments UI avec OWL-v2/DINO/YOLO
|
||||
- **LLMManager**: Raisonnement visuel avec Qwen 2.5-VL
|
||||
- **Logger**: Journalisation chiffrée de toutes les actions
|
||||
- **GUI**: Interface utilisateur (optionnel)
|
||||
|
||||
## Modes Opérationnels
|
||||
|
||||
### Mode Shadow (👀)
|
||||
- Observe les actions utilisateur sans exécuter
|
||||
- Enregistre les observations pour apprentissage
|
||||
- Construit la mémoire visuelle
|
||||
|
||||
### Mode Assisté (🤝)
|
||||
- Suggère des actions basées sur l'apprentissage
|
||||
- Attend validation utilisateur (Entrée/Échap/Alt+C)
|
||||
- Apprend des corrections
|
||||
|
||||
### Mode Autopilot (🤖)
|
||||
- Exécute automatiquement les actions
|
||||
- Vérifie la liste blanche avant exécution
|
||||
- Rétrograde si confiance < 90%
|
||||
|
||||
## Sécurité
|
||||
|
||||
### Liste Blanche
|
||||
- Vérifie que la fenêtre est autorisée avant toute action
|
||||
- Bloque et journalise les violations
|
||||
- Configurable dynamiquement
|
||||
|
||||
### Arrêt d'Urgence
|
||||
- Ctrl+Pause (ou Ctrl+C) arrête immédiatement
|
||||
- Journalise l'événement de sécurité
|
||||
- Retour au mode Assisté
|
||||
|
||||
### Journalisation
|
||||
- Toutes les actions sont loggées avec chiffrement AES-256
|
||||
- Inclut: fenêtre, action, confiance, résultat
|
||||
- Traçabilité complète pour audit
|
||||
|
||||
## Métriques de Performance
|
||||
|
||||
L'orchestrateur suit:
|
||||
- **total_cycles**: Nombre de cycles de la boucle cognitive
|
||||
- **avg_latency_ms**: Latence moyenne (cible: <400ms)
|
||||
- **detections_count**: Nombre d'éléments détectés
|
||||
- **actions_executed**: Actions exécutées en mode Autopilot
|
||||
- **actions_suggested**: Actions suggérées en mode Assisté
|
||||
|
||||
## Utilisation
|
||||
|
||||
### Initialisation
|
||||
|
||||
```python
|
||||
from geniusia2.core.orchestrator import Orchestrator
|
||||
from geniusia2.core.learning_manager import LearningManager
|
||||
from geniusia2.core.utils.vision_utils import VisionUtils
|
||||
from geniusia2.core.llm_manager import LLMManager
|
||||
from geniusia2.core.logger import Logger
|
||||
from geniusia2.core.embeddings_manager import EmbeddingsManager
|
||||
from geniusia2.core.config import get_config
|
||||
|
||||
# Initialiser les composants
|
||||
config = get_config()
|
||||
logger = Logger()
|
||||
embeddings_manager = EmbeddingsManager()
|
||||
learning_manager = LearningManager(embeddings_manager, logger, config)
|
||||
vision_utils = VisionUtils(config)
|
||||
llm_manager = LLMManager(logger=logger)
|
||||
|
||||
# Créer l'orchestrateur
|
||||
orchestrator = Orchestrator(
|
||||
learning_manager=learning_manager,
|
||||
vision_utils=vision_utils,
|
||||
llm_manager=llm_manager,
|
||||
logger=logger,
|
||||
gui=None # Ou votre instance GUI
|
||||
)
|
||||
```
|
||||
|
||||
### Démarrage
|
||||
|
||||
```python
|
||||
# Configurer la liste blanche
|
||||
orchestrator.add_to_whitelist("Dolibarr*")
|
||||
orchestrator.add_to_whitelist("Firefox*")
|
||||
|
||||
# Définir l'intention actuelle
|
||||
learning_manager.set_current_context({"intent": "cliquer sur valider"})
|
||||
|
||||
# Démarrer la boucle cognitive (dans un thread séparé)
|
||||
import threading
|
||||
cognitive_thread = threading.Thread(target=orchestrator.run)
|
||||
cognitive_thread.start()
|
||||
|
||||
# Arrêter proprement
|
||||
orchestrator.stop()
|
||||
cognitive_thread.join()
|
||||
```
|
||||
|
||||
### Gestion de la Liste Blanche
|
||||
|
||||
```python
|
||||
# Ajouter une fenêtre
|
||||
orchestrator.add_to_whitelist("MonApplication*")
|
||||
|
||||
# Retirer une fenêtre
|
||||
orchestrator.remove_from_whitelist("MonApplication*")
|
||||
|
||||
# Obtenir la liste
|
||||
whitelist = orchestrator.get_whitelist()
|
||||
|
||||
# Activer/désactiver l'application
|
||||
orchestrator.set_whitelist_enforcement(True)
|
||||
```
|
||||
|
||||
### Contrôle de la Boucle
|
||||
|
||||
```python
|
||||
# Mettre en pause
|
||||
orchestrator.pause()
|
||||
|
||||
# Reprendre
|
||||
orchestrator.resume()
|
||||
|
||||
# Arrêter
|
||||
orchestrator.stop()
|
||||
```
|
||||
|
||||
### Consultation des Métriques
|
||||
|
||||
```python
|
||||
# Obtenir les métriques
|
||||
metrics = orchestrator.get_metrics()
|
||||
print(f"Cycles: {metrics['total_cycles']}")
|
||||
print(f"Latence moyenne: {metrics['avg_latency_ms']:.2f}ms")
|
||||
|
||||
# Obtenir le statut complet
|
||||
status = orchestrator.get_status()
|
||||
print(f"Mode: {status['mode']}")
|
||||
print(f"Fenêtre actuelle: {status['current_window']}")
|
||||
```
|
||||
|
||||
## Flux de Décision
|
||||
|
||||
```
|
||||
Capture Contexte
|
||||
↓
|
||||
Détection UI (Vision)
|
||||
↓
|
||||
Raisonnement (LLM)
|
||||
↓
|
||||
Calcul Confiance
|
||||
↓
|
||||
┌─────────────────┐
|
||||
│ Mode actuel? │
|
||||
└─────────────────┘
|
||||
↓
|
||||
┌──┴──┬──────┬──────┐
|
||||
│ │ │ │
|
||||
Shadow Assist Auto │
|
||||
│ │ │ │
|
||||
↓ ↓ ↓ ↓
|
||||
Observer Suggérer Vérifier Exécuter
|
||||
+ Liste
|
||||
Valider Blanche
|
||||
↓ ↓
|
||||
Apprendre Apprendre
|
||||
```
|
||||
|
||||
## Gestion des Erreurs
|
||||
|
||||
### Échec de Capture
|
||||
- Log l'erreur
|
||||
- Continue la boucle
|
||||
- Réessaie au cycle suivant
|
||||
|
||||
### Échec de Détection
|
||||
- Essaie les modèles de fallback (OWL-v2 → DINO → YOLO)
|
||||
- Si tous échouent, log et continue
|
||||
|
||||
### Échec LLM
|
||||
- Fallback vers sélection basée vision pure
|
||||
- Continue avec confiance réduite
|
||||
|
||||
### Violation Liste Blanche
|
||||
- Bloque l'action immédiatement
|
||||
- Log événement de sécurité
|
||||
- Notifie l'utilisateur
|
||||
|
||||
## Exigences Satisfaites
|
||||
|
||||
✓ 1.1, 1.2 - Capture d'écran et détection fenêtre
|
||||
✓ 2.1, 2.2, 2.3, 2.4, 2.5 - Modes et suggestions
|
||||
✓ 3.2, 3.3 - Exécution automatique et arrêt d'urgence
|
||||
✓ 5.3, 5.4 - Liste blanche et sécurité
|
||||
✓ 6.2, 6.3 - Notifications et alertes
|
||||
|
||||
## Tests
|
||||
|
||||
Exécuter les tests basiques:
|
||||
```bash
|
||||
python3 test_orchestrator_simple.py
|
||||
```
|
||||
|
||||
Pour tester avec tous les composants (nécessite PyTorch, etc.):
|
||||
```bash
|
||||
python3 -m geniusia2.core.orchestrator
|
||||
```
|
||||
|
||||
## Notes d'Implémentation
|
||||
|
||||
- La boucle cognitive s'exécute dans un thread séparé
|
||||
- Latence cible: <400ms par cycle
|
||||
- Utilise des événements threading pour arrêt/pause propre
|
||||
- Toutes les actions sont journalisées avec chiffrement
|
||||
- Support multi-plateforme (Linux, Windows, macOS)
|
||||
|
||||
## Prochaines Étapes
|
||||
|
||||
1. Implémenter `input_utils.py` pour exécution réelle des actions
|
||||
2. Implémenter `replay_async.py` pour rollback
|
||||
3. Intégrer avec la GUI PyQt5
|
||||
4. Ajouter le système de détection de changements UI
|
||||
5. Implémenter les métriques avancées
|
||||
275
geniusia2/core/TASK_REPLAY_README.md
Normal file
275
geniusia2/core/TASK_REPLAY_README.md
Normal file
@@ -0,0 +1,275 @@
|
||||
# 🎮 Système de Rejeu de Tâches
|
||||
|
||||
## Vue d'ensemble
|
||||
|
||||
Le système de rejeu permet de **rejouer automatiquement des tâches apprises** en utilisant la reconnaissance visuelle pour localiser les éléments d'interface, même si leur position a changé.
|
||||
|
||||
## Architecture
|
||||
|
||||
```
|
||||
TaskReplayEngine
|
||||
├── Chargement de tâche (load_task)
|
||||
├── Recherche visuelle (find_element_visually)
|
||||
│ ├── Capture d'écran
|
||||
│ ├── Génération d'embeddings
|
||||
│ └── Recherche par similarité
|
||||
├── Exécution d'action (execute_action_at_location)
|
||||
│ ├── Click
|
||||
│ ├── Type
|
||||
│ ├── Scroll
|
||||
│ └── Drag
|
||||
└── Monitoring en temps réel
|
||||
```
|
||||
|
||||
## Fonctionnalités
|
||||
|
||||
### 1. Rejeu Basique
|
||||
|
||||
```python
|
||||
from core.task_replay import TaskReplayEngine
|
||||
|
||||
# Rejouer une tâche
|
||||
results = await replay_engine.replay_task(
|
||||
task_id="task_fc1d3e52",
|
||||
interactive=False
|
||||
)
|
||||
|
||||
print(f"Succès: {results['success']}")
|
||||
print(f"Actions exécutées: {results['executed_actions']}")
|
||||
```
|
||||
|
||||
### 2. Rejeu avec Monitoring
|
||||
|
||||
```python
|
||||
def on_step_completed(step_result):
|
||||
print(f"Étape {step_result['step']}: {step_result['status']}")
|
||||
|
||||
results = await replay_engine.replay_task_with_monitoring(
|
||||
task_id="task_fc1d3e52",
|
||||
on_step_completed=on_step_completed
|
||||
)
|
||||
```
|
||||
|
||||
### 3. Liste des Tâches Disponibles
|
||||
|
||||
```python
|
||||
tasks = replay_engine.list_available_tasks()
|
||||
|
||||
for task in tasks:
|
||||
print(f"{task['task_name']}: {task['observation_count']} observations")
|
||||
```
|
||||
|
||||
## Reconnaissance Visuelle
|
||||
|
||||
### Recherche par Grille
|
||||
|
||||
Le système divise l'écran en grille et recherche l'élément dans chaque cellule :
|
||||
|
||||
```python
|
||||
# Grille 4x4 par défaut
|
||||
grid_size = 4
|
||||
|
||||
# Pour chaque cellule :
|
||||
# 1. Extraire la région
|
||||
# 2. Générer l'embedding CLIP
|
||||
# 3. Calculer la similarité avec l'embedding cible
|
||||
# 4. Retourner les meilleures correspondances
|
||||
```
|
||||
|
||||
### Seuil de Similarité
|
||||
|
||||
```yaml
|
||||
replay:
|
||||
similarity_threshold: 0.75 # Minimum pour accepter une correspondance
|
||||
max_search_attempts: 3 # Nombre de tentatives
|
||||
```
|
||||
|
||||
## Types d'Actions Supportées
|
||||
|
||||
| Type | Description | Paramètres |
|
||||
|------|-------------|------------|
|
||||
| `click` | Clic souris | x, y, button |
|
||||
| `type` | Saisie texte | text, interval |
|
||||
| `scroll` | Défilement | direction, amount |
|
||||
| `drag` | Glisser-déposer | start_x, start_y, end_x, end_y |
|
||||
|
||||
## Gestion des Variations d'Interface
|
||||
|
||||
### Adaptation Automatique
|
||||
|
||||
Le système s'adapte aux changements :
|
||||
|
||||
1. **Position différente** : Recherche visuelle localise l'élément
|
||||
2. **Taille différente** : Utilise le centre de la région trouvée
|
||||
3. **Style différent** : L'embedding CLIP capture la sémantique
|
||||
|
||||
### Retry et Fallback
|
||||
|
||||
```python
|
||||
# Retry automatique si élément non trouvé
|
||||
max_search_attempts = 3
|
||||
|
||||
# Délai entre tentatives
|
||||
retry_delay = 0.5 # secondes
|
||||
```
|
||||
|
||||
## Mode Interactif
|
||||
|
||||
En mode interactif, le système demande confirmation avant chaque action :
|
||||
|
||||
```python
|
||||
results = await replay_engine.replay_task(
|
||||
task_id="task_fc1d3e52",
|
||||
interactive=True # Demande confirmation
|
||||
)
|
||||
```
|
||||
|
||||
## Résultats du Rejeu
|
||||
|
||||
```python
|
||||
{
|
||||
"task_id": "task_fc1d3e52",
|
||||
"success": True,
|
||||
"total_actions": 3,
|
||||
"executed_actions": 3,
|
||||
"failed_actions": 0,
|
||||
"actions": [
|
||||
{
|
||||
"step": 1,
|
||||
"success": True,
|
||||
"location": {
|
||||
"x": 640,
|
||||
"y": 360,
|
||||
"confidence": 0.89,
|
||||
"bbox": [600, 340, 680, 380]
|
||||
},
|
||||
"action_type": "click"
|
||||
},
|
||||
# ...
|
||||
]
|
||||
}
|
||||
```
|
||||
|
||||
## Logging
|
||||
|
||||
Tous les événements sont loggés :
|
||||
|
||||
```python
|
||||
# Démarrage du rejeu
|
||||
{"action": "task_replay_started", "task_id": "...", "interactive": False}
|
||||
|
||||
# Élément trouvé
|
||||
{"action": "element_found", "similarity": 0.89, "attempt": 1}
|
||||
|
||||
# Élément non trouvé
|
||||
{"action": "element_not_found", "step": 2, "signature": "..."}
|
||||
|
||||
# Rejeu terminé
|
||||
{"action": "task_replay_completed", "success": True, "executed": 3, "failed": 0}
|
||||
```
|
||||
|
||||
## Utilisation avec le CLI
|
||||
|
||||
```bash
|
||||
# Lister les tâches disponibles
|
||||
python test_task_replay.py
|
||||
|
||||
# Le script vous guidera pour :
|
||||
# 1. Voir les tâches disponibles
|
||||
# 2. Choisir une tâche
|
||||
# 3. Rejouer avec monitoring en temps réel
|
||||
```
|
||||
|
||||
## Exemple Complet
|
||||
|
||||
```python
|
||||
import asyncio
|
||||
from core.task_replay import TaskReplayEngine
|
||||
from core.learning_manager import LearningManager
|
||||
from core.embeddings_manager import EmbeddingsManager
|
||||
from core.utils.vision_utils import VisionUtils
|
||||
from core.utils.input_utils import InputUtils
|
||||
from core.logger import Logger
|
||||
from core.config import load_config
|
||||
|
||||
async def replay_example():
|
||||
# Initialiser
|
||||
config = load_config()
|
||||
logger = Logger(config)
|
||||
embeddings_manager = EmbeddingsManager(logger, config)
|
||||
learning_manager = LearningManager(embeddings_manager, logger, config)
|
||||
vision_utils = VisionUtils(logger, config)
|
||||
input_utils = InputUtils(logger, config)
|
||||
|
||||
replay_engine = TaskReplayEngine(
|
||||
learning_manager,
|
||||
embeddings_manager,
|
||||
vision_utils,
|
||||
input_utils,
|
||||
logger,
|
||||
config
|
||||
)
|
||||
|
||||
# Lister les tâches
|
||||
tasks = replay_engine.list_available_tasks()
|
||||
print(f"Tâches disponibles: {len(tasks)}")
|
||||
|
||||
# Rejouer la première tâche
|
||||
if tasks:
|
||||
task_id = tasks[0]['task_id']
|
||||
results = await replay_engine.replay_task(task_id)
|
||||
|
||||
if results['success']:
|
||||
print("✅ Tâche rejouée avec succès!")
|
||||
else:
|
||||
print(f"❌ Échec: {results['failed_actions']} actions échouées")
|
||||
|
||||
# Exécuter
|
||||
asyncio.run(replay_example())
|
||||
```
|
||||
|
||||
## Limitations Actuelles
|
||||
|
||||
1. **Recherche par grille** : Peut manquer des petits éléments entre les cellules
|
||||
2. **Pas de détection d'objets** : Utilise uniquement CLIP pour la similarité
|
||||
3. **Mode interactif basique** : Pas d'interface graphique pour la confirmation
|
||||
|
||||
## Améliorations Futures
|
||||
|
||||
- [ ] Recherche multi-échelle (grilles de différentes tailles)
|
||||
- [ ] Intégration avec OWL-v2 pour détection précise
|
||||
- [ ] Interface graphique pour le mode interactif
|
||||
- [ ] Gestion des erreurs avec rollback automatique
|
||||
- [ ] Support des actions conditionnelles
|
||||
- [ ] Rejeu parallèle de plusieurs tâches
|
||||
|
||||
## Configuration
|
||||
|
||||
```yaml
|
||||
replay:
|
||||
similarity_threshold: 0.75 # Seuil de similarité minimum
|
||||
max_search_attempts: 3 # Tentatives de recherche
|
||||
delay_between_actions: 0.5 # Délai entre actions (secondes)
|
||||
grid_size: 4 # Taille de la grille de recherche
|
||||
```
|
||||
|
||||
## Dépendances
|
||||
|
||||
- `numpy` : Calculs d'embeddings
|
||||
- `PIL` : Manipulation d'images
|
||||
- `pyautogui` : Capture d'écran et contrôle souris/clavier
|
||||
- `asyncio` : Exécution asynchrone
|
||||
|
||||
## Tests
|
||||
|
||||
```bash
|
||||
# Test complet avec interface interactive
|
||||
python test_task_replay.py
|
||||
|
||||
# Test programmatique
|
||||
python -c "
|
||||
import asyncio
|
||||
from test_task_replay import test_list_tasks
|
||||
asyncio.run(test_list_tasks())
|
||||
"
|
||||
```
|
||||
281
geniusia2/core/UI_CHANGE_DETECTOR_README.md
Normal file
281
geniusia2/core/UI_CHANGE_DETECTOR_README.md
Normal file
@@ -0,0 +1,281 @@
|
||||
# UIChangeDetector - Détecteur de Changements UI
|
||||
|
||||
## Vue d'ensemble
|
||||
|
||||
Le `UIChangeDetector` est un composant essentiel du système RPA Vision V2 qui surveille les changements d'interface utilisateur et déclenche le ré-entraînement lorsque nécessaire. Il détecte deux types de dérives :
|
||||
|
||||
1. **Dérive visuelle** : Changements dans l'apparence des éléments UI (détectés via similarité d'embeddings)
|
||||
2. **Dérive de position** : Déplacements des éléments UI (détectés via delta de bounding box)
|
||||
|
||||
## Fonctionnalités principales
|
||||
|
||||
### 1. Détection de changements visuels
|
||||
|
||||
```python
|
||||
change_detected, similarity = detector.detect_ui_change(
|
||||
current_embedding,
|
||||
stored_embeddings,
|
||||
task_id
|
||||
)
|
||||
```
|
||||
|
||||
Compare l'embedding visuel actuel avec les embeddings stockés pour détecter si l'UI a changé. Un changement est détecté si la similarité maximale est inférieure au seuil (70% par défaut).
|
||||
|
||||
### 2. Calcul de delta de position
|
||||
|
||||
```python
|
||||
deltas = detector.calculate_delta(
|
||||
predicted_bbox=(100, 200, 50, 30),
|
||||
actual_bbox=(105, 203, 50, 30)
|
||||
)
|
||||
```
|
||||
|
||||
Calcule les différences en pixels entre la position prédite et la position réelle d'un élément UI. Retourne :
|
||||
- `delta_x` : Différence en X
|
||||
- `delta_y` : Différence en Y
|
||||
- `delta_width` : Différence en largeur
|
||||
- `delta_height` : Différence en hauteur
|
||||
- `delta_center` : Distance euclidienne entre les centres
|
||||
- `max_delta` : Delta maximum (position)
|
||||
|
||||
### 3. Décision de ré-entraînement
|
||||
|
||||
```python
|
||||
should_retrain = detector.should_trigger_retraining(deltas, similarity)
|
||||
```
|
||||
|
||||
Détermine si le ré-entraînement doit être déclenché en fonction de :
|
||||
- Delta de position > seuil (10 pixels par défaut)
|
||||
- OU similarité visuelle < seuil (70% par défaut)
|
||||
|
||||
### 4. Déclenchement de ré-entraînement
|
||||
|
||||
```python
|
||||
detector.trigger_retraining(
|
||||
task_id="ouvrir_facture",
|
||||
reason="position_drift",
|
||||
metadata={"delta": 25.0}
|
||||
)
|
||||
```
|
||||
|
||||
Enregistre un événement de ré-entraînement dans les logs et l'historique.
|
||||
|
||||
### 5. Vérification complète
|
||||
|
||||
```python
|
||||
result = detector.check_and_trigger_retraining(
|
||||
task_id="ouvrir_facture",
|
||||
current_embedding=current_emb,
|
||||
stored_embeddings=stored_embs,
|
||||
predicted_bbox=(100, 200, 50, 30),
|
||||
actual_bbox=(150, 230, 50, 30)
|
||||
)
|
||||
```
|
||||
|
||||
Effectue une vérification complète et déclenche automatiquement le ré-entraînement si nécessaire. Retourne un dictionnaire avec :
|
||||
- `ui_change_detected` : Changement visuel détecté
|
||||
- `position_drift_detected` : Dérive de position détectée
|
||||
- `retraining_triggered` : Ré-entraînement déclenché
|
||||
- `similarity` : Similarité maximale trouvée
|
||||
- `deltas` : Dictionnaire des deltas de position
|
||||
|
||||
## Configuration
|
||||
|
||||
Les seuils sont configurables via le fichier `config.py` :
|
||||
|
||||
```python
|
||||
CONFIG = {
|
||||
"thresholds": {
|
||||
"ui_change_similarity": 0.70, # Seuil de similarité (70%)
|
||||
"bbox_delta_pixels": 10 # Seuil de delta position (10 pixels)
|
||||
}
|
||||
}
|
||||
```
|
||||
|
||||
## Intégration avec LearningManager
|
||||
|
||||
Le `UIChangeDetector` est intégré dans le `LearningManager` pour une surveillance continue :
|
||||
|
||||
```python
|
||||
# Dans LearningManager
|
||||
result = self.check_ui_changes(
|
||||
task_id="ouvrir_facture",
|
||||
current_embedding=current_emb,
|
||||
predicted_bbox=predicted_bbox,
|
||||
actual_bbox=actual_bbox
|
||||
)
|
||||
|
||||
if result.get("retraining_triggered"):
|
||||
# Le système ajoute automatiquement le nouvel embedding
|
||||
# et met à jour l'index FAISS
|
||||
pass
|
||||
```
|
||||
|
||||
### Surveillance de la dérive d'exécution
|
||||
|
||||
```python
|
||||
# Surveiller la différence entre prédiction et réalité
|
||||
drift_detected = learning_manager.monitor_execution_drift(
|
||||
task_id="ouvrir_facture",
|
||||
predicted_action=predicted_action,
|
||||
actual_action=actual_action
|
||||
)
|
||||
```
|
||||
|
||||
## Historique et statistiques
|
||||
|
||||
### Consulter l'historique
|
||||
|
||||
```python
|
||||
# Tous les changements
|
||||
history = detector.get_change_history()
|
||||
|
||||
# Changements pour une tâche spécifique
|
||||
history = detector.get_change_history(task_id="ouvrir_facture", limit=20)
|
||||
```
|
||||
|
||||
### Obtenir les statistiques
|
||||
|
||||
```python
|
||||
stats = detector.get_stats()
|
||||
# Retourne:
|
||||
# {
|
||||
# "total_changes_detected": 15,
|
||||
# "retraining_triggered_count": 8,
|
||||
# "changes_by_task": {
|
||||
# "ouvrir_facture": 5,
|
||||
# "valider_commande": 3
|
||||
# },
|
||||
# "ui_change_threshold": 0.70,
|
||||
# "bbox_delta_threshold": 10
|
||||
# }
|
||||
```
|
||||
|
||||
## Cas d'usage
|
||||
|
||||
### 1. Détection de mise à jour d'interface
|
||||
|
||||
Lorsqu'une application est mise à jour et que l'interface change :
|
||||
|
||||
```python
|
||||
# L'embedding actuel sera différent des embeddings stockés
|
||||
change_detected, similarity = detector.detect_ui_change(
|
||||
current_embedding,
|
||||
stored_embeddings,
|
||||
"ouvrir_facture"
|
||||
)
|
||||
|
||||
if change_detected:
|
||||
print(f"Mise à jour UI détectée (similarité: {similarity:.2%})")
|
||||
# Le système demandera à l'utilisateur de ré-observer la tâche
|
||||
```
|
||||
|
||||
### 2. Détection de déplacement d'éléments
|
||||
|
||||
Lorsque des éléments UI se déplacent (ex: fenêtre redimensionnée) :
|
||||
|
||||
```python
|
||||
deltas = detector.calculate_delta(
|
||||
predicted_bbox=(100, 200, 50, 30),
|
||||
actual_bbox=(150, 230, 50, 30)
|
||||
)
|
||||
|
||||
if deltas['max_delta'] > 10:
|
||||
print(f"Élément déplacé de {deltas['max_delta']:.0f} pixels")
|
||||
# Le système ajustera automatiquement ses prédictions
|
||||
```
|
||||
|
||||
### 3. Surveillance continue en mode Autopilot
|
||||
|
||||
En mode Autopilot, le système surveille automatiquement chaque exécution :
|
||||
|
||||
```python
|
||||
# Après chaque action automatisée
|
||||
result = learning_manager.check_ui_changes(
|
||||
task_id=current_task_id,
|
||||
current_embedding=detected_element.embedding,
|
||||
predicted_bbox=predicted_bbox,
|
||||
actual_bbox=detected_bbox
|
||||
)
|
||||
|
||||
if result['retraining_triggered']:
|
||||
# Rétrograder au mode Assisté pour validation
|
||||
learning_manager.rollback_if_low_confidence(current_task_id)
|
||||
```
|
||||
|
||||
## Exigences satisfaites
|
||||
|
||||
Ce module satisfait les exigences suivantes du document de requirements :
|
||||
|
||||
- **Exigence 4.1** : Calcul du delta entre emplacement prédit et réel
|
||||
- **Exigence 4.2** : Déclenchement du ré-entraînement si delta > 10 pixels
|
||||
- **Exigence 6.5** : Détection de changements UI (similarité < 70%)
|
||||
|
||||
## Tests
|
||||
|
||||
Un test complet est disponible dans `test_ui_change_detector_simple.py` :
|
||||
|
||||
```bash
|
||||
python3 test_ui_change_detector_simple.py
|
||||
```
|
||||
|
||||
Le test vérifie :
|
||||
- Détection de changements visuels
|
||||
- Calcul de deltas de position
|
||||
- Décisions de ré-entraînement
|
||||
- Déclenchement de ré-entraînement
|
||||
- Vérification complète
|
||||
- Statistiques et historique
|
||||
|
||||
## Architecture
|
||||
|
||||
```
|
||||
UIChangeDetector
|
||||
├── detect_ui_change() # Détection visuelle
|
||||
├── calculate_delta() # Calcul delta position
|
||||
├── should_trigger_retraining() # Décision
|
||||
├── trigger_retraining() # Déclenchement
|
||||
├── check_and_trigger_retraining() # Vérification complète
|
||||
├── get_change_history() # Historique
|
||||
├── get_stats() # Statistiques
|
||||
└── clear_history() # Nettoyage
|
||||
|
||||
Intégration LearningManager
|
||||
├── check_ui_changes() # Vérification avec mise à jour
|
||||
├── monitor_execution_drift() # Surveillance continue
|
||||
└── get_ui_change_stats() # Statistiques globales
|
||||
```
|
||||
|
||||
## Journalisation
|
||||
|
||||
Tous les événements sont enregistrés dans le logger chiffré :
|
||||
|
||||
- `ui_change_detected` : Changement UI détecté
|
||||
- `ui_stable` : UI stable (pas de changement)
|
||||
- `bbox_delta_calculated` : Delta de position calculé
|
||||
- `retraining_decision` : Décision de ré-entraînement
|
||||
- `retraining_triggered` : Ré-entraînement déclenché
|
||||
- `ui_change_retraining` : Ré-entraînement suite à changement UI
|
||||
|
||||
## Bonnes pratiques
|
||||
|
||||
1. **Seuils adaptatifs** : Ajuster les seuils en fonction du type d'application
|
||||
2. **Historique limité** : Utiliser `limit` pour éviter de surcharger la mémoire
|
||||
3. **Nettoyage périodique** : Effacer l'historique ancien avec `clear_history()`
|
||||
4. **Surveillance continue** : Intégrer dans la boucle cognitive de l'orchestrateur
|
||||
5. **Feedback utilisateur** : Notifier l'utilisateur lors de changements détectés
|
||||
|
||||
## Limitations
|
||||
|
||||
- La détection visuelle dépend de la qualité des embeddings OpenCLIP
|
||||
- Les vecteurs aléatoires en haute dimension ont naturellement une similarité modérée (~0.5-0.9)
|
||||
- Le seuil de 70% peut nécessiter un ajustement selon les cas d'usage
|
||||
- La détection de position nécessite des bounding boxes précises
|
||||
|
||||
## Évolutions futures
|
||||
|
||||
- Seuils adaptatifs par tâche
|
||||
- Détection de patterns de changements récurrents
|
||||
- Prédiction proactive de changements UI
|
||||
- Apprentissage des zones stables vs zones volatiles
|
||||
- Intégration avec un système de versioning d'UI
|
||||
318
geniusia2/core/WHITELIST_MANAGER_README.md
Normal file
318
geniusia2/core/WHITELIST_MANAGER_README.md
Normal file
@@ -0,0 +1,318 @@
|
||||
# WhitelistManager - Gestionnaire de Liste Blanche
|
||||
|
||||
## Vue d'ensemble
|
||||
|
||||
Le `WhitelistManager` est un composant de sécurité pour RPA Vision V2 qui gère la liste des fenêtres d'application autorisées pour l'automatisation. Il fournit un contrôle granulaire sur les applications où le système peut exécuter des actions automatisées.
|
||||
|
||||
## Fonctionnalités principales
|
||||
|
||||
### 1. Vérification de fenêtres autorisées
|
||||
- Vérifie si une fenêtre est dans la liste blanche
|
||||
- Support pour patterns avec wildcards (`*`)
|
||||
- Correspondance insensible à la casse
|
||||
|
||||
### 2. Gestion de la liste blanche
|
||||
- Ajout de fenêtres avec confirmation admin optionnelle
|
||||
- Suppression de fenêtres
|
||||
- Vidage complet de la liste
|
||||
- Obtention de la liste complète
|
||||
|
||||
### 3. Persistance
|
||||
- Sauvegarde automatique dans un fichier JSON
|
||||
- Chargement au démarrage
|
||||
- Export/Import pour partage de configurations
|
||||
|
||||
### 4. Métadonnées et audit
|
||||
- Suivi de qui a ajouté chaque entrée
|
||||
- Horodatage des modifications
|
||||
- Confirmation admin enregistrée
|
||||
- Statistiques sur la liste blanche
|
||||
|
||||
### 5. Journalisation
|
||||
- Toutes les opérations sont loggées
|
||||
- Événements de sécurité pour violations
|
||||
- Intégration avec le système de logs chiffrés
|
||||
|
||||
## Utilisation
|
||||
|
||||
### Initialisation
|
||||
|
||||
```python
|
||||
from core.whitelist_manager import WhitelistManager
|
||||
from core.logger import Logger
|
||||
|
||||
# Créer un logger
|
||||
logger = Logger()
|
||||
|
||||
# Créer le gestionnaire de liste blanche
|
||||
whitelist_manager = WhitelistManager(
|
||||
logger=logger,
|
||||
require_admin_confirmation=True # Nécessite confirmation pour ajouts
|
||||
)
|
||||
```
|
||||
|
||||
### Ajouter une fenêtre
|
||||
|
||||
```python
|
||||
# Ajouter avec confirmation admin
|
||||
success = whitelist_manager.add_to_whitelist(
|
||||
"Dolibarr*",
|
||||
admin_confirmed=True,
|
||||
added_by="admin_user"
|
||||
)
|
||||
|
||||
# Sans confirmation (si require_admin_confirmation=False)
|
||||
success = whitelist_manager.add_to_whitelist("Firefox*")
|
||||
```
|
||||
|
||||
### Vérifier si une fenêtre est autorisée
|
||||
|
||||
```python
|
||||
# Vérifier une fenêtre spécifique
|
||||
if whitelist_manager.is_window_allowed("Dolibarr - Facturation"):
|
||||
print("Fenêtre autorisée")
|
||||
else:
|
||||
print("Fenêtre bloquée")
|
||||
```
|
||||
|
||||
### Patterns avec wildcards
|
||||
|
||||
Le WhitelistManager supporte plusieurs types de patterns:
|
||||
|
||||
```python
|
||||
# Prefix wildcard - autorise toutes les fenêtres commençant par "Firefox"
|
||||
whitelist_manager.add_to_whitelist("Firefox*", admin_confirmed=True)
|
||||
# Autorise: "Firefox", "Firefox - Mozilla", "Firefox Developer Edition"
|
||||
|
||||
# Suffix wildcard - autorise toutes les fenêtres se terminant par "Chrome"
|
||||
whitelist_manager.add_to_whitelist("*Chrome", admin_confirmed=True)
|
||||
# Autorise: "Google Chrome", "Microsoft Edge Chrome"
|
||||
|
||||
# Wildcard au milieu - autorise les fenêtres avec pattern spécifique
|
||||
whitelist_manager.add_to_whitelist("Fire*fox", admin_confirmed=True)
|
||||
# Autorise: "Firefox", "Fire123fox", "Firefoxes"
|
||||
|
||||
# Correspondance exacte ou partielle
|
||||
whitelist_manager.add_to_whitelist("Visual Studio Code", admin_confirmed=True)
|
||||
# Autorise: "Visual Studio Code"
|
||||
```
|
||||
|
||||
### Supprimer une fenêtre
|
||||
|
||||
```python
|
||||
success = whitelist_manager.remove_from_whitelist("Firefox*")
|
||||
```
|
||||
|
||||
### Obtenir la liste blanche
|
||||
|
||||
```python
|
||||
whitelist = whitelist_manager.get_whitelist()
|
||||
print(f"Liste blanche: {whitelist}")
|
||||
```
|
||||
|
||||
### Obtenir des informations sur une entrée
|
||||
|
||||
```python
|
||||
info = whitelist_manager.get_entry_info("Dolibarr*")
|
||||
if info:
|
||||
print(f"Ajouté le: {info['added_at']}")
|
||||
print(f"Ajouté par: {info['added_by']}")
|
||||
print(f"Admin confirmé: {info['admin_confirmed']}")
|
||||
```
|
||||
|
||||
### Statistiques
|
||||
|
||||
```python
|
||||
stats = whitelist_manager.get_statistics()
|
||||
print(f"Total entrées: {stats['total_entries']}")
|
||||
print(f"Avec wildcards: {stats['entries_with_wildcards']}")
|
||||
print(f"Exactes: {stats['entries_exact']}")
|
||||
```
|
||||
|
||||
### Export/Import
|
||||
|
||||
```python
|
||||
# Exporter la liste blanche
|
||||
whitelist_manager.export_whitelist("whitelist_backup.json")
|
||||
|
||||
# Importer une liste blanche (remplace la liste actuelle)
|
||||
whitelist_manager.import_whitelist("whitelist_backup.json")
|
||||
|
||||
# Importer en fusionnant avec la liste existante
|
||||
whitelist_manager.import_whitelist(
|
||||
"whitelist_backup.json",
|
||||
merge=True,
|
||||
admin_confirmed=True
|
||||
)
|
||||
```
|
||||
|
||||
### Vider la liste blanche
|
||||
|
||||
```python
|
||||
whitelist_manager.clear_whitelist()
|
||||
```
|
||||
|
||||
## Intégration avec l'Orchestrateur
|
||||
|
||||
Le `WhitelistManager` est intégré dans l'`Orchestrator` pour vérifier automatiquement les fenêtres avant l'exécution d'actions:
|
||||
|
||||
```python
|
||||
from core.orchestrator import Orchestrator
|
||||
from core.whitelist_manager import WhitelistManager
|
||||
|
||||
# Créer le gestionnaire de liste blanche
|
||||
whitelist_manager = WhitelistManager(logger=logger)
|
||||
|
||||
# Ajouter des fenêtres autorisées
|
||||
whitelist_manager.add_to_whitelist("Dolibarr*", admin_confirmed=True)
|
||||
whitelist_manager.add_to_whitelist("Firefox*", admin_confirmed=True)
|
||||
|
||||
# Créer l'orchestrateur avec le gestionnaire
|
||||
orchestrator = Orchestrator(
|
||||
learning_manager=learning_manager,
|
||||
vision_utils=vision_utils,
|
||||
llm_manager=llm_manager,
|
||||
logger=logger,
|
||||
whitelist_manager=whitelist_manager
|
||||
)
|
||||
|
||||
# L'orchestrateur vérifiera automatiquement la liste blanche
|
||||
# avant d'exécuter des actions en mode Autopilot
|
||||
```
|
||||
|
||||
L'orchestrateur fournit également des méthodes de convenance:
|
||||
|
||||
```python
|
||||
# Ajouter via l'orchestrateur
|
||||
orchestrator.add_to_whitelist("Visual Studio Code", admin_confirmed=True)
|
||||
|
||||
# Supprimer via l'orchestrateur
|
||||
orchestrator.remove_from_whitelist("Visual Studio Code")
|
||||
|
||||
# Obtenir la liste blanche
|
||||
whitelist = orchestrator.get_whitelist()
|
||||
|
||||
# Activer/désactiver l'application de la liste blanche
|
||||
orchestrator.set_whitelist_enforcement(True)
|
||||
```
|
||||
|
||||
## Format du fichier de liste blanche
|
||||
|
||||
Le fichier `whitelist.json` est stocké dans `data/user_profiles/` et a le format suivant:
|
||||
|
||||
```json
|
||||
{
|
||||
"whitelist": [
|
||||
"Dolibarr*",
|
||||
"Firefox*",
|
||||
"Visual Studio Code"
|
||||
],
|
||||
"metadata": {
|
||||
"created_at": "2025-11-13T14:52:07.468307",
|
||||
"last_modified": "2025-11-13T14:52:07.483380",
|
||||
"version": "1.0",
|
||||
"entries": {
|
||||
"Dolibarr*": {
|
||||
"added_at": "2025-11-13T14:52:07.468307",
|
||||
"added_by": "admin_user",
|
||||
"admin_confirmed": true
|
||||
},
|
||||
"Firefox*": {
|
||||
"added_at": "2025-11-13T14:52:07.470123",
|
||||
"added_by": "user",
|
||||
"admin_confirmed": true
|
||||
}
|
||||
}
|
||||
}
|
||||
}
|
||||
```
|
||||
|
||||
## Sécurité
|
||||
|
||||
### Permissions de fichier
|
||||
Le fichier de liste blanche a des permissions restrictives (0600) pour empêcher les modifications non autorisées.
|
||||
|
||||
### Journalisation
|
||||
Toutes les opérations sont loggées dans le système de logs chiffrés:
|
||||
- Ajouts et suppressions d'entrées
|
||||
- Vérifications de fenêtres
|
||||
- Violations de liste blanche
|
||||
- Modifications de configuration
|
||||
|
||||
### Confirmation admin
|
||||
Lorsque `require_admin_confirmation=True`, les ajouts à la liste blanche nécessitent une confirmation explicite via le paramètre `admin_confirmed=True`.
|
||||
|
||||
## Exigences satisfaites
|
||||
|
||||
Le `WhitelistManager` satisfait les exigences suivantes du document de requirements:
|
||||
|
||||
- **Exigence 5.3**: Application de la liste blanche des fenêtres autorisées
|
||||
- **Exigence 5.4**: Blocage et journalisation des violations de liste blanche
|
||||
|
||||
## Tests
|
||||
|
||||
Des tests complets sont disponibles dans:
|
||||
- `geniusia2/core/whitelist_manager.py` (tests standalone)
|
||||
- `test_whitelist_simple.py` (tests d'intégration)
|
||||
|
||||
Pour exécuter les tests:
|
||||
|
||||
```bash
|
||||
# Tests standalone du module
|
||||
python3 geniusia2/core/whitelist_manager.py
|
||||
|
||||
# Tests d'intégration complets
|
||||
python3 test_whitelist_simple.py
|
||||
```
|
||||
|
||||
## Exemple complet
|
||||
|
||||
```python
|
||||
from core.whitelist_manager import WhitelistManager
|
||||
from core.logger import Logger
|
||||
from core.config import ensure_directories
|
||||
|
||||
# S'assurer que les répertoires existent
|
||||
ensure_directories()
|
||||
|
||||
# Créer le gestionnaire
|
||||
logger = Logger()
|
||||
wm = WhitelistManager(logger=logger, require_admin_confirmation=True)
|
||||
|
||||
# Configurer la liste blanche
|
||||
wm.add_to_whitelist("Dolibarr*", admin_confirmed=True, added_by="admin")
|
||||
wm.add_to_whitelist("Firefox*", admin_confirmed=True, added_by="admin")
|
||||
wm.add_to_whitelist("Visual Studio Code", admin_confirmed=True, added_by="admin")
|
||||
|
||||
# Vérifier des fenêtres
|
||||
test_windows = [
|
||||
"Dolibarr - Facturation",
|
||||
"Firefox - Mozilla",
|
||||
"Visual Studio Code",
|
||||
"Unknown Application"
|
||||
]
|
||||
|
||||
for window in test_windows:
|
||||
if wm.is_window_allowed(window):
|
||||
print(f"✓ Autorisé: {window}")
|
||||
else:
|
||||
print(f"✗ Bloqué: {window}")
|
||||
|
||||
# Afficher les statistiques
|
||||
stats = wm.get_statistics()
|
||||
print(f"\nStatistiques:")
|
||||
print(f" Total: {stats['total_entries']} entrées")
|
||||
print(f" Avec wildcards: {stats['entries_with_wildcards']}")
|
||||
print(f" Exactes: {stats['entries_exact']}")
|
||||
|
||||
# Exporter pour sauvegarde
|
||||
wm.export_whitelist("whitelist_backup.json")
|
||||
```
|
||||
|
||||
## Notes d'implémentation
|
||||
|
||||
- La correspondance de patterns est insensible à la casse
|
||||
- Les wildcards (`*`) peuvent être utilisés au début, à la fin, ou au milieu d'un pattern
|
||||
- Une liste blanche vide bloque toutes les fenêtres par défaut
|
||||
- Les métadonnées sont automatiquement mises à jour à chaque modification
|
||||
- Le fichier est sauvegardé automatiquement après chaque modification
|
||||
70
geniusia2/core/__init__.py
Normal file
70
geniusia2/core/__init__.py
Normal file
@@ -0,0 +1,70 @@
|
||||
"""
|
||||
Module core - Composants centraux du système RPA Vision V2
|
||||
"""
|
||||
|
||||
from .config import CONFIG, get_config, ensure_directories
|
||||
from .whitelist_manager import WhitelistManager
|
||||
|
||||
# UI Element Detection - Phase 1 (Mode Light)
|
||||
from .ui_element_models import (
|
||||
UIElement,
|
||||
UIElementType,
|
||||
VisualData,
|
||||
TextData,
|
||||
ElementProperties,
|
||||
ElementContext,
|
||||
EnrichedScreenState,
|
||||
WindowInfo,
|
||||
RawData,
|
||||
PerceptionData,
|
||||
StateEmbedding,
|
||||
EmbeddingComponents,
|
||||
ComponentInfo,
|
||||
ContextData
|
||||
)
|
||||
from .screen_state_manager import ScreenStateManager
|
||||
|
||||
# UI Element Detection - Phase 2 (Mode Enrichi)
|
||||
from .ui_element_detector import (
|
||||
UIElementDetector,
|
||||
RegionProposer,
|
||||
ElementCharacterizer,
|
||||
ElementClassifier,
|
||||
BoundingBox
|
||||
)
|
||||
from .enriched_screen_capture import EnrichedScreenCapture
|
||||
|
||||
# UI Element Detection - Phase 3 (Mode Complet)
|
||||
from .multimodal_embedding_manager import MultiModalEmbeddingManager
|
||||
|
||||
__all__ = [
|
||||
"CONFIG",
|
||||
"get_config",
|
||||
"ensure_directories",
|
||||
"WhitelistManager",
|
||||
# UI Element Detection - Phase 1
|
||||
"UIElement",
|
||||
"UIElementType",
|
||||
"VisualData",
|
||||
"TextData",
|
||||
"ElementProperties",
|
||||
"ElementContext",
|
||||
"EnrichedScreenState",
|
||||
"WindowInfo",
|
||||
"RawData",
|
||||
"PerceptionData",
|
||||
"StateEmbedding",
|
||||
"EmbeddingComponents",
|
||||
"ComponentInfo",
|
||||
"ContextData",
|
||||
"ScreenStateManager",
|
||||
# UI Element Detection - Phase 2
|
||||
"UIElementDetector",
|
||||
"RegionProposer",
|
||||
"ElementCharacterizer",
|
||||
"ElementClassifier",
|
||||
"BoundingBox",
|
||||
"EnrichedScreenCapture",
|
||||
# UI Element Detection - Phase 3
|
||||
"MultiModalEmbeddingManager"
|
||||
]
|
||||
BIN
geniusia2/core/__pycache__/__init__.cpython-312.pyc
Normal file
BIN
geniusia2/core/__pycache__/__init__.cpython-312.pyc
Normal file
Binary file not shown.
BIN
geniusia2/core/__pycache__/config.cpython-312.pyc
Normal file
BIN
geniusia2/core/__pycache__/config.cpython-312.pyc
Normal file
Binary file not shown.
BIN
geniusia2/core/__pycache__/embeddings_manager.cpython-312.pyc
Normal file
BIN
geniusia2/core/__pycache__/embeddings_manager.cpython-312.pyc
Normal file
Binary file not shown.
Binary file not shown.
Binary file not shown.
BIN
geniusia2/core/__pycache__/event_capture.cpython-312.pyc
Normal file
BIN
geniusia2/core/__pycache__/event_capture.cpython-312.pyc
Normal file
Binary file not shown.
BIN
geniusia2/core/__pycache__/faiss_index_builder.cpython-312.pyc
Normal file
BIN
geniusia2/core/__pycache__/faiss_index_builder.cpython-312.pyc
Normal file
Binary file not shown.
BIN
geniusia2/core/__pycache__/learning_manager.cpython-312.pyc
Normal file
BIN
geniusia2/core/__pycache__/learning_manager.cpython-312.pyc
Normal file
Binary file not shown.
BIN
geniusia2/core/__pycache__/llm_manager.cpython-312.pyc
Normal file
BIN
geniusia2/core/__pycache__/llm_manager.cpython-312.pyc
Normal file
Binary file not shown.
BIN
geniusia2/core/__pycache__/logger.cpython-312.pyc
Normal file
BIN
geniusia2/core/__pycache__/logger.cpython-312.pyc
Normal file
Binary file not shown.
BIN
geniusia2/core/__pycache__/metrics_collector.cpython-312.pyc
Normal file
BIN
geniusia2/core/__pycache__/metrics_collector.cpython-312.pyc
Normal file
Binary file not shown.
BIN
geniusia2/core/__pycache__/models.cpython-312.pyc
Normal file
BIN
geniusia2/core/__pycache__/models.cpython-312.pyc
Normal file
Binary file not shown.
Binary file not shown.
BIN
geniusia2/core/__pycache__/orchestrator.cpython-312.pyc
Normal file
BIN
geniusia2/core/__pycache__/orchestrator.cpython-312.pyc
Normal file
Binary file not shown.
BIN
geniusia2/core/__pycache__/replay_async.cpython-312.pyc
Normal file
BIN
geniusia2/core/__pycache__/replay_async.cpython-312.pyc
Normal file
Binary file not shown.
BIN
geniusia2/core/__pycache__/screen_state_manager.cpython-312.pyc
Normal file
BIN
geniusia2/core/__pycache__/screen_state_manager.cpython-312.pyc
Normal file
Binary file not shown.
BIN
geniusia2/core/__pycache__/session_manager.cpython-312.pyc
Normal file
BIN
geniusia2/core/__pycache__/session_manager.cpython-312.pyc
Normal file
Binary file not shown.
BIN
geniusia2/core/__pycache__/suggestion_manager.cpython-312.pyc
Normal file
BIN
geniusia2/core/__pycache__/suggestion_manager.cpython-312.pyc
Normal file
Binary file not shown.
BIN
geniusia2/core/__pycache__/task_replay.cpython-312.pyc
Normal file
BIN
geniusia2/core/__pycache__/task_replay.cpython-312.pyc
Normal file
Binary file not shown.
BIN
geniusia2/core/__pycache__/ui_change_detector.cpython-312.pyc
Normal file
BIN
geniusia2/core/__pycache__/ui_change_detector.cpython-312.pyc
Normal file
Binary file not shown.
BIN
geniusia2/core/__pycache__/ui_element_detector.cpython-312.pyc
Normal file
BIN
geniusia2/core/__pycache__/ui_element_detector.cpython-312.pyc
Normal file
Binary file not shown.
BIN
geniusia2/core/__pycache__/ui_element_models.cpython-312.pyc
Normal file
BIN
geniusia2/core/__pycache__/ui_element_models.cpython-312.pyc
Normal file
Binary file not shown.
BIN
geniusia2/core/__pycache__/vision_analysis.cpython-312.pyc
Normal file
BIN
geniusia2/core/__pycache__/vision_analysis.cpython-312.pyc
Normal file
Binary file not shown.
BIN
geniusia2/core/__pycache__/vision_search.cpython-312.pyc
Normal file
BIN
geniusia2/core/__pycache__/vision_search.cpython-312.pyc
Normal file
Binary file not shown.
BIN
geniusia2/core/__pycache__/whitelist_manager.cpython-312.pyc
Normal file
BIN
geniusia2/core/__pycache__/whitelist_manager.cpython-312.pyc
Normal file
Binary file not shown.
BIN
geniusia2/core/__pycache__/workflow_detector.cpython-312.pyc
Normal file
BIN
geniusia2/core/__pycache__/workflow_detector.cpython-312.pyc
Normal file
Binary file not shown.
BIN
geniusia2/core/__pycache__/workflow_matcher.cpython-312.pyc
Normal file
BIN
geniusia2/core/__pycache__/workflow_matcher.cpython-312.pyc
Normal file
Binary file not shown.
289
geniusia2/core/config.py
Normal file
289
geniusia2/core/config.py
Normal file
@@ -0,0 +1,289 @@
|
||||
"""
|
||||
Configuration globale pour RPA Vision V2
|
||||
Contient tous les paramètres pour les modèles, seuils, performance et sécurité
|
||||
"""
|
||||
|
||||
import os
|
||||
from pathlib import Path
|
||||
|
||||
# Répertoire racine du projet
|
||||
PROJECT_ROOT = Path(__file__).parent.parent
|
||||
|
||||
CONFIG = {
|
||||
# Configuration des modèles IA
|
||||
"models": {
|
||||
# Modèle de vision pour détection d'éléments UI
|
||||
# Options: "owl-v2", "dino", "yolo"
|
||||
"vision": "owl-v2",
|
||||
|
||||
# Modèle LLM pour raisonnement visuel
|
||||
"llm": "gemma3:12b", # Changé de qwen3-vl:8b à gemma3:12b (pas de thinking mode)
|
||||
|
||||
# Modèle OpenCLIP pour embeddings visuels
|
||||
"clip": "ViT-B-32",
|
||||
|
||||
# Chemins des modèles
|
||||
"paths": {
|
||||
"openclip": str(PROJECT_ROOT / "models" / "openclip"),
|
||||
"owl_v2": str(PROJECT_ROOT / "models" / "owl_v2"),
|
||||
"qwen_vl": str(PROJECT_ROOT / "models" / "qwen2.5_vl"),
|
||||
}
|
||||
},
|
||||
|
||||
# Seuils pour transitions de mode et confiance
|
||||
"thresholds": {
|
||||
# Nombre minimum d'observations avant passage en Autopilot
|
||||
"autopilot_observations": 20,
|
||||
|
||||
# Taux de concordance minimum pour passage en Autopilot (95%)
|
||||
"autopilot_concordance": 0.95,
|
||||
|
||||
# Score de confiance minimum pour maintenir le mode Autopilot (90%)
|
||||
"confidence_min": 0.90,
|
||||
|
||||
# Seuil de confiance pour rollback vers mode Assisté (85%)
|
||||
"rollback_confidence": 0.85,
|
||||
|
||||
# Seuil de concordance pour rollback vers mode Assisté (85%)
|
||||
"rollback_concordance": 0.85,
|
||||
|
||||
# Seuil de similarité d'embedding pour détection de changement UI (70%)
|
||||
"ui_change_threshold": 0.70,
|
||||
|
||||
# Delta maximum en pixels avant déclenchement de ré-entraînement
|
||||
"max_pixel_delta": 10,
|
||||
},
|
||||
|
||||
# Pondérations pour calcul du score de confiance
|
||||
"confidence_weights": {
|
||||
"vision": 0.6, # 60% - Confiance du modèle de vision
|
||||
"llm": 0.3, # 30% - Score du raisonnement LLM
|
||||
"history": 0.1, # 10% - Performance historique de la tâche
|
||||
},
|
||||
|
||||
# Paramètres de performance
|
||||
"performance": {
|
||||
# Latence maximale observation-vers-suggestion (ms)
|
||||
"max_latency_ms": 400,
|
||||
|
||||
# Taux de correction maximum acceptable (3%)
|
||||
"max_correction_rate": 0.03,
|
||||
|
||||
# Nombre d'exécutions récentes pour calcul de concordance
|
||||
"concordance_window": 10,
|
||||
|
||||
# Nombre d'exécutions pour calcul du taux de correction
|
||||
"correction_window": 20,
|
||||
|
||||
# Timeout pour opérations de rollback (secondes)
|
||||
"rollback_timeout_s": 5,
|
||||
},
|
||||
|
||||
# Configuration de sécurité
|
||||
"security": {
|
||||
# Algorithme de chiffrement pour les logs
|
||||
"encryption_algorithm": "AES-256-CBC",
|
||||
|
||||
# Durée de rétention des logs (jours)
|
||||
"log_retention_days": 90,
|
||||
|
||||
# Fréquence de rotation des clés de chiffrement (jours)
|
||||
"key_rotation_days": 90,
|
||||
|
||||
# Liste blanche par défaut (vide - doit être configurée par l'utilisateur)
|
||||
"default_whitelist": [],
|
||||
|
||||
# Activer la vérification de liste blanche
|
||||
# True: Seules les fenêtres autorisées sont observées (sécurisé)
|
||||
# False: Toutes les fenêtres sont observées (mode permissif)
|
||||
"enforce_whitelist": False, # Changé à False pour plus de flexibilité
|
||||
|
||||
# Demander confirmation avant d'observer une nouvelle fenêtre (si enforce_whitelist=False)
|
||||
"ask_before_new_window": True,
|
||||
},
|
||||
|
||||
# Configuration Ollama
|
||||
"ollama": {
|
||||
"host": "localhost:11434",
|
||||
"timeout": 30, # secondes
|
||||
},
|
||||
|
||||
# Configuration FAISS
|
||||
"faiss": {
|
||||
# Dimension des embeddings OpenCLIP
|
||||
"embedding_dim": 512,
|
||||
|
||||
# Type d'index FAISS ("Flat" pour petit dataset, "IVF" pour grand)
|
||||
"index_type": "Flat",
|
||||
|
||||
# Nombre de clusters pour index IVF (si utilisé)
|
||||
"n_clusters": 100,
|
||||
|
||||
# Nombre de voisins à retourner lors de la recherche
|
||||
"k_neighbors": 5,
|
||||
},
|
||||
|
||||
# Chemins des données
|
||||
"data_paths": {
|
||||
"user_profiles": str(PROJECT_ROOT / "data" / "user_profiles"),
|
||||
"logs": str(PROJECT_ROOT / "data" / "logs"),
|
||||
"faiss_index": str(PROJECT_ROOT / "data" / "faiss_index"),
|
||||
"encryption_keys": str(PROJECT_ROOT / "data" / ".keys"),
|
||||
},
|
||||
|
||||
# Configuration GUI
|
||||
"gui": {
|
||||
# Icônes pour les modes
|
||||
"mode_icons": {
|
||||
"shadow": "👀",
|
||||
"assist": "🤝",
|
||||
"auto": "🤖",
|
||||
},
|
||||
|
||||
# Durée d'affichage des notifications (secondes)
|
||||
"notification_timeout": 5,
|
||||
|
||||
# Raccourcis clavier
|
||||
"shortcuts": {
|
||||
"accept": "Return", # Entrée
|
||||
"reject": "Escape", # Échap
|
||||
"correct": "Alt+C", # Alt+C
|
||||
"emergency_stop": "Ctrl+Pause", # Ctrl+Pause
|
||||
},
|
||||
},
|
||||
|
||||
# Configuration du logger
|
||||
"logging": {
|
||||
"level": "INFO",
|
||||
"format": "%(asctime)s - %(name)s - %(levelname)s - %(message)s",
|
||||
"max_file_size_mb": 100,
|
||||
"backup_count": 5,
|
||||
},
|
||||
|
||||
# Configuration UI Element Detection (Phase 1-3)
|
||||
"ui_detection": {
|
||||
# Mode de détection: "light", "enriched", "complete"
|
||||
# - light: Structures de données seulement (pas de détection)
|
||||
# - enriched: Détection d'éléments UI activée
|
||||
# - complete: Détection + embeddings multi-modaux + matching amélioré
|
||||
"mode": "light", # Démarrer en mode light par défaut
|
||||
|
||||
# Activer la détection automatique d'éléments
|
||||
"enabled": True,
|
||||
|
||||
# Détecter les éléments à chaque capture d'écran
|
||||
"detect_on_capture": False, # False = détection manuelle uniquement
|
||||
|
||||
# Utiliser le VLM pour la détection (coûteux)
|
||||
"vlm_enabled": False,
|
||||
|
||||
# Configuration du détecteur
|
||||
"detector": {
|
||||
"use_text_detection": True,
|
||||
"use_rectangle_detection": True,
|
||||
"use_vlm_detection": False,
|
||||
"vlm_on_new_screens": False,
|
||||
},
|
||||
},
|
||||
|
||||
# Configuration Multi-Modal Embedding (Phase 3)
|
||||
"multimodal_embedding": {
|
||||
# Activer les embeddings multi-modaux
|
||||
"enabled": False, # Activé automatiquement en mode "complete"
|
||||
|
||||
# Dimension des embeddings
|
||||
"embedding_dim": 512,
|
||||
|
||||
# Méthode de fusion
|
||||
"fusion_method": "weighted_average",
|
||||
|
||||
# Poids de fusion des modalités
|
||||
"weights": {
|
||||
"image": 0.5, # Screenshot complet
|
||||
"text": 0.3, # Texte détecté
|
||||
"title": 0.1, # Titre de fenêtre
|
||||
"ui": 0.1, # Éléments UI
|
||||
"context": 0.0, # Contexte workflow (à implémenter)
|
||||
},
|
||||
},
|
||||
|
||||
# Configuration Enhanced Workflow Matcher (Phase 3)
|
||||
"enhanced_matcher": {
|
||||
# Activer le matcher amélioré
|
||||
"enabled": True,
|
||||
|
||||
# Poids pour le scoring composite
|
||||
"screen_weight": 0.6, # Poids de la similarité d'écran
|
||||
"elements_weight": 0.4, # Poids de la similarité d'éléments
|
||||
|
||||
# Seuils de matching
|
||||
"min_similarity_threshold": 0.3,
|
||||
"min_confidence_threshold": 0.5,
|
||||
|
||||
# Nombre maximum de candidats à évaluer
|
||||
"max_candidates": 10,
|
||||
},
|
||||
}
|
||||
|
||||
|
||||
def get_config():
|
||||
"""Retourne la configuration globale"""
|
||||
return CONFIG
|
||||
|
||||
|
||||
def get_model_config():
|
||||
"""Retourne la configuration des modèles"""
|
||||
return CONFIG["models"]
|
||||
|
||||
|
||||
def get_thresholds():
|
||||
"""Retourne les seuils de confiance et transitions"""
|
||||
return CONFIG["thresholds"]
|
||||
|
||||
|
||||
def get_performance_config():
|
||||
"""Retourne la configuration de performance"""
|
||||
return CONFIG["performance"]
|
||||
|
||||
|
||||
def get_security_config():
|
||||
"""Retourne la configuration de sécurité"""
|
||||
return CONFIG["security"]
|
||||
|
||||
|
||||
def get_data_paths():
|
||||
"""Retourne les chemins des données"""
|
||||
return CONFIG["data_paths"]
|
||||
|
||||
|
||||
def ensure_directories():
|
||||
"""Crée tous les répertoires nécessaires s'ils n'existent pas"""
|
||||
paths = get_data_paths()
|
||||
for path in paths.values():
|
||||
os.makedirs(path, exist_ok=True)
|
||||
|
||||
# Créer aussi les répertoires de modèles
|
||||
model_paths = CONFIG["models"]["paths"]
|
||||
for path in model_paths.values():
|
||||
os.makedirs(path, exist_ok=True)
|
||||
|
||||
|
||||
if __name__ == "__main__":
|
||||
# Test de la configuration
|
||||
print("Configuration RPA Vision V2")
|
||||
print("=" * 50)
|
||||
print(f"Répertoire racine: {PROJECT_ROOT}")
|
||||
print(f"\nModèles:")
|
||||
for key, value in CONFIG["models"].items():
|
||||
if key != "paths":
|
||||
print(f" {key}: {value}")
|
||||
print(f"\nSeuils:")
|
||||
for key, value in CONFIG["thresholds"].items():
|
||||
print(f" {key}: {value}")
|
||||
print(f"\nPerformance:")
|
||||
for key, value in CONFIG["performance"].items():
|
||||
print(f" {key}: {value}")
|
||||
print(f"\nSécurité:")
|
||||
for key, value in CONFIG["security"].items():
|
||||
if key != "default_whitelist":
|
||||
print(f" {key}: {value}")
|
||||
19
geniusia2/core/embedders/__init__.py
Normal file
19
geniusia2/core/embedders/__init__.py
Normal file
@@ -0,0 +1,19 @@
|
||||
"""
|
||||
Embedding system for visual similarity matching.
|
||||
|
||||
This module provides an abstraction layer for different embedding models
|
||||
(CLIP, Pix2Struct) used for workflow matching and visual analysis.
|
||||
"""
|
||||
|
||||
from .base import EmbedderBase
|
||||
from .clip_embedder import CLIPEmbedder
|
||||
from .faiss_index import FAISSIndex
|
||||
from .embedding_manager import EmbeddingManager
|
||||
from .fine_tuner import LightweightFineTuner
|
||||
|
||||
# Pix2Struct is optional (requires transformers>=4.35.0)
|
||||
try:
|
||||
from .pix2struct_embedder import Pix2StructEmbedder
|
||||
__all__ = ['EmbedderBase', 'CLIPEmbedder', 'Pix2StructEmbedder', 'FAISSIndex', 'EmbeddingManager', 'LightweightFineTuner']
|
||||
except ImportError:
|
||||
__all__ = ['EmbedderBase', 'CLIPEmbedder', 'FAISSIndex', 'EmbeddingManager', 'LightweightFineTuner']
|
||||
BIN
geniusia2/core/embedders/__pycache__/__init__.cpython-312.pyc
Normal file
BIN
geniusia2/core/embedders/__pycache__/__init__.cpython-312.pyc
Normal file
Binary file not shown.
BIN
geniusia2/core/embedders/__pycache__/base.cpython-312.pyc
Normal file
BIN
geniusia2/core/embedders/__pycache__/base.cpython-312.pyc
Normal file
Binary file not shown.
Binary file not shown.
Binary file not shown.
BIN
geniusia2/core/embedders/__pycache__/faiss_index.cpython-312.pyc
Normal file
BIN
geniusia2/core/embedders/__pycache__/faiss_index.cpython-312.pyc
Normal file
Binary file not shown.
BIN
geniusia2/core/embedders/__pycache__/fine_tuner.cpython-312.pyc
Normal file
BIN
geniusia2/core/embedders/__pycache__/fine_tuner.cpython-312.pyc
Normal file
Binary file not shown.
Binary file not shown.
100
geniusia2/core/embedders/base.py
Normal file
100
geniusia2/core/embedders/base.py
Normal file
@@ -0,0 +1,100 @@
|
||||
"""
|
||||
Abstract base class for embedding models.
|
||||
|
||||
This module defines the interface that all embedding models must implement,
|
||||
ensuring consistency across different model implementations (CLIP, Pix2Struct, etc.).
|
||||
"""
|
||||
|
||||
from abc import ABC, abstractmethod
|
||||
from typing import List
|
||||
from PIL import Image
|
||||
import numpy as np
|
||||
|
||||
|
||||
class EmbedderBase(ABC):
|
||||
"""
|
||||
Abstract base class for image embedding models.
|
||||
|
||||
All embedding models must implement this interface to ensure
|
||||
compatibility with the workflow matching system.
|
||||
"""
|
||||
|
||||
@abstractmethod
|
||||
def embed(self, image: Image.Image) -> np.ndarray:
|
||||
"""
|
||||
Generate an embedding vector for a single image.
|
||||
|
||||
Args:
|
||||
image: PIL Image to embed
|
||||
|
||||
Returns:
|
||||
np.ndarray: Normalized embedding vector of shape (dimension,)
|
||||
The vector should be L2-normalized for cosine similarity
|
||||
|
||||
Raises:
|
||||
ValueError: If image is invalid or cannot be processed
|
||||
RuntimeError: If model inference fails
|
||||
"""
|
||||
pass
|
||||
|
||||
@abstractmethod
|
||||
def get_dimension(self) -> int:
|
||||
"""
|
||||
Get the dimensionality of embeddings produced by this model.
|
||||
|
||||
Returns:
|
||||
int: Embedding dimension (e.g., 512 for CLIP ViT-B/32, 768 for Pix2Struct)
|
||||
"""
|
||||
pass
|
||||
|
||||
@abstractmethod
|
||||
def get_model_name(self) -> str:
|
||||
"""
|
||||
Get a unique identifier for this model.
|
||||
|
||||
Returns:
|
||||
str: Model name (e.g., "clip-vit-b32", "pix2struct-base")
|
||||
"""
|
||||
pass
|
||||
|
||||
@abstractmethod
|
||||
def supports_batch(self) -> bool:
|
||||
"""
|
||||
Check if this model supports batch processing.
|
||||
|
||||
Returns:
|
||||
bool: True if embed_batch() is optimized, False otherwise
|
||||
"""
|
||||
pass
|
||||
|
||||
def embed_batch(self, images: List[Image.Image]) -> np.ndarray:
|
||||
"""
|
||||
Generate embeddings for multiple images.
|
||||
|
||||
Default implementation processes images one by one.
|
||||
Subclasses can override this for optimized batch processing.
|
||||
|
||||
Args:
|
||||
images: List of PIL Images to embed
|
||||
|
||||
Returns:
|
||||
np.ndarray: Array of embeddings with shape (len(images), dimension)
|
||||
Each row is a normalized embedding vector
|
||||
|
||||
Raises:
|
||||
ValueError: If any image is invalid
|
||||
RuntimeError: If model inference fails
|
||||
"""
|
||||
if not images:
|
||||
return np.array([]).reshape(0, self.get_dimension())
|
||||
|
||||
embeddings = []
|
||||
for img in images:
|
||||
embedding = self.embed(img)
|
||||
embeddings.append(embedding)
|
||||
|
||||
return np.array(embeddings)
|
||||
|
||||
def __repr__(self) -> str:
|
||||
"""String representation of the embedder."""
|
||||
return f"{self.__class__.__name__}(model={self.get_model_name()}, dim={self.get_dimension()})"
|
||||
358
geniusia2/core/embedders/clip_embedder.py
Normal file
358
geniusia2/core/embedders/clip_embedder.py
Normal file
@@ -0,0 +1,358 @@
|
||||
"""
|
||||
CLIP-based embedder implementation.
|
||||
|
||||
This module provides a wrapper around OpenCLIP for generating image embeddings
|
||||
using the CLIP (Contrastive Language-Image Pre-training) model.
|
||||
"""
|
||||
|
||||
import torch
|
||||
import numpy as np
|
||||
from PIL import Image
|
||||
from typing import List, Optional
|
||||
import logging
|
||||
|
||||
try:
|
||||
import open_clip
|
||||
except ImportError:
|
||||
open_clip = None
|
||||
|
||||
from .base import EmbedderBase
|
||||
|
||||
|
||||
logger = logging.getLogger(__name__)
|
||||
|
||||
|
||||
class CLIPEmbedder(EmbedderBase):
|
||||
"""
|
||||
CLIP-based image embedder using OpenCLIP.
|
||||
|
||||
This embedder uses the ViT-B/32 architecture by default, which produces
|
||||
512-dimensional embeddings. It automatically handles GPU/CPU device selection.
|
||||
"""
|
||||
|
||||
def __init__(
|
||||
self,
|
||||
model_name: str = "ViT-B-32",
|
||||
pretrained: str = "openai",
|
||||
device: Optional[str] = None
|
||||
):
|
||||
"""
|
||||
Initialize the CLIP embedder.
|
||||
|
||||
Args:
|
||||
model_name: CLIP model architecture (default: ViT-B-32)
|
||||
pretrained: Pretrained weights to use (default: openai)
|
||||
device: Device to use ('cuda', 'cpu', or None for auto-detect)
|
||||
Note: Defaults to CPU to save GPU memory for other models
|
||||
|
||||
Raises:
|
||||
ImportError: If open_clip is not installed
|
||||
RuntimeError: If model loading fails
|
||||
"""
|
||||
if open_clip is None:
|
||||
raise ImportError(
|
||||
"OpenCLIP is not installed. "
|
||||
"Install it with: pip install open-clip-torch"
|
||||
)
|
||||
|
||||
# Default to CPU to save GPU for vision models (Qwen3-VL)
|
||||
if device is None:
|
||||
device = "cpu"
|
||||
|
||||
self.model_name = model_name
|
||||
self.pretrained = pretrained
|
||||
self.device = device
|
||||
self._embedding_dim = None
|
||||
|
||||
# Load model
|
||||
try:
|
||||
self.model, _, self.preprocess = open_clip.create_model_and_transforms(
|
||||
model_name,
|
||||
pretrained=pretrained,
|
||||
device=device
|
||||
)
|
||||
self.model.eval()
|
||||
|
||||
# Determine embedding dimension
|
||||
with torch.no_grad():
|
||||
dummy_image = torch.zeros(1, 3, 224, 224).to(self.device)
|
||||
dummy_embedding = self.model.encode_image(dummy_image)
|
||||
self._embedding_dim = dummy_embedding.shape[-1]
|
||||
|
||||
logger.info(
|
||||
f"CLIPEmbedder loaded: {model_name} on {device}, "
|
||||
f"dimension={self._embedding_dim}"
|
||||
)
|
||||
|
||||
except Exception as e:
|
||||
raise RuntimeError(f"Failed to load CLIP model: {e}")
|
||||
|
||||
def embed(self, image: Image.Image) -> np.ndarray:
|
||||
"""
|
||||
Generate embedding for a single image.
|
||||
|
||||
Args:
|
||||
image: PIL Image to embed
|
||||
|
||||
Returns:
|
||||
np.ndarray: Normalized embedding vector of shape (dimension,)
|
||||
|
||||
Raises:
|
||||
ValueError: If image is invalid
|
||||
RuntimeError: If embedding generation fails
|
||||
"""
|
||||
if not isinstance(image, Image.Image):
|
||||
raise ValueError("Input must be a PIL Image")
|
||||
|
||||
try:
|
||||
# Preprocess image
|
||||
image_tensor = self.preprocess(image).unsqueeze(0).to(self.device)
|
||||
|
||||
# Generate embedding
|
||||
with torch.no_grad():
|
||||
embedding = self.model.encode_image(image_tensor)
|
||||
# L2 normalize for cosine similarity
|
||||
embedding = embedding / embedding.norm(dim=-1, keepdim=True)
|
||||
|
||||
return embedding.cpu().numpy().flatten()
|
||||
|
||||
except Exception as e:
|
||||
raise RuntimeError(f"Failed to generate embedding: {e}")
|
||||
|
||||
def embed_batch(self, images: List[Image.Image]) -> np.ndarray:
|
||||
"""
|
||||
Generate embeddings for multiple images (optimized batch processing).
|
||||
|
||||
Args:
|
||||
images: List of PIL Images to embed
|
||||
|
||||
Returns:
|
||||
np.ndarray: Array of embeddings with shape (len(images), dimension)
|
||||
|
||||
Raises:
|
||||
ValueError: If any image is invalid
|
||||
RuntimeError: If embedding generation fails
|
||||
"""
|
||||
if not images:
|
||||
return np.array([]).reshape(0, self.get_dimension())
|
||||
|
||||
# Validate all images
|
||||
for i, img in enumerate(images):
|
||||
if not isinstance(img, Image.Image):
|
||||
raise ValueError(f"Image at index {i} is not a PIL Image")
|
||||
|
||||
try:
|
||||
# Preprocess all images
|
||||
image_tensors = torch.stack([
|
||||
self.preprocess(img) for img in images
|
||||
]).to(self.device)
|
||||
|
||||
# Generate embeddings in batch
|
||||
with torch.no_grad():
|
||||
embeddings = self.model.encode_image(image_tensors)
|
||||
# L2 normalize for cosine similarity
|
||||
embeddings = embeddings / embeddings.norm(dim=-1, keepdim=True)
|
||||
|
||||
return embeddings.cpu().numpy()
|
||||
|
||||
except Exception as e:
|
||||
raise RuntimeError(f"Failed to generate batch embeddings: {e}")
|
||||
|
||||
def get_dimension(self) -> int:
|
||||
"""
|
||||
Get the dimensionality of embeddings.
|
||||
|
||||
Returns:
|
||||
int: Embedding dimension (512 for ViT-B/32)
|
||||
"""
|
||||
return self._embedding_dim
|
||||
|
||||
def get_model_name(self) -> str:
|
||||
"""
|
||||
Get model identifier.
|
||||
|
||||
Returns:
|
||||
str: Model name (e.g., "clip-vit-b32")
|
||||
"""
|
||||
return f"clip-{self.model_name.lower().replace('/', '-')}"
|
||||
|
||||
def supports_batch(self) -> bool:
|
||||
"""
|
||||
Check if batch processing is supported.
|
||||
|
||||
Returns:
|
||||
bool: True (CLIP supports efficient batch processing)
|
||||
"""
|
||||
return True
|
||||
|
||||
def fine_tune(
|
||||
self,
|
||||
positive_images: List[Image.Image],
|
||||
negative_images: List[Image.Image],
|
||||
epochs: int = 1,
|
||||
learning_rate: float = 1e-4
|
||||
) -> dict:
|
||||
"""
|
||||
Fine-tune the model using contrastive learning.
|
||||
|
||||
This method fine-tunes only the final projection layer to adapt
|
||||
the model to user-specific workflows. It uses a simple contrastive
|
||||
loss: positive examples should be similar, negative examples should
|
||||
be dissimilar.
|
||||
|
||||
Args:
|
||||
positive_images: Images from successful workflows
|
||||
negative_images: Images from rejected workflows
|
||||
epochs: Number of training epochs (default: 1 for speed)
|
||||
learning_rate: Learning rate (default: 1e-4)
|
||||
|
||||
Returns:
|
||||
dict: Training metrics (loss, accuracy, etc.)
|
||||
"""
|
||||
if not positive_images and not negative_images:
|
||||
return {'loss': 0.0, 'note': 'No examples to train on'}
|
||||
|
||||
# Set model to training mode
|
||||
self.model.train()
|
||||
|
||||
# Only train the visual projection layer (last layer)
|
||||
# Freeze all other parameters
|
||||
for param in self.model.parameters():
|
||||
param.requires_grad = False
|
||||
|
||||
# Unfreeze visual projection
|
||||
if hasattr(self.model, 'visual') and hasattr(self.model.visual, 'proj'):
|
||||
if self.model.visual.proj is not None:
|
||||
self.model.visual.proj.requires_grad = True
|
||||
|
||||
# Setup optimizer (only for trainable parameters)
|
||||
trainable_params = [p for p in self.model.parameters() if p.requires_grad]
|
||||
|
||||
if not trainable_params:
|
||||
logger.warning("No trainable parameters found, skipping fine-tuning")
|
||||
self.model.eval()
|
||||
return {'loss': 0.0, 'note': 'No trainable parameters'}
|
||||
|
||||
optimizer = torch.optim.Adam(trainable_params, lr=learning_rate)
|
||||
|
||||
total_loss = 0.0
|
||||
num_batches = 0
|
||||
|
||||
try:
|
||||
for epoch in range(epochs):
|
||||
# Process positive examples (should be similar to each other)
|
||||
if len(positive_images) >= 2:
|
||||
pos_loss = self._contrastive_loss_positive(positive_images)
|
||||
total_loss += pos_loss
|
||||
num_batches += 1
|
||||
|
||||
optimizer.zero_grad()
|
||||
pos_loss.backward()
|
||||
optimizer.step()
|
||||
|
||||
# Process negative examples (should be dissimilar from positives)
|
||||
if positive_images and negative_images:
|
||||
neg_loss = self._contrastive_loss_negative(
|
||||
positive_images[:5], # Use subset for speed
|
||||
negative_images[:5]
|
||||
)
|
||||
total_loss += neg_loss
|
||||
num_batches += 1
|
||||
|
||||
optimizer.zero_grad()
|
||||
neg_loss.backward()
|
||||
optimizer.step()
|
||||
|
||||
finally:
|
||||
# Always restore eval mode and freeze parameters
|
||||
self.model.eval()
|
||||
for param in self.model.parameters():
|
||||
param.requires_grad = False
|
||||
|
||||
avg_loss = total_loss / num_batches if num_batches > 0 else 0.0
|
||||
|
||||
return {
|
||||
'loss': float(avg_loss),
|
||||
'epochs': epochs,
|
||||
'learning_rate': learning_rate,
|
||||
'positive_count': len(positive_images),
|
||||
'negative_count': len(negative_images),
|
||||
'num_batches': num_batches
|
||||
}
|
||||
|
||||
def _contrastive_loss_positive(self, images: List[Image.Image]) -> torch.Tensor:
|
||||
"""
|
||||
Contrastive loss for positive examples (should be similar).
|
||||
|
||||
Args:
|
||||
images: List of positive example images
|
||||
|
||||
Returns:
|
||||
torch.Tensor: Loss value
|
||||
"""
|
||||
# Generate embeddings
|
||||
embeddings = []
|
||||
for img in images:
|
||||
img_tensor = self.preprocess(img).unsqueeze(0).to(self.device)
|
||||
emb = self.model.encode_image(img_tensor)
|
||||
emb = emb / emb.norm(dim=-1, keepdim=True)
|
||||
embeddings.append(emb)
|
||||
|
||||
embeddings = torch.cat(embeddings, dim=0)
|
||||
|
||||
# Compute pairwise cosine similarities
|
||||
similarities = torch.mm(embeddings, embeddings.t())
|
||||
|
||||
# Loss: maximize similarity (minimize negative similarity)
|
||||
# Exclude diagonal (self-similarity)
|
||||
mask = torch.eye(len(images), device=self.device).bool()
|
||||
similarities = similarities.masked_fill(mask, 0)
|
||||
|
||||
# We want high similarity, so minimize (1 - similarity)
|
||||
loss = (1 - similarities).mean()
|
||||
|
||||
return loss
|
||||
|
||||
def _contrastive_loss_negative(
|
||||
self,
|
||||
positive_images: List[Image.Image],
|
||||
negative_images: List[Image.Image]
|
||||
) -> torch.Tensor:
|
||||
"""
|
||||
Contrastive loss for negative examples (should be dissimilar).
|
||||
|
||||
Args:
|
||||
positive_images: Positive example images
|
||||
negative_images: Negative example images
|
||||
|
||||
Returns:
|
||||
torch.Tensor: Loss value
|
||||
"""
|
||||
# Generate embeddings for positives
|
||||
pos_embeddings = []
|
||||
for img in positive_images:
|
||||
img_tensor = self.preprocess(img).unsqueeze(0).to(self.device)
|
||||
emb = self.model.encode_image(img_tensor)
|
||||
emb = emb / emb.norm(dim=-1, keepdim=True)
|
||||
pos_embeddings.append(emb)
|
||||
|
||||
pos_embeddings = torch.cat(pos_embeddings, dim=0)
|
||||
|
||||
# Generate embeddings for negatives
|
||||
neg_embeddings = []
|
||||
for img in negative_images:
|
||||
img_tensor = self.preprocess(img).unsqueeze(0).to(self.device)
|
||||
emb = self.model.encode_image(img_tensor)
|
||||
emb = emb / emb.norm(dim=-1, keepdim=True)
|
||||
neg_embeddings.append(emb)
|
||||
|
||||
neg_embeddings = torch.cat(neg_embeddings, dim=0)
|
||||
|
||||
# Compute cross-similarities (positive vs negative)
|
||||
similarities = torch.mm(pos_embeddings, neg_embeddings.t())
|
||||
|
||||
# We want low similarity, so minimize similarity directly
|
||||
# (or maximize dissimilarity)
|
||||
loss = similarities.mean()
|
||||
|
||||
return loss
|
||||
309
geniusia2/core/embedders/embedding_manager.py
Normal file
309
geniusia2/core/embedders/embedding_manager.py
Normal file
@@ -0,0 +1,309 @@
|
||||
"""
|
||||
Embedding manager with model selection, caching, and fallback.
|
||||
|
||||
This module provides a high-level interface for generating embeddings,
|
||||
with automatic model selection, LRU caching, and fallback to CLIP if
|
||||
the selected model fails to load.
|
||||
"""
|
||||
|
||||
import hashlib
|
||||
import logging
|
||||
from typing import Optional, Dict, Any
|
||||
from collections import OrderedDict
|
||||
from PIL import Image
|
||||
import numpy as np
|
||||
|
||||
from .base import EmbedderBase
|
||||
from .clip_embedder import CLIPEmbedder
|
||||
|
||||
|
||||
logger = logging.getLogger(__name__)
|
||||
|
||||
|
||||
class EmbeddingManager:
|
||||
"""
|
||||
High-level manager for image embeddings.
|
||||
|
||||
Features:
|
||||
- Model selection (CLIP, Pix2Struct, etc.)
|
||||
- Automatic fallback to CLIP on errors
|
||||
- LRU cache (1000 entries) for performance
|
||||
- GPU/CPU management
|
||||
- Logging and monitoring
|
||||
"""
|
||||
|
||||
def __init__(
|
||||
self,
|
||||
model_name: str = "clip",
|
||||
fallback_enabled: bool = True,
|
||||
cache_size: int = 1000,
|
||||
device: Optional[str] = None
|
||||
):
|
||||
"""
|
||||
Initialize the embedding manager.
|
||||
|
||||
Args:
|
||||
model_name: Model to use ("clip" or "pix2struct")
|
||||
fallback_enabled: If True, fallback to CLIP on model load failure
|
||||
cache_size: Maximum number of cached embeddings (LRU eviction)
|
||||
device: Device to use ('cuda', 'cpu', or None for auto)
|
||||
|
||||
Raises:
|
||||
RuntimeError: If model loading fails and fallback is disabled
|
||||
"""
|
||||
self.model_name = model_name.lower()
|
||||
self.fallback_enabled = fallback_enabled
|
||||
self.cache_size = cache_size
|
||||
self.device = device
|
||||
|
||||
# Initialize embedder
|
||||
self.embedder = self._load_embedder()
|
||||
|
||||
# Initialize LRU cache
|
||||
self._cache: OrderedDict[str, np.ndarray] = OrderedDict()
|
||||
|
||||
# Statistics
|
||||
self._cache_hits = 0
|
||||
self._cache_misses = 0
|
||||
|
||||
logger.info(
|
||||
f"EmbeddingManager initialized: model={self.embedder.get_model_name()}, "
|
||||
f"dimension={self.embedder.get_dimension()}, "
|
||||
f"cache_size={cache_size}"
|
||||
)
|
||||
|
||||
def _load_embedder(self) -> EmbedderBase:
|
||||
"""
|
||||
Load the specified embedder with fallback support.
|
||||
|
||||
Returns:
|
||||
EmbedderBase: Loaded embedder instance
|
||||
|
||||
Raises:
|
||||
RuntimeError: If loading fails and fallback is disabled
|
||||
"""
|
||||
try:
|
||||
if self.model_name == "clip":
|
||||
return CLIPEmbedder(device=self.device)
|
||||
|
||||
elif self.model_name == "pix2struct":
|
||||
# Import here to avoid dependency if not used
|
||||
try:
|
||||
from .pix2struct_embedder import Pix2StructEmbedder
|
||||
return Pix2StructEmbedder(device=self.device)
|
||||
except ImportError as e:
|
||||
if self.fallback_enabled:
|
||||
logger.warning(
|
||||
f"Pix2Struct not available ({e}), falling back to CLIP"
|
||||
)
|
||||
return CLIPEmbedder(device=self.device)
|
||||
raise
|
||||
|
||||
else:
|
||||
raise ValueError(f"Unknown model: {self.model_name}")
|
||||
|
||||
except Exception as e:
|
||||
if self.fallback_enabled:
|
||||
logger.warning(
|
||||
f"Failed to load {self.model_name} ({e}), falling back to CLIP"
|
||||
)
|
||||
return CLIPEmbedder(device=self.device)
|
||||
raise RuntimeError(f"Failed to load embedder: {e}")
|
||||
|
||||
def embed(self, image: Image.Image, use_cache: bool = True) -> np.ndarray:
|
||||
"""
|
||||
Generate embedding for an image with caching.
|
||||
|
||||
Args:
|
||||
image: PIL Image to embed
|
||||
use_cache: If True, use cache for identical images
|
||||
|
||||
Returns:
|
||||
np.ndarray: Normalized embedding vector
|
||||
|
||||
Raises:
|
||||
ValueError: If image is invalid
|
||||
RuntimeError: If embedding generation fails
|
||||
"""
|
||||
if not isinstance(image, Image.Image):
|
||||
raise ValueError("Input must be a PIL Image")
|
||||
|
||||
# Check cache if enabled
|
||||
if use_cache:
|
||||
cache_key = self._get_cache_key(image)
|
||||
|
||||
if cache_key in self._cache:
|
||||
# Move to end (most recently used)
|
||||
self._cache.move_to_end(cache_key)
|
||||
self._cache_hits += 1
|
||||
logger.debug(f"Cache hit (total: {self._cache_hits})")
|
||||
return self._cache[cache_key]
|
||||
|
||||
self._cache_misses += 1
|
||||
|
||||
# Generate embedding
|
||||
embedding = self.embedder.embed(image)
|
||||
|
||||
# Store in cache
|
||||
if use_cache:
|
||||
self._add_to_cache(cache_key, embedding)
|
||||
|
||||
return embedding
|
||||
|
||||
def embed_batch(
|
||||
self,
|
||||
images: list[Image.Image],
|
||||
use_cache: bool = True
|
||||
) -> np.ndarray:
|
||||
"""
|
||||
Generate embeddings for multiple images.
|
||||
|
||||
This method checks cache for each image individually and only
|
||||
generates embeddings for cache misses.
|
||||
|
||||
Args:
|
||||
images: List of PIL Images to embed
|
||||
use_cache: If True, use cache for identical images
|
||||
|
||||
Returns:
|
||||
np.ndarray: Array of embeddings (len(images), dimension)
|
||||
|
||||
Raises:
|
||||
ValueError: If any image is invalid
|
||||
RuntimeError: If embedding generation fails
|
||||
"""
|
||||
if not images:
|
||||
return np.array([]).reshape(0, self.get_dimension())
|
||||
|
||||
embeddings = []
|
||||
images_to_embed = []
|
||||
indices_to_embed = []
|
||||
|
||||
# Check cache for each image
|
||||
for i, img in enumerate(images):
|
||||
if not isinstance(img, Image.Image):
|
||||
raise ValueError(f"Image at index {i} is not a PIL Image")
|
||||
|
||||
if use_cache:
|
||||
cache_key = self._get_cache_key(img)
|
||||
if cache_key in self._cache:
|
||||
self._cache.move_to_end(cache_key)
|
||||
self._cache_hits += 1
|
||||
embeddings.append((i, self._cache[cache_key]))
|
||||
continue
|
||||
|
||||
self._cache_misses += 1
|
||||
|
||||
# Need to generate embedding
|
||||
images_to_embed.append(img)
|
||||
indices_to_embed.append(i)
|
||||
|
||||
# Generate embeddings for cache misses
|
||||
if images_to_embed:
|
||||
if self.embedder.supports_batch():
|
||||
new_embeddings = self.embedder.embed_batch(images_to_embed)
|
||||
else:
|
||||
new_embeddings = np.array([
|
||||
self.embedder.embed(img) for img in images_to_embed
|
||||
])
|
||||
|
||||
# Add to cache and results
|
||||
for img, idx, emb in zip(images_to_embed, indices_to_embed, new_embeddings):
|
||||
if use_cache:
|
||||
cache_key = self._get_cache_key(img)
|
||||
self._add_to_cache(cache_key, emb)
|
||||
embeddings.append((idx, emb))
|
||||
|
||||
# Sort by original index and extract embeddings
|
||||
embeddings.sort(key=lambda x: x[0])
|
||||
return np.array([emb for _, emb in embeddings])
|
||||
|
||||
def _get_cache_key(self, image: Image.Image) -> str:
|
||||
"""
|
||||
Generate cache key from image content.
|
||||
|
||||
Uses MD5 hash of image bytes for fast lookup.
|
||||
|
||||
Args:
|
||||
image: PIL Image
|
||||
|
||||
Returns:
|
||||
str: Cache key (MD5 hash)
|
||||
"""
|
||||
return hashlib.md5(image.tobytes()).hexdigest()
|
||||
|
||||
def _add_to_cache(self, key: str, embedding: np.ndarray):
|
||||
"""
|
||||
Add embedding to cache with LRU eviction.
|
||||
|
||||
Args:
|
||||
key: Cache key
|
||||
embedding: Embedding to cache
|
||||
"""
|
||||
# Add to cache
|
||||
self._cache[key] = embedding
|
||||
|
||||
# Evict oldest if cache is full
|
||||
if len(self._cache) > self.cache_size:
|
||||
oldest_key = next(iter(self._cache))
|
||||
del self._cache[oldest_key]
|
||||
logger.debug(f"Cache eviction: size={len(self._cache)}")
|
||||
|
||||
def clear_cache(self):
|
||||
"""Clear all cached embeddings."""
|
||||
self._cache.clear()
|
||||
logger.info("Cache cleared")
|
||||
|
||||
def get_dimension(self) -> int:
|
||||
"""
|
||||
Get embedding dimension.
|
||||
|
||||
Returns:
|
||||
int: Embedding dimension
|
||||
"""
|
||||
return self.embedder.get_dimension()
|
||||
|
||||
def get_model_name(self) -> str:
|
||||
"""
|
||||
Get current model name.
|
||||
|
||||
Returns:
|
||||
str: Model identifier
|
||||
"""
|
||||
return self.embedder.get_model_name()
|
||||
|
||||
def get_stats(self) -> Dict[str, Any]:
|
||||
"""
|
||||
Get manager statistics.
|
||||
|
||||
Returns:
|
||||
Dict with keys:
|
||||
- model_name: Current model
|
||||
- dimension: Embedding dimension
|
||||
- cache_size: Current cache size
|
||||
- cache_capacity: Maximum cache size
|
||||
- cache_hits: Number of cache hits
|
||||
- cache_misses: Number of cache misses
|
||||
- cache_hit_rate: Hit rate (0-1)
|
||||
"""
|
||||
total_requests = self._cache_hits + self._cache_misses
|
||||
hit_rate = self._cache_hits / total_requests if total_requests > 0 else 0.0
|
||||
|
||||
return {
|
||||
'model_name': self.get_model_name(),
|
||||
'dimension': self.get_dimension(),
|
||||
'cache_size': len(self._cache),
|
||||
'cache_capacity': self.cache_size,
|
||||
'cache_hits': self._cache_hits,
|
||||
'cache_misses': self._cache_misses,
|
||||
'cache_hit_rate': hit_rate
|
||||
}
|
||||
|
||||
def __repr__(self) -> str:
|
||||
"""String representation."""
|
||||
stats = self.get_stats()
|
||||
return (
|
||||
f"EmbeddingManager(model={stats['model_name']}, "
|
||||
f"cache={stats['cache_size']}/{stats['cache_capacity']}, "
|
||||
f"hit_rate={stats['cache_hit_rate']:.2%})"
|
||||
)
|
||||
309
geniusia2/core/embedders/faiss_index.py
Normal file
309
geniusia2/core/embedders/faiss_index.py
Normal file
@@ -0,0 +1,309 @@
|
||||
"""
|
||||
FAISS index wrapper with proper dimension handling and persistence.
|
||||
|
||||
This module provides a robust wrapper around FAISS for storing and searching
|
||||
image embeddings, with proper error handling for dimension mismatches and
|
||||
reliable save/load functionality.
|
||||
"""
|
||||
|
||||
import pickle
|
||||
import logging
|
||||
from pathlib import Path
|
||||
from typing import List, Dict, Any, Optional
|
||||
import numpy as np
|
||||
|
||||
try:
|
||||
import faiss
|
||||
except ImportError:
|
||||
faiss = None
|
||||
|
||||
|
||||
logger = logging.getLogger(__name__)
|
||||
|
||||
|
||||
class FAISSIndex:
|
||||
"""
|
||||
Wrapper around FAISS index with metadata storage and dimension validation.
|
||||
|
||||
This class handles:
|
||||
- Dimension validation on add/search operations
|
||||
- Metadata storage alongside embeddings
|
||||
- Reliable persistence (save/load)
|
||||
- Automatic index rebuilding on dimension changes
|
||||
"""
|
||||
|
||||
def __init__(self, dimension: int):
|
||||
"""
|
||||
Initialize a new FAISS index.
|
||||
|
||||
Args:
|
||||
dimension: Embedding dimension (e.g., 512 for CLIP, 768 for Pix2Struct)
|
||||
|
||||
Raises:
|
||||
ImportError: If FAISS is not installed
|
||||
ValueError: If dimension is invalid
|
||||
"""
|
||||
if faiss is None:
|
||||
raise ImportError(
|
||||
"FAISS is not installed. "
|
||||
"Install it with: pip install faiss-cpu or faiss-gpu"
|
||||
)
|
||||
|
||||
if dimension <= 0:
|
||||
raise ValueError(f"Dimension must be positive, got {dimension}")
|
||||
|
||||
self.dimension = dimension
|
||||
self.index = faiss.IndexFlatL2(dimension)
|
||||
self.metadata: List[Dict[str, Any]] = []
|
||||
|
||||
logger.info(f"FAISSIndex created with dimension={dimension}")
|
||||
|
||||
def add(self, embeddings: np.ndarray, metadata: List[Dict[str, Any]]):
|
||||
"""
|
||||
Add embeddings to the index with associated metadata.
|
||||
|
||||
Args:
|
||||
embeddings: Array of shape (N, dimension) containing N embeddings
|
||||
metadata: List of N metadata dictionaries
|
||||
|
||||
Raises:
|
||||
ValueError: If dimensions don't match or array shapes are invalid
|
||||
"""
|
||||
# Validate input shape
|
||||
if embeddings.ndim == 1:
|
||||
# Single embedding, reshape to (1, dimension)
|
||||
embeddings = embeddings.reshape(1, -1)
|
||||
elif embeddings.ndim != 2:
|
||||
raise ValueError(
|
||||
f"Embeddings must be 1D or 2D array, got shape {embeddings.shape}"
|
||||
)
|
||||
|
||||
# Validate dimension
|
||||
if embeddings.shape[1] != self.dimension:
|
||||
raise ValueError(
|
||||
f"Embedding dimension {embeddings.shape[1]} doesn't match "
|
||||
f"index dimension {self.dimension}"
|
||||
)
|
||||
|
||||
# Validate metadata count
|
||||
if len(metadata) != embeddings.shape[0]:
|
||||
raise ValueError(
|
||||
f"Number of metadata entries ({len(metadata)}) doesn't match "
|
||||
f"number of embeddings ({embeddings.shape[0]})"
|
||||
)
|
||||
|
||||
# Add to FAISS index
|
||||
self.index.add(embeddings.astype('float32'))
|
||||
|
||||
# Store metadata
|
||||
self.metadata.extend(metadata)
|
||||
|
||||
logger.debug(
|
||||
f"Added {embeddings.shape[0]} embeddings to index "
|
||||
f"(total: {self.index.ntotal})"
|
||||
)
|
||||
|
||||
def search(
|
||||
self,
|
||||
query: np.ndarray,
|
||||
k: int = 5
|
||||
) -> List[Dict[str, Any]]:
|
||||
"""
|
||||
Search for the k most similar embeddings.
|
||||
|
||||
Args:
|
||||
query: Query embedding of shape (dimension,) or (1, dimension)
|
||||
k: Number of results to return
|
||||
|
||||
Returns:
|
||||
List of dicts with keys:
|
||||
- 'index': Index in the FAISS index
|
||||
- 'distance': L2 distance
|
||||
- 'similarity': Similarity score (1 / (1 + distance))
|
||||
- 'metadata': Associated metadata dict
|
||||
|
||||
Raises:
|
||||
ValueError: If query dimension doesn't match index dimension
|
||||
"""
|
||||
if self.index.ntotal == 0:
|
||||
logger.warning("Search called on empty index")
|
||||
return []
|
||||
|
||||
# Reshape query if needed
|
||||
if query.ndim == 1:
|
||||
query = query.reshape(1, -1)
|
||||
elif query.ndim != 2:
|
||||
raise ValueError(
|
||||
f"Query must be 1D or 2D array, got shape {query.shape}"
|
||||
)
|
||||
|
||||
# Validate dimension
|
||||
if query.shape[1] != self.dimension:
|
||||
raise ValueError(
|
||||
f"Query dimension {query.shape[1]} doesn't match "
|
||||
f"index dimension {self.dimension}"
|
||||
)
|
||||
|
||||
# Limit k to available embeddings
|
||||
k = min(k, self.index.ntotal)
|
||||
|
||||
# Search
|
||||
distances, indices = self.index.search(query.astype('float32'), k)
|
||||
|
||||
# Format results
|
||||
results = []
|
||||
for dist, idx in zip(distances[0], indices[0]):
|
||||
# FAISS returns -1 if not enough results
|
||||
if idx >= 0 and idx < len(self.metadata):
|
||||
results.append({
|
||||
'index': int(idx),
|
||||
'distance': float(dist),
|
||||
'similarity': float(1.0 / (1.0 + dist)),
|
||||
'metadata': self.metadata[idx]
|
||||
})
|
||||
|
||||
return results
|
||||
|
||||
def save(self, path: str):
|
||||
"""
|
||||
Save index and metadata to disk.
|
||||
|
||||
Args:
|
||||
path: Base path for saving (will create .index and .metadata files)
|
||||
|
||||
Raises:
|
||||
RuntimeError: If save operation fails
|
||||
"""
|
||||
try:
|
||||
path_obj = Path(path)
|
||||
path_obj.parent.mkdir(parents=True, exist_ok=True)
|
||||
|
||||
# Save FAISS index
|
||||
index_file = f"{path}.index"
|
||||
faiss.write_index(self.index, index_file)
|
||||
|
||||
# Save metadata
|
||||
metadata_file = f"{path}.metadata"
|
||||
with open(metadata_file, 'wb') as f:
|
||||
pickle.dump({
|
||||
'dimension': self.dimension,
|
||||
'metadata': self.metadata
|
||||
}, f)
|
||||
|
||||
logger.info(
|
||||
f"Saved index with {self.index.ntotal} embeddings to {path}"
|
||||
)
|
||||
|
||||
except Exception as e:
|
||||
raise RuntimeError(f"Failed to save index: {e}")
|
||||
|
||||
def load(self, path: str):
|
||||
"""
|
||||
Load index and metadata from disk.
|
||||
|
||||
Args:
|
||||
path: Base path for loading (will read .index and .metadata files)
|
||||
|
||||
Raises:
|
||||
FileNotFoundError: If files don't exist
|
||||
RuntimeError: If load operation fails or dimension mismatch
|
||||
"""
|
||||
try:
|
||||
index_file = f"{path}.index"
|
||||
metadata_file = f"{path}.metadata"
|
||||
|
||||
# Check files exist
|
||||
if not Path(index_file).exists():
|
||||
raise FileNotFoundError(f"Index file not found: {index_file}")
|
||||
if not Path(metadata_file).exists():
|
||||
raise FileNotFoundError(f"Metadata file not found: {metadata_file}")
|
||||
|
||||
# Load FAISS index
|
||||
loaded_index = faiss.read_index(index_file)
|
||||
|
||||
# Load metadata
|
||||
with open(metadata_file, 'rb') as f:
|
||||
data = pickle.load(f)
|
||||
|
||||
loaded_dimension = data['dimension']
|
||||
loaded_metadata = data['metadata']
|
||||
|
||||
# Validate dimension
|
||||
if loaded_dimension != self.dimension:
|
||||
raise RuntimeError(
|
||||
f"Loaded index dimension ({loaded_dimension}) doesn't match "
|
||||
f"current dimension ({self.dimension}). "
|
||||
f"Use rebuild_if_needed() to handle dimension changes."
|
||||
)
|
||||
|
||||
# Update state
|
||||
self.index = loaded_index
|
||||
self.metadata = loaded_metadata
|
||||
|
||||
logger.info(
|
||||
f"Loaded index with {self.index.ntotal} embeddings from {path}"
|
||||
)
|
||||
|
||||
except Exception as e:
|
||||
if isinstance(e, (FileNotFoundError, RuntimeError)):
|
||||
raise
|
||||
raise RuntimeError(f"Failed to load index: {e}")
|
||||
|
||||
def rebuild_if_needed(self, new_dimension: int) -> bool:
|
||||
"""
|
||||
Rebuild index if dimension has changed.
|
||||
|
||||
This creates a new empty index with the new dimension.
|
||||
Old embeddings are lost and need to be regenerated.
|
||||
|
||||
Args:
|
||||
new_dimension: New embedding dimension
|
||||
|
||||
Returns:
|
||||
bool: True if index was rebuilt, False if dimension unchanged
|
||||
"""
|
||||
if new_dimension == self.dimension:
|
||||
return False
|
||||
|
||||
logger.warning(
|
||||
f"Rebuilding FAISS index: dimension changed from "
|
||||
f"{self.dimension} to {new_dimension}. "
|
||||
f"Old embeddings ({self.index.ntotal}) will be lost."
|
||||
)
|
||||
|
||||
# Create new index
|
||||
self.dimension = new_dimension
|
||||
self.index = faiss.IndexFlatL2(new_dimension)
|
||||
self.metadata = []
|
||||
|
||||
return True
|
||||
|
||||
def clear(self):
|
||||
"""Clear all embeddings from the index."""
|
||||
self.index = faiss.IndexFlatL2(self.dimension)
|
||||
self.metadata = []
|
||||
logger.info("Index cleared")
|
||||
|
||||
def get_stats(self) -> Dict[str, Any]:
|
||||
"""
|
||||
Get index statistics.
|
||||
|
||||
Returns:
|
||||
Dict with keys: num_embeddings, dimension, is_trained
|
||||
"""
|
||||
return {
|
||||
'num_embeddings': self.index.ntotal,
|
||||
'dimension': self.dimension,
|
||||
'is_trained': self.index.is_trained
|
||||
}
|
||||
|
||||
def __len__(self) -> int:
|
||||
"""Return number of embeddings in the index."""
|
||||
return self.index.ntotal
|
||||
|
||||
def __repr__(self) -> str:
|
||||
"""String representation of the index."""
|
||||
return (
|
||||
f"FAISSIndex(dimension={self.dimension}, "
|
||||
f"num_embeddings={self.index.ntotal})"
|
||||
)
|
||||
321
geniusia2/core/embedders/fine_tuner.py
Normal file
321
geniusia2/core/embedders/fine_tuner.py
Normal file
@@ -0,0 +1,321 @@
|
||||
"""
|
||||
Lightweight fine-tuner for embedding models.
|
||||
|
||||
This module provides incremental fine-tuning capabilities that run in the
|
||||
background, adapting the embedding model to user-specific workflows over time.
|
||||
"""
|
||||
|
||||
import threading
|
||||
import time
|
||||
import pickle
|
||||
import logging
|
||||
from collections import deque
|
||||
from pathlib import Path
|
||||
from typing import List, Dict, Any, Optional
|
||||
from PIL import Image
|
||||
import numpy as np
|
||||
|
||||
|
||||
logger = logging.getLogger(__name__)
|
||||
|
||||
|
||||
class LightweightFineTuner:
|
||||
"""
|
||||
Lightweight fine-tuner for incremental model adaptation.
|
||||
|
||||
This class collects positive and negative examples from user interactions
|
||||
and periodically fine-tunes the embedding model to improve accuracy on
|
||||
user-specific workflows.
|
||||
|
||||
Features:
|
||||
- Automatic triggering after N examples
|
||||
- Background training (non-blocking)
|
||||
- Checkpoint save/load for recovery
|
||||
- Metrics tracking
|
||||
"""
|
||||
|
||||
def __init__(
|
||||
self,
|
||||
embedder,
|
||||
trigger_threshold: int = 10,
|
||||
max_examples: int = 1000,
|
||||
checkpoint_dir: str = "data/fine_tuning"
|
||||
):
|
||||
"""
|
||||
Initialize the fine-tuner.
|
||||
|
||||
Args:
|
||||
embedder: Embedder instance to fine-tune (must support fine_tune method)
|
||||
trigger_threshold: Number of new examples before triggering fine-tuning
|
||||
max_examples: Maximum examples to keep (LRU eviction)
|
||||
checkpoint_dir: Directory for saving checkpoints
|
||||
"""
|
||||
self.embedder = embedder
|
||||
self.trigger_threshold = trigger_threshold
|
||||
self.max_examples = max_examples
|
||||
self.checkpoint_dir = Path(checkpoint_dir)
|
||||
self.checkpoint_dir.mkdir(parents=True, exist_ok=True)
|
||||
|
||||
# Example storage (deque for automatic LRU)
|
||||
self.positive_examples = deque(maxlen=max_examples)
|
||||
self.negative_examples = deque(maxlen=max_examples)
|
||||
|
||||
# Training state
|
||||
self.is_training = False
|
||||
self.training_thread: Optional[threading.Thread] = None
|
||||
self.last_training_time = 0
|
||||
self.training_count = 0
|
||||
|
||||
# Metrics
|
||||
self.metrics_history: List[Dict[str, Any]] = []
|
||||
|
||||
logger.info(
|
||||
f"LightweightFineTuner initialized: "
|
||||
f"trigger={trigger_threshold}, max_examples={max_examples}"
|
||||
)
|
||||
|
||||
def add_positive_example(self, image: Image.Image, workflow_id: str, metadata: Optional[Dict] = None):
|
||||
"""
|
||||
Add a positive example (successful workflow execution).
|
||||
|
||||
Args:
|
||||
image: Screenshot where workflow succeeded
|
||||
workflow_id: ID of the successful workflow
|
||||
metadata: Optional additional metadata
|
||||
"""
|
||||
example = {
|
||||
'image': image,
|
||||
'workflow_id': workflow_id,
|
||||
'metadata': metadata or {},
|
||||
'timestamp': time.time()
|
||||
}
|
||||
|
||||
self.positive_examples.append(example)
|
||||
|
||||
logger.debug(
|
||||
f"Added positive example: workflow={workflow_id}, "
|
||||
f"total_positive={len(self.positive_examples)}"
|
||||
)
|
||||
|
||||
self._check_trigger()
|
||||
|
||||
def add_negative_example(self, image: Image.Image, workflow_id: str, metadata: Optional[Dict] = None):
|
||||
"""
|
||||
Add a negative example (rejected workflow suggestion).
|
||||
|
||||
Args:
|
||||
image: Screenshot where workflow was rejected
|
||||
workflow_id: ID of the rejected workflow
|
||||
metadata: Optional additional metadata
|
||||
"""
|
||||
example = {
|
||||
'image': image,
|
||||
'workflow_id': workflow_id,
|
||||
'metadata': metadata or {},
|
||||
'timestamp': time.time()
|
||||
}
|
||||
|
||||
self.negative_examples.append(example)
|
||||
|
||||
logger.debug(
|
||||
f"Added negative example: workflow={workflow_id}, "
|
||||
f"total_negative={len(self.negative_examples)}"
|
||||
)
|
||||
|
||||
self._check_trigger()
|
||||
|
||||
def _check_trigger(self):
|
||||
"""Check if we should trigger fine-tuning."""
|
||||
total_new = len(self.positive_examples) + len(self.negative_examples)
|
||||
|
||||
# Don't trigger if already training
|
||||
if self.is_training:
|
||||
logger.debug("Fine-tuning already in progress, skipping trigger check")
|
||||
return
|
||||
|
||||
# Check if we have enough examples
|
||||
if total_new >= self.trigger_threshold:
|
||||
logger.info(
|
||||
f"Fine-tuning triggered: {total_new} examples "
|
||||
f"({len(self.positive_examples)} positive, "
|
||||
f"{len(self.negative_examples)} negative)"
|
||||
)
|
||||
self._start_training()
|
||||
|
||||
def _start_training(self):
|
||||
"""Start training in background thread."""
|
||||
if self.is_training:
|
||||
logger.warning("Training already in progress")
|
||||
return
|
||||
|
||||
self.training_thread = threading.Thread(
|
||||
target=self._train,
|
||||
name="FineTuningThread",
|
||||
daemon=True
|
||||
)
|
||||
self.training_thread.start()
|
||||
logger.info("Fine-tuning thread started")
|
||||
|
||||
def _train(self):
|
||||
"""Fine-tune the model (runs in background thread)."""
|
||||
self.is_training = True
|
||||
start_time = time.time()
|
||||
|
||||
try:
|
||||
# Check if embedder supports fine-tuning
|
||||
if not hasattr(self.embedder, 'fine_tune'):
|
||||
logger.info(
|
||||
f"Embedder {self.embedder.get_model_name()} doesn't support "
|
||||
f"fine-tuning, skipping"
|
||||
)
|
||||
return
|
||||
|
||||
# Prepare training data
|
||||
positive_images = [ex['image'] for ex in self.positive_examples]
|
||||
negative_images = [ex['image'] for ex in self.negative_examples]
|
||||
|
||||
if not positive_images and not negative_images:
|
||||
logger.warning("No examples to train on")
|
||||
return
|
||||
|
||||
logger.info(
|
||||
f"Starting fine-tuning: {len(positive_images)} positive, "
|
||||
f"{len(negative_images)} negative examples"
|
||||
)
|
||||
|
||||
# Fine-tune (implementation depends on embedder)
|
||||
metrics = self.embedder.fine_tune(
|
||||
positive_images=positive_images,
|
||||
negative_images=negative_images,
|
||||
epochs=1,
|
||||
learning_rate=1e-4
|
||||
)
|
||||
|
||||
# Record metrics
|
||||
duration = time.time() - start_time
|
||||
metrics['duration_seconds'] = duration
|
||||
metrics['timestamp'] = time.time()
|
||||
metrics['positive_count'] = len(positive_images)
|
||||
metrics['negative_count'] = len(negative_images)
|
||||
metrics['training_number'] = self.training_count
|
||||
|
||||
self.metrics_history.append(metrics)
|
||||
self.last_training_time = time.time()
|
||||
self.training_count += 1
|
||||
|
||||
logger.info(
|
||||
f"Fine-tuning complete #{self.training_count}: "
|
||||
f"loss={metrics.get('loss', 'N/A'):.4f}, "
|
||||
f"duration={duration:.1f}s"
|
||||
)
|
||||
|
||||
# Clear examples after successful training
|
||||
self.positive_examples.clear()
|
||||
self.negative_examples.clear()
|
||||
logger.debug("Training examples cleared")
|
||||
|
||||
except Exception as e:
|
||||
logger.error(f"Fine-tuning failed: {e}", exc_info=True)
|
||||
|
||||
finally:
|
||||
self.is_training = False
|
||||
|
||||
def save_checkpoint(self, name: str = "checkpoint"):
|
||||
"""
|
||||
Save training examples and metrics for recovery.
|
||||
|
||||
Args:
|
||||
name: Checkpoint name
|
||||
"""
|
||||
try:
|
||||
checkpoint_path = self.checkpoint_dir / f"{name}.pkl"
|
||||
|
||||
data = {
|
||||
'positive_examples': list(self.positive_examples),
|
||||
'negative_examples': list(self.negative_examples),
|
||||
'metrics_history': self.metrics_history,
|
||||
'training_count': self.training_count,
|
||||
'last_training_time': self.last_training_time
|
||||
}
|
||||
|
||||
with open(checkpoint_path, 'wb') as f:
|
||||
pickle.dump(data, f)
|
||||
|
||||
logger.info(f"Checkpoint saved: {checkpoint_path}")
|
||||
|
||||
except Exception as e:
|
||||
logger.error(f"Failed to save checkpoint: {e}")
|
||||
|
||||
def load_checkpoint(self, name: str = "checkpoint"):
|
||||
"""
|
||||
Load training examples and metrics from checkpoint.
|
||||
|
||||
Args:
|
||||
name: Checkpoint name
|
||||
|
||||
Returns:
|
||||
bool: True if loaded successfully, False otherwise
|
||||
"""
|
||||
try:
|
||||
checkpoint_path = self.checkpoint_dir / f"{name}.pkl"
|
||||
|
||||
if not checkpoint_path.exists():
|
||||
logger.warning(f"Checkpoint not found: {checkpoint_path}")
|
||||
return False
|
||||
|
||||
with open(checkpoint_path, 'rb') as f:
|
||||
data = pickle.load(f)
|
||||
|
||||
self.positive_examples.extend(data.get('positive_examples', []))
|
||||
self.negative_examples.extend(data.get('negative_examples', []))
|
||||
self.metrics_history = data.get('metrics_history', [])
|
||||
self.training_count = data.get('training_count', 0)
|
||||
self.last_training_time = data.get('last_training_time', 0)
|
||||
|
||||
logger.info(
|
||||
f"Checkpoint loaded: {len(self.positive_examples)} positive, "
|
||||
f"{len(self.negative_examples)} negative examples"
|
||||
)
|
||||
return True
|
||||
|
||||
except Exception as e:
|
||||
logger.error(f"Failed to load checkpoint: {e}")
|
||||
return False
|
||||
|
||||
def get_stats(self) -> Dict[str, Any]:
|
||||
"""
|
||||
Get fine-tuning statistics.
|
||||
|
||||
Returns:
|
||||
Dict with statistics
|
||||
"""
|
||||
return {
|
||||
'positive_examples': len(self.positive_examples),
|
||||
'negative_examples': len(self.negative_examples),
|
||||
'total_examples': len(self.positive_examples) + len(self.negative_examples),
|
||||
'is_training': self.is_training,
|
||||
'training_count': self.training_count,
|
||||
'last_training_time': self.last_training_time,
|
||||
'metrics_history': self.metrics_history,
|
||||
'trigger_threshold': self.trigger_threshold
|
||||
}
|
||||
|
||||
def wait_for_training(self, timeout: Optional[float] = None):
|
||||
"""
|
||||
Wait for current training to complete.
|
||||
|
||||
Args:
|
||||
timeout: Maximum time to wait in seconds (None = wait forever)
|
||||
"""
|
||||
if self.training_thread and self.training_thread.is_alive():
|
||||
self.training_thread.join(timeout=timeout)
|
||||
|
||||
def __repr__(self) -> str:
|
||||
"""String representation."""
|
||||
stats = self.get_stats()
|
||||
return (
|
||||
f"LightweightFineTuner("
|
||||
f"examples={stats['total_examples']}, "
|
||||
f"trainings={stats['training_count']}, "
|
||||
f"is_training={stats['is_training']})"
|
||||
)
|
||||
193
geniusia2/core/embedders/pix2struct_embedder.py
Normal file
193
geniusia2/core/embedders/pix2struct_embedder.py
Normal file
@@ -0,0 +1,193 @@
|
||||
"""
|
||||
Pix2Struct-based embedder implementation.
|
||||
|
||||
This module provides a wrapper around Google's Pix2Struct model for generating
|
||||
image embeddings specialized for UI understanding and document screenshots.
|
||||
"""
|
||||
|
||||
import torch
|
||||
import numpy as np
|
||||
from PIL import Image
|
||||
from typing import List, Optional
|
||||
import logging
|
||||
|
||||
try:
|
||||
from transformers import Pix2StructProcessor, Pix2StructForConditionalGeneration
|
||||
except ImportError:
|
||||
Pix2StructProcessor = None
|
||||
Pix2StructForConditionalGeneration = None
|
||||
|
||||
from .base import EmbedderBase
|
||||
|
||||
|
||||
logger = logging.getLogger(__name__)
|
||||
|
||||
|
||||
class Pix2StructEmbedder(EmbedderBase):
|
||||
"""
|
||||
Pix2Struct-based image embedder specialized for UI understanding.
|
||||
|
||||
Pix2Struct is a vision-language model trained on screenshots and structured
|
||||
documents, making it particularly well-suited for RPA and UI automation tasks.
|
||||
|
||||
This embedder uses the encoder's hidden states as embeddings, which capture
|
||||
visual features optimized for understanding UI elements and layouts.
|
||||
"""
|
||||
|
||||
def __init__(
|
||||
self,
|
||||
model_name: str = "google/pix2struct-base",
|
||||
device: Optional[str] = None
|
||||
):
|
||||
"""
|
||||
Initialize the Pix2Struct embedder.
|
||||
|
||||
Args:
|
||||
model_name: Pix2Struct model to use (default: google/pix2struct-base)
|
||||
device: Device to use ('cuda', 'cpu', or None for auto-detect)
|
||||
|
||||
Raises:
|
||||
ImportError: If transformers is not installed
|
||||
RuntimeError: If model loading fails
|
||||
"""
|
||||
if Pix2StructProcessor is None or Pix2StructForConditionalGeneration is None:
|
||||
raise ImportError(
|
||||
"Transformers is not installed or version is too old. "
|
||||
"Install it with: pip install transformers>=4.35.0"
|
||||
)
|
||||
|
||||
# Auto-detect device
|
||||
if device is None:
|
||||
device = "cuda" if torch.cuda.is_available() else "cpu"
|
||||
|
||||
self.model_name = model_name
|
||||
self.device = device
|
||||
self._embedding_dim = None
|
||||
|
||||
# Load model and processor
|
||||
try:
|
||||
logger.info(f"Loading Pix2Struct model: {model_name}")
|
||||
|
||||
self.processor = Pix2StructProcessor.from_pretrained(model_name)
|
||||
self.model = Pix2StructForConditionalGeneration.from_pretrained(
|
||||
model_name
|
||||
).to(device)
|
||||
self.model.eval()
|
||||
|
||||
# Determine embedding dimension from encoder
|
||||
with torch.no_grad():
|
||||
dummy_image = Image.new('RGB', (224, 224), color=(128, 128, 128))
|
||||
inputs = self.processor(images=dummy_image, return_tensors="pt").to(device)
|
||||
encoder_outputs = self.model.encoder(**inputs)
|
||||
# Use mean pooling of last hidden state
|
||||
self._embedding_dim = encoder_outputs.last_hidden_state.shape[-1]
|
||||
|
||||
logger.info(
|
||||
f"Pix2StructEmbedder loaded: {model_name} on {device}, "
|
||||
f"dimension={self._embedding_dim}"
|
||||
)
|
||||
|
||||
except Exception as e:
|
||||
raise RuntimeError(f"Failed to load Pix2Struct model: {e}")
|
||||
|
||||
def embed(self, image: Image.Image) -> np.ndarray:
|
||||
"""
|
||||
Generate embedding for a single image.
|
||||
|
||||
Args:
|
||||
image: PIL Image to embed
|
||||
|
||||
Returns:
|
||||
np.ndarray: Normalized embedding vector of shape (dimension,)
|
||||
|
||||
Raises:
|
||||
ValueError: If image is invalid
|
||||
RuntimeError: If embedding generation fails
|
||||
"""
|
||||
if not isinstance(image, Image.Image):
|
||||
raise ValueError("Input must be a PIL Image")
|
||||
|
||||
try:
|
||||
# Process image
|
||||
inputs = self.processor(images=image, return_tensors="pt").to(self.device)
|
||||
|
||||
# Generate embedding from encoder
|
||||
with torch.no_grad():
|
||||
encoder_outputs = self.model.encoder(**inputs)
|
||||
# Mean pooling over sequence dimension
|
||||
embedding = encoder_outputs.last_hidden_state.mean(dim=1)
|
||||
# L2 normalize for cosine similarity
|
||||
embedding = embedding / embedding.norm(dim=-1, keepdim=True)
|
||||
|
||||
return embedding.cpu().numpy().flatten()
|
||||
|
||||
except Exception as e:
|
||||
raise RuntimeError(f"Failed to generate embedding: {e}")
|
||||
|
||||
def embed_batch(self, images: List[Image.Image]) -> np.ndarray:
|
||||
"""
|
||||
Generate embeddings for multiple images (optimized batch processing).
|
||||
|
||||
Args:
|
||||
images: List of PIL Images to embed
|
||||
|
||||
Returns:
|
||||
np.ndarray: Array of embeddings with shape (len(images), dimension)
|
||||
|
||||
Raises:
|
||||
ValueError: If any image is invalid
|
||||
RuntimeError: If embedding generation fails
|
||||
"""
|
||||
if not images:
|
||||
return np.array([]).reshape(0, self.get_dimension())
|
||||
|
||||
# Validate all images
|
||||
for i, img in enumerate(images):
|
||||
if not isinstance(img, Image.Image):
|
||||
raise ValueError(f"Image at index {i} is not a PIL Image")
|
||||
|
||||
try:
|
||||
# Process all images in batch
|
||||
inputs = self.processor(images=images, return_tensors="pt").to(self.device)
|
||||
|
||||
# Generate embeddings in batch
|
||||
with torch.no_grad():
|
||||
encoder_outputs = self.model.encoder(**inputs)
|
||||
# Mean pooling over sequence dimension
|
||||
embeddings = encoder_outputs.last_hidden_state.mean(dim=1)
|
||||
# L2 normalize for cosine similarity
|
||||
embeddings = embeddings / embeddings.norm(dim=-1, keepdim=True)
|
||||
|
||||
return embeddings.cpu().numpy()
|
||||
|
||||
except Exception as e:
|
||||
raise RuntimeError(f"Failed to generate batch embeddings: {e}")
|
||||
|
||||
def get_dimension(self) -> int:
|
||||
"""
|
||||
Get the dimensionality of embeddings.
|
||||
|
||||
Returns:
|
||||
int: Embedding dimension (768 for pix2struct-base)
|
||||
"""
|
||||
return self._embedding_dim
|
||||
|
||||
def get_model_name(self) -> str:
|
||||
"""
|
||||
Get model identifier.
|
||||
|
||||
Returns:
|
||||
str: Model name (e.g., "pix2struct-base")
|
||||
"""
|
||||
# Extract model name from full path
|
||||
model_id = self.model_name.split('/')[-1]
|
||||
return f"pix2struct-{model_id}" if not model_id.startswith('pix2struct') else model_id
|
||||
|
||||
def supports_batch(self) -> bool:
|
||||
"""
|
||||
Check if batch processing is supported.
|
||||
|
||||
Returns:
|
||||
bool: True (Pix2Struct supports efficient batch processing)
|
||||
"""
|
||||
return True
|
||||
413
geniusia2/core/embeddings_manager.py
Normal file
413
geniusia2/core/embeddings_manager.py
Normal file
@@ -0,0 +1,413 @@
|
||||
"""
|
||||
Gestionnaire d'embeddings visuels avec OpenCLIP et FAISS.
|
||||
Gère l'encodage d'images, l'indexation et la recherche de similarité.
|
||||
"""
|
||||
|
||||
import os
|
||||
import json
|
||||
import pickle
|
||||
import numpy as np
|
||||
from pathlib import Path
|
||||
from typing import Dict, List, Optional, Tuple, Any
|
||||
import torch
|
||||
from PIL import Image
|
||||
|
||||
try:
|
||||
import open_clip
|
||||
except ImportError:
|
||||
open_clip = None
|
||||
|
||||
try:
|
||||
import faiss
|
||||
except ImportError:
|
||||
faiss = None
|
||||
|
||||
from .logger import Logger
|
||||
|
||||
|
||||
class EmbeddingsManager:
|
||||
"""
|
||||
Gestionnaire d'embeddings visuels utilisant OpenCLIP pour l'encodage
|
||||
et FAISS pour l'indexation et la recherche de similarité.
|
||||
"""
|
||||
|
||||
def __init__(
|
||||
self,
|
||||
model_name: str = "ViT-B-32",
|
||||
pretrained: str = "openai",
|
||||
index_path: str = "data/faiss_index",
|
||||
device: Optional[str] = None,
|
||||
logger: Optional[Logger] = None
|
||||
):
|
||||
"""
|
||||
Initialise le gestionnaire d'embeddings.
|
||||
|
||||
Args:
|
||||
model_name: Nom du modèle OpenCLIP
|
||||
pretrained: Dataset de pré-entraînement
|
||||
index_path: Chemin vers l'index FAISS
|
||||
device: Device PyTorch (cuda/cpu)
|
||||
logger: Instance du logger
|
||||
"""
|
||||
self.model_name = model_name
|
||||
self.pretrained = pretrained
|
||||
self.index_path = Path(index_path)
|
||||
# Forcer CPU pour OpenCLIP pour économiser la mémoire GPU (Qwen3-VL prioritaire)
|
||||
self.device = "cpu"
|
||||
self.logger = logger
|
||||
|
||||
if self.logger:
|
||||
self.logger.log_action({
|
||||
"action": "embeddings_manager_init",
|
||||
"device": self.device,
|
||||
"reason": "CPU forcé pour économiser GPU pour Qwen3-VL"
|
||||
})
|
||||
|
||||
# Créer le répertoire d'index si nécessaire
|
||||
self.index_path.mkdir(parents=True, exist_ok=True)
|
||||
|
||||
# Chemins des fichiers
|
||||
self.index_file = self.index_path / "embeddings.index"
|
||||
self.metadata_file = self.index_path / "metadata.pkl"
|
||||
|
||||
# Initialiser le modèle et l'index
|
||||
self.clip_model = None
|
||||
self.preprocess = None
|
||||
self.faiss_index = None
|
||||
self.metadata_store: Dict[int, Dict[str, Any]] = {}
|
||||
self.embedding_dim = 512 # Dimension par défaut pour ViT-B-32
|
||||
|
||||
self._load_model()
|
||||
self._load_or_create_index()
|
||||
|
||||
def _load_model(self):
|
||||
"""Charge le modèle OpenCLIP."""
|
||||
if open_clip is None:
|
||||
raise ImportError(
|
||||
"OpenCLIP n'est pas installé. "
|
||||
"Installez-le avec: pip install open-clip-torch"
|
||||
)
|
||||
|
||||
try:
|
||||
self.clip_model, _, self.preprocess = open_clip.create_model_and_transforms(
|
||||
self.model_name,
|
||||
pretrained=self.pretrained,
|
||||
device=self.device
|
||||
)
|
||||
self.clip_model.eval()
|
||||
|
||||
# Obtenir la dimension d'embedding réelle
|
||||
with torch.no_grad():
|
||||
dummy_image = torch.zeros(1, 3, 224, 224).to(self.device)
|
||||
dummy_embedding = self.clip_model.encode_image(dummy_image)
|
||||
self.embedding_dim = dummy_embedding.shape[-1]
|
||||
|
||||
if self.logger:
|
||||
self.logger.log_action({
|
||||
"action": "model_loaded",
|
||||
"model": self.model_name,
|
||||
"device": self.device,
|
||||
"embedding_dim": self.embedding_dim
|
||||
})
|
||||
|
||||
except Exception as e:
|
||||
error_msg = f"Erreur lors du chargement du modèle OpenCLIP: {e}"
|
||||
if self.logger:
|
||||
self.logger.log_action({
|
||||
"action": "model_load_error",
|
||||
"error": str(e)
|
||||
})
|
||||
raise RuntimeError(error_msg)
|
||||
|
||||
def _load_or_create_index(self):
|
||||
"""Charge l'index FAISS existant ou en crée un nouveau."""
|
||||
if faiss is None:
|
||||
raise ImportError(
|
||||
"FAISS n'est pas installé. "
|
||||
"Installez-le avec: pip install faiss-cpu ou faiss-gpu"
|
||||
)
|
||||
|
||||
# Charger l'index existant
|
||||
if self.index_file.exists() and self.metadata_file.exists():
|
||||
try:
|
||||
self.faiss_index = faiss.read_index(str(self.index_file))
|
||||
with open(self.metadata_file, 'rb') as f:
|
||||
self.metadata_store = pickle.load(f)
|
||||
|
||||
if self.logger:
|
||||
self.logger.log_action({
|
||||
"action": "index_loaded",
|
||||
"num_vectors": self.faiss_index.ntotal,
|
||||
"path": str(self.index_file)
|
||||
})
|
||||
return
|
||||
|
||||
except Exception as e:
|
||||
if self.logger:
|
||||
self.logger.log_action({
|
||||
"action": "index_load_error",
|
||||
"error": str(e)
|
||||
})
|
||||
# Continuer pour créer un nouvel index
|
||||
|
||||
# Créer un nouvel index
|
||||
self.faiss_index = faiss.IndexFlatL2(self.embedding_dim)
|
||||
self.metadata_store = {}
|
||||
|
||||
if self.logger:
|
||||
self.logger.log_action({
|
||||
"action": "index_created",
|
||||
"embedding_dim": self.embedding_dim
|
||||
})
|
||||
|
||||
def encode_image(self, image: np.ndarray) -> np.ndarray:
|
||||
"""
|
||||
Génère un embedding 512-d pour une image.
|
||||
|
||||
Args:
|
||||
image: Image numpy array (H, W, C) en BGR ou RGB
|
||||
|
||||
Returns:
|
||||
Embedding numpy array de forme (embedding_dim,)
|
||||
"""
|
||||
try:
|
||||
# Convertir BGR vers RGB si nécessaire
|
||||
if len(image.shape) == 3 and image.shape[2] == 3:
|
||||
# Supposer BGR (format OpenCV)
|
||||
image_rgb = image[:, :, ::-1]
|
||||
else:
|
||||
image_rgb = image
|
||||
|
||||
# Convertir en PIL Image
|
||||
pil_image = Image.fromarray(image_rgb.astype(np.uint8))
|
||||
|
||||
# Prétraiter l'image
|
||||
image_tensor = self.preprocess(pil_image).unsqueeze(0).to(self.device)
|
||||
|
||||
# Générer l'embedding
|
||||
with torch.no_grad():
|
||||
embedding = self.clip_model.encode_image(image_tensor)
|
||||
embedding = embedding / embedding.norm(dim=-1, keepdim=True) # Normaliser
|
||||
|
||||
return embedding.cpu().numpy().flatten()
|
||||
|
||||
except Exception as e:
|
||||
if self.logger:
|
||||
self.logger.log_action({
|
||||
"action": "encoding_error",
|
||||
"error": str(e)
|
||||
})
|
||||
raise RuntimeError(f"Erreur lors de l'encodage de l'image: {e}")
|
||||
|
||||
def add_to_index(self, embedding: np.ndarray, metadata: Dict[str, Any]) -> int:
|
||||
"""
|
||||
Ajoute un embedding à l'index FAISS avec ses métadonnées.
|
||||
|
||||
Args:
|
||||
embedding: Embedding numpy array
|
||||
metadata: Dictionnaire de métadonnées associées
|
||||
|
||||
Returns:
|
||||
ID de l'embedding dans l'index
|
||||
"""
|
||||
try:
|
||||
# Obtenir l'ID avant l'ajout
|
||||
idx = self.faiss_index.ntotal
|
||||
|
||||
# Ajouter à l'index FAISS
|
||||
embedding_2d = embedding.reshape(1, -1).astype(np.float32)
|
||||
self.faiss_index.add(embedding_2d)
|
||||
|
||||
# Stocker les métadonnées
|
||||
self.metadata_store[idx] = metadata
|
||||
|
||||
if self.logger:
|
||||
self.logger.log_action({
|
||||
"action": "embedding_added",
|
||||
"id": idx,
|
||||
"metadata": metadata
|
||||
})
|
||||
|
||||
return idx
|
||||
|
||||
except Exception as e:
|
||||
if self.logger:
|
||||
self.logger.log_action({
|
||||
"action": "add_to_index_error",
|
||||
"error": str(e)
|
||||
})
|
||||
raise RuntimeError(f"Erreur lors de l'ajout à l'index: {e}")
|
||||
|
||||
def search_similar(
|
||||
self,
|
||||
query_embedding: np.ndarray,
|
||||
k: int = 5
|
||||
) -> List[Dict[str, Any]]:
|
||||
"""
|
||||
Recherche les k embeddings les plus similaires.
|
||||
|
||||
Args:
|
||||
query_embedding: Embedding de requête
|
||||
k: Nombre de résultats à retourner
|
||||
|
||||
Returns:
|
||||
Liste de dictionnaires avec id, distance et metadata
|
||||
"""
|
||||
try:
|
||||
if self.faiss_index.ntotal == 0:
|
||||
return []
|
||||
|
||||
# Limiter k au nombre d'embeddings disponibles
|
||||
k = min(k, self.faiss_index.ntotal)
|
||||
|
||||
# Rechercher
|
||||
query_2d = query_embedding.reshape(1, -1).astype(np.float32)
|
||||
distances, indices = self.faiss_index.search(query_2d, k)
|
||||
|
||||
# Formater les résultats
|
||||
results = []
|
||||
for dist, idx in zip(distances[0], indices[0]):
|
||||
if idx != -1: # FAISS retourne -1 si pas assez de résultats
|
||||
results.append({
|
||||
"id": int(idx),
|
||||
"distance": float(dist),
|
||||
"similarity": float(1.0 / (1.0 + dist)), # Convertir distance en similarité
|
||||
"metadata": self.metadata_store.get(int(idx), {})
|
||||
})
|
||||
|
||||
return results
|
||||
|
||||
except Exception as e:
|
||||
if self.logger:
|
||||
self.logger.log_action({
|
||||
"action": "search_error",
|
||||
"error": str(e)
|
||||
})
|
||||
raise RuntimeError(f"Erreur lors de la recherche: {e}")
|
||||
|
||||
def get_embedding_similarity(
|
||||
self,
|
||||
emb1: np.ndarray,
|
||||
emb2: np.ndarray
|
||||
) -> float:
|
||||
"""
|
||||
Calcule la similarité cosinus entre deux embeddings.
|
||||
|
||||
Args:
|
||||
emb1: Premier embedding
|
||||
emb2: Deuxième embedding
|
||||
|
||||
Returns:
|
||||
Similarité cosinus (0-1)
|
||||
"""
|
||||
try:
|
||||
# Normaliser les embeddings
|
||||
emb1_norm = emb1 / np.linalg.norm(emb1)
|
||||
emb2_norm = emb2 / np.linalg.norm(emb2)
|
||||
|
||||
# Calculer la similarité cosinus
|
||||
similarity = np.dot(emb1_norm, emb2_norm)
|
||||
|
||||
# Convertir de [-1, 1] à [0, 1]
|
||||
similarity = (similarity + 1.0) / 2.0
|
||||
|
||||
return float(similarity)
|
||||
|
||||
except Exception as e:
|
||||
if self.logger:
|
||||
self.logger.log_action({
|
||||
"action": "similarity_calculation_error",
|
||||
"error": str(e)
|
||||
})
|
||||
return 0.0
|
||||
|
||||
def rebuild_index(self):
|
||||
"""Reconstruit l'index FAISS à partir des embeddings stockés."""
|
||||
try:
|
||||
if self.faiss_index.ntotal == 0:
|
||||
if self.logger:
|
||||
self.logger.log_action({
|
||||
"action": "rebuild_skipped",
|
||||
"reason": "index_empty"
|
||||
})
|
||||
return
|
||||
|
||||
# Extraire tous les embeddings
|
||||
all_embeddings = []
|
||||
for i in range(self.faiss_index.ntotal):
|
||||
embedding = self.faiss_index.reconstruct(i)
|
||||
all_embeddings.append(embedding)
|
||||
|
||||
# Créer un nouvel index
|
||||
new_index = faiss.IndexFlatL2(self.embedding_dim)
|
||||
|
||||
# Ajouter tous les embeddings
|
||||
embeddings_array = np.array(all_embeddings).astype(np.float32)
|
||||
new_index.add(embeddings_array)
|
||||
|
||||
# Remplacer l'ancien index
|
||||
self.faiss_index = new_index
|
||||
|
||||
if self.logger:
|
||||
self.logger.log_action({
|
||||
"action": "index_rebuilt",
|
||||
"num_vectors": self.faiss_index.ntotal
|
||||
})
|
||||
|
||||
except Exception as e:
|
||||
if self.logger:
|
||||
self.logger.log_action({
|
||||
"action": "rebuild_error",
|
||||
"error": str(e)
|
||||
})
|
||||
raise RuntimeError(f"Erreur lors de la reconstruction de l'index: {e}")
|
||||
|
||||
def save_index(self):
|
||||
"""Sauvegarde l'index FAISS et les métadonnées sur disque."""
|
||||
try:
|
||||
# Sauvegarder l'index FAISS
|
||||
faiss.write_index(self.faiss_index, str(self.index_file))
|
||||
|
||||
# Sauvegarder les métadonnées
|
||||
with open(self.metadata_file, 'wb') as f:
|
||||
pickle.dump(self.metadata_store, f)
|
||||
|
||||
if self.logger:
|
||||
self.logger.log_action({
|
||||
"action": "index_saved",
|
||||
"num_vectors": self.faiss_index.ntotal,
|
||||
"path": str(self.index_file)
|
||||
})
|
||||
|
||||
except Exception as e:
|
||||
if self.logger:
|
||||
self.logger.log_action({
|
||||
"action": "save_error",
|
||||
"error": str(e)
|
||||
})
|
||||
raise RuntimeError(f"Erreur lors de la sauvegarde de l'index: {e}")
|
||||
|
||||
def get_stats(self) -> Dict[str, Any]:
|
||||
"""
|
||||
Retourne des statistiques sur l'index.
|
||||
|
||||
Returns:
|
||||
Dictionnaire de statistiques
|
||||
"""
|
||||
return {
|
||||
"num_embeddings": self.faiss_index.ntotal,
|
||||
"embedding_dim": self.embedding_dim,
|
||||
"model_name": self.model_name,
|
||||
"device": self.device,
|
||||
"index_path": str(self.index_file)
|
||||
}
|
||||
|
||||
def clear_index(self):
|
||||
"""Efface tous les embeddings de l'index."""
|
||||
self.faiss_index = faiss.IndexFlatL2(self.embedding_dim)
|
||||
self.metadata_store = {}
|
||||
|
||||
if self.logger:
|
||||
self.logger.log_action({
|
||||
"action": "index_cleared"
|
||||
})
|
||||
850
geniusia2/core/enhanced_workflow_matcher.py
Normal file
850
geniusia2/core/enhanced_workflow_matcher.py
Normal file
@@ -0,0 +1,850 @@
|
||||
"""
|
||||
EnhancedWorkflowMatcher pour la Phase 3 - Mode Complet.
|
||||
Matching de workflows amélioré avec embeddings multi-modaux et matching au niveau élément.
|
||||
|
||||
Améliorations par rapport au WorkflowMatcher classique:
|
||||
1. Utilise les embeddings multi-modaux fusionnés
|
||||
2. Matching au niveau des éléments UI individuels
|
||||
3. Scoring composite (écran global + éléments)
|
||||
4. Cache des embeddings pour performance
|
||||
5. Métriques détaillées
|
||||
"""
|
||||
|
||||
import numpy as np
|
||||
from typing import Dict, List, Optional, Tuple, Any
|
||||
from dataclasses import dataclass
|
||||
from pathlib import Path
|
||||
import json
|
||||
from datetime import datetime
|
||||
|
||||
from .ui_element_models import EnrichedScreenState, UIElement
|
||||
from .multimodal_embedding_manager import MultiModalEmbeddingManager
|
||||
from .logger import Logger
|
||||
|
||||
|
||||
@dataclass
|
||||
class ElementMatch:
|
||||
"""Représente un match entre un élément UI et un élément de workflow."""
|
||||
ui_element: UIElement
|
||||
workflow_element_id: str
|
||||
similarity_score: float
|
||||
match_type: str # "exact", "similar", "partial"
|
||||
confidence: float
|
||||
|
||||
def to_dict(self) -> Dict[str, Any]:
|
||||
"""Convertit en dictionnaire."""
|
||||
return {
|
||||
"ui_element_id": self.ui_element.element_id,
|
||||
"ui_element_label": self.ui_element.label,
|
||||
"workflow_element_id": self.workflow_element_id,
|
||||
"similarity_score": float(self.similarity_score),
|
||||
"match_type": self.match_type,
|
||||
"confidence": float(self.confidence)
|
||||
}
|
||||
|
||||
|
||||
@dataclass
|
||||
class MatchDifference:
|
||||
"""Représente une différence détectée lors du matching."""
|
||||
difference_type: str # "missing_element", "wrong_type", "wrong_position", "low_similarity"
|
||||
severity: str # "critical", "major", "minor"
|
||||
description: str
|
||||
expected: Optional[Any] = None
|
||||
actual: Optional[Any] = None
|
||||
suggestion: Optional[str] = None
|
||||
|
||||
def to_dict(self) -> Dict[str, Any]:
|
||||
"""Convertit en dictionnaire."""
|
||||
return {
|
||||
"type": self.difference_type,
|
||||
"severity": self.severity,
|
||||
"description": self.description,
|
||||
"expected": self.expected,
|
||||
"actual": self.actual,
|
||||
"suggestion": self.suggestion
|
||||
}
|
||||
|
||||
|
||||
@dataclass
|
||||
class WorkflowMatch:
|
||||
"""Représente un match complet entre un écran et un workflow."""
|
||||
workflow_id: str
|
||||
workflow_name: str
|
||||
screen_similarity: float
|
||||
element_matches: List[ElementMatch]
|
||||
composite_score: float
|
||||
confidence: float
|
||||
match_details: Dict[str, Any]
|
||||
differences: Optional[List[MatchDifference]] = None # Nouveau champ pour le feedback
|
||||
|
||||
def to_dict(self) -> Dict[str, Any]:
|
||||
"""Convertit en dictionnaire."""
|
||||
result = {
|
||||
"workflow_id": self.workflow_id,
|
||||
"workflow_name": self.workflow_name,
|
||||
"screen_similarity": float(self.screen_similarity),
|
||||
"element_matches": [match.to_dict() for match in self.element_matches],
|
||||
"composite_score": float(self.composite_score),
|
||||
"confidence": float(self.confidence),
|
||||
"match_details": self.match_details
|
||||
}
|
||||
if self.differences:
|
||||
result["differences"] = [diff.to_dict() for diff in self.differences]
|
||||
return result
|
||||
|
||||
def get_feedback_summary(self) -> str:
|
||||
"""Génère un résumé lisible du feedback."""
|
||||
if not self.differences:
|
||||
return "✓ Match réussi - Aucun problème détecté"
|
||||
|
||||
lines = [f"⚠ Match partiel - {len(self.differences)} différence(s) détectée(s):"]
|
||||
|
||||
# Grouper par sévérité
|
||||
critical = [d for d in self.differences if d.severity == "critical"]
|
||||
major = [d for d in self.differences if d.severity == "major"]
|
||||
minor = [d for d in self.differences if d.severity == "minor"]
|
||||
|
||||
if critical:
|
||||
lines.append(f"\n🔴 Critique ({len(critical)}):")
|
||||
for diff in critical:
|
||||
lines.append(f" - {diff.description}")
|
||||
if diff.suggestion:
|
||||
lines.append(f" 💡 {diff.suggestion}")
|
||||
|
||||
if major:
|
||||
lines.append(f"\n🟠 Majeur ({len(major)}):")
|
||||
for diff in major:
|
||||
lines.append(f" - {diff.description}")
|
||||
if diff.suggestion:
|
||||
lines.append(f" 💡 {diff.suggestion}")
|
||||
|
||||
if minor:
|
||||
lines.append(f"\n🟡 Mineur ({len(minor)}):")
|
||||
for diff in minor:
|
||||
lines.append(f" - {diff.description}")
|
||||
|
||||
return "\n".join(lines)
|
||||
|
||||
|
||||
class EnhancedWorkflowMatcher:
|
||||
"""
|
||||
Matcher de workflows amélioré avec embeddings multi-modaux.
|
||||
|
||||
Stratégie de matching:
|
||||
1. Matching global de l'écran (embedding multi-modal)
|
||||
2. Matching individuel des éléments UI
|
||||
3. Scoring composite pondéré
|
||||
4. Filtrage par seuils de confiance
|
||||
"""
|
||||
|
||||
def __init__(
|
||||
self,
|
||||
multimodal_manager: MultiModalEmbeddingManager,
|
||||
logger: Optional[Logger] = None,
|
||||
config: Optional[Dict[str, Any]] = None
|
||||
):
|
||||
"""
|
||||
Initialise l'EnhancedWorkflowMatcher.
|
||||
|
||||
Args:
|
||||
multimodal_manager: Gestionnaire d'embeddings multi-modaux
|
||||
logger: Logger
|
||||
config: Configuration
|
||||
"""
|
||||
self.multimodal_manager = multimodal_manager
|
||||
self.logger = logger
|
||||
self.config = config or {}
|
||||
|
||||
# Configuration du matching
|
||||
self.screen_weight = self.config.get("screen_weight", 0.6)
|
||||
self.elements_weight = self.config.get("elements_weight", 0.4)
|
||||
self.min_similarity_threshold = self.config.get("min_similarity_threshold", 0.3)
|
||||
self.min_confidence_threshold = self.config.get("min_confidence_threshold", 0.5)
|
||||
self.max_candidates = self.config.get("max_candidates", 10)
|
||||
|
||||
# Cache des workflows et embeddings
|
||||
self._workflow_cache = {}
|
||||
self._embedding_cache = {}
|
||||
|
||||
if self.logger:
|
||||
self.logger.log_action({
|
||||
"action": "enhanced_workflow_matcher_initialized",
|
||||
"screen_weight": self.screen_weight,
|
||||
"elements_weight": self.elements_weight,
|
||||
"min_similarity_threshold": self.min_similarity_threshold
|
||||
})
|
||||
|
||||
def find_matching_workflows(
|
||||
self,
|
||||
screen_state: EnrichedScreenState,
|
||||
screenshot: Optional[np.ndarray] = None,
|
||||
workflows: Optional[List] = None,
|
||||
top_k: int = 5
|
||||
) -> List[WorkflowMatch]:
|
||||
"""
|
||||
Trouve les workflows qui matchent le mieux avec l'écran actuel.
|
||||
|
||||
Args:
|
||||
screen_state: État d'écran enrichi
|
||||
screenshot: Screenshot numpy array
|
||||
workflows: Liste de workflows à comparer (charge tous si None)
|
||||
top_k: Nombre de meilleurs matches à retourner
|
||||
|
||||
Returns:
|
||||
Liste des meilleurs WorkflowMatch triés par score
|
||||
"""
|
||||
try:
|
||||
# Générer l'embedding multi-modal de l'écran actuel
|
||||
current_embedding = self._get_screen_embedding(screen_state, screenshot)
|
||||
|
||||
if current_embedding is None:
|
||||
if self.logger:
|
||||
self.logger.log_action({
|
||||
"action": "matching_failed_no_embedding",
|
||||
"screen_state_id": screen_state.screen_state_id
|
||||
})
|
||||
return []
|
||||
|
||||
# Charger les workflows si non fournis
|
||||
if workflows is None:
|
||||
workflows = self._load_all_workflows()
|
||||
|
||||
# Calculer les matches pour chaque workflow
|
||||
matches = []
|
||||
for workflow in workflows:
|
||||
match = self._compute_workflow_match(
|
||||
screen_state, current_embedding, workflow
|
||||
)
|
||||
if match and match.composite_score >= self.min_similarity_threshold:
|
||||
matches.append(match)
|
||||
|
||||
# Trier par score composite décroissant
|
||||
matches.sort(key=lambda m: m.composite_score, reverse=True)
|
||||
|
||||
# Limiter au top_k
|
||||
matches = matches[:top_k]
|
||||
|
||||
if self.logger:
|
||||
self.logger.log_action({
|
||||
"action": "workflow_matching_completed",
|
||||
"screen_state_id": screen_state.screen_state_id,
|
||||
"workflows_evaluated": len(workflows),
|
||||
"matches_found": len(matches),
|
||||
"top_score": matches[0].composite_score if matches else 0.0
|
||||
})
|
||||
|
||||
return matches
|
||||
|
||||
except Exception as e:
|
||||
if self.logger:
|
||||
self.logger.log_action({
|
||||
"action": "workflow_matching_error",
|
||||
"screen_state_id": screen_state.screen_state_id,
|
||||
"error": str(e)
|
||||
})
|
||||
return []
|
||||
|
||||
def _get_screen_embedding(
|
||||
self,
|
||||
screen_state: EnrichedScreenState,
|
||||
screenshot: Optional[np.ndarray]
|
||||
) -> Optional[np.ndarray]:
|
||||
"""Obtient l'embedding multi-modal de l'écran."""
|
||||
try:
|
||||
# Vérifier le cache
|
||||
cache_key = screen_state.screen_state_id
|
||||
if cache_key in self._embedding_cache:
|
||||
return self._embedding_cache[cache_key]
|
||||
|
||||
# Générer l'embedding si pas en cache
|
||||
if screen_state.state_embedding and screen_state.state_embedding.provider == "multimodal_fusion_v1":
|
||||
# Charger l'embedding existant
|
||||
embedding = self.multimodal_manager.load_fused_embedding(
|
||||
screen_state.state_embedding.vector_id
|
||||
)
|
||||
else:
|
||||
# Générer un nouvel embedding multi-modal
|
||||
state_embedding = self.multimodal_manager.generate_multimodal_embedding(
|
||||
screen_state, screenshot, save=False
|
||||
)
|
||||
embedding = self.multimodal_manager.load_fused_embedding(
|
||||
state_embedding.vector_id
|
||||
)
|
||||
|
||||
# Mettre en cache
|
||||
if embedding is not None:
|
||||
self._embedding_cache[cache_key] = embedding
|
||||
|
||||
return embedding
|
||||
|
||||
except Exception as e:
|
||||
if self.logger:
|
||||
self.logger.log_action({
|
||||
"action": "screen_embedding_error",
|
||||
"screen_state_id": screen_state.screen_state_id,
|
||||
"error": str(e)
|
||||
})
|
||||
return None
|
||||
|
||||
def _compute_workflow_match(
|
||||
self,
|
||||
screen_state: EnrichedScreenState,
|
||||
current_embedding: np.ndarray,
|
||||
workflow: Any
|
||||
) -> Optional[WorkflowMatch]:
|
||||
"""
|
||||
Calcule le match entre un écran et un workflow.
|
||||
|
||||
Args:
|
||||
screen_state: État d'écran actuel
|
||||
current_embedding: Embedding de l'écran actuel
|
||||
workflow: Workflow à comparer
|
||||
|
||||
Returns:
|
||||
WorkflowMatch ou None si pas de match valide
|
||||
"""
|
||||
try:
|
||||
# 1. Matching global de l'écran
|
||||
screen_similarity = self._compute_screen_similarity(
|
||||
current_embedding, workflow
|
||||
)
|
||||
|
||||
# 2. Matching des éléments UI
|
||||
element_matches = self._compute_element_matches(
|
||||
screen_state.ui_elements, workflow
|
||||
)
|
||||
|
||||
# 3. Calcul du score composite
|
||||
elements_score = self._compute_elements_score(element_matches)
|
||||
composite_score = (
|
||||
self.screen_weight * screen_similarity +
|
||||
self.elements_weight * elements_score
|
||||
)
|
||||
|
||||
# 4. Calcul de la confiance
|
||||
confidence = self._compute_match_confidence(
|
||||
screen_similarity, elements_score, element_matches
|
||||
)
|
||||
|
||||
# 5. Détails du match
|
||||
match_details = {
|
||||
"screen_similarity": float(screen_similarity),
|
||||
"elements_score": float(elements_score),
|
||||
"elements_count": len(element_matches),
|
||||
"exact_matches": len([m for m in element_matches if m.match_type == "exact"]),
|
||||
"similar_matches": len([m for m in element_matches if m.match_type == "similar"]),
|
||||
"partial_matches": len([m for m in element_matches if m.match_type == "partial"]),
|
||||
"timestamp": datetime.now().isoformat()
|
||||
}
|
||||
|
||||
# 6. Générer le feedback détaillé si le match n'est pas parfait
|
||||
differences = None
|
||||
if composite_score < 0.9 or confidence < 0.8:
|
||||
differences = self._generate_match_feedback(
|
||||
screen_state, workflow, screen_similarity,
|
||||
element_matches, composite_score
|
||||
)
|
||||
|
||||
return WorkflowMatch(
|
||||
workflow_id=getattr(workflow, 'workflow_id', 'unknown'),
|
||||
workflow_name=getattr(workflow, 'name', 'unknown'),
|
||||
screen_similarity=screen_similarity,
|
||||
element_matches=element_matches,
|
||||
composite_score=composite_score,
|
||||
confidence=confidence,
|
||||
match_details=match_details,
|
||||
differences=differences
|
||||
)
|
||||
|
||||
except Exception as e:
|
||||
if self.logger:
|
||||
self.logger.log_action({
|
||||
"action": "workflow_match_computation_error",
|
||||
"workflow_id": getattr(workflow, 'workflow_id', 'unknown'),
|
||||
"error": str(e)
|
||||
})
|
||||
return None
|
||||
|
||||
def _compute_screen_similarity(
|
||||
self,
|
||||
current_embedding: np.ndarray,
|
||||
workflow: Any
|
||||
) -> float:
|
||||
"""
|
||||
Calcule la similarité globale de l'écran avec le workflow.
|
||||
|
||||
Compare l'embedding de l'écran actuel avec les embeddings des steps du workflow.
|
||||
Retourne la similarité maximale trouvée.
|
||||
|
||||
Args:
|
||||
current_embedding: Embedding de l'écran actuel
|
||||
workflow: Workflow à comparer
|
||||
|
||||
Returns:
|
||||
Score de similarité entre 0.0 et 1.0
|
||||
"""
|
||||
try:
|
||||
if not hasattr(workflow, 'steps') or not workflow.steps:
|
||||
return 0.0
|
||||
|
||||
similarities = []
|
||||
|
||||
for step in workflow.steps:
|
||||
if step.embedding is not None:
|
||||
# Calculer la similarité cosinus
|
||||
similarity = self.multimodal_manager.compute_similarity(
|
||||
current_embedding,
|
||||
step.embedding,
|
||||
metric="cosine"
|
||||
)
|
||||
similarities.append(similarity)
|
||||
|
||||
if similarities:
|
||||
# Retourner la similarité maximale (meilleur match)
|
||||
max_similarity = float(np.max(similarities))
|
||||
|
||||
if self.logger:
|
||||
self.logger.log_action({
|
||||
"action": "screen_similarity_computed",
|
||||
"workflow_id": getattr(workflow, 'workflow_id', 'unknown'),
|
||||
"max_similarity": max_similarity,
|
||||
"avg_similarity": float(np.mean(similarities)),
|
||||
"steps_compared": len(similarities)
|
||||
})
|
||||
|
||||
return max_similarity
|
||||
else:
|
||||
# Aucun embedding disponible dans le workflow
|
||||
if self.logger:
|
||||
self.logger.log_action({
|
||||
"action": "screen_similarity_no_embeddings",
|
||||
"workflow_id": getattr(workflow, 'workflow_id', 'unknown')
|
||||
})
|
||||
return 0.0
|
||||
|
||||
except Exception as e:
|
||||
if self.logger:
|
||||
self.logger.log_action({
|
||||
"action": "screen_similarity_error",
|
||||
"workflow_id": getattr(workflow, 'workflow_id', 'unknown'),
|
||||
"error": str(e)
|
||||
})
|
||||
return 0.0
|
||||
|
||||
def _compute_element_matches(
|
||||
self,
|
||||
ui_elements: List[UIElement],
|
||||
workflow: Any
|
||||
) -> List[ElementMatch]:
|
||||
"""
|
||||
Calcule les matches entre éléments UI et éléments de workflow.
|
||||
|
||||
Compare les éléments UI détectés avec les target_description des steps du workflow.
|
||||
Utilise la similarité textuelle et la position pour trouver les meilleurs matches.
|
||||
|
||||
Args:
|
||||
ui_elements: Liste des éléments UI détectés
|
||||
workflow: Workflow à comparer
|
||||
|
||||
Returns:
|
||||
Liste des ElementMatch trouvés
|
||||
"""
|
||||
if not ui_elements or not hasattr(workflow, 'steps') or not workflow.steps:
|
||||
return []
|
||||
|
||||
matches = []
|
||||
|
||||
try:
|
||||
# Pour chaque step du workflow, chercher le meilleur élément UI correspondant
|
||||
for step in workflow.steps:
|
||||
if not step.target_description:
|
||||
continue
|
||||
|
||||
best_match = None
|
||||
best_score = 0.0
|
||||
|
||||
# Comparer avec chaque élément UI
|
||||
for ui_element in ui_elements:
|
||||
# Calculer la similarité textuelle
|
||||
text_similarity = self._compute_text_similarity(
|
||||
ui_element.label.lower(),
|
||||
step.target_description.lower()
|
||||
)
|
||||
|
||||
# Calculer la similarité de position (si disponible)
|
||||
position_similarity = 0.0
|
||||
if step.position and ui_element.bbox:
|
||||
position_similarity = self._compute_position_similarity(
|
||||
ui_element.bbox,
|
||||
step.position
|
||||
)
|
||||
|
||||
# Score combiné (70% texte, 30% position)
|
||||
combined_score = 0.7 * text_similarity + 0.3 * position_similarity
|
||||
|
||||
if combined_score > best_score:
|
||||
best_score = combined_score
|
||||
|
||||
# Déterminer le type de match
|
||||
if combined_score >= 0.85:
|
||||
match_type = "exact"
|
||||
elif combined_score >= 0.6:
|
||||
match_type = "similar"
|
||||
else:
|
||||
match_type = "partial"
|
||||
|
||||
best_match = ElementMatch(
|
||||
ui_element=ui_element,
|
||||
workflow_element_id=f"{workflow.workflow_id}_step_{step.step_id}",
|
||||
similarity_score=combined_score,
|
||||
match_type=match_type,
|
||||
confidence=combined_score * ui_element.confidence
|
||||
)
|
||||
|
||||
# Ajouter le meilleur match si au-dessus du seuil
|
||||
if best_match and best_match.similarity_score >= 0.3:
|
||||
matches.append(best_match)
|
||||
|
||||
if self.logger:
|
||||
self.logger.log_action({
|
||||
"action": "element_matches_computed",
|
||||
"workflow_id": getattr(workflow, 'workflow_id', 'unknown'),
|
||||
"ui_elements_count": len(ui_elements),
|
||||
"workflow_steps_count": len(workflow.steps),
|
||||
"matches_found": len(matches),
|
||||
"exact_matches": len([m for m in matches if m.match_type == "exact"]),
|
||||
"similar_matches": len([m for m in matches if m.match_type == "similar"]),
|
||||
"partial_matches": len([m for m in matches if m.match_type == "partial"])
|
||||
})
|
||||
|
||||
return matches
|
||||
|
||||
except Exception as e:
|
||||
if self.logger:
|
||||
self.logger.log_action({
|
||||
"action": "element_matches_error",
|
||||
"workflow_id": getattr(workflow, 'workflow_id', 'unknown'),
|
||||
"error": str(e)
|
||||
})
|
||||
return []
|
||||
|
||||
def _compute_position_similarity(
|
||||
self,
|
||||
bbox: Tuple[int, int, int, int],
|
||||
target_position: Tuple[int, int]
|
||||
) -> float:
|
||||
"""
|
||||
Calcule la similarité de position entre un bbox et une position cible.
|
||||
|
||||
Args:
|
||||
bbox: (x, y, width, height) de l'élément UI
|
||||
target_position: (x, y) de la position cible
|
||||
|
||||
Returns:
|
||||
Score de similarité entre 0.0 et 1.0
|
||||
"""
|
||||
try:
|
||||
# Calculer le centre du bbox
|
||||
center_x = bbox[0] + bbox[2] / 2
|
||||
center_y = bbox[1] + bbox[3] / 2
|
||||
|
||||
# Calculer la distance euclidienne
|
||||
distance = np.sqrt(
|
||||
(center_x - target_position[0]) ** 2 +
|
||||
(center_y - target_position[1]) ** 2
|
||||
)
|
||||
|
||||
# Normaliser la distance (supposons un écran de 1920x1080)
|
||||
max_distance = np.sqrt(1920**2 + 1080**2)
|
||||
normalized_distance = distance / max_distance
|
||||
|
||||
# Convertir en similarité (1.0 = même position, 0.0 = très éloigné)
|
||||
similarity = 1.0 - normalized_distance
|
||||
|
||||
return max(0.0, min(1.0, similarity))
|
||||
|
||||
except Exception as e:
|
||||
return 0.0
|
||||
|
||||
def _compute_text_similarity(self, text1: str, text2: str) -> float:
|
||||
"""
|
||||
Calcule la similarité entre deux textes.
|
||||
|
||||
Utilise la similarité de Jaccard (intersection / union des mots).
|
||||
|
||||
Args:
|
||||
text1: Premier texte
|
||||
text2: Deuxième texte
|
||||
|
||||
Returns:
|
||||
Score de similarité entre 0.0 et 1.0
|
||||
"""
|
||||
if not text1 or not text2:
|
||||
return 0.0
|
||||
|
||||
# Tokenisation simple
|
||||
words1 = set(text1.lower().split())
|
||||
words2 = set(text2.lower().split())
|
||||
|
||||
# Intersection et union
|
||||
intersection = words1.intersection(words2)
|
||||
union = words1.union(words2)
|
||||
|
||||
if not union:
|
||||
return 0.0
|
||||
|
||||
# Jaccard similarity
|
||||
return len(intersection) / len(union)
|
||||
|
||||
def _compute_elements_score(self, element_matches: List[ElementMatch]) -> float:
|
||||
"""
|
||||
Calcule le score global des matches d'éléments.
|
||||
"""
|
||||
if not element_matches:
|
||||
return 0.0
|
||||
|
||||
# Moyenne pondérée par la confiance
|
||||
total_score = 0.0
|
||||
total_weight = 0.0
|
||||
|
||||
for match in element_matches:
|
||||
weight = match.confidence
|
||||
total_score += match.similarity_score * weight
|
||||
total_weight += weight
|
||||
|
||||
if total_weight > 0:
|
||||
return total_score / total_weight
|
||||
else:
|
||||
return 0.0
|
||||
|
||||
def _compute_match_confidence(
|
||||
self,
|
||||
screen_similarity: float,
|
||||
elements_score: float,
|
||||
element_matches: List[ElementMatch]
|
||||
) -> float:
|
||||
"""
|
||||
Calcule la confiance globale du match.
|
||||
"""
|
||||
# Facteurs de confiance
|
||||
base_confidence = (screen_similarity + elements_score) / 2
|
||||
|
||||
# Bonus pour nombre d'éléments matchés
|
||||
elements_bonus = min(len(element_matches) * 0.1, 0.3)
|
||||
|
||||
# Bonus pour matches exacts
|
||||
exact_matches = len([m for m in element_matches if m.match_type == "exact"])
|
||||
exact_bonus = min(exact_matches * 0.05, 0.2)
|
||||
|
||||
confidence = base_confidence + elements_bonus + exact_bonus
|
||||
return min(confidence, 1.0)
|
||||
|
||||
def _load_all_workflows(self) -> List:
|
||||
"""
|
||||
Charge tous les workflows disponibles.
|
||||
Pour l'instant, retourne une liste vide.
|
||||
TODO: Intégrer avec le système de workflows existant.
|
||||
"""
|
||||
# Placeholder - à intégrer avec le WorkflowDetector existant
|
||||
return []
|
||||
|
||||
def _generate_match_feedback(
|
||||
self,
|
||||
screen_state: EnrichedScreenState,
|
||||
workflow: Any,
|
||||
screen_similarity: float,
|
||||
element_matches: List[ElementMatch],
|
||||
composite_score: float
|
||||
) -> List[MatchDifference]:
|
||||
"""
|
||||
Génère un feedback détaillé sur les différences détectées.
|
||||
|
||||
Args:
|
||||
screen_state: État d'écran actuel
|
||||
workflow: Workflow comparé
|
||||
screen_similarity: Score de similarité d'écran
|
||||
element_matches: Matches d'éléments trouvés
|
||||
composite_score: Score composite final
|
||||
|
||||
Returns:
|
||||
Liste des différences détectées
|
||||
"""
|
||||
differences = []
|
||||
|
||||
try:
|
||||
# 1. Vérifier la similarité d'écran
|
||||
if screen_similarity < 0.7:
|
||||
severity = "critical" if screen_similarity < 0.5 else "major"
|
||||
differences.append(MatchDifference(
|
||||
difference_type="low_similarity",
|
||||
severity=severity,
|
||||
description=f"Similarité d'écran faible: {screen_similarity:.2f}",
|
||||
expected="≥ 0.70",
|
||||
actual=f"{screen_similarity:.2f}",
|
||||
suggestion="L'écran actuel semble très différent du workflow attendu. Vérifiez que vous êtes sur la bonne application/fenêtre."
|
||||
))
|
||||
|
||||
# 2. Vérifier les éléments manquants
|
||||
if hasattr(workflow, 'steps') and workflow.steps:
|
||||
expected_elements = len(workflow.steps)
|
||||
matched_elements = len([m for m in element_matches if m.match_type in ["exact", "similar"]])
|
||||
|
||||
if matched_elements < expected_elements:
|
||||
missing_count = expected_elements - matched_elements
|
||||
severity = "critical" if missing_count > expected_elements / 2 else "major"
|
||||
differences.append(MatchDifference(
|
||||
difference_type="missing_element",
|
||||
severity=severity,
|
||||
description=f"{missing_count} élément(s) requis manquant(s)",
|
||||
expected=f"{expected_elements} éléments",
|
||||
actual=f"{matched_elements} éléments trouvés",
|
||||
suggestion=f"Vérifiez que tous les éléments UI sont visibles à l'écran. Éléments manquants: {missing_count}"
|
||||
))
|
||||
|
||||
# 3. Vérifier les matches partiels
|
||||
partial_matches = [m for m in element_matches if m.match_type == "partial"]
|
||||
if partial_matches:
|
||||
for match in partial_matches:
|
||||
differences.append(MatchDifference(
|
||||
difference_type="low_similarity",
|
||||
severity="minor",
|
||||
description=f"Élément '{match.ui_element.label}' partiellement correspondant",
|
||||
expected="Match exact ou similaire",
|
||||
actual=f"Score: {match.similarity_score:.2f}",
|
||||
suggestion="L'élément existe mais ne correspond pas parfaitement (texte différent ou position décalée)"
|
||||
))
|
||||
|
||||
# 4. Vérifier les types d'éléments incorrects
|
||||
for match in element_matches:
|
||||
# Vérifier si le type d'élément est cohérent avec l'action attendue
|
||||
# (Cette logique pourrait être améliorée avec plus d'informations du workflow)
|
||||
if match.confidence < 0.5:
|
||||
differences.append(MatchDifference(
|
||||
difference_type="wrong_type",
|
||||
severity="major",
|
||||
description=f"Type d'élément incertain pour '{match.ui_element.label}'",
|
||||
expected="Confiance ≥ 0.50",
|
||||
actual=f"Confiance: {match.confidence:.2f}",
|
||||
suggestion="L'élément détecté pourrait ne pas être du bon type pour l'action attendue"
|
||||
))
|
||||
|
||||
# 5. Vérifier le score composite global
|
||||
if composite_score < 0.5:
|
||||
differences.append(MatchDifference(
|
||||
difference_type="low_similarity",
|
||||
severity="critical",
|
||||
description=f"Score composite très faible: {composite_score:.2f}",
|
||||
expected="≥ 0.50",
|
||||
actual=f"{composite_score:.2f}",
|
||||
suggestion="Le workflow ne correspond pas à l'écran actuel. Considérez un workflow différent ou vérifiez l'état de l'application."
|
||||
))
|
||||
elif composite_score < 0.7:
|
||||
differences.append(MatchDifference(
|
||||
difference_type="low_similarity",
|
||||
severity="major",
|
||||
description=f"Score composite modéré: {composite_score:.2f}",
|
||||
expected="≥ 0.70",
|
||||
actual=f"{composite_score:.2f}",
|
||||
suggestion="Le match est acceptable mais pas optimal. Certains éléments peuvent être différents."
|
||||
))
|
||||
|
||||
if self.logger and differences:
|
||||
self.logger.log_action({
|
||||
"action": "match_feedback_generated",
|
||||
"workflow_id": getattr(workflow, 'workflow_id', 'unknown'),
|
||||
"differences_count": len(differences),
|
||||
"critical": len([d for d in differences if d.severity == "critical"]),
|
||||
"major": len([d for d in differences if d.severity == "major"]),
|
||||
"minor": len([d for d in differences if d.severity == "minor"])
|
||||
})
|
||||
|
||||
return differences
|
||||
|
||||
except Exception as e:
|
||||
if self.logger:
|
||||
self.logger.log_action({
|
||||
"action": "feedback_generation_error",
|
||||
"workflow_id": getattr(workflow, 'workflow_id', 'unknown'),
|
||||
"error": str(e)
|
||||
})
|
||||
return []
|
||||
|
||||
def get_match_explanation(
|
||||
self,
|
||||
match: WorkflowMatch
|
||||
) -> Dict[str, Any]:
|
||||
"""
|
||||
Génère une explication détaillée d'un match.
|
||||
|
||||
Args:
|
||||
match: WorkflowMatch à expliquer
|
||||
|
||||
Returns:
|
||||
Dictionnaire avec l'explication détaillée
|
||||
"""
|
||||
explanation = {
|
||||
"workflow_info": {
|
||||
"id": match.workflow_id,
|
||||
"name": match.workflow_name
|
||||
},
|
||||
"scores": {
|
||||
"composite_score": match.composite_score,
|
||||
"screen_similarity": match.screen_similarity,
|
||||
"confidence": match.confidence
|
||||
},
|
||||
"element_analysis": {
|
||||
"total_matches": len(match.element_matches),
|
||||
"exact_matches": len([m for m in match.element_matches if m.match_type == "exact"]),
|
||||
"similar_matches": len([m for m in match.element_matches if m.match_type == "similar"]),
|
||||
"partial_matches": len([m for m in match.element_matches if m.match_type == "partial"])
|
||||
},
|
||||
"top_element_matches": [
|
||||
{
|
||||
"ui_element": m.ui_element.label,
|
||||
"similarity": m.similarity_score,
|
||||
"type": m.match_type
|
||||
}
|
||||
for m in sorted(match.element_matches, key=lambda x: x.similarity_score, reverse=True)[:5]
|
||||
],
|
||||
"match_details": match.match_details
|
||||
}
|
||||
|
||||
return explanation
|
||||
|
||||
|
||||
if __name__ == "__main__":
|
||||
# Tests basiques
|
||||
from .logger import Logger
|
||||
from .multimodal_embedding_manager import MultiModalEmbeddingManager
|
||||
import shutil
|
||||
|
||||
print("EnhancedWorkflowMatcher - Tests basiques")
|
||||
print("=" * 50)
|
||||
|
||||
# Test EnhancedWorkflowMatcher
|
||||
print("\n1. Test EnhancedWorkflowMatcher:")
|
||||
logger = Logger(log_dir="test_logs")
|
||||
multimodal_manager = MultiModalEmbeddingManager(logger=logger, data_dir="test_data")
|
||||
|
||||
matcher = EnhancedWorkflowMatcher(
|
||||
multimodal_manager=multimodal_manager,
|
||||
logger=logger,
|
||||
config={
|
||||
"screen_weight": 0.6,
|
||||
"elements_weight": 0.4,
|
||||
"min_similarity_threshold": 0.3
|
||||
}
|
||||
)
|
||||
|
||||
print(f" Matcher créé")
|
||||
print(f" Screen weight: {matcher.screen_weight}")
|
||||
print(f" Elements weight: {matcher.elements_weight}")
|
||||
print(f" Min threshold: {matcher.min_similarity_threshold}")
|
||||
|
||||
print("\n✓ Tous les tests basiques réussis!")
|
||||
|
||||
# Nettoyage
|
||||
if Path("test_data").exists():
|
||||
shutil.rmtree("test_data")
|
||||
if Path("test_logs").exists():
|
||||
shutil.rmtree("test_logs")
|
||||
367
geniusia2/core/enriched_screen_capture.py
Normal file
367
geniusia2/core/enriched_screen_capture.py
Normal file
@@ -0,0 +1,367 @@
|
||||
"""
|
||||
Module d'intégration pour la capture d'écran enrichie avec détection d'éléments UI.
|
||||
Intègre UIElementDetector avec ScreenStateManager pour le mode enrichi.
|
||||
"""
|
||||
|
||||
from typing import Optional, List
|
||||
import numpy as np
|
||||
|
||||
from .ui_element_detector import UIElementDetector
|
||||
from .screen_state_manager import ScreenStateManager
|
||||
from .ui_element_models import EnrichedScreenState, WindowInfo
|
||||
from .multimodal_embedding_manager import MultiModalEmbeddingManager
|
||||
from .enhanced_workflow_matcher import EnhancedWorkflowMatcher
|
||||
from .llm_manager import LLMManager
|
||||
from .logger import Logger
|
||||
|
||||
|
||||
class EnrichedScreenCapture:
|
||||
"""
|
||||
Gestionnaire de capture d'écran enrichie.
|
||||
|
||||
Combine:
|
||||
- Capture d'écran
|
||||
- Détection d'éléments UI (mode enrichi)
|
||||
- Création d'EnrichedScreenState
|
||||
- Sauvegarde
|
||||
"""
|
||||
|
||||
def __init__(
|
||||
self,
|
||||
llm_manager: Optional[LLMManager] = None,
|
||||
logger: Optional[Logger] = None,
|
||||
data_dir: str = "data",
|
||||
mode: str = "enriched",
|
||||
config: Optional[dict] = None
|
||||
):
|
||||
"""
|
||||
Initialise le gestionnaire de capture enrichie.
|
||||
|
||||
Args:
|
||||
llm_manager: Gestionnaire LLM pour VLM
|
||||
logger: Logger
|
||||
data_dir: Répertoire de données
|
||||
mode: Mode de traitement ("light", "enriched", "complete")
|
||||
config: Configuration
|
||||
"""
|
||||
self.llm = llm_manager
|
||||
self.logger = logger
|
||||
self.data_dir = data_dir
|
||||
self.mode = mode
|
||||
self.config = config or {}
|
||||
|
||||
# Créer le ScreenStateManager
|
||||
self.screen_state_manager = ScreenStateManager(
|
||||
logger=logger,
|
||||
data_dir=data_dir,
|
||||
mode=mode
|
||||
)
|
||||
|
||||
# Créer le UIElementDetector (seulement en mode enrichi ou complet)
|
||||
self.ui_detector = None
|
||||
if mode in ["enriched", "complete"]:
|
||||
self.ui_detector = UIElementDetector(
|
||||
llm_manager=llm_manager,
|
||||
logger=logger,
|
||||
config=self.config.get("ui_detector", {})
|
||||
)
|
||||
|
||||
# Créer le MultiModalEmbeddingManager (seulement en mode complet)
|
||||
self.multimodal_manager = None
|
||||
if mode == "complete":
|
||||
self.multimodal_manager = MultiModalEmbeddingManager(
|
||||
logger=logger,
|
||||
data_dir=data_dir,
|
||||
config=self.config.get("multimodal_embedding", {})
|
||||
)
|
||||
|
||||
# Créer l'EnhancedWorkflowMatcher (seulement en mode complet)
|
||||
self.enhanced_matcher = None
|
||||
if mode == "complete" and self.multimodal_manager:
|
||||
self.enhanced_matcher = EnhancedWorkflowMatcher(
|
||||
multimodal_manager=self.multimodal_manager,
|
||||
logger=logger,
|
||||
config=self.config.get("enhanced_matcher", {})
|
||||
)
|
||||
|
||||
if self.logger:
|
||||
self.logger.log_action({
|
||||
"action": "enriched_screen_capture_initialized",
|
||||
"mode": mode,
|
||||
"ui_detection_enabled": self.ui_detector is not None,
|
||||
"multimodal_embedding_enabled": self.multimodal_manager is not None,
|
||||
"enhanced_matching_enabled": self.enhanced_matcher is not None
|
||||
})
|
||||
|
||||
def capture_and_enrich(
|
||||
self,
|
||||
screenshot: np.ndarray,
|
||||
session_id: str,
|
||||
window_title: str,
|
||||
app_name: str,
|
||||
screen_resolution: tuple,
|
||||
detected_text: Optional[List[str]] = None,
|
||||
context_tags: Optional[List[str]] = None,
|
||||
workflow_candidate: Optional[str] = None,
|
||||
save: bool = True
|
||||
) -> EnrichedScreenState:
|
||||
"""
|
||||
Capture et enrichit un screenshot avec détection d'éléments.
|
||||
|
||||
Args:
|
||||
screenshot: Screenshot numpy array
|
||||
session_id: ID de session
|
||||
window_title: Titre de la fenêtre
|
||||
app_name: Nom de l'application
|
||||
screen_resolution: Résolution d'écran
|
||||
detected_text: Texte détecté (optionnel)
|
||||
context_tags: Tags de contexte
|
||||
workflow_candidate: Workflow candidat
|
||||
save: Sauvegarder sur disque
|
||||
|
||||
Returns:
|
||||
EnrichedScreenState créé
|
||||
"""
|
||||
# Créer les informations de fenêtre
|
||||
window_info = WindowInfo(
|
||||
app_name=app_name,
|
||||
window_title=window_title,
|
||||
screen_resolution=screen_resolution
|
||||
)
|
||||
|
||||
# Détecter les éléments UI (si mode enrichi/complet)
|
||||
ui_elements = []
|
||||
if self.ui_detector:
|
||||
try:
|
||||
if self.logger:
|
||||
self.logger.log_action({
|
||||
"action": "ui_detection_started",
|
||||
"app_name": app_name
|
||||
})
|
||||
|
||||
ui_elements = self.ui_detector.detect_elements(
|
||||
screenshot=screenshot,
|
||||
window_info=window_info,
|
||||
data_dir=self.data_dir
|
||||
)
|
||||
|
||||
if self.logger:
|
||||
self.logger.log_action({
|
||||
"action": "ui_detection_completed",
|
||||
"elements_count": len(ui_elements)
|
||||
})
|
||||
|
||||
except Exception as e:
|
||||
if self.logger:
|
||||
self.logger.log_action({
|
||||
"action": "ui_detection_error",
|
||||
"error": str(e)
|
||||
})
|
||||
# Continuer sans éléments UI
|
||||
ui_elements = []
|
||||
|
||||
# Créer l'EnrichedScreenState
|
||||
# Pour l'instant, on utilise le ScreenStateManager pour créer la base
|
||||
screen_state = self.screen_state_manager.create_screen_state(
|
||||
session_id=session_id,
|
||||
window_title=window_title,
|
||||
app_name=app_name,
|
||||
screenshot_path=f"{self.data_dir}/screens/temp_screenshot.png",
|
||||
screen_resolution=screen_resolution,
|
||||
detected_text=detected_text,
|
||||
context_tags=context_tags,
|
||||
workflow_candidate=workflow_candidate
|
||||
)
|
||||
|
||||
# Ajouter les éléments UI détectés
|
||||
screen_state.ui_elements = ui_elements
|
||||
screen_state.mode = self.mode
|
||||
|
||||
# Mode complet: générer l'embedding multi-modal
|
||||
if self.mode == "complete" and self.multimodal_manager:
|
||||
try:
|
||||
if self.logger:
|
||||
self.logger.log_action({
|
||||
"action": "multimodal_embedding_generation_started",
|
||||
"screen_state_id": screen_state.screen_state_id
|
||||
})
|
||||
|
||||
# Générer l'embedding multi-modal
|
||||
multimodal_embedding = self.multimodal_manager.generate_multimodal_embedding(
|
||||
screen_state=screen_state,
|
||||
screenshot=screenshot,
|
||||
save=save
|
||||
)
|
||||
|
||||
# Remplacer l'embedding simple par l'embedding multi-modal
|
||||
screen_state.state_embedding = multimodal_embedding
|
||||
|
||||
if self.logger:
|
||||
self.logger.log_action({
|
||||
"action": "multimodal_embedding_generated",
|
||||
"screen_state_id": screen_state.screen_state_id,
|
||||
"provider": multimodal_embedding.provider
|
||||
})
|
||||
|
||||
except Exception as e:
|
||||
if self.logger:
|
||||
self.logger.log_action({
|
||||
"action": "multimodal_embedding_error",
|
||||
"screen_state_id": screen_state.screen_state_id,
|
||||
"error": str(e)
|
||||
})
|
||||
# Continuer avec l'embedding simple
|
||||
|
||||
# Sauvegarder si demandé
|
||||
if save:
|
||||
# Sauvegarder le screenshot
|
||||
import cv2
|
||||
from pathlib import Path
|
||||
screenshot_path = Path(self.data_dir) / "screens" / f"{screen_state.screen_state_id}.png"
|
||||
screenshot_path.parent.mkdir(parents=True, exist_ok=True)
|
||||
cv2.imwrite(str(screenshot_path), screenshot)
|
||||
|
||||
# Mettre à jour le chemin
|
||||
screen_state.raw.screenshot_path = str(screenshot_path)
|
||||
|
||||
# Sauvegarder l'état
|
||||
self.screen_state_manager.save_screen_state(screen_state)
|
||||
|
||||
return screen_state
|
||||
|
||||
def get_mode(self) -> str:
|
||||
"""Retourne le mode actuel."""
|
||||
return self.mode
|
||||
|
||||
def set_mode(self, mode: str):
|
||||
"""
|
||||
Change le mode de traitement.
|
||||
|
||||
Args:
|
||||
mode: Nouveau mode ("light", "enriched", "complete")
|
||||
"""
|
||||
self.mode = mode
|
||||
self.screen_state_manager.mode = mode
|
||||
|
||||
# Créer/détruire le UIElementDetector selon le mode
|
||||
if mode in ["enriched", "complete"] and self.ui_detector is None:
|
||||
self.ui_detector = UIElementDetector(
|
||||
llm_manager=self.llm,
|
||||
logger=self.logger,
|
||||
config=self.config.get("ui_detector", {})
|
||||
)
|
||||
elif mode == "light":
|
||||
self.ui_detector = None
|
||||
|
||||
# Créer/détruire le MultiModalEmbeddingManager selon le mode
|
||||
if mode == "complete" and self.multimodal_manager is None:
|
||||
self.multimodal_manager = MultiModalEmbeddingManager(
|
||||
logger=self.logger,
|
||||
data_dir=self.data_dir,
|
||||
config=self.config.get("multimodal_embedding", {})
|
||||
)
|
||||
elif mode != "complete":
|
||||
self.multimodal_manager = None
|
||||
|
||||
# Créer/détruire l'EnhancedWorkflowMatcher selon le mode
|
||||
if mode == "complete" and self.multimodal_manager and self.enhanced_matcher is None:
|
||||
self.enhanced_matcher = EnhancedWorkflowMatcher(
|
||||
multimodal_manager=self.multimodal_manager,
|
||||
logger=self.logger,
|
||||
config=self.config.get("enhanced_matcher", {})
|
||||
)
|
||||
elif mode != "complete":
|
||||
self.enhanced_matcher = None
|
||||
|
||||
if self.logger:
|
||||
self.logger.log_action({
|
||||
"action": "mode_changed",
|
||||
"new_mode": mode,
|
||||
"ui_detection_enabled": self.ui_detector is not None,
|
||||
"multimodal_embedding_enabled": self.multimodal_manager is not None,
|
||||
"enhanced_matching_enabled": self.enhanced_matcher is not None
|
||||
})
|
||||
|
||||
def find_matching_workflows(
|
||||
self,
|
||||
screen_state: EnrichedScreenState,
|
||||
screenshot: Optional[np.ndarray] = None,
|
||||
workflows: Optional[List] = None,
|
||||
top_k: int = 5
|
||||
):
|
||||
"""
|
||||
Trouve les workflows qui matchent le mieux avec l'écran actuel.
|
||||
|
||||
Utilise l'EnhancedWorkflowMatcher en mode complet, sinon retourne None.
|
||||
|
||||
Args:
|
||||
screen_state: État d'écran enrichi
|
||||
screenshot: Screenshot numpy array (optionnel)
|
||||
workflows: Liste de workflows à comparer (charge tous si None)
|
||||
top_k: Nombre de meilleurs matches à retourner
|
||||
|
||||
Returns:
|
||||
Liste des meilleurs WorkflowMatch ou None si pas en mode complet
|
||||
"""
|
||||
if self.mode == "complete" and self.enhanced_matcher:
|
||||
return self.enhanced_matcher.find_matching_workflows(
|
||||
screen_state=screen_state,
|
||||
screenshot=screenshot,
|
||||
workflows=workflows,
|
||||
top_k=top_k
|
||||
)
|
||||
else:
|
||||
if self.logger:
|
||||
self.logger.log_action({
|
||||
"action": "enhanced_matching_not_available",
|
||||
"current_mode": self.mode,
|
||||
"reason": "Enhanced matching requires 'complete' mode"
|
||||
})
|
||||
return None
|
||||
|
||||
|
||||
if __name__ == "__main__":
|
||||
# Tests basiques
|
||||
from .logger import Logger
|
||||
import shutil
|
||||
from pathlib import Path
|
||||
|
||||
print("EnrichedScreenCapture - Tests basiques")
|
||||
print("=" * 50)
|
||||
|
||||
# Créer un logger de test
|
||||
logger = Logger(log_dir="test_logs")
|
||||
|
||||
# Test mode light
|
||||
print("\n1. Test mode light:")
|
||||
capture_light = EnrichedScreenCapture(
|
||||
logger=logger,
|
||||
data_dir="test_data",
|
||||
mode="light"
|
||||
)
|
||||
print(f" Mode: {capture_light.get_mode()}")
|
||||
print(f" UI Detector: {capture_light.ui_detector is not None}")
|
||||
|
||||
# Test mode enriched
|
||||
print("\n2. Test mode enriched:")
|
||||
capture_enriched = EnrichedScreenCapture(
|
||||
logger=logger,
|
||||
data_dir="test_data",
|
||||
mode="enriched"
|
||||
)
|
||||
print(f" Mode: {capture_enriched.get_mode()}")
|
||||
print(f" UI Detector: {capture_enriched.ui_detector is not None}")
|
||||
|
||||
# Test changement de mode
|
||||
print("\n3. Test changement de mode:")
|
||||
capture_enriched.set_mode("light")
|
||||
print(f" Nouveau mode: {capture_enriched.get_mode()}")
|
||||
print(f" UI Detector après changement: {capture_enriched.ui_detector is not None}")
|
||||
|
||||
print("\n✓ Tests basiques réussis!")
|
||||
|
||||
# Nettoyage
|
||||
if Path("test_data").exists():
|
||||
shutil.rmtree("test_data")
|
||||
if Path("test_logs").exists():
|
||||
shutil.rmtree("test_logs")
|
||||
488
geniusia2/core/event_capture.py
Normal file
488
geniusia2/core/event_capture.py
Normal file
@@ -0,0 +1,488 @@
|
||||
"""
|
||||
Capture des événements utilisateur (clavier et souris) pour l'apprentissage.
|
||||
Utilise pynput pour capturer les actions en temps réel.
|
||||
"""
|
||||
|
||||
import time
|
||||
from datetime import datetime
|
||||
from typing import Dict, List, Optional, Callable, Any
|
||||
from threading import Thread, Lock
|
||||
from collections import deque
|
||||
|
||||
try:
|
||||
from pynput import mouse, keyboard
|
||||
PYNPUT_AVAILABLE = True
|
||||
except ImportError:
|
||||
PYNPUT_AVAILABLE = False
|
||||
print("⚠️ pynput n'est pas installé. Capture d'événements désactivée.")
|
||||
|
||||
from .logger import Logger
|
||||
from .utils.image_utils import get_active_window
|
||||
from .session_manager import SessionManager
|
||||
from .workflow_detector import WorkflowDetector
|
||||
|
||||
|
||||
class EventCapture:
|
||||
"""
|
||||
Capture les événements clavier et souris pour détecter les patterns répétitifs.
|
||||
"""
|
||||
|
||||
def __init__(
|
||||
self,
|
||||
logger: Optional[Logger] = None,
|
||||
max_history: int = 1000,
|
||||
pattern_threshold: int = 3,
|
||||
config: Optional[Dict[str, Any]] = None
|
||||
):
|
||||
"""
|
||||
Initialise le capteur d'événements.
|
||||
|
||||
Args:
|
||||
logger: Logger pour journalisation
|
||||
max_history: Nombre max d'événements à garder en mémoire
|
||||
pattern_threshold: Nombre de répétitions pour détecter un pattern
|
||||
config: Configuration pour les workflows
|
||||
"""
|
||||
self.logger = logger
|
||||
self.max_history = max_history
|
||||
self.pattern_threshold = pattern_threshold
|
||||
self.config = config or {}
|
||||
|
||||
# Historique des événements
|
||||
self.events: deque = deque(maxlen=max_history)
|
||||
self.events_lock = Lock()
|
||||
|
||||
# État de capture
|
||||
self.capturing = False
|
||||
self.mouse_listener = None
|
||||
self.keyboard_listener = None
|
||||
|
||||
# Callbacks pour patterns détectés
|
||||
self.pattern_callbacks: List[Callable] = []
|
||||
|
||||
# Dernière fenêtre active
|
||||
self.last_window = None
|
||||
|
||||
# Composants pour la détection de workflows
|
||||
self.session_manager = SessionManager(logger, self.config)
|
||||
self.workflow_detector = WorkflowDetector(logger, self.config)
|
||||
|
||||
# Connecter les callbacks
|
||||
self.session_manager.on_session_completed = self._on_session_completed
|
||||
self.workflow_detector.on_workflow_detected = self._on_workflow_detected
|
||||
|
||||
if not PYNPUT_AVAILABLE:
|
||||
if logger:
|
||||
logger.log_action({
|
||||
"action": "event_capture_unavailable",
|
||||
"reason": "pynput not installed"
|
||||
})
|
||||
|
||||
def start(self):
|
||||
"""Démarre la capture d'événements."""
|
||||
if not PYNPUT_AVAILABLE:
|
||||
print("⚠️ Impossible de démarrer la capture : pynput non disponible")
|
||||
return False
|
||||
|
||||
if self.capturing:
|
||||
return True
|
||||
|
||||
self.capturing = True
|
||||
|
||||
# Démarrer les listeners
|
||||
self.mouse_listener = mouse.Listener(
|
||||
on_click=self._on_mouse_click,
|
||||
on_move=self._on_mouse_move,
|
||||
on_scroll=self._on_mouse_scroll
|
||||
)
|
||||
|
||||
self.keyboard_listener = keyboard.Listener(
|
||||
on_press=self._on_key_press
|
||||
)
|
||||
|
||||
self.mouse_listener.start()
|
||||
self.keyboard_listener.start()
|
||||
|
||||
if self.logger:
|
||||
self.logger.log_action({
|
||||
"action": "event_capture_started"
|
||||
})
|
||||
|
||||
print("✅ Capture d'événements démarrée")
|
||||
return True
|
||||
|
||||
def stop(self):
|
||||
"""Arrête la capture d'événements de manière synchrone."""
|
||||
if not self.capturing:
|
||||
return
|
||||
|
||||
self.capturing = False
|
||||
|
||||
# Arrêter les listeners et attendre qu'ils se terminent
|
||||
if self.mouse_listener:
|
||||
self.mouse_listener.stop()
|
||||
try:
|
||||
# Attendre max 2 secondes que le listener se termine
|
||||
self.mouse_listener.join(timeout=2.0)
|
||||
except Exception as e:
|
||||
if self.logger:
|
||||
self.logger.log_action({
|
||||
"action": "mouse_listener_stop_error",
|
||||
"error": str(e)
|
||||
})
|
||||
|
||||
if self.keyboard_listener:
|
||||
self.keyboard_listener.stop()
|
||||
try:
|
||||
# Attendre max 2 secondes que le listener se termine
|
||||
self.keyboard_listener.join(timeout=2.0)
|
||||
except Exception as e:
|
||||
if self.logger:
|
||||
self.logger.log_action({
|
||||
"action": "keyboard_listener_stop_error",
|
||||
"error": str(e)
|
||||
})
|
||||
|
||||
if self.logger:
|
||||
self.logger.log_action({
|
||||
"action": "event_capture_stopped",
|
||||
"total_events": len(self.events)
|
||||
})
|
||||
|
||||
print("⏹️ Capture d'événements arrêtée")
|
||||
|
||||
def _on_mouse_click(self, x: int, y: int, button, pressed: bool):
|
||||
"""Callback pour les clics souris."""
|
||||
if not pressed: # On enregistre seulement les clics (pas les relâchements)
|
||||
return
|
||||
|
||||
from .utils.image_utils import capture_screen
|
||||
|
||||
window = get_active_window()
|
||||
|
||||
# Capturer l'écran immédiatement
|
||||
screenshot = capture_screen()
|
||||
|
||||
event = {
|
||||
"type": "mouse_click",
|
||||
"x": x,
|
||||
"y": y,
|
||||
"button": str(button),
|
||||
"window": window,
|
||||
"timestamp": datetime.now(),
|
||||
"screenshot": screenshot # Ajout du screenshot
|
||||
}
|
||||
|
||||
self._add_event(event)
|
||||
|
||||
def _on_mouse_move(self, x: int, y: int):
|
||||
"""Callback pour les mouvements souris (optionnel, peut être bruyant)."""
|
||||
# On n'enregistre pas tous les mouvements pour éviter le bruit
|
||||
pass
|
||||
|
||||
def _on_mouse_scroll(self, x: int, y: int, dx: int, dy: int):
|
||||
"""Callback pour le scroll."""
|
||||
from .utils.image_utils import capture_screen
|
||||
|
||||
window = get_active_window()
|
||||
screenshot = capture_screen()
|
||||
|
||||
event = {
|
||||
"type": "scroll",
|
||||
"x": x,
|
||||
"y": y,
|
||||
"dx": dx,
|
||||
"dy": dy,
|
||||
"window": window,
|
||||
"timestamp": datetime.now(),
|
||||
"screenshot": screenshot
|
||||
}
|
||||
|
||||
self._add_event(event)
|
||||
|
||||
def _on_key_press(self, key):
|
||||
"""Callback pour les frappes clavier."""
|
||||
from .utils.image_utils import capture_screen
|
||||
|
||||
window = get_active_window()
|
||||
|
||||
try:
|
||||
key_char = key.char
|
||||
except AttributeError:
|
||||
key_char = str(key)
|
||||
|
||||
# Détecter les combinaisons (Ctrl+C, Ctrl+V, etc.)
|
||||
is_ctrl = hasattr(key, 'name') and 'ctrl' in str(key).lower()
|
||||
is_combo = is_ctrl or (hasattr(key, 'name') and key.name in ['ctrl_l', 'ctrl_r', 'alt_l', 'alt_r'])
|
||||
|
||||
# Capturer screenshot seulement pour les combos importants
|
||||
screenshot = None
|
||||
if is_combo or key_char in ['c', 'v', 'a', 'x', 'z']:
|
||||
screenshot = capture_screen()
|
||||
|
||||
event = {
|
||||
"type": "key_press",
|
||||
"key": key_char,
|
||||
"window": window,
|
||||
"timestamp": datetime.now(),
|
||||
"screenshot": screenshot,
|
||||
"is_combo": is_combo
|
||||
}
|
||||
|
||||
self._add_event(event)
|
||||
|
||||
def _add_event(self, event: Dict[str, Any]):
|
||||
"""Ajoute un événement à l'historique."""
|
||||
with self.events_lock:
|
||||
self.events.append(event)
|
||||
|
||||
# Limiter la mémoire : garder seulement les 100 derniers
|
||||
if len(self.events) > 100:
|
||||
# Supprimer le plus ancien
|
||||
old_event = self.events.popleft()
|
||||
# Libérer la mémoire du screenshot
|
||||
if 'screenshot' in old_event:
|
||||
del old_event['screenshot']
|
||||
|
||||
# Passer l'événement au SessionManager pour segmentation
|
||||
try:
|
||||
self.session_manager.add_action(event)
|
||||
except Exception as e:
|
||||
if self.logger:
|
||||
self.logger.log_action({
|
||||
"action": "session_add_action_failed",
|
||||
"error": str(e)
|
||||
})
|
||||
|
||||
# Vérifier si on détecte un pattern
|
||||
self._check_for_patterns()
|
||||
|
||||
def _check_for_patterns(self):
|
||||
"""Vérifie si les derniers événements forment un pattern répétitif."""
|
||||
# Détecter le pattern avec le lock
|
||||
with self.events_lock:
|
||||
if len(self.events) < self.pattern_threshold * 2:
|
||||
return
|
||||
|
||||
# Analyser les N derniers événements
|
||||
recent_events = list(self.events)[-20:] # 20 derniers événements
|
||||
|
||||
# Détecter des séquences répétitives
|
||||
pattern = self._detect_repetitive_sequence(recent_events)
|
||||
|
||||
# Appeler les callbacks EN DEHORS du lock pour éviter les deadlocks
|
||||
if pattern:
|
||||
print(f"🎯 Pattern détecté dans event_capture !")
|
||||
print(f" Répétitions: {pattern['repetitions']}")
|
||||
print(f" Longueur: {pattern['length']}")
|
||||
|
||||
# Notifier les callbacks
|
||||
for callback in self.pattern_callbacks:
|
||||
try:
|
||||
callback(pattern)
|
||||
except Exception as e:
|
||||
print(f"❌ Erreur dans callback: {e}")
|
||||
import traceback
|
||||
traceback.print_exc()
|
||||
|
||||
def _detect_repetitive_sequence(
|
||||
self,
|
||||
events: List[Dict[str, Any]]
|
||||
) -> Optional[Dict[str, Any]]:
|
||||
"""
|
||||
Détecte une séquence répétitive dans les événements.
|
||||
|
||||
Returns:
|
||||
Dictionnaire décrivant le pattern ou None
|
||||
"""
|
||||
if len(events) < self.pattern_threshold:
|
||||
return None
|
||||
|
||||
# Simplifier les événements pour la comparaison
|
||||
simplified = []
|
||||
for e in events:
|
||||
if e["type"] == "mouse_click":
|
||||
# Regrouper les clics proches (tolérance de 100px - plus permissif)
|
||||
simplified.append({
|
||||
"type": "click",
|
||||
"x_zone": e["x"] // 100, # Zones de 100px
|
||||
"y_zone": e["y"] // 100,
|
||||
"window": e["window"]
|
||||
})
|
||||
elif e["type"] == "key_press":
|
||||
simplified.append({
|
||||
"type": "key",
|
||||
"key": e["key"],
|
||||
"window": e["window"]
|
||||
})
|
||||
elif e["type"] == "scroll":
|
||||
simplified.append({
|
||||
"type": "scroll",
|
||||
"window": e["window"]
|
||||
})
|
||||
|
||||
# Chercher des répétitions
|
||||
for seq_len in range(1, len(simplified) // self.pattern_threshold + 1):
|
||||
sequence = simplified[-seq_len:]
|
||||
|
||||
# Vérifier si cette séquence se répète
|
||||
repetitions = 0
|
||||
for i in range(len(simplified) - seq_len, -1, -seq_len):
|
||||
if simplified[i:i+seq_len] == sequence:
|
||||
repetitions += 1
|
||||
else:
|
||||
break
|
||||
|
||||
if repetitions >= self.pattern_threshold:
|
||||
return {
|
||||
"sequence": sequence,
|
||||
"repetitions": repetitions,
|
||||
"length": seq_len,
|
||||
"window": sequence[0]["window"] if sequence else None
|
||||
}
|
||||
|
||||
return None
|
||||
|
||||
def register_pattern_callback(self, callback: Callable):
|
||||
"""Enregistre un callback à appeler quand un pattern est détecté."""
|
||||
self.pattern_callbacks.append(callback)
|
||||
|
||||
def get_recent_events(self, count: int = 10) -> List[Dict[str, Any]]:
|
||||
"""Retourne les N derniers événements."""
|
||||
with self.events_lock:
|
||||
return list(self.events)[-count:]
|
||||
|
||||
def get_events_for_window(self, window_title: str) -> List[Dict[str, Any]]:
|
||||
"""Retourne tous les événements pour une fenêtre donnée."""
|
||||
with self.events_lock:
|
||||
return [e for e in self.events if e.get("window") == window_title]
|
||||
|
||||
def clear_history(self):
|
||||
"""Efface l'historique des événements."""
|
||||
with self.events_lock:
|
||||
self.events.clear()
|
||||
|
||||
if self.logger:
|
||||
self.logger.log_action({
|
||||
"action": "event_history_cleared"
|
||||
})
|
||||
|
||||
def get_last_screenshots(self, count: int = 3) -> List[Dict[str, Any]]:
|
||||
"""
|
||||
Retourne les N derniers événements avec screenshots.
|
||||
|
||||
Returns:
|
||||
Liste d'événements avec screenshots
|
||||
"""
|
||||
with self.events_lock:
|
||||
events_with_screenshots = [
|
||||
e for e in self.events
|
||||
if e.get('screenshot') is not None
|
||||
]
|
||||
return events_with_screenshots[-count:] if events_with_screenshots else []
|
||||
|
||||
def _on_session_completed(self, session):
|
||||
"""
|
||||
Callback appelé quand une session est terminée.
|
||||
|
||||
Args:
|
||||
session: Session terminée
|
||||
"""
|
||||
if self.logger:
|
||||
self.logger.log_action({
|
||||
"action": "session_completed_callback",
|
||||
"session_id": session.session_id,
|
||||
"action_count": session.action_count
|
||||
})
|
||||
|
||||
# Analyser les sessions récentes pour détecter des workflows
|
||||
recent_sessions = self.session_manager.get_recent_sessions(10)
|
||||
self.workflow_detector.analyze_sessions(recent_sessions)
|
||||
|
||||
def _on_workflow_detected(self, workflow):
|
||||
"""
|
||||
Callback appelé quand un workflow est détecté.
|
||||
|
||||
Args:
|
||||
workflow: Workflow détecté (dictionnaire)
|
||||
"""
|
||||
if self.logger:
|
||||
self.logger.log_action({
|
||||
"action": "workflow_detected_callback",
|
||||
"workflow_id": workflow.get("workflow_id"),
|
||||
"name": workflow.get("name"),
|
||||
"steps": len(workflow.get("steps", [])),
|
||||
"repetitions": workflow.get("repetitions")
|
||||
})
|
||||
|
||||
# Notifier les callbacks de pattern (pour compatibilité)
|
||||
pattern_data = {
|
||||
"type": "workflow",
|
||||
"workflow_id": workflow.get("workflow_id"),
|
||||
"name": workflow.get("name"),
|
||||
"steps": len(workflow.get("steps", [])),
|
||||
"repetitions": workflow.get("repetitions"),
|
||||
"confidence": workflow.get("confidence")
|
||||
}
|
||||
|
||||
for callback in self.pattern_callbacks:
|
||||
try:
|
||||
callback(pattern_data)
|
||||
except Exception as e:
|
||||
if self.logger:
|
||||
self.logger.log_action({
|
||||
"action": "workflow_callback_error",
|
||||
"error": str(e)
|
||||
})
|
||||
|
||||
def get_workflows(self):
|
||||
"""
|
||||
Retourne les workflows détectés.
|
||||
|
||||
Returns:
|
||||
Liste des workflows
|
||||
"""
|
||||
return self.workflow_detector.get_workflows()
|
||||
|
||||
def get_sessions(self, count: int = 10):
|
||||
"""
|
||||
Retourne les sessions récentes.
|
||||
|
||||
Args:
|
||||
count: Nombre de sessions à retourner
|
||||
|
||||
Returns:
|
||||
Liste des sessions
|
||||
"""
|
||||
return self.session_manager.get_recent_sessions(count)
|
||||
|
||||
def get_workflow_stats(self):
|
||||
"""
|
||||
Retourne les statistiques des workflows.
|
||||
|
||||
Returns:
|
||||
Dictionnaire de statistiques
|
||||
"""
|
||||
return {
|
||||
"sessions": self.session_manager.get_stats(),
|
||||
"workflows": self.workflow_detector.get_stats()
|
||||
}
|
||||
|
||||
def force_finalize_session(self):
|
||||
"""
|
||||
Force la finalisation de la session courante.
|
||||
"""
|
||||
self.session_manager.force_finalize_session()
|
||||
|
||||
def capture_event(self, action: Dict[str, Any]):
|
||||
"""
|
||||
Capture un événement manuellement (pour les tests).
|
||||
|
||||
Args:
|
||||
action: Action à capturer
|
||||
"""
|
||||
# Ajouter à la session
|
||||
self.session_manager.add_action(action)
|
||||
|
||||
# Ajouter à l'historique
|
||||
self._add_event(action)
|
||||
408
geniusia2/core/faiss_index_builder.py
Normal file
408
geniusia2/core/faiss_index_builder.py
Normal file
@@ -0,0 +1,408 @@
|
||||
"""
|
||||
Constructeur et reconstructeur d'index FAISS.
|
||||
Scanne les tâches existantes et reconstruit l'index à partir des signatures sauvegardées.
|
||||
"""
|
||||
|
||||
import json
|
||||
import pickle
|
||||
import numpy as np
|
||||
from pathlib import Path
|
||||
from typing import Dict, List, Optional, Any, Tuple
|
||||
from datetime import datetime
|
||||
|
||||
from .logger import Logger
|
||||
from .embeddings_manager import EmbeddingsManager
|
||||
from .models import TaskProfile
|
||||
|
||||
|
||||
class FAISSIndexBuilder:
|
||||
"""
|
||||
Construit et reconstruit l'index FAISS à partir des tâches sauvegardées.
|
||||
Résout le problème critique : index vide malgré 19+ tâches existantes.
|
||||
"""
|
||||
|
||||
def __init__(
|
||||
self,
|
||||
embeddings_manager: EmbeddingsManager,
|
||||
logger: Logger,
|
||||
profiles_path: str = "data/user_profiles"
|
||||
):
|
||||
"""
|
||||
Initialise le constructeur d'index.
|
||||
|
||||
Args:
|
||||
embeddings_manager: Gestionnaire d'embeddings et index FAISS
|
||||
logger: Logger pour tracer les opérations
|
||||
profiles_path: Chemin vers les profils de tâches
|
||||
"""
|
||||
self.embeddings_manager = embeddings_manager
|
||||
self.logger = logger
|
||||
self.profiles_path = Path(profiles_path)
|
||||
|
||||
self.logger.log_action({
|
||||
"action": "faiss_index_builder_initialized",
|
||||
"profiles_path": str(self.profiles_path)
|
||||
})
|
||||
|
||||
def scan_tasks(self) -> List[Dict[str, Any]]:
|
||||
"""
|
||||
Scanne tous les dossiers de tâches pour trouver les tâches existantes.
|
||||
|
||||
Returns:
|
||||
Liste de dictionnaires avec task_id et chemin
|
||||
"""
|
||||
tasks = []
|
||||
|
||||
if not self.profiles_path.exists():
|
||||
self.logger.log_action({
|
||||
"action": "scan_tasks_no_directory",
|
||||
"path": str(self.profiles_path)
|
||||
})
|
||||
return tasks
|
||||
|
||||
# Scanner tous les sous-dossiers
|
||||
for task_dir in self.profiles_path.iterdir():
|
||||
if not task_dir.is_dir():
|
||||
continue
|
||||
|
||||
# Vérifier qu'il y a un fichier metadata.json
|
||||
metadata_file = task_dir / "metadata.json"
|
||||
if not metadata_file.exists():
|
||||
self.logger.log_action({
|
||||
"action": "scan_tasks_no_metadata",
|
||||
"task_dir": str(task_dir)
|
||||
})
|
||||
continue
|
||||
|
||||
tasks.append({
|
||||
"task_id": task_dir.name,
|
||||
"path": task_dir,
|
||||
"metadata_file": metadata_file
|
||||
})
|
||||
|
||||
self.logger.log_action({
|
||||
"action": "scan_tasks_completed",
|
||||
"tasks_found": len(tasks)
|
||||
})
|
||||
|
||||
return tasks
|
||||
|
||||
|
||||
def load_task_embeddings(self, task_info: Dict[str, Any]) -> Tuple[List[np.ndarray], Dict[str, Any]]:
|
||||
"""
|
||||
Charge les embeddings d'une tâche depuis ses fichiers.
|
||||
|
||||
Args:
|
||||
task_info: Dictionnaire avec task_id, path, metadata_file
|
||||
|
||||
Returns:
|
||||
Tuple (liste d'embeddings, métadonnées de la tâche)
|
||||
"""
|
||||
task_id = task_info["task_id"]
|
||||
task_path = task_info["path"]
|
||||
|
||||
embeddings = []
|
||||
metadata = {}
|
||||
|
||||
try:
|
||||
# Charger les métadonnées
|
||||
with open(task_info["metadata_file"], 'r') as f:
|
||||
metadata = json.load(f)
|
||||
|
||||
# Charger les signatures (qui contiennent les embeddings)
|
||||
signatures_file = task_path / "signatures.pkl"
|
||||
if signatures_file.exists():
|
||||
with open(signatures_file, 'rb') as f:
|
||||
signatures = pickle.load(f)
|
||||
|
||||
# Extraire les embeddings des signatures
|
||||
for sig in signatures:
|
||||
if isinstance(sig, dict) and "embedding" in sig:
|
||||
emb = sig["embedding"]
|
||||
if emb is not None:
|
||||
# Convertir en numpy array si nécessaire
|
||||
if not isinstance(emb, np.ndarray):
|
||||
emb = np.array(emb, dtype=np.float32)
|
||||
embeddings.append(emb)
|
||||
|
||||
self.logger.log_action({
|
||||
"action": "load_task_embeddings_success",
|
||||
"task_id": task_id,
|
||||
"embeddings_count": len(embeddings)
|
||||
})
|
||||
else:
|
||||
self.logger.log_action({
|
||||
"action": "load_task_embeddings_no_signatures",
|
||||
"task_id": task_id
|
||||
})
|
||||
|
||||
except Exception as e:
|
||||
self.logger.log_action({
|
||||
"action": "load_task_embeddings_error",
|
||||
"task_id": task_id,
|
||||
"error": str(e)
|
||||
})
|
||||
|
||||
return embeddings, metadata
|
||||
|
||||
def rebuild_index(self, force: bool = False) -> Dict[str, Any]:
|
||||
"""
|
||||
Reconstruit l'index FAISS complet à partir de toutes les tâches.
|
||||
|
||||
Args:
|
||||
force: Si True, reconstruit même si l'index n'est pas vide
|
||||
|
||||
Returns:
|
||||
Rapport de reconstruction avec statistiques
|
||||
"""
|
||||
start_time = datetime.now()
|
||||
|
||||
self.logger.log_action({
|
||||
"action": "rebuild_index_started",
|
||||
"force": force,
|
||||
"current_index_size": self.embeddings_manager.faiss_index.ntotal
|
||||
})
|
||||
|
||||
# Vérifier si reconstruction nécessaire
|
||||
if not force and self.embeddings_manager.faiss_index.ntotal > 0:
|
||||
self.logger.log_action({
|
||||
"action": "rebuild_index_skipped",
|
||||
"reason": "index_not_empty"
|
||||
})
|
||||
return {
|
||||
"success": False,
|
||||
"reason": "index_not_empty",
|
||||
"current_size": self.embeddings_manager.faiss_index.ntotal
|
||||
}
|
||||
|
||||
# Scanner les tâches
|
||||
tasks = self.scan_tasks()
|
||||
|
||||
if not tasks:
|
||||
self.logger.log_action({
|
||||
"action": "rebuild_index_no_tasks",
|
||||
"reason": "no_tasks_found"
|
||||
})
|
||||
return {
|
||||
"success": False,
|
||||
"reason": "no_tasks_found",
|
||||
"tasks_scanned": 0
|
||||
}
|
||||
|
||||
# Statistiques
|
||||
stats = {
|
||||
"tasks_scanned": len(tasks),
|
||||
"tasks_processed": 0,
|
||||
"tasks_failed": 0,
|
||||
"embeddings_added": 0,
|
||||
"errors": []
|
||||
}
|
||||
|
||||
# Traiter chaque tâche
|
||||
for task_info in tasks:
|
||||
task_id = task_info["task_id"]
|
||||
|
||||
try:
|
||||
# Charger les embeddings
|
||||
embeddings, metadata = self.load_task_embeddings(task_info)
|
||||
|
||||
if not embeddings:
|
||||
self.logger.log_action({
|
||||
"action": "rebuild_index_task_no_embeddings",
|
||||
"task_id": task_id
|
||||
})
|
||||
stats["tasks_failed"] += 1
|
||||
continue
|
||||
|
||||
# Ajouter chaque embedding à l'index
|
||||
for i, embedding in enumerate(embeddings):
|
||||
try:
|
||||
# Valider l'embedding
|
||||
if not self._validate_embedding(embedding):
|
||||
self.logger.log_action({
|
||||
"action": "rebuild_index_invalid_embedding",
|
||||
"task_id": task_id,
|
||||
"embedding_index": i
|
||||
})
|
||||
continue
|
||||
|
||||
# Ajouter à l'index FAISS
|
||||
self.embeddings_manager.add_to_index(
|
||||
embedding,
|
||||
{
|
||||
"task_id": task_id,
|
||||
"embedding_index": i,
|
||||
"timestamp": datetime.now().isoformat(),
|
||||
"from_rebuild": True
|
||||
}
|
||||
)
|
||||
stats["embeddings_added"] += 1
|
||||
|
||||
except Exception as e:
|
||||
self.logger.log_action({
|
||||
"action": "rebuild_index_embedding_error",
|
||||
"task_id": task_id,
|
||||
"embedding_index": i,
|
||||
"error": str(e)
|
||||
})
|
||||
stats["errors"].append({
|
||||
"task_id": task_id,
|
||||
"embedding_index": i,
|
||||
"error": str(e)
|
||||
})
|
||||
|
||||
stats["tasks_processed"] += 1
|
||||
|
||||
except Exception as e:
|
||||
self.logger.log_action({
|
||||
"action": "rebuild_index_task_error",
|
||||
"task_id": task_id,
|
||||
"error": str(e)
|
||||
})
|
||||
stats["tasks_failed"] += 1
|
||||
stats["errors"].append({
|
||||
"task_id": task_id,
|
||||
"error": str(e)
|
||||
})
|
||||
|
||||
# Sauvegarder l'index
|
||||
try:
|
||||
self.embeddings_manager.save_index()
|
||||
self.logger.log_action({
|
||||
"action": "rebuild_index_saved",
|
||||
"embeddings_count": stats["embeddings_added"]
|
||||
})
|
||||
except Exception as e:
|
||||
self.logger.log_action({
|
||||
"action": "rebuild_index_save_error",
|
||||
"error": str(e)
|
||||
})
|
||||
stats["errors"].append({
|
||||
"error": f"Failed to save index: {str(e)}"
|
||||
})
|
||||
|
||||
# Calculer la durée
|
||||
duration = (datetime.now() - start_time).total_seconds()
|
||||
|
||||
stats["success"] = stats["embeddings_added"] > 0
|
||||
stats["duration_seconds"] = duration
|
||||
stats["final_index_size"] = self.embeddings_manager.faiss_index.ntotal
|
||||
|
||||
self.logger.log_action({
|
||||
"action": "rebuild_index_completed",
|
||||
"stats": stats
|
||||
})
|
||||
|
||||
return stats
|
||||
|
||||
|
||||
def verify_index_integrity(self) -> Dict[str, Any]:
|
||||
"""
|
||||
Vérifie la cohérence entre les tâches sauvegardées et l'index FAISS.
|
||||
|
||||
Returns:
|
||||
Rapport de vérification avec incohérences détectées
|
||||
"""
|
||||
self.logger.log_action({
|
||||
"action": "verify_index_integrity_started"
|
||||
})
|
||||
|
||||
# Scanner les tâches
|
||||
tasks = self.scan_tasks()
|
||||
|
||||
# Compter les embeddings attendus
|
||||
expected_embeddings = 0
|
||||
tasks_with_embeddings = 0
|
||||
|
||||
for task_info in tasks:
|
||||
embeddings, _ = self.load_task_embeddings(task_info)
|
||||
if embeddings:
|
||||
expected_embeddings += len(embeddings)
|
||||
tasks_with_embeddings += 1
|
||||
|
||||
# Comparer avec l'index actuel
|
||||
actual_embeddings = self.embeddings_manager.faiss_index.ntotal
|
||||
|
||||
# Déterminer la cohérence
|
||||
is_consistent = (actual_embeddings >= expected_embeddings * 0.9) # Tolérance de 10%
|
||||
|
||||
report = {
|
||||
"is_consistent": is_consistent,
|
||||
"tasks_scanned": len(tasks),
|
||||
"tasks_with_embeddings": tasks_with_embeddings,
|
||||
"expected_embeddings": expected_embeddings,
|
||||
"actual_embeddings": actual_embeddings,
|
||||
"missing_embeddings": max(0, expected_embeddings - actual_embeddings),
|
||||
"needs_rebuild": not is_consistent
|
||||
}
|
||||
|
||||
self.logger.log_action({
|
||||
"action": "verify_index_integrity_completed",
|
||||
"report": report
|
||||
})
|
||||
|
||||
return report
|
||||
|
||||
def _validate_embedding(self, embedding: np.ndarray) -> bool:
|
||||
"""
|
||||
Valide qu'un embedding est correct.
|
||||
|
||||
Args:
|
||||
embedding: Embedding à valider
|
||||
|
||||
Returns:
|
||||
True si valide, False sinon
|
||||
"""
|
||||
try:
|
||||
# Vérifier que c'est un numpy array
|
||||
if not isinstance(embedding, np.ndarray):
|
||||
return False
|
||||
|
||||
# Vérifier la dimension
|
||||
if embedding.shape[0] != self.embeddings_manager.embedding_dim:
|
||||
self.logger.log_action({
|
||||
"action": "validate_embedding_wrong_dimension",
|
||||
"expected": self.embeddings_manager.embedding_dim,
|
||||
"actual": embedding.shape[0]
|
||||
})
|
||||
return False
|
||||
|
||||
# Vérifier pas de NaN ou Inf
|
||||
if np.isnan(embedding).any() or np.isinf(embedding).any():
|
||||
self.logger.log_action({
|
||||
"action": "validate_embedding_nan_inf"
|
||||
})
|
||||
return False
|
||||
|
||||
# Vérifier norme non-nulle
|
||||
norm = np.linalg.norm(embedding)
|
||||
if norm == 0:
|
||||
self.logger.log_action({
|
||||
"action": "validate_embedding_zero_norm"
|
||||
})
|
||||
return False
|
||||
|
||||
return True
|
||||
|
||||
except Exception as e:
|
||||
self.logger.log_action({
|
||||
"action": "validate_embedding_error",
|
||||
"error": str(e)
|
||||
})
|
||||
return False
|
||||
|
||||
def get_stats(self) -> Dict[str, Any]:
|
||||
"""
|
||||
Retourne des statistiques sur l'état actuel.
|
||||
|
||||
Returns:
|
||||
Dictionnaire de statistiques
|
||||
"""
|
||||
tasks = self.scan_tasks()
|
||||
|
||||
return {
|
||||
"tasks_found": len(tasks),
|
||||
"index_size": self.embeddings_manager.faiss_index.ntotal,
|
||||
"profiles_path": str(self.profiles_path),
|
||||
"profiles_path_exists": self.profiles_path.exists()
|
||||
}
|
||||
908
geniusia2/core/learning_manager.py
Normal file
908
geniusia2/core/learning_manager.py
Normal file
@@ -0,0 +1,908 @@
|
||||
"""
|
||||
Gestionnaire d'apprentissage pour RPA Vision V2.
|
||||
Gère les transitions de mode, le calcul de confiance et l'état d'apprentissage.
|
||||
"""
|
||||
|
||||
import json
|
||||
from pathlib import Path
|
||||
from typing import Dict, List, Optional, Any, Tuple
|
||||
from datetime import datetime
|
||||
from dataclasses import asdict
|
||||
import numpy as np
|
||||
|
||||
from .models import TaskProfile, Action
|
||||
from .embeddings_manager import EmbeddingsManager
|
||||
from .logger import Logger
|
||||
from .ui_change_detector import UIChangeDetector
|
||||
|
||||
|
||||
class LearningManager:
|
||||
"""
|
||||
Gestionnaire d'apprentissage qui suit la progression, gère les transitions
|
||||
de mode et calcule les scores de confiance.
|
||||
"""
|
||||
|
||||
def __init__(
|
||||
self,
|
||||
embeddings_manager: EmbeddingsManager,
|
||||
logger: Logger,
|
||||
config: Dict[str, Any],
|
||||
profiles_path: str = "data/user_profiles"
|
||||
):
|
||||
"""
|
||||
Initialise le gestionnaire d'apprentissage.
|
||||
|
||||
Args:
|
||||
embeddings_manager: Gestionnaire d'embeddings
|
||||
logger: Logger pour la journalisation
|
||||
config: Configuration globale
|
||||
profiles_path: Chemin vers les profils utilisateur
|
||||
"""
|
||||
self.embeddings_manager = embeddings_manager
|
||||
self.logger = logger
|
||||
self.config = config
|
||||
self.profiles_path = Path(profiles_path)
|
||||
|
||||
# Créer le répertoire de profils
|
||||
self.profiles_path.mkdir(parents=True, exist_ok=True)
|
||||
|
||||
# État actuel
|
||||
self.mode = "shadow" # Mode initial
|
||||
self.tasks: Dict[str, TaskProfile] = {}
|
||||
self.current_task_id: Optional[str] = None
|
||||
self.current_context: Dict[str, Any] = {}
|
||||
|
||||
# Seuils de configuration
|
||||
self.autopilot_observations = config.get("thresholds", {}).get(
|
||||
"autopilot_observations", 20
|
||||
)
|
||||
self.autopilot_concordance = config.get("thresholds", {}).get(
|
||||
"autopilot_concordance", 0.95
|
||||
)
|
||||
self.confidence_min = config.get("thresholds", {}).get(
|
||||
"confidence_min", 0.90
|
||||
)
|
||||
self.rollback_confidence = config.get("thresholds", {}).get(
|
||||
"rollback_confidence", 0.85
|
||||
)
|
||||
|
||||
# Initialiser le détecteur de changements UI
|
||||
self.ui_change_detector = UIChangeDetector(
|
||||
embeddings_manager,
|
||||
logger,
|
||||
config
|
||||
)
|
||||
|
||||
# Charger les profils existants
|
||||
self._load_profiles()
|
||||
|
||||
# Charger les tâches existantes dans l'index FAISS
|
||||
self._load_existing_tasks_to_index()
|
||||
|
||||
self.logger.log_action({
|
||||
"action": "learning_manager_initialized",
|
||||
"mode": self.mode,
|
||||
"num_tasks": len(self.tasks)
|
||||
})
|
||||
|
||||
def _save_profile(self, task_id: str):
|
||||
"""Sauvegarde un profil de tâche."""
|
||||
if task_id not in self.tasks:
|
||||
return
|
||||
|
||||
try:
|
||||
profile = self.tasks[task_id]
|
||||
profile_file = self.profiles_path / f"{task_id}.json"
|
||||
|
||||
with open(profile_file, 'w', encoding='utf-8') as f:
|
||||
json.dump(profile.to_json(), f, indent=2, ensure_ascii=False)
|
||||
|
||||
self.logger.log_action({
|
||||
"action": "profile_saved",
|
||||
"task_id": task_id
|
||||
})
|
||||
|
||||
except Exception as e:
|
||||
self.logger.log_action({
|
||||
"action": "profile_save_error",
|
||||
"task_id": task_id,
|
||||
"error": str(e)
|
||||
})
|
||||
|
||||
def observe(self, action: Action):
|
||||
"""
|
||||
Enregistre une observation en mode Shadow.
|
||||
|
||||
Args:
|
||||
action: Action observée
|
||||
"""
|
||||
# Obtenir ou créer le profil de tâche
|
||||
task_id = self.current_task_id or self._generate_task_id(action)
|
||||
|
||||
if task_id not in self.tasks:
|
||||
self.tasks[task_id] = TaskProfile(
|
||||
task_id=task_id,
|
||||
task_name=f"Tâche {task_id}",
|
||||
mode="shadow",
|
||||
observation_count=0,
|
||||
concordance_rate=0.0,
|
||||
confidence_score=0.0,
|
||||
correction_count=0,
|
||||
last_execution=datetime.now(),
|
||||
window_whitelist=[action.window_title],
|
||||
action_sequence=[],
|
||||
embeddings=[],
|
||||
metadata={}
|
||||
)
|
||||
|
||||
task = self.tasks[task_id]
|
||||
task.observation_count += 1
|
||||
task.action_sequence.append(action)
|
||||
task.last_execution = datetime.now()
|
||||
|
||||
# Ajouter l'embedding si disponible
|
||||
if action.embedding is not None:
|
||||
task.embeddings.append(action.embedding)
|
||||
|
||||
# Ajouter à l'index FAISS
|
||||
self.embeddings_manager.add_to_index(
|
||||
action.embedding,
|
||||
{
|
||||
"task_id": task_id,
|
||||
"action_type": action.action_type,
|
||||
"target_element": action.target_element,
|
||||
"timestamp": action.timestamp.isoformat()
|
||||
}
|
||||
)
|
||||
|
||||
self.logger.log_action({
|
||||
"action": "observation_recorded",
|
||||
"task_id": task_id,
|
||||
"observation_count": task.observation_count,
|
||||
"action_type": action.action_type
|
||||
})
|
||||
|
||||
# Sauvegarder le profil
|
||||
self._save_profile(task_id)
|
||||
|
||||
# Vérifier si on peut passer en mode Assisté
|
||||
if task.observation_count >= 5 and task.mode == "shadow":
|
||||
self._transition_mode(task_id, "assist")
|
||||
|
||||
def _generate_task_id(self, action: Action) -> str:
|
||||
"""Génère un ID de tâche basé sur l'action."""
|
||||
window_clean = action.window_title.replace(" ", "_").lower()
|
||||
element_clean = action.target_element.replace(" ", "_").lower()
|
||||
return f"{window_clean}_{element_clean}_{action.action_type}"
|
||||
|
||||
def suggest_action(self, context: Dict[str, Any]) -> Optional[Dict[str, Any]]:
|
||||
"""
|
||||
Génère une suggestion d'action en mode Assisté.
|
||||
|
||||
Args:
|
||||
context: Contexte actuel (fenêtre, capture d'écran, etc.)
|
||||
|
||||
Returns:
|
||||
Suggestion d'action ou None
|
||||
"""
|
||||
if not self.current_task_id or self.current_task_id not in self.tasks:
|
||||
return None
|
||||
|
||||
task = self.tasks[self.current_task_id]
|
||||
|
||||
if task.mode != "assist" and task.mode != "auto":
|
||||
return None
|
||||
|
||||
# Trouver l'action la plus similaire dans l'historique
|
||||
if not task.embeddings or 'current_embedding' not in context:
|
||||
return None
|
||||
|
||||
current_emb = context['current_embedding']
|
||||
|
||||
# Rechercher les embeddings similaires
|
||||
similar = self.embeddings_manager.search_similar(current_emb, k=3)
|
||||
|
||||
if not similar:
|
||||
return None
|
||||
|
||||
# Prendre le plus similaire
|
||||
best_match = similar[0]
|
||||
similarity = best_match['similarity']
|
||||
|
||||
# Trouver l'action correspondante
|
||||
matching_action = None
|
||||
for action in task.action_sequence:
|
||||
if (action.timestamp.isoformat() ==
|
||||
best_match['metadata'].get('timestamp')):
|
||||
matching_action = action
|
||||
break
|
||||
|
||||
if not matching_action:
|
||||
return None
|
||||
|
||||
# Calculer la confiance
|
||||
vision_conf = similarity
|
||||
llm_score = context.get('llm_score', 0.5)
|
||||
confidence = self.calculate_confidence(vision_conf, llm_score, self.current_task_id)
|
||||
|
||||
suggestion = {
|
||||
"action": matching_action,
|
||||
"confidence": confidence,
|
||||
"similarity": similarity,
|
||||
"task_id": self.current_task_id
|
||||
}
|
||||
|
||||
self.logger.log_action({
|
||||
"action": "suggestion_generated",
|
||||
"task_id": self.current_task_id,
|
||||
"confidence": confidence,
|
||||
"action_type": matching_action.action_type
|
||||
})
|
||||
|
||||
return suggestion
|
||||
|
||||
def confirm_action(self, feedback: Dict[str, Any]):
|
||||
"""
|
||||
Traite la validation ou correction utilisateur.
|
||||
|
||||
Args:
|
||||
feedback: Dictionnaire avec type (accept/reject/correct) et données
|
||||
"""
|
||||
feedback_type = feedback.get("type")
|
||||
task_id = feedback.get("task_id", self.current_task_id)
|
||||
|
||||
if not task_id or task_id not in self.tasks:
|
||||
return
|
||||
|
||||
task = self.tasks[task_id]
|
||||
|
||||
if feedback_type == "accept":
|
||||
# Action acceptée, augmenter la concordance
|
||||
self._update_concordance(task_id, success=True)
|
||||
|
||||
self.logger.log_action({
|
||||
"action": "action_accepted",
|
||||
"task_id": task_id,
|
||||
"concordance_rate": task.concordance_rate
|
||||
})
|
||||
|
||||
elif feedback_type == "reject":
|
||||
# Action rejetée
|
||||
self._update_concordance(task_id, success=False)
|
||||
|
||||
self.logger.log_action({
|
||||
"action": "action_rejected",
|
||||
"task_id": task_id,
|
||||
"concordance_rate": task.concordance_rate
|
||||
})
|
||||
|
||||
elif feedback_type == "correct":
|
||||
# Correction fournie
|
||||
task.correction_count += 1
|
||||
corrected_action = feedback.get("corrected_action")
|
||||
|
||||
if corrected_action:
|
||||
# Ajouter la correction à la séquence
|
||||
task.action_sequence.append(corrected_action)
|
||||
|
||||
# Mettre à jour les embeddings
|
||||
if corrected_action.embedding is not None:
|
||||
task.embeddings.append(corrected_action.embedding)
|
||||
self.embeddings_manager.add_to_index(
|
||||
corrected_action.embedding,
|
||||
{
|
||||
"task_id": task_id,
|
||||
"action_type": corrected_action.action_type,
|
||||
"target_element": corrected_action.target_element,
|
||||
"timestamp": corrected_action.timestamp.isoformat(),
|
||||
"is_correction": True
|
||||
}
|
||||
)
|
||||
|
||||
self._update_concordance(task_id, success=False)
|
||||
|
||||
self.logger.log_correction({
|
||||
"task_id": task_id,
|
||||
"correction_count": task.correction_count,
|
||||
"corrected_action": corrected_action.to_dict() if corrected_action else None
|
||||
})
|
||||
|
||||
# Sauvegarder le profil
|
||||
self._save_profile(task_id)
|
||||
|
||||
# Vérifier les transitions de mode
|
||||
self._check_mode_transitions(task_id)
|
||||
|
||||
def _update_concordance(self, task_id: str, success: bool):
|
||||
"""Met à jour le taux de concordance."""
|
||||
if task_id not in self.tasks:
|
||||
return
|
||||
|
||||
task = self.tasks[task_id]
|
||||
|
||||
# Utiliser une moyenne mobile sur les 10 dernières exécutions
|
||||
if not hasattr(task, '_recent_results'):
|
||||
task.metadata['recent_results'] = []
|
||||
|
||||
recent_results = task.metadata.get('recent_results', [])
|
||||
recent_results.append(1 if success else 0)
|
||||
|
||||
# Garder seulement les 10 derniers
|
||||
if len(recent_results) > 10:
|
||||
recent_results = recent_results[-10:]
|
||||
|
||||
task.metadata['recent_results'] = recent_results
|
||||
task.concordance_rate = sum(recent_results) / len(recent_results)
|
||||
|
||||
def calculate_confidence(
|
||||
self,
|
||||
vision_conf: float,
|
||||
llm_score: float,
|
||||
task_id: str
|
||||
) -> float:
|
||||
"""
|
||||
Calcule le score de confiance pondéré.
|
||||
|
||||
Args:
|
||||
vision_conf: Confiance de la détection vision (0-1)
|
||||
llm_score: Score du LLM (0-1)
|
||||
task_id: ID de la tâche
|
||||
|
||||
Returns:
|
||||
Score de confiance (0-1)
|
||||
"""
|
||||
# Obtenir la performance historique
|
||||
history_score = self._get_historical_performance(task_id)
|
||||
|
||||
# Formule : 0.6 × vision + 0.3 × llm + 0.1 × historique
|
||||
confidence = (
|
||||
0.6 * vision_conf +
|
||||
0.3 * llm_score +
|
||||
0.1 * history_score
|
||||
)
|
||||
|
||||
return max(0.0, min(1.0, confidence))
|
||||
|
||||
def _get_historical_performance(self, task_id: str) -> float:
|
||||
"""Obtient la performance historique d'une tâche."""
|
||||
if task_id not in self.tasks:
|
||||
return 0.5
|
||||
|
||||
task = self.tasks[task_id]
|
||||
return task.concordance_rate
|
||||
|
||||
def evaluate_task(self, task_id: str) -> Dict[str, Any]:
|
||||
"""
|
||||
Évalue une tâche et retourne ses métriques.
|
||||
|
||||
Args:
|
||||
task_id: ID de la tâche
|
||||
|
||||
Returns:
|
||||
Dictionnaire de métriques
|
||||
"""
|
||||
if task_id not in self.tasks:
|
||||
return {}
|
||||
|
||||
task = self.tasks[task_id]
|
||||
|
||||
return {
|
||||
"task_id": task_id,
|
||||
"task_name": task.task_name,
|
||||
"mode": task.mode,
|
||||
"observation_count": task.observation_count,
|
||||
"concordance_rate": task.concordance_rate,
|
||||
"confidence_score": task.confidence_score,
|
||||
"correction_count": task.correction_count,
|
||||
"correction_rate": (
|
||||
task.correction_count / max(1, task.observation_count)
|
||||
),
|
||||
"last_execution": task.last_execution.isoformat() if task.last_execution else None
|
||||
}
|
||||
|
||||
def should_transition_to_auto(self, task_id: str) -> bool:
|
||||
"""
|
||||
Vérifie si une tâche remplit les critères pour passer en Autopilot.
|
||||
|
||||
Args:
|
||||
task_id: ID de la tâche
|
||||
|
||||
Returns:
|
||||
True si éligible, False sinon
|
||||
"""
|
||||
if task_id not in self.tasks:
|
||||
return False
|
||||
|
||||
task = self.tasks[task_id]
|
||||
|
||||
return (
|
||||
task.observation_count >= self.autopilot_observations and
|
||||
task.concordance_rate >= self.autopilot_concordance
|
||||
)
|
||||
|
||||
def rollback_if_low_confidence(self, task_id: str):
|
||||
"""
|
||||
Rétrograde une tâche au mode Assisté si la confiance est faible.
|
||||
|
||||
Args:
|
||||
task_id: ID de la tâche
|
||||
"""
|
||||
if task_id not in self.tasks:
|
||||
return
|
||||
|
||||
task = self.tasks[task_id]
|
||||
|
||||
if task.mode == "auto" and task.confidence_score < self.confidence_min:
|
||||
self._transition_mode(task_id, "assist")
|
||||
|
||||
self.logger.log_mode_transition(
|
||||
task_id,
|
||||
"auto",
|
||||
"assist",
|
||||
f"Confiance faible: {task.confidence_score:.2f}"
|
||||
)
|
||||
|
||||
def _check_mode_transitions(self, task_id: str):
|
||||
"""Vérifie et effectue les transitions de mode si nécessaire."""
|
||||
if task_id not in self.tasks:
|
||||
return
|
||||
|
||||
task = self.tasks[task_id]
|
||||
|
||||
# Shadow → Assist (après 5 observations)
|
||||
if task.mode == "shadow" and task.observation_count >= 5:
|
||||
self._transition_mode(task_id, "assist")
|
||||
|
||||
# Assist → Auto (si critères remplis)
|
||||
elif task.mode == "assist" and self.should_transition_to_auto(task_id):
|
||||
self._transition_mode(task_id, "auto")
|
||||
|
||||
# Auto → Assist (si confiance faible)
|
||||
elif task.mode == "auto" and task.concordance_rate < self.rollback_confidence:
|
||||
self._transition_mode(task_id, "assist")
|
||||
|
||||
def _transition_mode(self, task_id: str, new_mode: str):
|
||||
"""Effectue une transition de mode."""
|
||||
if task_id not in self.tasks:
|
||||
return
|
||||
|
||||
task = self.tasks[task_id]
|
||||
old_mode = task.mode
|
||||
|
||||
if old_mode == new_mode:
|
||||
return
|
||||
|
||||
task.mode = new_mode
|
||||
|
||||
self.logger.log_mode_transition(
|
||||
task_id,
|
||||
old_mode,
|
||||
new_mode,
|
||||
f"Observations: {task.observation_count}, Concordance: {task.concordance_rate:.2%}"
|
||||
)
|
||||
|
||||
self._save_profile(task_id)
|
||||
|
||||
def get_mode(self) -> str:
|
||||
"""Retourne le mode opérationnel actuel."""
|
||||
if self.current_task_id and self.current_task_id in self.tasks:
|
||||
return self.tasks[self.current_task_id].mode
|
||||
return self.mode
|
||||
|
||||
def get_current_intent(self) -> str:
|
||||
"""Retourne l'intention actuelle."""
|
||||
return self.current_context.get("intent", "")
|
||||
|
||||
def set_current_intent(self, intent: str):
|
||||
"""
|
||||
Définit l'intention actuelle.
|
||||
|
||||
Args:
|
||||
intent: Intention à définir (ex: "button", "text field", "form")
|
||||
"""
|
||||
self.current_context["intent"] = intent
|
||||
|
||||
self.logger.log_action({
|
||||
"action": "intent_set",
|
||||
"intent": intent,
|
||||
"mode": self.mode
|
||||
})
|
||||
|
||||
def create_task_from_signatures(
|
||||
self,
|
||||
signatures: List[Dict[str, Any]],
|
||||
description: str = "Tâche automatique"
|
||||
) -> TaskProfile:
|
||||
"""
|
||||
Crée une tâche à partir de signatures visuelles.
|
||||
|
||||
Args:
|
||||
signatures: Liste de signatures d'actions
|
||||
description: Description de la tâche
|
||||
|
||||
Returns:
|
||||
TaskProfile créé
|
||||
"""
|
||||
import hashlib
|
||||
import json
|
||||
|
||||
# Générer un ID unique
|
||||
task_id = hashlib.md5(
|
||||
json.dumps(str(signatures[0])).encode()
|
||||
).hexdigest()[:8]
|
||||
|
||||
# Créer le profil de tâche
|
||||
task = TaskProfile(
|
||||
task_id=f"task_{task_id}",
|
||||
task_name=description,
|
||||
window_whitelist=[signatures[0].get("window", "Unknown")],
|
||||
observation_count=len(signatures),
|
||||
embeddings=[sig.get("embedding") for sig in signatures if sig.get("embedding") is not None],
|
||||
metadata={"signatures": signatures} # Stocker les signatures dans metadata
|
||||
)
|
||||
|
||||
# Ajouter aux tâches
|
||||
self.tasks[task.task_id] = task
|
||||
|
||||
# Sauvegarder
|
||||
self._save_task(task)
|
||||
|
||||
self.logger.log_action({
|
||||
"action": "task_created",
|
||||
"task_id": task.task_id,
|
||||
"signatures_count": len(signatures)
|
||||
})
|
||||
|
||||
return task
|
||||
|
||||
def _save_task(self, task: TaskProfile):
|
||||
"""Sauvegarde une tâche sur disque."""
|
||||
import json
|
||||
import pickle
|
||||
from pathlib import Path
|
||||
|
||||
task_dir = self.profiles_path / task.task_id
|
||||
task_dir.mkdir(parents=True, exist_ok=True)
|
||||
|
||||
# Sauvegarder les métadonnées
|
||||
metadata = {
|
||||
"task_id": task.task_id,
|
||||
"task_name": task.task_name,
|
||||
"window_whitelist": task.window_whitelist,
|
||||
"observation_count": task.observation_count,
|
||||
"mode": task.mode,
|
||||
"confidence_score": task.confidence_score
|
||||
}
|
||||
|
||||
with open(task_dir / "metadata.json", "w") as f:
|
||||
json.dump(metadata, f, indent=2)
|
||||
|
||||
# Sauvegarder les signatures (depuis metadata)
|
||||
if "signatures" in task.metadata:
|
||||
with open(task_dir / "signatures.pkl", "wb") as f:
|
||||
pickle.dump(task.metadata["signatures"], f)
|
||||
|
||||
# Sauvegarder l'index FAISS automatiquement (MVP)
|
||||
try:
|
||||
self.embeddings_manager.save_index()
|
||||
self.logger.log_action({
|
||||
"action": "faiss_index_saved",
|
||||
"task_id": task.task_id
|
||||
})
|
||||
except Exception as e:
|
||||
self.logger.log_action({
|
||||
"action": "faiss_index_save_error",
|
||||
"task_id": task.task_id,
|
||||
"error": str(e)
|
||||
})
|
||||
|
||||
def load_task(self, task_id: str) -> Optional[TaskProfile]:
|
||||
"""Charge une tâche depuis le disque."""
|
||||
import json
|
||||
import pickle
|
||||
|
||||
task_dir = self.profiles_path / task_id
|
||||
|
||||
if not task_dir.exists():
|
||||
return None
|
||||
|
||||
try:
|
||||
# Charger métadonnées
|
||||
with open(task_dir / "metadata.json", "r") as f:
|
||||
metadata = json.load(f)
|
||||
|
||||
task = TaskProfile(
|
||||
task_id=metadata["task_id"],
|
||||
task_name=metadata.get("task_name", metadata.get("description", "Unknown")),
|
||||
window_whitelist=metadata.get("window_whitelist", [metadata.get("window_title", "Unknown")]),
|
||||
observation_count=metadata.get("observation_count", metadata.get("observations", 0)),
|
||||
mode=metadata.get("mode", "shadow"),
|
||||
confidence_score=metadata.get("confidence_score", 0.0)
|
||||
)
|
||||
|
||||
# Charger signatures dans metadata
|
||||
sig_file = task_dir / "signatures.pkl"
|
||||
if sig_file.exists():
|
||||
with open(sig_file, "rb") as f:
|
||||
task.metadata["signatures"] = pickle.load(f)
|
||||
|
||||
return task
|
||||
|
||||
except Exception as e:
|
||||
self.logger.log_action({
|
||||
"action": "task_load_failed",
|
||||
"task_id": task_id,
|
||||
"error": str(e)
|
||||
})
|
||||
return None
|
||||
|
||||
def set_current_task(self, task_id: str):
|
||||
"""Définit la tâche actuelle."""
|
||||
self.current_task_id = task_id
|
||||
|
||||
def set_current_context(self, context: Dict[str, Any]):
|
||||
"""Définit le contexte actuel."""
|
||||
self.current_context = context
|
||||
|
||||
def record_execution(self, decision: Dict[str, Any]):
|
||||
"""
|
||||
Enregistre l'exécution d'une action.
|
||||
|
||||
Args:
|
||||
decision: Dictionnaire décrivant la décision et l'action
|
||||
"""
|
||||
task_id = decision.get("task_id", self.current_task_id)
|
||||
|
||||
if not task_id or task_id not in self.tasks:
|
||||
return
|
||||
|
||||
task = self.tasks[task_id]
|
||||
task.last_execution = datetime.now()
|
||||
|
||||
# Mettre à jour le score de confiance
|
||||
if "confidence" in decision:
|
||||
task.confidence_score = decision["confidence"]
|
||||
|
||||
self.logger.log_action({
|
||||
"action": "execution_recorded",
|
||||
"task_id": task_id,
|
||||
"confidence": task.confidence_score,
|
||||
"mode": task.mode
|
||||
})
|
||||
|
||||
self._save_profile(task_id)
|
||||
|
||||
def get_all_tasks(self) -> List[Dict[str, Any]]:
|
||||
"""Retourne toutes les tâches avec leurs métriques."""
|
||||
return [self.evaluate_task(task_id) for task_id in self.tasks.keys()]
|
||||
|
||||
def get_task_stats(self) -> Dict[str, Any]:
|
||||
"""Retourne des statistiques globales."""
|
||||
total_tasks = len(self.tasks)
|
||||
shadow_tasks = sum(1 for t in self.tasks.values() if t.mode == "shadow")
|
||||
assist_tasks = sum(1 for t in self.tasks.values() if t.mode == "assist")
|
||||
auto_tasks = sum(1 for t in self.tasks.values() if t.mode == "auto")
|
||||
|
||||
return {
|
||||
"total_tasks": total_tasks,
|
||||
"shadow_tasks": shadow_tasks,
|
||||
"assist_tasks": assist_tasks,
|
||||
"auto_tasks": auto_tasks,
|
||||
"current_mode": self.get_mode()
|
||||
}
|
||||
|
||||
def check_ui_changes(
|
||||
self,
|
||||
task_id: str,
|
||||
current_embedding: np.ndarray,
|
||||
predicted_bbox: Optional[Tuple[int, int, int, int]] = None,
|
||||
actual_bbox: Optional[Tuple[int, int, int, int]] = None
|
||||
) -> Dict[str, Any]:
|
||||
"""
|
||||
Vérifie les changements UI pour une tâche et déclenche le ré-entraînement si nécessaire.
|
||||
|
||||
Args:
|
||||
task_id: ID de la tâche
|
||||
current_embedding: Embedding visuel actuel
|
||||
predicted_bbox: Bbox prédite (optionnel)
|
||||
actual_bbox: Bbox réelle (optionnel)
|
||||
|
||||
Returns:
|
||||
Résultats de la vérification
|
||||
"""
|
||||
if task_id not in self.tasks:
|
||||
return {
|
||||
"error": "task_not_found",
|
||||
"task_id": task_id
|
||||
}
|
||||
|
||||
task = self.tasks[task_id]
|
||||
|
||||
# Vérifier les changements avec le détecteur
|
||||
result = self.ui_change_detector.check_and_trigger_retraining(
|
||||
task_id,
|
||||
current_embedding,
|
||||
task.embeddings,
|
||||
predicted_bbox,
|
||||
actual_bbox
|
||||
)
|
||||
|
||||
# Si ré-entraînement déclenché, mettre à jour le profil
|
||||
if result.get("retraining_triggered"):
|
||||
# Ajouter l'embedding actuel pour améliorer la détection future
|
||||
task.embeddings.append(current_embedding)
|
||||
|
||||
# Ajouter à l'index FAISS
|
||||
self.embeddings_manager.add_to_index(
|
||||
current_embedding,
|
||||
{
|
||||
"task_id": task_id,
|
||||
"timestamp": datetime.now().isoformat(),
|
||||
"is_retraining": True
|
||||
}
|
||||
)
|
||||
|
||||
# Sauvegarder le profil
|
||||
self._save_profile(task_id)
|
||||
|
||||
# Logger l'événement
|
||||
self.logger.log_action({
|
||||
"action": "ui_change_retraining",
|
||||
"task_id": task_id,
|
||||
"similarity": result.get("similarity"),
|
||||
"deltas": result.get("deltas"),
|
||||
"reason": "ui_drift_detected"
|
||||
})
|
||||
|
||||
return result
|
||||
|
||||
def monitor_execution_drift(
|
||||
self,
|
||||
task_id: str,
|
||||
predicted_action: Action,
|
||||
actual_action: Action
|
||||
) -> bool:
|
||||
"""
|
||||
Surveille la dérive entre l'action prédite et l'action réelle.
|
||||
|
||||
Args:
|
||||
task_id: ID de la tâche
|
||||
predicted_action: Action prédite par le système
|
||||
actual_action: Action réelle effectuée/validée
|
||||
|
||||
Returns:
|
||||
True si dérive détectée et ré-entraînement déclenché
|
||||
"""
|
||||
if task_id not in self.tasks:
|
||||
return False
|
||||
|
||||
# Vérifier les changements UI
|
||||
result = self.check_ui_changes(
|
||||
task_id,
|
||||
actual_action.embedding,
|
||||
predicted_action.bbox,
|
||||
actual_action.bbox
|
||||
)
|
||||
|
||||
return result.get("retraining_triggered", False)
|
||||
|
||||
def get_ui_change_stats(self) -> Dict[str, Any]:
|
||||
"""
|
||||
Retourne les statistiques de détection de changements UI.
|
||||
|
||||
Returns:
|
||||
Statistiques du détecteur de changements
|
||||
"""
|
||||
return self.ui_change_detector.get_stats()
|
||||
|
||||
def _load_profiles(self):
|
||||
"""Charge les profils de tâches existants."""
|
||||
if not self.profiles_path.exists():
|
||||
return
|
||||
|
||||
for profile_file in self.profiles_path.glob("*.json"):
|
||||
try:
|
||||
with open(profile_file, 'r', encoding='utf-8') as f:
|
||||
json_str = f.read()
|
||||
|
||||
# Charger le profil (sans embeddings pour l'instant)
|
||||
task = TaskProfile.from_json(json_str)
|
||||
self.tasks[task.task_id] = task
|
||||
|
||||
self.logger.log_action({
|
||||
"action": "profile_loaded",
|
||||
"task_id": task.task_id,
|
||||
"mode": task.mode
|
||||
})
|
||||
|
||||
except Exception as e:
|
||||
self.logger.log_action({
|
||||
"action": "profile_load_error",
|
||||
"file": str(profile_file),
|
||||
"error": str(e)
|
||||
})
|
||||
|
||||
def _load_existing_tasks_to_index(self):
|
||||
"""
|
||||
Charge les tâches existantes dans l'index FAISS au démarrage.
|
||||
Résout le problème : index vide malgré 40 tâches sauvegardées.
|
||||
"""
|
||||
from .faiss_index_builder import FAISSIndexBuilder
|
||||
|
||||
self.logger.log_action({
|
||||
"action": "load_existing_tasks_started"
|
||||
})
|
||||
|
||||
try:
|
||||
# Créer le builder
|
||||
builder = FAISSIndexBuilder(
|
||||
self.embeddings_manager,
|
||||
self.logger,
|
||||
str(self.profiles_path)
|
||||
)
|
||||
|
||||
# Vérifier l'intégrité de l'index
|
||||
report = builder.verify_index_integrity()
|
||||
|
||||
self.logger.log_action({
|
||||
"action": "index_integrity_check",
|
||||
"report": report
|
||||
})
|
||||
|
||||
# Si l'index est vide ou incohérent, le reconstruire
|
||||
if report['needs_rebuild']:
|
||||
self.logger.log_action({
|
||||
"action": "rebuilding_index_automatically",
|
||||
"reason": "index_empty_or_inconsistent"
|
||||
})
|
||||
|
||||
stats = builder.rebuild_index(force=True)
|
||||
|
||||
self.logger.log_action({
|
||||
"action": "index_rebuilt_automatically",
|
||||
"stats": stats
|
||||
})
|
||||
|
||||
if stats['success']:
|
||||
self.logger.log_action({
|
||||
"action": "load_existing_tasks_success",
|
||||
"embeddings_loaded": stats['embeddings_added'],
|
||||
"tasks_processed": stats['tasks_processed']
|
||||
})
|
||||
else:
|
||||
self.logger.log_action({
|
||||
"action": "load_existing_tasks_failed",
|
||||
"reason": "rebuild_failed"
|
||||
})
|
||||
else:
|
||||
self.logger.log_action({
|
||||
"action": "load_existing_tasks_skipped",
|
||||
"reason": "index_already_consistent",
|
||||
"index_size": report['actual_embeddings']
|
||||
})
|
||||
|
||||
# Charger les tâches dans self.tasks si pas déjà fait
|
||||
tasks_info = builder.scan_tasks()
|
||||
for task_info in tasks_info:
|
||||
task_id = task_info['task_id']
|
||||
if task_id not in self.tasks:
|
||||
# Charger la tâche
|
||||
task = self.load_task(task_id)
|
||||
if task:
|
||||
self.tasks[task_id] = task
|
||||
self.logger.log_action({
|
||||
"action": "task_loaded_from_disk",
|
||||
"task_id": task_id
|
||||
})
|
||||
|
||||
except Exception as e:
|
||||
self.logger.log_action({
|
||||
"action": "load_existing_tasks_error",
|
||||
"error": str(e)
|
||||
})
|
||||
# Ne pas bloquer le démarrage si erreur
|
||||
import traceback
|
||||
self.logger.log_action({
|
||||
"action": "load_existing_tasks_traceback",
|
||||
"traceback": traceback.format_exc()
|
||||
})
|
||||
460
geniusia2/core/llm_manager.py
Normal file
460
geniusia2/core/llm_manager.py
Normal file
@@ -0,0 +1,460 @@
|
||||
"""
|
||||
Gestionnaire LLM pour le raisonnement visuel avec Ollama.
|
||||
Interface vers les modèles vision-langage pour la prise de décision.
|
||||
"""
|
||||
|
||||
import json
|
||||
import base64
|
||||
from io import BytesIO
|
||||
from typing import Dict, List, Optional, Any, Tuple
|
||||
import numpy as np
|
||||
from PIL import Image
|
||||
|
||||
try:
|
||||
import ollama
|
||||
except ImportError:
|
||||
ollama = None
|
||||
|
||||
from .logger import Logger
|
||||
|
||||
|
||||
class LLMManager:
|
||||
"""
|
||||
Gestionnaire LLM pour le raisonnement visuel utilisant Ollama.
|
||||
Supporte les modèles vision-langage comme Qwen 2.5-VL et CogVLM.
|
||||
"""
|
||||
|
||||
def __init__(
|
||||
self,
|
||||
model_name: str = "qwen2.5-vl:3b",
|
||||
ollama_host: str = "localhost:11434",
|
||||
logger: Optional[Logger] = None,
|
||||
fallback_to_vision: bool = True
|
||||
):
|
||||
"""
|
||||
Initialise le gestionnaire LLM.
|
||||
|
||||
Args:
|
||||
model_name: Nom du modèle Ollama
|
||||
ollama_host: Hôte Ollama
|
||||
logger: Instance du logger
|
||||
fallback_to_vision: Utiliser la vision pure en cas d'échec LLM
|
||||
"""
|
||||
self.model_name = model_name
|
||||
self.ollama_host = ollama_host
|
||||
self.logger = logger
|
||||
self.fallback_to_vision = fallback_to_vision
|
||||
|
||||
# Initialiser le client Ollama
|
||||
self._init_client()
|
||||
|
||||
def _init_client(self):
|
||||
"""Initialise le client Ollama."""
|
||||
if ollama is None:
|
||||
raise ImportError(
|
||||
"Ollama n'est pas installé. "
|
||||
"Installez-le avec: pip install ollama"
|
||||
)
|
||||
|
||||
try:
|
||||
self.client = ollama.Client(host=self.ollama_host)
|
||||
|
||||
# Vérifier que le modèle est disponible
|
||||
models = self.client.list()
|
||||
model_names = [m.model for m in models.models] if hasattr(models, 'models') else []
|
||||
|
||||
if self.model_name not in model_names:
|
||||
if self.logger:
|
||||
self.logger.log_action({
|
||||
"action": "model_not_found",
|
||||
"model": self.model_name,
|
||||
"available_models": model_names
|
||||
})
|
||||
print(f"Avertissement: Le modèle {self.model_name} n'est pas trouvé.")
|
||||
print(f"Modèles disponibles: {model_names}")
|
||||
|
||||
if self.logger:
|
||||
self.logger.log_action({
|
||||
"action": "llm_client_initialized",
|
||||
"model": self.model_name,
|
||||
"host": self.ollama_host
|
||||
})
|
||||
|
||||
except Exception as e:
|
||||
error_msg = f"Erreur lors de l'initialisation du client Ollama: {e}"
|
||||
if self.logger:
|
||||
self.logger.log_action({
|
||||
"action": "llm_init_error",
|
||||
"error": str(e)
|
||||
})
|
||||
if not self.fallback_to_vision:
|
||||
raise RuntimeError(error_msg)
|
||||
print(f"Avertissement: {error_msg}")
|
||||
self.client = None
|
||||
|
||||
def _image_to_base64(self, image: np.ndarray) -> str:
|
||||
"""
|
||||
Convertit une image numpy en base64.
|
||||
|
||||
Args:
|
||||
image: Image numpy array (H, W, C)
|
||||
|
||||
Returns:
|
||||
String base64 de l'image
|
||||
"""
|
||||
# Convertir BGR vers RGB si nécessaire
|
||||
if len(image.shape) == 3 and image.shape[2] == 3:
|
||||
image_rgb = image[:, :, ::-1]
|
||||
else:
|
||||
image_rgb = image
|
||||
|
||||
# Convertir en PIL Image
|
||||
pil_image = Image.fromarray(image_rgb.astype(np.uint8))
|
||||
|
||||
# Convertir en base64
|
||||
buffered = BytesIO()
|
||||
pil_image.save(buffered, format="PNG")
|
||||
img_str = base64.b64encode(buffered.getvalue()).decode()
|
||||
|
||||
return img_str
|
||||
|
||||
def reason_about_detections(
|
||||
self,
|
||||
detections: List[Dict[str, Any]],
|
||||
context: Dict[str, Any],
|
||||
intent: str
|
||||
) -> Dict[str, Any]:
|
||||
"""
|
||||
Utilise le VLM pour sélectionner la meilleure action parmi les détections.
|
||||
|
||||
Args:
|
||||
detections: Liste de détections avec labels, bbox, images ROI
|
||||
context: Contexte actuel (fenêtre, historique, etc.)
|
||||
intent: Intention utilisateur
|
||||
|
||||
Returns:
|
||||
Dictionnaire avec l'élément sélectionné et le score de confiance
|
||||
"""
|
||||
if not detections:
|
||||
return {
|
||||
"selected_element": None,
|
||||
"confidence": 0.0,
|
||||
"reasoning": "Aucune détection disponible"
|
||||
}
|
||||
|
||||
# Fallback si pas de client Ollama
|
||||
if self.client is None and self.fallback_to_vision:
|
||||
return self._fallback_to_vision_only(detections)
|
||||
|
||||
try:
|
||||
# Préparer le prompt
|
||||
elements_desc = [
|
||||
f"- Élément {i+1}: {d['label']} (confiance: {d['confidence']:.2f})"
|
||||
for i, d in enumerate(detections)
|
||||
]
|
||||
|
||||
prompt = f"""Tu es un assistant d'automatisation RPA. Analyse ces éléments UI détectés et détermine lequel correspond le mieux à l'intention de l'utilisateur.
|
||||
|
||||
Intention: {intent}
|
||||
Contexte: Fenêtre '{context.get('window', 'Inconnue')}'
|
||||
|
||||
Éléments détectés:
|
||||
{chr(10).join(elements_desc)}
|
||||
|
||||
Réponds UNIQUEMENT avec un JSON au format suivant:
|
||||
{{
|
||||
"element_index": <index de l'élément (0-{len(detections)-1})>,
|
||||
"confidence": <score de confiance 0.0-1.0>,
|
||||
"reasoning": "<explication brève>"
|
||||
}}"""
|
||||
|
||||
# Préparer les images
|
||||
images = []
|
||||
for detection in detections:
|
||||
if 'roi_image' in detection and detection['roi_image'] is not None:
|
||||
img_b64 = self._image_to_base64(detection['roi_image'])
|
||||
images.append(img_b64)
|
||||
|
||||
# Générer la réponse
|
||||
response = self.client.generate(
|
||||
model=self.model_name,
|
||||
prompt=prompt,
|
||||
images=images if images else None,
|
||||
stream=False
|
||||
)
|
||||
|
||||
# Parser la réponse
|
||||
result = self._parse_llm_response(response['response'], detections)
|
||||
|
||||
if self.logger:
|
||||
self.logger.log_action({
|
||||
"action": "llm_reasoning",
|
||||
"intent": intent,
|
||||
"num_detections": len(detections),
|
||||
"selected_index": result.get("element_index"),
|
||||
"confidence": result.get("confidence")
|
||||
})
|
||||
|
||||
return result
|
||||
|
||||
except Exception as e:
|
||||
if self.logger:
|
||||
self.logger.log_action({
|
||||
"action": "llm_reasoning_error",
|
||||
"error": str(e)
|
||||
})
|
||||
|
||||
if self.fallback_to_vision:
|
||||
return self._fallback_to_vision_only(detections)
|
||||
|
||||
return {
|
||||
"selected_element": None,
|
||||
"confidence": 0.0,
|
||||
"reasoning": f"Erreur LLM: {str(e)}"
|
||||
}
|
||||
|
||||
def _fallback_to_vision_only(
|
||||
self,
|
||||
detections: List[Dict[str, Any]]
|
||||
) -> Dict[str, Any]:
|
||||
"""
|
||||
Fallback vers la sélection basée uniquement sur la confiance vision.
|
||||
|
||||
Args:
|
||||
detections: Liste de détections
|
||||
|
||||
Returns:
|
||||
Dictionnaire avec l'élément le plus confiant
|
||||
"""
|
||||
if not detections:
|
||||
return {
|
||||
"selected_element": None,
|
||||
"confidence": 0.0,
|
||||
"reasoning": "Aucune détection"
|
||||
}
|
||||
|
||||
# Sélectionner la détection avec la confiance la plus élevée
|
||||
best_detection = max(detections, key=lambda d: d.get('confidence', 0.0))
|
||||
best_index = detections.index(best_detection)
|
||||
|
||||
return {
|
||||
"element_index": best_index,
|
||||
"selected_element": best_detection,
|
||||
"confidence": best_detection.get('confidence', 0.0),
|
||||
"reasoning": "Sélection basée sur la confiance vision (fallback)",
|
||||
"llm_score": 0.0
|
||||
}
|
||||
|
||||
def _parse_llm_response(
|
||||
self,
|
||||
response: str,
|
||||
detections: List[Dict[str, Any]]
|
||||
) -> Dict[str, Any]:
|
||||
"""
|
||||
Parse la réponse du LLM.
|
||||
|
||||
Args:
|
||||
response: Réponse texte du LLM
|
||||
detections: Liste des détections originales
|
||||
|
||||
Returns:
|
||||
Dictionnaire avec l'élément sélectionné et les métadonnées
|
||||
"""
|
||||
try:
|
||||
# Extraire le JSON de la réponse
|
||||
response_clean = response.strip()
|
||||
|
||||
# Chercher le JSON dans la réponse
|
||||
start_idx = response_clean.find('{')
|
||||
end_idx = response_clean.rfind('}') + 1
|
||||
|
||||
if start_idx != -1 and end_idx > start_idx:
|
||||
json_str = response_clean[start_idx:end_idx]
|
||||
parsed = json.loads(json_str)
|
||||
|
||||
element_index = parsed.get('element_index', 0)
|
||||
confidence = parsed.get('confidence', 0.5)
|
||||
reasoning = parsed.get('reasoning', '')
|
||||
|
||||
# Valider l'index
|
||||
if 0 <= element_index < len(detections):
|
||||
selected = detections[element_index]
|
||||
|
||||
return {
|
||||
"element_index": element_index,
|
||||
"selected_element": selected,
|
||||
"confidence": confidence,
|
||||
"reasoning": reasoning,
|
||||
"llm_score": confidence
|
||||
}
|
||||
|
||||
# Si le parsing échoue, fallback
|
||||
return self._fallback_to_vision_only(detections)
|
||||
|
||||
except Exception as e:
|
||||
if self.logger:
|
||||
self.logger.log_action({
|
||||
"action": "llm_parse_error",
|
||||
"error": str(e),
|
||||
"response": response
|
||||
})
|
||||
return self._fallback_to_vision_only(detections)
|
||||
|
||||
def generate_with_vision(
|
||||
self,
|
||||
prompt: str,
|
||||
images: Optional[List[np.ndarray]] = None
|
||||
) -> str:
|
||||
"""
|
||||
Génération multi-modale avec images.
|
||||
|
||||
Args:
|
||||
prompt: Prompt texte
|
||||
images: Liste d'images numpy arrays
|
||||
|
||||
Returns:
|
||||
Réponse générée
|
||||
"""
|
||||
if self.client is None:
|
||||
return "Erreur: Client Ollama non disponible"
|
||||
|
||||
try:
|
||||
# Convertir les images en base64
|
||||
images_b64 = []
|
||||
if images:
|
||||
print(f"[LLM] Conversion de {len(images)} images en base64...")
|
||||
for i, img in enumerate(images):
|
||||
try:
|
||||
img_b64 = self._image_to_base64(img)
|
||||
images_b64.append(img_b64)
|
||||
print(f"[LLM] Image {i+1}/{len(images)} convertie ({len(img_b64)} bytes)")
|
||||
except Exception as e:
|
||||
print(f"[LLM] Erreur conversion image {i+1}: {e}")
|
||||
raise
|
||||
|
||||
# Générer
|
||||
print(f"[LLM] Appel Ollama avec modèle {self.model_name}...")
|
||||
print(f"[LLM] Prompt: {prompt[:100]}...")
|
||||
print(f"[LLM] Images: {len(images_b64) if images_b64 else 0}")
|
||||
|
||||
response = self.client.generate(
|
||||
model=self.model_name,
|
||||
prompt=prompt,
|
||||
images=images_b64 if images_b64 else None,
|
||||
stream=False,
|
||||
options={
|
||||
"temperature": 0.3, # Basse température pour réponses plus déterministes
|
||||
"num_predict": 20, # Limiter à 20 tokens (environ 3-4 mots)
|
||||
"top_p": 0.9,
|
||||
"top_k": 40
|
||||
}
|
||||
)
|
||||
|
||||
print(f"[LLM] Réponse brute: {response}")
|
||||
|
||||
# Qwen3-VL peut mettre la réponse dans 'thinking' au lieu de 'response'
|
||||
result = response.get('response', '')
|
||||
|
||||
# Si response est vide, essayer thinking
|
||||
if not result and 'thinking' in response:
|
||||
thinking = response['thinking']
|
||||
print(f"[LLM] Response vide, extraction depuis thinking: '{thinking}'")
|
||||
# Nettoyer les balises spéciales de Qwen
|
||||
result = thinking.replace('<|im_start|>', '').replace('<|im_end|>', '').replace('<think>', '').replace('</think>', '').strip()
|
||||
|
||||
print(f"[LLM] Réponse extraite: '{result}' (longueur: {len(result)})")
|
||||
return result
|
||||
|
||||
except Exception as e:
|
||||
print(f"[LLM] ❌ EXCEPTION: {e}")
|
||||
import traceback
|
||||
traceback.print_exc()
|
||||
if self.logger:
|
||||
self.logger.log_action({
|
||||
"action": "generation_error",
|
||||
"error": str(e)
|
||||
})
|
||||
return f"Erreur de génération: {str(e)}"
|
||||
|
||||
def score_action_relevance(
|
||||
self,
|
||||
action: Dict[str, Any],
|
||||
intent: str
|
||||
) -> float:
|
||||
"""
|
||||
Calcule un score de pertinence pour une action donnée.
|
||||
|
||||
Args:
|
||||
action: Dictionnaire décrivant l'action
|
||||
intent: Intention utilisateur
|
||||
|
||||
Returns:
|
||||
Score de confiance (0.0-1.0)
|
||||
"""
|
||||
if self.client is None:
|
||||
# Retourner un score neutre si pas de LLM
|
||||
return 0.5
|
||||
|
||||
try:
|
||||
prompt = f"""Évalue la pertinence de cette action par rapport à l'intention utilisateur.
|
||||
|
||||
Intention: {intent}
|
||||
Action: {action.get('action_type', 'unknown')} sur '{action.get('target_element', 'unknown')}'
|
||||
|
||||
Réponds UNIQUEMENT avec un score entre 0.0 et 1.0 (ex: 0.85)"""
|
||||
|
||||
response = self.client.generate(
|
||||
model=self.model_name,
|
||||
prompt=prompt,
|
||||
stream=False
|
||||
)
|
||||
|
||||
# Extraire le score
|
||||
response_text = response['response'].strip()
|
||||
|
||||
# Chercher un nombre décimal
|
||||
import re
|
||||
match = re.search(r'0\.\d+|1\.0|0|1', response_text)
|
||||
if match:
|
||||
score = float(match.group())
|
||||
return max(0.0, min(1.0, score))
|
||||
|
||||
return 0.5
|
||||
|
||||
except Exception as e:
|
||||
if self.logger:
|
||||
self.logger.log_action({
|
||||
"action": "scoring_error",
|
||||
"error": str(e)
|
||||
})
|
||||
return 0.5
|
||||
|
||||
def is_available(self) -> bool:
|
||||
"""
|
||||
Vérifie si le service LLM est disponible.
|
||||
|
||||
Returns:
|
||||
True si disponible, False sinon
|
||||
"""
|
||||
if self.client is None:
|
||||
return False
|
||||
|
||||
try:
|
||||
self.client.list()
|
||||
return True
|
||||
except:
|
||||
return False
|
||||
|
||||
def get_model_info(self) -> Dict[str, Any]:
|
||||
"""
|
||||
Retourne des informations sur le modèle.
|
||||
|
||||
Returns:
|
||||
Dictionnaire d'informations
|
||||
"""
|
||||
return {
|
||||
"model_name": self.model_name,
|
||||
"host": self.ollama_host,
|
||||
"available": self.is_available(),
|
||||
"fallback_enabled": self.fallback_to_vision
|
||||
}
|
||||
584
geniusia2/core/logger.py
Normal file
584
geniusia2/core/logger.py
Normal file
@@ -0,0 +1,584 @@
|
||||
"""
|
||||
Logger chiffré pour RPA Vision V2
|
||||
Gère la journalisation sécurisée avec chiffrement AES-256 de toutes les actions,
|
||||
corrections et transitions de mode du système.
|
||||
"""
|
||||
|
||||
import os
|
||||
import json
|
||||
import base64
|
||||
from datetime import datetime, timedelta
|
||||
from pathlib import Path
|
||||
from typing import Dict, Any, List, Optional
|
||||
from cryptography.hazmat.primitives.ciphers import Cipher, algorithms, modes
|
||||
from cryptography.hazmat.backends import default_backend
|
||||
from cryptography.hazmat.primitives import padding, hashes, hmac
|
||||
from cryptography.hazmat.primitives.kdf.pbkdf2 import PBKDF2HMAC
|
||||
import secrets
|
||||
|
||||
try:
|
||||
from .config import get_data_paths, get_security_config
|
||||
except ImportError:
|
||||
# Pour tests standalone
|
||||
import sys
|
||||
from pathlib import Path
|
||||
sys.path.insert(0, str(Path(__file__).parent))
|
||||
from config import get_data_paths, get_security_config
|
||||
|
||||
|
||||
class Logger:
|
||||
"""
|
||||
Logger chiffré avec AES-256-CBC pour journalisation sécurisée
|
||||
|
||||
Attributes:
|
||||
log_dir: Répertoire de stockage des logs
|
||||
key_path: Chemin du fichier de clé de chiffrement
|
||||
encryption_key: Clé AES-256 (32 bytes)
|
||||
hmac_key: Clé HMAC pour vérification d'intégrité
|
||||
"""
|
||||
|
||||
def __init__(self, log_dir: Optional[str] = None, key_path: Optional[str] = None):
|
||||
"""
|
||||
Initialise le logger avec génération ou chargement des clés AES
|
||||
|
||||
Args:
|
||||
log_dir: Répertoire pour les logs (utilise config par défaut si None)
|
||||
key_path: Chemin du fichier de clés (utilise config par défaut si None)
|
||||
"""
|
||||
# Configuration des chemins
|
||||
data_paths = get_data_paths()
|
||||
self.log_dir = Path(log_dir) if log_dir else Path(data_paths["logs"])
|
||||
self.key_path = Path(key_path) if key_path else Path(data_paths["encryption_keys"])
|
||||
|
||||
# Créer les répertoires si nécessaire
|
||||
self.log_dir.mkdir(parents=True, exist_ok=True)
|
||||
self.key_path.parent.mkdir(parents=True, exist_ok=True)
|
||||
|
||||
# Charger ou générer les clés de chiffrement
|
||||
self.encryption_key, self.hmac_key = self._load_or_generate_keys()
|
||||
|
||||
# Configuration de sécurité
|
||||
self.security_config = get_security_config()
|
||||
|
||||
def _load_or_generate_keys(self) -> tuple[bytes, bytes]:
|
||||
"""
|
||||
Charge les clés existantes ou en génère de nouvelles
|
||||
|
||||
Returns:
|
||||
Tuple (encryption_key, hmac_key)
|
||||
"""
|
||||
key_file = self.key_path / "encryption.key"
|
||||
|
||||
if key_file.exists():
|
||||
# Charger les clés existantes
|
||||
with open(key_file, 'rb') as f:
|
||||
key_data = f.read()
|
||||
# Les 32 premiers bytes sont la clé de chiffrement
|
||||
# Les 32 suivants sont la clé HMAC
|
||||
encryption_key = key_data[:32]
|
||||
hmac_key = key_data[32:64]
|
||||
else:
|
||||
# Générer de nouvelles clés
|
||||
encryption_key = secrets.token_bytes(32) # 256 bits
|
||||
hmac_key = secrets.token_bytes(32) # 256 bits
|
||||
|
||||
# S'assurer que le répertoire parent existe
|
||||
key_file.parent.mkdir(parents=True, exist_ok=True)
|
||||
|
||||
# Sauvegarder les clés de manière sécurisée
|
||||
with open(key_file, 'wb') as f:
|
||||
f.write(encryption_key + hmac_key)
|
||||
|
||||
# Restreindre les permissions (lecture/écriture propriétaire uniquement)
|
||||
os.chmod(key_file, 0o600)
|
||||
|
||||
return encryption_key, hmac_key
|
||||
|
||||
def encrypt_entry(self, data: Dict[str, Any]) -> bytes:
|
||||
"""
|
||||
Chiffre une entrée de log avec AES-256-CBC
|
||||
|
||||
Args:
|
||||
data: Dictionnaire contenant les données à chiffrer
|
||||
|
||||
Returns:
|
||||
Données chiffrées (IV + ciphertext + HMAC)
|
||||
"""
|
||||
# Convertir les données en JSON
|
||||
json_data = json.dumps(data, ensure_ascii=False, default=str)
|
||||
plaintext = json_data.encode('utf-8')
|
||||
|
||||
# Générer un IV aléatoire (16 bytes pour AES)
|
||||
iv = secrets.token_bytes(16)
|
||||
|
||||
# Padding PKCS7
|
||||
padder = padding.PKCS7(128).padder()
|
||||
padded_data = padder.update(plaintext) + padder.finalize()
|
||||
|
||||
# Chiffrement AES-256-CBC
|
||||
cipher = Cipher(
|
||||
algorithms.AES(self.encryption_key),
|
||||
modes.CBC(iv),
|
||||
backend=default_backend()
|
||||
)
|
||||
encryptor = cipher.encryptor()
|
||||
ciphertext = encryptor.update(padded_data) + encryptor.finalize()
|
||||
|
||||
# Calculer HMAC pour vérification d'intégrité
|
||||
h = hmac.HMAC(self.hmac_key, hashes.SHA256(), backend=default_backend())
|
||||
h.update(iv + ciphertext)
|
||||
mac = h.finalize()
|
||||
|
||||
# Retourner IV + ciphertext + HMAC
|
||||
return iv + ciphertext + mac
|
||||
|
||||
def decrypt_entry(self, encrypted_data: bytes) -> Dict[str, Any]:
|
||||
"""
|
||||
Déchiffre une entrée de log
|
||||
|
||||
Args:
|
||||
encrypted_data: Données chiffrées (IV + ciphertext + HMAC)
|
||||
|
||||
Returns:
|
||||
Dictionnaire contenant les données déchiffrées
|
||||
|
||||
Raises:
|
||||
ValueError: Si la vérification HMAC échoue
|
||||
"""
|
||||
# Extraire IV, ciphertext et HMAC
|
||||
iv = encrypted_data[:16]
|
||||
mac = encrypted_data[-32:]
|
||||
ciphertext = encrypted_data[16:-32]
|
||||
|
||||
# Vérifier l'intégrité avec HMAC
|
||||
h = hmac.HMAC(self.hmac_key, hashes.SHA256(), backend=default_backend())
|
||||
h.update(iv + ciphertext)
|
||||
try:
|
||||
h.verify(mac)
|
||||
except Exception:
|
||||
raise ValueError("HMAC verification failed - data may be corrupted or tampered")
|
||||
|
||||
# Déchiffrement AES-256-CBC
|
||||
cipher = Cipher(
|
||||
algorithms.AES(self.encryption_key),
|
||||
modes.CBC(iv),
|
||||
backend=default_backend()
|
||||
)
|
||||
decryptor = cipher.decryptor()
|
||||
padded_plaintext = decryptor.update(ciphertext) + decryptor.finalize()
|
||||
|
||||
# Retirer le padding PKCS7
|
||||
unpadder = padding.PKCS7(128).unpadder()
|
||||
plaintext = unpadder.update(padded_plaintext) + unpadder.finalize()
|
||||
|
||||
# Convertir JSON en dictionnaire
|
||||
json_data = plaintext.decode('utf-8')
|
||||
return json.loads(json_data)
|
||||
|
||||
def _get_log_file_path(self, date: Optional[datetime] = None) -> Path:
|
||||
"""
|
||||
Génère le chemin du fichier de log pour une date donnée
|
||||
|
||||
Args:
|
||||
date: Date pour le fichier de log (utilise aujourd'hui si None)
|
||||
|
||||
Returns:
|
||||
Chemin du fichier de log
|
||||
"""
|
||||
if date is None:
|
||||
date = datetime.now()
|
||||
|
||||
# Format: logs_YYYY-MM-DD.enc
|
||||
filename = f"logs_{date.strftime('%Y-%m-%d')}.enc"
|
||||
return self.log_dir / filename
|
||||
|
||||
def _write_encrypted_entry(self, encrypted_data: bytes):
|
||||
"""
|
||||
Écrit une entrée chiffrée dans le fichier de log du jour
|
||||
|
||||
Args:
|
||||
encrypted_data: Données chiffrées à écrire
|
||||
"""
|
||||
log_file = self._get_log_file_path()
|
||||
|
||||
# Encoder en base64 pour stockage texte
|
||||
encoded_data = base64.b64encode(encrypted_data).decode('ascii')
|
||||
|
||||
# Ajouter au fichier (mode append)
|
||||
with open(log_file, 'a') as f:
|
||||
f.write(encoded_data + '\n')
|
||||
|
||||
def log_action(self, action_data: Dict[str, Any]):
|
||||
"""
|
||||
Enregistre une action (sans chiffrement pour MVP)
|
||||
|
||||
Args:
|
||||
action_data: Dictionnaire contenant les données de l'action
|
||||
"""
|
||||
# Ajouter métadonnées système
|
||||
log_entry = {
|
||||
"type": "action",
|
||||
"timestamp": datetime.now().isoformat(),
|
||||
**action_data
|
||||
}
|
||||
|
||||
# Écrire en JSON simple (pas de chiffrement pour MVP)
|
||||
self._write_plain_entry(log_entry)
|
||||
|
||||
def _write_plain_entry(self, entry: Dict[str, Any]):
|
||||
"""Écrit une entrée en JSON simple (MVP - pas de chiffrement)."""
|
||||
log_file = self.log_dir / f"logs_{datetime.now().strftime('%Y-%m-%d')}.json"
|
||||
|
||||
# Écrire en mode append
|
||||
with open(log_file, 'a') as f:
|
||||
f.write(json.dumps(entry, ensure_ascii=False) + '\n')
|
||||
|
||||
def log_correction(self, correction_data: Dict[str, Any]):
|
||||
"""
|
||||
Enregistre une correction utilisateur (sans chiffrement pour MVP)
|
||||
|
||||
Args:
|
||||
correction_data: Dictionnaire contenant les données de correction
|
||||
Champs attendus:
|
||||
- task_id: Identifiant de la tâche
|
||||
- incorrect_element: Élément incorrectement détecté
|
||||
- correct_element: Élément correct fourni par l'utilisateur
|
||||
- incorrect_bbox: Bounding box incorrecte
|
||||
- correct_bbox: Bounding box correcte
|
||||
- window: Titre de la fenêtre
|
||||
- mode: Mode opérationnel au moment de la correction
|
||||
"""
|
||||
# Ajouter métadonnées système
|
||||
log_entry = {
|
||||
"type": "correction",
|
||||
"timestamp": datetime.now().isoformat(),
|
||||
**correction_data
|
||||
}
|
||||
|
||||
# Écrire en JSON simple (pas de chiffrement pour MVP)
|
||||
self._write_plain_entry(log_entry)
|
||||
|
||||
def log_mode_transition(self, task_id: str, from_mode: str, to_mode: str, reason: str):
|
||||
"""
|
||||
Enregistre une transition de mode pour une tâche (sans chiffrement pour MVP)
|
||||
|
||||
Args:
|
||||
task_id: Identifiant de la tâche
|
||||
from_mode: Mode d'origine ("shadow", "assist", "auto")
|
||||
to_mode: Nouveau mode ("shadow", "assist", "auto")
|
||||
reason: Raison de la transition (ex: "low_confidence", "threshold_met")
|
||||
"""
|
||||
log_entry = {
|
||||
"type": "mode_transition",
|
||||
"timestamp": datetime.now().isoformat(),
|
||||
"task_id": task_id,
|
||||
"from_mode": from_mode,
|
||||
"to_mode": to_mode,
|
||||
"reason": reason
|
||||
}
|
||||
|
||||
# Écrire en JSON simple (pas de chiffrement pour MVP)
|
||||
self._write_plain_entry(log_entry)
|
||||
|
||||
def log_security_event(self, event_data: Dict[str, Any]):
|
||||
"""
|
||||
Enregistre un événement de sécurité (ex: violation de liste blanche)
|
||||
|
||||
Args:
|
||||
event_data: Dictionnaire contenant les données de l'événement
|
||||
Champs attendus:
|
||||
- event_type: Type d'événement ("whitelist_violation", "rollback", etc.)
|
||||
- window: Titre de la fenêtre concernée
|
||||
- action_attempted: Action tentée
|
||||
- details: Détails additionnels
|
||||
"""
|
||||
log_entry = {
|
||||
"type": "security_event",
|
||||
"timestamp": datetime.now().isoformat(),
|
||||
**event_data
|
||||
}
|
||||
|
||||
# Chiffrer et écrire
|
||||
encrypted = self.encrypt_entry(log_entry)
|
||||
self._write_encrypted_entry(encrypted)
|
||||
|
||||
def get_logs(
|
||||
self,
|
||||
task_id: Optional[str] = None,
|
||||
start_time: Optional[datetime] = None,
|
||||
end_time: Optional[datetime] = None,
|
||||
log_type: Optional[str] = None
|
||||
) -> List[Dict[str, Any]]:
|
||||
"""
|
||||
Interroge les logs avec filtres optionnels
|
||||
|
||||
Args:
|
||||
task_id: Filtrer par identifiant de tâche
|
||||
start_time: Filtrer par date de début
|
||||
end_time: Filtrer par date de fin
|
||||
log_type: Filtrer par type de log ("action", "correction", "mode_transition", "security_event")
|
||||
|
||||
Returns:
|
||||
Liste des entrées de log correspondant aux critères
|
||||
"""
|
||||
results = []
|
||||
|
||||
# Déterminer les fichiers de log à lire
|
||||
if start_time is None:
|
||||
start_time = datetime.now() - timedelta(days=self.security_config["log_retention_days"])
|
||||
if end_time is None:
|
||||
end_time = datetime.now()
|
||||
|
||||
# Parcourir les jours dans la plage
|
||||
current_date = start_time.date()
|
||||
end_date = end_time.date()
|
||||
|
||||
while current_date <= end_date:
|
||||
log_file = self._get_log_file_path(datetime.combine(current_date, datetime.min.time()))
|
||||
|
||||
if log_file.exists():
|
||||
# Lire et déchiffrer les entrées
|
||||
with open(log_file, 'r') as f:
|
||||
for line in f:
|
||||
line = line.strip()
|
||||
if not line:
|
||||
continue
|
||||
|
||||
try:
|
||||
# Décoder base64 et déchiffrer
|
||||
encrypted_data = base64.b64decode(line)
|
||||
entry = self.decrypt_entry(encrypted_data)
|
||||
|
||||
# Appliquer les filtres
|
||||
entry_time = datetime.fromisoformat(entry["timestamp"])
|
||||
|
||||
if entry_time < start_time or entry_time > end_time:
|
||||
continue
|
||||
|
||||
if task_id and entry.get("task_id") != task_id:
|
||||
continue
|
||||
|
||||
if log_type and entry.get("type") != log_type:
|
||||
continue
|
||||
|
||||
results.append(entry)
|
||||
|
||||
except Exception as e:
|
||||
# Logger l'erreur mais continuer le traitement
|
||||
print(f"Warning: Failed to decrypt log entry: {e}")
|
||||
continue
|
||||
|
||||
# Passer au jour suivant
|
||||
current_date += timedelta(days=1)
|
||||
|
||||
# Trier par timestamp
|
||||
results.sort(key=lambda x: x["timestamp"])
|
||||
|
||||
return results
|
||||
|
||||
def cleanup_old_logs(self):
|
||||
"""
|
||||
Supprime les logs plus anciens que la période de rétention configurée
|
||||
"""
|
||||
retention_days = self.security_config["log_retention_days"]
|
||||
cutoff_date = datetime.now() - timedelta(days=retention_days)
|
||||
|
||||
# Parcourir tous les fichiers de log
|
||||
for log_file in self.log_dir.glob("logs_*.enc"):
|
||||
try:
|
||||
# Extraire la date du nom de fichier
|
||||
date_str = log_file.stem.replace("logs_", "")
|
||||
file_date = datetime.strptime(date_str, "%Y-%m-%d")
|
||||
|
||||
# Supprimer si trop ancien
|
||||
if file_date < cutoff_date:
|
||||
log_file.unlink()
|
||||
print(f"Deleted old log file: {log_file.name}")
|
||||
|
||||
except Exception as e:
|
||||
print(f"Warning: Failed to process log file {log_file.name}: {e}")
|
||||
|
||||
def get_task_statistics(self, task_id: str) -> Dict[str, Any]:
|
||||
"""
|
||||
Calcule des statistiques pour une tâche spécifique
|
||||
|
||||
Args:
|
||||
task_id: Identifiant de la tâche
|
||||
|
||||
Returns:
|
||||
Dictionnaire contenant les statistiques:
|
||||
- total_actions: Nombre total d'actions
|
||||
- success_count: Nombre d'actions réussies
|
||||
- failed_count: Nombre d'actions échouées
|
||||
- correction_count: Nombre de corrections
|
||||
- mode_transitions: Liste des transitions de mode
|
||||
- avg_confidence: Confiance moyenne
|
||||
- last_execution: Dernière exécution
|
||||
"""
|
||||
# Récupérer tous les logs pour cette tâche
|
||||
logs = self.get_logs(task_id=task_id)
|
||||
|
||||
stats = {
|
||||
"task_id": task_id,
|
||||
"total_actions": 0,
|
||||
"success_count": 0,
|
||||
"failed_count": 0,
|
||||
"correction_count": 0,
|
||||
"mode_transitions": [],
|
||||
"avg_confidence": 0.0,
|
||||
"last_execution": None
|
||||
}
|
||||
|
||||
confidence_sum = 0.0
|
||||
confidence_count = 0
|
||||
|
||||
for entry in logs:
|
||||
entry_type = entry.get("type")
|
||||
|
||||
if entry_type == "action":
|
||||
stats["total_actions"] += 1
|
||||
|
||||
result = entry.get("result")
|
||||
if result == "success":
|
||||
stats["success_count"] += 1
|
||||
elif result == "failed":
|
||||
stats["failed_count"] += 1
|
||||
|
||||
if "confidence" in entry:
|
||||
confidence_sum += entry["confidence"]
|
||||
confidence_count += 1
|
||||
|
||||
# Mettre à jour la dernière exécution
|
||||
timestamp = entry.get("timestamp")
|
||||
if timestamp:
|
||||
if stats["last_execution"] is None or timestamp > stats["last_execution"]:
|
||||
stats["last_execution"] = timestamp
|
||||
|
||||
elif entry_type == "correction":
|
||||
stats["correction_count"] += 1
|
||||
|
||||
elif entry_type == "mode_transition":
|
||||
stats["mode_transitions"].append({
|
||||
"timestamp": entry.get("timestamp"),
|
||||
"from_mode": entry.get("from_mode"),
|
||||
"to_mode": entry.get("to_mode"),
|
||||
"reason": entry.get("reason")
|
||||
})
|
||||
|
||||
# Calculer la confiance moyenne
|
||||
if confidence_count > 0:
|
||||
stats["avg_confidence"] = confidence_sum / confidence_count
|
||||
|
||||
return stats
|
||||
|
||||
|
||||
if __name__ == "__main__":
|
||||
# Tests du logger
|
||||
print("Test du Logger chiffré RPA Vision V2")
|
||||
print("=" * 50)
|
||||
|
||||
# Créer un logger de test
|
||||
import tempfile
|
||||
import shutil
|
||||
|
||||
test_dir = tempfile.mkdtemp()
|
||||
test_log_dir = os.path.join(test_dir, "logs")
|
||||
test_key_dir = os.path.join(test_dir, "keys")
|
||||
|
||||
try:
|
||||
logger = Logger(log_dir=test_log_dir, key_path=test_key_dir)
|
||||
print(f"✓ Logger initialisé")
|
||||
print(f" Log dir: {logger.log_dir}")
|
||||
print(f" Key path: {logger.key_path}")
|
||||
|
||||
# Test 1: Log d'action
|
||||
print("\n1. Test log_action:")
|
||||
logger.log_action({
|
||||
"window": "Dolibarr - Facturation",
|
||||
"action": "click",
|
||||
"element": "valider_button",
|
||||
"bbox": [450, 320, 120, 40],
|
||||
"confidence": 0.97,
|
||||
"mode": "auto",
|
||||
"result": "success",
|
||||
"task_id": "ouvrir_facture_001"
|
||||
})
|
||||
print(" ✓ Action loggée")
|
||||
|
||||
# Test 2: Log de correction
|
||||
print("\n2. Test log_correction:")
|
||||
logger.log_correction({
|
||||
"task_id": "ouvrir_facture_001",
|
||||
"incorrect_element": "annuler_button",
|
||||
"correct_element": "valider_button",
|
||||
"incorrect_bbox": [350, 320, 120, 40],
|
||||
"correct_bbox": [450, 320, 120, 40],
|
||||
"window": "Dolibarr - Facturation",
|
||||
"mode": "assist"
|
||||
})
|
||||
print(" ✓ Correction loggée")
|
||||
|
||||
# Test 3: Log de transition de mode
|
||||
print("\n3. Test log_mode_transition:")
|
||||
logger.log_mode_transition(
|
||||
task_id="ouvrir_facture_001",
|
||||
from_mode="assist",
|
||||
to_mode="auto",
|
||||
reason="threshold_met"
|
||||
)
|
||||
print(" ✓ Transition de mode loggée")
|
||||
|
||||
# Test 4: Log d'événement de sécurité
|
||||
print("\n4. Test log_security_event:")
|
||||
logger.log_security_event({
|
||||
"event_type": "whitelist_violation",
|
||||
"window": "Unknown Application",
|
||||
"action_attempted": "click",
|
||||
"details": "Window not in whitelist"
|
||||
})
|
||||
print(" ✓ Événement de sécurité loggé")
|
||||
|
||||
# Test 5: Récupération des logs
|
||||
print("\n5. Test get_logs:")
|
||||
all_logs = logger.get_logs()
|
||||
print(f" ✓ {len(all_logs)} entrées récupérées")
|
||||
|
||||
# Filtrer par task_id
|
||||
task_logs = logger.get_logs(task_id="ouvrir_facture_001")
|
||||
print(f" ✓ {len(task_logs)} entrées pour task_id='ouvrir_facture_001'")
|
||||
|
||||
# Filtrer par type
|
||||
action_logs = logger.get_logs(log_type="action")
|
||||
print(f" ✓ {len(action_logs)} entrées de type 'action'")
|
||||
|
||||
# Test 6: Statistiques de tâche
|
||||
print("\n6. Test get_task_statistics:")
|
||||
stats = logger.get_task_statistics("ouvrir_facture_001")
|
||||
print(f" ✓ Statistiques calculées:")
|
||||
print(f" - Total actions: {stats['total_actions']}")
|
||||
print(f" - Succès: {stats['success_count']}")
|
||||
print(f" - Corrections: {stats['correction_count']}")
|
||||
print(f" - Transitions: {len(stats['mode_transitions'])}")
|
||||
print(f" - Confiance moyenne: {stats['avg_confidence']:.2f}")
|
||||
|
||||
# Test 7: Chiffrement/Déchiffrement
|
||||
print("\n7. Test encrypt/decrypt:")
|
||||
test_data = {
|
||||
"test": "data",
|
||||
"number": 42,
|
||||
"nested": {"key": "value"}
|
||||
}
|
||||
encrypted = logger.encrypt_entry(test_data)
|
||||
print(f" ✓ Données chiffrées ({len(encrypted)} bytes)")
|
||||
|
||||
decrypted = logger.decrypt_entry(encrypted)
|
||||
print(f" ✓ Données déchiffrées")
|
||||
assert decrypted == test_data, "Decryption mismatch!"
|
||||
print(f" ✓ Vérification: données identiques")
|
||||
|
||||
print("\n✓ Tous les tests réussis!")
|
||||
|
||||
finally:
|
||||
# Nettoyer les fichiers de test
|
||||
shutil.rmtree(test_dir)
|
||||
print(f"\n✓ Fichiers de test nettoyés")
|
||||
495
geniusia2/core/metrics_collector.py
Normal file
495
geniusia2/core/metrics_collector.py
Normal file
@@ -0,0 +1,495 @@
|
||||
"""
|
||||
Collecteur de métriques pour surveiller les performances du système RPA.
|
||||
Suit la latence, la concordance, le taux de correction et génère des alertes.
|
||||
"""
|
||||
|
||||
import time
|
||||
from typing import Dict, Any, List, Optional
|
||||
from datetime import datetime, timedelta
|
||||
from collections import defaultdict
|
||||
|
||||
from .logger import Logger
|
||||
|
||||
|
||||
class MetricsCollector:
|
||||
"""
|
||||
Collecteur de métriques pour surveillance des performances.
|
||||
"""
|
||||
|
||||
def __init__(self, logger: Logger, config: Dict[str, Any]):
|
||||
"""
|
||||
Initialise le collecteur de métriques.
|
||||
|
||||
Args:
|
||||
logger: Logger pour journalisation
|
||||
config: Configuration globale
|
||||
"""
|
||||
self.logger = logger
|
||||
self.config = config
|
||||
|
||||
# Seuils de performance
|
||||
self.latency_threshold = config.get("performance", {}).get(
|
||||
"max_latency_ms", 400
|
||||
)
|
||||
self.concordance_threshold = config.get("thresholds", {}).get(
|
||||
"concordance_rate", 0.95
|
||||
)
|
||||
self.correction_rate_threshold = config.get("thresholds", {}).get(
|
||||
"correction_rate", 0.03
|
||||
)
|
||||
|
||||
# Métriques par tâche
|
||||
self.task_metrics: Dict[str, Dict[str, Any]] = defaultdict(
|
||||
lambda: {
|
||||
"latencies": [],
|
||||
"successes": 0,
|
||||
"failures": 0,
|
||||
"corrections": 0,
|
||||
"total_executions": 0,
|
||||
"last_execution": None
|
||||
}
|
||||
)
|
||||
|
||||
# Métriques globales
|
||||
self.global_metrics = {
|
||||
"total_latencies": [],
|
||||
"total_successes": 0,
|
||||
"total_failures": 0,
|
||||
"total_corrections": 0,
|
||||
"total_executions": 0,
|
||||
"alerts_generated": 0
|
||||
}
|
||||
|
||||
# Historique des alertes
|
||||
self.alerts: List[Dict[str, Any]] = []
|
||||
|
||||
self.logger.log_action({
|
||||
"action": "metrics_collector_initialized",
|
||||
"latency_threshold_ms": self.latency_threshold,
|
||||
"concordance_threshold": self.concordance_threshold,
|
||||
"correction_rate_threshold": self.correction_rate_threshold
|
||||
})
|
||||
|
||||
def track_latency(
|
||||
self,
|
||||
start_time: float,
|
||||
end_time: float,
|
||||
task_id: Optional[str] = None,
|
||||
operation: str = "execution"
|
||||
) -> float:
|
||||
"""
|
||||
Enregistre la latence d'une opération.
|
||||
|
||||
Args:
|
||||
start_time: Timestamp de début
|
||||
end_time: Timestamp de fin
|
||||
task_id: ID de la tâche (optionnel)
|
||||
operation: Type d'opération
|
||||
|
||||
Returns:
|
||||
Latence en millisecondes
|
||||
"""
|
||||
latency_ms = (end_time - start_time) * 1000
|
||||
|
||||
# Enregistrer dans les métriques globales
|
||||
self.global_metrics["total_latencies"].append(latency_ms)
|
||||
|
||||
# Enregistrer par tâche si spécifié
|
||||
if task_id:
|
||||
self.task_metrics[task_id]["latencies"].append(latency_ms)
|
||||
|
||||
# Logger
|
||||
self.logger.log_action({
|
||||
"action": "latency_tracked",
|
||||
"task_id": task_id,
|
||||
"operation": operation,
|
||||
"latency_ms": latency_ms,
|
||||
"threshold_exceeded": latency_ms > self.latency_threshold
|
||||
})
|
||||
|
||||
# Vérifier le seuil
|
||||
if latency_ms > self.latency_threshold:
|
||||
self._generate_alert(
|
||||
"latency_threshold_exceeded",
|
||||
{
|
||||
"task_id": task_id,
|
||||
"operation": operation,
|
||||
"latency_ms": latency_ms,
|
||||
"threshold_ms": self.latency_threshold
|
||||
}
|
||||
)
|
||||
|
||||
return latency_ms
|
||||
|
||||
def track_concordance(
|
||||
self,
|
||||
task_id: str,
|
||||
success: bool,
|
||||
metadata: Optional[Dict[str, Any]] = None
|
||||
):
|
||||
"""
|
||||
Enregistre le résultat d'une exécution pour calcul de concordance.
|
||||
|
||||
Args:
|
||||
task_id: ID de la tâche
|
||||
success: True si succès, False si échec
|
||||
metadata: Métadonnées additionnelles
|
||||
"""
|
||||
# Mettre à jour les compteurs
|
||||
self.task_metrics[task_id]["total_executions"] += 1
|
||||
self.task_metrics[task_id]["last_execution"] = datetime.now().isoformat()
|
||||
self.global_metrics["total_executions"] += 1
|
||||
|
||||
if success:
|
||||
self.task_metrics[task_id]["successes"] += 1
|
||||
self.global_metrics["total_successes"] += 1
|
||||
else:
|
||||
self.task_metrics[task_id]["failures"] += 1
|
||||
self.global_metrics["total_failures"] += 1
|
||||
|
||||
# Calculer le taux de concordance
|
||||
concordance_rate = self.get_concordance_rate(task_id)
|
||||
|
||||
self.logger.log_action({
|
||||
"action": "concordance_tracked",
|
||||
"task_id": task_id,
|
||||
"success": success,
|
||||
"concordance_rate": concordance_rate,
|
||||
"total_executions": self.task_metrics[task_id]["total_executions"],
|
||||
"metadata": metadata
|
||||
})
|
||||
|
||||
# Vérifier le seuil (seulement si assez d'exécutions)
|
||||
if self.task_metrics[task_id]["total_executions"] >= 10:
|
||||
if concordance_rate < self.concordance_threshold:
|
||||
self._generate_alert(
|
||||
"concordance_below_threshold",
|
||||
{
|
||||
"task_id": task_id,
|
||||
"concordance_rate": concordance_rate,
|
||||
"threshold": self.concordance_threshold,
|
||||
"total_executions": self.task_metrics[task_id]["total_executions"]
|
||||
}
|
||||
)
|
||||
|
||||
def track_correction_rate(
|
||||
self,
|
||||
task_id: str,
|
||||
correction_made: bool = True
|
||||
):
|
||||
"""
|
||||
Enregistre une correction utilisateur.
|
||||
|
||||
Args:
|
||||
task_id: ID de la tâche
|
||||
correction_made: True si correction effectuée
|
||||
"""
|
||||
if correction_made:
|
||||
self.task_metrics[task_id]["corrections"] += 1
|
||||
self.global_metrics["total_corrections"] += 1
|
||||
|
||||
# Calculer le taux de correction
|
||||
correction_rate = self.get_correction_rate(task_id)
|
||||
|
||||
self.logger.log_action({
|
||||
"action": "correction_tracked",
|
||||
"task_id": task_id,
|
||||
"correction_rate": correction_rate,
|
||||
"total_corrections": self.task_metrics[task_id]["corrections"],
|
||||
"total_executions": self.task_metrics[task_id]["total_executions"]
|
||||
})
|
||||
|
||||
# Vérifier le seuil (seulement si assez d'exécutions)
|
||||
if self.task_metrics[task_id]["total_executions"] >= 10:
|
||||
if correction_rate > self.correction_rate_threshold:
|
||||
self._generate_alert(
|
||||
"correction_rate_above_threshold",
|
||||
{
|
||||
"task_id": task_id,
|
||||
"correction_rate": correction_rate,
|
||||
"threshold": self.correction_rate_threshold,
|
||||
"total_corrections": self.task_metrics[task_id]["corrections"],
|
||||
"total_executions": self.task_metrics[task_id]["total_executions"]
|
||||
}
|
||||
)
|
||||
|
||||
def check_performance_thresholds(self) -> List[Dict[str, Any]]:
|
||||
"""
|
||||
Vérifie tous les seuils de performance et génère des alertes.
|
||||
|
||||
Returns:
|
||||
Liste des alertes générées
|
||||
"""
|
||||
alerts = []
|
||||
|
||||
# Vérifier la latence moyenne globale
|
||||
if self.global_metrics["total_latencies"]:
|
||||
avg_latency = sum(self.global_metrics["total_latencies"]) / len(
|
||||
self.global_metrics["total_latencies"]
|
||||
)
|
||||
|
||||
if avg_latency > self.latency_threshold:
|
||||
alert = self._generate_alert(
|
||||
"global_latency_high",
|
||||
{
|
||||
"avg_latency_ms": avg_latency,
|
||||
"threshold_ms": self.latency_threshold,
|
||||
"num_measurements": len(self.global_metrics["total_latencies"])
|
||||
}
|
||||
)
|
||||
alerts.append(alert)
|
||||
|
||||
# Vérifier la concordance globale
|
||||
if self.global_metrics["total_executions"] > 0:
|
||||
global_concordance = (
|
||||
self.global_metrics["total_successes"] /
|
||||
self.global_metrics["total_executions"]
|
||||
)
|
||||
|
||||
if global_concordance < self.concordance_threshold:
|
||||
alert = self._generate_alert(
|
||||
"global_concordance_low",
|
||||
{
|
||||
"concordance_rate": global_concordance,
|
||||
"threshold": self.concordance_threshold,
|
||||
"total_executions": self.global_metrics["total_executions"]
|
||||
}
|
||||
)
|
||||
alerts.append(alert)
|
||||
|
||||
# Vérifier le taux de correction global
|
||||
if self.global_metrics["total_executions"] > 0:
|
||||
global_correction_rate = (
|
||||
self.global_metrics["total_corrections"] /
|
||||
self.global_metrics["total_executions"]
|
||||
)
|
||||
|
||||
if global_correction_rate > self.correction_rate_threshold:
|
||||
alert = self._generate_alert(
|
||||
"global_correction_rate_high",
|
||||
{
|
||||
"correction_rate": global_correction_rate,
|
||||
"threshold": self.correction_rate_threshold,
|
||||
"total_corrections": self.global_metrics["total_corrections"],
|
||||
"total_executions": self.global_metrics["total_executions"]
|
||||
}
|
||||
)
|
||||
alerts.append(alert)
|
||||
|
||||
# Vérifier chaque tâche
|
||||
for task_id, metrics in self.task_metrics.items():
|
||||
if metrics["total_executions"] < 10:
|
||||
continue # Pas assez de données
|
||||
|
||||
# Latence moyenne par tâche
|
||||
if metrics["latencies"]:
|
||||
avg_latency = sum(metrics["latencies"]) / len(metrics["latencies"])
|
||||
if avg_latency > self.latency_threshold:
|
||||
alert = self._generate_alert(
|
||||
"task_latency_high",
|
||||
{
|
||||
"task_id": task_id,
|
||||
"avg_latency_ms": avg_latency,
|
||||
"threshold_ms": self.latency_threshold
|
||||
}
|
||||
)
|
||||
alerts.append(alert)
|
||||
|
||||
return alerts
|
||||
|
||||
def get_concordance_rate(self, task_id: str) -> float:
|
||||
"""
|
||||
Calcule le taux de concordance pour une tâche.
|
||||
|
||||
Args:
|
||||
task_id: ID de la tâche
|
||||
|
||||
Returns:
|
||||
Taux de concordance (0.0 à 1.0)
|
||||
"""
|
||||
metrics = self.task_metrics[task_id]
|
||||
total = metrics["total_executions"]
|
||||
|
||||
if total == 0:
|
||||
return 0.0
|
||||
|
||||
return metrics["successes"] / total
|
||||
|
||||
def get_correction_rate(self, task_id: str) -> float:
|
||||
"""
|
||||
Calcule le taux de correction pour une tâche.
|
||||
|
||||
Args:
|
||||
task_id: ID de la tâche
|
||||
|
||||
Returns:
|
||||
Taux de correction (0.0 à 1.0)
|
||||
"""
|
||||
metrics = self.task_metrics[task_id]
|
||||
total = metrics["total_executions"]
|
||||
|
||||
if total == 0:
|
||||
return 0.0
|
||||
|
||||
return metrics["corrections"] / total
|
||||
|
||||
def get_average_latency(
|
||||
self,
|
||||
task_id: Optional[str] = None,
|
||||
window_size: Optional[int] = None
|
||||
) -> float:
|
||||
"""
|
||||
Calcule la latence moyenne.
|
||||
|
||||
Args:
|
||||
task_id: ID de la tâche (None pour global)
|
||||
window_size: Nombre de dernières mesures à considérer
|
||||
|
||||
Returns:
|
||||
Latence moyenne en millisecondes
|
||||
"""
|
||||
if task_id:
|
||||
latencies = self.task_metrics[task_id]["latencies"]
|
||||
else:
|
||||
latencies = self.global_metrics["total_latencies"]
|
||||
|
||||
if not latencies:
|
||||
return 0.0
|
||||
|
||||
if window_size:
|
||||
latencies = latencies[-window_size:]
|
||||
|
||||
return sum(latencies) / len(latencies)
|
||||
|
||||
def get_task_metrics(self, task_id: str) -> Dict[str, Any]:
|
||||
"""
|
||||
Retourne les métriques d'une tâche.
|
||||
|
||||
Args:
|
||||
task_id: ID de la tâche
|
||||
|
||||
Returns:
|
||||
Dictionnaire de métriques
|
||||
"""
|
||||
metrics = self.task_metrics[task_id]
|
||||
|
||||
return {
|
||||
"task_id": task_id,
|
||||
"total_executions": metrics["total_executions"],
|
||||
"successes": metrics["successes"],
|
||||
"failures": metrics["failures"],
|
||||
"corrections": metrics["corrections"],
|
||||
"concordance_rate": self.get_concordance_rate(task_id),
|
||||
"correction_rate": self.get_correction_rate(task_id),
|
||||
"avg_latency_ms": self.get_average_latency(task_id),
|
||||
"last_execution": metrics["last_execution"]
|
||||
}
|
||||
|
||||
def get_global_metrics(self) -> Dict[str, Any]:
|
||||
"""
|
||||
Retourne les métriques globales.
|
||||
|
||||
Returns:
|
||||
Dictionnaire de métriques globales
|
||||
"""
|
||||
total_exec = self.global_metrics["total_executions"]
|
||||
|
||||
return {
|
||||
"total_executions": total_exec,
|
||||
"total_successes": self.global_metrics["total_successes"],
|
||||
"total_failures": self.global_metrics["total_failures"],
|
||||
"total_corrections": self.global_metrics["total_corrections"],
|
||||
"global_concordance_rate": (
|
||||
self.global_metrics["total_successes"] / total_exec
|
||||
if total_exec > 0 else 0.0
|
||||
),
|
||||
"global_correction_rate": (
|
||||
self.global_metrics["total_corrections"] / total_exec
|
||||
if total_exec > 0 else 0.0
|
||||
),
|
||||
"avg_latency_ms": self.get_average_latency(),
|
||||
"alerts_generated": self.global_metrics["alerts_generated"],
|
||||
"num_tasks_tracked": len(self.task_metrics)
|
||||
}
|
||||
|
||||
def get_alerts(
|
||||
self,
|
||||
limit: int = 50,
|
||||
alert_type: Optional[str] = None
|
||||
) -> List[Dict[str, Any]]:
|
||||
"""
|
||||
Retourne l'historique des alertes.
|
||||
|
||||
Args:
|
||||
limit: Nombre maximum d'alertes à retourner
|
||||
alert_type: Filtrer par type d'alerte
|
||||
|
||||
Returns:
|
||||
Liste des alertes
|
||||
"""
|
||||
alerts = self.alerts
|
||||
|
||||
if alert_type:
|
||||
alerts = [a for a in alerts if a["type"] == alert_type]
|
||||
|
||||
return alerts[-limit:]
|
||||
|
||||
def _generate_alert(
|
||||
self,
|
||||
alert_type: str,
|
||||
data: Dict[str, Any]
|
||||
) -> Dict[str, Any]:
|
||||
"""
|
||||
Génère une alerte.
|
||||
|
||||
Args:
|
||||
alert_type: Type d'alerte
|
||||
data: Données de l'alerte
|
||||
|
||||
Returns:
|
||||
Alerte générée
|
||||
"""
|
||||
alert = {
|
||||
"type": alert_type,
|
||||
"timestamp": datetime.now().isoformat(),
|
||||
"data": data
|
||||
}
|
||||
|
||||
self.alerts.append(alert)
|
||||
self.global_metrics["alerts_generated"] += 1
|
||||
|
||||
self.logger.log_action({
|
||||
"action": "alert_generated",
|
||||
**alert
|
||||
})
|
||||
|
||||
return alert
|
||||
|
||||
def reset_metrics(self, task_id: Optional[str] = None):
|
||||
"""
|
||||
Réinitialise les métriques.
|
||||
|
||||
Args:
|
||||
task_id: ID de la tâche (None pour tout réinitialiser)
|
||||
"""
|
||||
if task_id:
|
||||
if task_id in self.task_metrics:
|
||||
del self.task_metrics[task_id]
|
||||
self.logger.log_action({
|
||||
"action": "task_metrics_reset",
|
||||
"task_id": task_id
|
||||
})
|
||||
else:
|
||||
self.task_metrics.clear()
|
||||
self.global_metrics = {
|
||||
"total_latencies": [],
|
||||
"total_successes": 0,
|
||||
"total_failures": 0,
|
||||
"total_corrections": 0,
|
||||
"total_executions": 0,
|
||||
"alerts_generated": 0
|
||||
}
|
||||
self.alerts.clear()
|
||||
self.logger.log_action({
|
||||
"action": "all_metrics_reset"
|
||||
})
|
||||
379
geniusia2/core/models.py
Normal file
379
geniusia2/core/models.py
Normal file
@@ -0,0 +1,379 @@
|
||||
"""
|
||||
Modèles de données pour RPA Vision V2
|
||||
Contient les dataclasses pour TaskProfile, Action et Detection
|
||||
"""
|
||||
|
||||
from dataclasses import dataclass, field, asdict
|
||||
from datetime import datetime
|
||||
from typing import List, Dict, Any, Tuple, Optional
|
||||
import json
|
||||
import numpy as np
|
||||
|
||||
|
||||
@dataclass
|
||||
class Detection:
|
||||
"""
|
||||
Représente une détection d'élément UI par un modèle de vision
|
||||
|
||||
Attributes:
|
||||
label: Nom/label de l'élément détecté
|
||||
confidence: Score de confiance de la détection (0-1)
|
||||
bbox: Bounding box (x, y, width, height) en pixels
|
||||
embedding: Embedding visuel 512-d de l'élément
|
||||
model_source: Modèle ayant effectué la détection ("owl-v2", "dino", "yolo")
|
||||
roi_image: Image de la région d'intérêt (optionnel)
|
||||
metadata: Métadonnées additionnelles du modèle
|
||||
"""
|
||||
label: str
|
||||
confidence: float
|
||||
bbox: Tuple[int, int, int, int]
|
||||
embedding: np.ndarray
|
||||
model_source: str
|
||||
roi_image: Optional[np.ndarray] = None
|
||||
metadata: Dict[str, Any] = field(default_factory=dict)
|
||||
|
||||
def to_dict(self) -> Dict[str, Any]:
|
||||
"""
|
||||
Convertit la détection en dictionnaire pour sérialisation
|
||||
Note: Les arrays numpy ne sont pas sérialisés directement
|
||||
"""
|
||||
return {
|
||||
"label": self.label,
|
||||
"confidence": float(self.confidence),
|
||||
"bbox": list(self.bbox),
|
||||
"model_source": self.model_source,
|
||||
"metadata": self.metadata,
|
||||
# embedding et roi_image sont exclus car non JSON-sérialisables
|
||||
}
|
||||
|
||||
@classmethod
|
||||
def from_dict(cls, data: Dict[str, Any], embedding: Optional[np.ndarray] = None) -> 'Detection':
|
||||
"""
|
||||
Crée une instance Detection depuis un dictionnaire
|
||||
|
||||
Args:
|
||||
data: Dictionnaire contenant les données de détection
|
||||
embedding: Embedding numpy (doit être fourni séparément)
|
||||
"""
|
||||
return cls(
|
||||
label=data["label"],
|
||||
confidence=data["confidence"],
|
||||
bbox=tuple(data["bbox"]),
|
||||
embedding=embedding if embedding is not None else np.array([]),
|
||||
model_source=data["model_source"],
|
||||
metadata=data.get("metadata", {})
|
||||
)
|
||||
|
||||
|
||||
@dataclass
|
||||
class Action:
|
||||
"""
|
||||
Représente une action UI effectuée ou suggérée
|
||||
|
||||
Attributes:
|
||||
action_type: Type d'action ("click", "type", "scroll", "wait")
|
||||
target_element: Nom de l'élément cible
|
||||
bbox: Bounding box de l'élément cible
|
||||
confidence: Score de confiance pour cette action
|
||||
embedding: Embedding visuel de l'élément cible
|
||||
timestamp: Horodatage de l'action
|
||||
window_title: Titre de la fenêtre où l'action est effectuée
|
||||
parameters: Paramètres additionnels (ex: texte à taper, direction de scroll)
|
||||
result: Résultat de l'exécution ("success", "failed", "pending")
|
||||
"""
|
||||
action_type: str
|
||||
target_element: str
|
||||
bbox: Tuple[int, int, int, int]
|
||||
confidence: float
|
||||
embedding: np.ndarray
|
||||
timestamp: datetime
|
||||
window_title: str
|
||||
parameters: Dict[str, Any] = field(default_factory=dict)
|
||||
result: str = "pending"
|
||||
|
||||
def to_dict(self) -> Dict[str, Any]:
|
||||
"""
|
||||
Convertit l'action en dictionnaire pour sérialisation
|
||||
"""
|
||||
return {
|
||||
"action_type": self.action_type,
|
||||
"target_element": self.target_element,
|
||||
"bbox": list(self.bbox),
|
||||
"confidence": float(self.confidence),
|
||||
"timestamp": self.timestamp.isoformat(),
|
||||
"window_title": self.window_title,
|
||||
"parameters": self.parameters,
|
||||
"result": self.result,
|
||||
}
|
||||
|
||||
@classmethod
|
||||
def from_dict(cls, data: Dict[str, Any], embedding: Optional[np.ndarray] = None) -> 'Action':
|
||||
"""
|
||||
Crée une instance Action depuis un dictionnaire
|
||||
|
||||
Args:
|
||||
data: Dictionnaire contenant les données d'action
|
||||
embedding: Embedding numpy (doit être fourni séparément)
|
||||
"""
|
||||
return cls(
|
||||
action_type=data["action_type"],
|
||||
target_element=data["target_element"],
|
||||
bbox=tuple(data["bbox"]),
|
||||
confidence=data["confidence"],
|
||||
embedding=embedding if embedding is not None else np.array([]),
|
||||
timestamp=datetime.fromisoformat(data["timestamp"]),
|
||||
window_title=data["window_title"],
|
||||
parameters=data.get("parameters", {}),
|
||||
result=data.get("result", "pending"),
|
||||
)
|
||||
|
||||
def get_inverse_action(self) -> Optional['Action']:
|
||||
"""
|
||||
Retourne l'action inverse pour rollback (si applicable)
|
||||
"""
|
||||
# Cette méthode sera implémentée plus tard dans input_utils
|
||||
# Pour l'instant, retourne None
|
||||
return None
|
||||
|
||||
|
||||
@dataclass
|
||||
class TaskProfile:
|
||||
"""
|
||||
Profil d'une tâche apprise par le système
|
||||
|
||||
Attributes:
|
||||
task_id: Identifiant unique de la tâche
|
||||
task_name: Nom descriptif de la tâche
|
||||
mode: Mode opérationnel actuel ("shadow", "assist", "auto")
|
||||
observation_count: Nombre d'observations de cette tâche
|
||||
concordance_rate: Taux de concordance (0-1)
|
||||
confidence_score: Score de confiance global (0-1)
|
||||
correction_count: Nombre de corrections reçues
|
||||
last_execution: Horodatage de la dernière exécution
|
||||
window_whitelist: Liste des fenêtres autorisées pour cette tâche
|
||||
action_sequence: Séquence d'actions composant la tâche
|
||||
embeddings: Liste des embeddings visuels associés
|
||||
metadata: Métadonnées additionnelles
|
||||
execution_history: Historique des exécutions récentes
|
||||
"""
|
||||
task_id: str
|
||||
task_name: str
|
||||
mode: str = "shadow"
|
||||
observation_count: int = 0
|
||||
concordance_rate: float = 0.0
|
||||
confidence_score: float = 0.0
|
||||
correction_count: int = 0
|
||||
last_execution: Optional[datetime] = None
|
||||
window_whitelist: List[str] = field(default_factory=list)
|
||||
action_sequence: List[Action] = field(default_factory=list)
|
||||
embeddings: List[np.ndarray] = field(default_factory=list)
|
||||
metadata: Dict[str, Any] = field(default_factory=dict)
|
||||
execution_history: List[Dict[str, Any]] = field(default_factory=list)
|
||||
|
||||
def to_json(self) -> str:
|
||||
"""
|
||||
Sérialise le profil de tâche en JSON
|
||||
Note: Les embeddings numpy ne sont pas inclus dans le JSON
|
||||
"""
|
||||
data = {
|
||||
"task_id": self.task_id,
|
||||
"task_name": self.task_name,
|
||||
"mode": self.mode,
|
||||
"observation_count": self.observation_count,
|
||||
"concordance_rate": float(self.concordance_rate),
|
||||
"confidence_score": float(self.confidence_score),
|
||||
"correction_count": self.correction_count,
|
||||
"last_execution": self.last_execution.isoformat() if self.last_execution else None,
|
||||
"window_whitelist": self.window_whitelist,
|
||||
"action_sequence": [action.to_dict() for action in self.action_sequence],
|
||||
"metadata": self.metadata,
|
||||
"execution_history": self.execution_history,
|
||||
}
|
||||
return json.dumps(data, indent=2, ensure_ascii=False)
|
||||
|
||||
@classmethod
|
||||
def from_json(cls, json_str: str, embeddings: Optional[List[np.ndarray]] = None) -> 'TaskProfile':
|
||||
"""
|
||||
Crée une instance TaskProfile depuis une chaîne JSON
|
||||
|
||||
Args:
|
||||
json_str: Chaîne JSON contenant les données du profil
|
||||
embeddings: Liste d'embeddings numpy (doivent être fournis séparément)
|
||||
"""
|
||||
data = json.loads(json_str)
|
||||
|
||||
# Reconstruire les actions
|
||||
actions = [Action.from_dict(action_data) for action_data in data.get("action_sequence", [])]
|
||||
|
||||
return cls(
|
||||
task_id=data["task_id"],
|
||||
task_name=data["task_name"],
|
||||
mode=data.get("mode", "shadow"),
|
||||
observation_count=data.get("observation_count", 0),
|
||||
concordance_rate=data.get("concordance_rate", 0.0),
|
||||
confidence_score=data.get("confidence_score", 0.0),
|
||||
correction_count=data.get("correction_count", 0),
|
||||
last_execution=datetime.fromisoformat(data["last_execution"]) if data.get("last_execution") else None,
|
||||
window_whitelist=data.get("window_whitelist", []),
|
||||
action_sequence=actions,
|
||||
embeddings=embeddings if embeddings is not None else [],
|
||||
metadata=data.get("metadata", {}),
|
||||
execution_history=data.get("execution_history", []),
|
||||
)
|
||||
|
||||
def get_historical_performance(self) -> float:
|
||||
"""
|
||||
Calcule la performance historique basée sur les exécutions récentes
|
||||
|
||||
Returns:
|
||||
Score de performance (0-1)
|
||||
"""
|
||||
if not self.execution_history:
|
||||
return 0.0
|
||||
|
||||
# Calculer le taux de succès sur les exécutions récentes
|
||||
recent_executions = self.execution_history[-10:] # 10 dernières exécutions
|
||||
success_count = sum(1 for exec in recent_executions if exec.get("result") == "success")
|
||||
|
||||
return success_count / len(recent_executions) if recent_executions else 0.0
|
||||
|
||||
def add_execution(self, result: str, confidence: float, latency_ms: float):
|
||||
"""
|
||||
Ajoute une exécution à l'historique
|
||||
|
||||
Args:
|
||||
result: Résultat de l'exécution ("success", "failed")
|
||||
confidence: Score de confiance de l'exécution
|
||||
latency_ms: Latence en millisecondes
|
||||
"""
|
||||
execution = {
|
||||
"timestamp": datetime.now().isoformat(),
|
||||
"result": result,
|
||||
"confidence": confidence,
|
||||
"latency_ms": latency_ms,
|
||||
}
|
||||
self.execution_history.append(execution)
|
||||
self.last_execution = datetime.now()
|
||||
|
||||
# Limiter l'historique aux 50 dernières exécutions
|
||||
if len(self.execution_history) > 50:
|
||||
self.execution_history = self.execution_history[-50:]
|
||||
|
||||
def update_concordance_rate(self, success: bool):
|
||||
"""
|
||||
Met à jour le taux de concordance basé sur le résultat d'une exécution
|
||||
|
||||
Args:
|
||||
success: True si l'exécution a réussi, False sinon
|
||||
"""
|
||||
# Utiliser une moyenne mobile pour le taux de concordance
|
||||
window_size = 10 # Fenêtre de 10 exécutions
|
||||
recent_executions = self.execution_history[-window_size:]
|
||||
|
||||
if recent_executions:
|
||||
success_count = sum(1 for exec in recent_executions if exec.get("result") == "success")
|
||||
self.concordance_rate = success_count / len(recent_executions)
|
||||
|
||||
def should_transition_to_auto(self, min_observations: int = 20, min_concordance: float = 0.95) -> bool:
|
||||
"""
|
||||
Vérifie si la tâche remplit les critères pour passer en mode Autopilot
|
||||
|
||||
Args:
|
||||
min_observations: Nombre minimum d'observations requises
|
||||
min_concordance: Taux de concordance minimum requis
|
||||
|
||||
Returns:
|
||||
True si les critères sont remplis
|
||||
"""
|
||||
return (self.observation_count >= min_observations and
|
||||
self.concordance_rate >= min_concordance)
|
||||
|
||||
def should_rollback_to_assist(self, min_confidence: float = 0.90) -> bool:
|
||||
"""
|
||||
Vérifie si la tâche doit être rétrogradée au mode Assisté
|
||||
|
||||
Args:
|
||||
min_confidence: Score de confiance minimum requis
|
||||
|
||||
Returns:
|
||||
True si la confiance est trop faible
|
||||
"""
|
||||
return self.mode == "auto" and self.confidence_score < min_confidence
|
||||
|
||||
|
||||
if __name__ == "__main__":
|
||||
# Tests basiques des modèles
|
||||
print("Test des modèles de données RPA Vision V2")
|
||||
print("=" * 50)
|
||||
|
||||
# Test Detection
|
||||
print("\n1. Test Detection:")
|
||||
detection = Detection(
|
||||
label="valider_button",
|
||||
confidence=0.93,
|
||||
bbox=(450, 320, 120, 40),
|
||||
embedding=np.random.rand(512),
|
||||
model_source="owl-v2",
|
||||
metadata={"class": "button"}
|
||||
)
|
||||
print(f" Label: {detection.label}")
|
||||
print(f" Confidence: {detection.confidence}")
|
||||
print(f" BBox: {detection.bbox}")
|
||||
print(f" Model: {detection.model_source}")
|
||||
det_dict = detection.to_dict()
|
||||
print(f" Dict keys: {list(det_dict.keys())}")
|
||||
|
||||
# Test Action
|
||||
print("\n2. Test Action:")
|
||||
action = Action(
|
||||
action_type="click",
|
||||
target_element="valider_button",
|
||||
bbox=(450, 320, 120, 40),
|
||||
confidence=0.95,
|
||||
embedding=np.random.rand(512),
|
||||
timestamp=datetime.now(),
|
||||
window_title="Dolibarr - Facturation",
|
||||
parameters={"button": "left"},
|
||||
result="success"
|
||||
)
|
||||
print(f" Type: {action.action_type}")
|
||||
print(f" Target: {action.target_element}")
|
||||
print(f" Window: {action.window_title}")
|
||||
print(f" Result: {action.result}")
|
||||
action_dict = action.to_dict()
|
||||
print(f" Dict keys: {list(action_dict.keys())}")
|
||||
|
||||
# Test TaskProfile
|
||||
print("\n3. Test TaskProfile:")
|
||||
task = TaskProfile(
|
||||
task_id="ouvrir_facture_001",
|
||||
task_name="Ouvrir Facture",
|
||||
mode="assist",
|
||||
observation_count=15,
|
||||
concordance_rate=0.87,
|
||||
confidence_score=0.92,
|
||||
window_whitelist=["Dolibarr - Facturation"],
|
||||
action_sequence=[action]
|
||||
)
|
||||
print(f" Task ID: {task.task_id}")
|
||||
print(f" Mode: {task.mode}")
|
||||
print(f" Observations: {task.observation_count}")
|
||||
print(f" Concordance: {task.concordance_rate:.2%}")
|
||||
print(f" Confidence: {task.confidence_score:.2%}")
|
||||
|
||||
# Test sérialisation JSON
|
||||
print("\n4. Test sérialisation JSON:")
|
||||
json_str = task.to_json()
|
||||
print(f" JSON length: {len(json_str)} chars")
|
||||
|
||||
# Test désérialisation
|
||||
task_restored = TaskProfile.from_json(json_str)
|
||||
print(f" Restored task ID: {task_restored.task_id}")
|
||||
print(f" Restored mode: {task_restored.mode}")
|
||||
|
||||
# Test méthodes de transition
|
||||
print("\n5. Test méthodes de transition:")
|
||||
print(f" Should transition to auto: {task.should_transition_to_auto()}")
|
||||
print(f" Should rollback to assist: {task.should_rollback_to_assist()}")
|
||||
|
||||
print("\n✓ Tous les tests basiques réussis!")
|
||||
915
geniusia2/core/multimodal_embedding_manager.py
Normal file
915
geniusia2/core/multimodal_embedding_manager.py
Normal file
@@ -0,0 +1,915 @@
|
||||
"""
|
||||
Gestionnaire d'embeddings multi-modaux pour la Phase 3 - Mode Complet.
|
||||
Fusionne les embeddings de différentes modalités en un seul embedding unifié.
|
||||
|
||||
Modalités supportées:
|
||||
- Image (screenshot entier)
|
||||
- Texte (texte détecté)
|
||||
- Titre (window_title)
|
||||
- UI (éléments UI importants)
|
||||
- Contexte (métadonnées workflow)
|
||||
"""
|
||||
|
||||
import numpy as np
|
||||
from typing import List, Optional, Dict, Any
|
||||
from pathlib import Path
|
||||
|
||||
from .ui_element_models import (
|
||||
UIElement,
|
||||
StateEmbedding,
|
||||
EmbeddingComponents,
|
||||
ComponentInfo
|
||||
)
|
||||
from .llm_manager import LLMManager
|
||||
from .logger import Logger
|
||||
|
||||
# Import optionnel de EmbeddingManager
|
||||
try:
|
||||
from .embedders.embedding_manager import EmbeddingManager as BaseEmbeddingManager
|
||||
except ImportError:
|
||||
BaseEmbeddingManager = None
|
||||
|
||||
from dataclasses import dataclass
|
||||
|
||||
|
||||
@dataclass
|
||||
class EmbeddingWeights:
|
||||
"""Poids pour la fusion des embeddings multi-modaux."""
|
||||
image: float = 0.4
|
||||
text: float = 0.2
|
||||
title: float = 0.1
|
||||
ui: float = 0.2
|
||||
context: float = 0.1
|
||||
|
||||
def normalize(self) -> 'EmbeddingWeights':
|
||||
"""Normalise les poids pour qu'ils somment à 1.0."""
|
||||
total = self.image + self.text + self.title + self.ui + self.context
|
||||
if total == 0:
|
||||
return EmbeddingWeights()
|
||||
|
||||
return EmbeddingWeights(
|
||||
image=self.image / total,
|
||||
text=self.text / total,
|
||||
title=self.title / total,
|
||||
ui=self.ui / total,
|
||||
context=self.context / total
|
||||
)
|
||||
|
||||
def to_dict(self) -> Dict[str, float]:
|
||||
"""Convertit en dictionnaire."""
|
||||
return {
|
||||
"image": self.image,
|
||||
"text": self.text,
|
||||
"title": self.title,
|
||||
"ui": self.ui,
|
||||
"context": self.context
|
||||
}
|
||||
|
||||
@classmethod
|
||||
def from_dict(cls, data: Dict[str, float]) -> 'EmbeddingWeights':
|
||||
"""Crée depuis un dictionnaire."""
|
||||
return cls(
|
||||
image=data.get("image", 0.4),
|
||||
text=data.get("text", 0.2),
|
||||
title=data.get("title", 0.1),
|
||||
ui=data.get("ui", 0.2),
|
||||
context=data.get("context", 0.1)
|
||||
)
|
||||
|
||||
|
||||
class MultiModalEmbeddingManager:
|
||||
"""
|
||||
Gestionnaire d'embeddings multi-modaux.
|
||||
|
||||
Fusionne les embeddings de 5 modalités:
|
||||
1. Image globale (screenshot)
|
||||
2. Texte détecté (OCR/VLM)
|
||||
3. Titre de fenêtre
|
||||
4. Éléments UI (moyenne des éléments importants)
|
||||
5. Contexte workflow
|
||||
|
||||
La fusion est une combinaison pondérée normalisée.
|
||||
"""
|
||||
|
||||
def __init__(
|
||||
self,
|
||||
embedding_manager: Optional[BaseEmbeddingManager] = None,
|
||||
logger: Optional[Logger] = None,
|
||||
data_dir: str = "data",
|
||||
config: Optional[Dict[str, Any]] = None
|
||||
):
|
||||
"""
|
||||
Initialise le gestionnaire d'embeddings multi-modaux.
|
||||
|
||||
Args:
|
||||
embedding_manager: Gestionnaire d'embeddings existant
|
||||
logger: Logger
|
||||
data_dir: Répertoire de données
|
||||
config: Configuration
|
||||
"""
|
||||
self.embedding_manager = embedding_manager
|
||||
self.logger = logger
|
||||
self.data_dir = Path(data_dir)
|
||||
self.config = config or {}
|
||||
|
||||
# Configuration
|
||||
self.embedding_dim = self.config.get("embedding_dim", 512)
|
||||
self.fusion_method = self.config.get("fusion_method", "weighted_average")
|
||||
self.use_cache = self.config.get("use_cache", True)
|
||||
|
||||
# Poids par défaut
|
||||
weights_config = self.config.get("weights", {})
|
||||
self.default_weights = EmbeddingWeights.from_dict(weights_config).normalize()
|
||||
|
||||
# Poids de fusion (pour compatibilité)
|
||||
self.weights = {
|
||||
'image': self.default_weights.image,
|
||||
'text': self.default_weights.text,
|
||||
'title': self.default_weights.title,
|
||||
'ui': self.default_weights.ui,
|
||||
'context': self.default_weights.context
|
||||
}
|
||||
|
||||
# Cache des embeddings
|
||||
self._embedding_cache = {} if self.use_cache else None
|
||||
|
||||
# Créer les répertoires
|
||||
self.embeddings_dir = self.data_dir / "embeddings" / "multimodal"
|
||||
self.embeddings_dir.mkdir(parents=True, exist_ok=True)
|
||||
|
||||
if self.logger:
|
||||
self.logger.log_action({
|
||||
"action": "multimodal_embedding_manager_initialized",
|
||||
"embedding_dim": self.embedding_dim,
|
||||
"fusion_method": self.fusion_method,
|
||||
"default_weights": self.default_weights.to_dict()
|
||||
})
|
||||
|
||||
def create_state_embedding(
|
||||
self,
|
||||
screenshot: np.ndarray,
|
||||
detected_text: List[str],
|
||||
window_title: str,
|
||||
ui_elements: List[UIElement],
|
||||
context: Optional[Dict[str, Any]] = None,
|
||||
data_dir: str = "data"
|
||||
) -> StateEmbedding:
|
||||
"""
|
||||
Crée un embedding d'état unifié en fusionnant toutes les modalités.
|
||||
|
||||
Args:
|
||||
screenshot: Screenshot numpy array
|
||||
detected_text: Liste de texte détecté
|
||||
window_title: Titre de la fenêtre
|
||||
ui_elements: Liste des éléments UI
|
||||
context: Contexte workflow (optionnel)
|
||||
data_dir: Répertoire de données
|
||||
|
||||
Returns:
|
||||
StateEmbedding avec fusion multi-modale
|
||||
"""
|
||||
# Créer le répertoire pour les embeddings
|
||||
embeddings_dir = Path(data_dir) / "embeddings" / "multimodal"
|
||||
embeddings_dir.mkdir(parents=True, exist_ok=True)
|
||||
|
||||
# Générer un ID unique pour cet état
|
||||
import time
|
||||
state_id = f"state_{int(time.time() * 1000000)}"
|
||||
|
||||
# Composante 1: Image globale
|
||||
image_emb, image_path = self._compute_image_embedding(
|
||||
screenshot, state_id, embeddings_dir
|
||||
)
|
||||
image_emb_norm = self._normalize(image_emb)
|
||||
|
||||
# Composante 2: Texte concaténé
|
||||
text_emb, text_path = self._compute_text_embedding(
|
||||
detected_text, state_id, embeddings_dir
|
||||
)
|
||||
text_emb_norm = self._normalize(text_emb)
|
||||
|
||||
# Composante 3: Titre de fenêtre
|
||||
title_emb, title_path = self._compute_title_embedding(
|
||||
window_title, state_id, embeddings_dir
|
||||
)
|
||||
title_emb_norm = self._normalize(title_emb)
|
||||
|
||||
# Composante 4: UI éléments
|
||||
ui_emb, ui_path = self._compute_ui_embedding(
|
||||
ui_elements, state_id, embeddings_dir
|
||||
)
|
||||
ui_emb_norm = self._normalize(ui_emb)
|
||||
|
||||
# Composante 5: Contexte
|
||||
context_emb, context_path = self._compute_context_embedding(
|
||||
context, state_id, embeddings_dir
|
||||
)
|
||||
context_emb_norm = self._normalize(context_emb)
|
||||
|
||||
# Fusion pondérée
|
||||
state_emb = (
|
||||
self.weights['image'] * image_emb_norm +
|
||||
self.weights['text'] * text_emb_norm +
|
||||
self.weights['title'] * title_emb_norm +
|
||||
self.weights['ui'] * ui_emb_norm +
|
||||
self.weights['context'] * context_emb_norm
|
||||
)
|
||||
|
||||
# Normalisation finale
|
||||
state_emb_final = self._normalize(state_emb)
|
||||
|
||||
# Sauvegarder l'embedding fusionné
|
||||
fused_path = embeddings_dir / f"{state_id}_fused.npy"
|
||||
np.save(fused_path, state_emb_final)
|
||||
|
||||
# Créer les composantes
|
||||
components = EmbeddingComponents(
|
||||
image_embedding=ComponentInfo(
|
||||
provider="openclip_ViT-B-32",
|
||||
vector_id=str(image_path)
|
||||
),
|
||||
text_embedding=ComponentInfo(
|
||||
provider="clip_text",
|
||||
vector_id=str(text_path)
|
||||
),
|
||||
title_embedding=ComponentInfo(
|
||||
provider="clip_text",
|
||||
vector_id=str(title_path)
|
||||
),
|
||||
ui_embedding=ComponentInfo(
|
||||
provider="openclip_ViT-B-32",
|
||||
vector_id=str(ui_path)
|
||||
),
|
||||
context_embedding=ComponentInfo(
|
||||
provider="numeric_context_v1",
|
||||
vector_id=str(context_path)
|
||||
) if self.weights['context'] > 0 else None
|
||||
)
|
||||
|
||||
# Créer le StateEmbedding
|
||||
state_embedding = StateEmbedding(
|
||||
provider="multimodal_fusion_v1",
|
||||
vector_id=str(fused_path),
|
||||
components=components
|
||||
)
|
||||
|
||||
if self.logger:
|
||||
self.logger.log_action({
|
||||
"action": "state_embedding_created",
|
||||
"state_id": state_id,
|
||||
"components": {
|
||||
"image": image_emb.shape,
|
||||
"text": text_emb.shape,
|
||||
"title": title_emb.shape,
|
||||
"ui": ui_emb.shape,
|
||||
"context": context_emb.shape
|
||||
}
|
||||
})
|
||||
|
||||
return state_embedding
|
||||
|
||||
def _compute_image_embedding(
|
||||
self,
|
||||
screenshot: np.ndarray,
|
||||
state_id: str,
|
||||
embeddings_dir: Path
|
||||
) -> tuple:
|
||||
"""Calcule l'embedding de l'image globale."""
|
||||
try:
|
||||
# Convertir en PIL Image
|
||||
from PIL import Image
|
||||
if screenshot.shape[2] == 3:
|
||||
# BGR to RGB
|
||||
screenshot_rgb = screenshot[:, :, ::-1]
|
||||
else:
|
||||
screenshot_rgb = screenshot
|
||||
|
||||
pil_image = Image.fromarray(screenshot_rgb.astype(np.uint8))
|
||||
|
||||
# Générer l'embedding
|
||||
embedding = self.image_embedder.embed(pil_image)
|
||||
|
||||
# Sauvegarder
|
||||
path = embeddings_dir / f"{state_id}_image.npy"
|
||||
np.save(path, embedding)
|
||||
|
||||
return embedding, path
|
||||
|
||||
except Exception as e:
|
||||
if self.logger:
|
||||
self.logger.log_action({
|
||||
"action": "image_embedding_error",
|
||||
"error": str(e)
|
||||
})
|
||||
# Retourner un vecteur zéro
|
||||
embedding = np.zeros(self.embedding_dim)
|
||||
path = embeddings_dir / f"{state_id}_image.npy"
|
||||
np.save(path, embedding)
|
||||
return embedding, path
|
||||
|
||||
def _compute_text_embedding(
|
||||
self,
|
||||
detected_text: List[str],
|
||||
state_id: str,
|
||||
embeddings_dir: Path
|
||||
) -> tuple:
|
||||
"""Calcule l'embedding du texte concaténé."""
|
||||
try:
|
||||
# Concaténer le texte
|
||||
text_concat = " ".join(detected_text) if detected_text else ""
|
||||
|
||||
if not text_concat:
|
||||
# Pas de texte, retourner vecteur zéro
|
||||
embedding = np.zeros(self.embedding_dim)
|
||||
else:
|
||||
# Pour l'instant, utiliser un embedding simple
|
||||
# TODO: Intégrer avec un vrai text embedder
|
||||
embedding = self._simple_text_embedding(text_concat)
|
||||
|
||||
# Sauvegarder
|
||||
path = embeddings_dir / f"{state_id}_text.npy"
|
||||
np.save(path, embedding)
|
||||
|
||||
return embedding, path
|
||||
|
||||
except Exception as e:
|
||||
if self.logger:
|
||||
self.logger.log_action({
|
||||
"action": "text_embedding_error",
|
||||
"error": str(e)
|
||||
})
|
||||
embedding = np.zeros(self.embedding_dim)
|
||||
path = embeddings_dir / f"{state_id}_text.npy"
|
||||
np.save(path, embedding)
|
||||
return embedding, path
|
||||
|
||||
def _compute_title_embedding(
|
||||
self,
|
||||
window_title: str,
|
||||
state_id: str,
|
||||
embeddings_dir: Path
|
||||
) -> tuple:
|
||||
"""Calcule l'embedding du titre de fenêtre."""
|
||||
try:
|
||||
if not window_title:
|
||||
embedding = np.zeros(self.embedding_dim)
|
||||
else:
|
||||
embedding = self._simple_text_embedding(window_title)
|
||||
|
||||
# Sauvegarder
|
||||
path = embeddings_dir / f"{state_id}_title.npy"
|
||||
np.save(path, embedding)
|
||||
|
||||
return embedding, path
|
||||
|
||||
except Exception as e:
|
||||
if self.logger:
|
||||
self.logger.log_action({
|
||||
"action": "title_embedding_error",
|
||||
"error": str(e)
|
||||
})
|
||||
embedding = np.zeros(self.embedding_dim)
|
||||
path = embeddings_dir / f"{state_id}_title.npy"
|
||||
np.save(path, embedding)
|
||||
return embedding, path
|
||||
|
||||
def _compute_ui_embedding(
|
||||
self,
|
||||
ui_elements: List[UIElement],
|
||||
state_id: str,
|
||||
embeddings_dir: Path
|
||||
) -> tuple:
|
||||
"""Calcule l'embedding des éléments UI (moyenne des éléments importants)."""
|
||||
try:
|
||||
if not ui_elements:
|
||||
embedding = np.zeros(self.embedding_dim)
|
||||
else:
|
||||
# Filtrer les éléments importants
|
||||
important_elements = [
|
||||
elem for elem in ui_elements
|
||||
if elem.properties.is_clickable or 'primary_action' in elem.tags
|
||||
]
|
||||
|
||||
if not important_elements:
|
||||
# Prendre les 5 premiers éléments
|
||||
important_elements = ui_elements[:5]
|
||||
|
||||
# Charger et moyenner les embeddings
|
||||
embeddings = []
|
||||
for elem in important_elements:
|
||||
try:
|
||||
emb = np.load(elem.visual.embedding_vector_id)
|
||||
embeddings.append(emb)
|
||||
except:
|
||||
continue
|
||||
|
||||
if embeddings:
|
||||
embedding = np.mean(embeddings, axis=0)
|
||||
else:
|
||||
embedding = np.zeros(self.embedding_dim)
|
||||
|
||||
# Sauvegarder
|
||||
path = embeddings_dir / f"{state_id}_ui.npy"
|
||||
np.save(path, embedding)
|
||||
|
||||
return embedding, path
|
||||
|
||||
except Exception as e:
|
||||
if self.logger:
|
||||
self.logger.log_action({
|
||||
"action": "ui_embedding_error",
|
||||
"error": str(e)
|
||||
})
|
||||
embedding = np.zeros(self.embedding_dim)
|
||||
path = embeddings_dir / f"{state_id}_ui.npy"
|
||||
np.save(path, embedding)
|
||||
return embedding, path
|
||||
|
||||
def _compute_context_embedding(
|
||||
self,
|
||||
context: Optional[Dict[str, Any]],
|
||||
state_id: str,
|
||||
embeddings_dir: Path
|
||||
) -> tuple:
|
||||
"""Calcule l'embedding du contexte workflow."""
|
||||
try:
|
||||
if not context or self.weights['context'] == 0:
|
||||
embedding = np.zeros(self.embedding_dim)
|
||||
else:
|
||||
# Encoder les métadonnées de contexte en vecteur
|
||||
embedding = self._encode_context(context)
|
||||
|
||||
# Sauvegarder
|
||||
path = embeddings_dir / f"{state_id}_context.npy"
|
||||
np.save(path, embedding)
|
||||
|
||||
return embedding, path
|
||||
|
||||
except Exception as e:
|
||||
if self.logger:
|
||||
self.logger.log_action({
|
||||
"action": "context_embedding_error",
|
||||
"error": str(e)
|
||||
})
|
||||
embedding = np.zeros(self.embedding_dim)
|
||||
path = embeddings_dir / f"{state_id}_context.npy"
|
||||
np.save(path, embedding)
|
||||
return embedding, path
|
||||
|
||||
def _simple_text_embedding(self, text: str) -> np.ndarray:
|
||||
"""
|
||||
Crée un embedding simple de texte.
|
||||
TODO: Remplacer par un vrai text embedder (CLIP text, Sentence-BERT, etc.)
|
||||
"""
|
||||
# Pour l'instant, utiliser un hash simple
|
||||
import hashlib
|
||||
hash_obj = hashlib.sha256(text.encode('utf-8'))
|
||||
hash_bytes = hash_obj.digest()
|
||||
|
||||
# Convertir en vecteur de dimension embedding_dim
|
||||
embedding = np.zeros(self.embedding_dim)
|
||||
for i in range(min(len(hash_bytes), self.embedding_dim)):
|
||||
embedding[i] = hash_bytes[i] / 255.0
|
||||
|
||||
return embedding
|
||||
|
||||
def _encode_context(self, context: Dict[str, Any]) -> np.ndarray:
|
||||
"""
|
||||
Encode le contexte en vecteur numérique.
|
||||
TODO: Améliorer l'encodage du contexte.
|
||||
"""
|
||||
# Pour l'instant, encoder simplement les clés/valeurs
|
||||
context_str = str(context)
|
||||
return self._simple_text_embedding(context_str)
|
||||
|
||||
def _normalize(self, vector: np.ndarray) -> np.ndarray:
|
||||
"""Normalise un vecteur (norme L2 = 1.0)."""
|
||||
norm = np.linalg.norm(vector)
|
||||
if norm > 0:
|
||||
return vector / norm
|
||||
return vector
|
||||
|
||||
def get_weights(self) -> Dict[str, float]:
|
||||
"""Retourne les poids de fusion actuels."""
|
||||
return self.weights.copy()
|
||||
|
||||
def set_weights(self, weights: Dict[str, float]):
|
||||
"""
|
||||
Modifie les poids de fusion.
|
||||
|
||||
Args:
|
||||
weights: Dictionnaire des nouveaux poids
|
||||
"""
|
||||
self.weights.update(weights)
|
||||
|
||||
if self.logger:
|
||||
self.logger.log_action({
|
||||
"action": "weights_updated",
|
||||
"new_weights": self.weights
|
||||
})
|
||||
|
||||
def compute_similarity(
|
||||
self,
|
||||
embedding1: np.ndarray,
|
||||
embedding2: np.ndarray,
|
||||
metric: str = "cosine"
|
||||
) -> float:
|
||||
"""
|
||||
Calcule la similarité entre deux embeddings.
|
||||
|
||||
Args:
|
||||
embedding1: Premier embedding
|
||||
embedding2: Deuxième embedding
|
||||
metric: Métrique de similarité ("cosine" ou "euclidean")
|
||||
|
||||
Returns:
|
||||
Score de similarité entre 0.0 et 1.0
|
||||
"""
|
||||
try:
|
||||
if metric == "cosine":
|
||||
# Similarité cosinus
|
||||
dot_product = np.dot(embedding1, embedding2)
|
||||
norm1 = np.linalg.norm(embedding1)
|
||||
norm2 = np.linalg.norm(embedding2)
|
||||
|
||||
if norm1 == 0 or norm2 == 0:
|
||||
return 0.0
|
||||
|
||||
return float(dot_product / (norm1 * norm2))
|
||||
|
||||
elif metric == "euclidean":
|
||||
# Distance euclidienne (convertie en similarité)
|
||||
distance = np.linalg.norm(embedding1 - embedding2)
|
||||
return float(1.0 / (1.0 + distance))
|
||||
|
||||
else:
|
||||
raise ValueError(f"Métrique non supportée: {metric}")
|
||||
|
||||
except Exception as e:
|
||||
if self.logger:
|
||||
self.logger.log_action({
|
||||
"action": "similarity_computation_error",
|
||||
"metric": metric,
|
||||
"error": str(e)
|
||||
})
|
||||
return 0.0
|
||||
|
||||
def load_fused_embedding(self, vector_id: str) -> Optional[np.ndarray]:
|
||||
"""
|
||||
Charge un embedding fusionné depuis son vector_id.
|
||||
|
||||
Args:
|
||||
vector_id: ID du vecteur (chemin de fichier ou ID temporaire)
|
||||
|
||||
Returns:
|
||||
Embedding numpy array ou None si non trouvé
|
||||
"""
|
||||
try:
|
||||
if vector_id.startswith("temp_"):
|
||||
# Embedding temporaire, générer un embedding aléatoire
|
||||
return np.random.rand(self.embedding_dim)
|
||||
|
||||
# Charger depuis le fichier
|
||||
path = Path(vector_id)
|
||||
if path.exists():
|
||||
return np.load(path)
|
||||
else:
|
||||
# Fichier non trouvé, générer un embedding par défaut
|
||||
if self.logger:
|
||||
self.logger.log_action({
|
||||
"action": "fused_embedding_not_found",
|
||||
"vector_id": vector_id
|
||||
})
|
||||
return np.random.rand(self.embedding_dim)
|
||||
|
||||
except Exception as e:
|
||||
if self.logger:
|
||||
self.logger.log_action({
|
||||
"action": "fused_embedding_load_error",
|
||||
"vector_id": vector_id,
|
||||
"error": str(e)
|
||||
})
|
||||
return None
|
||||
|
||||
def generate_multimodal_embedding(
|
||||
self,
|
||||
screen_state,
|
||||
screenshot: Optional[np.ndarray] = None,
|
||||
weights: Optional[EmbeddingWeights] = None,
|
||||
save: bool = True
|
||||
) -> StateEmbedding:
|
||||
"""
|
||||
Génère un embedding multi-modal complet pour un état d'écran.
|
||||
|
||||
Args:
|
||||
screen_state: EnrichedScreenState
|
||||
screenshot: Screenshot numpy array (optionnel)
|
||||
weights: Poids de fusion (utilise les poids par défaut si None)
|
||||
save: Sauvegarder les embeddings
|
||||
|
||||
Returns:
|
||||
StateEmbedding avec composantes et embedding fusionné
|
||||
"""
|
||||
if weights is None:
|
||||
# Utiliser les poids de configuration
|
||||
weights = EmbeddingWeights(
|
||||
image=self.weights.get('image', 0.4),
|
||||
text=self.weights.get('text', 0.2),
|
||||
title=self.weights.get('title', 0.1),
|
||||
ui=self.weights.get('ui', 0.2),
|
||||
context=self.weights.get('context', 0.1)
|
||||
).normalize()
|
||||
else:
|
||||
weights = weights.normalize()
|
||||
|
||||
try:
|
||||
# Pour l'instant, générer un embedding simulé
|
||||
# TODO: Implémenter la vraie génération avec les embedders
|
||||
|
||||
# Créer les composantes
|
||||
components = EmbeddingComponents()
|
||||
|
||||
# Image embedding
|
||||
if screenshot is not None:
|
||||
components.image_embedding = ComponentInfo(
|
||||
provider="openclip_ViT-B-32",
|
||||
vector_id=f"temp_{screen_state.screen_state_id}_image"
|
||||
)
|
||||
|
||||
# Text embedding
|
||||
if screen_state.perception.detected_text:
|
||||
components.text_embedding = ComponentInfo(
|
||||
provider="clip_text",
|
||||
vector_id=f"temp_{screen_state.screen_state_id}_text"
|
||||
)
|
||||
|
||||
# Title embedding
|
||||
if screen_state.window.window_title:
|
||||
components.title_embedding = ComponentInfo(
|
||||
provider="clip_text",
|
||||
vector_id=f"temp_{screen_state.screen_state_id}_title"
|
||||
)
|
||||
|
||||
# UI embedding
|
||||
if screen_state.ui_elements:
|
||||
components.ui_embedding = ComponentInfo(
|
||||
provider="ui_aggregation_v1",
|
||||
vector_id=f"temp_{screen_state.screen_state_id}_ui"
|
||||
)
|
||||
|
||||
# Context embedding
|
||||
if screen_state.context.current_workflow_candidate or screen_state.context.tags:
|
||||
components.context_embedding = ComponentInfo(
|
||||
provider="context_embedding_v1",
|
||||
vector_id=f"temp_{screen_state.screen_state_id}_context"
|
||||
)
|
||||
|
||||
# Créer le StateEmbedding
|
||||
state_embedding = StateEmbedding(
|
||||
provider="multimodal_fusion_v1",
|
||||
vector_id=f"temp_{screen_state.screen_state_id}_fused",
|
||||
components=components
|
||||
)
|
||||
|
||||
if self.logger:
|
||||
self.logger.log_action({
|
||||
"action": "multimodal_embedding_generated",
|
||||
"screen_state_id": screen_state.screen_state_id,
|
||||
"provider": state_embedding.provider
|
||||
})
|
||||
|
||||
return state_embedding
|
||||
|
||||
except Exception as e:
|
||||
if self.logger:
|
||||
self.logger.log_action({
|
||||
"action": "multimodal_embedding_error",
|
||||
"screen_state_id": screen_state.screen_state_id,
|
||||
"error": str(e)
|
||||
})
|
||||
raise
|
||||
|
||||
|
||||
if __name__ == "__main__":
|
||||
# Tests basiques (sans dépendances lourdes)
|
||||
print("MultiModalEmbeddingManager - Tests basiques")
|
||||
print("=" * 50)
|
||||
|
||||
# Test normalisation (pas besoin de logger ou embedder)
|
||||
print("\n1. Test normalisation:")
|
||||
|
||||
# Créer une instance minimale pour tester la normalisation
|
||||
class MinimalManager:
|
||||
def _normalize(self, vector):
|
||||
norm = np.linalg.norm(vector)
|
||||
if norm > 0:
|
||||
return vector / norm
|
||||
return vector
|
||||
|
||||
manager = MinimalManager()
|
||||
vector = np.array([3.0, 4.0, 0.0])
|
||||
normalized = manager._normalize(vector)
|
||||
norm = np.linalg.norm(normalized)
|
||||
print(f" Vecteur original: {vector}")
|
||||
print(f" Vecteur normalisé: {normalized}")
|
||||
print(f" Norme: {norm:.6f}")
|
||||
assert abs(norm - 1.0) < 0.001, "La norme doit être 1.0"
|
||||
print(f" ✓ Normalisation correcte")
|
||||
|
||||
# Test poids par défaut
|
||||
print("\n2. Test configuration des poids:")
|
||||
default_weights = {
|
||||
'image': 0.5,
|
||||
'text': 0.3,
|
||||
'title': 0.1,
|
||||
'ui': 0.1,
|
||||
'context': 0.0
|
||||
}
|
||||
print(f" Poids par défaut: {default_weights}")
|
||||
total = sum(default_weights.values())
|
||||
print(f" Somme des poids: {total}")
|
||||
print(f" ✓ Configuration valide")
|
||||
|
||||
print("\n✓ Tous les tests basiques réussis!")
|
||||
|
||||
|
||||
def compute_similarity(
|
||||
self,
|
||||
embedding1: np.ndarray,
|
||||
embedding2: np.ndarray,
|
||||
metric: str = "cosine"
|
||||
) -> float:
|
||||
"""
|
||||
Calcule la similarité entre deux embeddings.
|
||||
|
||||
Args:
|
||||
embedding1: Premier embedding
|
||||
embedding2: Deuxième embedding
|
||||
metric: Métrique de similarité ("cosine" ou "euclidean")
|
||||
|
||||
Returns:
|
||||
Score de similarité entre 0.0 et 1.0
|
||||
"""
|
||||
try:
|
||||
if metric == "cosine":
|
||||
# Similarité cosinus
|
||||
dot_product = np.dot(embedding1, embedding2)
|
||||
norm1 = np.linalg.norm(embedding1)
|
||||
norm2 = np.linalg.norm(embedding2)
|
||||
|
||||
if norm1 == 0 or norm2 == 0:
|
||||
return 0.0
|
||||
|
||||
return float(dot_product / (norm1 * norm2))
|
||||
|
||||
elif metric == "euclidean":
|
||||
# Distance euclidienne (convertie en similarité)
|
||||
distance = np.linalg.norm(embedding1 - embedding2)
|
||||
return float(1.0 / (1.0 + distance))
|
||||
|
||||
else:
|
||||
raise ValueError(f"Métrique non supportée: {metric}")
|
||||
|
||||
except Exception as e:
|
||||
if self.logger:
|
||||
self.logger.log_action({
|
||||
"action": "similarity_computation_error",
|
||||
"metric": metric,
|
||||
"error": str(e)
|
||||
})
|
||||
return 0.0
|
||||
|
||||
def load_fused_embedding(self, vector_id: str) -> Optional[np.ndarray]:
|
||||
"""
|
||||
Charge un embedding fusionné depuis son vector_id.
|
||||
|
||||
Args:
|
||||
vector_id: ID du vecteur (chemin de fichier ou ID temporaire)
|
||||
|
||||
Returns:
|
||||
Embedding numpy array ou None si non trouvé
|
||||
"""
|
||||
try:
|
||||
if vector_id.startswith("temp_"):
|
||||
# Embedding temporaire, générer un embedding aléatoire
|
||||
return np.random.rand(self.embedding_dim)
|
||||
|
||||
# Charger depuis le fichier
|
||||
path = Path(vector_id)
|
||||
if path.exists():
|
||||
return np.load(path)
|
||||
else:
|
||||
# Fichier non trouvé, générer un embedding par défaut
|
||||
if self.logger:
|
||||
self.logger.log_action({
|
||||
"action": "fused_embedding_not_found",
|
||||
"vector_id": vector_id
|
||||
})
|
||||
return np.random.rand(self.embedding_dim)
|
||||
|
||||
except Exception as e:
|
||||
if self.logger:
|
||||
self.logger.log_action({
|
||||
"action": "fused_embedding_load_error",
|
||||
"vector_id": vector_id,
|
||||
"error": str(e)
|
||||
})
|
||||
return None
|
||||
|
||||
def generate_multimodal_embedding(
|
||||
self,
|
||||
screen_state,
|
||||
screenshot: Optional[np.ndarray] = None,
|
||||
weights: Optional[EmbeddingWeights] = None,
|
||||
save: bool = True
|
||||
) -> StateEmbedding:
|
||||
"""
|
||||
Génère un embedding multi-modal complet pour un état d'écran.
|
||||
|
||||
Args:
|
||||
screen_state: EnrichedScreenState
|
||||
screenshot: Screenshot numpy array (optionnel)
|
||||
weights: Poids de fusion (utilise les poids par défaut si None)
|
||||
save: Sauvegarder les embeddings
|
||||
|
||||
Returns:
|
||||
StateEmbedding avec composantes et embedding fusionné
|
||||
"""
|
||||
if weights is None:
|
||||
# Utiliser les poids de configuration
|
||||
weights = EmbeddingWeights(
|
||||
image=self.weights.get('image', 0.4),
|
||||
text=self.weights.get('text', 0.2),
|
||||
title=self.weights.get('title', 0.1),
|
||||
ui=self.weights.get('ui', 0.2),
|
||||
context=self.weights.get('context', 0.1)
|
||||
).normalize()
|
||||
else:
|
||||
weights = weights.normalize()
|
||||
|
||||
try:
|
||||
# Pour l'instant, générer un embedding simulé
|
||||
# TODO: Implémenter la vraie génération avec les embedders
|
||||
|
||||
# Créer les composantes
|
||||
components = EmbeddingComponents()
|
||||
|
||||
# Image embedding
|
||||
if screenshot is not None:
|
||||
components.image_embedding = ComponentInfo(
|
||||
provider="openclip_ViT-B-32",
|
||||
vector_id=f"temp_{screen_state.screen_state_id}_image"
|
||||
)
|
||||
|
||||
# Text embedding
|
||||
if screen_state.perception.detected_text:
|
||||
components.text_embedding = ComponentInfo(
|
||||
provider="clip_text",
|
||||
vector_id=f"temp_{screen_state.screen_state_id}_text"
|
||||
)
|
||||
|
||||
# Title embedding
|
||||
if screen_state.window.window_title:
|
||||
components.title_embedding = ComponentInfo(
|
||||
provider="clip_text",
|
||||
vector_id=f"temp_{screen_state.screen_state_id}_title"
|
||||
)
|
||||
|
||||
# UI embedding
|
||||
if screen_state.ui_elements:
|
||||
components.ui_embedding = ComponentInfo(
|
||||
provider="ui_aggregation_v1",
|
||||
vector_id=f"temp_{screen_state.screen_state_id}_ui"
|
||||
)
|
||||
|
||||
# Context embedding
|
||||
if screen_state.context.current_workflow_candidate or screen_state.context.tags:
|
||||
components.context_embedding = ComponentInfo(
|
||||
provider="context_embedding_v1",
|
||||
vector_id=f"temp_{screen_state.screen_state_id}_context"
|
||||
)
|
||||
|
||||
# Créer le StateEmbedding
|
||||
state_embedding = StateEmbedding(
|
||||
provider="multimodal_fusion_v1",
|
||||
vector_id=f"temp_{screen_state.screen_state_id}_fused",
|
||||
components=components
|
||||
)
|
||||
|
||||
if self.logger:
|
||||
self.logger.log_action({
|
||||
"action": "multimodal_embedding_generated",
|
||||
"screen_state_id": screen_state.screen_state_id,
|
||||
"provider": state_embedding.provider
|
||||
})
|
||||
|
||||
return state_embedding
|
||||
|
||||
except Exception as e:
|
||||
if self.logger:
|
||||
self.logger.log_action({
|
||||
"action": "multimodal_embedding_error",
|
||||
"screen_state_id": screen_state.screen_state_id,
|
||||
"error": str(e)
|
||||
})
|
||||
raise
|
||||
2257
geniusia2/core/orchestrator.py
Normal file
2257
geniusia2/core/orchestrator.py
Normal file
File diff suppressed because it is too large
Load Diff
460
geniusia2/core/replay_async.py
Normal file
460
geniusia2/core/replay_async.py
Normal file
@@ -0,0 +1,460 @@
|
||||
"""
|
||||
Moteur de rejeu d'actions pour rollback asynchrone.
|
||||
Permet de rejouer des séquences d'actions et d'annuler les dernières actions.
|
||||
"""
|
||||
|
||||
import asyncio
|
||||
import time
|
||||
from typing import List, Dict, Any, Optional, Callable
|
||||
from enum import Enum
|
||||
|
||||
from .utils.input_utils import InputUtils, ActionType
|
||||
from .logger import Logger
|
||||
|
||||
|
||||
class ReplayStatus(Enum):
|
||||
"""Statuts de rejeu."""
|
||||
IDLE = "idle"
|
||||
REPLAYING = "replaying"
|
||||
ROLLING_BACK = "rolling_back"
|
||||
PAUSED = "paused"
|
||||
COMPLETED = "completed"
|
||||
FAILED = "failed"
|
||||
|
||||
|
||||
class ReplayEngine:
|
||||
"""
|
||||
Moteur de rejeu d'actions pour exécution asynchrone et rollback.
|
||||
"""
|
||||
|
||||
def __init__(
|
||||
self,
|
||||
input_utils: InputUtils,
|
||||
logger: Logger,
|
||||
config: Dict[str, Any]
|
||||
):
|
||||
"""
|
||||
Initialise le moteur de rejeu.
|
||||
|
||||
Args:
|
||||
input_utils: Utilitaires d'entrée pour exécuter actions
|
||||
logger: Logger pour journalisation
|
||||
config: Configuration globale
|
||||
"""
|
||||
self.input_utils = input_utils
|
||||
self.logger = logger
|
||||
self.config = config
|
||||
|
||||
# État du moteur
|
||||
self.status = ReplayStatus.IDLE
|
||||
self.current_sequence: List[Dict[str, Any]] = []
|
||||
self.current_index = 0
|
||||
|
||||
# Callbacks pour notifications
|
||||
self.on_action_executed: Optional[Callable] = None
|
||||
self.on_sequence_completed: Optional[Callable] = None
|
||||
self.on_rollback_completed: Optional[Callable] = None
|
||||
self.on_error: Optional[Callable] = None
|
||||
|
||||
# Configuration
|
||||
self.delay_between_actions = config.get("replay", {}).get(
|
||||
"delay_between_actions", 0.5
|
||||
)
|
||||
self.max_rollback_attempts = config.get("replay", {}).get(
|
||||
"max_rollback_attempts", 3
|
||||
)
|
||||
|
||||
self.logger.log_action({
|
||||
"action": "replay_engine_initialized",
|
||||
"delay_between_actions": self.delay_between_actions,
|
||||
"max_rollback_attempts": self.max_rollback_attempts
|
||||
})
|
||||
|
||||
async def replay_sequence(
|
||||
self,
|
||||
action_sequence: List[Dict[str, Any]],
|
||||
start_index: int = 0
|
||||
) -> bool:
|
||||
"""
|
||||
Rejoue une séquence d'actions de manière asynchrone.
|
||||
|
||||
Args:
|
||||
action_sequence: Liste d'actions à rejouer
|
||||
start_index: Index de départ dans la séquence
|
||||
|
||||
Returns:
|
||||
True si succès complet, False sinon
|
||||
"""
|
||||
if self.status != ReplayStatus.IDLE:
|
||||
self.logger.log_action({
|
||||
"action": "replay_rejected",
|
||||
"reason": "engine_busy",
|
||||
"current_status": self.status.value
|
||||
})
|
||||
return False
|
||||
|
||||
self.status = ReplayStatus.REPLAYING
|
||||
self.current_sequence = action_sequence
|
||||
self.current_index = start_index
|
||||
|
||||
self.logger.log_action({
|
||||
"action": "replay_sequence_started",
|
||||
"num_actions": len(action_sequence),
|
||||
"start_index": start_index
|
||||
})
|
||||
|
||||
try:
|
||||
for i in range(start_index, len(action_sequence)):
|
||||
if self.status == ReplayStatus.PAUSED:
|
||||
# Attendre la reprise
|
||||
while self.status == ReplayStatus.PAUSED:
|
||||
await asyncio.sleep(0.1)
|
||||
|
||||
if self.status != ReplayStatus.REPLAYING:
|
||||
# Arrêt demandé
|
||||
break
|
||||
|
||||
action = action_sequence[i]
|
||||
self.current_index = i
|
||||
|
||||
# Exécuter l'action
|
||||
success = await self._execute_action(action)
|
||||
|
||||
if not success:
|
||||
self.logger.log_action({
|
||||
"action": "replay_action_failed",
|
||||
"index": i,
|
||||
"action_type": action.get("type")
|
||||
})
|
||||
|
||||
if self.on_error:
|
||||
self.on_error(i, action)
|
||||
|
||||
self.status = ReplayStatus.FAILED
|
||||
return False
|
||||
|
||||
# Notifier l'exécution
|
||||
if self.on_action_executed:
|
||||
self.on_action_executed(i, action)
|
||||
|
||||
# Attendre entre les actions
|
||||
if i < len(action_sequence) - 1:
|
||||
await asyncio.sleep(self.delay_between_actions)
|
||||
|
||||
# Séquence terminée avec succès
|
||||
self.status = ReplayStatus.COMPLETED
|
||||
|
||||
self.logger.log_action({
|
||||
"action": "replay_sequence_completed",
|
||||
"num_actions_executed": len(action_sequence) - start_index
|
||||
})
|
||||
|
||||
if self.on_sequence_completed:
|
||||
self.on_sequence_completed()
|
||||
|
||||
return True
|
||||
|
||||
except Exception as e:
|
||||
self.logger.log_action({
|
||||
"action": "replay_sequence_error",
|
||||
"error": str(e),
|
||||
"index": self.current_index
|
||||
})
|
||||
|
||||
self.status = ReplayStatus.FAILED
|
||||
|
||||
if self.on_error:
|
||||
self.on_error(self.current_index, None)
|
||||
|
||||
return False
|
||||
|
||||
finally:
|
||||
if self.status in [ReplayStatus.COMPLETED, ReplayStatus.FAILED]:
|
||||
self.status = ReplayStatus.IDLE
|
||||
|
||||
async def rollback_last_n(self, n: int) -> bool:
|
||||
"""
|
||||
Annule les n dernières actions en exécutant leurs inverses.
|
||||
|
||||
Args:
|
||||
n: Nombre d'actions à annuler
|
||||
|
||||
Returns:
|
||||
True si rollback réussi, False sinon
|
||||
"""
|
||||
if self.status != ReplayStatus.IDLE:
|
||||
self.logger.log_action({
|
||||
"action": "rollback_rejected",
|
||||
"reason": "engine_busy",
|
||||
"current_status": self.status.value
|
||||
})
|
||||
return False
|
||||
|
||||
# Récupérer les dernières actions de l'historique
|
||||
action_history = self.input_utils.get_action_history(limit=n)
|
||||
|
||||
if len(action_history) < n:
|
||||
self.logger.log_action({
|
||||
"action": "rollback_partial",
|
||||
"requested": n,
|
||||
"available": len(action_history)
|
||||
})
|
||||
n = len(action_history)
|
||||
|
||||
if n == 0:
|
||||
return True
|
||||
|
||||
self.status = ReplayStatus.ROLLING_BACK
|
||||
|
||||
self.logger.log_action({
|
||||
"action": "rollback_started",
|
||||
"num_actions": n
|
||||
})
|
||||
|
||||
try:
|
||||
# Inverser les actions (de la plus récente à la plus ancienne)
|
||||
actions_to_rollback = list(reversed(action_history[-n:]))
|
||||
|
||||
success_count = 0
|
||||
for i, action in enumerate(actions_to_rollback):
|
||||
# Générer l'action inverse
|
||||
inverse_action = self.input_utils.get_inverse_action(action)
|
||||
|
||||
if inverse_action is None:
|
||||
self.logger.log_action({
|
||||
"action": "rollback_action_not_invertible",
|
||||
"index": i,
|
||||
"action_type": action.get("type")
|
||||
})
|
||||
continue
|
||||
|
||||
# Exécuter l'action inverse avec retry
|
||||
success = await self._execute_action_with_retry(
|
||||
inverse_action,
|
||||
max_attempts=self.max_rollback_attempts
|
||||
)
|
||||
|
||||
if success:
|
||||
success_count += 1
|
||||
else:
|
||||
self.logger.log_action({
|
||||
"action": "rollback_action_failed",
|
||||
"index": i,
|
||||
"action_type": action.get("type"),
|
||||
"inverse_action": inverse_action
|
||||
})
|
||||
|
||||
# Attendre entre les actions
|
||||
if i < len(actions_to_rollback) - 1:
|
||||
await asyncio.sleep(self.delay_between_actions)
|
||||
|
||||
# Vérifier le succès
|
||||
all_success = success_count == len(actions_to_rollback)
|
||||
|
||||
self.logger.log_action({
|
||||
"action": "rollback_completed",
|
||||
"total_actions": len(actions_to_rollback),
|
||||
"successful": success_count,
|
||||
"failed": len(actions_to_rollback) - success_count,
|
||||
"all_success": all_success
|
||||
})
|
||||
|
||||
if self.on_rollback_completed:
|
||||
self.on_rollback_completed(all_success, success_count, len(actions_to_rollback))
|
||||
|
||||
return all_success
|
||||
|
||||
except Exception as e:
|
||||
self.logger.log_action({
|
||||
"action": "rollback_error",
|
||||
"error": str(e)
|
||||
})
|
||||
|
||||
if self.on_error:
|
||||
self.on_error(-1, None)
|
||||
|
||||
return False
|
||||
|
||||
finally:
|
||||
self.status = ReplayStatus.IDLE
|
||||
|
||||
async def execute_inverse_actions(
|
||||
self,
|
||||
actions: List[Dict[str, Any]]
|
||||
) -> bool:
|
||||
"""
|
||||
Exécute les actions inverses d'une liste d'actions.
|
||||
|
||||
Args:
|
||||
actions: Liste d'actions à inverser et exécuter
|
||||
|
||||
Returns:
|
||||
True si toutes les actions inverses ont été exécutées
|
||||
"""
|
||||
if self.status != ReplayStatus.IDLE:
|
||||
return False
|
||||
|
||||
self.status = ReplayStatus.ROLLING_BACK
|
||||
|
||||
try:
|
||||
# Inverser l'ordre et générer les actions inverses
|
||||
inverse_actions = []
|
||||
for action in reversed(actions):
|
||||
inverse = self.input_utils.get_inverse_action(action)
|
||||
if inverse:
|
||||
inverse_actions.append(inverse)
|
||||
|
||||
# Exécuter les actions inverses
|
||||
success = await self.replay_sequence(inverse_actions)
|
||||
|
||||
return success
|
||||
|
||||
finally:
|
||||
self.status = ReplayStatus.IDLE
|
||||
|
||||
async def _execute_action(self, action: Dict[str, Any]) -> bool:
|
||||
"""
|
||||
Exécute une action unique.
|
||||
|
||||
Args:
|
||||
action: Action à exécuter
|
||||
|
||||
Returns:
|
||||
True si succès, False sinon
|
||||
"""
|
||||
action_type = action.get("type")
|
||||
|
||||
try:
|
||||
if action_type == ActionType.CLICK.value:
|
||||
return self.input_utils.click(
|
||||
action["x"],
|
||||
action["y"],
|
||||
button=action.get("button", "left"),
|
||||
clicks=action.get("clicks", 1)
|
||||
)
|
||||
|
||||
elif action_type == ActionType.TYPE.value:
|
||||
return self.input_utils.type_text(
|
||||
action["text"],
|
||||
interval=action.get("interval", 0.0)
|
||||
)
|
||||
|
||||
elif action_type == ActionType.SCROLL.value:
|
||||
return self.input_utils.scroll(
|
||||
action["direction"],
|
||||
amount=action.get("amount", 3),
|
||||
x=action.get("x"),
|
||||
y=action.get("y")
|
||||
)
|
||||
|
||||
elif action_type == ActionType.WAIT.value:
|
||||
return self.input_utils.wait(action["duration"])
|
||||
|
||||
elif action_type == ActionType.MOVE.value:
|
||||
return self.input_utils.move(
|
||||
action["x"],
|
||||
action["y"],
|
||||
duration=action.get("duration", 0.2)
|
||||
)
|
||||
|
||||
elif action_type == ActionType.DRAG.value:
|
||||
return self.input_utils.drag(
|
||||
action["start_x"],
|
||||
action["start_y"],
|
||||
action["end_x"],
|
||||
action["end_y"],
|
||||
duration=action.get("duration", 0.5),
|
||||
button=action.get("button", "left")
|
||||
)
|
||||
|
||||
elif action_type == "press_key":
|
||||
# Action spéciale pour rollback de saisie
|
||||
import pyautogui
|
||||
key = action.get("key")
|
||||
presses = action.get("presses", 1)
|
||||
for _ in range(presses):
|
||||
pyautogui.press(key)
|
||||
await asyncio.sleep(0.05)
|
||||
return True
|
||||
|
||||
else:
|
||||
self.logger.log_action({
|
||||
"action": "unknown_action_type",
|
||||
"type": action_type
|
||||
})
|
||||
return False
|
||||
|
||||
except Exception as e:
|
||||
self.logger.log_action({
|
||||
"action": "execute_action_error",
|
||||
"action_type": action_type,
|
||||
"error": str(e)
|
||||
})
|
||||
return False
|
||||
|
||||
async def _execute_action_with_retry(
|
||||
self,
|
||||
action: Dict[str, Any],
|
||||
max_attempts: int = 3
|
||||
) -> bool:
|
||||
"""
|
||||
Exécute une action avec retry en cas d'échec.
|
||||
|
||||
Args:
|
||||
action: Action à exécuter
|
||||
max_attempts: Nombre maximum de tentatives
|
||||
|
||||
Returns:
|
||||
True si succès, False sinon
|
||||
"""
|
||||
for attempt in range(max_attempts):
|
||||
success = await self._execute_action(action)
|
||||
|
||||
if success:
|
||||
return True
|
||||
|
||||
if attempt < max_attempts - 1:
|
||||
# Attendre avant de réessayer
|
||||
await asyncio.sleep(0.5)
|
||||
self.logger.log_action({
|
||||
"action": "action_retry",
|
||||
"attempt": attempt + 1,
|
||||
"max_attempts": max_attempts
|
||||
})
|
||||
|
||||
return False
|
||||
|
||||
def pause(self):
|
||||
"""Met en pause le rejeu en cours."""
|
||||
if self.status == ReplayStatus.REPLAYING:
|
||||
self.status = ReplayStatus.PAUSED
|
||||
self.logger.log_action({"action": "replay_paused"})
|
||||
|
||||
def resume(self):
|
||||
"""Reprend le rejeu en pause."""
|
||||
if self.status == ReplayStatus.PAUSED:
|
||||
self.status = ReplayStatus.REPLAYING
|
||||
self.logger.log_action({"action": "replay_resumed"})
|
||||
|
||||
def stop(self):
|
||||
"""Arrête le rejeu en cours."""
|
||||
if self.status in [ReplayStatus.REPLAYING, ReplayStatus.PAUSED]:
|
||||
self.status = ReplayStatus.IDLE
|
||||
self.logger.log_action({"action": "replay_stopped"})
|
||||
|
||||
def get_status(self) -> Dict[str, Any]:
|
||||
"""
|
||||
Retourne l'état actuel du moteur.
|
||||
|
||||
Returns:
|
||||
Dictionnaire avec l'état
|
||||
"""
|
||||
return {
|
||||
"status": self.status.value,
|
||||
"current_index": self.current_index,
|
||||
"total_actions": len(self.current_sequence),
|
||||
"progress": (
|
||||
self.current_index / len(self.current_sequence)
|
||||
if len(self.current_sequence) > 0 else 0
|
||||
)
|
||||
}
|
||||
388
geniusia2/core/screen_state_manager.py
Normal file
388
geniusia2/core/screen_state_manager.py
Normal file
@@ -0,0 +1,388 @@
|
||||
"""
|
||||
Gestionnaire d'états d'écran pour le système RPA Vision V2.
|
||||
Gère la création et la persistence des EnrichedScreenState en mode light.
|
||||
|
||||
Phase 1 - Mode Light: Compatibilité arrière complète avec le système existant.
|
||||
"""
|
||||
|
||||
from datetime import datetime
|
||||
from pathlib import Path
|
||||
from typing import Optional, Dict, Any
|
||||
import numpy as np
|
||||
import json
|
||||
|
||||
from .ui_element_models import (
|
||||
EnrichedScreenState,
|
||||
WindowInfo,
|
||||
RawData,
|
||||
PerceptionData,
|
||||
StateEmbedding,
|
||||
ContextData,
|
||||
UIElement
|
||||
)
|
||||
from .logger import Logger
|
||||
|
||||
|
||||
class ScreenStateManager:
|
||||
"""
|
||||
Gestionnaire d'états d'écran enrichis.
|
||||
|
||||
En mode light (Phase 1):
|
||||
- Crée des EnrichedScreenState avec ui_elements vide
|
||||
- Utilise uniquement l'embedding image (pas de fusion multi-modale)
|
||||
- Assure la compatibilité arrière avec le système existant
|
||||
"""
|
||||
|
||||
def __init__(
|
||||
self,
|
||||
logger: Logger,
|
||||
data_dir: str = "data",
|
||||
mode: str = "light"
|
||||
):
|
||||
"""
|
||||
Initialise le gestionnaire d'états d'écran.
|
||||
|
||||
Args:
|
||||
logger: Logger pour journalisation
|
||||
data_dir: Répertoire de données
|
||||
mode: Mode de traitement ("light", "enriched", "complete")
|
||||
"""
|
||||
self.logger = logger
|
||||
self.data_dir = Path(data_dir)
|
||||
self.mode = mode
|
||||
|
||||
# Créer les répertoires nécessaires
|
||||
self.screens_dir = self.data_dir / "screens"
|
||||
self.embeddings_dir = self.data_dir / "embeddings" / "screens"
|
||||
self.states_dir = self.data_dir / "screen_states"
|
||||
|
||||
self.screens_dir.mkdir(parents=True, exist_ok=True)
|
||||
self.embeddings_dir.mkdir(parents=True, exist_ok=True)
|
||||
self.states_dir.mkdir(parents=True, exist_ok=True)
|
||||
|
||||
self.logger.log_action({
|
||||
"action": "screen_state_manager_initialized",
|
||||
"mode": self.mode,
|
||||
"data_dir": str(self.data_dir)
|
||||
})
|
||||
|
||||
def create_screen_state(
|
||||
self,
|
||||
session_id: str,
|
||||
window_title: str,
|
||||
app_name: str,
|
||||
screenshot_path: str,
|
||||
screen_resolution: tuple,
|
||||
embedding_provider: str = "openclip_ViT-B-32",
|
||||
embedding_vector_id: Optional[str] = None,
|
||||
detected_text: Optional[list] = None,
|
||||
context_tags: Optional[list] = None,
|
||||
workflow_candidate: Optional[str] = None
|
||||
) -> EnrichedScreenState:
|
||||
"""
|
||||
Crée un EnrichedScreenState en mode light.
|
||||
|
||||
Args:
|
||||
session_id: ID de session
|
||||
window_title: Titre de la fenêtre
|
||||
app_name: Nom de l'application
|
||||
screenshot_path: Chemin vers le screenshot
|
||||
screen_resolution: Résolution d'écran (width, height)
|
||||
embedding_provider: Provider de l'embedding
|
||||
embedding_vector_id: ID du vecteur d'embedding (généré si None)
|
||||
detected_text: Texte détecté (optionnel)
|
||||
context_tags: Tags de contexte (optionnel)
|
||||
workflow_candidate: Workflow candidat (optionnel)
|
||||
|
||||
Returns:
|
||||
EnrichedScreenState créé
|
||||
"""
|
||||
# Générer un ID unique pour l'état d'écran
|
||||
timestamp = datetime.now()
|
||||
screen_state_id = f"screen_{timestamp.strftime('%Y%m%d_%H%M%S_%f')}"
|
||||
|
||||
# Générer l'ID du vecteur d'embedding si non fourni
|
||||
if embedding_vector_id is None:
|
||||
embedding_vector_id = str(self.embeddings_dir / f"{screen_state_id}.npy")
|
||||
|
||||
# Créer les informations de fenêtre
|
||||
window = WindowInfo(
|
||||
app_name=app_name,
|
||||
window_title=window_title,
|
||||
screen_resolution=screen_resolution
|
||||
)
|
||||
|
||||
# Créer les données de perception
|
||||
perception = PerceptionData(
|
||||
detected_text=detected_text or [],
|
||||
ocr_results=None
|
||||
)
|
||||
|
||||
# Créer le state embedding (mode light: image uniquement)
|
||||
state_embedding = StateEmbedding(
|
||||
provider=embedding_provider,
|
||||
vector_id=embedding_vector_id,
|
||||
components=None # Pas de composantes en mode light
|
||||
)
|
||||
|
||||
# Créer le contexte
|
||||
context = ContextData(
|
||||
current_workflow_candidate=workflow_candidate,
|
||||
tags=context_tags or [],
|
||||
metadata={}
|
||||
)
|
||||
|
||||
# Créer l'EnrichedScreenState
|
||||
screen_state = EnrichedScreenState(
|
||||
screen_state_id=screen_state_id,
|
||||
timestamp=timestamp,
|
||||
session_id=session_id,
|
||||
window=window,
|
||||
raw=RawData(screenshot_path=screenshot_path),
|
||||
perception=perception,
|
||||
ui_elements=[], # Vide en mode light
|
||||
state_embedding=state_embedding,
|
||||
context=context,
|
||||
mode=self.mode
|
||||
)
|
||||
|
||||
self.logger.log_action({
|
||||
"action": "screen_state_created",
|
||||
"screen_state_id": screen_state_id,
|
||||
"mode": self.mode,
|
||||
"session_id": session_id,
|
||||
"app_name": app_name
|
||||
})
|
||||
|
||||
return screen_state
|
||||
|
||||
def save_screen_state(
|
||||
self,
|
||||
screen_state: EnrichedScreenState,
|
||||
save_embedding: bool = False,
|
||||
embedding_vector: Optional[np.ndarray] = None
|
||||
) -> Path:
|
||||
"""
|
||||
Sauvegarde un EnrichedScreenState sur disque.
|
||||
|
||||
Args:
|
||||
screen_state: État d'écran à sauvegarder
|
||||
save_embedding: Si True, sauvegarde aussi le vecteur d'embedding
|
||||
embedding_vector: Vecteur d'embedding à sauvegarder (si save_embedding=True)
|
||||
|
||||
Returns:
|
||||
Chemin du fichier JSON créé
|
||||
"""
|
||||
# Créer le fichier JSON
|
||||
state_file = self.states_dir / f"{screen_state.screen_state_id}.json"
|
||||
|
||||
try:
|
||||
# Sérialiser en JSON
|
||||
json_str = screen_state.to_json()
|
||||
|
||||
# Écrire le fichier
|
||||
with open(state_file, 'w', encoding='utf-8') as f:
|
||||
f.write(json_str)
|
||||
|
||||
# Sauvegarder l'embedding si demandé
|
||||
if save_embedding and embedding_vector is not None:
|
||||
embedding_path = Path(screen_state.state_embedding.vector_id)
|
||||
embedding_path.parent.mkdir(parents=True, exist_ok=True)
|
||||
np.save(embedding_path, embedding_vector)
|
||||
|
||||
self.logger.log_action({
|
||||
"action": "screen_state_saved",
|
||||
"screen_state_id": screen_state.screen_state_id,
|
||||
"file": str(state_file),
|
||||
"embedding_saved": save_embedding
|
||||
})
|
||||
|
||||
return state_file
|
||||
|
||||
except Exception as e:
|
||||
self.logger.log_action({
|
||||
"action": "screen_state_save_failed",
|
||||
"screen_state_id": screen_state.screen_state_id,
|
||||
"error": str(e)
|
||||
})
|
||||
raise
|
||||
|
||||
def load_screen_state(self, screen_state_id: str) -> Optional[EnrichedScreenState]:
|
||||
"""
|
||||
Charge un EnrichedScreenState depuis le disque.
|
||||
|
||||
Args:
|
||||
screen_state_id: ID de l'état d'écran à charger
|
||||
|
||||
Returns:
|
||||
EnrichedScreenState chargé ou None si non trouvé
|
||||
"""
|
||||
state_file = self.states_dir / f"{screen_state_id}.json"
|
||||
|
||||
if not state_file.exists():
|
||||
self.logger.log_action({
|
||||
"action": "screen_state_not_found",
|
||||
"screen_state_id": screen_state_id
|
||||
})
|
||||
return None
|
||||
|
||||
try:
|
||||
# Lire le fichier JSON
|
||||
with open(state_file, 'r', encoding='utf-8') as f:
|
||||
json_str = f.read()
|
||||
|
||||
# Désérialiser
|
||||
screen_state = EnrichedScreenState.from_json(json_str)
|
||||
|
||||
self.logger.log_action({
|
||||
"action": "screen_state_loaded",
|
||||
"screen_state_id": screen_state_id,
|
||||
"mode": screen_state.mode
|
||||
})
|
||||
|
||||
return screen_state
|
||||
|
||||
except Exception as e:
|
||||
self.logger.log_action({
|
||||
"action": "screen_state_load_failed",
|
||||
"screen_state_id": screen_state_id,
|
||||
"error": str(e)
|
||||
})
|
||||
return None
|
||||
|
||||
def load_embedding(self, vector_id: str) -> Optional[np.ndarray]:
|
||||
"""
|
||||
Charge un vecteur d'embedding depuis le disque.
|
||||
|
||||
Args:
|
||||
vector_id: Chemin vers le fichier .npy
|
||||
|
||||
Returns:
|
||||
Vecteur numpy ou None si non trouvé
|
||||
"""
|
||||
embedding_path = Path(vector_id)
|
||||
|
||||
if not embedding_path.exists():
|
||||
self.logger.log_action({
|
||||
"action": "embedding_not_found",
|
||||
"vector_id": vector_id
|
||||
})
|
||||
return None
|
||||
|
||||
try:
|
||||
embedding = np.load(embedding_path)
|
||||
return embedding
|
||||
|
||||
except Exception as e:
|
||||
self.logger.log_action({
|
||||
"action": "embedding_load_failed",
|
||||
"vector_id": vector_id,
|
||||
"error": str(e)
|
||||
})
|
||||
return None
|
||||
|
||||
def list_screen_states(
|
||||
self,
|
||||
session_id: Optional[str] = None,
|
||||
limit: Optional[int] = None
|
||||
) -> list:
|
||||
"""
|
||||
Liste les états d'écran disponibles.
|
||||
|
||||
Args:
|
||||
session_id: Filtrer par session (optionnel)
|
||||
limit: Limiter le nombre de résultats (optionnel)
|
||||
|
||||
Returns:
|
||||
Liste des screen_state_id
|
||||
"""
|
||||
state_files = sorted(self.states_dir.glob("*.json"), reverse=True)
|
||||
|
||||
screen_state_ids = []
|
||||
for state_file in state_files:
|
||||
if limit and len(screen_state_ids) >= limit:
|
||||
break
|
||||
|
||||
# Si on filtre par session, charger et vérifier
|
||||
if session_id:
|
||||
try:
|
||||
with open(state_file, 'r', encoding='utf-8') as f:
|
||||
data = json.load(f)
|
||||
if data.get("session_id") == session_id:
|
||||
screen_state_ids.append(state_file.stem)
|
||||
except:
|
||||
continue
|
||||
else:
|
||||
screen_state_ids.append(state_file.stem)
|
||||
|
||||
return screen_state_ids
|
||||
|
||||
|
||||
if __name__ == "__main__":
|
||||
# Tests basiques
|
||||
from .logger import Logger
|
||||
|
||||
print("Test du ScreenStateManager")
|
||||
print("=" * 50)
|
||||
|
||||
# Créer un logger de test
|
||||
logger = Logger(log_dir="test_logs")
|
||||
|
||||
# Créer le manager
|
||||
manager = ScreenStateManager(
|
||||
logger=logger,
|
||||
data_dir="test_data",
|
||||
mode="light"
|
||||
)
|
||||
|
||||
print("\n1. Test création d'un screen state:")
|
||||
screen_state = manager.create_screen_state(
|
||||
session_id="test_session_001",
|
||||
window_title="Test Window",
|
||||
app_name="test_app",
|
||||
screenshot_path="test_data/screens/test_001.png",
|
||||
screen_resolution=(1920, 1080),
|
||||
detected_text=["Test", "Button"],
|
||||
context_tags=["test"]
|
||||
)
|
||||
|
||||
print(f" Screen State ID: {screen_state.screen_state_id}")
|
||||
print(f" Mode: {screen_state.mode}")
|
||||
print(f" Session ID: {screen_state.session_id}")
|
||||
print(f" UI Elements: {len(screen_state.ui_elements)}")
|
||||
|
||||
print("\n2. Test sauvegarde:")
|
||||
# Créer un embedding de test
|
||||
test_embedding = np.random.rand(512)
|
||||
state_file = manager.save_screen_state(
|
||||
screen_state,
|
||||
save_embedding=True,
|
||||
embedding_vector=test_embedding
|
||||
)
|
||||
print(f" Saved to: {state_file}")
|
||||
|
||||
print("\n3. Test chargement:")
|
||||
loaded_state = manager.load_screen_state(screen_state.screen_state_id)
|
||||
if loaded_state:
|
||||
print(f" Loaded screen_state_id: {loaded_state.screen_state_id}")
|
||||
print(f" Loaded mode: {loaded_state.mode}")
|
||||
print(f" Loaded session_id: {loaded_state.session_id}")
|
||||
|
||||
print("\n4. Test chargement d'embedding:")
|
||||
loaded_embedding = manager.load_embedding(screen_state.state_embedding.vector_id)
|
||||
if loaded_embedding is not None:
|
||||
print(f" Loaded embedding shape: {loaded_embedding.shape}")
|
||||
print(f" Embeddings match: {np.allclose(test_embedding, loaded_embedding)}")
|
||||
|
||||
print("\n5. Test listage:")
|
||||
state_ids = manager.list_screen_states(session_id="test_session_001")
|
||||
print(f" Found {len(state_ids)} screen states")
|
||||
|
||||
print("\n✓ Tous les tests ScreenStateManager réussis!")
|
||||
|
||||
# Nettoyage
|
||||
import shutil
|
||||
if Path("test_data").exists():
|
||||
shutil.rmtree("test_data")
|
||||
if Path("test_logs").exists():
|
||||
shutil.rmtree("test_logs")
|
||||
198
geniusia2/core/session_manager.py
Normal file
198
geniusia2/core/session_manager.py
Normal file
@@ -0,0 +1,198 @@
|
||||
#!/usr/bin/env python3
|
||||
"""
|
||||
SessionManager - Segmente les actions en sessions pour détecter les workflows
|
||||
"""
|
||||
|
||||
from datetime import datetime, timedelta
|
||||
from typing import List, Dict, Any, Optional, Callable
|
||||
from dataclasses import dataclass, field
|
||||
|
||||
|
||||
@dataclass
|
||||
class Session:
|
||||
"""Représente une session d'actions utilisateur."""
|
||||
session_id: str
|
||||
start_time: datetime
|
||||
end_time: Optional[datetime] = None
|
||||
actions: List[Dict[str, Any]] = field(default_factory=list)
|
||||
window: Optional[str] = None
|
||||
|
||||
@property
|
||||
def duration(self) -> timedelta:
|
||||
"""Durée de la session."""
|
||||
if self.end_time:
|
||||
return self.end_time - self.start_time
|
||||
return datetime.now() - self.start_time
|
||||
|
||||
@property
|
||||
def action_count(self) -> int:
|
||||
"""Nombre d'actions dans la session."""
|
||||
return len(self.actions)
|
||||
|
||||
|
||||
class SessionManager:
|
||||
"""
|
||||
Gestionnaire de sessions pour segmenter les actions en sessions.
|
||||
Une session = groupe d'actions dans une fenêtre de temps.
|
||||
"""
|
||||
|
||||
def __init__(self, logger, config: Dict[str, Any]):
|
||||
"""
|
||||
Initialise le gestionnaire de sessions.
|
||||
|
||||
Args:
|
||||
logger: Logger pour journalisation
|
||||
config: Configuration globale
|
||||
"""
|
||||
self.logger = logger
|
||||
self.config = config
|
||||
|
||||
# Configuration
|
||||
self.session_timeout = config.get("workflow", {}).get(
|
||||
"session_timeout", 300 # 5 minutes par défaut
|
||||
)
|
||||
|
||||
# Session courante
|
||||
self.current_session: Optional[Session] = None
|
||||
|
||||
# Historique des sessions
|
||||
self.sessions: List[Session] = []
|
||||
|
||||
# Callback pour session complétée
|
||||
self.on_session_completed: Optional[Callable] = None
|
||||
|
||||
if self.logger:
|
||||
self.logger.log_action({
|
||||
"action": "session_manager_initialized",
|
||||
"session_timeout": self.session_timeout
|
||||
})
|
||||
|
||||
def add_action(self, action: Dict[str, Any]):
|
||||
"""
|
||||
Ajoute une action à la session courante.
|
||||
Crée une nouvelle session si nécessaire.
|
||||
|
||||
Args:
|
||||
action: Action à ajouter
|
||||
"""
|
||||
# Vérifier si on doit créer une nouvelle session
|
||||
if self.should_start_new_session(action):
|
||||
self.finalize_current_session()
|
||||
self.start_new_session(action)
|
||||
|
||||
# Ajouter l'action à la session courante
|
||||
if self.current_session:
|
||||
self.current_session.actions.append(action)
|
||||
|
||||
def should_start_new_session(self, action: Dict[str, Any]) -> bool:
|
||||
"""
|
||||
Détermine si une nouvelle session doit être créée.
|
||||
|
||||
Args:
|
||||
action: Action à évaluer
|
||||
|
||||
Returns:
|
||||
True si nouvelle session nécessaire
|
||||
"""
|
||||
# Pas de session courante
|
||||
if not self.current_session:
|
||||
return True
|
||||
|
||||
# Vérifier le timeout
|
||||
if self.current_session.actions:
|
||||
last_action = self.current_session.actions[-1]
|
||||
last_time = last_action.get("timestamp")
|
||||
|
||||
if last_time:
|
||||
if isinstance(last_time, str):
|
||||
last_time = datetime.fromisoformat(last_time)
|
||||
|
||||
current_time = action.get("timestamp")
|
||||
if isinstance(current_time, str):
|
||||
current_time = datetime.fromisoformat(current_time)
|
||||
elif not current_time:
|
||||
current_time = datetime.now()
|
||||
|
||||
time_gap = (current_time - last_time).total_seconds()
|
||||
|
||||
if time_gap > self.session_timeout:
|
||||
return True
|
||||
|
||||
# Changement de fenêtre majeur
|
||||
current_window = action.get("window")
|
||||
if current_window and self.current_session.window:
|
||||
if current_window != self.current_session.window:
|
||||
return True
|
||||
|
||||
return False
|
||||
|
||||
def start_new_session(self, first_action: Dict[str, Any]):
|
||||
"""
|
||||
Démarre une nouvelle session.
|
||||
|
||||
Args:
|
||||
first_action: Première action de la session
|
||||
"""
|
||||
session_id = f"session_{datetime.now().strftime('%Y%m%d_%H%M%S')}"
|
||||
|
||||
timestamp = first_action.get("timestamp")
|
||||
if isinstance(timestamp, str):
|
||||
timestamp = datetime.fromisoformat(timestamp)
|
||||
elif not timestamp:
|
||||
timestamp = datetime.now()
|
||||
|
||||
self.current_session = Session(
|
||||
session_id=session_id,
|
||||
start_time=timestamp,
|
||||
window=first_action.get("window")
|
||||
)
|
||||
|
||||
if self.logger:
|
||||
self.logger.log_action({
|
||||
"action": "session_started",
|
||||
"session_id": session_id,
|
||||
"window": first_action.get("window")
|
||||
})
|
||||
|
||||
def finalize_current_session(self):
|
||||
"""Finalise la session courante."""
|
||||
if not self.current_session:
|
||||
return
|
||||
|
||||
# Marquer la fin
|
||||
self.current_session.end_time = datetime.now()
|
||||
|
||||
# Ajouter à l'historique
|
||||
self.sessions.append(self.current_session)
|
||||
|
||||
# Callback (l'Orchestrator loggera avec plus de détails)
|
||||
if self.on_session_completed:
|
||||
self.on_session_completed(self.current_session)
|
||||
|
||||
# Réinitialiser
|
||||
self.current_session = None
|
||||
|
||||
def force_finalize_session(self):
|
||||
"""Force la finalisation de la session courante."""
|
||||
self.finalize_current_session()
|
||||
|
||||
def get_recent_sessions(self, n: int = 10) -> List[Session]:
|
||||
"""
|
||||
Retourne les N sessions les plus récentes.
|
||||
|
||||
Args:
|
||||
n: Nombre de sessions à retourner
|
||||
|
||||
Returns:
|
||||
Liste des sessions récentes
|
||||
"""
|
||||
return self.sessions[-n:] if len(self.sessions) >= n else self.sessions
|
||||
|
||||
def get_stats(self) -> Dict[str, Any]:
|
||||
"""Retourne les statistiques des sessions."""
|
||||
return {
|
||||
"total_sessions": len(self.sessions),
|
||||
"current_session_actions": self.current_session.action_count if self.current_session else 0,
|
||||
"avg_session_duration": sum(s.duration.total_seconds() for s in self.sessions) / len(self.sessions) if self.sessions else 0,
|
||||
"avg_actions_per_session": sum(s.action_count for s in self.sessions) / len(self.sessions) if self.sessions else 0
|
||||
}
|
||||
696
geniusia2/core/suggestion_manager.py
Normal file
696
geniusia2/core/suggestion_manager.py
Normal file
@@ -0,0 +1,696 @@
|
||||
"""
|
||||
Gestionnaire de suggestions pour le Mode Assisté.
|
||||
Gère les suggestions en temps réel, les scores de confiance et les timeouts.
|
||||
"""
|
||||
|
||||
import time
|
||||
from typing import Dict, Any, Optional, Callable, List
|
||||
from datetime import datetime, timedelta
|
||||
from threading import Lock
|
||||
|
||||
from .learning_manager import LearningManager
|
||||
from .embeddings_manager import EmbeddingsManager
|
||||
from .logger import Logger
|
||||
from .workflow_matcher import WorkflowMatcher, WorkflowMatch
|
||||
|
||||
|
||||
class SuggestionManager:
|
||||
"""
|
||||
Gestionnaire de suggestions pour le Mode Assisté.
|
||||
"""
|
||||
|
||||
def __init__(
|
||||
self,
|
||||
learning_manager: LearningManager,
|
||||
embeddings_manager: EmbeddingsManager,
|
||||
logger: Logger,
|
||||
config: Dict[str, Any],
|
||||
workflow_matcher: Optional[WorkflowMatcher] = None
|
||||
):
|
||||
"""
|
||||
Initialise le gestionnaire de suggestions.
|
||||
|
||||
Args:
|
||||
learning_manager: Gestionnaire d'apprentissage
|
||||
embeddings_manager: Gestionnaire d'embeddings
|
||||
logger: Logger
|
||||
config: Configuration
|
||||
workflow_matcher: Matcher de workflows (optionnel)
|
||||
"""
|
||||
self.learning_manager = learning_manager
|
||||
self.embeddings_manager = embeddings_manager
|
||||
self.logger = logger
|
||||
self.config = config
|
||||
|
||||
# WorkflowMatcher pour la détection de workflows
|
||||
self.workflow_matcher = workflow_matcher or WorkflowMatcher(logger, config)
|
||||
|
||||
# Configuration
|
||||
self.similarity_threshold = config.get("assist", {}).get(
|
||||
"similarity_threshold", 0.75
|
||||
)
|
||||
self.suggestion_timeout = config.get("assist", {}).get(
|
||||
"suggestion_timeout", 10.0 # secondes
|
||||
)
|
||||
self.workflow_confidence_threshold = config.get("workflow", {}).get(
|
||||
"min_confidence", 0.80 # 80% par défaut
|
||||
)
|
||||
|
||||
# État actuel
|
||||
self.current_suggestion: Optional[Dict[str, Any]] = None
|
||||
self.suggestion_lock = Lock()
|
||||
self.suggestion_start_time: Optional[datetime] = None
|
||||
|
||||
# Tracking des rejets par workflow
|
||||
self.workflow_rejections: Dict[str, int] = {} # workflow_id -> count
|
||||
self.workflow_priority_adjustments: Dict[str, float] = {} # workflow_id -> multiplier
|
||||
|
||||
# Callbacks
|
||||
self.on_suggestion_created: Optional[Callable] = None
|
||||
self.on_suggestion_accepted: Optional[Callable] = None
|
||||
self.on_suggestion_rejected: Optional[Callable] = None
|
||||
self.on_suggestion_timeout: Optional[Callable] = None
|
||||
|
||||
self.logger.log_action({
|
||||
"action": "suggestion_manager_initialized",
|
||||
"similarity_threshold": self.similarity_threshold,
|
||||
"timeout": self.suggestion_timeout,
|
||||
"workflow_confidence_threshold": self.workflow_confidence_threshold
|
||||
})
|
||||
|
||||
def find_suggestion(self, context: Dict[str, Any]) -> Optional[Dict[str, Any]]:
|
||||
"""
|
||||
Recherche une suggestion basée sur le contexte actuel.
|
||||
|
||||
Args:
|
||||
context: Contexte actuel (embedding, fenêtre, etc.)
|
||||
|
||||
Returns:
|
||||
Suggestion ou None
|
||||
"""
|
||||
# D'abord, vérifier s'il y a un workflow en cours
|
||||
workflow_suggestion = self._check_workflow_suggestion(context)
|
||||
if workflow_suggestion:
|
||||
return workflow_suggestion
|
||||
|
||||
# Sinon, recherche classique par embedding
|
||||
embedding = context.get("embedding")
|
||||
|
||||
if embedding is None:
|
||||
return None
|
||||
|
||||
# Rechercher dans FAISS
|
||||
results = self.embeddings_manager.search_similar(embedding, k=3)
|
||||
|
||||
if not results:
|
||||
return None
|
||||
|
||||
# Filtrer par seuil de similarité
|
||||
best_match = results[0]
|
||||
|
||||
if best_match["similarity"] < self.similarity_threshold:
|
||||
return None
|
||||
|
||||
# Récupérer les métadonnées
|
||||
metadata = best_match["metadata"]
|
||||
task_id = metadata.get("task_id")
|
||||
|
||||
if not task_id:
|
||||
return None
|
||||
|
||||
# Charger la tâche
|
||||
task = self.learning_manager.load_task(task_id)
|
||||
|
||||
if not task:
|
||||
return None
|
||||
|
||||
# Créer la suggestion
|
||||
suggestion = {
|
||||
"type": "action", # Type de suggestion
|
||||
"task_id": task_id,
|
||||
"task_name": task.task_name,
|
||||
"action_type": metadata.get("action_type", "unknown"),
|
||||
"description": metadata.get("description", ""),
|
||||
"similarity": best_match["similarity"],
|
||||
"confidence": self._calculate_confidence(best_match, task),
|
||||
"metadata": metadata,
|
||||
"timestamp": datetime.now()
|
||||
}
|
||||
|
||||
self.logger.log_action({
|
||||
"action": "suggestion_found",
|
||||
"task_id": task_id,
|
||||
"similarity": best_match["similarity"],
|
||||
"confidence": suggestion["confidence"]
|
||||
})
|
||||
|
||||
return suggestion
|
||||
|
||||
def _check_workflow_suggestion(self, context: Dict[str, Any]) -> Optional[Dict[str, Any]]:
|
||||
"""
|
||||
Vérifie s'il y a un workflow en cours qui correspond au contexte.
|
||||
|
||||
Args:
|
||||
context: Contexte actuel
|
||||
|
||||
Returns:
|
||||
Suggestion de workflow ou None
|
||||
"""
|
||||
# Récupérer l'event_capture depuis le contexte
|
||||
event_capture = context.get("event_capture")
|
||||
if not event_capture:
|
||||
return None
|
||||
|
||||
# Récupérer les workflows détectés
|
||||
workflows = event_capture.get_workflows()
|
||||
if not workflows:
|
||||
return None
|
||||
|
||||
# Récupérer la session courante
|
||||
current_session = event_capture.session_manager.current_session
|
||||
if not current_session or not current_session.actions:
|
||||
return None
|
||||
|
||||
# Comparer avec les workflows connus
|
||||
for workflow in workflows:
|
||||
# Vérifier si on est au début d'un workflow
|
||||
match_score = self._match_workflow_start(current_session.actions, workflow)
|
||||
|
||||
if match_score >= 0.8: # 80% de correspondance
|
||||
# Calculer quelle est la prochaine étape
|
||||
next_step_index = len(current_session.actions)
|
||||
|
||||
if next_step_index < len(workflow.steps):
|
||||
next_step = workflow.steps[next_step_index]
|
||||
|
||||
# Créer une suggestion de workflow
|
||||
suggestion = {
|
||||
"type": "workflow", # Type workflow
|
||||
"workflow_id": workflow.workflow_id,
|
||||
"workflow_name": workflow.name,
|
||||
"current_step": next_step_index,
|
||||
"total_steps": len(workflow.steps),
|
||||
"next_action": {
|
||||
"action_type": next_step.action_type,
|
||||
"description": next_step.target_description,
|
||||
"position": next_step.position
|
||||
},
|
||||
"remaining_steps": len(workflow.steps) - next_step_index,
|
||||
"confidence": workflow.confidence * match_score,
|
||||
"repetitions": workflow.repetitions,
|
||||
"timestamp": datetime.now()
|
||||
}
|
||||
|
||||
self.logger.log_action({
|
||||
"action": "workflow_suggestion_found",
|
||||
"workflow_id": workflow.workflow_id,
|
||||
"step": next_step_index,
|
||||
"confidence": suggestion["confidence"]
|
||||
})
|
||||
|
||||
return suggestion
|
||||
|
||||
return None
|
||||
|
||||
def _match_workflow_start(self, current_actions: list, workflow) -> float:
|
||||
"""
|
||||
Calcule le score de correspondance entre les actions courantes et le début d'un workflow.
|
||||
|
||||
Args:
|
||||
current_actions: Actions de la session courante
|
||||
workflow: Workflow à comparer
|
||||
|
||||
Returns:
|
||||
Score de correspondance (0-1)
|
||||
"""
|
||||
if not current_actions or not workflow.steps:
|
||||
return 0.0
|
||||
|
||||
# Comparer les N premières actions
|
||||
n = min(len(current_actions), len(workflow.steps))
|
||||
matches = 0
|
||||
|
||||
for i in range(n):
|
||||
action = current_actions[i]
|
||||
step = workflow.steps[i]
|
||||
|
||||
# Comparer le type d'action
|
||||
if action.get("action_type") == step.action_type:
|
||||
matches += 1
|
||||
|
||||
# Bonus si même fenêtre
|
||||
if action.get("window") == step.window:
|
||||
matches += 0.5
|
||||
|
||||
# Score normalisé
|
||||
max_score = n * 1.5 # 1 pour le type + 0.5 pour la fenêtre
|
||||
return matches / max_score if max_score > 0 else 0.0
|
||||
|
||||
def _calculate_confidence(
|
||||
self,
|
||||
match: Dict[str, Any],
|
||||
task: Any
|
||||
) -> float:
|
||||
"""
|
||||
Calcule le score de confiance pour une suggestion.
|
||||
|
||||
Args:
|
||||
match: Résultat de recherche FAISS
|
||||
task: Profil de tâche
|
||||
|
||||
Returns:
|
||||
Score de confiance (0-1)
|
||||
"""
|
||||
# Similarité visuelle
|
||||
vision_score = match["similarity"]
|
||||
|
||||
# Performance historique
|
||||
historical_score = task.concordance_rate if hasattr(task, "concordance_rate") else 0.5
|
||||
|
||||
# Formule : 70% vision + 30% historique
|
||||
confidence = 0.7 * vision_score + 0.3 * historical_score
|
||||
|
||||
return max(0.0, min(1.0, confidence))
|
||||
|
||||
def create_suggestion(self, context: Dict[str, Any]) -> bool:
|
||||
"""
|
||||
Crée une nouvelle suggestion si applicable.
|
||||
|
||||
Args:
|
||||
context: Contexte actuel
|
||||
|
||||
Returns:
|
||||
True si suggestion créée
|
||||
"""
|
||||
with self.suggestion_lock:
|
||||
# Vérifier qu'il n'y a pas déjà une suggestion active
|
||||
if self.current_suggestion is not None:
|
||||
return False
|
||||
|
||||
# Rechercher une suggestion
|
||||
suggestion = self.find_suggestion(context)
|
||||
|
||||
if suggestion is None:
|
||||
return False
|
||||
|
||||
# Vérifier le seuil de confiance
|
||||
if suggestion["confidence"] < self.similarity_threshold:
|
||||
return False
|
||||
|
||||
# Créer la suggestion
|
||||
self.current_suggestion = suggestion
|
||||
self.suggestion_start_time = datetime.now()
|
||||
|
||||
# Callback
|
||||
if self.on_suggestion_created:
|
||||
self.on_suggestion_created(suggestion)
|
||||
|
||||
self.logger.log_action({
|
||||
"action": "suggestion_created",
|
||||
"task_id": suggestion["task_id"],
|
||||
"confidence": suggestion["confidence"]
|
||||
})
|
||||
|
||||
return True
|
||||
|
||||
def accept_suggestion(self) -> Optional[Dict[str, Any]]:
|
||||
"""
|
||||
Accepte la suggestion actuelle.
|
||||
|
||||
Returns:
|
||||
Suggestion acceptée ou None
|
||||
"""
|
||||
with self.suggestion_lock:
|
||||
if self.current_suggestion is None:
|
||||
return None
|
||||
|
||||
suggestion = self.current_suggestion
|
||||
self.current_suggestion = None
|
||||
self.suggestion_start_time = None
|
||||
|
||||
# Si c'est une suggestion de workflow, tracker l'acceptation
|
||||
if suggestion.get("type") == "workflow":
|
||||
workflow_id = suggestion.get("workflow_id")
|
||||
if workflow_id:
|
||||
self._track_workflow_acceptance(workflow_id)
|
||||
|
||||
# Mettre à jour les statistiques (pour les suggestions d'action)
|
||||
if suggestion.get("task_id"):
|
||||
self.learning_manager.confirm_action({
|
||||
"type": "accept",
|
||||
"task_id": suggestion["task_id"]
|
||||
})
|
||||
|
||||
# Callback
|
||||
if self.on_suggestion_accepted:
|
||||
self.on_suggestion_accepted(suggestion)
|
||||
|
||||
self.logger.log_action({
|
||||
"action": "suggestion_accepted",
|
||||
"suggestion_type": suggestion.get("type"),
|
||||
"workflow_id": suggestion.get("workflow_id"),
|
||||
"task_id": suggestion.get("task_id")
|
||||
})
|
||||
|
||||
return suggestion
|
||||
|
||||
def reject_suggestion(self) -> bool:
|
||||
"""
|
||||
Rejette la suggestion actuelle.
|
||||
|
||||
Returns:
|
||||
True si suggestion rejetée
|
||||
"""
|
||||
with self.suggestion_lock:
|
||||
if self.current_suggestion is None:
|
||||
return False
|
||||
|
||||
suggestion = self.current_suggestion
|
||||
self.current_suggestion = None
|
||||
self.suggestion_start_time = None
|
||||
|
||||
# Si c'est une suggestion de workflow, tracker le rejet
|
||||
if suggestion.get("type") == "workflow":
|
||||
workflow_id = suggestion.get("workflow_id")
|
||||
if workflow_id:
|
||||
self._track_workflow_rejection(workflow_id)
|
||||
|
||||
# Mettre à jour les statistiques (pour les suggestions d'action)
|
||||
if suggestion.get("task_id"):
|
||||
self.learning_manager.confirm_action({
|
||||
"type": "reject",
|
||||
"task_id": suggestion["task_id"]
|
||||
})
|
||||
|
||||
# Callback
|
||||
if self.on_suggestion_rejected:
|
||||
self.on_suggestion_rejected(suggestion)
|
||||
|
||||
self.logger.log_action({
|
||||
"action": "suggestion_rejected",
|
||||
"suggestion_type": suggestion.get("type"),
|
||||
"workflow_id": suggestion.get("workflow_id"),
|
||||
"task_id": suggestion.get("task_id")
|
||||
})
|
||||
|
||||
return True
|
||||
|
||||
def check_timeout(self) -> bool:
|
||||
"""
|
||||
Vérifie si la suggestion actuelle a expiré.
|
||||
|
||||
Returns:
|
||||
True si timeout
|
||||
"""
|
||||
with self.suggestion_lock:
|
||||
if self.current_suggestion is None:
|
||||
return False
|
||||
|
||||
if self.suggestion_start_time is None:
|
||||
return False
|
||||
|
||||
elapsed = (datetime.now() - self.suggestion_start_time).total_seconds()
|
||||
|
||||
if elapsed >= self.suggestion_timeout:
|
||||
# Timeout
|
||||
suggestion = self.current_suggestion
|
||||
self.current_suggestion = None
|
||||
self.suggestion_start_time = None
|
||||
|
||||
# Callback
|
||||
if self.on_suggestion_timeout:
|
||||
self.on_suggestion_timeout(suggestion)
|
||||
|
||||
self.logger.log_action({
|
||||
"action": "suggestion_timeout",
|
||||
"task_id": suggestion["task_id"],
|
||||
"elapsed": elapsed
|
||||
})
|
||||
|
||||
return True
|
||||
|
||||
return False
|
||||
|
||||
def get_current_suggestion(self) -> Optional[Dict[str, Any]]:
|
||||
"""Retourne la suggestion actuelle."""
|
||||
with self.suggestion_lock:
|
||||
return self.current_suggestion
|
||||
|
||||
def clear_suggestion(self):
|
||||
"""Efface la suggestion actuelle."""
|
||||
with self.suggestion_lock:
|
||||
self.current_suggestion = None
|
||||
self.suggestion_start_time = None
|
||||
|
||||
def get_stats(self) -> Dict[str, Any]:
|
||||
"""Retourne les statistiques du gestionnaire."""
|
||||
return {
|
||||
"has_active_suggestion": self.current_suggestion is not None,
|
||||
"similarity_threshold": self.similarity_threshold,
|
||||
"timeout": self.suggestion_timeout,
|
||||
"workflow_rejections": len(self.workflow_rejections),
|
||||
"workflows_with_adjusted_priority": len(self.workflow_priority_adjustments)
|
||||
}
|
||||
|
||||
def check_workflow_match(
|
||||
self,
|
||||
session_actions: List[Dict[str, Any]],
|
||||
workflows: List[Dict[str, Any]]
|
||||
) -> Optional[WorkflowMatch]:
|
||||
"""
|
||||
Vérifie périodiquement si les actions courantes correspondent à un workflow connu.
|
||||
|
||||
Cette méthode doit être appelée régulièrement (ex: toutes les 2s) en mode Assist
|
||||
pour détecter les correspondances de workflows.
|
||||
|
||||
Args:
|
||||
session_actions: Liste des actions de la session courante
|
||||
workflows: Liste des workflows connus
|
||||
|
||||
Returns:
|
||||
Meilleure correspondance si trouvée, None sinon
|
||||
"""
|
||||
if not session_actions or not workflows:
|
||||
return None
|
||||
|
||||
# Vérifier qu'il n'y a pas déjà une suggestion active
|
||||
with self.suggestion_lock:
|
||||
if self.current_suggestion is not None:
|
||||
return None
|
||||
|
||||
# Utiliser le WorkflowMatcher pour trouver les correspondances
|
||||
matches = self.workflow_matcher.match_current_session(
|
||||
session_actions,
|
||||
workflows
|
||||
)
|
||||
|
||||
if not matches:
|
||||
return None
|
||||
|
||||
# Appliquer les ajustements de priorité basés sur les rejets
|
||||
adjusted_matches = []
|
||||
for match in matches:
|
||||
adjusted_confidence = self._apply_priority_adjustment(match)
|
||||
|
||||
# Créer un nouveau match avec la confiance ajustée
|
||||
adjusted_match = WorkflowMatch(
|
||||
workflow_id=match.workflow_id,
|
||||
workflow_name=match.workflow_name,
|
||||
confidence=adjusted_confidence,
|
||||
matched_steps=match.matched_steps,
|
||||
total_steps=match.total_steps,
|
||||
remaining_steps=match.remaining_steps,
|
||||
current_step_index=match.current_step_index
|
||||
)
|
||||
adjusted_matches.append(adjusted_match)
|
||||
|
||||
# Trier à nouveau par confiance ajustée
|
||||
adjusted_matches.sort(key=lambda m: m.confidence, reverse=True)
|
||||
|
||||
# Trouver le meilleur match
|
||||
best_match = self.workflow_matcher.find_best_match(adjusted_matches)
|
||||
|
||||
if best_match:
|
||||
self.logger.log_action({
|
||||
"action": "workflow_match_found",
|
||||
"workflow_id": best_match.workflow_id,
|
||||
"workflow_name": best_match.workflow_name,
|
||||
"confidence": best_match.confidence,
|
||||
"matched_steps": best_match.matched_steps,
|
||||
"remaining_steps": len(best_match.remaining_steps)
|
||||
})
|
||||
|
||||
return best_match
|
||||
|
||||
def create_workflow_suggestion(
|
||||
self,
|
||||
workflow_match: WorkflowMatch
|
||||
) -> Optional[Dict[str, Any]]:
|
||||
"""
|
||||
Crée une suggestion de workflow avec les détails des étapes restantes.
|
||||
|
||||
Args:
|
||||
workflow_match: Correspondance de workflow trouvée
|
||||
|
||||
Returns:
|
||||
Suggestion créée ou None si impossible
|
||||
"""
|
||||
with self.suggestion_lock:
|
||||
# Vérifier qu'il n'y a pas déjà une suggestion active
|
||||
if self.current_suggestion is not None:
|
||||
return None
|
||||
|
||||
# Vérifier le seuil de confiance
|
||||
if workflow_match.confidence < self.workflow_confidence_threshold:
|
||||
self.logger.log_action({
|
||||
"action": "workflow_suggestion_rejected_low_confidence",
|
||||
"workflow_id": workflow_match.workflow_id,
|
||||
"confidence": workflow_match.confidence,
|
||||
"threshold": self.workflow_confidence_threshold
|
||||
})
|
||||
return None
|
||||
|
||||
# Créer la suggestion avec les détails des étapes
|
||||
suggestion = {
|
||||
"type": "workflow",
|
||||
"workflow_id": workflow_match.workflow_id,
|
||||
"workflow_name": workflow_match.workflow_name,
|
||||
"confidence": workflow_match.confidence,
|
||||
"current_step": workflow_match.current_step_index,
|
||||
"total_steps": workflow_match.total_steps,
|
||||
"matched_steps": workflow_match.matched_steps,
|
||||
"remaining_steps": workflow_match.remaining_steps,
|
||||
"next_steps_preview": workflow_match.remaining_steps[:3], # 3 prochaines étapes
|
||||
"created_at": datetime.now(),
|
||||
"timeout": self.suggestion_timeout
|
||||
}
|
||||
|
||||
# Enregistrer la suggestion
|
||||
self.current_suggestion = suggestion
|
||||
self.suggestion_start_time = datetime.now()
|
||||
|
||||
# Callback
|
||||
if self.on_suggestion_created:
|
||||
self.on_suggestion_created(suggestion)
|
||||
|
||||
self.logger.log_action({
|
||||
"action": "workflow_suggestion_created",
|
||||
"workflow_id": workflow_match.workflow_id,
|
||||
"workflow_name": workflow_match.workflow_name,
|
||||
"confidence": workflow_match.confidence,
|
||||
"remaining_steps": len(workflow_match.remaining_steps)
|
||||
})
|
||||
|
||||
return suggestion
|
||||
|
||||
def _apply_priority_adjustment(self, match: WorkflowMatch) -> float:
|
||||
"""
|
||||
Applique l'ajustement de priorité basé sur les rejets précédents.
|
||||
|
||||
Args:
|
||||
match: Correspondance de workflow
|
||||
|
||||
Returns:
|
||||
Confiance ajustée
|
||||
"""
|
||||
workflow_id = match.workflow_id
|
||||
|
||||
# Récupérer le multiplicateur d'ajustement
|
||||
adjustment = self.workflow_priority_adjustments.get(workflow_id, 1.0)
|
||||
|
||||
# Appliquer l'ajustement
|
||||
adjusted_confidence = match.confidence * adjustment
|
||||
|
||||
# S'assurer que la confiance reste dans [0, 1]
|
||||
adjusted_confidence = max(0.0, min(1.0, adjusted_confidence))
|
||||
|
||||
if adjustment != 1.0:
|
||||
self.logger.log_action({
|
||||
"action": "priority_adjustment_applied",
|
||||
"workflow_id": workflow_id,
|
||||
"original_confidence": match.confidence,
|
||||
"adjusted_confidence": adjusted_confidence,
|
||||
"adjustment_multiplier": adjustment
|
||||
})
|
||||
|
||||
return adjusted_confidence
|
||||
|
||||
def _track_workflow_rejection(self, workflow_id: str):
|
||||
"""
|
||||
Enregistre un rejet de workflow et ajuste la priorité si nécessaire.
|
||||
|
||||
Après 3 rejets, la priorité du workflow est réduite (confiance * 0.9).
|
||||
|
||||
Args:
|
||||
workflow_id: ID du workflow rejeté
|
||||
"""
|
||||
# Incrémenter le compteur de rejets
|
||||
current_rejections = self.workflow_rejections.get(workflow_id, 0)
|
||||
current_rejections += 1
|
||||
self.workflow_rejections[workflow_id] = current_rejections
|
||||
|
||||
self.logger.log_action({
|
||||
"action": "workflow_rejection_tracked",
|
||||
"workflow_id": workflow_id,
|
||||
"total_rejections": current_rejections
|
||||
})
|
||||
|
||||
# Après 3 rejets, ajuster la priorité
|
||||
if current_rejections >= 3:
|
||||
# Réduire la confiance de 10% à chaque tranche de 3 rejets
|
||||
adjustment_factor = 0.9 ** (current_rejections // 3)
|
||||
self.workflow_priority_adjustments[workflow_id] = adjustment_factor
|
||||
|
||||
self.logger.log_action({
|
||||
"action": "workflow_priority_adjusted",
|
||||
"workflow_id": workflow_id,
|
||||
"rejections": current_rejections,
|
||||
"new_adjustment_factor": adjustment_factor
|
||||
})
|
||||
|
||||
def _track_workflow_acceptance(self, workflow_id: str):
|
||||
"""
|
||||
Enregistre une acceptation de workflow et améliore la priorité.
|
||||
|
||||
Args:
|
||||
workflow_id: ID du workflow accepté
|
||||
"""
|
||||
# Réduire le compteur de rejets (récompenser les acceptations)
|
||||
current_rejections = self.workflow_rejections.get(workflow_id, 0)
|
||||
if current_rejections > 0:
|
||||
current_rejections = max(0, current_rejections - 2) # Réduire de 2
|
||||
self.workflow_rejections[workflow_id] = current_rejections
|
||||
|
||||
# Recalculer l'ajustement de priorité
|
||||
if current_rejections >= 3:
|
||||
adjustment_factor = 0.9 ** (current_rejections // 3)
|
||||
self.workflow_priority_adjustments[workflow_id] = adjustment_factor
|
||||
else:
|
||||
# Retirer l'ajustement si moins de 3 rejets
|
||||
if workflow_id in self.workflow_priority_adjustments:
|
||||
del self.workflow_priority_adjustments[workflow_id]
|
||||
|
||||
self.logger.log_action({
|
||||
"action": "workflow_acceptance_tracked",
|
||||
"workflow_id": workflow_id,
|
||||
"remaining_rejections": current_rejections
|
||||
})
|
||||
|
||||
def on_workflow_detected(self, workflow: Dict[str, Any]):
|
||||
"""
|
||||
Callback appelé quand un workflow est détecté.
|
||||
Peut créer une suggestion immédiate si le workflow est pertinent.
|
||||
|
||||
Args:
|
||||
workflow: Workflow détecté
|
||||
"""
|
||||
self.logger.log_action({
|
||||
"action": "workflow_detected_in_suggestion_manager",
|
||||
"workflow_id": workflow.get("workflow_id"),
|
||||
"workflow_name": workflow.get("name"),
|
||||
"confidence": workflow.get("confidence")
|
||||
})
|
||||
|
||||
# Pour l'instant, on log seulement
|
||||
# Dans le futur, on pourrait créer une suggestion proactive
|
||||
# basée sur le workflow détecté
|
||||
498
geniusia2/core/task_replay.py
Normal file
498
geniusia2/core/task_replay.py
Normal file
@@ -0,0 +1,498 @@
|
||||
"""
|
||||
Système de rejeu intelligent de tâches apprises avec reconnaissance visuelle.
|
||||
Permet de rejouer des tâches en s'adaptant aux variations d'interface.
|
||||
"""
|
||||
|
||||
import asyncio
|
||||
import time
|
||||
from typing import Dict, List, Optional, Any, Tuple
|
||||
from pathlib import Path
|
||||
import numpy as np
|
||||
from PIL import Image
|
||||
|
||||
from .models import TaskProfile, Action
|
||||
from .learning_manager import LearningManager
|
||||
from .embeddings_manager import EmbeddingsManager
|
||||
from .utils.vision_utils import VisionUtils
|
||||
from .utils.input_utils import InputUtils
|
||||
from .logger import Logger
|
||||
|
||||
|
||||
class TaskReplayEngine:
|
||||
"""
|
||||
Moteur de rejeu intelligent qui utilise la vision pour localiser
|
||||
les éléments et rejouer les tâches apprises.
|
||||
"""
|
||||
|
||||
def __init__(
|
||||
self,
|
||||
learning_manager: LearningManager,
|
||||
embeddings_manager: EmbeddingsManager,
|
||||
vision_utils: VisionUtils,
|
||||
input_utils: InputUtils,
|
||||
logger: Logger,
|
||||
config: Dict[str, Any]
|
||||
):
|
||||
"""
|
||||
Initialise le moteur de rejeu.
|
||||
|
||||
Args:
|
||||
learning_manager: Gestionnaire d'apprentissage
|
||||
embeddings_manager: Gestionnaire d'embeddings
|
||||
vision_utils: Utilitaires de vision
|
||||
input_utils: Utilitaires d'entrée
|
||||
logger: Logger
|
||||
config: Configuration
|
||||
"""
|
||||
self.learning_manager = learning_manager
|
||||
self.embeddings_manager = embeddings_manager
|
||||
self.vision_utils = vision_utils
|
||||
self.input_utils = input_utils
|
||||
self.logger = logger
|
||||
self.config = config
|
||||
|
||||
# Configuration
|
||||
self.similarity_threshold = config.get("replay", {}).get(
|
||||
"similarity_threshold", 0.75
|
||||
)
|
||||
self.max_search_attempts = config.get("replay", {}).get(
|
||||
"max_search_attempts", 3
|
||||
)
|
||||
self.delay_between_actions = config.get("replay", {}).get(
|
||||
"delay_between_actions", 0.5
|
||||
)
|
||||
|
||||
self.logger.log_action({
|
||||
"action": "task_replay_engine_initialized",
|
||||
"similarity_threshold": self.similarity_threshold
|
||||
})
|
||||
|
||||
async def replay_task(
|
||||
self,
|
||||
task_id: str,
|
||||
interactive: bool = False
|
||||
) -> Dict[str, Any]:
|
||||
"""
|
||||
Rejoue une tâche apprise.
|
||||
|
||||
Args:
|
||||
task_id: ID de la tâche à rejouer
|
||||
interactive: Si True, demande confirmation avant chaque action
|
||||
|
||||
Returns:
|
||||
Résultats du rejeu
|
||||
"""
|
||||
# Charger la tâche
|
||||
task = self.learning_manager.load_task(task_id)
|
||||
|
||||
if not task:
|
||||
return {
|
||||
"success": False,
|
||||
"error": "task_not_found",
|
||||
"task_id": task_id
|
||||
}
|
||||
|
||||
self.logger.log_action({
|
||||
"action": "task_replay_started",
|
||||
"task_id": task_id,
|
||||
"interactive": interactive
|
||||
})
|
||||
|
||||
# Récupérer les signatures
|
||||
signatures = task.metadata.get("signatures", [])
|
||||
|
||||
if not signatures:
|
||||
return {
|
||||
"success": False,
|
||||
"error": "no_signatures",
|
||||
"task_id": task_id
|
||||
}
|
||||
|
||||
results = {
|
||||
"task_id": task_id,
|
||||
"total_actions": len(signatures),
|
||||
"executed_actions": 0,
|
||||
"failed_actions": 0,
|
||||
"actions": []
|
||||
}
|
||||
|
||||
# Rejouer chaque action
|
||||
for i, signature in enumerate(signatures):
|
||||
self.logger.log_action({
|
||||
"action": "replaying_step",
|
||||
"step": i + 1,
|
||||
"total": len(signatures)
|
||||
})
|
||||
|
||||
# Localiser l'élément visuellement
|
||||
location = await self._find_element_visually(signature)
|
||||
|
||||
if not location:
|
||||
self.logger.log_action({
|
||||
"action": "element_not_found",
|
||||
"step": i + 1,
|
||||
"signature": signature.get("description", "Unknown")
|
||||
})
|
||||
|
||||
results["failed_actions"] += 1
|
||||
results["actions"].append({
|
||||
"step": i + 1,
|
||||
"success": False,
|
||||
"error": "element_not_found"
|
||||
})
|
||||
continue
|
||||
|
||||
# Exécuter l'action
|
||||
success = await self._execute_action_at_location(
|
||||
signature,
|
||||
location,
|
||||
interactive
|
||||
)
|
||||
|
||||
results["actions"].append({
|
||||
"step": i + 1,
|
||||
"success": success,
|
||||
"location": location,
|
||||
"action_type": signature.get("action_type")
|
||||
})
|
||||
|
||||
if success:
|
||||
results["executed_actions"] += 1
|
||||
else:
|
||||
results["failed_actions"] += 1
|
||||
|
||||
# Attendre entre les actions
|
||||
if i < len(signatures) - 1:
|
||||
await asyncio.sleep(self.delay_between_actions)
|
||||
|
||||
results["success"] = results["failed_actions"] == 0
|
||||
|
||||
self.logger.log_action({
|
||||
"action": "task_replay_completed",
|
||||
"task_id": task_id,
|
||||
"success": results["success"],
|
||||
"executed": results["executed_actions"],
|
||||
"failed": results["failed_actions"]
|
||||
})
|
||||
|
||||
return results
|
||||
|
||||
async def _find_element_visually(
|
||||
self,
|
||||
signature: Dict[str, Any]
|
||||
) -> Optional[Dict[str, Any]]:
|
||||
"""
|
||||
Localise un élément visuellement en utilisant son embedding.
|
||||
|
||||
Args:
|
||||
signature: Signature visuelle de l'élément
|
||||
|
||||
Returns:
|
||||
Localisation (x, y, confidence) ou None
|
||||
"""
|
||||
embedding = signature.get("embedding")
|
||||
|
||||
if embedding is None:
|
||||
return None
|
||||
|
||||
# Capturer l'écran actuel
|
||||
import pyautogui
|
||||
screenshot = pyautogui.screenshot()
|
||||
screenshot_np = np.array(screenshot)
|
||||
|
||||
# Rechercher l'élément avec vision
|
||||
for attempt in range(self.max_search_attempts):
|
||||
# Générer l'embedding de l'écran actuel
|
||||
current_embedding = self.vision_utils.generate_clip_embedding(
|
||||
screenshot_np
|
||||
)
|
||||
|
||||
# Rechercher les zones similaires
|
||||
similar_regions = await self._search_similar_regions(
|
||||
screenshot_np,
|
||||
embedding
|
||||
)
|
||||
|
||||
if similar_regions:
|
||||
best_match = similar_regions[0]
|
||||
|
||||
if best_match["similarity"] >= self.similarity_threshold:
|
||||
self.logger.log_action({
|
||||
"action": "element_found",
|
||||
"similarity": best_match["similarity"],
|
||||
"attempt": attempt + 1
|
||||
})
|
||||
|
||||
return {
|
||||
"x": best_match["x"],
|
||||
"y": best_match["y"],
|
||||
"confidence": best_match["similarity"],
|
||||
"bbox": best_match.get("bbox")
|
||||
}
|
||||
|
||||
# Attendre un peu avant de réessayer
|
||||
if attempt < self.max_search_attempts - 1:
|
||||
await asyncio.sleep(0.5)
|
||||
|
||||
return None
|
||||
|
||||
async def _search_similar_regions(
|
||||
self,
|
||||
screenshot: np.ndarray,
|
||||
target_embedding: np.ndarray,
|
||||
grid_size: int = 4
|
||||
) -> List[Dict[str, Any]]:
|
||||
"""
|
||||
Recherche les régions similaires dans une capture d'écran.
|
||||
|
||||
Args:
|
||||
screenshot: Capture d'écran
|
||||
target_embedding: Embedding cible
|
||||
grid_size: Taille de la grille de recherche
|
||||
|
||||
Returns:
|
||||
Liste de régions similaires triées par similarité
|
||||
"""
|
||||
height, width = screenshot.shape[:2]
|
||||
cell_height = height // grid_size
|
||||
cell_width = width // grid_size
|
||||
|
||||
regions = []
|
||||
|
||||
# Parcourir la grille
|
||||
for row in range(grid_size):
|
||||
for col in range(grid_size):
|
||||
y1 = row * cell_height
|
||||
x1 = col * cell_width
|
||||
y2 = min((row + 1) * cell_height, height)
|
||||
x2 = min((col + 1) * cell_width, width)
|
||||
|
||||
# Extraire la région
|
||||
region = screenshot[y1:y2, x1:x2]
|
||||
|
||||
# Générer l'embedding
|
||||
region_embedding = self.vision_utils.generate_clip_embedding(region)
|
||||
|
||||
# Calculer la similarité
|
||||
similarity = self._cosine_similarity(
|
||||
target_embedding,
|
||||
region_embedding
|
||||
)
|
||||
|
||||
# Centre de la région
|
||||
center_x = (x1 + x2) // 2
|
||||
center_y = (y1 + y2) // 2
|
||||
|
||||
regions.append({
|
||||
"x": center_x,
|
||||
"y": center_y,
|
||||
"bbox": (x1, y1, x2, y2),
|
||||
"similarity": similarity
|
||||
})
|
||||
|
||||
# Trier par similarité décroissante
|
||||
regions.sort(key=lambda r: r["similarity"], reverse=True)
|
||||
|
||||
return regions
|
||||
|
||||
def _cosine_similarity(
|
||||
self,
|
||||
emb1: np.ndarray,
|
||||
emb2: np.ndarray
|
||||
) -> float:
|
||||
"""Calcule la similarité cosinus entre deux embeddings."""
|
||||
dot_product = np.dot(emb1, emb2)
|
||||
norm1 = np.linalg.norm(emb1)
|
||||
norm2 = np.linalg.norm(emb2)
|
||||
|
||||
if norm1 == 0 or norm2 == 0:
|
||||
return 0.0
|
||||
|
||||
return float(dot_product / (norm1 * norm2))
|
||||
|
||||
async def _execute_action_at_location(
|
||||
self,
|
||||
signature: Dict[str, Any],
|
||||
location: Dict[str, Any],
|
||||
interactive: bool
|
||||
) -> bool:
|
||||
"""
|
||||
Exécute une action à une localisation donnée.
|
||||
|
||||
Args:
|
||||
signature: Signature de l'action
|
||||
location: Localisation de l'élément
|
||||
interactive: Si True, demande confirmation
|
||||
|
||||
Returns:
|
||||
True si succès
|
||||
"""
|
||||
action_type = signature.get("action_type", "click")
|
||||
x = location["x"]
|
||||
y = location["y"]
|
||||
|
||||
# Demander confirmation si mode interactif
|
||||
if interactive:
|
||||
confirmed = await self._ask_confirmation(signature, location)
|
||||
if not confirmed:
|
||||
return False
|
||||
|
||||
try:
|
||||
if action_type == "click":
|
||||
self.input_utils.click(x, y)
|
||||
return True
|
||||
|
||||
elif action_type == "type":
|
||||
text = signature.get("text", "")
|
||||
self.input_utils.type_text(text)
|
||||
return True
|
||||
|
||||
elif action_type == "scroll":
|
||||
direction = signature.get("direction", "down")
|
||||
amount = signature.get("amount", 3)
|
||||
self.input_utils.scroll(direction, amount, x, y)
|
||||
return True
|
||||
|
||||
elif action_type == "drag":
|
||||
# Pour le drag, on a besoin de la destination
|
||||
end_x = signature.get("end_x", x + 100)
|
||||
end_y = signature.get("end_y", y)
|
||||
self.input_utils.drag(x, y, end_x, end_y)
|
||||
return True
|
||||
|
||||
else:
|
||||
self.logger.log_action({
|
||||
"action": "unknown_action_type",
|
||||
"type": action_type
|
||||
})
|
||||
return False
|
||||
|
||||
except Exception as e:
|
||||
self.logger.log_action({
|
||||
"action": "execute_action_error",
|
||||
"error": str(e),
|
||||
"action_type": action_type
|
||||
})
|
||||
return False
|
||||
|
||||
async def _ask_confirmation(
|
||||
self,
|
||||
signature: Dict[str, Any],
|
||||
location: Dict[str, Any]
|
||||
) -> bool:
|
||||
"""
|
||||
Demande confirmation à l'utilisateur (mode interactif).
|
||||
|
||||
Args:
|
||||
signature: Signature de l'action
|
||||
location: Localisation
|
||||
|
||||
Returns:
|
||||
True si confirmé
|
||||
"""
|
||||
# TODO: Implémenter une vraie interface de confirmation
|
||||
# Pour l'instant, on accepte automatiquement
|
||||
return True
|
||||
|
||||
def list_available_tasks(self) -> List[Dict[str, Any]]:
|
||||
"""
|
||||
Liste toutes les tâches disponibles pour le rejeu.
|
||||
|
||||
Returns:
|
||||
Liste des tâches avec leurs métadonnées
|
||||
"""
|
||||
tasks = []
|
||||
profiles_path = Path(self.learning_manager.profiles_path)
|
||||
|
||||
for task_dir in profiles_path.iterdir():
|
||||
if not task_dir.is_dir():
|
||||
continue
|
||||
|
||||
metadata_file = task_dir / "metadata.json"
|
||||
if not metadata_file.exists():
|
||||
continue
|
||||
|
||||
try:
|
||||
import json
|
||||
with open(metadata_file, "r") as f:
|
||||
metadata = json.load(f)
|
||||
|
||||
tasks.append({
|
||||
"task_id": metadata.get("task_id"),
|
||||
"task_name": metadata.get("task_name", metadata.get("description")),
|
||||
"observation_count": metadata.get("observation_count", metadata.get("observations")),
|
||||
"confidence": metadata.get("confidence_score", 0.0)
|
||||
})
|
||||
|
||||
except Exception as e:
|
||||
self.logger.log_action({
|
||||
"action": "task_list_error",
|
||||
"task_dir": str(task_dir),
|
||||
"error": str(e)
|
||||
})
|
||||
|
||||
return tasks
|
||||
|
||||
async def replay_task_with_monitoring(
|
||||
self,
|
||||
task_id: str,
|
||||
on_step_completed: Optional[callable] = None
|
||||
) -> Dict[str, Any]:
|
||||
"""
|
||||
Rejoue une tâche avec monitoring en temps réel.
|
||||
|
||||
Args:
|
||||
task_id: ID de la tâche
|
||||
on_step_completed: Callback appelé après chaque étape
|
||||
|
||||
Returns:
|
||||
Résultats du rejeu
|
||||
"""
|
||||
task = self.learning_manager.load_task(task_id)
|
||||
|
||||
if not task:
|
||||
return {
|
||||
"success": False,
|
||||
"error": "task_not_found"
|
||||
}
|
||||
|
||||
signatures = task.metadata.get("signatures", [])
|
||||
results = {
|
||||
"task_id": task_id,
|
||||
"steps": []
|
||||
}
|
||||
|
||||
for i, signature in enumerate(signatures):
|
||||
step_result = {
|
||||
"step": i + 1,
|
||||
"description": signature.get("description", "Unknown"),
|
||||
"status": "pending"
|
||||
}
|
||||
|
||||
# Localiser et exécuter
|
||||
location = await self._find_element_visually(signature)
|
||||
|
||||
if location:
|
||||
success = await self._execute_action_at_location(
|
||||
signature,
|
||||
location,
|
||||
False
|
||||
)
|
||||
step_result["status"] = "success" if success else "failed"
|
||||
step_result["location"] = location
|
||||
else:
|
||||
step_result["status"] = "not_found"
|
||||
|
||||
results["steps"].append(step_result)
|
||||
|
||||
# Callback
|
||||
if on_step_completed:
|
||||
on_step_completed(step_result)
|
||||
|
||||
await asyncio.sleep(self.delay_between_actions)
|
||||
|
||||
results["success"] = all(
|
||||
s["status"] == "success" for s in results["steps"]
|
||||
)
|
||||
|
||||
return results
|
||||
406
geniusia2/core/ui_change_detector.py
Normal file
406
geniusia2/core/ui_change_detector.py
Normal file
@@ -0,0 +1,406 @@
|
||||
"""
|
||||
Détecteur de changements UI pour RPA Vision V2.
|
||||
Détecte les dérives d'interface et déclenche le ré-entraînement si nécessaire.
|
||||
"""
|
||||
|
||||
import numpy as np
|
||||
from typing import Dict, List, Optional, Any, Tuple, TYPE_CHECKING
|
||||
from datetime import datetime
|
||||
|
||||
if TYPE_CHECKING:
|
||||
from .embeddings_manager import EmbeddingsManager
|
||||
from .logger import Logger
|
||||
|
||||
|
||||
class UIChangeDetector:
|
||||
"""
|
||||
Détecteur de changements UI qui surveille les dérives d'interface
|
||||
et déclenche le ré-entraînement lorsque nécessaire.
|
||||
"""
|
||||
|
||||
def __init__(
|
||||
self,
|
||||
embeddings_manager: "EmbeddingsManager",
|
||||
logger: "Logger",
|
||||
config: Dict[str, Any]
|
||||
):
|
||||
"""
|
||||
Initialise le détecteur de changements UI.
|
||||
|
||||
Args:
|
||||
embeddings_manager: Gestionnaire d'embeddings pour comparaison
|
||||
logger: Logger pour la journalisation
|
||||
config: Configuration globale
|
||||
"""
|
||||
self.embeddings_manager = embeddings_manager
|
||||
self.logger = logger
|
||||
self.config = config
|
||||
|
||||
# Seuils de configuration
|
||||
self.ui_change_threshold = config.get("thresholds", {}).get(
|
||||
"ui_change_similarity", 0.70
|
||||
)
|
||||
self.bbox_delta_threshold = config.get("thresholds", {}).get(
|
||||
"bbox_delta_pixels", 10
|
||||
)
|
||||
|
||||
# Historique des changements détectés
|
||||
self.change_history: List[Dict[str, Any]] = []
|
||||
|
||||
self.logger.log_action({
|
||||
"action": "ui_change_detector_initialized",
|
||||
"ui_change_threshold": self.ui_change_threshold,
|
||||
"bbox_delta_threshold": self.bbox_delta_threshold
|
||||
})
|
||||
|
||||
def detect_ui_change(
|
||||
self,
|
||||
current_embedding: np.ndarray,
|
||||
stored_embeddings: List[np.ndarray],
|
||||
task_id: str
|
||||
) -> Tuple[bool, float]:
|
||||
"""
|
||||
Détecte si l'UI a changé en comparant l'embedding actuel
|
||||
avec les embeddings stockés.
|
||||
|
||||
Args:
|
||||
current_embedding: Embedding visuel actuel de l'élément UI
|
||||
stored_embeddings: Liste des embeddings stockés pour cette tâche
|
||||
task_id: ID de la tâche concernée
|
||||
|
||||
Returns:
|
||||
Tuple (changement_détecté, similarité_max)
|
||||
- changement_détecté: True si similarité < seuil (70%)
|
||||
- similarité_max: Meilleure similarité trouvée
|
||||
"""
|
||||
if not stored_embeddings or len(stored_embeddings) == 0:
|
||||
self.logger.log_action({
|
||||
"action": "ui_change_check_skipped",
|
||||
"task_id": task_id,
|
||||
"reason": "no_stored_embeddings"
|
||||
})
|
||||
return False, 0.0
|
||||
|
||||
try:
|
||||
# Calculer la similarité avec tous les embeddings stockés
|
||||
similarities = []
|
||||
for stored_emb in stored_embeddings:
|
||||
similarity = self.embeddings_manager.get_embedding_similarity(
|
||||
current_embedding,
|
||||
stored_emb
|
||||
)
|
||||
similarities.append(similarity)
|
||||
|
||||
# Prendre la meilleure similarité
|
||||
max_similarity = max(similarities)
|
||||
avg_similarity = np.mean(similarities)
|
||||
|
||||
# Détecter le changement si la similarité est inférieure au seuil
|
||||
change_detected = max_similarity < self.ui_change_threshold
|
||||
|
||||
# Logger le résultat
|
||||
self.logger.log_action({
|
||||
"action": "ui_change_detected" if change_detected else "ui_stable",
|
||||
"task_id": task_id,
|
||||
"max_similarity": float(max_similarity),
|
||||
"avg_similarity": float(avg_similarity),
|
||||
"threshold": self.ui_change_threshold,
|
||||
"num_stored_embeddings": len(stored_embeddings)
|
||||
})
|
||||
|
||||
# Enregistrer dans l'historique si changement détecté
|
||||
if change_detected:
|
||||
self.change_history.append({
|
||||
"timestamp": datetime.now().isoformat(),
|
||||
"task_id": task_id,
|
||||
"max_similarity": float(max_similarity),
|
||||
"avg_similarity": float(avg_similarity),
|
||||
"threshold": self.ui_change_threshold
|
||||
})
|
||||
|
||||
return change_detected, max_similarity
|
||||
|
||||
except Exception as e:
|
||||
self.logger.log_action({
|
||||
"action": "ui_change_detection_error",
|
||||
"task_id": task_id,
|
||||
"error": str(e)
|
||||
})
|
||||
# En cas d'erreur, considérer qu'il n'y a pas de changement
|
||||
return False, 0.0
|
||||
|
||||
def calculate_delta(
|
||||
self,
|
||||
predicted_bbox: Tuple[int, int, int, int],
|
||||
actual_bbox: Tuple[int, int, int, int]
|
||||
) -> Dict[str, float]:
|
||||
"""
|
||||
Calcule le delta en pixels entre la bbox prédite et la bbox réelle.
|
||||
|
||||
Args:
|
||||
predicted_bbox: Bounding box prédite (x, y, width, height)
|
||||
actual_bbox: Bounding box réelle (x, y, width, height)
|
||||
|
||||
Returns:
|
||||
Dictionnaire avec les deltas:
|
||||
- delta_x: Différence en x (pixels)
|
||||
- delta_y: Différence en y (pixels)
|
||||
- delta_width: Différence en largeur (pixels)
|
||||
- delta_height: Différence en hauteur (pixels)
|
||||
- delta_center: Distance euclidienne entre les centres (pixels)
|
||||
- max_delta: Delta maximum parmi x, y
|
||||
"""
|
||||
try:
|
||||
pred_x, pred_y, pred_w, pred_h = predicted_bbox
|
||||
actual_x, actual_y, actual_w, actual_h = actual_bbox
|
||||
|
||||
# Calculer les deltas
|
||||
delta_x = abs(actual_x - pred_x)
|
||||
delta_y = abs(actual_y - pred_y)
|
||||
delta_width = abs(actual_w - pred_w)
|
||||
delta_height = abs(actual_h - pred_h)
|
||||
|
||||
# Calculer les centres
|
||||
pred_center_x = pred_x + pred_w / 2
|
||||
pred_center_y = pred_y + pred_h / 2
|
||||
actual_center_x = actual_x + actual_w / 2
|
||||
actual_center_y = actual_y + actual_h / 2
|
||||
|
||||
# Distance euclidienne entre les centres
|
||||
delta_center = np.sqrt(
|
||||
(actual_center_x - pred_center_x) ** 2 +
|
||||
(actual_center_y - pred_center_y) ** 2
|
||||
)
|
||||
|
||||
# Delta maximum (position)
|
||||
max_delta = max(delta_x, delta_y)
|
||||
|
||||
deltas = {
|
||||
"delta_x": float(delta_x),
|
||||
"delta_y": float(delta_y),
|
||||
"delta_width": float(delta_width),
|
||||
"delta_height": float(delta_height),
|
||||
"delta_center": float(delta_center),
|
||||
"max_delta": float(max_delta)
|
||||
}
|
||||
|
||||
self.logger.log_action({
|
||||
"action": "bbox_delta_calculated",
|
||||
"predicted_bbox": list(predicted_bbox),
|
||||
"actual_bbox": list(actual_bbox),
|
||||
"deltas": deltas
|
||||
})
|
||||
|
||||
return deltas
|
||||
|
||||
except Exception as e:
|
||||
self.logger.log_action({
|
||||
"action": "bbox_delta_calculation_error",
|
||||
"error": str(e)
|
||||
})
|
||||
return {
|
||||
"delta_x": 0.0,
|
||||
"delta_y": 0.0,
|
||||
"delta_width": 0.0,
|
||||
"delta_height": 0.0,
|
||||
"delta_center": 0.0,
|
||||
"max_delta": 0.0
|
||||
}
|
||||
|
||||
def should_trigger_retraining(
|
||||
self,
|
||||
deltas: Dict[str, float],
|
||||
similarity: float
|
||||
) -> bool:
|
||||
"""
|
||||
Détermine si le ré-entraînement doit être déclenché.
|
||||
|
||||
Args:
|
||||
deltas: Dictionnaire des deltas de bbox
|
||||
similarity: Similarité d'embedding
|
||||
|
||||
Returns:
|
||||
True si ré-entraînement nécessaire
|
||||
"""
|
||||
# Déclencher si delta de position > seuil OU similarité < seuil
|
||||
position_drift = deltas.get("max_delta", 0) > self.bbox_delta_threshold
|
||||
visual_drift = similarity < self.ui_change_threshold
|
||||
|
||||
should_retrain = position_drift or visual_drift
|
||||
|
||||
self.logger.log_action({
|
||||
"action": "retraining_decision",
|
||||
"should_retrain": should_retrain,
|
||||
"position_drift": position_drift,
|
||||
"visual_drift": visual_drift,
|
||||
"max_delta": deltas.get("max_delta", 0),
|
||||
"similarity": float(similarity),
|
||||
"bbox_threshold": self.bbox_delta_threshold,
|
||||
"similarity_threshold": self.ui_change_threshold
|
||||
})
|
||||
|
||||
return should_retrain
|
||||
|
||||
def trigger_retraining(
|
||||
self,
|
||||
task_id: str,
|
||||
reason: str,
|
||||
metadata: Optional[Dict[str, Any]] = None
|
||||
):
|
||||
"""
|
||||
Déclenche le ré-entraînement pour une tâche spécifique.
|
||||
|
||||
Args:
|
||||
task_id: ID de la tâche nécessitant un ré-entraînement
|
||||
reason: Raison du ré-entraînement
|
||||
metadata: Métadonnées additionnelles
|
||||
"""
|
||||
retraining_event = {
|
||||
"timestamp": datetime.now().isoformat(),
|
||||
"task_id": task_id,
|
||||
"reason": reason,
|
||||
"metadata": metadata or {}
|
||||
}
|
||||
|
||||
self.logger.log_action({
|
||||
"action": "retraining_triggered",
|
||||
**retraining_event
|
||||
})
|
||||
|
||||
# Enregistrer dans l'historique
|
||||
self.change_history.append(retraining_event)
|
||||
|
||||
# Note: L'implémentation réelle du ré-entraînement sera gérée
|
||||
# par le LearningManager qui recevra cette notification
|
||||
|
||||
def check_and_trigger_retraining(
|
||||
self,
|
||||
task_id: str,
|
||||
current_embedding: np.ndarray,
|
||||
stored_embeddings: List[np.ndarray],
|
||||
predicted_bbox: Optional[Tuple[int, int, int, int]] = None,
|
||||
actual_bbox: Optional[Tuple[int, int, int, int]] = None
|
||||
) -> Dict[str, Any]:
|
||||
"""
|
||||
Vérifie les changements UI et déclenche le ré-entraînement si nécessaire.
|
||||
|
||||
Args:
|
||||
task_id: ID de la tâche
|
||||
current_embedding: Embedding actuel
|
||||
stored_embeddings: Embeddings stockés
|
||||
predicted_bbox: Bbox prédite (optionnel)
|
||||
actual_bbox: Bbox réelle (optionnel)
|
||||
|
||||
Returns:
|
||||
Dictionnaire avec les résultats de la vérification
|
||||
"""
|
||||
result = {
|
||||
"task_id": task_id,
|
||||
"ui_change_detected": False,
|
||||
"position_drift_detected": False,
|
||||
"retraining_triggered": False,
|
||||
"similarity": 0.0,
|
||||
"deltas": {}
|
||||
}
|
||||
|
||||
# Vérifier les changements visuels
|
||||
ui_changed, similarity = self.detect_ui_change(
|
||||
current_embedding,
|
||||
stored_embeddings,
|
||||
task_id
|
||||
)
|
||||
result["ui_change_detected"] = ui_changed
|
||||
result["similarity"] = similarity
|
||||
|
||||
# Vérifier les deltas de position si disponibles
|
||||
if predicted_bbox is not None and actual_bbox is not None:
|
||||
deltas = self.calculate_delta(predicted_bbox, actual_bbox)
|
||||
result["deltas"] = deltas
|
||||
result["position_drift_detected"] = (
|
||||
deltas.get("max_delta", 0) > self.bbox_delta_threshold
|
||||
)
|
||||
else:
|
||||
deltas = {}
|
||||
|
||||
# Décider si ré-entraînement nécessaire
|
||||
if self.should_trigger_retraining(deltas, similarity):
|
||||
reasons = []
|
||||
if ui_changed:
|
||||
reasons.append(f"visual_drift (similarity={similarity:.2f})")
|
||||
if result["position_drift_detected"]:
|
||||
reasons.append(f"position_drift (delta={deltas.get('max_delta', 0):.1f}px)")
|
||||
|
||||
reason = ", ".join(reasons)
|
||||
self.trigger_retraining(
|
||||
task_id,
|
||||
reason,
|
||||
{
|
||||
"similarity": similarity,
|
||||
"deltas": deltas,
|
||||
"ui_change_detected": ui_changed,
|
||||
"position_drift_detected": result["position_drift_detected"]
|
||||
}
|
||||
)
|
||||
result["retraining_triggered"] = True
|
||||
|
||||
return result
|
||||
|
||||
def get_change_history(
|
||||
self,
|
||||
task_id: Optional[str] = None,
|
||||
limit: int = 50
|
||||
) -> List[Dict[str, Any]]:
|
||||
"""
|
||||
Retourne l'historique des changements détectés.
|
||||
|
||||
Args:
|
||||
task_id: Filtrer par task_id (optionnel)
|
||||
limit: Nombre maximum d'entrées à retourner
|
||||
|
||||
Returns:
|
||||
Liste des changements détectés
|
||||
"""
|
||||
history = self.change_history
|
||||
|
||||
# Filtrer par task_id si spécifié
|
||||
if task_id:
|
||||
history = [h for h in history if h.get("task_id") == task_id]
|
||||
|
||||
# Limiter le nombre de résultats
|
||||
return history[-limit:]
|
||||
|
||||
def get_stats(self) -> Dict[str, Any]:
|
||||
"""
|
||||
Retourne des statistiques sur les changements détectés.
|
||||
|
||||
Returns:
|
||||
Dictionnaire de statistiques
|
||||
"""
|
||||
total_changes = len(self.change_history)
|
||||
|
||||
# Compter les changements par tâche
|
||||
changes_by_task = {}
|
||||
for change in self.change_history:
|
||||
task_id = change.get("task_id", "unknown")
|
||||
changes_by_task[task_id] = changes_by_task.get(task_id, 0) + 1
|
||||
|
||||
# Compter les ré-entraînements déclenchés
|
||||
retraining_count = sum(
|
||||
1 for change in self.change_history
|
||||
if change.get("reason") is not None
|
||||
)
|
||||
|
||||
return {
|
||||
"total_changes_detected": total_changes,
|
||||
"retraining_triggered_count": retraining_count,
|
||||
"changes_by_task": changes_by_task,
|
||||
"ui_change_threshold": self.ui_change_threshold,
|
||||
"bbox_delta_threshold": self.bbox_delta_threshold
|
||||
}
|
||||
|
||||
def clear_history(self):
|
||||
"""Efface l'historique des changements."""
|
||||
self.change_history = []
|
||||
self.logger.log_action({
|
||||
"action": "change_history_cleared"
|
||||
})
|
||||
793
geniusia2/core/ui_element_detector.py
Normal file
793
geniusia2/core/ui_element_detector.py
Normal file
@@ -0,0 +1,793 @@
|
||||
"""
|
||||
Détecteur d'éléments UI pour la Phase 2 - Mode Enrichi.
|
||||
Implémente le pipeline complet de détection d'éléments UI.
|
||||
|
||||
Pipeline:
|
||||
1. RegionProposer - Propose des régions d'intérêt candidates
|
||||
2. ElementCharacterizer - Caractérise chaque région
|
||||
3. ElementClassifier - Classifie type et rôle
|
||||
"""
|
||||
|
||||
import cv2
|
||||
import numpy as np
|
||||
from typing import List, Dict, Any, Tuple, Optional
|
||||
from dataclasses import dataclass
|
||||
from pathlib import Path
|
||||
|
||||
from .ui_element_models import (
|
||||
UIElement,
|
||||
UIElementType,
|
||||
VisualData,
|
||||
TextData,
|
||||
ElementProperties,
|
||||
ElementContext,
|
||||
WindowInfo
|
||||
)
|
||||
from .llm_manager import LLMManager
|
||||
from .logger import Logger
|
||||
|
||||
|
||||
@dataclass
|
||||
class BoundingBox:
|
||||
"""Représente une bounding box."""
|
||||
x1: int
|
||||
y1: int
|
||||
x2: int
|
||||
y2: int
|
||||
confidence: float = 1.0
|
||||
source: str = "unknown"
|
||||
|
||||
def area(self) -> int:
|
||||
"""Calcule l'aire de la bbox."""
|
||||
return (self.x2 - self.x1) * (self.y2 - self.y1)
|
||||
|
||||
def center(self) -> Tuple[int, int]:
|
||||
"""Calcule le centre de la bbox."""
|
||||
return ((self.x1 + self.x2) // 2, (self.y1 + self.y2) // 2)
|
||||
|
||||
def iou(self, other: 'BoundingBox') -> float:
|
||||
"""Calcule l'Intersection over Union avec une autre bbox."""
|
||||
# Intersection
|
||||
x1_inter = max(self.x1, other.x1)
|
||||
y1_inter = max(self.y1, other.y1)
|
||||
x2_inter = min(self.x2, other.x2)
|
||||
y2_inter = min(self.y2, other.y2)
|
||||
|
||||
if x2_inter < x1_inter or y2_inter < y1_inter:
|
||||
return 0.0
|
||||
|
||||
inter_area = (x2_inter - x1_inter) * (y2_inter - y1_inter)
|
||||
|
||||
# Union
|
||||
union_area = self.area() + other.area() - inter_area
|
||||
|
||||
return inter_area / union_area if union_area > 0 else 0.0
|
||||
|
||||
|
||||
class RegionProposer:
|
||||
"""
|
||||
Propose des régions d'intérêt candidates pour les éléments UI.
|
||||
|
||||
Méthodes:
|
||||
1. Détection de zones de texte (rapide)
|
||||
2. Détection de rectangles autour de texte (heuristique)
|
||||
3. Requête VLM pour zones cliquables (optionnel, coûteux)
|
||||
"""
|
||||
|
||||
def __init__(
|
||||
self,
|
||||
llm_manager: Optional[LLMManager] = None,
|
||||
logger: Optional[Logger] = None,
|
||||
config: Optional[Dict[str, Any]] = None
|
||||
):
|
||||
"""
|
||||
Initialise le RegionProposer.
|
||||
|
||||
Args:
|
||||
llm_manager: Gestionnaire LLM pour requêtes VLM
|
||||
logger: Logger pour journalisation
|
||||
config: Configuration
|
||||
"""
|
||||
self.llm = llm_manager
|
||||
self.logger = logger
|
||||
self.config = config or {}
|
||||
|
||||
# Configuration
|
||||
self.use_text_detection = self.config.get("use_text_detection", True)
|
||||
self.use_rectangle_detection = self.config.get("use_rectangle_detection", True)
|
||||
self.use_vlm_detection = self.config.get("use_vlm_detection", False)
|
||||
self.min_region_size = self.config.get("min_region_size", 20)
|
||||
self.max_region_size = self.config.get("max_region_size", 500)
|
||||
|
||||
def propose_regions(
|
||||
self,
|
||||
screenshot: np.ndarray,
|
||||
window_info: WindowInfo
|
||||
) -> List[BoundingBox]:
|
||||
"""
|
||||
Propose des régions d'intérêt candidates.
|
||||
|
||||
Args:
|
||||
screenshot: Screenshot numpy array
|
||||
window_info: Informations sur la fenêtre
|
||||
|
||||
Returns:
|
||||
Liste de BoundingBox candidates
|
||||
"""
|
||||
regions = []
|
||||
|
||||
# Méthode 1: Zones de texte (rapide)
|
||||
if self.use_text_detection:
|
||||
text_regions = self._detect_text_regions(screenshot)
|
||||
regions.extend(text_regions)
|
||||
if self.logger:
|
||||
self.logger.log_action({
|
||||
"action": "text_regions_detected",
|
||||
"count": len(text_regions)
|
||||
})
|
||||
|
||||
# Méthode 2: Rectangles propres (heuristique)
|
||||
if self.use_rectangle_detection:
|
||||
rect_regions = self._detect_rectangles(screenshot)
|
||||
regions.extend(rect_regions)
|
||||
if self.logger:
|
||||
self.logger.log_action({
|
||||
"action": "rectangle_regions_detected",
|
||||
"count": len(rect_regions)
|
||||
})
|
||||
|
||||
# Méthode 3: VLM (précis mais lent)
|
||||
if self.use_vlm_detection and self._should_use_vlm(window_info):
|
||||
vlm_regions = self._query_vlm_for_regions(screenshot)
|
||||
regions.extend(vlm_regions)
|
||||
if self.logger:
|
||||
self.logger.log_action({
|
||||
"action": "vlm_regions_detected",
|
||||
"count": len(vlm_regions)
|
||||
})
|
||||
|
||||
# Fusion et nettoyage
|
||||
regions = self._merge_overlapping_regions(regions)
|
||||
regions = self._filter_invalid_regions(regions, screenshot.shape)
|
||||
|
||||
if self.logger:
|
||||
self.logger.log_action({
|
||||
"action": "regions_proposed",
|
||||
"total_count": len(regions)
|
||||
})
|
||||
|
||||
return regions
|
||||
|
||||
def _detect_text_regions(self, screenshot: np.ndarray) -> List[BoundingBox]:
|
||||
"""Détecte les zones de texte en utilisant des heuristiques simples."""
|
||||
regions = []
|
||||
|
||||
try:
|
||||
# Convertir en niveaux de gris
|
||||
gray = cv2.cvtColor(screenshot, cv2.COLOR_BGR2GRAY)
|
||||
|
||||
# Appliquer un seuil adaptatif
|
||||
thresh = cv2.adaptiveThreshold(
|
||||
gray, 255, cv2.ADAPTIVE_THRESH_GAUSSIAN_C,
|
||||
cv2.THRESH_BINARY_INV, 11, 2
|
||||
)
|
||||
|
||||
# Trouver les contours
|
||||
contours, _ = cv2.findContours(
|
||||
thresh, cv2.RETR_EXTERNAL, cv2.CHAIN_APPROX_SIMPLE
|
||||
)
|
||||
|
||||
# Filtrer et créer des bboxes
|
||||
for contour in contours:
|
||||
x, y, w, h = cv2.boundingRect(contour)
|
||||
|
||||
# Filtrer par taille
|
||||
if w < self.min_region_size or h < self.min_region_size:
|
||||
continue
|
||||
if w > self.max_region_size or h > self.max_region_size:
|
||||
continue
|
||||
|
||||
# Filtrer par ratio (texte généralement horizontal)
|
||||
ratio = w / h if h > 0 else 0
|
||||
if ratio < 0.5 or ratio > 20:
|
||||
continue
|
||||
|
||||
regions.append(BoundingBox(
|
||||
x1=x, y1=y, x2=x+w, y2=y+h,
|
||||
confidence=0.7,
|
||||
source="text_detection"
|
||||
))
|
||||
|
||||
except Exception as e:
|
||||
if self.logger:
|
||||
self.logger.log_action({
|
||||
"action": "text_detection_error",
|
||||
"error": str(e)
|
||||
})
|
||||
|
||||
return regions
|
||||
|
||||
def _detect_rectangles(self, screenshot: np.ndarray) -> List[BoundingBox]:
|
||||
"""Détecte les rectangles propres (boutons, champs, etc.)."""
|
||||
regions = []
|
||||
|
||||
try:
|
||||
# Convertir en niveaux de gris
|
||||
gray = cv2.cvtColor(screenshot, cv2.COLOR_BGR2GRAY)
|
||||
|
||||
# Détection de contours avec Canny
|
||||
edges = cv2.Canny(gray, 50, 150)
|
||||
|
||||
# Dilatation pour connecter les contours
|
||||
kernel = np.ones((3, 3), np.uint8)
|
||||
dilated = cv2.dilate(edges, kernel, iterations=1)
|
||||
|
||||
# Trouver les contours
|
||||
contours, _ = cv2.findContours(
|
||||
dilated, cv2.RETR_EXTERNAL, cv2.CHAIN_APPROX_SIMPLE
|
||||
)
|
||||
|
||||
# Filtrer et créer des bboxes
|
||||
for contour in contours:
|
||||
# Approximer le contour
|
||||
epsilon = 0.02 * cv2.arcLength(contour, True)
|
||||
approx = cv2.approxPolyDP(contour, epsilon, True)
|
||||
|
||||
# Garder seulement les formes rectangulaires (4 coins)
|
||||
if len(approx) >= 4:
|
||||
x, y, w, h = cv2.boundingRect(contour)
|
||||
|
||||
# Filtrer par taille
|
||||
if w < self.min_region_size or h < self.min_region_size:
|
||||
continue
|
||||
if w > self.max_region_size or h > self.max_region_size:
|
||||
continue
|
||||
|
||||
regions.append(BoundingBox(
|
||||
x1=x, y1=y, x2=x+w, y2=y+h,
|
||||
confidence=0.8,
|
||||
source="rectangle_detection"
|
||||
))
|
||||
|
||||
except Exception as e:
|
||||
if self.logger:
|
||||
self.logger.log_action({
|
||||
"action": "rectangle_detection_error",
|
||||
"error": str(e)
|
||||
})
|
||||
|
||||
return regions
|
||||
|
||||
def _query_vlm_for_regions(self, screenshot: np.ndarray) -> List[BoundingBox]:
|
||||
"""Utilise le VLM pour détecter les zones cliquables."""
|
||||
regions = []
|
||||
|
||||
if not self.llm or not self.llm.is_available():
|
||||
return regions
|
||||
|
||||
try:
|
||||
prompt = """Identifie tous les éléments cliquables dans cette interface (boutons, liens, champs).
|
||||
Réponds avec une liste de coordonnées au format: x1,y1,x2,y2;x1,y1,x2,y2;..."""
|
||||
|
||||
response = self.llm.generate_with_vision(prompt, [screenshot])
|
||||
|
||||
# Parser la réponse
|
||||
# Format attendu: "x1,y1,x2,y2;x1,y1,x2,y2;..."
|
||||
if response and ';' in response:
|
||||
for bbox_str in response.split(';'):
|
||||
try:
|
||||
coords = [int(x.strip()) for x in bbox_str.split(',')]
|
||||
if len(coords) == 4:
|
||||
regions.append(BoundingBox(
|
||||
x1=coords[0], y1=coords[1],
|
||||
x2=coords[2], y2=coords[3],
|
||||
confidence=0.9,
|
||||
source="vlm_detection"
|
||||
))
|
||||
except:
|
||||
continue
|
||||
|
||||
except Exception as e:
|
||||
if self.logger:
|
||||
self.logger.log_action({
|
||||
"action": "vlm_detection_error",
|
||||
"error": str(e)
|
||||
})
|
||||
|
||||
return regions
|
||||
|
||||
def _should_use_vlm(self, window_info: WindowInfo) -> bool:
|
||||
"""Décide si on doit utiliser le VLM (coûteux)."""
|
||||
# Pour l'instant, toujours False sauf si explicitement activé
|
||||
return self.use_vlm_detection
|
||||
|
||||
def _merge_overlapping_regions(
|
||||
self,
|
||||
regions: List[BoundingBox],
|
||||
iou_threshold: float = 0.5
|
||||
) -> List[BoundingBox]:
|
||||
"""Fusionne les régions qui se chevauchent."""
|
||||
if not regions:
|
||||
return []
|
||||
|
||||
# Trier par confiance décroissante
|
||||
regions = sorted(regions, key=lambda r: r.confidence, reverse=True)
|
||||
|
||||
merged = []
|
||||
used = set()
|
||||
|
||||
for i, region in enumerate(regions):
|
||||
if i in used:
|
||||
continue
|
||||
|
||||
# Chercher les régions qui se chevauchent
|
||||
overlapping = [region]
|
||||
for j, other in enumerate(regions[i+1:], start=i+1):
|
||||
if j in used:
|
||||
continue
|
||||
|
||||
if region.iou(other) > iou_threshold:
|
||||
overlapping.append(other)
|
||||
used.add(j)
|
||||
|
||||
# Fusionner en prenant l'union
|
||||
if len(overlapping) > 1:
|
||||
x1 = min(r.x1 for r in overlapping)
|
||||
y1 = min(r.y1 for r in overlapping)
|
||||
x2 = max(r.x2 for r in overlapping)
|
||||
y2 = max(r.y2 for r in overlapping)
|
||||
conf = max(r.confidence for r in overlapping)
|
||||
|
||||
merged.append(BoundingBox(
|
||||
x1=x1, y1=y1, x2=x2, y2=y2,
|
||||
confidence=conf,
|
||||
source="merged"
|
||||
))
|
||||
else:
|
||||
merged.append(region)
|
||||
|
||||
return merged
|
||||
|
||||
def _filter_invalid_regions(
|
||||
self,
|
||||
regions: List[BoundingBox],
|
||||
image_shape: Tuple[int, ...]
|
||||
) -> List[BoundingBox]:
|
||||
"""Filtre les régions invalides."""
|
||||
height, width = image_shape[:2]
|
||||
|
||||
valid = []
|
||||
for region in regions:
|
||||
# Vérifier que la région est dans l'image
|
||||
if region.x1 < 0 or region.y1 < 0:
|
||||
continue
|
||||
if region.x2 > width or region.y2 > height:
|
||||
continue
|
||||
|
||||
# Vérifier la taille
|
||||
w = region.x2 - region.x1
|
||||
h = region.y2 - region.y1
|
||||
|
||||
if w < self.min_region_size or h < self.min_region_size:
|
||||
continue
|
||||
if w > self.max_region_size or h > self.max_region_size:
|
||||
continue
|
||||
|
||||
valid.append(region)
|
||||
|
||||
return valid
|
||||
|
||||
|
||||
# À suivre: ElementCharacterizer et ElementClassifier dans la prochaine partie...
|
||||
|
||||
|
||||
|
||||
class ElementCharacterizer:
|
||||
"""
|
||||
Caractérise chaque région détectée en extrayant ses caractéristiques.
|
||||
|
||||
Extrait:
|
||||
- Crop image de la région
|
||||
- Embedding image via CLIP
|
||||
- Texte dans/autour de la région via VLM
|
||||
- Embedding texte
|
||||
- Position bbox
|
||||
"""
|
||||
|
||||
def __init__(
|
||||
self,
|
||||
llm_manager: Optional[LLMManager] = None,
|
||||
image_embedder: Optional[Any] = None,
|
||||
text_embedder: Optional[Any] = None,
|
||||
logger: Optional[Logger] = None
|
||||
):
|
||||
"""
|
||||
Initialise l'ElementCharacterizer.
|
||||
|
||||
Args:
|
||||
llm_manager: Gestionnaire LLM pour extraction de texte
|
||||
image_embedder: Embedder d'images (CLIP)
|
||||
text_embedder: Embedder de texte
|
||||
logger: Logger
|
||||
"""
|
||||
self.llm = llm_manager
|
||||
self.image_embedder = image_embedder
|
||||
self.text_embedder = text_embedder
|
||||
self.logger = logger
|
||||
|
||||
def characterize(
|
||||
self,
|
||||
screenshot: np.ndarray,
|
||||
region: BoundingBox,
|
||||
window_info: WindowInfo,
|
||||
data_dir: str = "data"
|
||||
) -> Optional[Dict[str, Any]]:
|
||||
"""
|
||||
Caractérise une région détectée.
|
||||
|
||||
Args:
|
||||
screenshot: Screenshot complet
|
||||
region: Région à caractériser
|
||||
window_info: Informations sur la fenêtre
|
||||
data_dir: Répertoire de données
|
||||
|
||||
Returns:
|
||||
Dictionnaire avec les caractéristiques ou None si échec
|
||||
"""
|
||||
try:
|
||||
# Extraire le crop
|
||||
crop = screenshot[region.y1:region.y2, region.x1:region.x2]
|
||||
|
||||
if crop.size == 0:
|
||||
return None
|
||||
|
||||
# Sauvegarder le crop (temporaire)
|
||||
crop_path = Path(data_dir) / "temp" / f"crop_{region.x1}_{region.y1}.png"
|
||||
crop_path.parent.mkdir(parents=True, exist_ok=True)
|
||||
cv2.imwrite(str(crop_path), crop)
|
||||
|
||||
# Extraire le texte via VLM
|
||||
text = self._extract_text(crop)
|
||||
|
||||
# Générer les embeddings (simulés pour l'instant)
|
||||
image_embedding_path = str(crop_path).replace(".png", "_img_emb.npy")
|
||||
text_embedding_path = str(crop_path).replace(".png", "_text_emb.npy")
|
||||
|
||||
# Sauvegarder des embeddings factices (à remplacer par vrais embeddings)
|
||||
np.save(image_embedding_path, np.random.rand(512))
|
||||
np.save(text_embedding_path, np.random.rand(512))
|
||||
|
||||
return {
|
||||
"crop_path": str(crop_path),
|
||||
"crop_image": crop,
|
||||
"text_raw": text,
|
||||
"text_normalized": text.lower().strip(),
|
||||
"bbox": (region.x1, region.y1, region.x2, region.y2),
|
||||
"image_embedding_path": image_embedding_path,
|
||||
"text_embedding_path": text_embedding_path,
|
||||
"confidence": region.confidence,
|
||||
"source": region.source
|
||||
}
|
||||
|
||||
except Exception as e:
|
||||
if self.logger:
|
||||
self.logger.log_action({
|
||||
"action": "characterization_error",
|
||||
"error": str(e),
|
||||
"region": f"({region.x1},{region.y1},{region.x2},{region.y2})"
|
||||
})
|
||||
return None
|
||||
|
||||
def _extract_text(self, crop: np.ndarray) -> str:
|
||||
"""Extrait le texte d'un crop via VLM."""
|
||||
if not self.llm or not self.llm.is_available():
|
||||
return ""
|
||||
|
||||
try:
|
||||
prompt = "Quel est le texte visible dans cette image? Réponds UNIQUEMENT avec le texte, rien d'autre."
|
||||
response = self.llm.generate_with_vision(prompt, [crop])
|
||||
return response.strip() if response else ""
|
||||
|
||||
except Exception as e:
|
||||
if self.logger:
|
||||
self.logger.log_action({
|
||||
"action": "text_extraction_error",
|
||||
"error": str(e)
|
||||
})
|
||||
return ""
|
||||
|
||||
|
||||
class ElementClassifier:
|
||||
"""
|
||||
Classifie le type et le rôle sémantique des éléments.
|
||||
|
||||
Utilise:
|
||||
- Caractéristiques visuelles
|
||||
- Analyse textuelle VLM
|
||||
- Heuristiques
|
||||
"""
|
||||
|
||||
def __init__(
|
||||
self,
|
||||
llm_manager: Optional[LLMManager] = None,
|
||||
logger: Optional[Logger] = None,
|
||||
config: Optional[Dict[str, Any]] = None
|
||||
):
|
||||
"""
|
||||
Initialise l'ElementClassifier.
|
||||
|
||||
Args:
|
||||
llm_manager: Gestionnaire LLM
|
||||
logger: Logger
|
||||
config: Configuration
|
||||
"""
|
||||
self.llm = llm_manager
|
||||
self.logger = logger
|
||||
self.config = config or {}
|
||||
|
||||
def classify(
|
||||
self,
|
||||
characterized_element: Dict[str, Any],
|
||||
window_info: WindowInfo
|
||||
) -> Tuple[UIElementType, str, float]:
|
||||
"""
|
||||
Classifie un élément caractérisé.
|
||||
|
||||
Args:
|
||||
characterized_element: Élément caractérisé
|
||||
window_info: Informations sur la fenêtre
|
||||
|
||||
Returns:
|
||||
Tuple (type, role, confidence)
|
||||
"""
|
||||
text = characterized_element.get("text_raw", "").lower()
|
||||
bbox = characterized_element.get("bbox", (0, 0, 0, 0))
|
||||
|
||||
# Classification basée sur le texte et les heuristiques
|
||||
element_type, confidence = self._classify_type(text, bbox)
|
||||
role = self._infer_role(text, element_type)
|
||||
|
||||
return element_type, role, confidence
|
||||
|
||||
def _classify_type(
|
||||
self,
|
||||
text: str,
|
||||
bbox: Tuple[int, int, int, int]
|
||||
) -> Tuple[UIElementType, float]:
|
||||
"""Classifie le type d'élément."""
|
||||
# Heuristiques simples basées sur le texte
|
||||
text_lower = text.lower()
|
||||
|
||||
# Boutons
|
||||
button_keywords = ["valider", "ok", "annuler", "enregistrer", "submit", "cancel", "save", "delete", "supprimer"]
|
||||
if any(kw in text_lower for kw in button_keywords):
|
||||
return UIElementType.BUTTON, 0.8
|
||||
|
||||
# Champs de texte
|
||||
if "recherche" in text_lower or "search" in text_lower:
|
||||
return UIElementType.TEXT_INPUT, 0.7
|
||||
|
||||
# Liens
|
||||
link_keywords = ["lien", "link", "voir", "view", "plus", "more"]
|
||||
if any(kw in text_lower for kw in link_keywords):
|
||||
return UIElementType.LINK, 0.6
|
||||
|
||||
# Par défaut: élément interactif générique
|
||||
return UIElementType.GENERIC_INTERACTIVE, 0.5
|
||||
|
||||
def _infer_role(self, text: str, element_type: UIElementType) -> str:
|
||||
"""Infère le rôle sémantique."""
|
||||
text_lower = text.lower()
|
||||
|
||||
# Rôles basés sur le texte
|
||||
if "valider" in text_lower or "submit" in text_lower:
|
||||
return "validate_action"
|
||||
elif "annuler" in text_lower or "cancel" in text_lower:
|
||||
return "cancel_action"
|
||||
elif "enregistrer" in text_lower or "save" in text_lower:
|
||||
return "save_action"
|
||||
elif "supprimer" in text_lower or "delete" in text_lower:
|
||||
return "delete_action"
|
||||
elif "recherche" in text_lower or "search" in text_lower:
|
||||
return "search_field"
|
||||
else:
|
||||
return "generic_action"
|
||||
|
||||
|
||||
class UIElementDetector:
|
||||
"""
|
||||
Détecteur principal d'éléments UI.
|
||||
Orchestre le pipeline complet: RegionProposer → ElementCharacterizer → ElementClassifier
|
||||
"""
|
||||
|
||||
def __init__(
|
||||
self,
|
||||
llm_manager: Optional[LLMManager] = None,
|
||||
image_embedder: Optional[Any] = None,
|
||||
text_embedder: Optional[Any] = None,
|
||||
logger: Optional[Logger] = None,
|
||||
config: Optional[Dict[str, Any]] = None
|
||||
):
|
||||
"""
|
||||
Initialise le UIElementDetector.
|
||||
|
||||
Args:
|
||||
llm_manager: Gestionnaire LLM
|
||||
image_embedder: Embedder d'images
|
||||
text_embedder: Embedder de texte
|
||||
logger: Logger
|
||||
config: Configuration
|
||||
"""
|
||||
self.llm = llm_manager
|
||||
self.logger = logger
|
||||
self.config = config or {}
|
||||
|
||||
# Créer les composants du pipeline
|
||||
self.region_proposer = RegionProposer(
|
||||
llm_manager=llm_manager,
|
||||
logger=logger,
|
||||
config=self.config.get("region_proposer", {})
|
||||
)
|
||||
|
||||
self.characterizer = ElementCharacterizer(
|
||||
llm_manager=llm_manager,
|
||||
image_embedder=image_embedder,
|
||||
text_embedder=text_embedder,
|
||||
logger=logger
|
||||
)
|
||||
|
||||
self.classifier = ElementClassifier(
|
||||
llm_manager=llm_manager,
|
||||
logger=logger,
|
||||
config=self.config.get("classifier", {})
|
||||
)
|
||||
|
||||
def detect_elements(
|
||||
self,
|
||||
screenshot: np.ndarray,
|
||||
window_info: WindowInfo,
|
||||
data_dir: str = "data"
|
||||
) -> List[UIElement]:
|
||||
"""
|
||||
Pipeline complet de détection d'éléments.
|
||||
|
||||
Args:
|
||||
screenshot: Screenshot numpy array
|
||||
window_info: Informations sur la fenêtre
|
||||
data_dir: Répertoire de données
|
||||
|
||||
Returns:
|
||||
Liste de UIElement détectés
|
||||
"""
|
||||
elements = []
|
||||
|
||||
try:
|
||||
# Étape 1: Proposer des régions
|
||||
if self.logger:
|
||||
self.logger.log_action({"action": "detection_started"})
|
||||
|
||||
regions = self.region_proposer.propose_regions(screenshot, window_info)
|
||||
|
||||
if self.logger:
|
||||
self.logger.log_action({
|
||||
"action": "regions_proposed",
|
||||
"count": len(regions)
|
||||
})
|
||||
|
||||
# Étape 2: Caractériser chaque région
|
||||
characterized = []
|
||||
for i, region in enumerate(regions):
|
||||
char_elem = self.characterizer.characterize(
|
||||
screenshot, region, window_info, data_dir
|
||||
)
|
||||
if char_elem:
|
||||
characterized.append(char_elem)
|
||||
|
||||
if self.logger:
|
||||
self.logger.log_action({
|
||||
"action": "elements_characterized",
|
||||
"count": len(characterized)
|
||||
})
|
||||
|
||||
# Étape 3: Classifier chaque élément
|
||||
for char_elem in characterized:
|
||||
try:
|
||||
element_type, role, confidence = self.classifier.classify(
|
||||
char_elem, window_info
|
||||
)
|
||||
|
||||
# Créer l'UIElement
|
||||
bbox = char_elem["bbox"]
|
||||
label = char_elem["text_raw"] or "Élément sans texte"
|
||||
|
||||
element_id = UIElement.generate_element_id(
|
||||
app_name=window_info.app_name,
|
||||
bbox=bbox,
|
||||
label=label
|
||||
)
|
||||
|
||||
element = UIElement(
|
||||
element_id=element_id,
|
||||
type=element_type,
|
||||
role=role,
|
||||
bbox=bbox,
|
||||
label=label,
|
||||
visual=VisualData(
|
||||
screenshot_path=char_elem["crop_path"],
|
||||
embedding_provider="openclip_ViT-B-32",
|
||||
embedding_vector_id=char_elem["image_embedding_path"]
|
||||
),
|
||||
text=TextData(
|
||||
raw=char_elem["text_raw"],
|
||||
normalized=char_elem["text_normalized"],
|
||||
embedding_provider="clip_text",
|
||||
embedding_vector_id=char_elem["text_embedding_path"]
|
||||
),
|
||||
properties=ElementProperties(
|
||||
is_clickable=(element_type in [
|
||||
UIElementType.BUTTON,
|
||||
UIElementType.LINK
|
||||
]),
|
||||
is_focusable=(element_type == UIElementType.TEXT_INPUT),
|
||||
is_dangerous=("supprimer" in label.lower() or "delete" in label.lower())
|
||||
),
|
||||
context=ElementContext(
|
||||
app_name=window_info.app_name,
|
||||
window_title=window_info.window_title,
|
||||
workflow_hint=None
|
||||
),
|
||||
tags=[],
|
||||
confidence=confidence,
|
||||
detection_method=char_elem["source"]
|
||||
)
|
||||
|
||||
elements.append(element)
|
||||
|
||||
except Exception as e:
|
||||
if self.logger:
|
||||
self.logger.log_action({
|
||||
"action": "element_classification_error",
|
||||
"error": str(e)
|
||||
})
|
||||
# Continuer avec les autres éléments
|
||||
continue
|
||||
|
||||
if self.logger:
|
||||
self.logger.log_action({
|
||||
"action": "detection_completed",
|
||||
"elements_detected": len(elements)
|
||||
})
|
||||
|
||||
except Exception as e:
|
||||
if self.logger:
|
||||
self.logger.log_action({
|
||||
"action": "detection_pipeline_error",
|
||||
"error": str(e)
|
||||
})
|
||||
|
||||
return elements
|
||||
|
||||
|
||||
if __name__ == "__main__":
|
||||
# Tests basiques
|
||||
print("UIElementDetector - Tests basiques")
|
||||
print("=" * 50)
|
||||
|
||||
# Test BoundingBox
|
||||
print("\n1. Test BoundingBox:")
|
||||
bbox1 = BoundingBox(10, 10, 50, 50)
|
||||
bbox2 = BoundingBox(30, 30, 70, 70)
|
||||
print(f" BBox1: ({bbox1.x1},{bbox1.y1},{bbox1.x2},{bbox1.y2})")
|
||||
print(f" BBox2: ({bbox2.x1},{bbox2.y1},{bbox2.x2},{bbox2.y2})")
|
||||
print(f" IoU: {bbox1.iou(bbox2):.2f}")
|
||||
print(f" Area1: {bbox1.area()}")
|
||||
print(f" Center1: {bbox1.center()}")
|
||||
|
||||
# Test RegionProposer
|
||||
print("\n2. Test RegionProposer:")
|
||||
proposer = RegionProposer()
|
||||
print(f" RegionProposer créé")
|
||||
print(f" use_text_detection: {proposer.use_text_detection}")
|
||||
print(f" use_rectangle_detection: {proposer.use_rectangle_detection}")
|
||||
|
||||
print("\n✓ Tests basiques réussis!")
|
||||
827
geniusia2/core/ui_element_models.py
Normal file
827
geniusia2/core/ui_element_models.py
Normal file
@@ -0,0 +1,827 @@
|
||||
"""
|
||||
Modèles de données pour la détection d'éléments UI et l'état d'écran enrichi.
|
||||
Implémente les structures UIElement et EnrichedScreenState pour le système RPA Vision V2.
|
||||
|
||||
Phase 1 - Mode Light: Structures de base avec compatibilité arrière complète.
|
||||
"""
|
||||
|
||||
from dataclasses import dataclass, field, asdict
|
||||
from datetime import datetime
|
||||
from typing import List, Dict, Any, Tuple, Optional
|
||||
from enum import Enum
|
||||
import json
|
||||
import hashlib
|
||||
import numpy as np
|
||||
|
||||
|
||||
class UIElementType(Enum):
|
||||
"""Types d'éléments UI supportés."""
|
||||
BUTTON = "button"
|
||||
TEXT_INPUT = "text_input"
|
||||
DROPDOWN = "dropdown"
|
||||
TAB = "tab"
|
||||
CHECKBOX = "checkbox"
|
||||
RADIO_BUTTON = "radio_button"
|
||||
LINK = "link"
|
||||
GENERIC_INTERACTIVE = "generic_interactive"
|
||||
|
||||
|
||||
@dataclass
|
||||
class VisualData:
|
||||
"""Données visuelles d'un élément UI."""
|
||||
screenshot_path: str
|
||||
embedding_provider: str # ex: "openclip_ViT-B-32"
|
||||
embedding_vector_id: str # chemin vers le fichier .npy
|
||||
|
||||
def to_dict(self) -> Dict[str, Any]:
|
||||
"""Convertit en dictionnaire pour sérialisation JSON."""
|
||||
return {
|
||||
"screenshot_path": self.screenshot_path,
|
||||
"embedding": {
|
||||
"provider": self.embedding_provider,
|
||||
"vector_id": self.embedding_vector_id
|
||||
}
|
||||
}
|
||||
|
||||
@classmethod
|
||||
def from_dict(cls, data: Dict[str, Any]) -> 'VisualData':
|
||||
"""Crée une instance depuis un dictionnaire."""
|
||||
if "embedding" in data:
|
||||
# Nouveau format
|
||||
return cls(
|
||||
screenshot_path=data["screenshot_path"],
|
||||
embedding_provider=data["embedding"]["provider"],
|
||||
embedding_vector_id=data["embedding"]["vector_id"]
|
||||
)
|
||||
else:
|
||||
# Format legacy
|
||||
return cls(
|
||||
screenshot_path=data["screenshot_path"],
|
||||
embedding_provider=data.get("embedding_provider", ""),
|
||||
embedding_vector_id=data.get("embedding_vector_id", "")
|
||||
)
|
||||
|
||||
|
||||
@dataclass
|
||||
class TextData:
|
||||
"""Données textuelles d'un élément UI."""
|
||||
raw: str
|
||||
normalized: str
|
||||
embedding_provider: str # ex: "clip_text"
|
||||
embedding_vector_id: str # chemin vers le fichier .npy
|
||||
|
||||
def to_dict(self) -> Dict[str, Any]:
|
||||
"""Convertit en dictionnaire pour sérialisation JSON."""
|
||||
return {
|
||||
"raw": self.raw,
|
||||
"normalized": self.normalized,
|
||||
"embedding": {
|
||||
"provider": self.embedding_provider,
|
||||
"vector_id": self.embedding_vector_id
|
||||
}
|
||||
}
|
||||
|
||||
@classmethod
|
||||
def from_dict(cls, data: Dict[str, Any]) -> 'TextData':
|
||||
"""Crée une instance depuis un dictionnaire."""
|
||||
if "embedding" in data:
|
||||
# Nouveau format
|
||||
return cls(
|
||||
raw=data["raw"],
|
||||
normalized=data["normalized"],
|
||||
embedding_provider=data["embedding"]["provider"],
|
||||
embedding_vector_id=data["embedding"]["vector_id"]
|
||||
)
|
||||
else:
|
||||
# Format legacy
|
||||
return cls(
|
||||
raw=data.get("raw", ""),
|
||||
normalized=data.get("normalized", ""),
|
||||
embedding_provider=data.get("embedding_provider", ""),
|
||||
embedding_vector_id=data.get("embedding_vector_id", "")
|
||||
)
|
||||
|
||||
|
||||
@dataclass
|
||||
class ElementProperties:
|
||||
"""Propriétés d'un élément UI."""
|
||||
is_clickable: bool = False
|
||||
is_focusable: bool = False
|
||||
is_dangerous: bool = False
|
||||
|
||||
def to_dict(self) -> Dict[str, Any]:
|
||||
"""Convertit en dictionnaire pour sérialisation JSON."""
|
||||
return {
|
||||
"is_clickable": self.is_clickable,
|
||||
"is_focusable": self.is_focusable,
|
||||
"is_dangerous": self.is_dangerous
|
||||
}
|
||||
|
||||
@classmethod
|
||||
def from_dict(cls, data: Dict[str, Any]) -> 'ElementProperties':
|
||||
"""Crée une instance depuis un dictionnaire."""
|
||||
return cls(
|
||||
is_clickable=data.get("is_clickable", False),
|
||||
is_focusable=data.get("is_focusable", False),
|
||||
is_dangerous=data.get("is_dangerous", False)
|
||||
)
|
||||
|
||||
|
||||
@dataclass
|
||||
class ElementContext:
|
||||
"""Contexte d'un élément UI."""
|
||||
app_name: str
|
||||
window_title: str
|
||||
workflow_hint: Optional[str] = None
|
||||
|
||||
def to_dict(self) -> Dict[str, Any]:
|
||||
"""Convertit en dictionnaire pour sérialisation JSON."""
|
||||
return {
|
||||
"app_name": self.app_name,
|
||||
"window_title": self.window_title,
|
||||
"workflow_hint": self.workflow_hint
|
||||
}
|
||||
|
||||
@classmethod
|
||||
def from_dict(cls, data: Dict[str, Any]) -> 'ElementContext':
|
||||
"""Crée une instance depuis un dictionnaire."""
|
||||
return cls(
|
||||
app_name=data["app_name"],
|
||||
window_title=data["window_title"],
|
||||
workflow_hint=data.get("workflow_hint")
|
||||
)
|
||||
|
||||
|
||||
@dataclass
|
||||
class UIElement:
|
||||
"""
|
||||
Représente un élément d'interface utilisateur détecté.
|
||||
|
||||
Attributes:
|
||||
element_id: Identifiant stable basé sur hash(app_name + center_bbox + label_normalized)
|
||||
type: Type d'élément (button, text_input, etc.)
|
||||
role: Rôle sémantique (validate_invoice, search_field, etc.)
|
||||
bbox: Bounding box (x1, y1, x2, y2)
|
||||
label: Texte visible de l'élément
|
||||
visual: Données visuelles (screenshot, embedding)
|
||||
text: Données textuelles (raw, normalized, embedding)
|
||||
properties: Propriétés (is_clickable, is_focusable, is_dangerous)
|
||||
context: Contexte (app_name, window_title, workflow_hint)
|
||||
tags: Tags additionnels
|
||||
confidence: Score de confiance de la détection (0.0-1.0)
|
||||
detection_method: Méthode de détection utilisée
|
||||
"""
|
||||
element_id: str
|
||||
type: UIElementType
|
||||
role: str
|
||||
bbox: Tuple[int, int, int, int] # (x1, y1, x2, y2)
|
||||
label: str
|
||||
visual: VisualData
|
||||
text: TextData
|
||||
properties: ElementProperties
|
||||
context: ElementContext
|
||||
tags: List[str] = field(default_factory=list)
|
||||
confidence: float = 1.0
|
||||
detection_method: str = "unknown"
|
||||
|
||||
@staticmethod
|
||||
def generate_element_id(app_name: str, bbox: Tuple[int, int, int, int], label: str) -> str:
|
||||
"""
|
||||
Génère un identifiant stable pour un élément UI.
|
||||
|
||||
Args:
|
||||
app_name: Nom de l'application
|
||||
bbox: Bounding box (x1, y1, x2, y2)
|
||||
label: Label de l'élément
|
||||
|
||||
Returns:
|
||||
Identifiant stable basé sur hash
|
||||
"""
|
||||
# Calculer le centre de la bbox
|
||||
center_x = (bbox[0] + bbox[2]) // 2
|
||||
center_y = (bbox[1] + bbox[3]) // 2
|
||||
|
||||
# Normaliser le label (lowercase, strip whitespace)
|
||||
label_normalized = label.lower().strip()
|
||||
|
||||
# Créer la chaîne à hasher
|
||||
hash_input = f"{app_name}_{center_x}_{center_y}_{label_normalized}"
|
||||
|
||||
# Générer le hash
|
||||
hash_obj = hashlib.sha256(hash_input.encode('utf-8'))
|
||||
hash_hex = hash_obj.hexdigest()[:16] # Prendre les 16 premiers caractères
|
||||
|
||||
return f"el_{hash_hex}"
|
||||
|
||||
def to_dict(self) -> Dict[str, Any]:
|
||||
"""Convertit en dictionnaire pour sérialisation JSON."""
|
||||
return {
|
||||
"schema_version": "uielement_v1",
|
||||
"element_id": self.element_id,
|
||||
"type": self.type.value,
|
||||
"role": self.role,
|
||||
"bbox": list(self.bbox),
|
||||
"label": self.label,
|
||||
"confidence": float(self.confidence),
|
||||
"detection_method": self.detection_method,
|
||||
"visual": self.visual.to_dict(),
|
||||
"text": self.text.to_dict(),
|
||||
"properties": self.properties.to_dict(),
|
||||
"context": self.context.to_dict(),
|
||||
"tags": self.tags
|
||||
}
|
||||
|
||||
@classmethod
|
||||
def from_dict(cls, data: Dict[str, Any]) -> 'UIElement':
|
||||
"""Crée une instance depuis un dictionnaire."""
|
||||
# Gérer la compatibilité avec différentes versions de schéma
|
||||
schema_version = data.get("schema_version", "uielement_v1")
|
||||
|
||||
# Parser le type
|
||||
element_type = UIElementType(data["type"])
|
||||
|
||||
# Reconstruire les sous-structures
|
||||
visual = VisualData.from_dict(data["visual"])
|
||||
text = TextData.from_dict(data["text"])
|
||||
properties = ElementProperties.from_dict(data["properties"])
|
||||
context = ElementContext.from_dict(data["context"])
|
||||
|
||||
return cls(
|
||||
element_id=data["element_id"],
|
||||
type=element_type,
|
||||
role=data["role"],
|
||||
bbox=tuple(data["bbox"]),
|
||||
label=data["label"],
|
||||
visual=visual,
|
||||
text=text,
|
||||
properties=properties,
|
||||
context=context,
|
||||
tags=data.get("tags", []),
|
||||
confidence=data.get("confidence", 1.0),
|
||||
detection_method=data.get("detection_method", "unknown")
|
||||
)
|
||||
|
||||
def to_json(self) -> str:
|
||||
"""Sérialise en JSON."""
|
||||
return json.dumps(self.to_dict(), indent=2, ensure_ascii=False)
|
||||
|
||||
@classmethod
|
||||
def from_json(cls, json_str: str) -> 'UIElement':
|
||||
"""Désérialise depuis JSON."""
|
||||
data = json.loads(json_str)
|
||||
return cls.from_dict(data)
|
||||
|
||||
|
||||
if __name__ == "__main__":
|
||||
# Tests basiques
|
||||
print("Test des modèles UIElement")
|
||||
print("=" * 50)
|
||||
|
||||
# Test génération d'element_id
|
||||
print("\n1. Test génération d'element_id:")
|
||||
element_id = UIElement.generate_element_id(
|
||||
app_name="test_app",
|
||||
bbox=(100, 200, 300, 250),
|
||||
label="Valider"
|
||||
)
|
||||
print(f" Element ID: {element_id}")
|
||||
|
||||
# Test création d'un UIElement
|
||||
print("\n2. Test création UIElement:")
|
||||
element = UIElement(
|
||||
element_id=element_id,
|
||||
type=UIElementType.BUTTON,
|
||||
role="validate_action",
|
||||
bbox=(100, 200, 300, 250),
|
||||
label="Valider",
|
||||
visual=VisualData(
|
||||
screenshot_path="data/elements/el_001.png",
|
||||
embedding_provider="openclip_ViT-B-32",
|
||||
embedding_vector_id="data/embeddings/el_001.npy"
|
||||
),
|
||||
text=TextData(
|
||||
raw="Valider",
|
||||
normalized="valider",
|
||||
embedding_provider="clip_text",
|
||||
embedding_vector_id="data/embeddings/el_001_text.npy"
|
||||
),
|
||||
properties=ElementProperties(
|
||||
is_clickable=True,
|
||||
is_focusable=True,
|
||||
is_dangerous=False
|
||||
),
|
||||
context=ElementContext(
|
||||
app_name="test_app",
|
||||
window_title="Test Window",
|
||||
workflow_hint="WF_test"
|
||||
),
|
||||
tags=["primary_action"],
|
||||
confidence=0.95,
|
||||
detection_method="heuristic_rectangle"
|
||||
)
|
||||
|
||||
print(f" Element ID: {element.element_id}")
|
||||
print(f" Type: {element.type.value}")
|
||||
print(f" Role: {element.role}")
|
||||
print(f" Label: {element.label}")
|
||||
print(f" Confidence: {element.confidence}")
|
||||
|
||||
# Test sérialisation
|
||||
print("\n3. Test sérialisation JSON:")
|
||||
json_str = element.to_json()
|
||||
print(f" JSON length: {len(json_str)} chars")
|
||||
print(f" Schema version: uielement_v1")
|
||||
|
||||
# Test désérialisation
|
||||
print("\n4. Test désérialisation:")
|
||||
element_restored = UIElement.from_json(json_str)
|
||||
print(f" Restored element_id: {element_restored.element_id}")
|
||||
print(f" Restored type: {element_restored.type.value}")
|
||||
print(f" Restored label: {element_restored.label}")
|
||||
|
||||
# Test stabilité de l'ID
|
||||
print("\n5. Test stabilité de l'element_id:")
|
||||
element_id_2 = UIElement.generate_element_id(
|
||||
app_name="test_app",
|
||||
bbox=(100, 200, 300, 250),
|
||||
label="Valider"
|
||||
)
|
||||
print(f" ID 1: {element_id}")
|
||||
print(f" ID 2: {element_id_2}")
|
||||
print(f" IDs identiques: {element_id == element_id_2}")
|
||||
|
||||
print("\n✓ Tous les tests basiques réussis!")
|
||||
|
||||
|
||||
# ============================================================================
|
||||
# EnrichedScreenState and related structures
|
||||
# ============================================================================
|
||||
|
||||
|
||||
@dataclass
|
||||
class WindowInfo:
|
||||
"""Informations sur la fenêtre active."""
|
||||
app_name: str
|
||||
window_title: str
|
||||
screen_resolution: Tuple[int, int] # (width, height)
|
||||
|
||||
def to_dict(self) -> Dict[str, Any]:
|
||||
"""Convertit en dictionnaire pour sérialisation JSON."""
|
||||
return {
|
||||
"app_name": self.app_name,
|
||||
"window_title": self.window_title,
|
||||
"screen_resolution": list(self.screen_resolution)
|
||||
}
|
||||
|
||||
@classmethod
|
||||
def from_dict(cls, data: Dict[str, Any]) -> 'WindowInfo':
|
||||
"""Crée une instance depuis un dictionnaire."""
|
||||
return cls(
|
||||
app_name=data["app_name"],
|
||||
window_title=data["window_title"],
|
||||
screen_resolution=tuple(data["screen_resolution"])
|
||||
)
|
||||
|
||||
|
||||
@dataclass
|
||||
class RawData:
|
||||
"""Données brutes de capture d'écran."""
|
||||
screenshot_path: str
|
||||
|
||||
def to_dict(self) -> Dict[str, Any]:
|
||||
"""Convertit en dictionnaire pour sérialisation JSON."""
|
||||
return {
|
||||
"screenshot_path": self.screenshot_path
|
||||
}
|
||||
|
||||
@classmethod
|
||||
def from_dict(cls, data: Dict[str, Any]) -> 'RawData':
|
||||
"""Crée une instance depuis un dictionnaire."""
|
||||
return cls(screenshot_path=data["screenshot_path"])
|
||||
|
||||
|
||||
@dataclass
|
||||
class PerceptionData:
|
||||
"""Données de perception (texte détecté, OCR, etc.)."""
|
||||
detected_text: List[str] = field(default_factory=list)
|
||||
ocr_results: Optional[Dict[str, Any]] = None
|
||||
|
||||
def to_dict(self) -> Dict[str, Any]:
|
||||
"""Convertit en dictionnaire pour sérialisation JSON."""
|
||||
return {
|
||||
"detected_text": self.detected_text,
|
||||
"ocr_results": self.ocr_results
|
||||
}
|
||||
|
||||
@classmethod
|
||||
def from_dict(cls, data: Dict[str, Any]) -> 'PerceptionData':
|
||||
"""Crée une instance depuis un dictionnaire."""
|
||||
return cls(
|
||||
detected_text=data.get("detected_text", []),
|
||||
ocr_results=data.get("ocr_results")
|
||||
)
|
||||
|
||||
|
||||
@dataclass
|
||||
class ComponentInfo:
|
||||
"""Informations sur une composante d'embedding."""
|
||||
provider: str
|
||||
vector_id: str
|
||||
|
||||
def to_dict(self) -> Dict[str, Any]:
|
||||
"""Convertit en dictionnaire pour sérialisation JSON."""
|
||||
return {
|
||||
"provider": self.provider,
|
||||
"vector_id": self.vector_id
|
||||
}
|
||||
|
||||
@classmethod
|
||||
def from_dict(cls, data: Dict[str, Any]) -> 'ComponentInfo':
|
||||
"""Crée une instance depuis un dictionnaire."""
|
||||
return cls(
|
||||
provider=data["provider"],
|
||||
vector_id=data["vector_id"]
|
||||
)
|
||||
|
||||
|
||||
@dataclass
|
||||
class EmbeddingComponents:
|
||||
"""Composantes individuelles d'un state embedding multi-modal."""
|
||||
image_embedding: Optional[ComponentInfo] = None
|
||||
text_embedding: Optional[ComponentInfo] = None
|
||||
title_embedding: Optional[ComponentInfo] = None
|
||||
ui_embedding: Optional[ComponentInfo] = None
|
||||
context_embedding: Optional[ComponentInfo] = None
|
||||
|
||||
def to_dict(self) -> Dict[str, Any]:
|
||||
"""Convertit en dictionnaire pour sérialisation JSON."""
|
||||
result = {}
|
||||
if self.image_embedding:
|
||||
result["image_embedding"] = self.image_embedding.to_dict()
|
||||
if self.text_embedding:
|
||||
result["text_embedding"] = self.text_embedding.to_dict()
|
||||
if self.title_embedding:
|
||||
result["title_embedding"] = self.title_embedding.to_dict()
|
||||
if self.ui_embedding:
|
||||
result["ui_embedding"] = self.ui_embedding.to_dict()
|
||||
if self.context_embedding:
|
||||
result["context_embedding"] = self.context_embedding.to_dict()
|
||||
return result
|
||||
|
||||
@classmethod
|
||||
def from_dict(cls, data: Dict[str, Any]) -> 'EmbeddingComponents':
|
||||
"""Crée une instance depuis un dictionnaire."""
|
||||
return cls(
|
||||
image_embedding=ComponentInfo.from_dict(data["image_embedding"]) if "image_embedding" in data else None,
|
||||
text_embedding=ComponentInfo.from_dict(data["text_embedding"]) if "text_embedding" in data else None,
|
||||
title_embedding=ComponentInfo.from_dict(data["title_embedding"]) if "title_embedding" in data else None,
|
||||
ui_embedding=ComponentInfo.from_dict(data["ui_embedding"]) if "ui_embedding" in data else None,
|
||||
context_embedding=ComponentInfo.from_dict(data["context_embedding"]) if "context_embedding" in data else None
|
||||
)
|
||||
|
||||
|
||||
@dataclass
|
||||
class StateEmbedding:
|
||||
"""Embedding d'état unifié (multi-modal ou simple)."""
|
||||
provider: str
|
||||
vector_id: str
|
||||
components: Optional[EmbeddingComponents] = None
|
||||
|
||||
def to_dict(self) -> Dict[str, Any]:
|
||||
"""Convertit en dictionnaire pour sérialisation JSON."""
|
||||
result = {
|
||||
"provider": self.provider,
|
||||
"vector_id": self.vector_id
|
||||
}
|
||||
if self.components:
|
||||
result["components"] = self.components.to_dict()
|
||||
return result
|
||||
|
||||
@classmethod
|
||||
def from_dict(cls, data: Dict[str, Any]) -> 'StateEmbedding':
|
||||
"""Crée une instance depuis un dictionnaire."""
|
||||
components = None
|
||||
if "components" in data and data["components"]:
|
||||
components = EmbeddingComponents.from_dict(data["components"])
|
||||
|
||||
return cls(
|
||||
provider=data["provider"],
|
||||
vector_id=data["vector_id"],
|
||||
components=components
|
||||
)
|
||||
|
||||
|
||||
@dataclass
|
||||
class ContextData:
|
||||
"""Données de contexte workflow."""
|
||||
current_workflow_candidate: Optional[str] = None
|
||||
tags: List[str] = field(default_factory=list)
|
||||
metadata: Dict[str, Any] = field(default_factory=dict)
|
||||
|
||||
def to_dict(self) -> Dict[str, Any]:
|
||||
"""Convertit en dictionnaire pour sérialisation JSON."""
|
||||
return {
|
||||
"current_workflow_candidate": self.current_workflow_candidate,
|
||||
"tags": self.tags,
|
||||
"metadata": self.metadata
|
||||
}
|
||||
|
||||
@classmethod
|
||||
def from_dict(cls, data: Dict[str, Any]) -> 'ContextData':
|
||||
"""Crée une instance depuis un dictionnaire."""
|
||||
return cls(
|
||||
current_workflow_candidate=data.get("current_workflow_candidate"),
|
||||
tags=data.get("tags", []),
|
||||
metadata=data.get("metadata", {})
|
||||
)
|
||||
|
||||
|
||||
@dataclass
|
||||
class EnrichedScreenState:
|
||||
"""
|
||||
ScreenState enrichi avec éléments d'UI et embedding multi-modal.
|
||||
|
||||
Attributes:
|
||||
screen_state_id: Identifiant unique de l'état d'écran
|
||||
timestamp: Horodatage de la capture
|
||||
session_id: Identifiant de session
|
||||
window: Informations sur la fenêtre
|
||||
raw: Données brutes (screenshot_path)
|
||||
perception: Données de perception (texte détecté)
|
||||
ui_elements: Liste des éléments UI détectés
|
||||
state_embedding: Embedding d'état unifié
|
||||
context: Contexte workflow
|
||||
mode: Mode de traitement ("light", "enriched", "complete")
|
||||
processing_metadata: Métadonnées de traitement (optionnel)
|
||||
"""
|
||||
screen_state_id: str
|
||||
timestamp: datetime
|
||||
session_id: str
|
||||
window: WindowInfo
|
||||
raw: RawData
|
||||
perception: PerceptionData
|
||||
ui_elements: List[UIElement]
|
||||
state_embedding: StateEmbedding
|
||||
context: ContextData
|
||||
mode: str = "light"
|
||||
processing_metadata: Optional[Dict[str, Any]] = None
|
||||
|
||||
def to_dict(self) -> Dict[str, Any]:
|
||||
"""Convertit en dictionnaire pour sérialisation JSON."""
|
||||
result = {
|
||||
"schema_version": "screenstate_v1",
|
||||
"mode": self.mode,
|
||||
"screen_state_id": self.screen_state_id,
|
||||
"timestamp": self.timestamp.isoformat(),
|
||||
"session_id": self.session_id,
|
||||
"window": self.window.to_dict(),
|
||||
"raw": self.raw.to_dict(),
|
||||
"perception": self.perception.to_dict(),
|
||||
"ui_elements": [elem.to_dict() for elem in self.ui_elements],
|
||||
"state_embedding": self.state_embedding.to_dict(),
|
||||
"context": self.context.to_dict()
|
||||
}
|
||||
|
||||
if self.processing_metadata:
|
||||
result["processing_metadata"] = self.processing_metadata
|
||||
|
||||
return result
|
||||
|
||||
@classmethod
|
||||
def from_dict(cls, data: Dict[str, Any]) -> 'EnrichedScreenState':
|
||||
"""Crée une instance depuis un dictionnaire."""
|
||||
# Gérer la compatibilité avec différentes versions de schéma
|
||||
schema_version = data.get("schema_version", "screenstate_v1")
|
||||
|
||||
# Parser le timestamp
|
||||
timestamp = datetime.fromisoformat(data["timestamp"])
|
||||
|
||||
# Reconstruire les sous-structures
|
||||
window = WindowInfo.from_dict(data["window"])
|
||||
raw = RawData.from_dict(data["raw"])
|
||||
perception = PerceptionData.from_dict(data["perception"])
|
||||
|
||||
# Reconstruire les UI elements
|
||||
ui_elements = [UIElement.from_dict(elem_data) for elem_data in data.get("ui_elements", [])]
|
||||
|
||||
# Reconstruire le state embedding
|
||||
state_embedding = StateEmbedding.from_dict(data["state_embedding"])
|
||||
|
||||
# Reconstruire le contexte
|
||||
context = ContextData.from_dict(data["context"])
|
||||
|
||||
return cls(
|
||||
screen_state_id=data["screen_state_id"],
|
||||
timestamp=timestamp,
|
||||
session_id=data["session_id"],
|
||||
window=window,
|
||||
raw=raw,
|
||||
perception=perception,
|
||||
ui_elements=ui_elements,
|
||||
state_embedding=state_embedding,
|
||||
context=context,
|
||||
mode=data.get("mode", "light"),
|
||||
processing_metadata=data.get("processing_metadata")
|
||||
)
|
||||
|
||||
def to_json(self) -> str:
|
||||
"""Sérialise en JSON."""
|
||||
return json.dumps(self.to_dict(), indent=2, ensure_ascii=False)
|
||||
|
||||
@classmethod
|
||||
def from_json(cls, json_str: str) -> 'EnrichedScreenState':
|
||||
"""Désérialise depuis JSON."""
|
||||
data = json.loads(json_str)
|
||||
return cls.from_dict(data)
|
||||
|
||||
@classmethod
|
||||
def create_light_mode(
|
||||
cls,
|
||||
screen_state_id: str,
|
||||
session_id: str,
|
||||
window: WindowInfo,
|
||||
screenshot_path: str,
|
||||
image_embedding_provider: str,
|
||||
image_embedding_vector_id: str
|
||||
) -> 'EnrichedScreenState':
|
||||
"""
|
||||
Crée un EnrichedScreenState en mode light (compatibilité arrière).
|
||||
|
||||
Args:
|
||||
screen_state_id: ID de l'état d'écran
|
||||
session_id: ID de session
|
||||
window: Informations sur la fenêtre
|
||||
screenshot_path: Chemin vers le screenshot
|
||||
image_embedding_provider: Provider de l'embedding image
|
||||
image_embedding_vector_id: ID du vecteur d'embedding image
|
||||
|
||||
Returns:
|
||||
EnrichedScreenState en mode light
|
||||
"""
|
||||
return cls(
|
||||
screen_state_id=screen_state_id,
|
||||
timestamp=datetime.now(),
|
||||
session_id=session_id,
|
||||
window=window,
|
||||
raw=RawData(screenshot_path=screenshot_path),
|
||||
perception=PerceptionData(detected_text=[]),
|
||||
ui_elements=[], # Vide en mode light
|
||||
state_embedding=StateEmbedding(
|
||||
provider=image_embedding_provider,
|
||||
vector_id=image_embedding_vector_id,
|
||||
components=None # Pas de composantes en mode light
|
||||
),
|
||||
context=ContextData(),
|
||||
mode="light"
|
||||
)
|
||||
|
||||
|
||||
# Tests pour EnrichedScreenState
|
||||
def test_enriched_screen_state():
|
||||
"""Tests basiques pour EnrichedScreenState."""
|
||||
print("\n" + "=" * 50)
|
||||
print("Test des modèles EnrichedScreenState")
|
||||
print("=" * 50)
|
||||
|
||||
# Test mode light
|
||||
print("\n1. Test création en mode light:")
|
||||
window = WindowInfo(
|
||||
app_name="test_app",
|
||||
window_title="Test Window",
|
||||
screen_resolution=(1920, 1080)
|
||||
)
|
||||
|
||||
screen_state_light = EnrichedScreenState.create_light_mode(
|
||||
screen_state_id="screen_001",
|
||||
session_id="session_001",
|
||||
window=window,
|
||||
screenshot_path="data/screens/screen_001.png",
|
||||
image_embedding_provider="openclip_ViT-B-32",
|
||||
image_embedding_vector_id="data/embeddings/screen_001.npy"
|
||||
)
|
||||
|
||||
print(f" Screen State ID: {screen_state_light.screen_state_id}")
|
||||
print(f" Mode: {screen_state_light.mode}")
|
||||
print(f" UI Elements: {len(screen_state_light.ui_elements)}")
|
||||
print(f" State Embedding Provider: {screen_state_light.state_embedding.provider}")
|
||||
print(f" Has Components: {screen_state_light.state_embedding.components is not None}")
|
||||
|
||||
# Test sérialisation mode light
|
||||
print("\n2. Test sérialisation JSON (mode light):")
|
||||
json_str = screen_state_light.to_json()
|
||||
print(f" JSON length: {len(json_str)} chars")
|
||||
|
||||
# Test désérialisation mode light
|
||||
print("\n3. Test désérialisation (mode light):")
|
||||
screen_state_restored = EnrichedScreenState.from_json(json_str)
|
||||
print(f" Restored screen_state_id: {screen_state_restored.screen_state_id}")
|
||||
print(f" Restored mode: {screen_state_restored.mode}")
|
||||
print(f" Restored UI elements count: {len(screen_state_restored.ui_elements)}")
|
||||
|
||||
# Test mode enriched avec éléments
|
||||
print("\n4. Test création en mode enriched:")
|
||||
element = UIElement(
|
||||
element_id="el_test_001",
|
||||
type=UIElementType.BUTTON,
|
||||
role="validate_action",
|
||||
bbox=(100, 200, 300, 250),
|
||||
label="Valider",
|
||||
visual=VisualData(
|
||||
screenshot_path="data/elements/el_001.png",
|
||||
embedding_provider="openclip_ViT-B-32",
|
||||
embedding_vector_id="data/embeddings/el_001.npy"
|
||||
),
|
||||
text=TextData(
|
||||
raw="Valider",
|
||||
normalized="valider",
|
||||
embedding_provider="clip_text",
|
||||
embedding_vector_id="data/embeddings/el_001_text.npy"
|
||||
),
|
||||
properties=ElementProperties(is_clickable=True),
|
||||
context=ElementContext(
|
||||
app_name="test_app",
|
||||
window_title="Test Window"
|
||||
),
|
||||
tags=["primary_action"],
|
||||
confidence=0.95
|
||||
)
|
||||
|
||||
screen_state_enriched = EnrichedScreenState(
|
||||
screen_state_id="screen_002",
|
||||
timestamp=datetime.now(),
|
||||
session_id="session_001",
|
||||
window=window,
|
||||
raw=RawData(screenshot_path="data/screens/screen_002.png"),
|
||||
perception=PerceptionData(detected_text=["Valider", "Annuler"]),
|
||||
ui_elements=[element],
|
||||
state_embedding=StateEmbedding(
|
||||
provider="openclip_ViT-B-32",
|
||||
vector_id="data/embeddings/screen_002.npy",
|
||||
components=None
|
||||
),
|
||||
context=ContextData(tags=["test"]),
|
||||
mode="enriched"
|
||||
)
|
||||
|
||||
print(f" Screen State ID: {screen_state_enriched.screen_state_id}")
|
||||
print(f" Mode: {screen_state_enriched.mode}")
|
||||
print(f" UI Elements: {len(screen_state_enriched.ui_elements)}")
|
||||
print(f" Detected Text: {screen_state_enriched.perception.detected_text}")
|
||||
|
||||
# Test mode complete avec composantes
|
||||
print("\n5. Test création en mode complete:")
|
||||
components = EmbeddingComponents(
|
||||
image_embedding=ComponentInfo(
|
||||
provider="openclip_ViT-B-32",
|
||||
vector_id="data/embeddings/screen_003_image.npy"
|
||||
),
|
||||
text_embedding=ComponentInfo(
|
||||
provider="clip_text",
|
||||
vector_id="data/embeddings/screen_003_text.npy"
|
||||
),
|
||||
title_embedding=ComponentInfo(
|
||||
provider="clip_text",
|
||||
vector_id="data/embeddings/screen_003_title.npy"
|
||||
)
|
||||
)
|
||||
|
||||
screen_state_complete = EnrichedScreenState(
|
||||
screen_state_id="screen_003",
|
||||
timestamp=datetime.now(),
|
||||
session_id="session_001",
|
||||
window=window,
|
||||
raw=RawData(screenshot_path="data/screens/screen_003.png"),
|
||||
perception=PerceptionData(detected_text=["Valider", "Annuler"]),
|
||||
ui_elements=[element],
|
||||
state_embedding=StateEmbedding(
|
||||
provider="multimodal_fusion_v1",
|
||||
vector_id="data/embeddings/screen_003_fused.npy",
|
||||
components=components
|
||||
),
|
||||
context=ContextData(tags=["test"]),
|
||||
mode="complete"
|
||||
)
|
||||
|
||||
print(f" Screen State ID: {screen_state_complete.screen_state_id}")
|
||||
print(f" Mode: {screen_state_complete.mode}")
|
||||
print(f" State Embedding Provider: {screen_state_complete.state_embedding.provider}")
|
||||
print(f" Has Components: {screen_state_complete.state_embedding.components is not None}")
|
||||
|
||||
# Test sérialisation mode complete
|
||||
print("\n6. Test sérialisation JSON (mode complete):")
|
||||
json_str_complete = screen_state_complete.to_json()
|
||||
print(f" JSON length: {len(json_str_complete)} chars")
|
||||
|
||||
# Test désérialisation mode complete
|
||||
print("\n7. Test désérialisation (mode complete):")
|
||||
screen_state_complete_restored = EnrichedScreenState.from_json(json_str_complete)
|
||||
print(f" Restored screen_state_id: {screen_state_complete_restored.screen_state_id}")
|
||||
print(f" Restored mode: {screen_state_complete_restored.mode}")
|
||||
print(f" Restored components: {screen_state_complete_restored.state_embedding.components is not None}")
|
||||
|
||||
print("\n✓ Tous les tests EnrichedScreenState réussis!")
|
||||
|
||||
|
||||
if __name__ == "__main__":
|
||||
# Exécuter les tests
|
||||
test_enriched_screen_state()
|
||||
398
geniusia2/core/utils/INPUT_UTILS_README.md
Normal file
398
geniusia2/core/utils/INPUT_UTILS_README.md
Normal file
@@ -0,0 +1,398 @@
|
||||
# Input Utils - Documentation d'Implémentation
|
||||
|
||||
## Vue d'Ensemble
|
||||
|
||||
Le module `input_utils.py` fournit une interface complète pour exécuter des actions UI (souris, clavier) avec support du rollback et journalisation complète.
|
||||
|
||||
## Fonctionnalités Implémentées
|
||||
|
||||
### 1. Classe ActionType (Enum)
|
||||
|
||||
Énumération des types d'actions supportées :
|
||||
- ✅ `CLICK` - Clics souris
|
||||
- ✅ `TYPE` - Saisie texte
|
||||
- ✅ `SCROLL` - Défilement
|
||||
- ✅ `WAIT` - Attente temporisée
|
||||
- ✅ `MOVE` - Déplacement souris
|
||||
- ✅ `DRAG` - Glisser-déposer
|
||||
|
||||
### 2. Classe InputUtils
|
||||
|
||||
#### Initialisation
|
||||
- ✅ Configuration PyAutoGUI avec FAILSAFE activé
|
||||
- ✅ Pause configurable entre actions
|
||||
- ✅ Historique des actions pour rollback
|
||||
- ✅ Mapping AZERTY pour caractères spéciaux
|
||||
- ✅ Intégration avec Logger pour traçabilité
|
||||
|
||||
#### Méthodes d'Action
|
||||
|
||||
##### `click(x, y, button="left", clicks=1, interval=0.0)`
|
||||
- ✅ Effectue un clic souris à la position spécifiée
|
||||
- ✅ Support multi-boutons (left, right, middle)
|
||||
- ✅ Support clics multiples (double-clic, etc.)
|
||||
- ✅ Enregistre la position précédente pour rollback
|
||||
- ✅ Logging complet de l'opération
|
||||
|
||||
##### `type_text(text, interval=0.0, use_azerty=True)`
|
||||
- ✅ Saisit du texte au clavier
|
||||
- ✅ Support du mapping AZERTY
|
||||
- ✅ Intervalle configurable entre caractères
|
||||
- ✅ Enregistre la longueur pour rollback (suppression)
|
||||
- ✅ Logging du texte saisi
|
||||
|
||||
##### `scroll(direction, amount=3, x=None, y=None)`
|
||||
- ✅ Effectue un défilement vertical ou horizontal
|
||||
- ✅ Directions supportées : up, down, left, right
|
||||
- ✅ Quantité configurable
|
||||
- ✅ Position optionnelle
|
||||
- ✅ Action inversible pour rollback
|
||||
|
||||
##### `wait(duration)`
|
||||
- ✅ Attend pendant une durée spécifiée
|
||||
- ✅ Enregistré dans l'historique
|
||||
- ✅ Non inversible (pas de rollback)
|
||||
|
||||
##### `move(x, y, duration=0.2)`
|
||||
- ✅ Déplace la souris vers une position
|
||||
- ✅ Durée de mouvement configurable
|
||||
- ✅ Enregistre la position précédente
|
||||
- ✅ Inversible pour rollback
|
||||
|
||||
##### `drag(start_x, start_y, end_x, end_y, duration=0.5, button="left")`
|
||||
- ✅ Effectue un glisser-déposer
|
||||
- ✅ Support multi-boutons
|
||||
- ✅ Durée configurable
|
||||
- ✅ Complètement inversible
|
||||
|
||||
#### Méthodes de Rollback
|
||||
|
||||
##### `get_inverse_action(action)`
|
||||
- ✅ Génère l'action inverse pour rollback
|
||||
- ✅ Support pour tous les types d'actions inversibles
|
||||
- ✅ Retourne None pour actions non inversibles
|
||||
|
||||
**Actions inversibles :**
|
||||
- `CLICK` → Retour à la position précédente
|
||||
- `TYPE` → Suppression du texte (backspace × longueur)
|
||||
- `SCROLL` → Défilement inverse
|
||||
- `MOVE` → Retour à la position précédente
|
||||
- `DRAG` → Glissement inverse
|
||||
|
||||
**Actions non inversibles :**
|
||||
- `WAIT` → Pas d'inverse logique
|
||||
|
||||
##### `execute_inverse_action(action)`
|
||||
- ✅ Exécute l'action inverse générée
|
||||
- ✅ Gestion d'erreurs robuste
|
||||
- ✅ Retourne True/False selon le succès
|
||||
|
||||
#### Méthode d'Exécution Unifiée
|
||||
|
||||
##### `execute_action(action_data)`
|
||||
- ✅ Exécute une action depuis un dictionnaire
|
||||
- ✅ Calcul automatique du centre de bbox
|
||||
- ✅ Support de tous les types d'actions
|
||||
- ✅ Interface unifiée pour l'orchestrateur
|
||||
|
||||
**Format d'entrée :**
|
||||
```python
|
||||
action_data = {
|
||||
"action_type": "click", # ou "type", "scroll", etc.
|
||||
"bbox": (x, y, w, h), # Bounding box de l'élément
|
||||
"parameters": { # Paramètres spécifiques
|
||||
"button": "left",
|
||||
"text": "...",
|
||||
"direction": "down",
|
||||
# etc.
|
||||
}
|
||||
}
|
||||
```
|
||||
|
||||
**Actions supportées :**
|
||||
- `click` - Clic simple
|
||||
- `double_click` - Double-clic
|
||||
- `right_click` - Clic droit
|
||||
- `type` - Saisie texte
|
||||
- `scroll` - Défilement
|
||||
- `wait` - Attente
|
||||
- `move` - Déplacement
|
||||
- `drag` - Glisser-déposer
|
||||
|
||||
#### Méthodes Utilitaires
|
||||
|
||||
##### `get_action_history(limit=50)`
|
||||
- ✅ Retourne l'historique des actions
|
||||
- ✅ Limite configurable
|
||||
- ✅ Utilisé pour rollback et analyse
|
||||
|
||||
##### `clear_history()`
|
||||
- ✅ Efface l'historique des actions
|
||||
- ✅ Logging de l'opération
|
||||
|
||||
##### `_convert_to_azerty(text)` (privée)
|
||||
- ✅ Convertit du texte pour clavier AZERTY
|
||||
- ✅ Mapping des caractères spéciaux
|
||||
- ✅ PyAutoGUI gère déjà le layout système
|
||||
|
||||
## Conformité aux Exigences
|
||||
|
||||
### Exigence 3.2
|
||||
> LORSQUE le Système_RPA fonctionne en Mode_Autopilot, LE Système_RPA DOIT exécuter automatiquement les actions suggérées
|
||||
|
||||
✅ **Implémenté** : La méthode `execute_action()` permet l'exécution automatique depuis l'orchestrateur.
|
||||
|
||||
### Exigence 3.4
|
||||
> LORSQU'une action automatisée échoue, LE Système_RPA DOIT effectuer un rollback des 3 dernières actions
|
||||
|
||||
✅ **Implémenté** :
|
||||
- `get_inverse_action()` génère les actions inverses
|
||||
- `execute_inverse_action()` exécute le rollback
|
||||
- `action_history` conserve toutes les actions
|
||||
|
||||
### Exigence 3.5
|
||||
> LORSQU'un rollback est effectué, LE Système_RPA DOIT journaliser l'événement
|
||||
|
||||
✅ **Implémenté** : Toutes les actions et leurs inverses sont loggées via `self.logger.log_action()`.
|
||||
|
||||
### Exigence 5.1
|
||||
> LE Système_RPA DOIT supporter les claviers AZERTY
|
||||
|
||||
✅ **Implémenté** :
|
||||
- Mapping AZERTY dans `azerty_mapping`
|
||||
- Méthode `_convert_to_azerty()`
|
||||
- Option `use_azerty` dans `type_text()`
|
||||
|
||||
## Sécurité
|
||||
|
||||
### FAILSAFE
|
||||
- ✅ `pyautogui.FAILSAFE = True` activé
|
||||
- ✅ Déplacer la souris dans un coin arrête toutes les opérations
|
||||
- ✅ Protection contre les boucles infinies
|
||||
|
||||
### Pause entre Actions
|
||||
- ✅ Configurable via `config["input"]["pause_between_actions"]`
|
||||
- ✅ Évite les actions trop rapides
|
||||
- ✅ Améliore la fiabilité
|
||||
|
||||
### Journalisation Complète
|
||||
- ✅ Toutes les actions loggées avec paramètres complets
|
||||
- ✅ Positions précédentes enregistrées
|
||||
- ✅ Timestamps pour traçabilité
|
||||
- ✅ Erreurs capturées et loggées
|
||||
|
||||
## Format de l'Historique
|
||||
|
||||
Chaque action dans `action_history` contient :
|
||||
|
||||
```python
|
||||
{
|
||||
"type": "click", # Type d'action
|
||||
"x": 450, # Coordonnées
|
||||
"y": 320,
|
||||
"button": "left", # Paramètres spécifiques
|
||||
"clicks": 1,
|
||||
"previous_position": (100, 200), # Pour rollback
|
||||
"timestamp": 1234567890.123 # Horodatage
|
||||
}
|
||||
```
|
||||
|
||||
## Utilisation
|
||||
|
||||
### Exemple Basique
|
||||
|
||||
```python
|
||||
from geniusia2.core.utils.input_utils import InputUtils
|
||||
from geniusia2.core.logger import Logger
|
||||
from geniusia2.core.config import get_config
|
||||
|
||||
# Initialiser
|
||||
logger = Logger()
|
||||
config = get_config()
|
||||
input_utils = InputUtils(logger, config)
|
||||
|
||||
# Effectuer un clic
|
||||
success = input_utils.click(450, 320, button="left")
|
||||
|
||||
# Saisir du texte
|
||||
success = input_utils.type_text("Bonjour!", use_azerty=True)
|
||||
|
||||
# Défiler
|
||||
success = input_utils.scroll("down", amount=3)
|
||||
|
||||
# Obtenir l'historique
|
||||
history = input_utils.get_action_history(limit=10)
|
||||
```
|
||||
|
||||
### Exemple avec Rollback
|
||||
|
||||
```python
|
||||
# Exécuter plusieurs actions
|
||||
input_utils.click(100, 100)
|
||||
input_utils.type_text("test")
|
||||
input_utils.click(200, 200)
|
||||
|
||||
# Obtenir les 3 dernières actions
|
||||
recent_actions = input_utils.get_action_history(limit=3)
|
||||
|
||||
# Rollback dans l'ordre inverse
|
||||
for action in reversed(recent_actions):
|
||||
input_utils.execute_inverse_action(action)
|
||||
```
|
||||
|
||||
### Exemple avec l'Orchestrateur
|
||||
|
||||
```python
|
||||
# L'orchestrateur prépare l'action
|
||||
action_data = {
|
||||
"action_type": "click",
|
||||
"bbox": (450, 320, 120, 40),
|
||||
"parameters": {"button": "left"}
|
||||
}
|
||||
|
||||
# Exécution unifiée
|
||||
success = input_utils.execute_action(action_data)
|
||||
|
||||
if not success:
|
||||
# Rollback des 3 dernières actions
|
||||
recent = input_utils.get_action_history(limit=3)
|
||||
for action in reversed(recent):
|
||||
input_utils.execute_inverse_action(action)
|
||||
```
|
||||
|
||||
## Intégration avec l'Orchestrateur
|
||||
|
||||
L'orchestrateur utilise `InputUtils` dans sa méthode `execute_action()` :
|
||||
|
||||
```python
|
||||
# Dans orchestrator.py
|
||||
def execute_action(self, decision: Dict[str, Any]):
|
||||
action = decision.get("action")
|
||||
|
||||
# Préparer les données d'action
|
||||
action_data = {
|
||||
"action_type": action.action_type,
|
||||
"bbox": action.bbox,
|
||||
"parameters": action.parameters
|
||||
}
|
||||
|
||||
# Exécuter via InputUtils
|
||||
success = self.input_utils.execute_action(action_data)
|
||||
|
||||
if not success:
|
||||
# Rollback si échec
|
||||
self.rollback_last_actions(count=3)
|
||||
```
|
||||
|
||||
## Dépendances
|
||||
|
||||
### Requises
|
||||
- `pyautogui` - Contrôle souris et clavier
|
||||
- `time` - Gestion des délais
|
||||
- `typing` - Annotations de types
|
||||
- `enum` - Énumération des types d'actions
|
||||
|
||||
### Internes
|
||||
- `Logger` - Journalisation chiffrée
|
||||
- `config` - Configuration globale
|
||||
|
||||
## Tests
|
||||
|
||||
Tests de validation dans `test_input_utils_simple.py` :
|
||||
- ✅ Vérification de la structure
|
||||
- ✅ Présence de toutes les méthodes
|
||||
- ✅ Support AZERTY
|
||||
- ✅ Support rollback
|
||||
- ✅ Sécurité (FAILSAFE, logging)
|
||||
- ✅ Conformité aux exigences
|
||||
|
||||
## Notes d'Implémentation
|
||||
|
||||
1. **PyAutoGUI** : Utilisé pour toutes les opérations bas niveau
|
||||
2. **FAILSAFE** : Toujours activé pour sécurité
|
||||
3. **Historique** : Conservé en mémoire, pas persisté
|
||||
4. **AZERTY** : PyAutoGUI détecte automatiquement le layout système
|
||||
5. **Rollback** : Limité aux actions inversibles logiquement
|
||||
6. **Logging** : Toutes les opérations sont tracées
|
||||
|
||||
## Limitations Connues
|
||||
|
||||
1. **Clics non inversibles** : Un clic ne peut pas être "annulé" logiquement
|
||||
2. **Attentes non inversibles** : Le temps ne peut pas être "rembobiné"
|
||||
3. **Dépendance système** : Nécessite un environnement graphique
|
||||
4. **Permissions** : Peut nécessiter des permissions spéciales sur certains OS
|
||||
|
||||
## Statut
|
||||
|
||||
✅ **Implémentation COMPLÈTE**
|
||||
|
||||
Toutes les fonctionnalités requises sont implémentées :
|
||||
- ✅ Actions souris (clic, déplacement, glisser-déposer)
|
||||
- ✅ Saisie texte avec support AZERTY
|
||||
- ✅ Défilement vertical et horizontal
|
||||
- ✅ Actions inverses pour rollback
|
||||
- ✅ Historique des actions
|
||||
- ✅ Logging complet
|
||||
- ✅ Gestion d'erreurs robuste
|
||||
- ✅ Interface unifiée pour l'orchestrateur
|
||||
|
||||
## Prochaines Étapes
|
||||
|
||||
L'InputUtils est maintenant prêt pour intégration avec :
|
||||
1. **Orchestrateur** - Exécution des actions en mode Autopilot
|
||||
2. **Moteur de Rejeu** - Rollback automatique en cas d'échec
|
||||
3. **Tests d'intégration** - Validation avec actions réelles
|
||||
|
||||
## Exemple Complet
|
||||
|
||||
```python
|
||||
#!/usr/bin/env python3
|
||||
"""Exemple d'utilisation complète d'InputUtils"""
|
||||
|
||||
from geniusia2.core.utils.input_utils import InputUtils
|
||||
from geniusia2.core.logger import Logger
|
||||
from geniusia2.core.config import get_config
|
||||
|
||||
# Initialisation
|
||||
logger = Logger()
|
||||
config = get_config()
|
||||
input_utils = InputUtils(logger, config)
|
||||
|
||||
# Scénario : Remplir un formulaire
|
||||
print("Remplissage du formulaire...")
|
||||
|
||||
# 1. Cliquer sur le champ nom
|
||||
input_utils.click(300, 200)
|
||||
input_utils.wait(0.5)
|
||||
|
||||
# 2. Saisir le nom
|
||||
input_utils.type_text("Jean Dupont")
|
||||
input_utils.wait(0.3)
|
||||
|
||||
# 3. Cliquer sur le champ email
|
||||
input_utils.click(300, 250)
|
||||
input_utils.wait(0.5)
|
||||
|
||||
# 4. Saisir l'email
|
||||
input_utils.type_text("jean.dupont@example.com")
|
||||
input_utils.wait(0.3)
|
||||
|
||||
# 5. Défiler vers le bas
|
||||
input_utils.scroll("down", amount=3)
|
||||
input_utils.wait(0.5)
|
||||
|
||||
# 6. Cliquer sur le bouton valider
|
||||
success = input_utils.click(450, 400)
|
||||
|
||||
if not success:
|
||||
print("Échec du clic, rollback...")
|
||||
# Rollback des 3 dernières actions
|
||||
history = input_utils.get_action_history(limit=3)
|
||||
for action in reversed(history):
|
||||
input_utils.execute_inverse_action(action)
|
||||
else:
|
||||
print("Formulaire soumis avec succès!")
|
||||
|
||||
# Afficher l'historique
|
||||
print(f"\nActions effectuées: {len(input_utils.get_action_history())}")
|
||||
```
|
||||
199
geniusia2/core/utils/VISION_UTILS_README.md
Normal file
199
geniusia2/core/utils/VISION_UTILS_README.md
Normal file
@@ -0,0 +1,199 @@
|
||||
# Vision Utils - Documentation d'Implémentation
|
||||
|
||||
## Vue d'Ensemble
|
||||
|
||||
Le module `vision_utils.py` fournit une interface unifiée pour la détection d'éléments UI en utilisant plusieurs modèles de vision par ordinateur avec fallback automatique.
|
||||
|
||||
## Fonctionnalités Implémentées
|
||||
|
||||
### 1. Classe VisionUtils
|
||||
|
||||
#### Initialisation
|
||||
- ✅ Chargement configurable des modèles (OWL-v2, Grounding DINO, YOLO-World)
|
||||
- ✅ Configuration du modèle principal et ordre de fallback
|
||||
- ✅ Lazy loading des modèles pour optimiser la mémoire
|
||||
|
||||
#### Méthodes de Détection
|
||||
|
||||
##### `detect_with_owlv2(prompt, frame)`
|
||||
- ✅ Détection open-vocabulary avec OWL-v2
|
||||
- ✅ Support des prompts textuels pour décrire les éléments UI
|
||||
- ✅ Conversion automatique des bounding boxes
|
||||
- ✅ Extraction ROI pour chaque détection
|
||||
- ✅ Génération d'embeddings (placeholder pour intégration OpenCLIP future)
|
||||
- ✅ Gestion d'erreurs robuste
|
||||
|
||||
##### `detect_with_dino(prompt, frame)`
|
||||
- ✅ Interface préparée pour Grounding DINO
|
||||
- ✅ Stub fonctionnel en attente d'implémentation complète
|
||||
- ✅ Gestion d'erreurs
|
||||
|
||||
##### `detect_with_yolo(prompt, frame)`
|
||||
- ✅ Interface préparée pour YOLO-World
|
||||
- ✅ Stub fonctionnel en attente d'implémentation complète
|
||||
- ✅ Gestion d'erreurs
|
||||
|
||||
##### `detect(prompt, frame, model=None)`
|
||||
- ✅ Détection avec fallback automatique entre modèles
|
||||
- ✅ Essai séquentiel des modèles jusqu'à obtenir des détections
|
||||
- ✅ Logging détaillé des tentatives et échecs
|
||||
- ✅ Retour gracieux en cas d'échec de tous les modèles
|
||||
|
||||
#### Méthodes de Sélection et Filtrage
|
||||
|
||||
##### `select_best_detection(detections, context=None)`
|
||||
- ✅ Sélection intelligente basée sur plusieurs critères:
|
||||
- Score de confiance
|
||||
- Modèle source (bonus pour modèle principal)
|
||||
- Proximité avec position précédente (si contexte fourni)
|
||||
- Taille raisonnable de bounding box
|
||||
- ✅ Support du contexte pour améliorer la sélection
|
||||
|
||||
##### `filter_detections(detections, min_confidence, max_detections)`
|
||||
- ✅ Filtrage par seuil de confiance minimum
|
||||
- ✅ Tri par confiance décroissante
|
||||
- ✅ Limitation du nombre de détections retournées
|
||||
|
||||
##### `merge_overlapping_detections(detections, iou_threshold)`
|
||||
- ✅ Calcul d'IoU (Intersection over Union)
|
||||
- ✅ Fusion des détections chevauchantes
|
||||
- ✅ Conservation de la détection avec meilleure confiance
|
||||
|
||||
#### Méthodes Utilitaires
|
||||
|
||||
##### `get_detection_statistics(detections)`
|
||||
- ✅ Calcul de statistiques complètes:
|
||||
- Nombre de détections
|
||||
- Confiance moyenne, max, min, écart-type
|
||||
- Modèles utilisés
|
||||
- Distribution par modèle
|
||||
|
||||
##### `unload_models()`
|
||||
- ✅ Déchargement propre des modèles
|
||||
- ✅ Libération de la mémoire GPU/CPU
|
||||
- ✅ Garbage collection
|
||||
|
||||
## Conformité aux Exigences
|
||||
|
||||
### Exigence 1.1
|
||||
> LORSQUE le Système_RPA fonctionne en Mode_Shadow, LE Système_RPA DOIT capturer toutes les trames d'écran et coordonnées d'Élément_UI
|
||||
|
||||
✅ **Implémenté**: Les méthodes de détection acceptent des frames et retournent des objets Detection avec coordonnées bbox précises.
|
||||
|
||||
### Exigence 2.1
|
||||
> LORSQUE le Système_RPA fonctionne en Mode_Assisté, LE Système_RPA DOIT surligner les Élément_UI suggérés
|
||||
|
||||
✅ **Implémenté**: Les détections incluent bbox et roi_image pour permettre le surlignage visuel par la GUI.
|
||||
|
||||
### Exigence 4.1
|
||||
> LORSQU'une action automatisée est exécutée, LE Gestionnaire_Apprentissage DOIT calculer le delta entre l'emplacement prédit de l'Élément_UI et l'emplacement réel
|
||||
|
||||
✅ **Implémenté**: Les détections fournissent les coordonnées précises nécessaires au calcul de delta. La méthode `select_best_detection` supporte le contexte avec `previous_bbox` pour comparaison.
|
||||
|
||||
## Gestion d'Erreurs avec Fallback
|
||||
|
||||
Le système implémente une stratégie de fallback robuste:
|
||||
|
||||
1. **Tentative avec modèle principal** (configuré dans config.py)
|
||||
2. **Fallback automatique** vers les modèles alternatifs
|
||||
3. **Logging détaillé** de chaque tentative
|
||||
4. **Retour gracieux** avec liste vide si tous les modèles échouent
|
||||
|
||||
Exemple de séquence de fallback:
|
||||
```
|
||||
OWL-v2 (principal) → Grounding DINO → YOLO-World
|
||||
```
|
||||
|
||||
## Format des Détections
|
||||
|
||||
Chaque détection retournée est un objet `Detection` avec:
|
||||
- `label`: Nom de l'élément détecté
|
||||
- `confidence`: Score de confiance (0-1)
|
||||
- `bbox`: Bounding box (x, y, width, height)
|
||||
- `embedding`: Embedding visuel 512-d
|
||||
- `model_source`: Modèle ayant effectué la détection
|
||||
- `roi_image`: Image de la région d'intérêt
|
||||
- `metadata`: Métadonnées additionnelles
|
||||
|
||||
## Tests
|
||||
|
||||
Tests unitaires complets dans `tests/test_vision_utils.py`:
|
||||
- ✅ Initialisation
|
||||
- ✅ Filtrage des détections
|
||||
- ✅ Sélection de la meilleure détection
|
||||
- ✅ Fusion des détections chevauchantes
|
||||
- ✅ Calcul de statistiques
|
||||
- ✅ Gestion des cas limites (liste vide, détection unique)
|
||||
|
||||
Tous les tests passent avec succès.
|
||||
|
||||
## Dépendances
|
||||
|
||||
### Requises
|
||||
- numpy
|
||||
- logging (standard library)
|
||||
|
||||
### Optionnelles (pour détection complète)
|
||||
- transformers (pour OWL-v2)
|
||||
- torch (pour OWL-v2)
|
||||
- PIL/Pillow (pour traitement d'images)
|
||||
|
||||
### À implémenter
|
||||
- Grounding DINO (nécessite installation spéciale)
|
||||
- YOLO-World (nécessite ultralytics)
|
||||
|
||||
## Utilisation
|
||||
|
||||
```python
|
||||
from geniusia2.core.utils.vision_utils import VisionUtils
|
||||
import numpy as np
|
||||
|
||||
# Initialiser
|
||||
vision = VisionUtils()
|
||||
|
||||
# Capturer un frame (exemple)
|
||||
frame = np.random.randint(0, 255, (480, 640, 3), dtype=np.uint8)
|
||||
|
||||
# Détecter un élément
|
||||
detections = vision.detect("button valider", frame)
|
||||
|
||||
# Filtrer les détections
|
||||
filtered = vision.filter_detections(detections, min_confidence=0.5)
|
||||
|
||||
# Sélectionner la meilleure
|
||||
best = vision.select_best_detection(filtered)
|
||||
|
||||
if best:
|
||||
print(f"Élément trouvé: {best.label}")
|
||||
print(f"Confiance: {best.confidence:.2f}")
|
||||
print(f"Position: {best.bbox}")
|
||||
```
|
||||
|
||||
## Intégration Future
|
||||
|
||||
Le module est conçu pour s'intégrer avec:
|
||||
- **EmbeddingsManager**: Pour remplacer les embeddings placeholder par OpenCLIP
|
||||
- **Orchestrator**: Pour la boucle cognitive principale
|
||||
- **LearningManager**: Pour le calcul de confiance et adaptation
|
||||
- **GUI**: Pour l'affichage des détections et surlignage
|
||||
|
||||
## Notes d'Implémentation
|
||||
|
||||
1. **Lazy Loading**: Les modèles ne sont chargés qu'à la première utilisation pour économiser la mémoire
|
||||
2. **GPU Support**: Détection automatique et utilisation du GPU si disponible
|
||||
3. **Logging**: Logging détaillé à tous les niveaux pour debugging
|
||||
4. **Extensibilité**: Architecture permettant l'ajout facile de nouveaux modèles
|
||||
5. **Robustesse**: Gestion d'erreurs complète avec fallback automatique
|
||||
|
||||
## Statut
|
||||
|
||||
✅ **Tâche 5.1 COMPLÈTE**
|
||||
|
||||
Toutes les fonctionnalités requises sont implémentées:
|
||||
- ✅ Classe VisionUtils avec chargement modèles
|
||||
- ✅ Méthode detect_with_owlv2()
|
||||
- ✅ Méthode detect_with_dino() (stub)
|
||||
- ✅ Méthode detect_with_yolo() (stub)
|
||||
- ✅ Méthode select_best_detection()
|
||||
- ✅ Gestion d'erreurs avec fallback
|
||||
- ✅ Tests unitaires complets
|
||||
3
geniusia2/core/utils/__init__.py
Normal file
3
geniusia2/core/utils/__init__.py
Normal file
@@ -0,0 +1,3 @@
|
||||
"""
|
||||
Utilitaires pour le système RPA Vision V2
|
||||
"""
|
||||
BIN
geniusia2/core/utils/__pycache__/__init__.cpython-312.pyc
Normal file
BIN
geniusia2/core/utils/__pycache__/__init__.cpython-312.pyc
Normal file
Binary file not shown.
BIN
geniusia2/core/utils/__pycache__/image_utils.cpython-312.pyc
Normal file
BIN
geniusia2/core/utils/__pycache__/image_utils.cpython-312.pyc
Normal file
Binary file not shown.
BIN
geniusia2/core/utils/__pycache__/input_utils.cpython-312.pyc
Normal file
BIN
geniusia2/core/utils/__pycache__/input_utils.cpython-312.pyc
Normal file
Binary file not shown.
BIN
geniusia2/core/utils/__pycache__/vision_utils.cpython-312.pyc
Normal file
BIN
geniusia2/core/utils/__pycache__/vision_utils.cpython-312.pyc
Normal file
Binary file not shown.
520
geniusia2/core/utils/image_utils.py
Normal file
520
geniusia2/core/utils/image_utils.py
Normal file
@@ -0,0 +1,520 @@
|
||||
"""
|
||||
Utilitaires pour la capture d'écran et le traitement d'images
|
||||
Fournit des fonctions pour capturer l'écran, extraire des ROI et dessiner des bounding boxes
|
||||
"""
|
||||
|
||||
import numpy as np
|
||||
import cv2
|
||||
from typing import Tuple, Optional
|
||||
import platform
|
||||
import subprocess
|
||||
|
||||
|
||||
def capture_screen() -> np.ndarray:
|
||||
"""
|
||||
Capture l'écran complet et retourne l'image en format numpy array
|
||||
|
||||
Returns:
|
||||
Image de l'écran en format BGR (OpenCV standard)
|
||||
|
||||
Raises:
|
||||
RuntimeError: Si la capture d'écran échoue
|
||||
"""
|
||||
try:
|
||||
# Utiliser différentes méthodes selon le système d'exploitation
|
||||
system = platform.system()
|
||||
|
||||
if system == "Linux":
|
||||
# Sur Linux, utiliser scrot ou gnome-screenshot
|
||||
return _capture_screen_linux()
|
||||
elif system == "Windows":
|
||||
# Sur Windows, utiliser mss ou pyautogui
|
||||
return _capture_screen_windows()
|
||||
elif system == "Darwin": # macOS
|
||||
# Sur macOS, utiliser screencapture
|
||||
return _capture_screen_macos()
|
||||
else:
|
||||
raise RuntimeError(f"Système d'exploitation non supporté: {system}")
|
||||
|
||||
except Exception as e:
|
||||
raise RuntimeError(f"Échec de la capture d'écran: {str(e)}")
|
||||
|
||||
|
||||
def _capture_screen_linux() -> np.ndarray:
|
||||
"""
|
||||
Capture d'écran spécifique à Linux
|
||||
Utilise mss pour une capture rapide
|
||||
"""
|
||||
try:
|
||||
import mss
|
||||
import mss.tools
|
||||
|
||||
with mss.mss() as sct:
|
||||
# Capturer le moniteur principal
|
||||
monitor = sct.monitors[1]
|
||||
screenshot = sct.grab(monitor)
|
||||
|
||||
# Convertir en numpy array
|
||||
img = np.array(screenshot)
|
||||
|
||||
# Convertir BGRA vers BGR
|
||||
img = cv2.cvtColor(img, cv2.COLOR_BGRA2BGR)
|
||||
|
||||
return img
|
||||
except ImportError:
|
||||
# Fallback: utiliser PIL/Pillow
|
||||
return _capture_screen_pil()
|
||||
|
||||
|
||||
def _capture_screen_windows() -> np.ndarray:
|
||||
"""
|
||||
Capture d'écran spécifique à Windows
|
||||
"""
|
||||
try:
|
||||
import mss
|
||||
import mss.tools
|
||||
|
||||
with mss.mss() as sct:
|
||||
monitor = sct.monitors[1]
|
||||
screenshot = sct.grab(monitor)
|
||||
img = np.array(screenshot)
|
||||
img = cv2.cvtColor(img, cv2.COLOR_BGRA2BGR)
|
||||
return img
|
||||
except ImportError:
|
||||
return _capture_screen_pil()
|
||||
|
||||
|
||||
def _capture_screen_macos() -> np.ndarray:
|
||||
"""
|
||||
Capture d'écran spécifique à macOS
|
||||
"""
|
||||
try:
|
||||
import mss
|
||||
import mss.tools
|
||||
|
||||
with mss.mss() as sct:
|
||||
monitor = sct.monitors[1]
|
||||
screenshot = sct.grab(monitor)
|
||||
img = np.array(screenshot)
|
||||
img = cv2.cvtColor(img, cv2.COLOR_BGRA2BGR)
|
||||
return img
|
||||
except ImportError:
|
||||
return _capture_screen_pil()
|
||||
|
||||
|
||||
def _capture_screen_pil() -> np.ndarray:
|
||||
"""
|
||||
Capture d'écran en utilisant PIL/Pillow (fallback)
|
||||
"""
|
||||
try:
|
||||
from PIL import ImageGrab
|
||||
|
||||
screenshot = ImageGrab.grab()
|
||||
img = np.array(screenshot)
|
||||
|
||||
# Convertir RGB vers BGR (format OpenCV)
|
||||
img = cv2.cvtColor(img, cv2.COLOR_RGB2BGR)
|
||||
|
||||
return img
|
||||
except ImportError:
|
||||
raise RuntimeError("Aucune bibliothèque de capture d'écran disponible. "
|
||||
"Installez 'mss' ou 'Pillow'.")
|
||||
|
||||
|
||||
def get_active_window() -> str:
|
||||
"""
|
||||
Obtient le titre de la fenêtre active
|
||||
|
||||
Returns:
|
||||
Titre de la fenêtre active, ou chaîne vide si impossible à déterminer
|
||||
"""
|
||||
try:
|
||||
system = platform.system()
|
||||
|
||||
if system == "Linux":
|
||||
return _get_active_window_linux()
|
||||
elif system == "Windows":
|
||||
return _get_active_window_windows()
|
||||
elif system == "Darwin": # macOS
|
||||
return _get_active_window_macos()
|
||||
else:
|
||||
return ""
|
||||
|
||||
except Exception as e:
|
||||
print(f"Erreur lors de la récupération de la fenêtre active: {e}")
|
||||
return ""
|
||||
|
||||
|
||||
def _get_active_window_linux() -> str:
|
||||
"""
|
||||
Obtient la fenêtre active sur Linux avec plusieurs méthodes de fallback
|
||||
"""
|
||||
# Méthode 1: xdotool (le plus fiable)
|
||||
try:
|
||||
result = subprocess.run(
|
||||
["xdotool", "getactivewindow", "getwindowname"],
|
||||
capture_output=True,
|
||||
text=True,
|
||||
timeout=1,
|
||||
check=False
|
||||
)
|
||||
if result.returncode == 0 and result.stdout.strip():
|
||||
return result.stdout.strip()
|
||||
except (subprocess.TimeoutExpired, FileNotFoundError):
|
||||
pass
|
||||
|
||||
# Méthode 2: xprop avec _NET_ACTIVE_WINDOW
|
||||
try:
|
||||
# Obtenir l'ID de la fenêtre active
|
||||
result = subprocess.run(
|
||||
["xprop", "-root", "_NET_ACTIVE_WINDOW"],
|
||||
capture_output=True,
|
||||
text=True,
|
||||
timeout=1,
|
||||
check=False
|
||||
)
|
||||
if result.returncode == 0:
|
||||
# Extraire l'ID de fenêtre (format: "_NET_ACTIVE_WINDOW(WINDOW): window id # 0x...")
|
||||
window_id = result.stdout.strip().split()[-1]
|
||||
|
||||
# Obtenir le nom de la fenêtre
|
||||
result2 = subprocess.run(
|
||||
["xprop", "-id", window_id, "WM_NAME"],
|
||||
capture_output=True,
|
||||
text=True,
|
||||
timeout=1,
|
||||
check=False
|
||||
)
|
||||
if result2.returncode == 0:
|
||||
# Format: WM_NAME(STRING) = "Titre de la fenêtre"
|
||||
name = result2.stdout.strip()
|
||||
if '=' in name:
|
||||
title = name.split('=', 1)[1].strip().strip('"')
|
||||
if title:
|
||||
return title
|
||||
except (subprocess.TimeoutExpired, FileNotFoundError):
|
||||
pass
|
||||
|
||||
# Méthode 3: wmctrl
|
||||
try:
|
||||
result = subprocess.run(
|
||||
["wmctrl", "-l", "-p"],
|
||||
capture_output=True,
|
||||
text=True,
|
||||
timeout=1,
|
||||
check=False
|
||||
)
|
||||
if result.returncode == 0:
|
||||
lines = result.stdout.strip().split('\n')
|
||||
# Essayer de trouver la fenêtre active (première ligne comme approximation)
|
||||
if lines and lines[0]:
|
||||
parts = lines[0].split(None, 4)
|
||||
if len(parts) >= 5:
|
||||
return parts[4]
|
||||
except (subprocess.TimeoutExpired, FileNotFoundError):
|
||||
pass
|
||||
|
||||
# Méthode 4: Essayer avec Python Xlib (si disponible)
|
||||
try:
|
||||
from Xlib import X, display
|
||||
from Xlib.error import XError
|
||||
|
||||
d = display.Display()
|
||||
root = d.screen().root
|
||||
|
||||
# Obtenir la fenêtre active
|
||||
window_id = root.get_full_property(
|
||||
d.intern_atom('_NET_ACTIVE_WINDOW'),
|
||||
X.AnyPropertyType
|
||||
)
|
||||
|
||||
if window_id and window_id.value:
|
||||
active_window = d.create_resource_object('window', window_id.value[0])
|
||||
window_name = active_window.get_wm_name()
|
||||
if window_name:
|
||||
return window_name
|
||||
except (ImportError, XError, Exception):
|
||||
pass
|
||||
|
||||
return "Unknown Window"
|
||||
|
||||
|
||||
def _get_active_window_windows() -> str:
|
||||
"""
|
||||
Obtient la fenêtre active sur Windows
|
||||
"""
|
||||
try:
|
||||
import win32gui
|
||||
|
||||
hwnd = win32gui.GetForegroundWindow()
|
||||
return win32gui.GetWindowText(hwnd)
|
||||
except ImportError:
|
||||
# Fallback sans pywin32
|
||||
try:
|
||||
import ctypes
|
||||
|
||||
hwnd = ctypes.windll.user32.GetForegroundWindow()
|
||||
length = ctypes.windll.user32.GetWindowTextLengthW(hwnd)
|
||||
buff = ctypes.create_unicode_buffer(length + 1)
|
||||
ctypes.windll.user32.GetWindowTextW(hwnd, buff, length + 1)
|
||||
return buff.value
|
||||
except Exception:
|
||||
return ""
|
||||
|
||||
|
||||
def _get_active_window_macos() -> str:
|
||||
"""
|
||||
Obtient la fenêtre active sur macOS
|
||||
"""
|
||||
try:
|
||||
script = '''
|
||||
tell application "System Events"
|
||||
set frontApp to name of first application process whose frontmost is true
|
||||
set frontWindow to name of front window of application process frontApp
|
||||
return frontApp & " - " & frontWindow
|
||||
end tell
|
||||
'''
|
||||
result = subprocess.run(
|
||||
["osascript", "-e", script],
|
||||
capture_output=True,
|
||||
text=True,
|
||||
timeout=1
|
||||
)
|
||||
if result.returncode == 0:
|
||||
return result.stdout.strip()
|
||||
except (subprocess.TimeoutExpired, FileNotFoundError):
|
||||
pass
|
||||
|
||||
return ""
|
||||
|
||||
|
||||
def extract_roi(frame: np.ndarray, bbox: Tuple[int, int, int, int]) -> np.ndarray:
|
||||
"""
|
||||
Extrait une région d'intérêt (ROI) d'une image
|
||||
|
||||
Args:
|
||||
frame: Image source en format numpy array
|
||||
bbox: Bounding box (x, y, width, height) en pixels
|
||||
|
||||
Returns:
|
||||
Image de la région d'intérêt
|
||||
|
||||
Raises:
|
||||
ValueError: Si la bounding box est invalide
|
||||
"""
|
||||
x, y, w, h = bbox
|
||||
|
||||
# Valider les dimensions
|
||||
if w <= 0 or h <= 0:
|
||||
raise ValueError(f"Dimensions de bounding box invalides: width={w}, height={h}")
|
||||
|
||||
# Obtenir les dimensions de l'image
|
||||
img_height, img_width = frame.shape[:2]
|
||||
|
||||
# Limiter les coordonnées aux dimensions de l'image
|
||||
x = max(0, min(x, img_width - 1))
|
||||
y = max(0, min(y, img_height - 1))
|
||||
x2 = max(0, min(x + w, img_width))
|
||||
y2 = max(0, min(y + h, img_height))
|
||||
|
||||
# Extraire la ROI
|
||||
roi = frame[y:y2, x:x2]
|
||||
|
||||
# Vérifier que la ROI n'est pas vide
|
||||
if roi.size == 0:
|
||||
raise ValueError(f"ROI vide avec bbox={bbox}, image_size=({img_width}, {img_height})")
|
||||
|
||||
return roi
|
||||
|
||||
|
||||
def draw_bbox(frame: np.ndarray, bbox: Tuple[int, int, int, int],
|
||||
label: str = "", color: Tuple[int, int, int] = (0, 255, 0),
|
||||
thickness: int = 2) -> np.ndarray:
|
||||
"""
|
||||
Dessine une bounding box sur une image avec un label optionnel
|
||||
|
||||
Args:
|
||||
frame: Image sur laquelle dessiner
|
||||
bbox: Bounding box (x, y, width, height) en pixels
|
||||
label: Label à afficher au-dessus de la box (optionnel)
|
||||
color: Couleur BGR de la box (par défaut: vert)
|
||||
thickness: Épaisseur de la ligne en pixels
|
||||
|
||||
Returns:
|
||||
Image avec la bounding box dessinée (copie de l'original)
|
||||
"""
|
||||
# Créer une copie pour ne pas modifier l'original
|
||||
img = frame.copy()
|
||||
|
||||
x, y, w, h = bbox
|
||||
|
||||
# Dessiner le rectangle
|
||||
cv2.rectangle(img, (x, y), (x + w, y + h), color, thickness)
|
||||
|
||||
# Dessiner le label si fourni
|
||||
if label:
|
||||
# Calculer la taille du texte
|
||||
font = cv2.FONT_HERSHEY_SIMPLEX
|
||||
font_scale = 0.6
|
||||
font_thickness = 2
|
||||
(text_width, text_height), baseline = cv2.getTextSize(
|
||||
label, font, font_scale, font_thickness
|
||||
)
|
||||
|
||||
# Dessiner un rectangle de fond pour le texte
|
||||
label_y = y - 10 if y - 10 > text_height else y + h + text_height + 10
|
||||
cv2.rectangle(
|
||||
img,
|
||||
(x, label_y - text_height - baseline),
|
||||
(x + text_width, label_y + baseline),
|
||||
color,
|
||||
-1 # Remplir
|
||||
)
|
||||
|
||||
# Dessiner le texte
|
||||
cv2.putText(
|
||||
img,
|
||||
label,
|
||||
(x, label_y),
|
||||
font,
|
||||
font_scale,
|
||||
(255, 255, 255), # Blanc
|
||||
font_thickness
|
||||
)
|
||||
|
||||
return img
|
||||
|
||||
|
||||
def resize_image(image: np.ndarray, max_width: int = 1920,
|
||||
max_height: int = 1080) -> np.ndarray:
|
||||
"""
|
||||
Redimensionne une image en conservant le ratio d'aspect
|
||||
|
||||
Args:
|
||||
image: Image à redimensionner
|
||||
max_width: Largeur maximale
|
||||
max_height: Hauteur maximale
|
||||
|
||||
Returns:
|
||||
Image redimensionnée
|
||||
"""
|
||||
height, width = image.shape[:2]
|
||||
|
||||
# Calculer le ratio de redimensionnement
|
||||
ratio = min(max_width / width, max_height / height)
|
||||
|
||||
# Si l'image est déjà plus petite, ne pas la redimensionner
|
||||
if ratio >= 1.0:
|
||||
return image
|
||||
|
||||
# Calculer les nouvelles dimensions
|
||||
new_width = int(width * ratio)
|
||||
new_height = int(height * ratio)
|
||||
|
||||
# Redimensionner
|
||||
resized = cv2.resize(image, (new_width, new_height), interpolation=cv2.INTER_AREA)
|
||||
|
||||
return resized
|
||||
|
||||
|
||||
def save_image(image: np.ndarray, filepath: str) -> bool:
|
||||
"""
|
||||
Sauvegarde une image sur le disque
|
||||
|
||||
Args:
|
||||
image: Image à sauvegarder
|
||||
filepath: Chemin du fichier de destination
|
||||
|
||||
Returns:
|
||||
True si la sauvegarde a réussi, False sinon
|
||||
"""
|
||||
try:
|
||||
cv2.imwrite(filepath, image)
|
||||
return True
|
||||
except Exception as e:
|
||||
print(f"Erreur lors de la sauvegarde de l'image: {e}")
|
||||
return False
|
||||
|
||||
|
||||
def load_image(filepath: str) -> Optional[np.ndarray]:
|
||||
"""
|
||||
Charge une image depuis le disque
|
||||
|
||||
Args:
|
||||
filepath: Chemin du fichier image
|
||||
|
||||
Returns:
|
||||
Image en format numpy array, ou None si le chargement échoue
|
||||
"""
|
||||
try:
|
||||
image = cv2.imread(filepath)
|
||||
if image is None:
|
||||
print(f"Impossible de charger l'image: {filepath}")
|
||||
return image
|
||||
except Exception as e:
|
||||
print(f"Erreur lors du chargement de l'image: {e}")
|
||||
return None
|
||||
|
||||
|
||||
if __name__ == "__main__":
|
||||
# Tests basiques des utilitaires d'image
|
||||
print("Test des utilitaires d'image RPA Vision V2")
|
||||
print("=" * 50)
|
||||
|
||||
# Test 1: Capture d'écran
|
||||
print("\n1. Test capture_screen():")
|
||||
try:
|
||||
screen = capture_screen()
|
||||
print(f" ✓ Capture réussie: {screen.shape} (H x W x C)")
|
||||
print(f" Type: {screen.dtype}")
|
||||
except Exception as e:
|
||||
print(f" ✗ Échec: {e}")
|
||||
|
||||
# Test 2: Fenêtre active
|
||||
print("\n2. Test get_active_window():")
|
||||
window_title = get_active_window()
|
||||
if window_title:
|
||||
print(f" ✓ Fenêtre active: '{window_title}'")
|
||||
else:
|
||||
print(f" ⚠ Impossible de déterminer la fenêtre active")
|
||||
|
||||
# Test 3: Extraction ROI
|
||||
print("\n3. Test extract_roi():")
|
||||
try:
|
||||
# Créer une image de test
|
||||
test_img = np.zeros((480, 640, 3), dtype=np.uint8)
|
||||
test_img[100:200, 150:300] = [0, 255, 0] # Rectangle vert
|
||||
|
||||
# Extraire une ROI
|
||||
roi = extract_roi(test_img, (150, 100, 150, 100))
|
||||
print(f" ✓ ROI extraite: {roi.shape}")
|
||||
|
||||
# Test avec bbox invalide (devrait être limité)
|
||||
roi2 = extract_roi(test_img, (600, 400, 100, 100))
|
||||
print(f" ✓ ROI avec bbox hors limites: {roi2.shape}")
|
||||
except Exception as e:
|
||||
print(f" ✗ Échec: {e}")
|
||||
|
||||
# Test 4: Dessin de bounding box
|
||||
print("\n4. Test draw_bbox():")
|
||||
try:
|
||||
test_img = np.zeros((480, 640, 3), dtype=np.uint8)
|
||||
|
||||
# Dessiner plusieurs bounding boxes
|
||||
img_with_bbox = draw_bbox(test_img, (100, 100, 200, 150), "Bouton 1", (0, 255, 0))
|
||||
img_with_bbox = draw_bbox(img_with_bbox, (350, 200, 150, 100), "Bouton 2", (255, 0, 0))
|
||||
|
||||
print(f" ✓ Bounding boxes dessinées: {img_with_bbox.shape}")
|
||||
except Exception as e:
|
||||
print(f" ✗ Échec: {e}")
|
||||
|
||||
# Test 5: Redimensionnement
|
||||
print("\n5. Test resize_image():")
|
||||
try:
|
||||
large_img = np.zeros((2160, 3840, 3), dtype=np.uint8)
|
||||
resized = resize_image(large_img, max_width=1920, max_height=1080)
|
||||
print(f" ✓ Image redimensionnée: {large_img.shape} -> {resized.shape}")
|
||||
except Exception as e:
|
||||
print(f" ✗ Échec: {e}")
|
||||
|
||||
print("\n✓ Tests terminés!")
|
||||
608
geniusia2/core/utils/input_utils.py
Normal file
608
geniusia2/core/utils/input_utils.py
Normal file
@@ -0,0 +1,608 @@
|
||||
"""
|
||||
Utilitaires d'entrée pour exécuter des actions UI (souris, clavier, etc.).
|
||||
Support du clavier AZERTY et gestion du rollback d'actions.
|
||||
"""
|
||||
|
||||
import time
|
||||
import pyautogui
|
||||
from typing import Dict, Any, Optional, Tuple
|
||||
from enum import Enum
|
||||
|
||||
from ..logger import Logger
|
||||
|
||||
|
||||
class ActionType(Enum):
|
||||
"""Types d'actions UI supportées."""
|
||||
CLICK = "click"
|
||||
TYPE = "type"
|
||||
SCROLL = "scroll"
|
||||
WAIT = "wait"
|
||||
MOVE = "move"
|
||||
DRAG = "drag"
|
||||
|
||||
|
||||
class InputUtils:
|
||||
"""
|
||||
Gestionnaire d'entrées utilisateur pour exécuter des actions UI.
|
||||
Support du clavier AZERTY et rollback d'actions.
|
||||
"""
|
||||
|
||||
def __init__(self, logger: Logger, config: Dict[str, Any]):
|
||||
"""
|
||||
Initialise les utilitaires d'entrée.
|
||||
|
||||
Args:
|
||||
logger: Logger pour journalisation
|
||||
config: Configuration globale
|
||||
"""
|
||||
self.logger = logger
|
||||
self.config = config
|
||||
|
||||
# Configuration PyAutoGUI
|
||||
pyautogui.FAILSAFE = True # Déplacer souris dans coin = arrêt
|
||||
pyautogui.PAUSE = config.get("input", {}).get("pause_between_actions", 0.1)
|
||||
|
||||
# Historique des actions pour rollback
|
||||
self.action_history = []
|
||||
|
||||
# Mapping AZERTY pour caractères spéciaux
|
||||
self.azerty_mapping = {
|
||||
'0': 'à',
|
||||
'1': '&',
|
||||
'2': 'é',
|
||||
'3': '"',
|
||||
'4': "'",
|
||||
'5': '(',
|
||||
'6': '-',
|
||||
'7': 'è',
|
||||
'8': '_',
|
||||
'9': 'ç',
|
||||
'.': ':',
|
||||
'/': '!',
|
||||
',': ';',
|
||||
';': ',',
|
||||
':': '.',
|
||||
'!': '/',
|
||||
'?': 'M', # Shift + ,
|
||||
}
|
||||
|
||||
self.logger.log_action({
|
||||
"action": "input_utils_initialized",
|
||||
"failsafe": True,
|
||||
"pause": pyautogui.PAUSE
|
||||
})
|
||||
|
||||
def click(
|
||||
self,
|
||||
x: int,
|
||||
y: int,
|
||||
button: str = "left",
|
||||
clicks: int = 1,
|
||||
interval: float = 0.0
|
||||
) -> bool:
|
||||
"""
|
||||
Effectue un clic souris à la position spécifiée.
|
||||
|
||||
Args:
|
||||
x: Coordonnée X
|
||||
y: Coordonnée Y
|
||||
button: Bouton souris ("left", "right", "middle")
|
||||
clicks: Nombre de clics
|
||||
interval: Intervalle entre clics multiples
|
||||
|
||||
Returns:
|
||||
True si succès, False sinon
|
||||
"""
|
||||
try:
|
||||
# Enregistrer position actuelle pour rollback
|
||||
current_pos = pyautogui.position()
|
||||
|
||||
# Effectuer le clic
|
||||
pyautogui.click(x, y, clicks=clicks, interval=interval, button=button)
|
||||
|
||||
# Enregistrer dans l'historique
|
||||
action_record = {
|
||||
"type": ActionType.CLICK.value,
|
||||
"x": x,
|
||||
"y": y,
|
||||
"button": button,
|
||||
"clicks": clicks,
|
||||
"previous_position": current_pos,
|
||||
"timestamp": time.time()
|
||||
}
|
||||
self.action_history.append(action_record)
|
||||
|
||||
self.logger.log_action({
|
||||
"action": "click_executed",
|
||||
**action_record
|
||||
})
|
||||
|
||||
return True
|
||||
|
||||
except Exception as e:
|
||||
self.logger.log_action({
|
||||
"action": "click_failed",
|
||||
"x": x,
|
||||
"y": y,
|
||||
"error": str(e)
|
||||
})
|
||||
return False
|
||||
|
||||
def type_text(
|
||||
self,
|
||||
text: str,
|
||||
interval: float = 0.0,
|
||||
use_azerty: bool = True
|
||||
) -> bool:
|
||||
"""
|
||||
Saisit du texte au clavier.
|
||||
|
||||
Args:
|
||||
text: Texte à saisir
|
||||
interval: Intervalle entre chaque caractère
|
||||
use_azerty: Utiliser le mapping AZERTY
|
||||
|
||||
Returns:
|
||||
True si succès, False sinon
|
||||
"""
|
||||
try:
|
||||
# Convertir pour AZERTY si nécessaire
|
||||
if use_azerty:
|
||||
converted_text = self._convert_to_azerty(text)
|
||||
else:
|
||||
converted_text = text
|
||||
|
||||
# Saisir le texte
|
||||
pyautogui.write(converted_text, interval=interval)
|
||||
|
||||
# Enregistrer dans l'historique
|
||||
action_record = {
|
||||
"type": ActionType.TYPE.value,
|
||||
"text": text,
|
||||
"converted_text": converted_text,
|
||||
"length": len(text),
|
||||
"timestamp": time.time()
|
||||
}
|
||||
self.action_history.append(action_record)
|
||||
|
||||
self.logger.log_action({
|
||||
"action": "text_typed",
|
||||
**action_record
|
||||
})
|
||||
|
||||
return True
|
||||
|
||||
except Exception as e:
|
||||
self.logger.log_action({
|
||||
"action": "type_text_failed",
|
||||
"text": text[:50], # Limiter pour logs
|
||||
"error": str(e)
|
||||
})
|
||||
return False
|
||||
|
||||
def scroll(
|
||||
self,
|
||||
direction: str,
|
||||
amount: int = 3,
|
||||
x: Optional[int] = None,
|
||||
y: Optional[int] = None
|
||||
) -> bool:
|
||||
"""
|
||||
Effectue un défilement.
|
||||
|
||||
Args:
|
||||
direction: Direction ("up", "down", "left", "right")
|
||||
amount: Quantité de défilement (nombre de "clics" de molette)
|
||||
x: Position X optionnelle
|
||||
y: Position Y optionnelle
|
||||
|
||||
Returns:
|
||||
True si succès, False sinon
|
||||
"""
|
||||
try:
|
||||
# Calculer le montant de défilement
|
||||
if direction in ["up", "right"]:
|
||||
scroll_amount = amount
|
||||
elif direction in ["down", "left"]:
|
||||
scroll_amount = -amount
|
||||
else:
|
||||
raise ValueError(f"Direction invalide: {direction}")
|
||||
|
||||
# Déplacer la souris si position spécifiée
|
||||
if x is not None and y is not None:
|
||||
pyautogui.moveTo(x, y)
|
||||
|
||||
# Effectuer le défilement
|
||||
if direction in ["up", "down"]:
|
||||
pyautogui.scroll(scroll_amount)
|
||||
else:
|
||||
pyautogui.hscroll(scroll_amount)
|
||||
|
||||
# Enregistrer dans l'historique
|
||||
action_record = {
|
||||
"type": ActionType.SCROLL.value,
|
||||
"direction": direction,
|
||||
"amount": amount,
|
||||
"x": x,
|
||||
"y": y,
|
||||
"timestamp": time.time()
|
||||
}
|
||||
self.action_history.append(action_record)
|
||||
|
||||
self.logger.log_action({
|
||||
"action": "scroll_executed",
|
||||
**action_record
|
||||
})
|
||||
|
||||
return True
|
||||
|
||||
except Exception as e:
|
||||
self.logger.log_action({
|
||||
"action": "scroll_failed",
|
||||
"direction": direction,
|
||||
"amount": amount,
|
||||
"error": str(e)
|
||||
})
|
||||
return False
|
||||
|
||||
def wait(self, duration: float) -> bool:
|
||||
"""
|
||||
Attend pendant une durée spécifiée.
|
||||
|
||||
Args:
|
||||
duration: Durée en secondes
|
||||
|
||||
Returns:
|
||||
True
|
||||
"""
|
||||
try:
|
||||
time.sleep(duration)
|
||||
|
||||
action_record = {
|
||||
"type": ActionType.WAIT.value,
|
||||
"duration": duration,
|
||||
"timestamp": time.time()
|
||||
}
|
||||
self.action_history.append(action_record)
|
||||
|
||||
self.logger.log_action({
|
||||
"action": "wait_executed",
|
||||
**action_record
|
||||
})
|
||||
|
||||
return True
|
||||
|
||||
except Exception as e:
|
||||
self.logger.log_action({
|
||||
"action": "wait_failed",
|
||||
"duration": duration,
|
||||
"error": str(e)
|
||||
})
|
||||
return False
|
||||
|
||||
def move(self, x: int, y: int, duration: float = 0.2) -> bool:
|
||||
"""
|
||||
Déplace la souris vers une position.
|
||||
|
||||
Args:
|
||||
x: Coordonnée X
|
||||
y: Coordonnée Y
|
||||
duration: Durée du mouvement en secondes
|
||||
|
||||
Returns:
|
||||
True si succès, False sinon
|
||||
"""
|
||||
try:
|
||||
current_pos = pyautogui.position()
|
||||
|
||||
pyautogui.moveTo(x, y, duration=duration)
|
||||
|
||||
action_record = {
|
||||
"type": ActionType.MOVE.value,
|
||||
"x": x,
|
||||
"y": y,
|
||||
"previous_position": current_pos,
|
||||
"duration": duration,
|
||||
"timestamp": time.time()
|
||||
}
|
||||
self.action_history.append(action_record)
|
||||
|
||||
self.logger.log_action({
|
||||
"action": "move_executed",
|
||||
**action_record
|
||||
})
|
||||
|
||||
return True
|
||||
|
||||
except Exception as e:
|
||||
self.logger.log_action({
|
||||
"action": "move_failed",
|
||||
"x": x,
|
||||
"y": y,
|
||||
"error": str(e)
|
||||
})
|
||||
return False
|
||||
|
||||
def drag(
|
||||
self,
|
||||
start_x: int,
|
||||
start_y: int,
|
||||
end_x: int,
|
||||
end_y: int,
|
||||
duration: float = 0.5,
|
||||
button: str = "left"
|
||||
) -> bool:
|
||||
"""
|
||||
Effectue un glisser-déposer.
|
||||
|
||||
Args:
|
||||
start_x: X de départ
|
||||
start_y: Y de départ
|
||||
end_x: X d'arrivée
|
||||
end_y: Y d'arrivée
|
||||
duration: Durée du glissement
|
||||
button: Bouton souris
|
||||
|
||||
Returns:
|
||||
True si succès, False sinon
|
||||
"""
|
||||
try:
|
||||
current_pos = pyautogui.position()
|
||||
|
||||
pyautogui.moveTo(start_x, start_y)
|
||||
pyautogui.drag(end_x - start_x, end_y - start_y, duration=duration, button=button)
|
||||
|
||||
action_record = {
|
||||
"type": ActionType.DRAG.value,
|
||||
"start_x": start_x,
|
||||
"start_y": start_y,
|
||||
"end_x": end_x,
|
||||
"end_y": end_y,
|
||||
"previous_position": current_pos,
|
||||
"duration": duration,
|
||||
"button": button,
|
||||
"timestamp": time.time()
|
||||
}
|
||||
self.action_history.append(action_record)
|
||||
|
||||
self.logger.log_action({
|
||||
"action": "drag_executed",
|
||||
**action_record
|
||||
})
|
||||
|
||||
return True
|
||||
|
||||
except Exception as e:
|
||||
self.logger.log_action({
|
||||
"action": "drag_failed",
|
||||
"start": (start_x, start_y),
|
||||
"end": (end_x, end_y),
|
||||
"error": str(e)
|
||||
})
|
||||
return False
|
||||
|
||||
def execute_inverse_action(self, action: Dict[str, Any]) -> bool:
|
||||
"""
|
||||
Exécute l'action inverse pour rollback.
|
||||
|
||||
Args:
|
||||
action: Action à inverser
|
||||
|
||||
Returns:
|
||||
True si succès, False sinon
|
||||
"""
|
||||
inverse = self.get_inverse_action(action)
|
||||
if not inverse:
|
||||
return False
|
||||
|
||||
action_type = inverse.get("type")
|
||||
|
||||
if action_type == ActionType.MOVE.value:
|
||||
return self.move(inverse["x"], inverse["y"], inverse.get("duration", 0.2))
|
||||
elif action_type == ActionType.SCROLL.value:
|
||||
return self.scroll(
|
||||
inverse["direction"],
|
||||
inverse["amount"],
|
||||
inverse.get("x"),
|
||||
inverse.get("y")
|
||||
)
|
||||
elif action_type == ActionType.DRAG.value:
|
||||
return self.drag(
|
||||
inverse["start_x"],
|
||||
inverse["start_y"],
|
||||
inverse["end_x"],
|
||||
inverse["end_y"],
|
||||
inverse.get("duration", 0.5),
|
||||
inverse.get("button", "left")
|
||||
)
|
||||
elif action_type == "press_key":
|
||||
# Exécuter les suppressions
|
||||
for _ in range(inverse.get("presses", 0)):
|
||||
pyautogui.press("backspace")
|
||||
return True
|
||||
|
||||
return False
|
||||
|
||||
def get_inverse_action(self, action: Dict[str, Any]) -> Optional[Dict[str, Any]]:
|
||||
"""
|
||||
Génère l'action inverse pour rollback.
|
||||
|
||||
Args:
|
||||
action: Action à inverser
|
||||
|
||||
Returns:
|
||||
Action inverse ou None si non inversible
|
||||
"""
|
||||
action_type = action.get("type")
|
||||
|
||||
if action_type == ActionType.CLICK.value:
|
||||
# Un clic n'est pas vraiment inversible
|
||||
# On peut retourner à la position précédente
|
||||
prev_pos = action.get("previous_position")
|
||||
if prev_pos:
|
||||
return {
|
||||
"type": ActionType.MOVE.value,
|
||||
"x": prev_pos[0],
|
||||
"y": prev_pos[1],
|
||||
"duration": 0.2
|
||||
}
|
||||
|
||||
elif action_type == ActionType.TYPE.value:
|
||||
# Inverser la saisie = supprimer le texte
|
||||
text_length = action.get("length", 0)
|
||||
return {
|
||||
"type": "press_key",
|
||||
"key": "backspace",
|
||||
"presses": text_length
|
||||
}
|
||||
|
||||
elif action_type == ActionType.SCROLL.value:
|
||||
# Inverser le défilement
|
||||
direction = action.get("direction")
|
||||
amount = action.get("amount")
|
||||
inverse_direction = {
|
||||
"up": "down",
|
||||
"down": "up",
|
||||
"left": "right",
|
||||
"right": "left"
|
||||
}.get(direction)
|
||||
|
||||
return {
|
||||
"type": ActionType.SCROLL.value,
|
||||
"direction": inverse_direction,
|
||||
"amount": amount,
|
||||
"x": action.get("x"),
|
||||
"y": action.get("y")
|
||||
}
|
||||
|
||||
elif action_type == ActionType.MOVE.value:
|
||||
# Retourner à la position précédente
|
||||
prev_pos = action.get("previous_position")
|
||||
if prev_pos:
|
||||
return {
|
||||
"type": ActionType.MOVE.value,
|
||||
"x": prev_pos[0],
|
||||
"y": prev_pos[1],
|
||||
"duration": 0.2
|
||||
}
|
||||
|
||||
elif action_type == ActionType.DRAG.value:
|
||||
# Inverser le glissement
|
||||
return {
|
||||
"type": ActionType.DRAG.value,
|
||||
"start_x": action.get("end_x"),
|
||||
"start_y": action.get("end_y"),
|
||||
"end_x": action.get("start_x"),
|
||||
"end_y": action.get("start_y"),
|
||||
"duration": action.get("duration", 0.5),
|
||||
"button": action.get("button", "left")
|
||||
}
|
||||
|
||||
elif action_type == ActionType.WAIT.value:
|
||||
# L'attente n'a pas d'inverse
|
||||
return None
|
||||
|
||||
return None
|
||||
|
||||
def _convert_to_azerty(self, text: str) -> str:
|
||||
"""
|
||||
Convertit du texte pour clavier AZERTY.
|
||||
|
||||
Args:
|
||||
text: Texte à convertir
|
||||
|
||||
Returns:
|
||||
Texte converti
|
||||
"""
|
||||
# Pour l'instant, retourner tel quel
|
||||
# PyAutoGUI gère déjà le layout clavier du système
|
||||
# Cette méthode peut être étendue si nécessaire
|
||||
return text
|
||||
|
||||
def get_action_history(self, limit: int = 50) -> list:
|
||||
"""
|
||||
Retourne l'historique des actions.
|
||||
|
||||
Args:
|
||||
limit: Nombre maximum d'actions à retourner
|
||||
|
||||
Returns:
|
||||
Liste des dernières actions
|
||||
"""
|
||||
return self.action_history[-limit:]
|
||||
|
||||
def clear_history(self):
|
||||
"""Efface l'historique des actions."""
|
||||
self.action_history = []
|
||||
self.logger.log_action({
|
||||
"action": "action_history_cleared"
|
||||
})
|
||||
|
||||
def execute_action(self, action_data: Dict[str, Any]) -> bool:
|
||||
"""
|
||||
Exécute une action depuis un dictionnaire de données.
|
||||
|
||||
Args:
|
||||
action_data: Données de l'action à exécuter
|
||||
{
|
||||
"action_type": str,
|
||||
"bbox": (x, y, w, h),
|
||||
"parameters": dict
|
||||
}
|
||||
|
||||
Returns:
|
||||
True si succès, False sinon
|
||||
"""
|
||||
action_type = action_data.get("action_type", "").lower()
|
||||
bbox = action_data.get("bbox", (0, 0, 0, 0))
|
||||
params = action_data.get("parameters", {})
|
||||
|
||||
# Calculer le centre de la bbox pour les actions de clic
|
||||
x, y, w, h = bbox
|
||||
center_x = x + w // 2
|
||||
center_y = y + h // 2
|
||||
|
||||
if action_type == "click":
|
||||
button = params.get("button", "left")
|
||||
clicks = params.get("clicks", 1)
|
||||
return self.click(center_x, center_y, button=button, clicks=clicks)
|
||||
|
||||
elif action_type == "double_click":
|
||||
return self.click(center_x, center_y, clicks=2)
|
||||
|
||||
elif action_type == "right_click":
|
||||
return self.click(center_x, center_y, button="right")
|
||||
|
||||
elif action_type == "type":
|
||||
text = params.get("text", "")
|
||||
interval = params.get("interval", 0.0)
|
||||
return self.type_text(text, interval=interval)
|
||||
|
||||
elif action_type == "scroll":
|
||||
direction = params.get("direction", "down")
|
||||
amount = params.get("amount", 3)
|
||||
return self.scroll(direction, amount, center_x, center_y)
|
||||
|
||||
elif action_type == "wait":
|
||||
duration = params.get("duration", 1.0)
|
||||
return self.wait(duration)
|
||||
|
||||
elif action_type == "move":
|
||||
duration = params.get("duration", 0.2)
|
||||
return self.move(center_x, center_y, duration=duration)
|
||||
|
||||
elif action_type == "drag":
|
||||
end_bbox = params.get("end_bbox", bbox)
|
||||
end_x, end_y, end_w, end_h = end_bbox
|
||||
end_center_x = end_x + end_w // 2
|
||||
end_center_y = end_y + end_h // 2
|
||||
duration = params.get("duration", 0.5)
|
||||
button = params.get("button", "left")
|
||||
return self.drag(center_x, center_y, end_center_x, end_center_y, duration, button)
|
||||
|
||||
else:
|
||||
self.logger.log_action({
|
||||
"action": "unknown_action_type",
|
||||
"action_type": action_type
|
||||
})
|
||||
return False
|
||||
798
geniusia2/core/utils/vision_utils.py
Normal file
798
geniusia2/core/utils/vision_utils.py
Normal file
@@ -0,0 +1,798 @@
|
||||
"""
|
||||
Utilitaires de vision pour détection d'éléments UI
|
||||
Fournit des interfaces vers les modèles de vision (OWL-v2, Grounding DINO, YOLO-World)
|
||||
"""
|
||||
|
||||
import logging
|
||||
from typing import List, Dict, Any, Optional, Tuple
|
||||
import numpy as np
|
||||
from pathlib import Path
|
||||
|
||||
from ..models import Detection
|
||||
from ..config import get_config, get_model_config
|
||||
from .image_utils import extract_roi
|
||||
|
||||
# Configuration du logger
|
||||
logger = logging.getLogger(__name__)
|
||||
|
||||
|
||||
class VisionUtils:
|
||||
"""
|
||||
Classe utilitaire pour la détection d'éléments UI avec plusieurs modèles de vision
|
||||
Supporte OWL-v2, Grounding DINO et YOLO-World avec fallback automatique
|
||||
"""
|
||||
|
||||
def __init__(self, config: Optional[Dict[str, Any]] = None):
|
||||
"""
|
||||
Initialise VisionUtils avec les modèles de vision
|
||||
|
||||
Args:
|
||||
config: Configuration optionnelle (utilise CONFIG global si None)
|
||||
"""
|
||||
self.config = config or get_config()
|
||||
self.model_config = get_model_config()
|
||||
|
||||
# Modèle principal configuré
|
||||
self.primary_model = self.model_config.get("vision", "owl-v2")
|
||||
|
||||
# Ordre de fallback des modèles
|
||||
self.fallback_order = ["owl-v2", "dino", "yolo"]
|
||||
|
||||
# Modèles chargés (lazy loading)
|
||||
self._models = {}
|
||||
self._models_loaded = {
|
||||
"owl-v2": False,
|
||||
"dino": False,
|
||||
"yolo": False,
|
||||
}
|
||||
|
||||
logger.info(f"VisionUtils initialisé avec modèle principal: {self.primary_model}")
|
||||
|
||||
def _load_owlv2(self) -> Any:
|
||||
"""
|
||||
Charge le modèle OWL-v2 (OWLv2 pour détection open-vocabulary)
|
||||
|
||||
Returns:
|
||||
Modèle OWL-v2 chargé
|
||||
"""
|
||||
try:
|
||||
logger.info("Chargement du modèle OWL-v2...")
|
||||
|
||||
# Import dynamique pour éviter les dépendances si non utilisé
|
||||
from transformers import Owlv2Processor, Owlv2ForObjectDetection
|
||||
import torch
|
||||
|
||||
model_path = self.model_config["paths"].get("owl_v2")
|
||||
|
||||
# Charger le modèle pré-entraîné
|
||||
processor = Owlv2Processor.from_pretrained(
|
||||
"google/owlv2-base-patch16-ensemble",
|
||||
cache_dir=model_path
|
||||
)
|
||||
model = Owlv2ForObjectDetection.from_pretrained(
|
||||
"google/owlv2-base-patch16-ensemble",
|
||||
cache_dir=model_path
|
||||
)
|
||||
|
||||
# Déplacer vers GPU si disponible
|
||||
device = "cuda" if torch.cuda.is_available() else "cpu"
|
||||
model = model.to(device)
|
||||
model.eval()
|
||||
|
||||
self._models["owl-v2"] = {
|
||||
"processor": processor,
|
||||
"model": model,
|
||||
"device": device
|
||||
}
|
||||
self._models_loaded["owl-v2"] = True
|
||||
|
||||
logger.info(f"OWL-v2 chargé avec succès sur {device}")
|
||||
return self._models["owl-v2"]
|
||||
|
||||
except Exception as e:
|
||||
logger.error(f"Erreur lors du chargement d'OWL-v2: {e}")
|
||||
self._models_loaded["owl-v2"] = False
|
||||
raise
|
||||
|
||||
def _load_dino(self) -> Any:
|
||||
"""
|
||||
Charge le modèle Grounding DINO
|
||||
|
||||
Returns:
|
||||
Modèle Grounding DINO chargé
|
||||
"""
|
||||
try:
|
||||
logger.info("Chargement du modèle Grounding DINO...")
|
||||
|
||||
from transformers import AutoProcessor, AutoModelForZeroShotObjectDetection
|
||||
import torch
|
||||
|
||||
# Charger le modèle Grounding DINO depuis HuggingFace
|
||||
model_id = "IDEA-Research/grounding-dino-tiny"
|
||||
|
||||
processor = AutoProcessor.from_pretrained(model_id)
|
||||
model = AutoModelForZeroShotObjectDetection.from_pretrained(model_id)
|
||||
|
||||
# Déplacer vers GPU si disponible
|
||||
device = "cuda" if torch.cuda.is_available() else "cpu"
|
||||
model = model.to(device)
|
||||
model.eval()
|
||||
|
||||
self._models["dino"] = {
|
||||
"processor": processor,
|
||||
"model": model,
|
||||
"device": device
|
||||
}
|
||||
self._models_loaded["dino"] = True
|
||||
|
||||
logger.info(f"Grounding DINO chargé avec succès sur {device}")
|
||||
return self._models["dino"]
|
||||
|
||||
except Exception as e:
|
||||
logger.error(f"Erreur lors du chargement de Grounding DINO: {e}")
|
||||
self._models_loaded["dino"] = False
|
||||
self._models["dino"] = {"model": None, "loaded": False}
|
||||
return self._models["dino"]
|
||||
|
||||
def _load_yolo(self) -> Any:
|
||||
"""
|
||||
Charge le modèle YOLO-World
|
||||
|
||||
Returns:
|
||||
Modèle YOLO-World chargé
|
||||
"""
|
||||
try:
|
||||
logger.info("Chargement du modèle YOLO-World...")
|
||||
|
||||
from ultralytics import YOLOWorld
|
||||
|
||||
# Charger YOLO-World (modèle pré-entraîné)
|
||||
model = YOLOWorld("yolov8s-worldv2.pt")
|
||||
|
||||
self._models["yolo"] = {
|
||||
"model": model
|
||||
}
|
||||
self._models_loaded["yolo"] = True
|
||||
|
||||
logger.info("YOLO-World chargé avec succès")
|
||||
return self._models["yolo"]
|
||||
|
||||
except Exception as e:
|
||||
logger.error(f"Erreur lors du chargement de YOLO-World: {e}")
|
||||
self._models_loaded["yolo"] = False
|
||||
self._models["yolo"] = {"model": None, "loaded": False}
|
||||
return self._models["yolo"]
|
||||
|
||||
def _ensure_model_loaded(self, model_name: str) -> bool:
|
||||
"""
|
||||
S'assure qu'un modèle est chargé
|
||||
|
||||
Args:
|
||||
model_name: Nom du modèle ("owl-v2", "dino", "yolo")
|
||||
|
||||
Returns:
|
||||
True si le modèle est chargé avec succès
|
||||
"""
|
||||
if self._models_loaded.get(model_name, False):
|
||||
return True
|
||||
|
||||
try:
|
||||
if model_name == "owl-v2":
|
||||
self._load_owlv2()
|
||||
elif model_name == "dino":
|
||||
self._load_dino()
|
||||
elif model_name == "yolo":
|
||||
self._load_yolo()
|
||||
else:
|
||||
logger.error(f"Modèle inconnu: {model_name}")
|
||||
return False
|
||||
|
||||
return self._models_loaded.get(model_name, False)
|
||||
|
||||
except Exception as e:
|
||||
logger.error(f"Impossible de charger le modèle {model_name}: {e}")
|
||||
return False
|
||||
|
||||
def detect_with_owlv2(self, prompt: str, frame: np.ndarray) -> List[Detection]:
|
||||
"""
|
||||
Détection d'éléments UI avec OWL-v2
|
||||
|
||||
Args:
|
||||
prompt: Description textuelle de l'élément à détecter
|
||||
frame: Image de l'écran (numpy array RGB)
|
||||
|
||||
Returns:
|
||||
Liste de détections trouvées
|
||||
"""
|
||||
try:
|
||||
# S'assurer que le modèle est chargé
|
||||
if not self._ensure_model_loaded("owl-v2"):
|
||||
logger.error("OWL-v2 n'est pas disponible")
|
||||
return []
|
||||
|
||||
import torch
|
||||
from PIL import Image
|
||||
|
||||
model_data = self._models["owl-v2"]
|
||||
processor = model_data["processor"]
|
||||
model = model_data["model"]
|
||||
device = model_data["device"]
|
||||
|
||||
# Convertir frame numpy en PIL Image
|
||||
if frame.dtype != np.uint8:
|
||||
frame = (frame * 255).astype(np.uint8)
|
||||
image = Image.fromarray(frame)
|
||||
|
||||
# Préparer les prompts (OWL-v2 accepte plusieurs prompts)
|
||||
texts = [[prompt]]
|
||||
|
||||
# Traiter l'image et le texte
|
||||
inputs = processor(text=texts, images=image, return_tensors="pt")
|
||||
inputs = {k: v.to(device) for k, v in inputs.items()}
|
||||
|
||||
# Inférence
|
||||
with torch.no_grad():
|
||||
outputs = model(**inputs)
|
||||
|
||||
# Post-traitement des résultats
|
||||
target_sizes = torch.tensor([image.size[::-1]]).to(device)
|
||||
results = processor.post_process_object_detection(
|
||||
outputs=outputs,
|
||||
threshold=0.1, # Seuil bas pour capturer plus de détections
|
||||
target_sizes=target_sizes
|
||||
)[0]
|
||||
|
||||
# Convertir en objets Detection
|
||||
detections = []
|
||||
boxes = results["boxes"].cpu().numpy()
|
||||
scores = results["scores"].cpu().numpy()
|
||||
labels = results["labels"].cpu().numpy()
|
||||
|
||||
for box, score, label in zip(boxes, scores, labels):
|
||||
# Convertir bbox de [x1, y1, x2, y2] vers [x, y, w, h]
|
||||
x1, y1, x2, y2 = box
|
||||
x, y = int(x1), int(y1)
|
||||
w, h = int(x2 - x1), int(y2 - y1)
|
||||
|
||||
# Extraire ROI pour embedding
|
||||
roi = extract_roi(frame, (x, y, w, h))
|
||||
|
||||
# Créer embedding simple (sera remplacé par OpenCLIP plus tard)
|
||||
embedding = np.random.rand(512) # Placeholder
|
||||
|
||||
detection = Detection(
|
||||
label=prompt,
|
||||
confidence=float(score),
|
||||
bbox=(x, y, w, h),
|
||||
embedding=embedding,
|
||||
model_source="owl-v2",
|
||||
roi_image=roi,
|
||||
metadata={
|
||||
"label_id": int(label),
|
||||
"raw_box": box.tolist()
|
||||
}
|
||||
)
|
||||
detections.append(detection)
|
||||
|
||||
logger.info(f"OWL-v2: {len(detections)} détections pour '{prompt}'")
|
||||
return detections
|
||||
|
||||
except Exception as e:
|
||||
logger.error(f"Erreur lors de la détection OWL-v2: {e}")
|
||||
return []
|
||||
|
||||
def detect_with_dino(self, prompt: str, frame: np.ndarray) -> List[Detection]:
|
||||
"""
|
||||
Détection d'éléments UI avec Grounding DINO
|
||||
|
||||
Args:
|
||||
prompt: Description textuelle de l'élément à détecter
|
||||
frame: Image de l'écran (numpy array RGB)
|
||||
|
||||
Returns:
|
||||
Liste de détections trouvées
|
||||
"""
|
||||
try:
|
||||
# S'assurer que le modèle est chargé
|
||||
if not self._ensure_model_loaded("dino"):
|
||||
logger.warning("Grounding DINO n'est pas disponible")
|
||||
return []
|
||||
|
||||
import torch
|
||||
from PIL import Image
|
||||
|
||||
model_data = self._models["dino"]
|
||||
if not model_data.get("model"):
|
||||
return []
|
||||
|
||||
processor = model_data["processor"]
|
||||
model = model_data["model"]
|
||||
device = model_data["device"]
|
||||
|
||||
# Convertir frame numpy en PIL Image
|
||||
if frame.dtype != np.uint8:
|
||||
frame = (frame * 255).astype(np.uint8)
|
||||
image = Image.fromarray(frame)
|
||||
|
||||
# Préparer les inputs
|
||||
inputs = processor(images=image, text=prompt, return_tensors="pt")
|
||||
inputs = {k: v.to(device) for k, v in inputs.items()}
|
||||
|
||||
# Inférence
|
||||
with torch.no_grad():
|
||||
outputs = model(**inputs)
|
||||
|
||||
# Post-traitement
|
||||
target_sizes = torch.tensor([image.size[::-1]]).to(device)
|
||||
results = processor.post_process_grounded_object_detection(
|
||||
outputs=outputs,
|
||||
input_ids=inputs["input_ids"],
|
||||
threshold=0.3,
|
||||
target_sizes=target_sizes
|
||||
)[0]
|
||||
|
||||
# Convertir en objets Detection
|
||||
detections = []
|
||||
boxes = results["boxes"].cpu().numpy()
|
||||
scores = results["scores"].cpu().numpy()
|
||||
labels = results["labels"]
|
||||
|
||||
for box, score, label in zip(boxes, scores, labels):
|
||||
x1, y1, x2, y2 = box
|
||||
x, y = int(x1), int(y1)
|
||||
w, h = int(x2 - x1), int(y2 - y1)
|
||||
|
||||
roi = extract_roi(frame, (x, y, w, h))
|
||||
embedding = np.random.rand(512) # Placeholder
|
||||
|
||||
detection = Detection(
|
||||
label=label,
|
||||
confidence=float(score),
|
||||
bbox=(x, y, w, h),
|
||||
embedding=embedding,
|
||||
model_source="dino",
|
||||
roi_image=roi,
|
||||
metadata={"raw_box": box.tolist()}
|
||||
)
|
||||
detections.append(detection)
|
||||
|
||||
logger.info(f"Grounding DINO: {len(detections)} détections pour '{prompt}'")
|
||||
return detections
|
||||
|
||||
except Exception as e:
|
||||
logger.error(f"Erreur lors de la détection Grounding DINO: {e}")
|
||||
return []
|
||||
|
||||
def detect_with_yolo(self, prompt: str, frame: np.ndarray) -> List[Detection]:
|
||||
"""
|
||||
Détection d'éléments UI avec YOLO-World
|
||||
|
||||
Args:
|
||||
prompt: Description textuelle de l'élément à détecter
|
||||
frame: Image de l'écran (numpy array RGB)
|
||||
|
||||
Returns:
|
||||
Liste de détections trouvées
|
||||
"""
|
||||
try:
|
||||
# S'assurer que le modèle est chargé
|
||||
if not self._ensure_model_loaded("yolo"):
|
||||
logger.warning("YOLO-World n'est pas disponible")
|
||||
return []
|
||||
|
||||
model_data = self._models["yolo"]
|
||||
if not model_data.get("model"):
|
||||
return []
|
||||
|
||||
model = model_data["model"]
|
||||
|
||||
# Définir les classes à détecter (YOLO-World accepte des prompts textuels)
|
||||
model.set_classes([prompt])
|
||||
|
||||
# Convertir BGR vers RGB si nécessaire
|
||||
if frame.dtype != np.uint8:
|
||||
frame = (frame * 255).astype(np.uint8)
|
||||
|
||||
# Inférence
|
||||
results = model.predict(frame, conf=0.1, verbose=False)
|
||||
|
||||
# Convertir en objets Detection
|
||||
detections = []
|
||||
for result in results:
|
||||
boxes = result.boxes
|
||||
for box in boxes:
|
||||
# Extraire les coordonnées
|
||||
x1, y1, x2, y2 = box.xyxy[0].cpu().numpy()
|
||||
x, y = int(x1), int(y1)
|
||||
w, h = int(x2 - x1), int(y2 - y1)
|
||||
|
||||
# Score de confiance
|
||||
confidence = float(box.conf[0])
|
||||
|
||||
# Classe détectée
|
||||
cls_id = int(box.cls[0])
|
||||
label = model.names[cls_id] if cls_id < len(model.names) else prompt
|
||||
|
||||
roi = extract_roi(frame, (x, y, w, h))
|
||||
embedding = np.random.rand(512) # Placeholder
|
||||
|
||||
detection = Detection(
|
||||
label=label,
|
||||
confidence=confidence,
|
||||
bbox=(x, y, w, h),
|
||||
embedding=embedding,
|
||||
model_source="yolo",
|
||||
roi_image=roi,
|
||||
metadata={"class_id": cls_id}
|
||||
)
|
||||
detections.append(detection)
|
||||
|
||||
logger.info(f"YOLO-World: {len(detections)} détections pour '{prompt}'")
|
||||
return detections
|
||||
|
||||
except Exception as e:
|
||||
logger.error(f"Erreur lors de la détection YOLO-World: {e}")
|
||||
return []
|
||||
|
||||
|
||||
def detect(self, prompt: str, frame: np.ndarray,
|
||||
model: Optional[str] = None) -> List[Detection]:
|
||||
"""
|
||||
Détection d'éléments UI avec fallback automatique entre modèles
|
||||
|
||||
Args:
|
||||
prompt: Description textuelle de l'élément à détecter
|
||||
frame: Image de l'écran (numpy array RGB)
|
||||
model: Modèle spécifique à utiliser (None = utiliser le modèle principal)
|
||||
|
||||
Returns:
|
||||
Liste de détections trouvées
|
||||
"""
|
||||
# Déterminer l'ordre des modèles à essayer
|
||||
if model:
|
||||
models_to_try = [model] + [m for m in self.fallback_order if m != model]
|
||||
else:
|
||||
models_to_try = [self.primary_model] + [m for m in self.fallback_order if m != self.primary_model]
|
||||
|
||||
# Essayer chaque modèle jusqu'à obtenir des détections
|
||||
for model_name in models_to_try:
|
||||
try:
|
||||
logger.info(f"Tentative de détection avec {model_name}...")
|
||||
|
||||
if model_name == "owl-v2":
|
||||
detections = self.detect_with_owlv2(prompt, frame)
|
||||
elif model_name == "dino":
|
||||
detections = self.detect_with_dino(prompt, frame)
|
||||
elif model_name == "yolo":
|
||||
detections = self.detect_with_yolo(prompt, frame)
|
||||
else:
|
||||
logger.warning(f"Modèle inconnu: {model_name}")
|
||||
continue
|
||||
|
||||
# Si des détections sont trouvées, retourner
|
||||
if detections:
|
||||
logger.info(f"Détection réussie avec {model_name}: {len(detections)} éléments")
|
||||
return detections
|
||||
else:
|
||||
logger.warning(f"Aucune détection avec {model_name}, essai du modèle suivant...")
|
||||
|
||||
except Exception as e:
|
||||
logger.error(f"Erreur avec {model_name}: {e}, essai du modèle suivant...")
|
||||
continue
|
||||
|
||||
# Aucun modèle n'a réussi
|
||||
logger.error(f"Aucun modèle n'a pu détecter '{prompt}'")
|
||||
return []
|
||||
|
||||
def select_best_detection(self, detections: List[Detection],
|
||||
context: Optional[Dict[str, Any]] = None) -> Optional[Detection]:
|
||||
"""
|
||||
Sélectionne la meilleure détection parmi une liste
|
||||
|
||||
Args:
|
||||
detections: Liste de détections à évaluer
|
||||
context: Contexte additionnel pour la sélection (position précédente, etc.)
|
||||
|
||||
Returns:
|
||||
La meilleure détection ou None si la liste est vide
|
||||
"""
|
||||
if not detections:
|
||||
return None
|
||||
|
||||
# Si une seule détection, la retourner
|
||||
if len(detections) == 1:
|
||||
return detections[0]
|
||||
|
||||
# Stratégie de sélection basée sur plusieurs critères
|
||||
best_detection = None
|
||||
best_score = -1
|
||||
|
||||
for detection in detections:
|
||||
score = detection.confidence
|
||||
|
||||
# Bonus pour les détections du modèle principal
|
||||
if detection.model_source == self.primary_model:
|
||||
score *= 1.1
|
||||
|
||||
# Si contexte fourni avec position précédente, favoriser les détections proches
|
||||
if context and "previous_bbox" in context:
|
||||
prev_x, prev_y, prev_w, prev_h = context["previous_bbox"]
|
||||
curr_x, curr_y, curr_w, curr_h = detection.bbox
|
||||
|
||||
# Calculer la distance entre les centres
|
||||
prev_center = (prev_x + prev_w / 2, prev_y + prev_h / 2)
|
||||
curr_center = (curr_x + curr_w / 2, curr_y + curr_h / 2)
|
||||
distance = np.sqrt(
|
||||
(prev_center[0] - curr_center[0]) ** 2 +
|
||||
(prev_center[1] - curr_center[1]) ** 2
|
||||
)
|
||||
|
||||
# Bonus inversement proportionnel à la distance (max 20% bonus)
|
||||
proximity_bonus = max(0, 1 - distance / 500) * 0.2
|
||||
score *= (1 + proximity_bonus)
|
||||
|
||||
# Favoriser les détections avec des bounding boxes de taille raisonnable
|
||||
x, y, w, h = detection.bbox
|
||||
area = w * h
|
||||
if 100 < area < 100000: # Taille raisonnable pour un élément UI
|
||||
score *= 1.05
|
||||
|
||||
if score > best_score:
|
||||
best_score = score
|
||||
best_detection = detection
|
||||
|
||||
logger.info(f"Meilleure détection sélectionnée: {best_detection.label} "
|
||||
f"(confiance: {best_detection.confidence:.2f}, "
|
||||
f"modèle: {best_detection.model_source})")
|
||||
|
||||
return best_detection
|
||||
|
||||
def filter_detections(self, detections: List[Detection],
|
||||
min_confidence: float = 0.3,
|
||||
max_detections: int = 10) -> List[Detection]:
|
||||
"""
|
||||
Filtre les détections selon des critères de qualité
|
||||
|
||||
Args:
|
||||
detections: Liste de détections à filtrer
|
||||
min_confidence: Confiance minimale requise
|
||||
max_detections: Nombre maximum de détections à retourner
|
||||
|
||||
Returns:
|
||||
Liste filtrée et triée de détections
|
||||
"""
|
||||
# Filtrer par confiance minimale
|
||||
filtered = [d for d in detections if d.confidence >= min_confidence]
|
||||
|
||||
# Trier par confiance décroissante
|
||||
filtered.sort(key=lambda d: d.confidence, reverse=True)
|
||||
|
||||
# Limiter le nombre de détections
|
||||
filtered = filtered[:max_detections]
|
||||
|
||||
logger.info(f"Filtrage: {len(detections)} -> {len(filtered)} détections "
|
||||
f"(seuil: {min_confidence})")
|
||||
|
||||
return filtered
|
||||
|
||||
def merge_overlapping_detections(self, detections: List[Detection],
|
||||
iou_threshold: float = 0.5) -> List[Detection]:
|
||||
"""
|
||||
Fusionne les détections qui se chevauchent (même élément détecté plusieurs fois)
|
||||
|
||||
Args:
|
||||
detections: Liste de détections
|
||||
iou_threshold: Seuil d'IoU pour considérer deux détections comme identiques
|
||||
|
||||
Returns:
|
||||
Liste de détections fusionnées
|
||||
"""
|
||||
if len(detections) <= 1:
|
||||
return detections
|
||||
|
||||
def calculate_iou(box1: Tuple[int, int, int, int],
|
||||
box2: Tuple[int, int, int, int]) -> float:
|
||||
"""Calcule l'Intersection over Union entre deux bounding boxes"""
|
||||
x1, y1, w1, h1 = box1
|
||||
x2, y2, w2, h2 = box2
|
||||
|
||||
# Coordonnées de l'intersection
|
||||
xi1 = max(x1, x2)
|
||||
yi1 = max(y1, y2)
|
||||
xi2 = min(x1 + w1, x2 + w2)
|
||||
yi2 = min(y1 + h1, y2 + h2)
|
||||
|
||||
# Aire de l'intersection
|
||||
inter_area = max(0, xi2 - xi1) * max(0, yi2 - yi1)
|
||||
|
||||
# Aires des deux boxes
|
||||
box1_area = w1 * h1
|
||||
box2_area = w2 * h2
|
||||
|
||||
# Union
|
||||
union_area = box1_area + box2_area - inter_area
|
||||
|
||||
# IoU
|
||||
return inter_area / union_area if union_area > 0 else 0
|
||||
|
||||
# Trier par confiance décroissante
|
||||
sorted_detections = sorted(detections, key=lambda d: d.confidence, reverse=True)
|
||||
|
||||
merged = []
|
||||
used = set()
|
||||
|
||||
for i, det1 in enumerate(sorted_detections):
|
||||
if i in used:
|
||||
continue
|
||||
|
||||
# Trouver toutes les détections qui se chevauchent avec det1
|
||||
overlapping = [det1]
|
||||
for j, det2 in enumerate(sorted_detections[i+1:], start=i+1):
|
||||
if j in used:
|
||||
continue
|
||||
|
||||
iou = calculate_iou(det1.bbox, det2.bbox)
|
||||
if iou >= iou_threshold:
|
||||
overlapping.append(det2)
|
||||
used.add(j)
|
||||
|
||||
# Si plusieurs détections se chevauchent, garder celle avec la meilleure confiance
|
||||
# (det1 est déjà la meilleure car la liste est triée)
|
||||
merged.append(det1)
|
||||
used.add(i)
|
||||
|
||||
logger.info(f"Fusion: {len(detections)} -> {len(merged)} détections "
|
||||
f"(seuil IoU: {iou_threshold})")
|
||||
|
||||
return merged
|
||||
|
||||
def get_detection_statistics(self, detections: List[Detection]) -> Dict[str, Any]:
|
||||
"""
|
||||
Calcule des statistiques sur une liste de détections
|
||||
|
||||
Args:
|
||||
detections: Liste de détections
|
||||
|
||||
Returns:
|
||||
Dictionnaire de statistiques
|
||||
"""
|
||||
if not detections:
|
||||
return {
|
||||
"count": 0,
|
||||
"avg_confidence": 0.0,
|
||||
"max_confidence": 0.0,
|
||||
"min_confidence": 0.0,
|
||||
"models_used": []
|
||||
}
|
||||
|
||||
confidences = [d.confidence for d in detections]
|
||||
models = [d.model_source for d in detections]
|
||||
|
||||
stats = {
|
||||
"count": len(detections),
|
||||
"avg_confidence": float(np.mean(confidences)),
|
||||
"max_confidence": float(np.max(confidences)),
|
||||
"min_confidence": float(np.min(confidences)),
|
||||
"std_confidence": float(np.std(confidences)),
|
||||
"models_used": list(set(models)),
|
||||
"model_distribution": {model: models.count(model) for model in set(models)}
|
||||
}
|
||||
|
||||
return stats
|
||||
|
||||
def unload_models(self):
|
||||
"""Décharge tous les modèles de la mémoire"""
|
||||
logger.info("Déchargement des modèles de vision...")
|
||||
self._models.clear()
|
||||
self._models_loaded = {k: False for k in self._models_loaded}
|
||||
|
||||
# Forcer le garbage collection
|
||||
import gc
|
||||
gc.collect()
|
||||
|
||||
# Si CUDA disponible, vider le cache
|
||||
try:
|
||||
import torch
|
||||
if torch.cuda.is_available():
|
||||
torch.cuda.empty_cache()
|
||||
except ImportError:
|
||||
pass
|
||||
|
||||
logger.info("Modèles déchargés")
|
||||
|
||||
|
||||
if __name__ == "__main__":
|
||||
"""Tests basiques de VisionUtils"""
|
||||
import sys
|
||||
|
||||
print("Test de VisionUtils")
|
||||
print("=" * 50)
|
||||
|
||||
# Initialiser VisionUtils
|
||||
print("\n1. Initialisation de VisionUtils...")
|
||||
vision = VisionUtils()
|
||||
print(f" Modèle principal: {vision.primary_model}")
|
||||
print(f" Ordre de fallback: {vision.fallback_order}")
|
||||
|
||||
# Créer une image de test
|
||||
print("\n2. Création d'une image de test...")
|
||||
test_frame = np.random.randint(0, 255, (480, 640, 3), dtype=np.uint8)
|
||||
print(f" Taille de l'image: {test_frame.shape}")
|
||||
|
||||
# Test de détection (nécessite les modèles installés)
|
||||
print("\n3. Test de détection...")
|
||||
try:
|
||||
detections = vision.detect("button", test_frame)
|
||||
print(f" Détections trouvées: {len(detections)}")
|
||||
|
||||
if detections:
|
||||
print("\n4. Statistiques des détections:")
|
||||
stats = vision.get_detection_statistics(detections)
|
||||
for key, value in stats.items():
|
||||
print(f" {key}: {value}")
|
||||
|
||||
print("\n5. Sélection de la meilleure détection:")
|
||||
best = vision.select_best_detection(detections)
|
||||
if best:
|
||||
print(f" Label: {best.label}")
|
||||
print(f" Confiance: {best.confidence:.2f}")
|
||||
print(f" BBox: {best.bbox}")
|
||||
print(f" Modèle: {best.model_source}")
|
||||
except Exception as e:
|
||||
print(f" Erreur lors de la détection: {e}")
|
||||
print(" (Normal si les modèles ne sont pas installés)")
|
||||
|
||||
# Test de filtrage
|
||||
print("\n6. Test de filtrage de détections...")
|
||||
mock_detections = [
|
||||
Detection(
|
||||
label="button1",
|
||||
confidence=0.95,
|
||||
bbox=(100, 100, 50, 30),
|
||||
embedding=np.random.rand(512),
|
||||
model_source="owl-v2"
|
||||
),
|
||||
Detection(
|
||||
label="button2",
|
||||
confidence=0.25,
|
||||
bbox=(200, 100, 50, 30),
|
||||
embedding=np.random.rand(512),
|
||||
model_source="owl-v2"
|
||||
),
|
||||
Detection(
|
||||
label="button3",
|
||||
confidence=0.75,
|
||||
bbox=(300, 100, 50, 30),
|
||||
embedding=np.random.rand(512),
|
||||
model_source="dino"
|
||||
),
|
||||
]
|
||||
|
||||
filtered = vision.filter_detections(mock_detections, min_confidence=0.5)
|
||||
print(f" Détections avant filtrage: {len(mock_detections)}")
|
||||
print(f" Détections après filtrage: {len(filtered)}")
|
||||
|
||||
# Test de fusion
|
||||
print("\n7. Test de fusion de détections chevauchantes...")
|
||||
overlapping_detections = [
|
||||
Detection(
|
||||
label="button",
|
||||
confidence=0.95,
|
||||
bbox=(100, 100, 50, 30),
|
||||
embedding=np.random.rand(512),
|
||||
model_source="owl-v2"
|
||||
),
|
||||
Detection(
|
||||
label="button",
|
||||
confidence=0.85,
|
||||
bbox=(105, 102, 48, 28), # Légèrement décalé
|
||||
embedding=np.random.rand(512),
|
||||
model_source="dino"
|
||||
),
|
||||
]
|
||||
|
||||
merged = vision.merge_overlapping_detections(overlapping_detections, iou_threshold=0.5)
|
||||
print(f" Détections avant fusion: {len(overlapping_detections)}")
|
||||
print(f" Détections après fusion: {len(merged)}")
|
||||
|
||||
print("\n✓ Tests basiques terminés!")
|
||||
279
geniusia2/core/vision_analysis.py
Normal file
279
geniusia2/core/vision_analysis.py
Normal file
@@ -0,0 +1,279 @@
|
||||
"""
|
||||
Analyse visuelle des actions utilisateur.
|
||||
Extrait et analyse la région autour d'une action pour créer une signature visuelle.
|
||||
"""
|
||||
|
||||
import numpy as np
|
||||
import cv2
|
||||
from typing import Dict, Any, Optional, Tuple, Union
|
||||
from pathlib import Path
|
||||
from PIL import Image
|
||||
|
||||
from .embeddings_manager import EmbeddingsManager
|
||||
from .embedders import EmbeddingManager as NewEmbeddingManager
|
||||
from .utils.vision_utils import VisionUtils
|
||||
from .llm_manager import LLMManager
|
||||
from .logger import Logger
|
||||
|
||||
|
||||
class VisionAnalysis:
|
||||
"""
|
||||
Analyse visuelle des actions pour créer des signatures réutilisables.
|
||||
"""
|
||||
|
||||
def __init__(
|
||||
self,
|
||||
embeddings_manager: Union[EmbeddingsManager, NewEmbeddingManager],
|
||||
vision_utils: VisionUtils,
|
||||
llm_manager: Optional[LLMManager] = None,
|
||||
logger: Optional[Logger] = None
|
||||
):
|
||||
"""
|
||||
Initialise l'analyseur visuel.
|
||||
|
||||
Args:
|
||||
embeddings_manager: Pour créer les embeddings (old or new system)
|
||||
vision_utils: Pour la détection d'éléments
|
||||
llm_manager: Pour le contexte (optionnel)
|
||||
logger: Pour la journalisation
|
||||
"""
|
||||
self.embeddings = embeddings_manager
|
||||
self.vision = vision_utils
|
||||
self.llm = llm_manager
|
||||
self.logger = logger
|
||||
|
||||
# Detect if using new embedding system
|
||||
self._use_new_system = isinstance(embeddings_manager, NewEmbeddingManager)
|
||||
|
||||
def analyze_action(
|
||||
self,
|
||||
screenshot: np.ndarray,
|
||||
x: int,
|
||||
y: int,
|
||||
action_type: str,
|
||||
window: str
|
||||
) -> Dict[str, Any]:
|
||||
"""
|
||||
Analyse une action utilisateur et crée sa signature visuelle.
|
||||
|
||||
Args:
|
||||
screenshot: Image complète de l'écran
|
||||
x, y: Position de l'action
|
||||
action_type: Type d'action (mouse_click, key_press, etc.)
|
||||
window: Fenêtre active
|
||||
|
||||
Returns:
|
||||
Signature visuelle de l'action (always returns a valid signature)
|
||||
"""
|
||||
# Initialize defaults
|
||||
element_type = "unknown"
|
||||
element_description = ""
|
||||
region = None
|
||||
region_coords = (0, 0, 0, 0)
|
||||
embedding = None
|
||||
|
||||
try:
|
||||
# 1. Extraire la région autour de l'action
|
||||
region, region_coords = self._extract_region(screenshot, x, y, size=100)
|
||||
|
||||
# 2. Créer l'embedding de la région
|
||||
if self._use_new_system:
|
||||
# New system: convert numpy to PIL, then embed
|
||||
region_rgb = cv2.cvtColor(region, cv2.COLOR_BGR2RGB)
|
||||
pil_image = Image.fromarray(region_rgb.astype(np.uint8))
|
||||
embedding = self.embeddings.embed(pil_image)
|
||||
else:
|
||||
# Old system: use encode_image directly
|
||||
embedding = self.embeddings.encode_image(region)
|
||||
|
||||
except Exception as e:
|
||||
if self.logger:
|
||||
self.logger.log_action({
|
||||
"action": "region_extraction_failed",
|
||||
"error": str(e),
|
||||
"position": (x, y)
|
||||
})
|
||||
# Continue with defaults
|
||||
|
||||
# 3. Détecter le type d'élément avec le LLM
|
||||
try:
|
||||
# Utiliser Qwen3-VL pour identifier l'élément
|
||||
if self.llm and region is not None:
|
||||
try:
|
||||
prompt = f"""Analyse cette région d'interface utilisateur où l'utilisateur a cliqué.
|
||||
|
||||
Position du clic: ({x}, {y})
|
||||
Type d'action: {action_type}
|
||||
|
||||
Identifie l'élément UI en une courte phrase (max 30 caractères).
|
||||
Exemples: "Bouton Rafraîchir", "Icône Paramètres", "Champ de texte"
|
||||
|
||||
Réponds UNIQUEMENT avec l'identification, sans explication."""
|
||||
|
||||
response = self.llm.generate_with_vision(
|
||||
prompt=prompt,
|
||||
images=[region]
|
||||
)
|
||||
|
||||
element_description = response.strip()[:50]
|
||||
|
||||
# Extraire le type (premier mot généralement)
|
||||
words = element_description.lower().split()
|
||||
if words:
|
||||
if "bouton" in words or "button" in words:
|
||||
element_type = "button"
|
||||
elif "icône" in words or "icon" in words:
|
||||
element_type = "icon"
|
||||
elif "champ" in words or "field" in words or "input" in words:
|
||||
element_type = "text_field"
|
||||
elif "lien" in words or "link" in words:
|
||||
element_type = "link"
|
||||
else:
|
||||
element_type = words[0]
|
||||
|
||||
except Exception as llm_error:
|
||||
if self.logger:
|
||||
self.logger.log_action({
|
||||
"action": "llm_analysis_failed",
|
||||
"error": str(llm_error),
|
||||
"position": (x, y)
|
||||
})
|
||||
|
||||
# Fallback: détection visuelle classique
|
||||
if element_type == "unknown" and region is not None:
|
||||
try:
|
||||
all_detections = []
|
||||
for elem_type in ["button", "icon", "text field"]:
|
||||
detections = self.vision.detect(elem_type, screenshot)
|
||||
all_detections.extend(detections)
|
||||
|
||||
if all_detections:
|
||||
closest = self._find_closest_detection(all_detections, x, y)
|
||||
if closest:
|
||||
element_type = closest.label
|
||||
|
||||
except Exception as vision_error:
|
||||
if self.logger:
|
||||
self.logger.log_action({
|
||||
"action": "vision_detection_failed",
|
||||
"error": str(vision_error),
|
||||
"position": (x, y)
|
||||
})
|
||||
|
||||
except Exception as e:
|
||||
if self.logger:
|
||||
self.logger.log_action({
|
||||
"action": "element_detection_failed",
|
||||
"error": str(e),
|
||||
"position": (x, y)
|
||||
})
|
||||
|
||||
# 4. Créer la signature (always return a valid signature)
|
||||
signature = {
|
||||
"position": (x, y),
|
||||
"region_coords": region_coords,
|
||||
"region_image": region,
|
||||
"embedding": embedding,
|
||||
"element_type": element_type,
|
||||
"element_description": element_description,
|
||||
"action_type": action_type,
|
||||
"window": window,
|
||||
"screenshot_shape": screenshot.shape
|
||||
}
|
||||
|
||||
return signature
|
||||
|
||||
def _extract_region(
|
||||
self,
|
||||
image: np.ndarray,
|
||||
x: int,
|
||||
y: int,
|
||||
size: int = 100
|
||||
) -> Tuple[np.ndarray, Tuple[int, int, int, int]]:
|
||||
"""
|
||||
Extrait une région carrée autour d'un point.
|
||||
|
||||
Returns:
|
||||
(région extraite, coordonnées (x1, y1, x2, y2))
|
||||
"""
|
||||
h, w = image.shape[:2]
|
||||
half_size = size // 2
|
||||
|
||||
# Calculer les coordonnées en restant dans l'image
|
||||
x1 = max(0, x - half_size)
|
||||
y1 = max(0, y - half_size)
|
||||
x2 = min(w, x + half_size)
|
||||
y2 = min(h, y + half_size)
|
||||
|
||||
region = image[y1:y2, x1:x2].copy()
|
||||
|
||||
# Redimensionner à size x size si nécessaire
|
||||
if region.shape[0] != size or region.shape[1] != size:
|
||||
region = cv2.resize(region, (size, size))
|
||||
|
||||
return region, (x1, y1, x2, y2)
|
||||
|
||||
def _find_closest_detection(
|
||||
self,
|
||||
detections: list,
|
||||
x: int,
|
||||
y: int,
|
||||
max_distance: int = 50
|
||||
):
|
||||
"""
|
||||
Trouve la détection la plus proche d'un point.
|
||||
|
||||
Returns:
|
||||
Détection la plus proche ou None
|
||||
"""
|
||||
if not detections:
|
||||
return None
|
||||
|
||||
closest = None
|
||||
min_dist = float('inf')
|
||||
|
||||
for det in detections:
|
||||
# Gérer à la fois les objets Detection et les dicts
|
||||
if hasattr(det, 'bbox'):
|
||||
bbox = det.bbox # Objet Detection
|
||||
else:
|
||||
bbox = det.get("bbox", []) # Dict
|
||||
|
||||
if bbox and len(bbox) >= 4:
|
||||
# bbox format: (x, y, w, h)
|
||||
cx = bbox[0] + bbox[2] / 2
|
||||
cy = bbox[1] + bbox[3] / 2
|
||||
|
||||
# Distance au point
|
||||
dist = np.sqrt((cx - x)**2 + (cy - y)**2)
|
||||
|
||||
if dist < min_dist and dist < max_distance:
|
||||
min_dist = dist
|
||||
closest = det
|
||||
|
||||
return closest
|
||||
|
||||
def compare_signatures(
|
||||
self,
|
||||
sig1: Dict[str, Any],
|
||||
sig2: Dict[str, Any]
|
||||
) -> float:
|
||||
"""
|
||||
Compare deux signatures visuelles.
|
||||
|
||||
Returns:
|
||||
Score de similarité (0-1)
|
||||
"""
|
||||
# Comparer les embeddings
|
||||
emb1 = sig1.get("embedding")
|
||||
emb2 = sig2.get("embedding")
|
||||
|
||||
if emb1 is None or emb2 is None:
|
||||
return 0.0
|
||||
|
||||
# Similarité cosinus
|
||||
similarity = np.dot(emb1, emb2) / (
|
||||
np.linalg.norm(emb1) * np.linalg.norm(emb2)
|
||||
)
|
||||
|
||||
return float(similarity)
|
||||
212
geniusia2/core/vision_search.py
Normal file
212
geniusia2/core/vision_search.py
Normal file
@@ -0,0 +1,212 @@
|
||||
"""
|
||||
Recherche visuelle d'éléments dans l'écran.
|
||||
Utilise une approche hybride : template matching (rapide) + embeddings (robuste).
|
||||
"""
|
||||
|
||||
import numpy as np
|
||||
import cv2
|
||||
from typing import Dict, Any, Optional, Tuple, List
|
||||
|
||||
from .embeddings_manager import EmbeddingsManager
|
||||
from .logger import Logger
|
||||
|
||||
|
||||
class VisionSearch:
|
||||
"""
|
||||
Recherche visuelle d'éléments en utilisant template matching et embeddings.
|
||||
"""
|
||||
|
||||
def __init__(
|
||||
self,
|
||||
embeddings_manager: EmbeddingsManager,
|
||||
logger: Optional[Logger] = None
|
||||
):
|
||||
"""
|
||||
Initialise le moteur de recherche visuelle.
|
||||
|
||||
Args:
|
||||
embeddings_manager: Pour les embeddings
|
||||
logger: Pour la journalisation
|
||||
"""
|
||||
self.embeddings = embeddings_manager
|
||||
self.logger = logger
|
||||
|
||||
def find_element(
|
||||
self,
|
||||
screenshot: np.ndarray,
|
||||
target_signature: Dict[str, Any],
|
||||
confidence_threshold: float = 0.8
|
||||
) -> Optional[Tuple[int, int, float]]:
|
||||
"""
|
||||
Trouve un élément dans l'écran en utilisant sa signature visuelle.
|
||||
|
||||
Args:
|
||||
screenshot: Image de l'écran actuel
|
||||
target_signature: Signature de l'élément à trouver
|
||||
confidence_threshold: Seuil de confiance minimum
|
||||
|
||||
Returns:
|
||||
(x, y, confidence) ou None si non trouvé
|
||||
"""
|
||||
# 1. Essayer template matching (rapide)
|
||||
result = self._template_matching(
|
||||
screenshot,
|
||||
target_signature.get("region_image"),
|
||||
confidence_threshold=0.9 # Seuil élevé pour template
|
||||
)
|
||||
|
||||
if result:
|
||||
return result
|
||||
|
||||
# 2. Sinon, recherche par embedding (plus lent mais robuste)
|
||||
result = self._embedding_search(
|
||||
screenshot,
|
||||
target_signature.get("embedding"),
|
||||
region_size=100,
|
||||
confidence_threshold=confidence_threshold
|
||||
)
|
||||
|
||||
return result
|
||||
|
||||
def _template_matching(
|
||||
self,
|
||||
screenshot: np.ndarray,
|
||||
template: np.ndarray,
|
||||
confidence_threshold: float = 0.9
|
||||
) -> Optional[Tuple[int, int, float]]:
|
||||
"""
|
||||
Recherche par template matching OpenCV (rapide).
|
||||
|
||||
Returns:
|
||||
(x, y, confidence) ou None
|
||||
"""
|
||||
if template is None or template.size == 0:
|
||||
return None
|
||||
|
||||
try:
|
||||
# Convertir en niveaux de gris
|
||||
gray_screenshot = cv2.cvtColor(screenshot, cv2.COLOR_BGR2GRAY)
|
||||
gray_template = cv2.cvtColor(template, cv2.COLOR_BGR2GRAY)
|
||||
|
||||
# Template matching
|
||||
result = cv2.matchTemplate(
|
||||
gray_screenshot,
|
||||
gray_template,
|
||||
cv2.TM_CCOEFF_NORMED
|
||||
)
|
||||
|
||||
min_val, max_val, min_loc, max_loc = cv2.minMaxLoc(result)
|
||||
|
||||
if max_val >= confidence_threshold:
|
||||
# Centre du template
|
||||
h, w = gray_template.shape
|
||||
x = max_loc[0] + w // 2
|
||||
y = max_loc[1] + h // 2
|
||||
|
||||
return (x, y, float(max_val))
|
||||
|
||||
except Exception as e:
|
||||
if self.logger:
|
||||
self.logger.log_action({
|
||||
"action": "template_matching_failed",
|
||||
"error": str(e)
|
||||
})
|
||||
|
||||
return None
|
||||
|
||||
def _embedding_search(
|
||||
self,
|
||||
screenshot: np.ndarray,
|
||||
target_embedding: np.ndarray,
|
||||
region_size: int = 100,
|
||||
confidence_threshold: float = 0.8,
|
||||
step: int = 20
|
||||
) -> Optional[Tuple[int, int, float]]:
|
||||
"""
|
||||
Recherche par similarité d'embeddings (robuste mais lent).
|
||||
|
||||
Args:
|
||||
step: Pas de la fenêtre glissante (20 = rapide, 10 = précis)
|
||||
|
||||
Returns:
|
||||
(x, y, confidence) ou None
|
||||
"""
|
||||
if target_embedding is None:
|
||||
return None
|
||||
|
||||
h, w = screenshot.shape[:2]
|
||||
half_size = region_size // 2
|
||||
|
||||
best_position = None
|
||||
best_similarity = 0.0
|
||||
|
||||
# Fenêtre glissante
|
||||
for y in range(half_size, h - half_size, step):
|
||||
for x in range(half_size, w - half_size, step):
|
||||
# Extraire région
|
||||
x1 = x - half_size
|
||||
y1 = y - half_size
|
||||
x2 = x + half_size
|
||||
y2 = y + half_size
|
||||
|
||||
region = screenshot[y1:y2, x1:x2]
|
||||
|
||||
# Redimensionner si nécessaire
|
||||
if region.shape[0] != region_size or region.shape[1] != region_size:
|
||||
region = cv2.resize(region, (region_size, region_size))
|
||||
|
||||
# Calculer embedding
|
||||
try:
|
||||
embedding = self.embeddings.encode_image(region)
|
||||
|
||||
# Similarité cosinus
|
||||
similarity = np.dot(embedding, target_embedding) / (
|
||||
np.linalg.norm(embedding) * np.linalg.norm(target_embedding)
|
||||
)
|
||||
|
||||
if similarity > best_similarity:
|
||||
best_similarity = similarity
|
||||
best_position = (x, y)
|
||||
|
||||
except Exception:
|
||||
continue
|
||||
|
||||
if best_similarity >= confidence_threshold:
|
||||
return (*best_position, float(best_similarity))
|
||||
|
||||
return None
|
||||
|
||||
def find_in_region(
|
||||
self,
|
||||
screenshot: np.ndarray,
|
||||
target_signature: Dict[str, Any],
|
||||
search_region: Tuple[int, int, int, int],
|
||||
confidence_threshold: float = 0.8
|
||||
) -> Optional[Tuple[int, int, float]]:
|
||||
"""
|
||||
Recherche dans une région spécifique (optimisation).
|
||||
|
||||
Args:
|
||||
search_region: (x1, y1, x2, y2) région de recherche
|
||||
|
||||
Returns:
|
||||
(x, y, confidence) ou None
|
||||
"""
|
||||
x1, y1, x2, y2 = search_region
|
||||
|
||||
# Extraire la région de recherche
|
||||
region_screenshot = screenshot[y1:y2, x1:x2]
|
||||
|
||||
# Chercher dans cette région
|
||||
result = self.find_element(
|
||||
region_screenshot,
|
||||
target_signature,
|
||||
confidence_threshold
|
||||
)
|
||||
|
||||
if result:
|
||||
# Ajuster les coordonnées
|
||||
x, y, conf = result
|
||||
return (x + x1, y + y1, conf)
|
||||
|
||||
return None
|
||||
665
geniusia2/core/whitelist_manager.py
Normal file
665
geniusia2/core/whitelist_manager.py
Normal file
@@ -0,0 +1,665 @@
|
||||
"""
|
||||
Gestionnaire de liste blanche pour RPA Vision V2
|
||||
Gère la liste des fenêtres d'application autorisées pour l'automatisation
|
||||
avec persistance et confirmation administrateur pour les modifications.
|
||||
"""
|
||||
|
||||
import json
|
||||
import os
|
||||
from pathlib import Path
|
||||
from typing import List, Optional, Dict, Any
|
||||
from datetime import datetime
|
||||
|
||||
try:
|
||||
from .logger import Logger
|
||||
from .config import get_data_paths, get_security_config
|
||||
except ImportError:
|
||||
# Pour tests standalone
|
||||
import sys
|
||||
sys.path.insert(0, str(Path(__file__).parent))
|
||||
from logger import Logger
|
||||
from config import get_data_paths, get_security_config
|
||||
|
||||
|
||||
class WhitelistManager:
|
||||
"""
|
||||
Gestionnaire de liste blanche pour contrôler les fenêtres autorisées
|
||||
|
||||
Attributes:
|
||||
whitelist_path: Chemin du fichier de liste blanche
|
||||
whitelist: Liste des patterns de fenêtres autorisées
|
||||
logger: Logger pour journalisation des modifications
|
||||
require_admin_confirmation: Si True, nécessite confirmation pour ajouts
|
||||
"""
|
||||
|
||||
def __init__(
|
||||
self,
|
||||
whitelist_path: Optional[str] = None,
|
||||
logger: Optional[Logger] = None,
|
||||
require_admin_confirmation: bool = True
|
||||
):
|
||||
"""
|
||||
Initialise le gestionnaire de liste blanche
|
||||
|
||||
Args:
|
||||
whitelist_path: Chemin du fichier de liste blanche (utilise config par défaut si None)
|
||||
logger: Logger pour journalisation (crée un nouveau si None)
|
||||
require_admin_confirmation: Si True, nécessite confirmation pour ajouts
|
||||
"""
|
||||
# Configuration des chemins
|
||||
data_paths = get_data_paths()
|
||||
if whitelist_path:
|
||||
self.whitelist_path = Path(whitelist_path)
|
||||
else:
|
||||
# Utiliser le répertoire user_profiles pour stocker la liste blanche
|
||||
profiles_dir = Path(data_paths["user_profiles"])
|
||||
profiles_dir.mkdir(parents=True, exist_ok=True)
|
||||
self.whitelist_path = profiles_dir / "whitelist.json"
|
||||
|
||||
# Logger
|
||||
self.logger = logger or Logger()
|
||||
|
||||
# Configuration
|
||||
self.require_admin_confirmation = require_admin_confirmation
|
||||
self.security_config = get_security_config()
|
||||
|
||||
# Charger la liste blanche
|
||||
self.whitelist: List[str] = []
|
||||
self.metadata: Dict[str, Any] = {}
|
||||
self.load_whitelist()
|
||||
|
||||
self.logger.log_action({
|
||||
"action": "whitelist_manager_initialized",
|
||||
"whitelist_path": str(self.whitelist_path),
|
||||
"whitelist_size": len(self.whitelist),
|
||||
"require_admin_confirmation": self.require_admin_confirmation
|
||||
})
|
||||
|
||||
def is_window_allowed(self, window_title: str) -> bool:
|
||||
"""
|
||||
Vérifie si une fenêtre est autorisée dans la liste blanche
|
||||
|
||||
Args:
|
||||
window_title: Titre de la fenêtre à vérifier
|
||||
|
||||
Returns:
|
||||
True si la fenêtre est autorisée, False sinon
|
||||
"""
|
||||
if not window_title:
|
||||
return False
|
||||
|
||||
# Si la liste blanche est vide, tout est bloqué par défaut
|
||||
if not self.whitelist:
|
||||
self.logger.log_security_event({
|
||||
"event_type": "whitelist_check",
|
||||
"window": window_title,
|
||||
"allowed": False,
|
||||
"reason": "empty_whitelist"
|
||||
})
|
||||
return False
|
||||
|
||||
# Vérifier si le titre correspond à un pattern de la liste blanche
|
||||
window_title_lower = window_title.lower()
|
||||
|
||||
for pattern in self.whitelist:
|
||||
pattern_lower = pattern.lower()
|
||||
|
||||
# Support pour wildcards simples
|
||||
if pattern_lower.endswith("*"):
|
||||
# Pattern de type "Firefox*" - correspondance de préfixe
|
||||
prefix = pattern_lower[:-1]
|
||||
if window_title_lower.startswith(prefix):
|
||||
return True
|
||||
elif pattern_lower.startswith("*"):
|
||||
# Pattern de type "*Firefox" - correspondance de suffixe
|
||||
suffix = pattern_lower[1:]
|
||||
if window_title_lower.endswith(suffix):
|
||||
return True
|
||||
elif "*" in pattern_lower:
|
||||
# Pattern de type "Fire*fox" - correspondance avec wildcard au milieu
|
||||
parts = pattern_lower.split("*")
|
||||
if all(part in window_title_lower for part in parts):
|
||||
# Vérifier l'ordre des parties
|
||||
pos = 0
|
||||
for part in parts:
|
||||
idx = window_title_lower.find(part, pos)
|
||||
if idx == -1:
|
||||
break
|
||||
pos = idx + len(part)
|
||||
else:
|
||||
return True
|
||||
else:
|
||||
# Correspondance exacte (insensible à la casse)
|
||||
if pattern_lower == window_title_lower:
|
||||
return True
|
||||
# Ou correspondance partielle
|
||||
if pattern_lower in window_title_lower:
|
||||
return True
|
||||
|
||||
# Aucune correspondance trouvée
|
||||
self.logger.log_security_event({
|
||||
"event_type": "whitelist_check",
|
||||
"window": window_title,
|
||||
"allowed": False,
|
||||
"reason": "no_match"
|
||||
})
|
||||
|
||||
return False
|
||||
|
||||
def add_to_whitelist(
|
||||
self,
|
||||
window_title: str,
|
||||
admin_confirmed: bool = False,
|
||||
added_by: str = "user"
|
||||
) -> bool:
|
||||
"""
|
||||
Ajoute une fenêtre à la liste blanche avec confirmation admin optionnelle
|
||||
|
||||
Args:
|
||||
window_title: Pattern de titre de fenêtre à autoriser
|
||||
admin_confirmed: Si True, bypass la confirmation admin
|
||||
added_by: Identifiant de l'utilisateur qui ajoute (pour audit)
|
||||
|
||||
Returns:
|
||||
True si ajouté avec succès, False sinon
|
||||
"""
|
||||
if not window_title or not window_title.strip():
|
||||
self.logger.log_action({
|
||||
"action": "whitelist_add_failed",
|
||||
"reason": "empty_window_title"
|
||||
})
|
||||
return False
|
||||
|
||||
window_title = window_title.strip()
|
||||
|
||||
# Vérifier si déjà dans la liste blanche
|
||||
if window_title in self.whitelist:
|
||||
self.logger.log_action({
|
||||
"action": "whitelist_add_skipped",
|
||||
"window": window_title,
|
||||
"reason": "already_exists"
|
||||
})
|
||||
return True # Déjà présent, considéré comme succès
|
||||
|
||||
# Vérifier la confirmation admin si requise
|
||||
if self.require_admin_confirmation and not admin_confirmed:
|
||||
self.logger.log_security_event({
|
||||
"event_type": "whitelist_add_pending",
|
||||
"window": window_title,
|
||||
"added_by": added_by,
|
||||
"details": "Confirmation administrateur requise"
|
||||
})
|
||||
return False
|
||||
|
||||
# Ajouter à la liste blanche
|
||||
self.whitelist.append(window_title)
|
||||
|
||||
# Mettre à jour les métadonnées
|
||||
if "entries" not in self.metadata:
|
||||
self.metadata["entries"] = {}
|
||||
|
||||
self.metadata["entries"][window_title] = {
|
||||
"added_at": datetime.now().isoformat(),
|
||||
"added_by": added_by,
|
||||
"admin_confirmed": admin_confirmed
|
||||
}
|
||||
|
||||
# Sauvegarder
|
||||
self.save_whitelist()
|
||||
|
||||
self.logger.log_action({
|
||||
"action": "whitelist_add_success",
|
||||
"window": window_title,
|
||||
"added_by": added_by,
|
||||
"admin_confirmed": admin_confirmed,
|
||||
"whitelist_size": len(self.whitelist)
|
||||
})
|
||||
|
||||
return True
|
||||
|
||||
def remove_from_whitelist(self, window_title: str) -> bool:
|
||||
"""
|
||||
Retire une fenêtre de la liste blanche
|
||||
|
||||
Args:
|
||||
window_title: Pattern de titre de fenêtre à retirer
|
||||
|
||||
Returns:
|
||||
True si retiré avec succès, False si non trouvé
|
||||
"""
|
||||
if window_title not in self.whitelist:
|
||||
self.logger.log_action({
|
||||
"action": "whitelist_remove_failed",
|
||||
"window": window_title,
|
||||
"reason": "not_found"
|
||||
})
|
||||
return False
|
||||
|
||||
# Retirer de la liste blanche
|
||||
self.whitelist.remove(window_title)
|
||||
|
||||
# Retirer des métadonnées
|
||||
if "entries" in self.metadata and window_title in self.metadata["entries"]:
|
||||
del self.metadata["entries"][window_title]
|
||||
|
||||
# Sauvegarder
|
||||
self.save_whitelist()
|
||||
|
||||
self.logger.log_action({
|
||||
"action": "whitelist_remove_success",
|
||||
"window": window_title,
|
||||
"whitelist_size": len(self.whitelist)
|
||||
})
|
||||
|
||||
return True
|
||||
|
||||
def load_whitelist(self) -> bool:
|
||||
"""
|
||||
Charge la liste blanche depuis le fichier
|
||||
|
||||
Returns:
|
||||
True si chargé avec succès, False sinon
|
||||
"""
|
||||
try:
|
||||
if not self.whitelist_path.exists():
|
||||
# Créer une liste blanche par défaut
|
||||
self.whitelist = self.security_config.get("default_whitelist", [])
|
||||
self.metadata = {
|
||||
"created_at": datetime.now().isoformat(),
|
||||
"version": "1.0",
|
||||
"entries": {}
|
||||
}
|
||||
|
||||
# Sauvegarder la liste blanche par défaut
|
||||
self.save_whitelist()
|
||||
|
||||
self.logger.log_action({
|
||||
"action": "whitelist_created",
|
||||
"path": str(self.whitelist_path),
|
||||
"default_size": len(self.whitelist)
|
||||
})
|
||||
|
||||
return True
|
||||
|
||||
# Charger depuis le fichier
|
||||
with open(self.whitelist_path, 'r', encoding='utf-8') as f:
|
||||
data = json.load(f)
|
||||
|
||||
self.whitelist = data.get("whitelist", [])
|
||||
self.metadata = data.get("metadata", {})
|
||||
|
||||
self.logger.log_action({
|
||||
"action": "whitelist_loaded",
|
||||
"path": str(self.whitelist_path),
|
||||
"whitelist_size": len(self.whitelist)
|
||||
})
|
||||
|
||||
return True
|
||||
|
||||
except Exception as e:
|
||||
self.logger.log_action({
|
||||
"action": "whitelist_load_failed",
|
||||
"path": str(self.whitelist_path),
|
||||
"error": str(e)
|
||||
})
|
||||
|
||||
# En cas d'erreur, utiliser une liste vide
|
||||
self.whitelist = []
|
||||
self.metadata = {}
|
||||
|
||||
return False
|
||||
|
||||
def save_whitelist(self) -> bool:
|
||||
"""
|
||||
Sauvegarde la liste blanche dans le fichier
|
||||
|
||||
Returns:
|
||||
True si sauvegardé avec succès, False sinon
|
||||
"""
|
||||
try:
|
||||
# Mettre à jour les métadonnées
|
||||
self.metadata["last_modified"] = datetime.now().isoformat()
|
||||
self.metadata["version"] = self.metadata.get("version", "1.0")
|
||||
|
||||
# Préparer les données
|
||||
data = {
|
||||
"whitelist": self.whitelist,
|
||||
"metadata": self.metadata
|
||||
}
|
||||
|
||||
# S'assurer que le répertoire parent existe
|
||||
self.whitelist_path.parent.mkdir(parents=True, exist_ok=True)
|
||||
|
||||
# Sauvegarder dans le fichier
|
||||
with open(self.whitelist_path, 'w', encoding='utf-8') as f:
|
||||
json.dump(data, f, indent=2, ensure_ascii=False)
|
||||
|
||||
# Restreindre les permissions (lecture/écriture propriétaire uniquement)
|
||||
os.chmod(self.whitelist_path, 0o600)
|
||||
|
||||
self.logger.log_action({
|
||||
"action": "whitelist_saved",
|
||||
"path": str(self.whitelist_path),
|
||||
"whitelist_size": len(self.whitelist)
|
||||
})
|
||||
|
||||
return True
|
||||
|
||||
except Exception as e:
|
||||
self.logger.log_action({
|
||||
"action": "whitelist_save_failed",
|
||||
"path": str(self.whitelist_path),
|
||||
"error": str(e)
|
||||
})
|
||||
|
||||
return False
|
||||
|
||||
def get_whitelist(self) -> List[str]:
|
||||
"""
|
||||
Retourne une copie de la liste blanche actuelle
|
||||
|
||||
Returns:
|
||||
Liste des patterns de fenêtres autorisées
|
||||
"""
|
||||
return self.whitelist.copy()
|
||||
|
||||
def clear_whitelist(self) -> bool:
|
||||
"""
|
||||
Vide complètement la liste blanche
|
||||
|
||||
Returns:
|
||||
True si vidé avec succès
|
||||
"""
|
||||
old_size = len(self.whitelist)
|
||||
self.whitelist = []
|
||||
self.metadata["entries"] = {}
|
||||
|
||||
self.save_whitelist()
|
||||
|
||||
self.logger.log_security_event({
|
||||
"event_type": "whitelist_cleared",
|
||||
"previous_size": old_size,
|
||||
"details": "Liste blanche complètement vidée"
|
||||
})
|
||||
|
||||
return True
|
||||
|
||||
def get_metadata(self) -> Dict[str, Any]:
|
||||
"""
|
||||
Retourne les métadonnées de la liste blanche
|
||||
|
||||
Returns:
|
||||
Dictionnaire contenant les métadonnées
|
||||
"""
|
||||
return self.metadata.copy()
|
||||
|
||||
def get_entry_info(self, window_title: str) -> Optional[Dict[str, Any]]:
|
||||
"""
|
||||
Retourne les informations sur une entrée spécifique
|
||||
|
||||
Args:
|
||||
window_title: Pattern de fenêtre à rechercher
|
||||
|
||||
Returns:
|
||||
Dictionnaire avec les informations ou None si non trouvé
|
||||
"""
|
||||
if window_title not in self.whitelist:
|
||||
return None
|
||||
|
||||
entry_metadata = self.metadata.get("entries", {}).get(window_title, {})
|
||||
|
||||
return {
|
||||
"window_title": window_title,
|
||||
"in_whitelist": True,
|
||||
**entry_metadata
|
||||
}
|
||||
|
||||
def export_whitelist(self, export_path: str) -> bool:
|
||||
"""
|
||||
Exporte la liste blanche vers un fichier
|
||||
|
||||
Args:
|
||||
export_path: Chemin du fichier d'export
|
||||
|
||||
Returns:
|
||||
True si exporté avec succès
|
||||
"""
|
||||
try:
|
||||
export_data = {
|
||||
"whitelist": self.whitelist,
|
||||
"metadata": self.metadata,
|
||||
"exported_at": datetime.now().isoformat()
|
||||
}
|
||||
|
||||
with open(export_path, 'w', encoding='utf-8') as f:
|
||||
json.dump(export_data, f, indent=2, ensure_ascii=False)
|
||||
|
||||
self.logger.log_action({
|
||||
"action": "whitelist_exported",
|
||||
"export_path": export_path,
|
||||
"whitelist_size": len(self.whitelist)
|
||||
})
|
||||
|
||||
return True
|
||||
|
||||
except Exception as e:
|
||||
self.logger.log_action({
|
||||
"action": "whitelist_export_failed",
|
||||
"export_path": export_path,
|
||||
"error": str(e)
|
||||
})
|
||||
|
||||
return False
|
||||
|
||||
def import_whitelist(
|
||||
self,
|
||||
import_path: str,
|
||||
merge: bool = False,
|
||||
admin_confirmed: bool = False
|
||||
) -> bool:
|
||||
"""
|
||||
Importe une liste blanche depuis un fichier
|
||||
|
||||
Args:
|
||||
import_path: Chemin du fichier à importer
|
||||
merge: Si True, fusionne avec la liste existante; sinon remplace
|
||||
admin_confirmed: Si True, bypass la confirmation admin
|
||||
|
||||
Returns:
|
||||
True si importé avec succès
|
||||
"""
|
||||
try:
|
||||
with open(import_path, 'r', encoding='utf-8') as f:
|
||||
import_data = json.load(f)
|
||||
|
||||
imported_whitelist = import_data.get("whitelist", [])
|
||||
|
||||
if not merge:
|
||||
# Remplacer complètement
|
||||
self.whitelist = imported_whitelist
|
||||
self.metadata = import_data.get("metadata", {})
|
||||
else:
|
||||
# Fusionner
|
||||
for window_title in imported_whitelist:
|
||||
if window_title not in self.whitelist:
|
||||
self.add_to_whitelist(
|
||||
window_title,
|
||||
admin_confirmed=admin_confirmed,
|
||||
added_by="import"
|
||||
)
|
||||
|
||||
self.save_whitelist()
|
||||
|
||||
self.logger.log_action({
|
||||
"action": "whitelist_imported",
|
||||
"import_path": import_path,
|
||||
"merge": merge,
|
||||
"whitelist_size": len(self.whitelist)
|
||||
})
|
||||
|
||||
return True
|
||||
|
||||
except Exception as e:
|
||||
self.logger.log_action({
|
||||
"action": "whitelist_import_failed",
|
||||
"import_path": import_path,
|
||||
"error": str(e)
|
||||
})
|
||||
|
||||
return False
|
||||
|
||||
def get_statistics(self) -> Dict[str, Any]:
|
||||
"""
|
||||
Retourne des statistiques sur la liste blanche
|
||||
|
||||
Returns:
|
||||
Dictionnaire contenant les statistiques
|
||||
"""
|
||||
stats = {
|
||||
"total_entries": len(self.whitelist),
|
||||
"created_at": self.metadata.get("created_at"),
|
||||
"last_modified": self.metadata.get("last_modified"),
|
||||
"version": self.metadata.get("version"),
|
||||
"entries_with_wildcards": sum(1 for w in self.whitelist if "*" in w),
|
||||
"entries_exact": sum(1 for w in self.whitelist if "*" not in w)
|
||||
}
|
||||
|
||||
return stats
|
||||
|
||||
|
||||
if __name__ == "__main__":
|
||||
"""Tests du gestionnaire de liste blanche"""
|
||||
print("Test du WhitelistManager RPA Vision V2")
|
||||
print("=" * 50)
|
||||
|
||||
import tempfile
|
||||
import shutil
|
||||
|
||||
# Créer un répertoire de test
|
||||
test_dir = tempfile.mkdtemp()
|
||||
test_whitelist_path = os.path.join(test_dir, "whitelist.json")
|
||||
|
||||
try:
|
||||
# Test 1: Initialisation
|
||||
print("\n1. Test d'initialisation:")
|
||||
manager = WhitelistManager(
|
||||
whitelist_path=test_whitelist_path,
|
||||
require_admin_confirmation=False # Désactiver pour les tests
|
||||
)
|
||||
print(f" ✓ WhitelistManager initialisé")
|
||||
print(f" - Chemin: {manager.whitelist_path}")
|
||||
print(f" - Taille initiale: {len(manager.whitelist)}")
|
||||
|
||||
# Test 2: Ajout d'entrées
|
||||
print("\n2. Test d'ajout d'entrées:")
|
||||
test_windows = [
|
||||
"Dolibarr*",
|
||||
"Firefox*",
|
||||
"*Chrome",
|
||||
"Visual Studio Code"
|
||||
]
|
||||
|
||||
for window in test_windows:
|
||||
success = manager.add_to_whitelist(window, admin_confirmed=True)
|
||||
status = "✓" if success else "✗"
|
||||
print(f" {status} Ajouté: {window}")
|
||||
|
||||
print(f" - Taille après ajouts: {len(manager.whitelist)}")
|
||||
|
||||
# Test 3: Vérification is_window_allowed
|
||||
print("\n3. Test is_window_allowed:")
|
||||
test_cases = [
|
||||
("Dolibarr - Facturation", True),
|
||||
("Firefox - Mozilla", True),
|
||||
("Google Chrome", True),
|
||||
("Visual Studio Code", True),
|
||||
("Unknown Application", False),
|
||||
("Notepad", False)
|
||||
]
|
||||
|
||||
for window, expected in test_cases:
|
||||
allowed = manager.is_window_allowed(window)
|
||||
status = "✓" if allowed == expected else "✗"
|
||||
result = "Autorisé" if allowed else "Bloqué"
|
||||
print(f" {status} {result}: {window}")
|
||||
|
||||
# Test 4: Sauvegarde et rechargement
|
||||
print("\n4. Test sauvegarde et rechargement:")
|
||||
manager.save_whitelist()
|
||||
print(" ✓ Liste blanche sauvegardée")
|
||||
|
||||
# Créer un nouveau manager pour tester le chargement
|
||||
manager2 = WhitelistManager(
|
||||
whitelist_path=test_whitelist_path,
|
||||
require_admin_confirmation=False
|
||||
)
|
||||
print(f" ✓ Liste blanche rechargée")
|
||||
print(f" - Taille après rechargement: {len(manager2.whitelist)}")
|
||||
assert len(manager2.whitelist) == len(test_windows), "Taille incorrecte après rechargement!"
|
||||
|
||||
# Test 5: Suppression d'entrée
|
||||
print("\n5. Test de suppression:")
|
||||
removed = manager.remove_from_whitelist("Firefox*")
|
||||
print(f" ✓ Supprimé: Firefox* (succès={removed})")
|
||||
print(f" - Taille après suppression: {len(manager.whitelist)}")
|
||||
|
||||
# Test 6: Métadonnées
|
||||
print("\n6. Test des métadonnées:")
|
||||
metadata = manager.get_metadata()
|
||||
print(f" ✓ Métadonnées récupérées:")
|
||||
print(f" - Version: {metadata.get('version')}")
|
||||
print(f" - Dernière modification: {metadata.get('last_modified')}")
|
||||
|
||||
# Test 7: Informations sur une entrée
|
||||
print("\n7. Test get_entry_info:")
|
||||
info = manager.get_entry_info("Dolibarr*")
|
||||
if info:
|
||||
print(f" ✓ Informations pour 'Dolibarr*':")
|
||||
print(f" - Ajouté le: {info.get('added_at')}")
|
||||
print(f" - Ajouté par: {info.get('added_by')}")
|
||||
|
||||
# Test 8: Statistiques
|
||||
print("\n8. Test des statistiques:")
|
||||
stats = manager.get_statistics()
|
||||
print(f" ✓ Statistiques:")
|
||||
print(f" - Total entrées: {stats['total_entries']}")
|
||||
print(f" - Avec wildcards: {stats['entries_with_wildcards']}")
|
||||
print(f" - Exactes: {stats['entries_exact']}")
|
||||
|
||||
# Test 9: Export/Import
|
||||
print("\n9. Test export/import:")
|
||||
export_path = os.path.join(test_dir, "whitelist_export.json")
|
||||
manager.export_whitelist(export_path)
|
||||
print(f" ✓ Exporté vers: {export_path}")
|
||||
|
||||
# Vider et réimporter
|
||||
manager.clear_whitelist()
|
||||
print(f" ✓ Liste blanche vidée (taille={len(manager.whitelist)})")
|
||||
|
||||
manager.import_whitelist(export_path, admin_confirmed=True)
|
||||
print(f" ✓ Réimporté (taille={len(manager.whitelist)})")
|
||||
|
||||
# Test 10: Patterns avec wildcards
|
||||
print("\n10. Test patterns avec wildcards:")
|
||||
manager.clear_whitelist()
|
||||
manager.add_to_whitelist("Fire*fox", admin_confirmed=True)
|
||||
|
||||
test_wildcard_cases = [
|
||||
("Firefox", True),
|
||||
("Firefoxes", True),
|
||||
("Fire123fox", True),
|
||||
("Chrome", False)
|
||||
]
|
||||
|
||||
for window, expected in test_wildcard_cases:
|
||||
allowed = manager.is_window_allowed(window)
|
||||
status = "✓" if allowed == expected else "✗"
|
||||
result = "Autorisé" if allowed else "Bloqué"
|
||||
print(f" {status} {result}: {window}")
|
||||
|
||||
print("\n✓ Tous les tests réussis!")
|
||||
|
||||
finally:
|
||||
# Nettoyer les fichiers de test
|
||||
shutil.rmtree(test_dir)
|
||||
print(f"\n✓ Fichiers de test nettoyés")
|
||||
492
geniusia2/core/workflow_detector.py
Normal file
492
geniusia2/core/workflow_detector.py
Normal file
@@ -0,0 +1,492 @@
|
||||
"""
|
||||
Détecteur de workflows pour identifier les séquences répétitives.
|
||||
Analyse les sessions pour détecter des patterns de workflows.
|
||||
"""
|
||||
|
||||
from datetime import datetime
|
||||
from typing import List, Dict, Any, Optional, Callable
|
||||
from dataclasses import dataclass, field, asdict
|
||||
from pathlib import Path
|
||||
import numpy as np
|
||||
import json
|
||||
|
||||
|
||||
@dataclass
|
||||
class WorkflowStep:
|
||||
"""Représente une étape dans un workflow."""
|
||||
step_id: int
|
||||
action_type: str
|
||||
target_description: str
|
||||
position: tuple
|
||||
window: str
|
||||
embedding: Optional[np.ndarray] = None
|
||||
screenshot: Optional[np.ndarray] = None
|
||||
|
||||
|
||||
@dataclass
|
||||
class Workflow:
|
||||
"""Représente un workflow détecté."""
|
||||
workflow_id: str
|
||||
name: str
|
||||
steps: List[WorkflowStep]
|
||||
repetitions: int
|
||||
confidence: float
|
||||
last_seen: datetime
|
||||
created_at: datetime = field(default_factory=datetime.now)
|
||||
|
||||
|
||||
class WorkflowDetector:
|
||||
"""
|
||||
Détecteur de workflows pour identifier les séquences répétitives.
|
||||
"""
|
||||
|
||||
def __init__(self, logger, config: Dict[str, Any]):
|
||||
"""
|
||||
Initialise le détecteur de workflows.
|
||||
|
||||
Args:
|
||||
logger: Logger pour journalisation
|
||||
config: Configuration globale
|
||||
"""
|
||||
self.logger = logger
|
||||
self.config = config
|
||||
|
||||
# Configuration
|
||||
self.min_repetitions = config.get("workflow", {}).get(
|
||||
"min_repetitions", 3 # 3 répétitions minimum
|
||||
)
|
||||
self.similarity_threshold = config.get("workflow", {}).get(
|
||||
"similarity_threshold", 0.75 # 75% de similarité
|
||||
)
|
||||
self.min_workflow_length = config.get("workflow", {}).get(
|
||||
"min_workflow_length", 3 # 3 actions minimum pour un workflow
|
||||
)
|
||||
|
||||
# Workflows détectés
|
||||
self.workflows: List[Workflow] = []
|
||||
|
||||
# Buffer de sessions récentes pour analyse
|
||||
self.recent_sessions: List[Any] = []
|
||||
self.max_recent_sessions = 10 # Garder les 10 dernières sessions
|
||||
|
||||
# Callback pour workflow détecté
|
||||
self.on_workflow_detected: Optional[Callable] = None
|
||||
|
||||
# Répertoire de persistence
|
||||
self.workflows_dir = Path(config.get("data_dir", "data")) / "user_profiles" / "workflows"
|
||||
self.workflows_dir.mkdir(parents=True, exist_ok=True)
|
||||
|
||||
# Charger les workflows existants
|
||||
self._load_workflows()
|
||||
|
||||
self.logger.log_action({
|
||||
"action": "workflow_detector_initialized",
|
||||
"min_repetitions": self.min_repetitions,
|
||||
"similarity_threshold": self.similarity_threshold,
|
||||
"workflows_loaded": len(self.workflows)
|
||||
})
|
||||
|
||||
def analyze_session(self, session: Any):
|
||||
"""
|
||||
Analyse une session individuelle et l'ajoute au buffer.
|
||||
Détecte des workflows si suffisamment de sessions similaires.
|
||||
|
||||
Args:
|
||||
session: Session à analyser
|
||||
"""
|
||||
# Ajouter au buffer
|
||||
self.recent_sessions.append(session)
|
||||
|
||||
# Limiter la taille du buffer
|
||||
if len(self.recent_sessions) > self.max_recent_sessions:
|
||||
self.recent_sessions.pop(0)
|
||||
|
||||
# Analyser si on a assez de sessions
|
||||
if len(self.recent_sessions) >= self.min_repetitions:
|
||||
self.analyze_sessions(self.recent_sessions)
|
||||
|
||||
def analyze_sessions(self, sessions: List[Any]):
|
||||
"""
|
||||
Analyse les sessions pour détecter des workflows.
|
||||
|
||||
Args:
|
||||
sessions: Liste de sessions à analyser
|
||||
"""
|
||||
if len(sessions) < self.min_repetitions:
|
||||
return
|
||||
|
||||
# Grouper les sessions par fenêtre
|
||||
sessions_by_window = {}
|
||||
for session in sessions:
|
||||
window = session.window or "Unknown"
|
||||
if window not in sessions_by_window:
|
||||
sessions_by_window[window] = []
|
||||
sessions_by_window[window].append(session)
|
||||
|
||||
# Analyser chaque groupe de fenêtre séparément
|
||||
for window, window_sessions in sessions_by_window.items():
|
||||
if len(window_sessions) < self.min_repetitions:
|
||||
continue
|
||||
|
||||
# Comparer les sessions de cette fenêtre
|
||||
for i in range(len(window_sessions) - self.min_repetitions + 1):
|
||||
# Prendre N sessions consécutives
|
||||
session_group = window_sessions[i:i + self.min_repetitions]
|
||||
|
||||
# Calculer la similarité entre elles
|
||||
similarity = self._calculate_session_similarity(session_group)
|
||||
|
||||
if similarity >= self.similarity_threshold:
|
||||
# Workflow détecté !
|
||||
workflow = self._create_workflow_from_sessions(session_group)
|
||||
|
||||
# Ignorer les workflows trop courts (probablement du bruit)
|
||||
if len(workflow.steps) < self.min_workflow_length:
|
||||
continue
|
||||
|
||||
# Vérifier si ce workflow existe déjà
|
||||
existing = self._find_existing_workflow(workflow)
|
||||
|
||||
if existing:
|
||||
# Mettre à jour
|
||||
existing.repetitions += 1
|
||||
existing.last_seen = datetime.now()
|
||||
existing.confidence = min(1.0, existing.confidence + 0.05)
|
||||
|
||||
# Sauvegarder les modifications
|
||||
self._save_workflow(existing)
|
||||
else:
|
||||
# Nouveau workflow
|
||||
self.workflows.append(workflow)
|
||||
|
||||
# Sauvegarder sur disque
|
||||
self._save_workflow(workflow)
|
||||
|
||||
# Callback avec dictionnaire
|
||||
if self.on_workflow_detected:
|
||||
workflow_dict = {
|
||||
"workflow_id": workflow.workflow_id,
|
||||
"name": workflow.name,
|
||||
"pattern": [step.action_type for step in workflow.steps],
|
||||
"steps": workflow.steps,
|
||||
"confidence": workflow.confidence,
|
||||
"repetitions": workflow.repetitions,
|
||||
"last_seen": workflow.last_seen,
|
||||
"window": workflow.steps[0].window if workflow.steps else ""
|
||||
}
|
||||
self.on_workflow_detected(workflow_dict)
|
||||
|
||||
self.logger.log_action({
|
||||
"action": "workflow_detected",
|
||||
"workflow_id": workflow.workflow_id,
|
||||
"workflow_name": workflow.name,
|
||||
"steps": len(workflow.steps),
|
||||
"confidence": workflow.confidence
|
||||
})
|
||||
|
||||
def _calculate_session_similarity(self, sessions: List[Any]) -> float:
|
||||
"""
|
||||
Calcule la similarité entre plusieurs sessions.
|
||||
|
||||
Args:
|
||||
sessions: Sessions à comparer
|
||||
|
||||
Returns:
|
||||
Score de similarité (0-1)
|
||||
"""
|
||||
if len(sessions) < 2:
|
||||
return 0.0
|
||||
|
||||
# Comparer chaque paire de sessions
|
||||
similarities = []
|
||||
|
||||
for i in range(len(sessions) - 1):
|
||||
for j in range(i + 1, len(sessions)):
|
||||
sim = self._compare_two_sessions(sessions[i], sessions[j])
|
||||
similarities.append(sim)
|
||||
|
||||
# Moyenne des similarités
|
||||
return np.mean(similarities) if similarities else 0.0
|
||||
|
||||
def _compare_two_sessions(self, session1: Any, session2: Any) -> float:
|
||||
"""
|
||||
Compare deux sessions.
|
||||
|
||||
Args:
|
||||
session1: Première session
|
||||
session2: Deuxième session
|
||||
|
||||
Returns:
|
||||
Score de similarité (0-1)
|
||||
"""
|
||||
actions1 = session1.actions
|
||||
actions2 = session2.actions
|
||||
|
||||
# Vérifier la longueur
|
||||
if abs(len(actions1) - len(actions2)) > 2:
|
||||
return 0.0
|
||||
|
||||
# Comparer les actions
|
||||
min_len = min(len(actions1), len(actions2))
|
||||
matches = 0
|
||||
|
||||
for i in range(min_len):
|
||||
a1 = actions1[i]
|
||||
a2 = actions2[i]
|
||||
|
||||
# Comparer le type d'action (peut être "type" ou "action_type")
|
||||
type1 = a1.get("action_type") or a1.get("type")
|
||||
type2 = a2.get("action_type") or a2.get("type")
|
||||
|
||||
if type1 == type2:
|
||||
matches += 1
|
||||
|
||||
# Bonus si même fenêtre
|
||||
if a1.get("window") == a2.get("window"):
|
||||
matches += 0.5
|
||||
|
||||
# Score normalisé
|
||||
max_score = min_len * 1.5
|
||||
return matches / max_score if max_score > 0 else 0.0
|
||||
|
||||
def _create_workflow_from_sessions(self, sessions: List[Any]) -> Workflow:
|
||||
"""
|
||||
Crée un workflow à partir de sessions similaires.
|
||||
|
||||
Args:
|
||||
sessions: Sessions à fusionner
|
||||
|
||||
Returns:
|
||||
Workflow créé
|
||||
"""
|
||||
# Prendre la première session comme référence
|
||||
reference = sessions[0]
|
||||
|
||||
# Créer les étapes
|
||||
steps = []
|
||||
for i, action in enumerate(reference.actions):
|
||||
# Gérer les deux formats d'action
|
||||
action_type = action.get("action_type") or action.get("type", "unknown")
|
||||
position = action.get("position") or (action.get("x", 0), action.get("y", 0))
|
||||
|
||||
step = WorkflowStep(
|
||||
step_id=i,
|
||||
action_type=action_type,
|
||||
target_description=action.get("description", ""),
|
||||
position=position,
|
||||
window=action.get("window", ""),
|
||||
embedding=action.get("embedding")
|
||||
)
|
||||
steps.append(step)
|
||||
|
||||
# Générer un nom
|
||||
workflow_name = self._generate_workflow_name(steps)
|
||||
|
||||
# Créer le workflow
|
||||
workflow_id = f"workflow_{datetime.now().strftime('%Y%m%d_%H%M%S')}"
|
||||
|
||||
return Workflow(
|
||||
workflow_id=workflow_id,
|
||||
name=workflow_name,
|
||||
steps=steps,
|
||||
repetitions=len(sessions),
|
||||
confidence=0.8, # Confiance initiale
|
||||
last_seen=datetime.now()
|
||||
)
|
||||
|
||||
def _generate_workflow_name(self, steps: List[WorkflowStep]) -> str:
|
||||
"""
|
||||
Génère un nom pour le workflow.
|
||||
|
||||
Args:
|
||||
steps: Étapes du workflow
|
||||
|
||||
Returns:
|
||||
Nom généré
|
||||
"""
|
||||
if not steps:
|
||||
return "Workflow vide"
|
||||
|
||||
# Utiliser les premières actions
|
||||
first_actions = [s.action_type for s in steps[:3]]
|
||||
|
||||
# Créer un nom descriptif
|
||||
if len(steps) == 1:
|
||||
return f"{steps[0].action_type.capitalize()}"
|
||||
elif len(steps) <= 3:
|
||||
return " → ".join(a.capitalize() for a in first_actions)
|
||||
else:
|
||||
return f"{' → '.join(a.capitalize() for a in first_actions)} (+ {len(steps) - 3} étapes)"
|
||||
|
||||
def _find_existing_workflow(self, workflow: Workflow) -> Optional[Workflow]:
|
||||
"""
|
||||
Cherche un workflow existant similaire.
|
||||
|
||||
Args:
|
||||
workflow: Workflow à chercher
|
||||
|
||||
Returns:
|
||||
Workflow existant ou None
|
||||
"""
|
||||
for existing in self.workflows:
|
||||
# Comparer le nombre d'étapes
|
||||
if len(existing.steps) != len(workflow.steps):
|
||||
continue
|
||||
|
||||
# Comparer les types d'actions
|
||||
matches = 0
|
||||
for i in range(len(existing.steps)):
|
||||
if existing.steps[i].action_type == workflow.steps[i].action_type:
|
||||
matches += 1
|
||||
|
||||
similarity = matches / len(existing.steps)
|
||||
|
||||
if similarity >= 0.8: # 80% de correspondance
|
||||
return existing
|
||||
|
||||
return None
|
||||
|
||||
def get_workflows(self) -> List[Workflow]:
|
||||
"""Retourne tous les workflows détectés."""
|
||||
return self.workflows
|
||||
|
||||
def get_stats(self) -> Dict[str, Any]:
|
||||
"""Retourne les statistiques des workflows."""
|
||||
return {
|
||||
"total_workflows": len(self.workflows),
|
||||
"avg_steps": np.mean([len(w.steps) for w in self.workflows]) if self.workflows else 0,
|
||||
"avg_repetitions": np.mean([w.repetitions for w in self.workflows]) if self.workflows else 0,
|
||||
"avg_confidence": np.mean([w.confidence for w in self.workflows]) if self.workflows else 0
|
||||
}
|
||||
|
||||
def _save_workflow(self, workflow: Workflow):
|
||||
"""
|
||||
Sauvegarde un workflow sur disque.
|
||||
|
||||
Args:
|
||||
workflow: Workflow à sauvegarder
|
||||
"""
|
||||
try:
|
||||
workflow_file = self.workflows_dir / f"{workflow.workflow_id}.json"
|
||||
|
||||
# Convertir en dictionnaire
|
||||
workflow_dict = {
|
||||
"workflow_id": workflow.workflow_id,
|
||||
"name": workflow.name,
|
||||
"repetitions": workflow.repetitions,
|
||||
"confidence": workflow.confidence,
|
||||
"last_seen": workflow.last_seen.isoformat(),
|
||||
"created_at": workflow.created_at.isoformat(),
|
||||
"steps": [
|
||||
{
|
||||
"step_id": step.step_id,
|
||||
"action_type": step.action_type,
|
||||
"target_description": step.target_description,
|
||||
"position": step.position,
|
||||
"window": step.window
|
||||
# Note: on ne sauvegarde pas embedding et screenshot (trop gros)
|
||||
}
|
||||
for step in workflow.steps
|
||||
]
|
||||
}
|
||||
|
||||
# Sauvegarder
|
||||
with open(workflow_file, 'w') as f:
|
||||
json.dump(workflow_dict, f, indent=2)
|
||||
|
||||
self.logger.log_action({
|
||||
"action": "workflow_saved",
|
||||
"workflow_id": workflow.workflow_id,
|
||||
"file": str(workflow_file)
|
||||
})
|
||||
|
||||
except Exception as e:
|
||||
self.logger.log_action({
|
||||
"action": "workflow_save_failed",
|
||||
"workflow_id": workflow.workflow_id,
|
||||
"error": str(e)
|
||||
})
|
||||
|
||||
def _load_workflows(self):
|
||||
"""Charge les workflows depuis le disque."""
|
||||
try:
|
||||
if not self.workflows_dir.exists():
|
||||
return
|
||||
|
||||
workflow_files = list(self.workflows_dir.glob("*.json"))
|
||||
|
||||
for workflow_file in workflow_files:
|
||||
try:
|
||||
with open(workflow_file, 'r') as f:
|
||||
data = json.load(f)
|
||||
|
||||
# Reconstruire le workflow
|
||||
steps = [
|
||||
WorkflowStep(
|
||||
step_id=step_data["step_id"],
|
||||
action_type=step_data["action_type"],
|
||||
target_description=step_data["target_description"],
|
||||
position=tuple(step_data["position"]),
|
||||
window=step_data["window"]
|
||||
)
|
||||
for step_data in data["steps"]
|
||||
]
|
||||
|
||||
workflow = Workflow(
|
||||
workflow_id=data["workflow_id"],
|
||||
name=data["name"],
|
||||
steps=steps,
|
||||
repetitions=data["repetitions"],
|
||||
confidence=data["confidence"],
|
||||
last_seen=datetime.fromisoformat(data["last_seen"]),
|
||||
created_at=datetime.fromisoformat(data["created_at"])
|
||||
)
|
||||
|
||||
self.workflows.append(workflow)
|
||||
|
||||
except Exception as e:
|
||||
self.logger.log_action({
|
||||
"action": "workflow_load_failed",
|
||||
"file": str(workflow_file),
|
||||
"error": str(e)
|
||||
})
|
||||
|
||||
self.logger.log_action({
|
||||
"action": "workflows_loaded",
|
||||
"count": len(self.workflows)
|
||||
})
|
||||
|
||||
except Exception as e:
|
||||
self.logger.log_action({
|
||||
"action": "workflows_load_error",
|
||||
"error": str(e)
|
||||
})
|
||||
|
||||
def update_workflow_confidence(self, workflow_id: str, success: bool):
|
||||
"""
|
||||
Met à jour la confiance d'un workflow après exécution.
|
||||
|
||||
Args:
|
||||
workflow_id: ID du workflow
|
||||
success: True si l'exécution a réussi
|
||||
"""
|
||||
for workflow in self.workflows:
|
||||
if workflow.workflow_id == workflow_id:
|
||||
if success:
|
||||
workflow.confidence = min(1.0, workflow.confidence + 0.05)
|
||||
else:
|
||||
workflow.confidence = max(0.0, workflow.confidence - 0.1)
|
||||
|
||||
workflow.last_seen = datetime.now()
|
||||
|
||||
# Sauvegarder
|
||||
self._save_workflow(workflow)
|
||||
|
||||
self.logger.log_action({
|
||||
"action": "workflow_confidence_updated",
|
||||
"workflow_id": workflow_id,
|
||||
"new_confidence": workflow.confidence,
|
||||
"success": success
|
||||
})
|
||||
|
||||
break
|
||||
536
geniusia2/core/workflow_matcher.py
Normal file
536
geniusia2/core/workflow_matcher.py
Normal file
@@ -0,0 +1,536 @@
|
||||
#!/usr/bin/env python3
|
||||
"""
|
||||
WorkflowMatcher - Compare les actions courantes avec les workflows connus
|
||||
pour détecter les correspondances et suggérer l'auto-complétion.
|
||||
"""
|
||||
|
||||
from typing import List, Dict, Any, Optional
|
||||
from dataclasses import dataclass
|
||||
import math
|
||||
import numpy as np
|
||||
|
||||
|
||||
@dataclass
|
||||
class WorkflowMatch:
|
||||
"""Représente une correspondance entre la session courante et un workflow."""
|
||||
workflow_id: str
|
||||
workflow_name: str
|
||||
confidence: float
|
||||
matched_steps: int
|
||||
total_steps: int
|
||||
remaining_steps: List[Dict[str, Any]]
|
||||
current_step_index: int
|
||||
|
||||
@property
|
||||
def completion_percentage(self) -> float:
|
||||
"""Pourcentage de complétion du workflow."""
|
||||
if self.total_steps == 0:
|
||||
return 0.0
|
||||
return (self.matched_steps / self.total_steps) * 100
|
||||
|
||||
|
||||
class WorkflowMatcher:
|
||||
"""
|
||||
Gestionnaire de correspondance de workflows.
|
||||
Compare les actions courantes avec les workflows connus.
|
||||
"""
|
||||
|
||||
def __init__(self, logger, config: Dict[str, Any], faiss_index=None):
|
||||
"""
|
||||
Initialise le matcher de workflows.
|
||||
|
||||
Args:
|
||||
logger: Logger pour journalisation
|
||||
config: Configuration globale
|
||||
faiss_index: Optional FAISSIndex for visual similarity matching
|
||||
"""
|
||||
self.logger = logger
|
||||
self.config = config
|
||||
self.faiss_index = faiss_index
|
||||
|
||||
# Configuration
|
||||
self.position_tolerance = config.get("workflow", {}).get(
|
||||
"position_tolerance", 50 # 50px par défaut
|
||||
)
|
||||
self.min_confidence = config.get("workflow", {}).get(
|
||||
"min_confidence", 0.80 # 80% par défaut
|
||||
)
|
||||
|
||||
# Visual similarity weight (if FAISS is available)
|
||||
self.visual_similarity_weight = config.get("workflow", {}).get(
|
||||
"visual_similarity_weight", 0.3 # 30% weight for visual similarity
|
||||
)
|
||||
|
||||
if self.logger:
|
||||
self.logger.log_action({
|
||||
"action": "workflow_matcher_initialized",
|
||||
"position_tolerance": self.position_tolerance,
|
||||
"min_confidence": self.min_confidence,
|
||||
"faiss_enabled": faiss_index is not None
|
||||
})
|
||||
|
||||
def match_current_session(
|
||||
self,
|
||||
session_actions: List[Dict[str, Any]],
|
||||
workflows: List[Dict[str, Any]]
|
||||
) -> List[WorkflowMatch]:
|
||||
"""
|
||||
Compare la session courante avec tous les workflows connus.
|
||||
|
||||
Args:
|
||||
session_actions: Liste des actions de la session courante
|
||||
workflows: Liste des workflows connus
|
||||
|
||||
Returns:
|
||||
Liste des correspondances trouvées, triées par confiance
|
||||
"""
|
||||
if not session_actions or not workflows:
|
||||
return []
|
||||
|
||||
matches = []
|
||||
|
||||
for workflow in workflows:
|
||||
match_score = self.calculate_match_score(
|
||||
session_actions,
|
||||
workflow.get("steps", [])
|
||||
)
|
||||
|
||||
if match_score > 0:
|
||||
# Calculer le nombre d'étapes matchées
|
||||
matched_steps = self._count_matched_steps(
|
||||
session_actions,
|
||||
workflow.get("steps", [])
|
||||
)
|
||||
|
||||
total_steps = len(workflow.get("steps", []))
|
||||
|
||||
# Créer la correspondance
|
||||
match = WorkflowMatch(
|
||||
workflow_id=workflow.get("workflow_id", ""),
|
||||
workflow_name=workflow.get("name", "Workflow inconnu"),
|
||||
confidence=match_score,
|
||||
matched_steps=matched_steps,
|
||||
total_steps=total_steps,
|
||||
remaining_steps=workflow.get("steps", [])[matched_steps:],
|
||||
current_step_index=matched_steps
|
||||
)
|
||||
|
||||
matches.append(match)
|
||||
|
||||
# Trier par confiance décroissante
|
||||
matches.sort(key=lambda m: m.confidence, reverse=True)
|
||||
|
||||
if matches and self.logger:
|
||||
self.logger.log_action({
|
||||
"action": "workflows_matched",
|
||||
"num_matches": len(matches),
|
||||
"best_confidence": matches[0].confidence if matches else 0.0
|
||||
})
|
||||
|
||||
return matches
|
||||
|
||||
def calculate_match_score(
|
||||
self,
|
||||
actions: List[Dict[str, Any]],
|
||||
workflow_steps: List[Dict[str, Any]]
|
||||
) -> float:
|
||||
"""
|
||||
Calcule le score de correspondance entre des actions et un workflow.
|
||||
|
||||
Le score prend en compte:
|
||||
- La correspondance des types d'actions
|
||||
- La similarité des positions (avec tolérance)
|
||||
- La correspondance des fenêtres
|
||||
|
||||
Args:
|
||||
actions: Liste des actions à comparer
|
||||
workflow_steps: Étapes du workflow
|
||||
|
||||
Returns:
|
||||
Score de correspondance (0-1)
|
||||
"""
|
||||
if not actions or not workflow_steps:
|
||||
return 0.0
|
||||
|
||||
# Comparer seulement les N premières actions avec le début du workflow
|
||||
num_actions = len(actions)
|
||||
num_steps = len(workflow_steps)
|
||||
|
||||
# On ne peut pas matcher plus d'étapes qu'il n'y en a dans le workflow
|
||||
compare_length = min(num_actions, num_steps)
|
||||
|
||||
if compare_length == 0:
|
||||
return 0.0
|
||||
|
||||
total_score = 0.0
|
||||
|
||||
for i in range(compare_length):
|
||||
action = actions[i]
|
||||
step = workflow_steps[i]
|
||||
|
||||
# Score pour cette étape
|
||||
step_score = self._calculate_step_similarity(action, step)
|
||||
total_score += step_score
|
||||
|
||||
# Score moyen
|
||||
avg_score = total_score / compare_length
|
||||
|
||||
# Bonus si on a matché plusieurs étapes (plus fiable)
|
||||
# Plus on a d'étapes matchées, plus on est confiant
|
||||
sequence_bonus = min(0.1, compare_length * 0.02)
|
||||
|
||||
final_score = min(1.0, avg_score + sequence_bonus)
|
||||
|
||||
return final_score
|
||||
|
||||
def _normalize_action_type(self, action_type: str) -> str:
|
||||
"""
|
||||
Normalise les types d'actions pour le matching.
|
||||
|
||||
Args:
|
||||
action_type: Type d'action brut
|
||||
|
||||
Returns:
|
||||
Type d'action normalisé
|
||||
"""
|
||||
# Mapper les variantes vers un type standard
|
||||
type_mapping = {
|
||||
"mouse_click": "click",
|
||||
"mouse_move": "move",
|
||||
"key_press": "type",
|
||||
"keyboard": "type",
|
||||
}
|
||||
|
||||
return type_mapping.get(action_type, action_type)
|
||||
|
||||
def _calculate_step_similarity(
|
||||
self,
|
||||
action: Dict[str, Any],
|
||||
step: Dict[str, Any]
|
||||
) -> float:
|
||||
"""
|
||||
Calcule la similarité entre une action et une étape de workflow.
|
||||
|
||||
Args:
|
||||
action: Action à comparer
|
||||
step: Étape du workflow
|
||||
|
||||
Returns:
|
||||
Score de similarité (0-1)
|
||||
"""
|
||||
# Check if we have visual embeddings and FAISS index
|
||||
has_visual = (
|
||||
self.faiss_index is not None and
|
||||
action.get("embedding") is not None and
|
||||
step.get("embedding") is not None
|
||||
)
|
||||
|
||||
if has_visual:
|
||||
# Adjust weights to include visual similarity
|
||||
weights = {
|
||||
"action_type": 0.3,
|
||||
"position": 0.2,
|
||||
"window": 0.2,
|
||||
"visual": 0.3
|
||||
}
|
||||
else:
|
||||
# Original weights without visual
|
||||
weights = {
|
||||
"action_type": 0.4,
|
||||
"position": 0.3,
|
||||
"window": 0.3
|
||||
}
|
||||
|
||||
score = 0.0
|
||||
|
||||
# 1. Correspondance du type d'action
|
||||
action_type = self._normalize_action_type(action.get("action_type", ""))
|
||||
step_type = self._normalize_action_type(step.get("action_type", ""))
|
||||
|
||||
action_type_match = (action_type == step_type)
|
||||
if action_type_match:
|
||||
score += weights["action_type"]
|
||||
|
||||
# 2. Similarité de position
|
||||
action_pos = action.get("position", [0, 0])
|
||||
step_pos = step.get("position", [0, 0])
|
||||
|
||||
if action_pos and step_pos:
|
||||
position_similarity = self._calculate_position_similarity(
|
||||
action_pos,
|
||||
step_pos
|
||||
)
|
||||
score += weights["position"] * position_similarity
|
||||
|
||||
# 3. Correspondance de fenêtre
|
||||
action_window = action.get("window", "")
|
||||
step_window = step.get("window", "")
|
||||
|
||||
if action_window and step_window:
|
||||
# Correspondance exacte ou partielle
|
||||
if action_window == step_window:
|
||||
score += weights["window"]
|
||||
elif action_window in step_window or step_window in action_window:
|
||||
score += weights["window"] * 0.5
|
||||
|
||||
# 4. Similarité visuelle (si disponible)
|
||||
if has_visual:
|
||||
visual_similarity = self._calculate_visual_similarity(
|
||||
action.get("embedding"),
|
||||
step.get("embedding")
|
||||
)
|
||||
score += weights["visual"] * visual_similarity
|
||||
|
||||
return score
|
||||
|
||||
def _calculate_visual_similarity(
|
||||
self,
|
||||
embedding1: np.ndarray,
|
||||
embedding2: np.ndarray
|
||||
) -> float:
|
||||
"""
|
||||
Calcule la similarité cosinus entre deux embeddings.
|
||||
|
||||
Args:
|
||||
embedding1: Premier embedding
|
||||
embedding2: Deuxième embedding
|
||||
|
||||
Returns:
|
||||
Score de similarité (0-1)
|
||||
"""
|
||||
if embedding1 is None or embedding2 is None:
|
||||
return 0.0
|
||||
|
||||
try:
|
||||
# Ensure embeddings are numpy arrays
|
||||
if not isinstance(embedding1, np.ndarray):
|
||||
embedding1 = np.array(embedding1)
|
||||
if not isinstance(embedding2, np.ndarray):
|
||||
embedding2 = np.array(embedding2)
|
||||
|
||||
# Normalize embeddings
|
||||
emb1_norm = embedding1 / (np.linalg.norm(embedding1) + 1e-8)
|
||||
emb2_norm = embedding2 / (np.linalg.norm(embedding2) + 1e-8)
|
||||
|
||||
# Cosine similarity
|
||||
similarity = np.dot(emb1_norm, emb2_norm)
|
||||
|
||||
# Clamp to [0, 1]
|
||||
return float(max(0.0, min(1.0, similarity)))
|
||||
|
||||
except Exception as e:
|
||||
if self.logger:
|
||||
self.logger.log_action({
|
||||
"action": "visual_similarity_error",
|
||||
"error": str(e)
|
||||
})
|
||||
return 0.0
|
||||
|
||||
def _calculate_position_similarity(
|
||||
self,
|
||||
pos1: List[int],
|
||||
pos2: List[int]
|
||||
) -> float:
|
||||
"""
|
||||
Calcule la similarité entre deux positions avec tolérance.
|
||||
|
||||
Args:
|
||||
pos1: Position 1 [x, y]
|
||||
pos2: Position 2 [x, y]
|
||||
|
||||
Returns:
|
||||
Score de similarité (0-1)
|
||||
"""
|
||||
if not pos1 or not pos2 or len(pos1) < 2 or len(pos2) < 2:
|
||||
return 0.0
|
||||
|
||||
# Distance euclidienne
|
||||
dx = pos1[0] - pos2[0]
|
||||
dy = pos1[1] - pos2[1]
|
||||
distance = math.sqrt(dx * dx + dy * dy)
|
||||
|
||||
# Si distance <= tolérance, score = 1.0
|
||||
# Si distance > tolérance, score décroît linéairement
|
||||
if distance <= self.position_tolerance:
|
||||
return 1.0
|
||||
else:
|
||||
# Décroissance linéaire jusqu'à 2x la tolérance
|
||||
max_distance = self.position_tolerance * 2
|
||||
if distance >= max_distance:
|
||||
return 0.0
|
||||
else:
|
||||
return 1.0 - ((distance - self.position_tolerance) / self.position_tolerance)
|
||||
|
||||
def _count_matched_steps(
|
||||
self,
|
||||
actions: List[Dict[str, Any]],
|
||||
workflow_steps: List[Dict[str, Any]]
|
||||
) -> int:
|
||||
"""
|
||||
Compte le nombre d'étapes matchées consécutivement.
|
||||
|
||||
Args:
|
||||
actions: Liste des actions
|
||||
workflow_steps: Étapes du workflow
|
||||
|
||||
Returns:
|
||||
Nombre d'étapes matchées
|
||||
"""
|
||||
compare_length = min(len(actions), len(workflow_steps))
|
||||
matched = 0
|
||||
|
||||
for i in range(compare_length):
|
||||
similarity = self._calculate_step_similarity(
|
||||
actions[i],
|
||||
workflow_steps[i]
|
||||
)
|
||||
|
||||
# On considère qu'une étape est matchée si similarité > 0.7
|
||||
if similarity >= 0.7:
|
||||
matched += 1
|
||||
else:
|
||||
# Arrêter au premier non-match (séquence consécutive)
|
||||
break
|
||||
|
||||
return matched
|
||||
|
||||
def find_best_match(
|
||||
self,
|
||||
matches: List[WorkflowMatch]
|
||||
) -> Optional[WorkflowMatch]:
|
||||
"""
|
||||
Trouve la meilleure correspondance parmi une liste.
|
||||
|
||||
Args:
|
||||
matches: Liste des correspondances
|
||||
|
||||
Returns:
|
||||
Meilleure correspondance si confiance > seuil, None sinon
|
||||
"""
|
||||
if not matches:
|
||||
return None
|
||||
|
||||
# Les matches sont déjà triés par confiance décroissante
|
||||
best_match = matches[0]
|
||||
|
||||
# Vérifier le seuil de confiance
|
||||
if best_match.confidence >= self.min_confidence:
|
||||
if self.logger:
|
||||
self.logger.log_action({
|
||||
"action": "best_match_found",
|
||||
"workflow_id": best_match.workflow_id,
|
||||
"workflow_name": best_match.workflow_name,
|
||||
"confidence": best_match.confidence,
|
||||
"matched_steps": best_match.matched_steps,
|
||||
"total_steps": best_match.total_steps
|
||||
})
|
||||
return best_match
|
||||
|
||||
return None
|
||||
|
||||
def get_match_details(self, match: WorkflowMatch) -> Dict[str, Any]:
|
||||
"""
|
||||
Retourne les détails d'une correspondance pour affichage.
|
||||
|
||||
Args:
|
||||
match: Correspondance à détailler
|
||||
|
||||
Returns:
|
||||
Dictionnaire avec les détails
|
||||
"""
|
||||
return {
|
||||
"workflow_id": match.workflow_id,
|
||||
"workflow_name": match.workflow_name,
|
||||
"confidence": match.confidence,
|
||||
"matched_steps": match.matched_steps,
|
||||
"total_steps": match.total_steps,
|
||||
"remaining_steps": match.remaining_steps,
|
||||
"completion_percentage": match.completion_percentage,
|
||||
"next_steps_preview": match.remaining_steps[:3] # 3 prochaines étapes
|
||||
}
|
||||
|
||||
|
||||
if __name__ == "__main__":
|
||||
# Tests basiques
|
||||
print("Test du WorkflowMatcher")
|
||||
print("=" * 50)
|
||||
|
||||
# Mock logger
|
||||
class MockLogger:
|
||||
def log_action(self, data):
|
||||
print(f"[LOG] {data}")
|
||||
|
||||
logger = MockLogger()
|
||||
config = {
|
||||
"workflow": {
|
||||
"position_tolerance": 50,
|
||||
"min_confidence": 0.80
|
||||
}
|
||||
}
|
||||
|
||||
matcher = WorkflowMatcher(logger, config)
|
||||
|
||||
# Test 1: Match parfait
|
||||
print("\n1. Test match parfait:")
|
||||
session_actions = [
|
||||
{
|
||||
"action_type": "click",
|
||||
"position": [100, 100],
|
||||
"window": "Calculatrice"
|
||||
},
|
||||
{
|
||||
"action_type": "type",
|
||||
"position": [0, 0],
|
||||
"window": "Calculatrice"
|
||||
}
|
||||
]
|
||||
|
||||
workflow = {
|
||||
"workflow_id": "calc_001",
|
||||
"name": "Calcul simple",
|
||||
"steps": [
|
||||
{
|
||||
"action_type": "click",
|
||||
"position": [100, 100],
|
||||
"window": "Calculatrice"
|
||||
},
|
||||
{
|
||||
"action_type": "type",
|
||||
"position": [0, 0],
|
||||
"window": "Calculatrice"
|
||||
},
|
||||
{
|
||||
"action_type": "click",
|
||||
"position": [200, 200],
|
||||
"window": "Calculatrice"
|
||||
}
|
||||
]
|
||||
}
|
||||
|
||||
matches = matcher.match_current_session(session_actions, [workflow])
|
||||
print(f" Nombre de matches: {len(matches)}")
|
||||
if matches:
|
||||
print(f" Meilleur match: {matches[0].workflow_name}")
|
||||
print(f" Confiance: {matches[0].confidence:.2%}")
|
||||
print(f" Étapes matchées: {matches[0].matched_steps}/{matches[0].total_steps}")
|
||||
|
||||
# Test 2: Match avec tolérance de position
|
||||
print("\n2. Test avec tolérance de position:")
|
||||
session_actions[0]["position"] = [120, 110] # Légèrement décalé
|
||||
|
||||
matches = matcher.match_current_session(session_actions, [workflow])
|
||||
if matches:
|
||||
print(f" Confiance avec décalage: {matches[0].confidence:.2%}")
|
||||
|
||||
# Test 3: Trouver le meilleur match
|
||||
print("\n3. Test find_best_match:")
|
||||
best = matcher.find_best_match(matches)
|
||||
if best:
|
||||
print(f" Meilleur match trouvé: {best.workflow_name}")
|
||||
print(f" Confiance: {best.confidence:.2%}")
|
||||
details = matcher.get_match_details(best)
|
||||
print(f" Prochaines étapes: {len(details['next_steps_preview'])}")
|
||||
else:
|
||||
print(" Aucun match au-dessus du seuil")
|
||||
|
||||
print("\n✓ Tests terminés!")
|
||||
252
geniusia2/core/workflow_state_adapter.py
Normal file
252
geniusia2/core/workflow_state_adapter.py
Normal file
@@ -0,0 +1,252 @@
|
||||
"""
|
||||
Adaptateur pour intégrer EnrichedScreenState avec le système de workflows existant.
|
||||
Assure la compatibilité arrière complète en Phase 1 (Mode Light).
|
||||
"""
|
||||
|
||||
from typing import Optional, Dict, Any
|
||||
import numpy as np
|
||||
from datetime import datetime
|
||||
|
||||
from .ui_element_models import EnrichedScreenState, WindowInfo
|
||||
from .screen_state_manager import ScreenStateManager
|
||||
from .workflow_detector import WorkflowStep
|
||||
from .logger import Logger
|
||||
|
||||
|
||||
class WorkflowStateAdapter:
|
||||
"""
|
||||
Adaptateur pour convertir entre WorkflowStep (ancien) et EnrichedScreenState (nouveau).
|
||||
|
||||
En Phase 1 (Mode Light):
|
||||
- Convertit les WorkflowStep existants en EnrichedScreenState
|
||||
- Maintient la compatibilité avec le système de workflows existant
|
||||
- Permet une migration progressive sans casser les workflows existants
|
||||
"""
|
||||
|
||||
def __init__(
|
||||
self,
|
||||
screen_state_manager: ScreenStateManager,
|
||||
logger: Logger
|
||||
):
|
||||
"""
|
||||
Initialise l'adaptateur.
|
||||
|
||||
Args:
|
||||
screen_state_manager: Gestionnaire d'états d'écran
|
||||
logger: Logger pour journalisation
|
||||
"""
|
||||
self.screen_state_manager = screen_state_manager
|
||||
self.logger = logger
|
||||
|
||||
def workflow_step_to_screen_state(
|
||||
self,
|
||||
step: WorkflowStep,
|
||||
session_id: str,
|
||||
screenshot_path: str,
|
||||
screen_resolution: tuple = (1920, 1080)
|
||||
) -> EnrichedScreenState:
|
||||
"""
|
||||
Convertit un WorkflowStep en EnrichedScreenState.
|
||||
|
||||
Args:
|
||||
step: WorkflowStep à convertir
|
||||
session_id: ID de session
|
||||
screenshot_path: Chemin vers le screenshot
|
||||
screen_resolution: Résolution d'écran
|
||||
|
||||
Returns:
|
||||
EnrichedScreenState créé
|
||||
"""
|
||||
# Extraire les informations du WorkflowStep
|
||||
app_name = step.window.split(" - ")[0] if " - " in step.window else step.window
|
||||
window_title = step.window
|
||||
|
||||
# Créer l'EnrichedScreenState
|
||||
screen_state = self.screen_state_manager.create_screen_state(
|
||||
session_id=session_id,
|
||||
window_title=window_title,
|
||||
app_name=app_name,
|
||||
screenshot_path=screenshot_path,
|
||||
screen_resolution=screen_resolution,
|
||||
embedding_provider="legacy_workflow_step",
|
||||
detected_text=[step.target_description] if step.target_description else [],
|
||||
context_tags=[step.action_type]
|
||||
)
|
||||
|
||||
self.logger.log_action({
|
||||
"action": "workflow_step_converted",
|
||||
"step_id": step.step_id,
|
||||
"screen_state_id": screen_state.screen_state_id
|
||||
})
|
||||
|
||||
return screen_state
|
||||
|
||||
def screen_state_to_workflow_step(
|
||||
self,
|
||||
screen_state: EnrichedScreenState,
|
||||
step_id: int,
|
||||
action_type: str = "unknown",
|
||||
position: tuple = (0, 0)
|
||||
) -> WorkflowStep:
|
||||
"""
|
||||
Convertit un EnrichedScreenState en WorkflowStep (pour compatibilité).
|
||||
|
||||
Args:
|
||||
screen_state: EnrichedScreenState à convertir
|
||||
step_id: ID de l'étape
|
||||
action_type: Type d'action
|
||||
position: Position de l'action
|
||||
|
||||
Returns:
|
||||
WorkflowStep créé
|
||||
"""
|
||||
# Extraire les informations de l'EnrichedScreenState
|
||||
target_description = " ".join(screen_state.perception.detected_text[:3]) # Premiers 3 mots
|
||||
|
||||
# Créer le WorkflowStep
|
||||
step = WorkflowStep(
|
||||
step_id=step_id,
|
||||
action_type=action_type,
|
||||
target_description=target_description,
|
||||
position=position,
|
||||
window=screen_state.window.window_title,
|
||||
embedding=None, # Sera chargé séparément si nécessaire
|
||||
screenshot=None # Sera chargé séparément si nécessaire
|
||||
)
|
||||
|
||||
self.logger.log_action({
|
||||
"action": "screen_state_converted",
|
||||
"screen_state_id": screen_state.screen_state_id,
|
||||
"step_id": step_id
|
||||
})
|
||||
|
||||
return step
|
||||
|
||||
def save_workflow_with_screen_states(
|
||||
self,
|
||||
workflow_id: str,
|
||||
steps: list,
|
||||
session_id: str
|
||||
) -> list:
|
||||
"""
|
||||
Sauvegarde un workflow en créant des EnrichedScreenState pour chaque étape.
|
||||
|
||||
Args:
|
||||
workflow_id: ID du workflow
|
||||
steps: Liste de WorkflowStep
|
||||
session_id: ID de session
|
||||
|
||||
Returns:
|
||||
Liste des screen_state_id créés
|
||||
"""
|
||||
screen_state_ids = []
|
||||
|
||||
for i, step in enumerate(steps):
|
||||
# Créer un chemin de screenshot (peut être None si pas de screenshot)
|
||||
screenshot_path = f"data/workflows/{workflow_id}/step_{i}.png"
|
||||
|
||||
# Convertir en EnrichedScreenState
|
||||
screen_state = self.workflow_step_to_screen_state(
|
||||
step=step,
|
||||
session_id=session_id,
|
||||
screenshot_path=screenshot_path
|
||||
)
|
||||
|
||||
# Sauvegarder
|
||||
embedding_vector = step.embedding if step.embedding is not None else None
|
||||
self.screen_state_manager.save_screen_state(
|
||||
screen_state,
|
||||
save_embedding=(embedding_vector is not None),
|
||||
embedding_vector=embedding_vector
|
||||
)
|
||||
|
||||
screen_state_ids.append(screen_state.screen_state_id)
|
||||
|
||||
self.logger.log_action({
|
||||
"action": "workflow_saved_with_screen_states",
|
||||
"workflow_id": workflow_id,
|
||||
"steps_count": len(steps),
|
||||
"screen_states_created": len(screen_state_ids)
|
||||
})
|
||||
|
||||
return screen_state_ids
|
||||
|
||||
|
||||
if __name__ == "__main__":
|
||||
# Tests basiques
|
||||
from .logger import Logger
|
||||
from .workflow_detector import WorkflowStep
|
||||
import shutil
|
||||
from pathlib import Path
|
||||
|
||||
print("Test du WorkflowStateAdapter")
|
||||
print("=" * 50)
|
||||
|
||||
# Créer un logger de test
|
||||
logger = Logger(log_dir="test_logs")
|
||||
|
||||
# Créer le manager et l'adaptateur
|
||||
manager = ScreenStateManager(
|
||||
logger=logger,
|
||||
data_dir="test_data",
|
||||
mode="light"
|
||||
)
|
||||
|
||||
adapter = WorkflowStateAdapter(
|
||||
screen_state_manager=manager,
|
||||
logger=logger
|
||||
)
|
||||
|
||||
print("\n1. Test conversion WorkflowStep → EnrichedScreenState:")
|
||||
# Créer un WorkflowStep de test
|
||||
step = WorkflowStep(
|
||||
step_id=1,
|
||||
action_type="click",
|
||||
target_description="Valider Button",
|
||||
position=(100, 200),
|
||||
window="Test App - Main Window",
|
||||
embedding=np.random.rand(512),
|
||||
screenshot=None
|
||||
)
|
||||
|
||||
screen_state = adapter.workflow_step_to_screen_state(
|
||||
step=step,
|
||||
session_id="test_session",
|
||||
screenshot_path="test_data/screens/step_1.png"
|
||||
)
|
||||
|
||||
print(f" Screen State ID: {screen_state.screen_state_id}")
|
||||
print(f" Window Title: {screen_state.window.window_title}")
|
||||
print(f" Detected Text: {screen_state.perception.detected_text}")
|
||||
|
||||
print("\n2. Test conversion EnrichedScreenState → WorkflowStep:")
|
||||
converted_step = adapter.screen_state_to_workflow_step(
|
||||
screen_state=screen_state,
|
||||
step_id=1,
|
||||
action_type="click",
|
||||
position=(100, 200)
|
||||
)
|
||||
|
||||
print(f" Step ID: {converted_step.step_id}")
|
||||
print(f" Action Type: {converted_step.action_type}")
|
||||
print(f" Target: {converted_step.target_description}")
|
||||
print(f" Window: {converted_step.window}")
|
||||
|
||||
print("\n3. Test sauvegarde de workflow avec screen states:")
|
||||
steps = [step]
|
||||
screen_state_ids = adapter.save_workflow_with_screen_states(
|
||||
workflow_id="test_workflow_001",
|
||||
steps=steps,
|
||||
session_id="test_session"
|
||||
)
|
||||
|
||||
print(f" Created {len(screen_state_ids)} screen states")
|
||||
print(f" Screen State IDs: {screen_state_ids}")
|
||||
|
||||
print("\n✓ Tous les tests WorkflowStateAdapter réussis!")
|
||||
|
||||
# Nettoyage
|
||||
if Path("test_data").exists():
|
||||
shutil.rmtree("test_data")
|
||||
if Path("test_logs").exists():
|
||||
shutil.rmtree("test_logs")
|
||||
Reference in New Issue
Block a user