feat(grounding): pipeline centralisé + serveur UI-TARS transformers + nettoyage code mort
Architecture grounding complète :
- core/grounding/server.py : serveur FastAPI (port 8200) avec UI-TARS-1.5-7B en 4-bit NF4
Process séparé avec son propre contexte CUDA (résout le crash Flask/CUDA)
- core/grounding/pipeline.py : orchestrateur cascade template→OCR→UI-TARS→static
- core/grounding/template_matcher.py : TemplateMatcher centralisé (remplace 5 copies)
- core/grounding/ui_tars_grounder.py : client HTTP vers le serveur de grounding
- core/grounding/target.py : GroundingTarget + GroundingResult
ORA modifié :
- _act_click() : capture unique de l'écran envoyée au serveur de grounding
- Pre-check VLM skippé pour ui_tars (redondant, et Ollama n'a plus de VRAM)
- verify_level='none' par défaut (vérification titre OCR prévue en Phase 2)
- Détection réponses négatives UI-TARS ("I don't see it" → fallback OCR)
Nettoyage :
- 9 fichiers morts archivés dans _archive/ (~6300 lignes supprimées)
- 21 tests ajoutés pour TemplateMatcher
Co-Authored-By: Claude Opus 4.6 (1M context) <noreply@anthropic.com>
This commit is contained in:
@@ -1,877 +0,0 @@
|
||||
"""
|
||||
Workflow Simulation Report - Fiche #16++
|
||||
|
||||
Système de simulation complète de workflows pour tester la chaîne complète :
|
||||
Node Matching (FAISS) → Target Resolution → Post-conditions → Transition
|
||||
|
||||
Utilise des "scenario packs" avec frames séquentielles pour simuler des workflows
|
||||
réalistes et générer des rapports de performance détaillés.
|
||||
|
||||
Auteur : Dom, Alice Kiro - 22 décembre 2025
|
||||
"""
|
||||
|
||||
import json
|
||||
import logging
|
||||
import time
|
||||
from dataclasses import dataclass, field
|
||||
from pathlib import Path
|
||||
from typing import Dict, List, Optional, Any, Tuple, Union
|
||||
import numpy as np
|
||||
from datetime import datetime
|
||||
|
||||
from ..models.screen_state import ScreenState
|
||||
from ..models.ui_element import UIElement
|
||||
from ..models.workflow_graph import Workflow, WorkflowNode, WorkflowEdge, TargetSpec, PostConditions, PostConditionCheck
|
||||
from ..graph.node_matcher import NodeMatcher
|
||||
from ..embedding.state_embedding_builder import StateEmbeddingBuilder
|
||||
from ..execution.target_resolver import TargetResolver
|
||||
|
||||
logger = logging.getLogger(__name__)
|
||||
|
||||
|
||||
@dataclass
|
||||
class ScenarioFrame:
|
||||
"""Frame individuelle dans un scénario de workflow"""
|
||||
frame_id: str
|
||||
step_number: int
|
||||
screen_state: ScreenState
|
||||
expected_node_id: Optional[str] = None # Node attendu pour ce frame
|
||||
expected_action: Optional[Dict[str, Any]] = None # Action attendue
|
||||
metadata: Dict[str, Any] = field(default_factory=dict)
|
||||
|
||||
|
||||
@dataclass
|
||||
class ScenarioPack:
|
||||
"""Pack de scénario complet avec frames séquentielles"""
|
||||
scenario_id: str
|
||||
name: str
|
||||
description: str
|
||||
workflow_id: str # Workflow à tester
|
||||
frames: List[ScenarioFrame]
|
||||
expected_path: List[str] # Séquence de node_ids attendue
|
||||
metadata: Dict[str, Any] = field(default_factory=dict)
|
||||
|
||||
@classmethod
|
||||
def load_from_directory(cls, scenario_dir: Path) -> 'ScenarioPack':
|
||||
"""Charger un scenario pack depuis un répertoire"""
|
||||
scenario_file = scenario_dir / "scenario.json"
|
||||
if not scenario_file.exists():
|
||||
raise FileNotFoundError(f"scenario.json not found in {scenario_dir}")
|
||||
|
||||
with open(scenario_file, 'r', encoding='utf-8') as f:
|
||||
scenario_data = json.load(f)
|
||||
|
||||
# Charger les frames
|
||||
frames = []
|
||||
for step_data in scenario_data.get("steps", []):
|
||||
step_file = scenario_dir / f"step_{step_data['step_number']:03d}.json"
|
||||
if not step_file.exists():
|
||||
logger.warning(f"Step file not found: {step_file}")
|
||||
continue
|
||||
|
||||
with open(step_file, 'r', encoding='utf-8') as f:
|
||||
step_content = json.load(f)
|
||||
|
||||
# Reconstruire ScreenState depuis JSON
|
||||
screen_state = ScreenState.from_dict(step_content["screen_state"])
|
||||
|
||||
frame = ScenarioFrame(
|
||||
frame_id=f"{scenario_data['scenario_id']}_step_{step_data['step_number']:03d}",
|
||||
step_number=step_data["step_number"],
|
||||
screen_state=screen_state,
|
||||
expected_node_id=step_data.get("expected_node_id"),
|
||||
expected_action=step_data.get("expected_action"),
|
||||
metadata=step_data.get("metadata", {})
|
||||
)
|
||||
frames.append(frame)
|
||||
|
||||
return cls(
|
||||
scenario_id=scenario_data["scenario_id"],
|
||||
name=scenario_data["name"],
|
||||
description=scenario_data["description"],
|
||||
workflow_id=scenario_data["workflow_id"],
|
||||
frames=frames,
|
||||
expected_path=scenario_data.get("expected_path", []),
|
||||
metadata=scenario_data.get("metadata", {})
|
||||
)
|
||||
|
||||
|
||||
@dataclass
|
||||
class NodeMatchingResult:
|
||||
"""Résultat du matching de node"""
|
||||
frame_id: str
|
||||
expected_node_id: Optional[str]
|
||||
matched_node_id: Optional[str]
|
||||
confidence: float
|
||||
success: bool
|
||||
strategy_used: str
|
||||
error_message: Optional[str] = None
|
||||
alternatives: List[Tuple[str, float]] = field(default_factory=list) # (node_id, confidence)
|
||||
|
||||
|
||||
@dataclass
|
||||
class TargetResolutionResult:
|
||||
"""Résultat de la résolution de cible"""
|
||||
frame_id: str
|
||||
target_spec: Optional[TargetSpec]
|
||||
resolved_element_id: Optional[str]
|
||||
expected_element_id: Optional[str]
|
||||
confidence: float
|
||||
success: bool
|
||||
strategy_used: str
|
||||
resolution_time_ms: float
|
||||
error_message: Optional[str] = None
|
||||
alternatives: List[Dict[str, Any]] = field(default_factory=list)
|
||||
|
||||
|
||||
@dataclass
|
||||
class PostConditionResult:
|
||||
"""Résultat de vérification des post-conditions"""
|
||||
frame_id: str
|
||||
post_conditions: Optional[PostConditions]
|
||||
checks_passed: int
|
||||
checks_total: int
|
||||
success: bool
|
||||
timeout_occurred: bool
|
||||
verification_time_ms: float
|
||||
failed_checks: List[str] = field(default_factory=list)
|
||||
error_message: Optional[str] = None
|
||||
|
||||
|
||||
@dataclass
|
||||
class TransitionResult:
|
||||
"""Résultat de transition vers le node suivant"""
|
||||
from_frame_id: str
|
||||
to_frame_id: str
|
||||
expected_transition: bool
|
||||
actual_transition: bool
|
||||
success: bool
|
||||
transition_confidence: float
|
||||
error_message: Optional[str] = None
|
||||
|
||||
|
||||
@dataclass
|
||||
class WorkflowStepResult:
|
||||
"""Résultat complet d'une étape de workflow"""
|
||||
frame_id: str
|
||||
step_number: int
|
||||
node_matching: NodeMatchingResult
|
||||
target_resolution: Optional[TargetResolutionResult]
|
||||
post_conditions: Optional[PostConditionResult]
|
||||
transition: Optional[TransitionResult]
|
||||
overall_success: bool
|
||||
step_duration_ms: float
|
||||
|
||||
@property
|
||||
def success_components(self) -> Dict[str, bool]:
|
||||
"""Composants de succès pour analyse détaillée"""
|
||||
return {
|
||||
"node_matching": self.node_matching.success,
|
||||
"target_resolution": self.target_resolution.success if self.target_resolution else True,
|
||||
"post_conditions": self.post_conditions.success if self.post_conditions else True,
|
||||
"transition": self.transition.success if self.transition else True
|
||||
}
|
||||
|
||||
|
||||
@dataclass
|
||||
class WorkflowSimulationReport:
|
||||
"""Rapport complet de simulation de workflow"""
|
||||
scenario_id: str
|
||||
workflow_id: str
|
||||
timestamp: datetime
|
||||
total_steps: int
|
||||
successful_steps: int
|
||||
step_results: List[WorkflowStepResult]
|
||||
|
||||
# Métriques globales
|
||||
node_matching_accuracy: float
|
||||
target_resolution_accuracy: float
|
||||
post_condition_success_rate: float
|
||||
transition_accuracy: float
|
||||
|
||||
# Performance
|
||||
total_simulation_time_ms: float
|
||||
avg_step_time_ms: float
|
||||
|
||||
# Analyse des erreurs
|
||||
error_breakdown: Dict[str, int]
|
||||
failure_points: List[str]
|
||||
|
||||
# Recommandations
|
||||
recommendations: List[str]
|
||||
|
||||
@property
|
||||
def overall_success_rate(self) -> float:
|
||||
"""Taux de succès global"""
|
||||
return self.successful_steps / max(1, self.total_steps)
|
||||
|
||||
def to_dict(self) -> Dict[str, Any]:
|
||||
"""Sérialiser en dictionnaire"""
|
||||
return {
|
||||
"scenario_id": self.scenario_id,
|
||||
"workflow_id": self.workflow_id,
|
||||
"timestamp": self.timestamp.isoformat(),
|
||||
"total_steps": self.total_steps,
|
||||
"successful_steps": self.successful_steps,
|
||||
"step_results": [
|
||||
{
|
||||
"frame_id": result.frame_id,
|
||||
"step_number": result.step_number,
|
||||
"overall_success": result.overall_success,
|
||||
"step_duration_ms": result.step_duration_ms,
|
||||
"success_components": result.success_components,
|
||||
"node_matching": {
|
||||
"expected_node_id": result.node_matching.expected_node_id,
|
||||
"matched_node_id": result.node_matching.matched_node_id,
|
||||
"confidence": result.node_matching.confidence,
|
||||
"success": result.node_matching.success,
|
||||
"strategy_used": result.node_matching.strategy_used,
|
||||
"error_message": result.node_matching.error_message
|
||||
},
|
||||
"target_resolution": {
|
||||
"resolved_element_id": result.target_resolution.resolved_element_id if result.target_resolution else None,
|
||||
"confidence": result.target_resolution.confidence if result.target_resolution else 0.0,
|
||||
"success": result.target_resolution.success if result.target_resolution else True,
|
||||
"strategy_used": result.target_resolution.strategy_used if result.target_resolution else "N/A",
|
||||
"resolution_time_ms": result.target_resolution.resolution_time_ms if result.target_resolution else 0.0
|
||||
} if result.target_resolution else None,
|
||||
"post_conditions": {
|
||||
"checks_passed": result.post_conditions.checks_passed if result.post_conditions else 0,
|
||||
"checks_total": result.post_conditions.checks_total if result.post_conditions else 0,
|
||||
"success": result.post_conditions.success if result.post_conditions else True,
|
||||
"verification_time_ms": result.post_conditions.verification_time_ms if result.post_conditions else 0.0
|
||||
} if result.post_conditions else None,
|
||||
"transition": {
|
||||
"expected_transition": result.transition.expected_transition if result.transition else False,
|
||||
"actual_transition": result.transition.actual_transition if result.transition else False,
|
||||
"success": result.transition.success if result.transition else True,
|
||||
"transition_confidence": result.transition.transition_confidence if result.transition else 0.0
|
||||
} if result.transition else None
|
||||
}
|
||||
for result in self.step_results
|
||||
],
|
||||
"metrics": {
|
||||
"node_matching_accuracy": self.node_matching_accuracy,
|
||||
"target_resolution_accuracy": self.target_resolution_accuracy,
|
||||
"post_condition_success_rate": self.post_condition_success_rate,
|
||||
"transition_accuracy": self.transition_accuracy,
|
||||
"overall_success_rate": self.overall_success_rate
|
||||
},
|
||||
"performance": {
|
||||
"total_simulation_time_ms": self.total_simulation_time_ms,
|
||||
"avg_step_time_ms": self.avg_step_time_ms
|
||||
},
|
||||
"analysis": {
|
||||
"error_breakdown": self.error_breakdown,
|
||||
"failure_points": self.failure_points,
|
||||
"recommendations": self.recommendations
|
||||
}
|
||||
}
|
||||
|
||||
def save_to_file(self, filepath: Path) -> None:
|
||||
"""Sauvegarder le rapport dans un fichier JSON"""
|
||||
filepath.parent.mkdir(parents=True, exist_ok=True)
|
||||
with open(filepath, 'w', encoding='utf-8') as f:
|
||||
json.dump(self.to_dict(), f, indent=2, ensure_ascii=False)
|
||||
|
||||
def generate_markdown_report(self) -> str:
|
||||
"""Générer un rapport Markdown lisible"""
|
||||
md_lines = [
|
||||
f"# Workflow Simulation Report",
|
||||
f"",
|
||||
f"**Scenario:** {self.scenario_id}",
|
||||
f"**Workflow:** {self.workflow_id}",
|
||||
f"**Date:** {self.timestamp.strftime('%Y-%m-%d %H:%M:%S')}",
|
||||
f"",
|
||||
f"## Summary",
|
||||
f"",
|
||||
f"- **Total Steps:** {self.total_steps}",
|
||||
f"- **Successful Steps:** {self.successful_steps}",
|
||||
f"- **Overall Success Rate:** {self.overall_success_rate:.1%}",
|
||||
f"- **Total Simulation Time:** {self.total_simulation_time_ms:.0f}ms",
|
||||
f"- **Average Step Time:** {self.avg_step_time_ms:.0f}ms",
|
||||
f"",
|
||||
f"## Component Accuracy",
|
||||
f"",
|
||||
f"| Component | Accuracy |",
|
||||
f"|-----------|----------|",
|
||||
f"| Node Matching | {self.node_matching_accuracy:.1%} |",
|
||||
f"| Target Resolution | {self.target_resolution_accuracy:.1%} |",
|
||||
f"| Post-conditions | {self.post_condition_success_rate:.1%} |",
|
||||
f"| Transitions | {self.transition_accuracy:.1%} |",
|
||||
f"",
|
||||
f"## Error Breakdown",
|
||||
f""
|
||||
]
|
||||
|
||||
if self.error_breakdown:
|
||||
for error_type, count in self.error_breakdown.items():
|
||||
md_lines.append(f"- **{error_type}:** {count}")
|
||||
else:
|
||||
md_lines.append("- No errors detected")
|
||||
|
||||
md_lines.extend([
|
||||
f"",
|
||||
f"## Failure Points",
|
||||
f""
|
||||
])
|
||||
|
||||
if self.failure_points:
|
||||
for failure in self.failure_points:
|
||||
md_lines.append(f"- {failure}")
|
||||
else:
|
||||
md_lines.append("- No critical failure points identified")
|
||||
|
||||
md_lines.extend([
|
||||
f"",
|
||||
f"## Recommendations",
|
||||
f""
|
||||
])
|
||||
|
||||
if self.recommendations:
|
||||
for rec in self.recommendations:
|
||||
md_lines.append(f"- {rec}")
|
||||
else:
|
||||
md_lines.append("- No specific recommendations at this time")
|
||||
|
||||
md_lines.extend([
|
||||
f"",
|
||||
f"## Detailed Step Results",
|
||||
f"",
|
||||
f"| Step | Node Match | Target Res | Post-Cond | Transition | Duration |",
|
||||
f"|------|------------|------------|-----------|------------|----------|"
|
||||
])
|
||||
|
||||
for result in self.step_results:
|
||||
node_status = "✅" if result.node_matching.success else "❌"
|
||||
target_status = "✅" if result.target_resolution and result.target_resolution.success else "N/A"
|
||||
post_status = "✅" if result.post_conditions and result.post_conditions.success else "N/A"
|
||||
trans_status = "✅" if result.transition and result.transition.success else "N/A"
|
||||
|
||||
md_lines.append(
|
||||
f"| {result.step_number} | {node_status} | {target_status} | {post_status} | {trans_status} | {result.step_duration_ms:.0f}ms |"
|
||||
)
|
||||
|
||||
return "\n".join(md_lines)
|
||||
|
||||
|
||||
class WorkflowSimulator:
|
||||
"""
|
||||
Simulateur de workflow complet
|
||||
|
||||
Teste la chaîne complète : Node Matching → Target Resolution → Post-conditions → Transition
|
||||
"""
|
||||
|
||||
def __init__(
|
||||
self,
|
||||
node_matcher: Optional[NodeMatcher] = None,
|
||||
target_resolver: Optional[TargetResolver] = None,
|
||||
state_embedding_builder: Optional[StateEmbeddingBuilder] = None
|
||||
):
|
||||
"""
|
||||
Initialiser le simulateur
|
||||
|
||||
Args:
|
||||
node_matcher: Matcher de nodes (créé par défaut si None)
|
||||
target_resolver: Résolveur de cibles (créé par défaut si None)
|
||||
state_embedding_builder: Builder d'embeddings (créé par défaut si None)
|
||||
"""
|
||||
self.node_matcher = node_matcher or NodeMatcher()
|
||||
self.target_resolver = target_resolver or TargetResolver()
|
||||
self.state_embedding_builder = state_embedding_builder or StateEmbeddingBuilder()
|
||||
|
||||
logger.info("WorkflowSimulator initialized")
|
||||
|
||||
def simulate_workflow(
|
||||
self,
|
||||
scenario_pack: ScenarioPack,
|
||||
workflow: Workflow,
|
||||
output_dir: Optional[Path] = None
|
||||
) -> WorkflowSimulationReport:
|
||||
"""
|
||||
Simuler un workflow complet avec un scenario pack
|
||||
|
||||
Args:
|
||||
scenario_pack: Pack de scénario avec frames séquentielles
|
||||
workflow: Workflow à tester
|
||||
output_dir: Répertoire de sortie pour les rapports (optionnel)
|
||||
|
||||
Returns:
|
||||
Rapport de simulation complet
|
||||
"""
|
||||
start_time = time.time()
|
||||
step_results = []
|
||||
|
||||
logger.info(f"Starting workflow simulation: {scenario_pack.scenario_id}")
|
||||
logger.info(f"Workflow: {workflow.workflow_id}, Steps: {len(scenario_pack.frames)}")
|
||||
|
||||
# Simuler chaque étape
|
||||
for i, frame in enumerate(scenario_pack.frames):
|
||||
step_start = time.time()
|
||||
|
||||
# 1. Node Matching
|
||||
node_matching_result = self._simulate_node_matching(frame, workflow)
|
||||
|
||||
# 2. Target Resolution (si node matché et action attendue)
|
||||
target_resolution_result = None
|
||||
if node_matching_result.success and frame.expected_action:
|
||||
target_resolution_result = self._simulate_target_resolution(frame, workflow, node_matching_result.matched_node_id)
|
||||
|
||||
# 3. Post-conditions (si action résolue)
|
||||
post_condition_result = None
|
||||
if target_resolution_result and target_resolution_result.success:
|
||||
post_condition_result = self._simulate_post_conditions(frame, workflow, node_matching_result.matched_node_id)
|
||||
|
||||
# 4. Transition (si pas dernière étape)
|
||||
transition_result = None
|
||||
if i < len(scenario_pack.frames) - 1:
|
||||
next_frame = scenario_pack.frames[i + 1]
|
||||
transition_result = self._simulate_transition(frame, next_frame, workflow)
|
||||
|
||||
# Calculer succès global de l'étape
|
||||
overall_success = (
|
||||
node_matching_result.success and
|
||||
(target_resolution_result is None or target_resolution_result.success) and
|
||||
(post_condition_result is None or post_condition_result.success) and
|
||||
(transition_result is None or transition_result.success)
|
||||
)
|
||||
|
||||
step_duration = (time.time() - step_start) * 1000
|
||||
|
||||
step_result = WorkflowStepResult(
|
||||
frame_id=frame.frame_id,
|
||||
step_number=frame.step_number,
|
||||
node_matching=node_matching_result,
|
||||
target_resolution=target_resolution_result,
|
||||
post_conditions=post_condition_result,
|
||||
transition=transition_result,
|
||||
overall_success=overall_success,
|
||||
step_duration_ms=step_duration
|
||||
)
|
||||
|
||||
step_results.append(step_result)
|
||||
|
||||
logger.debug(f"Step {frame.step_number}: {'✅' if overall_success else '❌'} ({step_duration:.0f}ms)")
|
||||
|
||||
# Calculer métriques globales
|
||||
total_time = (time.time() - start_time) * 1000
|
||||
report = self._generate_report(scenario_pack, workflow, step_results, total_time)
|
||||
|
||||
# Sauvegarder si répertoire spécifié
|
||||
if output_dir:
|
||||
self._save_reports(report, output_dir)
|
||||
|
||||
logger.info(f"Simulation completed: {report.overall_success_rate:.1%} success rate")
|
||||
return report
|
||||
|
||||
def _simulate_node_matching(self, frame: ScenarioFrame, workflow: Workflow) -> NodeMatchingResult:
|
||||
"""Simuler le matching de node"""
|
||||
try:
|
||||
# Construire embedding pour le frame
|
||||
state_embedding = self.state_embedding_builder.build(frame.screen_state)
|
||||
|
||||
# Tenter de matcher avec les nodes du workflow
|
||||
candidate_nodes = workflow.nodes
|
||||
match_result = self.node_matcher.match(frame.screen_state, candidate_nodes)
|
||||
|
||||
if match_result:
|
||||
matched_node, confidence = match_result
|
||||
success = True
|
||||
matched_node_id = matched_node.node_id
|
||||
strategy_used = "faiss_search" # ou autre selon NodeMatcher
|
||||
error_message = None
|
||||
else:
|
||||
success = False
|
||||
matched_node_id = None
|
||||
confidence = 0.0
|
||||
strategy_used = "none"
|
||||
error_message = "No matching node found"
|
||||
|
||||
return NodeMatchingResult(
|
||||
frame_id=frame.frame_id,
|
||||
expected_node_id=frame.expected_node_id,
|
||||
matched_node_id=matched_node_id,
|
||||
confidence=confidence,
|
||||
success=success,
|
||||
strategy_used=strategy_used,
|
||||
error_message=error_message
|
||||
)
|
||||
|
||||
except Exception as e:
|
||||
logger.error(f"Node matching failed for frame {frame.frame_id}: {e}")
|
||||
return NodeMatchingResult(
|
||||
frame_id=frame.frame_id,
|
||||
expected_node_id=frame.expected_node_id,
|
||||
matched_node_id=None,
|
||||
confidence=0.0,
|
||||
success=False,
|
||||
strategy_used="error",
|
||||
error_message=str(e)
|
||||
)
|
||||
|
||||
def _simulate_target_resolution(
|
||||
self,
|
||||
frame: ScenarioFrame,
|
||||
workflow: Workflow,
|
||||
matched_node_id: str
|
||||
) -> TargetResolutionResult:
|
||||
"""Simuler la résolution de cible"""
|
||||
try:
|
||||
start_time = time.time()
|
||||
|
||||
# Récupérer l'action attendue
|
||||
expected_action = frame.expected_action
|
||||
if not expected_action or "target" not in expected_action:
|
||||
return TargetResolutionResult(
|
||||
frame_id=frame.frame_id,
|
||||
target_spec=None,
|
||||
resolved_element_id=None,
|
||||
expected_element_id=None,
|
||||
confidence=0.0,
|
||||
success=True, # Pas d'action = succès
|
||||
strategy_used="no_action",
|
||||
resolution_time_ms=0.0
|
||||
)
|
||||
|
||||
# Construire TargetSpec depuis l'action attendue
|
||||
target_spec = TargetSpec.from_dict(expected_action["target"])
|
||||
|
||||
# Résoudre la cible
|
||||
resolved_target = self.target_resolver.resolve_target(
|
||||
target_spec,
|
||||
frame.screen_state,
|
||||
context={}
|
||||
)
|
||||
|
||||
resolution_time = (time.time() - start_time) * 1000
|
||||
|
||||
if resolved_target:
|
||||
return TargetResolutionResult(
|
||||
frame_id=frame.frame_id,
|
||||
target_spec=target_spec,
|
||||
resolved_element_id=resolved_target.element.element_id,
|
||||
expected_element_id=expected_action.get("expected_element_id"),
|
||||
confidence=resolved_target.confidence,
|
||||
success=True,
|
||||
strategy_used=resolved_target.strategy_used,
|
||||
resolution_time_ms=resolution_time
|
||||
)
|
||||
else:
|
||||
return TargetResolutionResult(
|
||||
frame_id=frame.frame_id,
|
||||
target_spec=target_spec,
|
||||
resolved_element_id=None,
|
||||
expected_element_id=expected_action.get("expected_element_id"),
|
||||
confidence=0.0,
|
||||
success=False,
|
||||
strategy_used="failed",
|
||||
resolution_time_ms=resolution_time,
|
||||
error_message="Target resolution failed"
|
||||
)
|
||||
|
||||
except Exception as e:
|
||||
logger.error(f"Target resolution failed for frame {frame.frame_id}: {e}")
|
||||
return TargetResolutionResult(
|
||||
frame_id=frame.frame_id,
|
||||
target_spec=None,
|
||||
resolved_element_id=None,
|
||||
expected_element_id=None,
|
||||
confidence=0.0,
|
||||
success=False,
|
||||
strategy_used="error",
|
||||
resolution_time_ms=0.0,
|
||||
error_message=str(e)
|
||||
)
|
||||
|
||||
def _simulate_post_conditions(
|
||||
self,
|
||||
frame: ScenarioFrame,
|
||||
workflow: Workflow,
|
||||
matched_node_id: str
|
||||
) -> PostConditionResult:
|
||||
"""Simuler la vérification des post-conditions"""
|
||||
try:
|
||||
start_time = time.time()
|
||||
|
||||
# Trouver l'edge correspondant pour récupérer les post-conditions
|
||||
outgoing_edges = workflow.get_outgoing_edges(matched_node_id)
|
||||
if not outgoing_edges:
|
||||
return PostConditionResult(
|
||||
frame_id=frame.frame_id,
|
||||
post_conditions=None,
|
||||
checks_passed=0,
|
||||
checks_total=0,
|
||||
success=True, # Pas de post-conditions = succès
|
||||
timeout_occurred=False,
|
||||
verification_time_ms=0.0
|
||||
)
|
||||
|
||||
# Prendre le premier edge (simplification)
|
||||
edge = outgoing_edges[0]
|
||||
post_conditions = edge.post_conditions
|
||||
|
||||
if not post_conditions or not post_conditions.success:
|
||||
return PostConditionResult(
|
||||
frame_id=frame.frame_id,
|
||||
post_conditions=post_conditions,
|
||||
checks_passed=0,
|
||||
checks_total=0,
|
||||
success=True,
|
||||
timeout_occurred=False,
|
||||
verification_time_ms=0.0
|
||||
)
|
||||
|
||||
# Simuler vérification des post-conditions
|
||||
checks_total = len(post_conditions.success)
|
||||
checks_passed = 0
|
||||
failed_checks = []
|
||||
|
||||
for check in post_conditions.success:
|
||||
if self._verify_post_condition_check(check, frame.screen_state):
|
||||
checks_passed += 1
|
||||
else:
|
||||
failed_checks.append(f"{check.kind}: {check.value}")
|
||||
|
||||
verification_time = (time.time() - start_time) * 1000
|
||||
success = checks_passed == checks_total
|
||||
|
||||
return PostConditionResult(
|
||||
frame_id=frame.frame_id,
|
||||
post_conditions=post_conditions,
|
||||
checks_passed=checks_passed,
|
||||
checks_total=checks_total,
|
||||
success=success,
|
||||
timeout_occurred=False,
|
||||
verification_time_ms=verification_time,
|
||||
failed_checks=failed_checks
|
||||
)
|
||||
|
||||
except Exception as e:
|
||||
logger.error(f"Post-condition verification failed for frame {frame.frame_id}: {e}")
|
||||
return PostConditionResult(
|
||||
frame_id=frame.frame_id,
|
||||
post_conditions=None,
|
||||
checks_passed=0,
|
||||
checks_total=0,
|
||||
success=False,
|
||||
timeout_occurred=False,
|
||||
verification_time_ms=0.0,
|
||||
error_message=str(e)
|
||||
)
|
||||
|
||||
def _verify_post_condition_check(self, check: PostConditionCheck, screen_state: ScreenState) -> bool:
|
||||
"""Vérifier une post-condition individuelle"""
|
||||
try:
|
||||
if check.kind == "text_present":
|
||||
# Vérifier présence de texte
|
||||
detected_texts = getattr(screen_state.perception, 'detected_text', []) if hasattr(screen_state, 'perception') else []
|
||||
return any(check.value in text for text in detected_texts)
|
||||
|
||||
elif check.kind == "text_absent":
|
||||
# Vérifier absence de texte
|
||||
detected_texts = getattr(screen_state.perception, 'detected_text', []) if hasattr(screen_state, 'perception') else []
|
||||
return not any(check.value in text for text in detected_texts)
|
||||
|
||||
elif check.kind == "element_present":
|
||||
# Vérifier présence d'élément
|
||||
if not check.target:
|
||||
return False
|
||||
resolved_target = self.target_resolver.resolve_target(check.target, screen_state, context={})
|
||||
return resolved_target is not None
|
||||
|
||||
elif check.kind == "window_title_contains":
|
||||
# Vérifier titre de fenêtre
|
||||
window_title = getattr(screen_state.window, 'window_title', '') if hasattr(screen_state, 'window') else ''
|
||||
return check.value in window_title
|
||||
|
||||
else:
|
||||
logger.warning(f"Unknown post-condition check kind: {check.kind}")
|
||||
return False
|
||||
|
||||
except Exception as e:
|
||||
logger.error(f"Post-condition check failed: {e}")
|
||||
return False
|
||||
|
||||
def _simulate_transition(
|
||||
self,
|
||||
current_frame: ScenarioFrame,
|
||||
next_frame: ScenarioFrame,
|
||||
workflow: Workflow
|
||||
) -> TransitionResult:
|
||||
"""Simuler la transition vers le frame suivant"""
|
||||
try:
|
||||
# Vérifier si une transition est attendue
|
||||
expected_transition = (
|
||||
current_frame.expected_node_id != next_frame.expected_node_id and
|
||||
current_frame.expected_node_id is not None and
|
||||
next_frame.expected_node_id is not None
|
||||
)
|
||||
|
||||
# Simuler la transition (ici on assume qu'elle réussit si les nodes sont différents)
|
||||
actual_transition = expected_transition
|
||||
success = expected_transition == actual_transition
|
||||
transition_confidence = 1.0 if success else 0.0
|
||||
|
||||
return TransitionResult(
|
||||
from_frame_id=current_frame.frame_id,
|
||||
to_frame_id=next_frame.frame_id,
|
||||
expected_transition=expected_transition,
|
||||
actual_transition=actual_transition,
|
||||
success=success,
|
||||
transition_confidence=transition_confidence
|
||||
)
|
||||
|
||||
except Exception as e:
|
||||
logger.error(f"Transition simulation failed: {e}")
|
||||
return TransitionResult(
|
||||
from_frame_id=current_frame.frame_id,
|
||||
to_frame_id=next_frame.frame_id,
|
||||
expected_transition=False,
|
||||
actual_transition=False,
|
||||
success=False,
|
||||
transition_confidence=0.0,
|
||||
error_message=str(e)
|
||||
)
|
||||
|
||||
def _generate_report(
|
||||
self,
|
||||
scenario_pack: ScenarioPack,
|
||||
workflow: Workflow,
|
||||
step_results: List[WorkflowStepResult],
|
||||
total_time_ms: float
|
||||
) -> WorkflowSimulationReport:
|
||||
"""Générer le rapport final"""
|
||||
total_steps = len(step_results)
|
||||
successful_steps = sum(1 for result in step_results if result.overall_success)
|
||||
|
||||
# Calculer métriques par composant
|
||||
node_matching_successes = sum(1 for result in step_results if result.node_matching.success)
|
||||
target_resolution_successes = sum(1 for result in step_results
|
||||
if result.target_resolution is None or result.target_resolution.success)
|
||||
post_condition_successes = sum(1 for result in step_results
|
||||
if result.post_conditions is None or result.post_conditions.success)
|
||||
transition_successes = sum(1 for result in step_results
|
||||
if result.transition is None or result.transition.success)
|
||||
|
||||
node_matching_accuracy = node_matching_successes / max(1, total_steps)
|
||||
target_resolution_accuracy = target_resolution_successes / max(1, total_steps)
|
||||
post_condition_success_rate = post_condition_successes / max(1, total_steps)
|
||||
transition_accuracy = transition_successes / max(1, total_steps)
|
||||
|
||||
# Analyser les erreurs
|
||||
error_breakdown = {}
|
||||
failure_points = []
|
||||
|
||||
for result in step_results:
|
||||
if not result.overall_success:
|
||||
failure_points.append(f"Step {result.step_number}: {result.frame_id}")
|
||||
|
||||
if not result.node_matching.success:
|
||||
error_breakdown["node_matching_failures"] = error_breakdown.get("node_matching_failures", 0) + 1
|
||||
if result.target_resolution and not result.target_resolution.success:
|
||||
error_breakdown["target_resolution_failures"] = error_breakdown.get("target_resolution_failures", 0) + 1
|
||||
if result.post_conditions and not result.post_conditions.success:
|
||||
error_breakdown["post_condition_failures"] = error_breakdown.get("post_condition_failures", 0) + 1
|
||||
if result.transition and not result.transition.success:
|
||||
error_breakdown["transition_failures"] = error_breakdown.get("transition_failures", 0) + 1
|
||||
|
||||
# Générer recommandations
|
||||
recommendations = []
|
||||
if node_matching_accuracy < 0.9:
|
||||
recommendations.append("Consider improving node matching accuracy by updating embedding prototypes")
|
||||
if target_resolution_accuracy < 0.9:
|
||||
recommendations.append("Review target resolution strategies and fallback mechanisms")
|
||||
if post_condition_success_rate < 0.9:
|
||||
recommendations.append("Verify post-condition definitions and timeout settings")
|
||||
if transition_accuracy < 0.9:
|
||||
recommendations.append("Check workflow edge definitions and transition logic")
|
||||
|
||||
avg_step_time = total_time_ms / max(1, total_steps)
|
||||
|
||||
return WorkflowSimulationReport(
|
||||
scenario_id=scenario_pack.scenario_id,
|
||||
workflow_id=workflow.workflow_id,
|
||||
timestamp=datetime.now(),
|
||||
total_steps=total_steps,
|
||||
successful_steps=successful_steps,
|
||||
step_results=step_results,
|
||||
node_matching_accuracy=node_matching_accuracy,
|
||||
target_resolution_accuracy=target_resolution_accuracy,
|
||||
post_condition_success_rate=post_condition_success_rate,
|
||||
transition_accuracy=transition_accuracy,
|
||||
total_simulation_time_ms=total_time_ms,
|
||||
avg_step_time_ms=avg_step_time,
|
||||
error_breakdown=error_breakdown,
|
||||
failure_points=failure_points,
|
||||
recommendations=recommendations
|
||||
)
|
||||
|
||||
def _save_reports(self, report: WorkflowSimulationReport, output_dir: Path) -> None:
|
||||
"""Sauvegarder les rapports JSON et Markdown"""
|
||||
output_dir.mkdir(parents=True, exist_ok=True)
|
||||
|
||||
# Rapport JSON
|
||||
json_path = output_dir / f"workflow_simulation_{report.scenario_id}_{report.timestamp.strftime('%Y%m%d_%H%M%S')}.json"
|
||||
report.save_to_file(json_path)
|
||||
|
||||
# Rapport Markdown
|
||||
md_path = output_dir / f"workflow_simulation_{report.scenario_id}_{report.timestamp.strftime('%Y%m%d_%H%M%S')}.md"
|
||||
with open(md_path, 'w', encoding='utf-8') as f:
|
||||
f.write(report.generate_markdown_report())
|
||||
|
||||
logger.info(f"Reports saved to {output_dir}")
|
||||
|
||||
|
||||
# ============================================================================
|
||||
# Fonctions utilitaires
|
||||
# ============================================================================
|
||||
|
||||
def load_scenario_pack(scenario_dir: Union[str, Path]) -> ScenarioPack:
|
||||
"""Charger un scenario pack depuis un répertoire"""
|
||||
return ScenarioPack.load_from_directory(Path(scenario_dir))
|
||||
|
||||
|
||||
def simulate_workflow_from_files(
|
||||
scenario_dir: Union[str, Path],
|
||||
workflow_file: Union[str, Path],
|
||||
output_dir: Optional[Union[str, Path]] = None
|
||||
) -> WorkflowSimulationReport:
|
||||
"""
|
||||
Simuler un workflow depuis des fichiers
|
||||
|
||||
Args:
|
||||
scenario_dir: Répertoire du scenario pack
|
||||
workflow_file: Fichier JSON du workflow
|
||||
output_dir: Répertoire de sortie (optionnel)
|
||||
|
||||
Returns:
|
||||
Rapport de simulation
|
||||
"""
|
||||
# Charger scenario pack
|
||||
scenario_pack = load_scenario_pack(scenario_dir)
|
||||
|
||||
# Charger workflow
|
||||
workflow = Workflow.load_from_file(Path(workflow_file))
|
||||
|
||||
# Créer simulateur
|
||||
simulator = WorkflowSimulator()
|
||||
|
||||
# Exécuter simulation
|
||||
output_path = Path(output_dir) if output_dir else None
|
||||
return simulator.simulate_workflow(scenario_pack, workflow, output_path)
|
||||
|
||||
|
||||
if __name__ == "__main__":
|
||||
# Test basique
|
||||
logging.basicConfig(level=logging.INFO)
|
||||
|
||||
# Exemple d'utilisation
|
||||
scenario_dir = Path("tests/scenarios/login_flow")
|
||||
workflow_file = Path("data/workflows/login_workflow.json")
|
||||
output_dir = Path("data/simulation_reports")
|
||||
|
||||
if scenario_dir.exists() and workflow_file.exists():
|
||||
report = simulate_workflow_from_files(scenario_dir, workflow_file, output_dir)
|
||||
print(f"Simulation completed: {report.overall_success_rate:.1%} success rate")
|
||||
else:
|
||||
print("Example files not found - create test scenarios first")
|
||||
@@ -1363,20 +1363,51 @@ Règles:
|
||||
x, y = None, None
|
||||
method_used = ''
|
||||
|
||||
# --- Méthode 1 : UI-TARS grounding (~3s, 94% précision) ---
|
||||
# Le plus fiable : on dit "click on X" et UI-TARS trouve les coordonnées
|
||||
# --- Capture unique de l'écran pour TOUTES les méthodes ---
|
||||
_screen_b64 = None
|
||||
if MSS_AVAILABLE and PIL_AVAILABLE:
|
||||
try:
|
||||
import io as _io
|
||||
with mss_lib.mss() as _sct:
|
||||
_mon = _sct.monitors[0]
|
||||
_grab = _sct.grab(_mon)
|
||||
_screen_pil = Image.frombytes('RGB', _grab.size, _grab.bgra, 'raw', 'BGRX')
|
||||
_buf = _io.BytesIO()
|
||||
_screen_pil.save(_buf, format='JPEG', quality=85)
|
||||
_screen_b64 = base64.b64encode(_buf.getvalue()).decode('utf-8')
|
||||
print(f"📸 [ORA/capture] Écran capturé: {_screen_pil.size}")
|
||||
except Exception as _e:
|
||||
print(f"⚠️ [ORA/capture] Erreur: {_e}")
|
||||
|
||||
# --- Méthode 1 : UI-TARS via serveur grounding (port 8200, ~3s) ---
|
||||
# Le serveur tourne dans un process séparé avec son propre CUDA context.
|
||||
# Si le serveur n'est pas lancé → on passe au template matching.
|
||||
if target_text or target_desc:
|
||||
try:
|
||||
from core.execution.input_handler import _grounding_ui_tars
|
||||
import requests as _http
|
||||
click_label = target_desc or target_text
|
||||
print(f"🎯 [ORA/UI-TARS] Recherche: '{click_label}'")
|
||||
result = _grounding_ui_tars(target_text, target_desc)
|
||||
if result:
|
||||
x, y = result['x'], result['y']
|
||||
method_used = 'ui_tars'
|
||||
print(f"✅ [ORA/UI-TARS] Trouvé à ({x}, {y})")
|
||||
_payload = {
|
||||
'target_text': target_text,
|
||||
'target_description': target_desc,
|
||||
}
|
||||
if _screen_b64:
|
||||
_payload['image_b64'] = _screen_b64
|
||||
_resp = _http.post('http://localhost:8200/ground', json=_payload, timeout=30)
|
||||
if _resp.status_code == 200:
|
||||
_data = _resp.json()
|
||||
if _data.get('x') is not None:
|
||||
x, y = _data['x'], _data['y']
|
||||
method_used = 'ui_tars'
|
||||
print(f"✅ [ORA/UI-TARS] Trouvé à ({x}, {y}) conf={_data.get('confidence', 0):.2f} ({_data.get('time_ms', 0):.0f}ms)")
|
||||
else:
|
||||
print(f"⚠️ [ORA/UI-TARS] Serveur n'a pas trouvé '{click_label}'")
|
||||
else:
|
||||
print(f"⚠️ [ORA/UI-TARS] Serveur HTTP {_resp.status_code}")
|
||||
except _http.ConnectionError:
|
||||
print(f"⚠️ [ORA/UI-TARS] Serveur grounding non démarré (port 8200)")
|
||||
except Exception as e:
|
||||
logger.debug(f"⚠️ [ORA/UI-TARS] Erreur: {e}")
|
||||
print(f"⚠️ [ORA/UI-TARS] Erreur: {e}")
|
||||
|
||||
# --- Méthode 2 : Template matching (~80ms) ---
|
||||
if x is None and screenshot_b64 and CV2_AVAILABLE and PIL_AVAILABLE and MSS_AVAILABLE:
|
||||
@@ -1405,19 +1436,22 @@ Règles:
|
||||
y = max_loc[1] + anchor_cv.shape[0] // 2
|
||||
method_used = 'template'
|
||||
except Exception as e:
|
||||
logger.debug(f"⚠️ [ORA/template] Erreur: {e}")
|
||||
print(f"⚠️ [ORA/template] Erreur: {e}")
|
||||
|
||||
# --- Méthode 3 : OCR texte (~1s) ---
|
||||
if x is None and target_text:
|
||||
try:
|
||||
from core.execution.input_handler import _grounding_ocr
|
||||
print(f"🔍 [ORA/OCR] Recherche: '{target_text}'")
|
||||
result = _grounding_ocr(target_text, anchor_bbox=bbox if bbox else None)
|
||||
if result:
|
||||
x, y = result['x'], result['y']
|
||||
method_used = 'ocr'
|
||||
print(f"🔍 [ORA/OCR] Trouvé à ({x}, {y})")
|
||||
else:
|
||||
print(f"🔍 [ORA/OCR] '{target_text}' non trouvé")
|
||||
except Exception as e:
|
||||
logger.debug(f"⚠️ [ORA/OCR] Erreur: {e}")
|
||||
print(f"⚠️ [ORA/OCR] Erreur: {e}")
|
||||
|
||||
# --- Exécuter le clic ---
|
||||
if x is None:
|
||||
@@ -1426,13 +1460,13 @@ Règles:
|
||||
x = int(bbox.get('x', 0) + bbox.get('width', 0) / 2)
|
||||
y = int(bbox.get('y', 0) + bbox.get('height', 0) / 2)
|
||||
method_used = 'static_fallback'
|
||||
logger.warning(f"⚠️ [ORA/click] Fallback coordonnées statiques: ({x}, {y})")
|
||||
print(f"⚠️ [ORA/click] Fallback coordonnées statiques: ({x}, {y})")
|
||||
else:
|
||||
logger.error(f"❌ [ORA/click] Impossible de localiser '{target_text}' — aucune méthode n'a fonctionné")
|
||||
return False
|
||||
|
||||
# --- Vérification pré-action : est-ce le bon élément ? ---
|
||||
if target_text and method_used not in ('template',) and MSS_AVAILABLE and PIL_AVAILABLE:
|
||||
# --- Vérification pré-action (skip si UI-TARS a déjà validé visuellement) ---
|
||||
if target_text and method_used not in ('template', 'ui_tars') and MSS_AVAILABLE and PIL_AVAILABLE:
|
||||
try:
|
||||
pre_check = self._verify_pre_click(x, y, target_text, target_desc)
|
||||
if not pre_check:
|
||||
|
||||
20
core/grounding/__init__.py
Normal file
20
core/grounding/__init__.py
Normal file
@@ -0,0 +1,20 @@
|
||||
# core/grounding — Module de localisation d'éléments UI
|
||||
#
|
||||
# Centralise les méthodes de grounding visuel : template matching,
|
||||
# OCR, VLM, etc. Chaque méthode produit un GroundingResult uniforme.
|
||||
#
|
||||
# Le serveur de grounding (server.py) tourne dans un process séparé
|
||||
# sur le port 8200. Le client HTTP (UITarsGrounder) l'appelle via HTTP.
|
||||
# Le pipeline (GroundingPipeline) orchestre template → OCR → UI-TARS → static.
|
||||
|
||||
from core.grounding.template_matcher import TemplateMatcher, MatchResult
|
||||
from core.grounding.target import GroundingTarget, GroundingResult
|
||||
from core.grounding.ui_tars_grounder import UITarsGrounder
|
||||
from core.grounding.pipeline import GroundingPipeline
|
||||
|
||||
__all__ = [
|
||||
'TemplateMatcher', 'MatchResult',
|
||||
'GroundingTarget', 'GroundingResult',
|
||||
'UITarsGrounder',
|
||||
'GroundingPipeline',
|
||||
]
|
||||
190
core/grounding/pipeline.py
Normal file
190
core/grounding/pipeline.py
Normal file
@@ -0,0 +1,190 @@
|
||||
"""
|
||||
core/grounding/pipeline.py — Pipeline de grounding en cascade
|
||||
|
||||
Orchestre les methodes de localisation dans l'ordre :
|
||||
1. Template matching (TemplateMatcher, local, ~80ms)
|
||||
2. OCR (docTR via input_handler, local, ~1s)
|
||||
3. UI-TARS (HTTP vers serveur grounding, ~3s)
|
||||
4. Static fallback (coordonnees d'origine du workflow)
|
||||
|
||||
Chaque methode est essayee dans l'ordre. Des qu'une reussit, on retourne
|
||||
le resultat. Cela permet un equilibre entre vitesse (template) et robustesse
|
||||
(UI-TARS pour les elements qui ont change de position/apparence).
|
||||
|
||||
Utilisation :
|
||||
from core.grounding.pipeline import GroundingPipeline
|
||||
from core.grounding.target import GroundingTarget
|
||||
|
||||
pipeline = GroundingPipeline()
|
||||
result = pipeline.locate(GroundingTarget(
|
||||
text="Valider",
|
||||
description="bouton vert en bas",
|
||||
template_b64=screenshot_b64,
|
||||
original_bbox={"x": 100, "y": 200, "width": 80, "height": 30},
|
||||
))
|
||||
if result:
|
||||
print(f"Trouve a ({result.x}, {result.y}) via {result.method}")
|
||||
"""
|
||||
|
||||
from __future__ import annotations
|
||||
|
||||
import time
|
||||
from typing import Optional
|
||||
|
||||
from core.grounding.target import GroundingTarget, GroundingResult
|
||||
|
||||
|
||||
class GroundingPipeline:
|
||||
"""Pipeline de localisation en cascade : template -> OCR -> UI-TARS -> static."""
|
||||
|
||||
def __init__(self, template_threshold: float = 0.75, enable_uitars: bool = True):
|
||||
self.template_threshold = template_threshold
|
||||
self.enable_uitars = enable_uitars
|
||||
|
||||
def locate(self, target: GroundingTarget) -> Optional[GroundingResult]:
|
||||
"""Localise un element UI en essayant les methodes en cascade.
|
||||
|
||||
Args:
|
||||
target: description de l'element a localiser
|
||||
|
||||
Returns:
|
||||
GroundingResult ou None si aucune methode ne trouve l'element
|
||||
"""
|
||||
t0 = time.time()
|
||||
|
||||
# --- Methode 1 : Template matching (~80ms) ---
|
||||
result = self._try_template(target)
|
||||
if result:
|
||||
print(f"[GroundingPipeline] Localise via {result.method} en "
|
||||
f"{(time.time() - t0) * 1000:.0f}ms")
|
||||
return result
|
||||
|
||||
# --- Methode 2 : OCR texte (~1s) ---
|
||||
result = self._try_ocr(target)
|
||||
if result:
|
||||
print(f"[GroundingPipeline] Localise via {result.method} en "
|
||||
f"{(time.time() - t0) * 1000:.0f}ms")
|
||||
return result
|
||||
|
||||
# --- Methode 3 : UI-TARS via serveur HTTP (~3s) ---
|
||||
if self.enable_uitars:
|
||||
result = self._try_uitars(target)
|
||||
if result:
|
||||
print(f"[GroundingPipeline] Localise via {result.method} en "
|
||||
f"{(time.time() - t0) * 1000:.0f}ms")
|
||||
return result
|
||||
|
||||
# --- Methode 4 : Fallback statique ---
|
||||
result = self._try_static(target)
|
||||
if result:
|
||||
print(f"[GroundingPipeline] Localise via {result.method} en "
|
||||
f"{(time.time() - t0) * 1000:.0f}ms")
|
||||
return result
|
||||
|
||||
print(f"[GroundingPipeline] ECHEC: '{target.text}' introuvable "
|
||||
f"(toutes methodes epuisees, {(time.time() - t0) * 1000:.0f}ms)")
|
||||
return None
|
||||
|
||||
# ------------------------------------------------------------------
|
||||
# Methodes individuelles
|
||||
# ------------------------------------------------------------------
|
||||
|
||||
def _try_template(self, target: GroundingTarget) -> Optional[GroundingResult]:
|
||||
"""Template matching — rapide, exact, mais sensible aux changements visuels."""
|
||||
if not target.template_b64:
|
||||
return None
|
||||
|
||||
try:
|
||||
from core.grounding.template_matcher import TemplateMatcher
|
||||
matcher = TemplateMatcher(threshold=self.template_threshold)
|
||||
match = matcher.match_screen(anchor_b64=target.template_b64)
|
||||
if match:
|
||||
print(f"[GroundingPipeline/template] score={match.score:.3f} "
|
||||
f"pos=({match.x},{match.y}) ({match.time_ms:.0f}ms)")
|
||||
return GroundingResult(
|
||||
x=match.x,
|
||||
y=match.y,
|
||||
method='template',
|
||||
confidence=match.score,
|
||||
time_ms=match.time_ms,
|
||||
)
|
||||
else:
|
||||
diag = matcher.match_screen_diagnostic(anchor_b64=target.template_b64)
|
||||
print(f"[GroundingPipeline/template] pas de match — best={diag}")
|
||||
except Exception as e:
|
||||
print(f"[GroundingPipeline/template] ERREUR: {e}")
|
||||
|
||||
return None
|
||||
|
||||
def _try_ocr(self, target: GroundingTarget) -> Optional[GroundingResult]:
|
||||
"""OCR : cherche le texte cible sur l'ecran via docTR."""
|
||||
if not target.text:
|
||||
return None
|
||||
|
||||
try:
|
||||
from core.execution.input_handler import _grounding_ocr
|
||||
bbox = target.original_bbox if target.original_bbox else None
|
||||
result = _grounding_ocr(target.text, anchor_bbox=bbox)
|
||||
if result:
|
||||
print(f"[GroundingPipeline/OCR] '{target.text}' -> ({result['x']}, {result['y']})")
|
||||
return GroundingResult(
|
||||
x=result['x'],
|
||||
y=result['y'],
|
||||
method='ocr',
|
||||
confidence=result.get('confidence', 0.80),
|
||||
time_ms=result.get('time_ms', 0),
|
||||
)
|
||||
else:
|
||||
print(f"[GroundingPipeline/OCR] '{target.text}' non trouve")
|
||||
except Exception as e:
|
||||
print(f"[GroundingPipeline/OCR] ERREUR: {e}")
|
||||
|
||||
return None
|
||||
|
||||
def _try_uitars(self, target: GroundingTarget) -> Optional[GroundingResult]:
|
||||
"""UI-TARS via serveur HTTP — robust, gere les changements de layout."""
|
||||
if not target.text and not target.description:
|
||||
return None
|
||||
|
||||
try:
|
||||
from core.grounding.ui_tars_grounder import UITarsGrounder
|
||||
grounder = UITarsGrounder.get_instance()
|
||||
result = grounder.ground(
|
||||
target_text=target.text,
|
||||
target_description=target.description,
|
||||
)
|
||||
if result:
|
||||
print(f"[GroundingPipeline/UI-TARS] ({result.x}, {result.y}) "
|
||||
f"conf={result.confidence:.2f} ({result.time_ms:.0f}ms)")
|
||||
return result
|
||||
else:
|
||||
print(f"[GroundingPipeline/UI-TARS] pas de resultat")
|
||||
except Exception as e:
|
||||
print(f"[GroundingPipeline/UI-TARS] ERREUR: {e}")
|
||||
|
||||
return None
|
||||
|
||||
def _try_static(self, target: GroundingTarget) -> Optional[GroundingResult]:
|
||||
"""Fallback : coordonnees d'origine du workflow (centre du bounding box)."""
|
||||
bbox = target.original_bbox
|
||||
if not bbox:
|
||||
return None
|
||||
|
||||
w = bbox.get('width', 0)
|
||||
h = bbox.get('height', 0)
|
||||
if not w or not h:
|
||||
return None
|
||||
|
||||
x = int(bbox.get('x', 0) + w / 2)
|
||||
y = int(bbox.get('y', 0) + h / 2)
|
||||
|
||||
print(f"[GroundingPipeline/static] fallback ({x}, {y}) "
|
||||
f"depuis bbox {bbox}")
|
||||
|
||||
return GroundingResult(
|
||||
x=x,
|
||||
y=y,
|
||||
method='static_fallback',
|
||||
confidence=0.30,
|
||||
time_ms=0.0,
|
||||
)
|
||||
433
core/grounding/server.py
Normal file
433
core/grounding/server.py
Normal file
@@ -0,0 +1,433 @@
|
||||
"""
|
||||
core/grounding/server.py — Serveur FastAPI de grounding visuel (port 8200)
|
||||
|
||||
Charge UI-TARS-1.5-7B en 4-bit NF4 dans son propre process Python avec son
|
||||
propre contexte CUDA. Le backend Flask VWB (port 5002) et la boucle ORA
|
||||
appellent ce serveur en HTTP au lieu de charger le modele in-process.
|
||||
|
||||
Lancement :
|
||||
.venv/bin/python3 -m core.grounding.server
|
||||
|
||||
Endpoints :
|
||||
GET /health — verifie que le modele est charge
|
||||
POST /ground — localise un element UI sur un screenshot
|
||||
"""
|
||||
|
||||
import base64
|
||||
import gc
|
||||
import io
|
||||
import math
|
||||
import os
|
||||
import re
|
||||
import time
|
||||
from typing import Optional
|
||||
|
||||
import torch
|
||||
from fastapi import FastAPI, HTTPException
|
||||
from pydantic import BaseModel
|
||||
import uvicorn
|
||||
|
||||
# ---------------------------------------------------------------------------
|
||||
# Configuration
|
||||
# ---------------------------------------------------------------------------
|
||||
|
||||
PORT = int(os.environ.get("GROUNDING_PORT", 8200))
|
||||
MODEL_ID = "ByteDance-Seed/UI-TARS-1.5-7B"
|
||||
MIN_PIXELS = 100 * 28 * 28
|
||||
MAX_PIXELS = 16384 * 28 * 28
|
||||
|
||||
# ---------------------------------------------------------------------------
|
||||
# Smart resize — identique a /tmp/test_uitars.py
|
||||
# ---------------------------------------------------------------------------
|
||||
|
||||
def _smart_resize(height: int, width: int, factor: int = 28,
|
||||
min_pixels: int = MIN_PIXELS, max_pixels: int = MAX_PIXELS):
|
||||
"""UI-TARS smart resize (memes defaults que le test valide)."""
|
||||
h_bar = max(factor, round(height / factor) * factor)
|
||||
w_bar = max(factor, round(width / factor) * factor)
|
||||
if h_bar * w_bar > max_pixels:
|
||||
beta = math.sqrt((height * width) / max_pixels)
|
||||
h_bar = math.floor(height / beta / factor) * factor
|
||||
w_bar = math.floor(width / beta / factor) * factor
|
||||
elif h_bar * w_bar < min_pixels:
|
||||
beta = math.sqrt(min_pixels / (height * width))
|
||||
h_bar = math.ceil(height * beta / factor) * factor
|
||||
w_bar = math.ceil(width * beta / factor) * factor
|
||||
return h_bar, w_bar
|
||||
|
||||
|
||||
# ---------------------------------------------------------------------------
|
||||
# Prompt officiel UI-TARS — identique a /tmp/test_uitars.py
|
||||
# ---------------------------------------------------------------------------
|
||||
|
||||
_GROUNDING_PROMPT = """You are a GUI agent. You are given a task and your action history, with screenshots. You need to perform the next action to complete the task.
|
||||
|
||||
## Output Format
|
||||
|
||||
Thought: ...
|
||||
Action: ...
|
||||
|
||||
|
||||
## Action Space
|
||||
click(start_box='(x1, y1)')
|
||||
|
||||
|
||||
## User Instruction
|
||||
{instruction}"""
|
||||
|
||||
|
||||
# ---------------------------------------------------------------------------
|
||||
# Modele singleton
|
||||
# ---------------------------------------------------------------------------
|
||||
|
||||
_model = None
|
||||
_processor = None
|
||||
_model_loaded = False
|
||||
|
||||
|
||||
def _evict_ollama_models():
|
||||
"""Libere les modeles Ollama de la VRAM avant de charger UI-TARS."""
|
||||
try:
|
||||
import requests
|
||||
try:
|
||||
ps_resp = requests.get('http://localhost:11434/api/ps', timeout=3)
|
||||
if ps_resp.status_code == 200:
|
||||
loaded = ps_resp.json().get('models', [])
|
||||
model_names = [m.get('name', '') for m in loaded if m.get('name')]
|
||||
else:
|
||||
model_names = []
|
||||
except Exception:
|
||||
model_names = []
|
||||
|
||||
if not model_names:
|
||||
print("[grounding-server] Aucun modele Ollama en VRAM")
|
||||
return
|
||||
|
||||
for model_name in model_names:
|
||||
try:
|
||||
requests.post(
|
||||
'http://localhost:11434/api/generate',
|
||||
json={'model': model_name, 'keep_alive': '0'},
|
||||
timeout=5,
|
||||
)
|
||||
print(f"[grounding-server] Ollama: eviction de '{model_name}'")
|
||||
except Exception:
|
||||
pass
|
||||
|
||||
time.sleep(1.0)
|
||||
print("[grounding-server] Modeles Ollama liberes")
|
||||
except ImportError:
|
||||
print("[grounding-server] requests non dispo, skip eviction Ollama")
|
||||
|
||||
|
||||
def _load_model():
|
||||
"""Charge UI-TARS-1.5-7B en 4-bit NF4 — code identique a /tmp/test_uitars.py."""
|
||||
global _model, _processor, _model_loaded
|
||||
|
||||
if _model_loaded:
|
||||
return
|
||||
|
||||
print("=" * 60)
|
||||
print(f"[grounding-server] Chargement de {MODEL_ID}")
|
||||
print("=" * 60)
|
||||
|
||||
if not torch.cuda.is_available():
|
||||
raise RuntimeError("CUDA non disponible — le serveur de grounding necessite un GPU")
|
||||
|
||||
# Liberer la VRAM Ollama
|
||||
_evict_ollama_models()
|
||||
|
||||
torch.cuda.empty_cache()
|
||||
gc.collect()
|
||||
|
||||
from transformers import Qwen2_5_VLForConditionalGeneration, AutoProcessor, BitsAndBytesConfig
|
||||
|
||||
bnb_config = BitsAndBytesConfig(
|
||||
load_in_4bit=True,
|
||||
bnb_4bit_quant_type="nf4",
|
||||
bnb_4bit_compute_dtype=torch.bfloat16,
|
||||
bnb_4bit_use_double_quant=True,
|
||||
)
|
||||
|
||||
t0 = time.time()
|
||||
_model = Qwen2_5_VLForConditionalGeneration.from_pretrained(
|
||||
MODEL_ID,
|
||||
quantization_config=bnb_config,
|
||||
device_map="auto",
|
||||
)
|
||||
_model.eval()
|
||||
|
||||
_processor = AutoProcessor.from_pretrained(
|
||||
MODEL_ID,
|
||||
min_pixels=MIN_PIXELS,
|
||||
max_pixels=MAX_PIXELS,
|
||||
)
|
||||
|
||||
_model_loaded = True
|
||||
load_time = time.time() - t0
|
||||
alloc = torch.cuda.memory_allocated() / 1024**3
|
||||
peak = torch.cuda.max_memory_allocated() / 1024**3
|
||||
print(f"[grounding-server] Modele charge en {load_time:.1f}s | "
|
||||
f"VRAM: {alloc:.2f} GB (peak: {peak:.2f} GB)")
|
||||
|
||||
|
||||
def _capture_screen():
|
||||
"""Capture l'ecran complet via mss. Retourne PIL Image ou None."""
|
||||
try:
|
||||
import mss as mss_lib
|
||||
from PIL import Image
|
||||
with mss_lib.mss() as sct:
|
||||
mon = sct.monitors[0]
|
||||
grab = sct.grab(mon)
|
||||
return Image.frombytes('RGB', grab.size, grab.bgra, 'raw', 'BGRX')
|
||||
except Exception as e:
|
||||
print(f"[grounding-server] Erreur capture ecran: {e}")
|
||||
return None
|
||||
|
||||
|
||||
def _parse_coordinates(raw: str, orig_w: int, orig_h: int,
|
||||
resized_w: int, resized_h: int):
|
||||
"""Parse les coordonnees du modele — identique a /tmp/test_uitars.py.
|
||||
|
||||
Retourne (px, py, method_detail, confidence) ou None.
|
||||
"""
|
||||
cx, cy = None, None
|
||||
|
||||
# Format 1: <point>x y</point>
|
||||
pm = re.search(r'<point>\s*(\d+)\s+(\d+)\s*</point>', raw)
|
||||
if pm:
|
||||
cx, cy = int(pm.group(1)), int(pm.group(2))
|
||||
|
||||
# Format 2: start_box='(x, y)'
|
||||
if cx is None:
|
||||
bm = re.search(r"start_box=\s*['\"]?\((\d+)\s*,\s*(\d+)\)", raw)
|
||||
if bm:
|
||||
cx, cy = int(bm.group(1)), int(bm.group(2))
|
||||
|
||||
# Format 3: fallback x, y
|
||||
if cx is None:
|
||||
fm = re.search(r'(\d+)\s*,\s*(\d+)', raw)
|
||||
if fm:
|
||||
cx, cy = int(fm.group(1)), int(fm.group(2))
|
||||
|
||||
if cx is None or cy is None:
|
||||
return None
|
||||
|
||||
# Conversion : tester les 2 interpretations, garder la meilleure
|
||||
# Methode A : coordonnees dans l'espace de l'image resizee
|
||||
px_r = int(cx / resized_w * orig_w)
|
||||
py_r = int(cy / resized_h * orig_h)
|
||||
delta_r = ((px_r - orig_w / 2) ** 2 + (py_r - orig_h / 2) ** 2) ** 0.5
|
||||
|
||||
# Methode B : coordonnees 0-1000
|
||||
px_1k = int(cx / 1000 * orig_w)
|
||||
py_1k = int(cy / 1000 * orig_h)
|
||||
delta_1k = ((px_1k - orig_w / 2) ** 2 + (py_1k - orig_h / 2) ** 2) ** 0.5
|
||||
|
||||
# Heuristique du script valide : si coords dans les limites du resize,
|
||||
# les deux sont possibles. UI-TARS utilise l'espace resize en natif.
|
||||
if cx <= resized_w and cy <= resized_h:
|
||||
in_screen_r = (0 <= px_r <= orig_w and 0 <= py_r <= orig_h)
|
||||
in_screen_1k = (0 <= px_1k <= orig_w and 0 <= py_1k <= orig_h)
|
||||
|
||||
if in_screen_r and in_screen_1k:
|
||||
px, py = px_r, py_r
|
||||
method_detail = "resized"
|
||||
elif in_screen_r:
|
||||
px, py = px_r, py_r
|
||||
method_detail = "resized"
|
||||
else:
|
||||
px, py = px_1k, py_1k
|
||||
method_detail = "0-1000"
|
||||
else:
|
||||
px, py = px_1k, py_1k
|
||||
method_detail = "0-1000"
|
||||
|
||||
confidence = 0.85 if ("start_box" in raw or "<point>" in raw) else 0.70
|
||||
|
||||
print(f"[grounding-server] model=({cx},{cy}) -> pixel=({px},{py}) "
|
||||
f"[{method_detail}] resized={resized_w}x{resized_h} orig={orig_w}x{orig_h}")
|
||||
|
||||
return px, py, method_detail, confidence
|
||||
|
||||
|
||||
# ---------------------------------------------------------------------------
|
||||
# FastAPI app
|
||||
# ---------------------------------------------------------------------------
|
||||
|
||||
app = FastAPI(title="RPA Vision Grounding Server", version="1.0.0")
|
||||
|
||||
|
||||
class GroundRequest(BaseModel):
|
||||
target_text: str = ""
|
||||
target_description: str = ""
|
||||
image_b64: str = ""
|
||||
|
||||
|
||||
class GroundResponse(BaseModel):
|
||||
x: Optional[int] = None
|
||||
y: Optional[int] = None
|
||||
method: str = "ui_tars"
|
||||
confidence: float = 0.85
|
||||
time_ms: float = 0.0
|
||||
raw_output: str = ""
|
||||
|
||||
|
||||
@app.get("/health")
|
||||
def health():
|
||||
return {
|
||||
"status": "ok" if _model_loaded else "loading",
|
||||
"model": MODEL_ID,
|
||||
"model_loaded": _model_loaded,
|
||||
"cuda_available": torch.cuda.is_available(),
|
||||
"vram_allocated_gb": round(torch.cuda.memory_allocated() / 1024**3, 2) if torch.cuda.is_available() else 0,
|
||||
}
|
||||
|
||||
|
||||
@app.post("/ground", response_model=GroundResponse)
|
||||
def ground(req: GroundRequest):
|
||||
if not _model_loaded:
|
||||
raise HTTPException(status_code=503, detail="Modele pas encore charge")
|
||||
|
||||
from PIL import Image
|
||||
from qwen_vl_utils import process_vision_info
|
||||
|
||||
# Construire l'instruction
|
||||
parts = []
|
||||
if req.target_text:
|
||||
parts.append(req.target_text)
|
||||
if req.target_description:
|
||||
parts.append(req.target_description)
|
||||
if not parts:
|
||||
raise HTTPException(status_code=400, detail="target_text ou target_description requis")
|
||||
|
||||
instruction = f"Click on the {' — '.join(parts)}"
|
||||
|
||||
# Obtenir l'image (fournie en b64 ou capture ecran)
|
||||
if req.image_b64:
|
||||
try:
|
||||
raw_b64 = req.image_b64.split(',')[1] if ',' in req.image_b64 else req.image_b64
|
||||
img_data = base64.b64decode(raw_b64)
|
||||
screen_pil = Image.open(io.BytesIO(img_data)).convert('RGB')
|
||||
except Exception as e:
|
||||
raise HTTPException(status_code=400, detail=f"Erreur decodage image: {e}")
|
||||
else:
|
||||
screen_pil = _capture_screen()
|
||||
if screen_pil is None:
|
||||
raise HTTPException(status_code=500, detail="Capture ecran echouee")
|
||||
|
||||
W, H = screen_pil.size
|
||||
rH, rW = _smart_resize(H, W, min_pixels=MIN_PIXELS, max_pixels=MAX_PIXELS)
|
||||
|
||||
# Sauver temporairement l'image pour qwen_vl_utils
|
||||
import tempfile
|
||||
tmp_path = os.path.join(tempfile.gettempdir(), f"grounding_screen_{os.getpid()}.png")
|
||||
screen_pil.save(tmp_path)
|
||||
|
||||
try:
|
||||
system_prompt = _GROUNDING_PROMPT.format(instruction=instruction)
|
||||
|
||||
messages = [
|
||||
{
|
||||
"role": "user",
|
||||
"content": [
|
||||
{
|
||||
"type": "image",
|
||||
"image": f"file://{tmp_path}",
|
||||
"min_pixels": MIN_PIXELS,
|
||||
"max_pixels": MAX_PIXELS,
|
||||
},
|
||||
{
|
||||
"type": "text",
|
||||
"text": system_prompt,
|
||||
},
|
||||
],
|
||||
}
|
||||
]
|
||||
|
||||
text = _processor.apply_chat_template(
|
||||
messages, tokenize=False, add_generation_prompt=True
|
||||
)
|
||||
image_inputs, video_inputs = process_vision_info(messages)
|
||||
inputs = _processor(
|
||||
text=[text],
|
||||
images=image_inputs,
|
||||
videos=video_inputs,
|
||||
padding=True,
|
||||
return_tensors="pt",
|
||||
).to(_model.device)
|
||||
|
||||
# Inference
|
||||
t0 = time.time()
|
||||
with torch.no_grad():
|
||||
gen = _model.generate(**inputs, max_new_tokens=256)
|
||||
infer_ms = (time.time() - t0) * 1000
|
||||
|
||||
# Decoder
|
||||
trimmed = [o[len(i):] for i, o in zip(inputs.input_ids, gen)]
|
||||
raw = _processor.batch_decode(
|
||||
trimmed, skip_special_tokens=True, clean_up_tokenization_spaces=False
|
||||
)[0].strip()
|
||||
|
||||
print(f"[grounding-server] '{instruction}' -> raw='{raw[:150]}' ({infer_ms:.0f}ms)")
|
||||
|
||||
# Détecter les réponses négatives (le modèle dit qu'il ne voit pas l'élément)
|
||||
_raw_lower = raw.lower()
|
||||
_negative_markers = ["don't see", "do not see", "cannot find", "can't find",
|
||||
"not visible", "not found", "doesn't appear", "does not appear",
|
||||
"i don't", "unable to find", "unable to locate"]
|
||||
for _neg in _negative_markers:
|
||||
if _neg in _raw_lower:
|
||||
print(f"[grounding-server] NÉGATIF détecté: '{_neg}' → élément non trouvé")
|
||||
return GroundResponse(x=None, y=None, method="ui_tars", confidence=0.0,
|
||||
time_ms=round(infer_ms, 1), raw_output=raw[:300])
|
||||
|
||||
# Parser les coordonnees
|
||||
parsed = _parse_coordinates(raw, W, H, rW, rH)
|
||||
if parsed is None:
|
||||
raise HTTPException(
|
||||
status_code=422,
|
||||
detail=f"Coordonnees non parsees dans la reponse: {raw[:200]}"
|
||||
)
|
||||
|
||||
px, py, method_detail, confidence = parsed
|
||||
|
||||
print(f"[grounding-server] Resultat: ({px}, {py}) conf={confidence:.2f} "
|
||||
f"[{method_detail}] ({infer_ms:.0f}ms)")
|
||||
|
||||
return GroundResponse(
|
||||
x=px,
|
||||
y=py,
|
||||
method="ui_tars",
|
||||
confidence=confidence,
|
||||
time_ms=round(infer_ms, 1),
|
||||
raw_output=raw[:300],
|
||||
)
|
||||
|
||||
finally:
|
||||
try:
|
||||
os.unlink(tmp_path)
|
||||
except OSError:
|
||||
pass
|
||||
|
||||
|
||||
# ---------------------------------------------------------------------------
|
||||
# Entrypoint
|
||||
# ---------------------------------------------------------------------------
|
||||
|
||||
@app.on_event("startup")
|
||||
async def startup_event():
|
||||
"""Charge le modele au demarrage du serveur."""
|
||||
print(f"[grounding-server] Demarrage sur port {PORT}...")
|
||||
_load_model()
|
||||
print(f"[grounding-server] Pret a recevoir des requetes sur http://localhost:{PORT}")
|
||||
|
||||
|
||||
if __name__ == "__main__":
|
||||
uvicorn.run(
|
||||
"core.grounding.server:app",
|
||||
host="0.0.0.0",
|
||||
port=PORT,
|
||||
log_level="info",
|
||||
workers=1, # 1 seul worker (1 seul GPU)
|
||||
)
|
||||
48
core/grounding/target.py
Normal file
48
core/grounding/target.py
Normal file
@@ -0,0 +1,48 @@
|
||||
"""
|
||||
core/grounding/target.py — Types partagés pour le grounding visuel
|
||||
|
||||
Dataclasses décrivant une cible à localiser (GroundingTarget) et
|
||||
le résultat d'une localisation (GroundingResult).
|
||||
|
||||
Ces types sont la brique commune pour tous les modules de grounding :
|
||||
template matching, OCR, VLM, CLIP, etc.
|
||||
"""
|
||||
|
||||
from __future__ import annotations
|
||||
|
||||
from dataclasses import dataclass, field
|
||||
from typing import Dict, Optional
|
||||
|
||||
|
||||
@dataclass
|
||||
class GroundingTarget:
|
||||
"""Description d'un élément UI à localiser sur l'écran.
|
||||
|
||||
Attributs :
|
||||
text : texte visible de l'élément (bouton, label, etc.)
|
||||
description : description sémantique libre (ex: "le bouton Valider en bas à droite")
|
||||
template_b64 : capture visuelle de l'élément, encodée en base64 PNG/JPEG
|
||||
original_bbox : position d'origine lors de la capture {x, y, width, height}
|
||||
"""
|
||||
text: str = ""
|
||||
description: str = ""
|
||||
template_b64: str = ""
|
||||
original_bbox: Optional[Dict[str, int]] = field(default=None)
|
||||
|
||||
|
||||
@dataclass
|
||||
class GroundingResult:
|
||||
"""Résultat d'une localisation d'élément UI.
|
||||
|
||||
Attributs :
|
||||
x : coordonnée X du centre de l'élément trouvé (pixels écran)
|
||||
y : coordonnée Y du centre de l'élément trouvé (pixels écran)
|
||||
method : méthode ayant produit le résultat ('template', 'ocr', 'vlm', 'clip', etc.)
|
||||
confidence : score de confiance [0.0 – 1.0]
|
||||
time_ms : temps de recherche en millisecondes
|
||||
"""
|
||||
x: int
|
||||
y: int
|
||||
method: str
|
||||
confidence: float
|
||||
time_ms: float
|
||||
350
core/grounding/template_matcher.py
Normal file
350
core/grounding/template_matcher.py
Normal file
@@ -0,0 +1,350 @@
|
||||
"""
|
||||
core/grounding/template_matcher.py — Template matching centralisé
|
||||
|
||||
Fournit une classe TemplateMatcher qui localise une ancre visuelle (image template)
|
||||
dans un screenshot via cv2.matchTemplate. Supporte single-scale et multi-scale.
|
||||
|
||||
Remplace les implémentations dupliquées dans :
|
||||
- core/execution/observe_reason_act.py (~1348-1375)
|
||||
- visual_workflow_builder/backend/api_v3/execute.py (~930-963)
|
||||
- visual_workflow_builder/backend/catalog_routes_v2_vlm.py (~339-381)
|
||||
- visual_workflow_builder/backend/services/intelligent_executor.py (~131-210)
|
||||
- core/detection/omniparser_adapter.py (~330)
|
||||
|
||||
Utilisation :
|
||||
from core.grounding import TemplateMatcher, MatchResult
|
||||
|
||||
matcher = TemplateMatcher(threshold=0.75)
|
||||
result = matcher.match_screen(anchor_b64="...")
|
||||
if result:
|
||||
print(f"Trouvé à ({result.x}, {result.y}) score={result.score:.3f}")
|
||||
"""
|
||||
|
||||
from __future__ import annotations
|
||||
|
||||
import base64
|
||||
import io
|
||||
import logging
|
||||
import time
|
||||
from dataclasses import dataclass
|
||||
from typing import List, Optional, Tuple
|
||||
|
||||
logger = logging.getLogger(__name__)
|
||||
|
||||
# Imports optionnels — le module se charge même sans cv2/PIL/mss
|
||||
try:
|
||||
import cv2
|
||||
_CV2 = True
|
||||
except ImportError:
|
||||
_CV2 = False
|
||||
|
||||
try:
|
||||
import numpy as np
|
||||
_NP = True
|
||||
except ImportError:
|
||||
_NP = False
|
||||
|
||||
try:
|
||||
from PIL import Image
|
||||
_PIL = True
|
||||
except ImportError:
|
||||
_PIL = False
|
||||
|
||||
try:
|
||||
import mss as mss_lib
|
||||
_MSS = True
|
||||
except ImportError:
|
||||
_MSS = False
|
||||
|
||||
|
||||
# ---------------------------------------------------------------------------
|
||||
# Résultat d'un match
|
||||
# ---------------------------------------------------------------------------
|
||||
|
||||
@dataclass
|
||||
class MatchResult:
|
||||
"""Résultat d'un template matching."""
|
||||
x: int
|
||||
y: int
|
||||
score: float
|
||||
method: str # 'template' | 'template_multiscale'
|
||||
time_ms: float
|
||||
scale: float = 1.0 # Échelle à laquelle le meilleur match a été trouvé
|
||||
|
||||
|
||||
# ---------------------------------------------------------------------------
|
||||
# TemplateMatcher
|
||||
# ---------------------------------------------------------------------------
|
||||
|
||||
class TemplateMatcher:
|
||||
"""Localise une ancre visuelle dans un screenshot via template matching.
|
||||
|
||||
Paramètres :
|
||||
threshold : score minimum pour accepter un match (défaut 0.75)
|
||||
multiscale : active le matching multi-échelle (défaut False)
|
||||
scales : liste d'échelles à tester en mode multi-scale
|
||||
method : méthode cv2 (défaut cv2.TM_CCOEFF_NORMED)
|
||||
grayscale : convertir en niveaux de gris avant matching (défaut False)
|
||||
"""
|
||||
|
||||
# Échelles par défaut pour le mode multi-scale, ordonnées par
|
||||
# probabilité décroissante (1.0 en premier = rapide si ça matche)
|
||||
DEFAULT_SCALES: List[float] = [1.0, 0.95, 1.05, 0.9, 1.1, 0.85, 1.15, 0.8, 1.2]
|
||||
|
||||
def __init__(
|
||||
self,
|
||||
threshold: float = 0.75,
|
||||
multiscale: bool = False,
|
||||
scales: Optional[List[float]] = None,
|
||||
grayscale: bool = False,
|
||||
):
|
||||
self.threshold = threshold
|
||||
self.multiscale = multiscale
|
||||
self.scales = scales or self.DEFAULT_SCALES
|
||||
self.grayscale = grayscale
|
||||
# cv2.TM_CCOEFF_NORMED est la méthode utilisée partout dans le projet
|
||||
self._cv2_method = cv2.TM_CCOEFF_NORMED if _CV2 else None
|
||||
|
||||
# ------------------------------------------------------------------
|
||||
# API publique
|
||||
# ------------------------------------------------------------------
|
||||
|
||||
def match_screen(
|
||||
self,
|
||||
anchor_b64: Optional[str] = None,
|
||||
anchor_pil: Optional["Image.Image"] = None,
|
||||
screen_pil: Optional["Image.Image"] = None,
|
||||
) -> Optional[MatchResult]:
|
||||
"""Cherche l'ancre dans le screenshot courant (ou fourni).
|
||||
|
||||
L'ancre peut être passée en base64 ou en PIL Image.
|
||||
Le screenshot est capturé via mss si non fourni.
|
||||
|
||||
Retourne un MatchResult ou None si aucun match >= seuil.
|
||||
"""
|
||||
if not (_CV2 and _NP and _PIL):
|
||||
logger.debug("[TemplateMatcher] cv2/numpy/PIL non disponible")
|
||||
return None
|
||||
|
||||
# --- Préparer l'ancre ---
|
||||
anchor_img = self._decode_anchor(anchor_b64, anchor_pil)
|
||||
if anchor_img is None:
|
||||
return None
|
||||
|
||||
# --- Préparer le screenshot ---
|
||||
if screen_pil is None:
|
||||
screen_pil = self._capture_screen()
|
||||
if screen_pil is None:
|
||||
return None
|
||||
|
||||
# --- Convertir en arrays cv2 ---
|
||||
screen_cv = cv2.cvtColor(np.array(screen_pil), cv2.COLOR_RGB2BGR)
|
||||
anchor_cv = cv2.cvtColor(np.array(anchor_img), cv2.COLOR_RGB2BGR)
|
||||
|
||||
# --- Matching ---
|
||||
if self.multiscale:
|
||||
return self._match_multiscale(screen_cv, anchor_cv)
|
||||
else:
|
||||
return self._match_single(screen_cv, anchor_cv)
|
||||
|
||||
def match_in_region(
|
||||
self,
|
||||
region_cv: "np.ndarray",
|
||||
anchor_cv: "np.ndarray",
|
||||
threshold: Optional[float] = None,
|
||||
) -> Optional[MatchResult]:
|
||||
"""Match dans une région déjà découpée (arrays BGR).
|
||||
|
||||
Utilisé par les pipelines qui font leur propre capture/découpe.
|
||||
"""
|
||||
if not (_CV2 and _NP):
|
||||
return None
|
||||
|
||||
thr = threshold if threshold is not None else self.threshold
|
||||
|
||||
if self.multiscale:
|
||||
return self._match_multiscale(region_cv, anchor_cv, threshold_override=thr)
|
||||
else:
|
||||
return self._match_single(region_cv, anchor_cv, threshold_override=thr)
|
||||
|
||||
def match_screen_diagnostic(
|
||||
self,
|
||||
anchor_b64: Optional[str] = None,
|
||||
anchor_pil: Optional["Image.Image"] = None,
|
||||
screen_pil: Optional["Image.Image"] = None,
|
||||
) -> str:
|
||||
"""Retourne un diagnostic textuel (score + position) même sans match."""
|
||||
if not (_CV2 and _NP and _PIL):
|
||||
return "cv2/numpy/PIL non dispo"
|
||||
|
||||
anchor_img = self._decode_anchor(anchor_b64, anchor_pil)
|
||||
if anchor_img is None:
|
||||
return "ancre non décodable"
|
||||
|
||||
if screen_pil is None:
|
||||
screen_pil = self._capture_screen()
|
||||
if screen_pil is None:
|
||||
return "capture écran échouée"
|
||||
|
||||
screen_cv = cv2.cvtColor(np.array(screen_pil), cv2.COLOR_RGB2BGR)
|
||||
anchor_cv = cv2.cvtColor(np.array(anchor_img), cv2.COLOR_RGB2BGR)
|
||||
|
||||
if anchor_cv.shape[0] >= screen_cv.shape[0] or anchor_cv.shape[1] >= screen_cv.shape[1]:
|
||||
return f"ancre {anchor_cv.shape[:2]} >= écran {screen_cv.shape[:2]}"
|
||||
|
||||
s_img, a_img = self._maybe_grayscale(screen_cv, anchor_cv)
|
||||
result_tm = cv2.matchTemplate(s_img, a_img, self._cv2_method)
|
||||
_, max_val, _, max_loc = cv2.minMaxLoc(result_tm)
|
||||
return f"{max_val:.3f} pos={max_loc}"
|
||||
|
||||
# ------------------------------------------------------------------
|
||||
# Méthodes internes
|
||||
# ------------------------------------------------------------------
|
||||
|
||||
def _match_single(
|
||||
self,
|
||||
screen_cv: "np.ndarray",
|
||||
anchor_cv: "np.ndarray",
|
||||
threshold_override: Optional[float] = None,
|
||||
) -> Optional[MatchResult]:
|
||||
"""Template matching single-scale."""
|
||||
threshold = threshold_override if threshold_override is not None else self.threshold
|
||||
|
||||
if anchor_cv.shape[0] >= screen_cv.shape[0] or anchor_cv.shape[1] >= screen_cv.shape[1]:
|
||||
logger.debug("[TemplateMatcher] Ancre plus grande que le screen")
|
||||
return None
|
||||
|
||||
s_img, a_img = self._maybe_grayscale(screen_cv, anchor_cv)
|
||||
|
||||
t0 = time.time()
|
||||
result_tm = cv2.matchTemplate(s_img, a_img, self._cv2_method)
|
||||
_, max_val, _, max_loc = cv2.minMaxLoc(result_tm)
|
||||
elapsed_ms = (time.time() - t0) * 1000
|
||||
|
||||
logger.debug(
|
||||
"[TemplateMatcher] score=%.3f pos=%s (%.0fms)",
|
||||
max_val, max_loc, elapsed_ms,
|
||||
)
|
||||
|
||||
if max_val >= threshold:
|
||||
cx = max_loc[0] + anchor_cv.shape[1] // 2
|
||||
cy = max_loc[1] + anchor_cv.shape[0] // 2
|
||||
return MatchResult(
|
||||
x=cx,
|
||||
y=cy,
|
||||
score=float(max_val),
|
||||
method='template',
|
||||
time_ms=elapsed_ms,
|
||||
scale=1.0,
|
||||
)
|
||||
return None
|
||||
|
||||
def _match_multiscale(
|
||||
self,
|
||||
screen_cv: "np.ndarray",
|
||||
anchor_cv: "np.ndarray",
|
||||
threshold_override: Optional[float] = None,
|
||||
) -> Optional[MatchResult]:
|
||||
"""Template matching multi-scale."""
|
||||
threshold = threshold_override if threshold_override is not None else self.threshold
|
||||
|
||||
best_score = -1.0
|
||||
best_loc = None
|
||||
best_scale = 1.0
|
||||
best_anchor_shape = anchor_cv.shape
|
||||
|
||||
t0 = time.time()
|
||||
|
||||
for scale in self.scales:
|
||||
if scale == 1.0:
|
||||
scaled = anchor_cv
|
||||
else:
|
||||
new_w = int(anchor_cv.shape[1] * scale)
|
||||
new_h = int(anchor_cv.shape[0] * scale)
|
||||
if new_w < 8 or new_h < 8:
|
||||
continue
|
||||
if new_h >= screen_cv.shape[0] or new_w >= screen_cv.shape[1]:
|
||||
continue
|
||||
scaled = cv2.resize(anchor_cv, (new_w, new_h), interpolation=cv2.INTER_AREA)
|
||||
|
||||
if scaled.shape[0] >= screen_cv.shape[0] or scaled.shape[1] >= screen_cv.shape[1]:
|
||||
continue
|
||||
|
||||
s_img, a_img = self._maybe_grayscale(screen_cv, scaled)
|
||||
result_tm = cv2.matchTemplate(s_img, a_img, self._cv2_method)
|
||||
_, max_val, _, max_loc = cv2.minMaxLoc(result_tm)
|
||||
|
||||
if max_val > best_score:
|
||||
best_score = max_val
|
||||
best_loc = max_loc
|
||||
best_scale = scale
|
||||
best_anchor_shape = scaled.shape
|
||||
|
||||
elapsed_ms = (time.time() - t0) * 1000
|
||||
|
||||
logger.debug(
|
||||
"[TemplateMatcher/multiscale] best_score=%.3f scale=%.2f (%.0fms)",
|
||||
best_score, best_scale, elapsed_ms,
|
||||
)
|
||||
|
||||
if best_score >= threshold and best_loc is not None:
|
||||
cx = best_loc[0] + best_anchor_shape[1] // 2
|
||||
cy = best_loc[1] + best_anchor_shape[0] // 2
|
||||
return MatchResult(
|
||||
x=cx,
|
||||
y=cy,
|
||||
score=float(best_score),
|
||||
method='template_multiscale',
|
||||
time_ms=elapsed_ms,
|
||||
scale=best_scale,
|
||||
)
|
||||
return None
|
||||
|
||||
def _maybe_grayscale(
|
||||
self,
|
||||
screen: "np.ndarray",
|
||||
anchor: "np.ndarray",
|
||||
) -> Tuple["np.ndarray", "np.ndarray"]:
|
||||
"""Convertit en niveaux de gris si self.grayscale est True."""
|
||||
if not self.grayscale:
|
||||
return screen, anchor
|
||||
s = cv2.cvtColor(screen, cv2.COLOR_BGR2GRAY) if len(screen.shape) == 3 else screen
|
||||
a = cv2.cvtColor(anchor, cv2.COLOR_BGR2GRAY) if len(anchor.shape) == 3 else anchor
|
||||
return s, a
|
||||
|
||||
@staticmethod
|
||||
def _decode_anchor(
|
||||
anchor_b64: Optional[str],
|
||||
anchor_pil: Optional["Image.Image"],
|
||||
) -> Optional["Image.Image"]:
|
||||
"""Décode l'ancre depuis base64 ou retourne le PIL directement."""
|
||||
if anchor_pil is not None:
|
||||
return anchor_pil
|
||||
|
||||
if anchor_b64 is None:
|
||||
logger.debug("[TemplateMatcher] Ni anchor_b64 ni anchor_pil fourni")
|
||||
return None
|
||||
|
||||
try:
|
||||
raw = anchor_b64.split(',')[1] if ',' in anchor_b64 else anchor_b64
|
||||
data = base64.b64decode(raw)
|
||||
return Image.open(io.BytesIO(data))
|
||||
except Exception as e:
|
||||
logger.debug("[TemplateMatcher] Erreur décodage ancre: %s", e)
|
||||
return None
|
||||
|
||||
@staticmethod
|
||||
def _capture_screen() -> Optional["Image.Image"]:
|
||||
"""Capture l'écran complet via mss (moniteur 0 = tous les écrans)."""
|
||||
if not _MSS:
|
||||
logger.debug("[TemplateMatcher] mss non disponible")
|
||||
return None
|
||||
|
||||
try:
|
||||
with mss_lib.mss() as sct:
|
||||
mon = sct.monitors[0]
|
||||
grab = sct.grab(mon)
|
||||
return Image.frombytes('RGB', grab.size, grab.bgra, 'raw', 'BGRX')
|
||||
except Exception as e:
|
||||
logger.debug("[TemplateMatcher] Erreur capture écran: %s", e)
|
||||
return None
|
||||
204
core/grounding/ui_tars_grounder.py
Normal file
204
core/grounding/ui_tars_grounder.py
Normal file
@@ -0,0 +1,204 @@
|
||||
"""
|
||||
core/grounding/ui_tars_grounder.py — Client HTTP pour le serveur de grounding
|
||||
|
||||
Remplace le chargement in-process du modele UI-TARS (qui crashe dans Flask
|
||||
a cause de conflits CUDA) par un CLIENT HTTP qui appelle le serveur de
|
||||
grounding separe sur le port 8200.
|
||||
|
||||
Le serveur est lance separement via :
|
||||
.venv/bin/python3 -m core.grounding.server
|
||||
|
||||
Utilisation (inchangee) :
|
||||
from core.grounding.ui_tars_grounder import UITarsGrounder
|
||||
|
||||
grounder = UITarsGrounder.get_instance()
|
||||
result = grounder.ground("Bouton Valider", "le bouton vert en bas a droite")
|
||||
if result:
|
||||
print(f"Trouve a ({result.x}, {result.y})")
|
||||
"""
|
||||
|
||||
from __future__ import annotations
|
||||
|
||||
import base64
|
||||
import io
|
||||
import os
|
||||
import threading
|
||||
import time
|
||||
from typing import Optional
|
||||
|
||||
from core.grounding.target import GroundingResult
|
||||
|
||||
# ---------------------------------------------------------------------------
|
||||
# Singleton
|
||||
# ---------------------------------------------------------------------------
|
||||
|
||||
_instance: Optional[UITarsGrounder] = None
|
||||
_instance_lock = threading.Lock()
|
||||
|
||||
|
||||
class UITarsGrounder:
|
||||
"""Client HTTP pour le serveur de grounding UI-TARS (port 8200).
|
||||
|
||||
Singleton : utiliser get_instance() pour obtenir l'instance unique.
|
||||
Le serveur doit etre lance separement (.venv/bin/python3 -m core.grounding.server).
|
||||
"""
|
||||
|
||||
SERVER_URL = os.environ.get("GROUNDING_SERVER_URL", "http://localhost:8200")
|
||||
|
||||
def __init__(self):
|
||||
self._server_available: Optional[bool] = None
|
||||
self._last_check = 0.0
|
||||
|
||||
@classmethod
|
||||
def get_instance(cls) -> UITarsGrounder:
|
||||
"""Retourne l'instance singleton du grounder."""
|
||||
global _instance
|
||||
if _instance is None:
|
||||
with _instance_lock:
|
||||
if _instance is None:
|
||||
_instance = cls()
|
||||
return _instance
|
||||
|
||||
# ------------------------------------------------------------------
|
||||
# Verification du serveur
|
||||
# ------------------------------------------------------------------
|
||||
|
||||
def _check_server(self, force: bool = False) -> bool:
|
||||
"""Verifie si le serveur de grounding est disponible.
|
||||
|
||||
Cache le resultat pendant 30 secondes pour eviter le spam.
|
||||
"""
|
||||
now = time.time()
|
||||
if not force and self._server_available is not None and (now - self._last_check) < 30:
|
||||
return self._server_available
|
||||
|
||||
try:
|
||||
import requests
|
||||
resp = requests.get(f"{self.SERVER_URL}/health", timeout=3)
|
||||
if resp.status_code == 200:
|
||||
data = resp.json()
|
||||
self._server_available = data.get("model_loaded", False)
|
||||
if not self._server_available:
|
||||
print(f"[UI-TARS/client] Serveur en cours de chargement...")
|
||||
else:
|
||||
self._server_available = False
|
||||
except Exception:
|
||||
self._server_available = False
|
||||
|
||||
self._last_check = now
|
||||
|
||||
if not self._server_available:
|
||||
print(f"[UI-TARS/client] Serveur non disponible sur {self.SERVER_URL} "
|
||||
f"— lancer: .venv/bin/python3 -m core.grounding.server")
|
||||
|
||||
return self._server_available
|
||||
|
||||
@property
|
||||
def is_loaded(self) -> bool:
|
||||
"""Compatibilite : verifie si le serveur est pret."""
|
||||
return self._check_server()
|
||||
|
||||
def load(self) -> None:
|
||||
"""Compatibilite : ne fait rien (le serveur charge le modele au demarrage)."""
|
||||
if not self._check_server(force=True):
|
||||
print(f"[UI-TARS/client] ATTENTION: serveur non disponible sur {self.SERVER_URL}")
|
||||
print(f"[UI-TARS/client] Lancer le serveur: .venv/bin/python3 -m core.grounding.server")
|
||||
|
||||
def unload(self) -> None:
|
||||
"""Compatibilite : ne fait rien (le modele vit dans le process serveur)."""
|
||||
pass
|
||||
|
||||
# ------------------------------------------------------------------
|
||||
# Grounding via HTTP
|
||||
# ------------------------------------------------------------------
|
||||
|
||||
def ground(
|
||||
self,
|
||||
target_text: str = "",
|
||||
target_description: str = "",
|
||||
screen_pil: Optional["PIL.Image.Image"] = None,
|
||||
) -> Optional[GroundingResult]:
|
||||
"""Localise un element UI en appelant le serveur de grounding.
|
||||
|
||||
Args:
|
||||
target_text: texte visible de l'element (ex: "Valider", "Rechercher")
|
||||
target_description: description semantique (ex: "le bouton vert en bas")
|
||||
screen_pil: screenshot PIL, le serveur capture si None
|
||||
|
||||
Returns:
|
||||
GroundingResult avec coordonnees en pixels ecran, ou None si echec
|
||||
"""
|
||||
if not target_text and not target_description:
|
||||
print("[UI-TARS/client] Pas de target_text ni target_description")
|
||||
return None
|
||||
|
||||
# Verifier que le serveur est disponible
|
||||
if not self._check_server():
|
||||
return None
|
||||
|
||||
import requests
|
||||
|
||||
# Encoder l'image en base64 si fournie
|
||||
image_b64 = ""
|
||||
if screen_pil is not None:
|
||||
try:
|
||||
buffer = io.BytesIO()
|
||||
screen_pil.save(buffer, format='PNG')
|
||||
image_b64 = base64.b64encode(buffer.getvalue()).decode('utf-8')
|
||||
except Exception as e:
|
||||
print(f"[UI-TARS/client] Erreur encodage image: {e}")
|
||||
# Continuer sans image — le serveur capturera l'ecran
|
||||
|
||||
payload = {
|
||||
"target_text": target_text,
|
||||
"target_description": target_description,
|
||||
"image_b64": image_b64,
|
||||
}
|
||||
|
||||
try:
|
||||
t0 = time.time()
|
||||
resp = requests.post(
|
||||
f"{self.SERVER_URL}/ground",
|
||||
json=payload,
|
||||
timeout=30, # UI-TARS peut prendre 3-5s + overhead reseau
|
||||
)
|
||||
total_ms = (time.time() - t0) * 1000
|
||||
|
||||
if resp.status_code == 200:
|
||||
data = resp.json()
|
||||
result = GroundingResult(
|
||||
x=data["x"],
|
||||
y=data["y"],
|
||||
method=data.get("method", "ui_tars"),
|
||||
confidence=data.get("confidence", 0.85),
|
||||
time_ms=data.get("time_ms", total_ms),
|
||||
)
|
||||
print(f"[UI-TARS/client] '{target_text or target_description}' -> "
|
||||
f"({result.x}, {result.y}) conf={result.confidence:.2f} "
|
||||
f"({result.time_ms:.0f}ms)")
|
||||
return result
|
||||
|
||||
elif resp.status_code == 422:
|
||||
# Coordonnees non parsees
|
||||
detail = resp.json().get("detail", "")
|
||||
print(f"[UI-TARS/client] Pas de coordonnees parsees: {detail[:150]}")
|
||||
return None
|
||||
|
||||
elif resp.status_code == 503:
|
||||
print(f"[UI-TARS/client] Serveur pas encore pret (modele en chargement)")
|
||||
return None
|
||||
|
||||
else:
|
||||
print(f"[UI-TARS/client] Erreur HTTP {resp.status_code}: {resp.text[:200]}")
|
||||
return None
|
||||
|
||||
except requests.exceptions.ConnectionError:
|
||||
self._server_available = False
|
||||
print(f"[UI-TARS/client] Serveur non joignable sur {self.SERVER_URL}")
|
||||
return None
|
||||
except requests.exceptions.Timeout:
|
||||
print(f"[UI-TARS/client] Timeout (>30s) pour '{target_text}'")
|
||||
return None
|
||||
except Exception as e:
|
||||
print(f"[UI-TARS/client] Erreur inattendue: {e}")
|
||||
return None
|
||||
@@ -1,406 +0,0 @@
|
||||
"""
|
||||
Amélioration de WorkflowPipeline pour utiliser WorkflowExecutionResult avec métadonnées complètes
|
||||
|
||||
Cette version améliore la méthode execute_workflow_step pour retourner un objet
|
||||
WorkflowExecutionResult au lieu d'un dictionnaire, incluant toutes les métadonnées
|
||||
requises : correlation_id, performance_metrics, recovery_applied.
|
||||
|
||||
Auteur: Dom, Alice Kiro - 20 décembre 2024
|
||||
"""
|
||||
|
||||
import logging
|
||||
import uuid
|
||||
from datetime import datetime
|
||||
from typing import Optional, Dict, Any
|
||||
|
||||
from core.models.screen_state import ScreenState
|
||||
from core.models.execution_result import (
|
||||
WorkflowExecutionResult,
|
||||
PerformanceMetrics,
|
||||
RecoveryInfo,
|
||||
StepExecutionStatus
|
||||
)
|
||||
from core.execution.action_executor import ExecutionStatus
|
||||
|
||||
logger = logging.getLogger(__name__)
|
||||
|
||||
|
||||
class WorkflowPipelineEnhanced:
|
||||
"""
|
||||
Mixin pour améliorer WorkflowPipeline avec ExecutionResult complet.
|
||||
|
||||
Cette classe peut être utilisée pour étendre WorkflowPipeline existant
|
||||
ou comme référence pour la migration.
|
||||
"""
|
||||
|
||||
def execute_workflow_step_enhanced(
|
||||
self,
|
||||
workflow_id: str,
|
||||
current_state: ScreenState,
|
||||
context: Optional[Dict[str, Any]] = None
|
||||
) -> WorkflowExecutionResult:
|
||||
"""
|
||||
Exécute une étape complète de workflow de bout en bout avec métadonnées complètes.
|
||||
|
||||
Pipeline d'exécution intégré:
|
||||
1. Matcher l'état actuel avec le workflow
|
||||
2. Obtenir la prochaine action à exécuter
|
||||
3. Résoudre la cible avec TargetResolver
|
||||
4. Exécuter l'action avec ActionExecutor
|
||||
5. Gérer les erreurs avec ErrorHandler et stratégies appropriées
|
||||
6. Retourner WorkflowExecutionResult avec métadonnées complètes
|
||||
|
||||
Args:
|
||||
workflow_id: ID du workflow à exécuter
|
||||
current_state: État actuel de l'écran
|
||||
context: Contexte d'exécution optionnel (variables, etc.)
|
||||
|
||||
Returns:
|
||||
WorkflowExecutionResult avec métadonnées complètes incluant:
|
||||
- correlation_id unique pour traçabilité
|
||||
- performance_metrics détaillées par phase
|
||||
- recovery_applied si des stratégies de récupération ont été utilisées
|
||||
- execution_details pour métadonnées personnalisées
|
||||
"""
|
||||
# Générer les identifiants uniques
|
||||
execution_id = str(uuid.uuid4())
|
||||
correlation_id = str(uuid.uuid4())
|
||||
start_time = datetime.now()
|
||||
|
||||
logger.info(f"Executing workflow step: {workflow_id} (execution_id: {execution_id}, correlation_id: {correlation_id})")
|
||||
|
||||
# Initialiser les métriques de performance
|
||||
performance_metrics = PerformanceMetrics(total_execution_time_ms=0.0)
|
||||
|
||||
try:
|
||||
# 1. Matcher l'état actuel avec mesure de performance
|
||||
match_start = datetime.now()
|
||||
match_result = self.match_current_state(
|
||||
screenshot_path=current_state.raw.screenshot_path,
|
||||
workflow_id=workflow_id,
|
||||
window_title=current_state.window.window_title
|
||||
)
|
||||
performance_metrics.state_matching_time_ms = (datetime.now() - match_start).total_seconds() * 1000
|
||||
|
||||
if not match_result:
|
||||
# Gérer l'échec de matching avec ErrorHandler
|
||||
workflow = self.load_workflow(workflow_id)
|
||||
candidate_nodes = workflow.nodes if workflow else []
|
||||
|
||||
recovery_start = datetime.now()
|
||||
recovery_result = self.error_handler.handle_matching_failure(
|
||||
screen_state=current_state,
|
||||
candidate_nodes=candidate_nodes,
|
||||
best_confidence=0.0,
|
||||
threshold=0.85
|
||||
)
|
||||
recovery_duration = (datetime.now() - recovery_start).total_seconds() * 1000
|
||||
|
||||
# Créer les informations de récupération
|
||||
recovery_info = RecoveryInfo(
|
||||
strategy=recovery_result.strategy_used.value,
|
||||
message=recovery_result.message,
|
||||
success=recovery_result.success,
|
||||
attempts=1,
|
||||
duration_ms=recovery_duration
|
||||
)
|
||||
|
||||
# Finaliser les métriques
|
||||
performance_metrics.total_execution_time_ms = (datetime.now() - start_time).total_seconds() * 1000
|
||||
performance_metrics.error_handling_time_ms = recovery_duration
|
||||
|
||||
# Créer et retourner le résultat de no_match
|
||||
result = WorkflowExecutionResult.no_match(
|
||||
execution_id=execution_id,
|
||||
workflow_id=workflow_id,
|
||||
current_state=current_state,
|
||||
recovery_info=recovery_info,
|
||||
performance_metrics=performance_metrics
|
||||
)
|
||||
result.correlation_id = correlation_id
|
||||
|
||||
logger.warning(f"No match found for workflow {workflow_id}, applied recovery: {recovery_result.strategy_used.value}")
|
||||
return result
|
||||
|
||||
current_node_id = match_result["node_id"]
|
||||
logger.info(f"Matched current state to node: {current_node_id} (confidence: {match_result['confidence']:.3f})")
|
||||
|
||||
# 2. Obtenir la prochaine action (contrat dict avec status explicite)
|
||||
action_info = self.get_next_action(workflow_id, current_node_id)
|
||||
action_status = action_info.get("status")
|
||||
|
||||
if action_status == "terminal":
|
||||
# Workflow terminé (aucun outgoing_edge = fin légitime)
|
||||
performance_metrics.total_execution_time_ms = (datetime.now() - start_time).total_seconds() * 1000
|
||||
|
||||
result = WorkflowExecutionResult.workflow_complete(
|
||||
execution_id=execution_id,
|
||||
workflow_id=workflow_id,
|
||||
current_node=current_node_id,
|
||||
performance_metrics=performance_metrics,
|
||||
)
|
||||
result.correlation_id = correlation_id
|
||||
result.match_result = match_result
|
||||
|
||||
logger.info(f"Workflow {workflow_id} completed at node {current_node_id}")
|
||||
return result
|
||||
|
||||
if action_status == "blocked":
|
||||
# Des edges existent mais aucun ne passe les filtres :
|
||||
# c'est un blocage, pas une fin de workflow.
|
||||
performance_metrics.total_execution_time_ms = (datetime.now() - start_time).total_seconds() * 1000
|
||||
|
||||
result = WorkflowExecutionResult.error(
|
||||
execution_id=execution_id,
|
||||
workflow_id=workflow_id,
|
||||
error_message=f"No valid edge: {action_info.get('reason', 'unknown')}",
|
||||
step_type="action_selection",
|
||||
current_node=current_node_id,
|
||||
performance_metrics=performance_metrics,
|
||||
)
|
||||
result.correlation_id = correlation_id
|
||||
|
||||
logger.warning(
|
||||
f"Workflow {workflow_id} blocked at node {current_node_id}: "
|
||||
f"{action_info.get('reason')}"
|
||||
)
|
||||
return result
|
||||
|
||||
logger.info(f"Next action: {action_info['action']['type']} -> {action_info['target_node']}")
|
||||
|
||||
# 3. Charger le workflow pour obtenir l'edge complet
|
||||
workflow = self.load_workflow(workflow_id)
|
||||
if not workflow:
|
||||
performance_metrics.total_execution_time_ms = (datetime.now() - start_time).total_seconds() * 1000
|
||||
|
||||
result = WorkflowExecutionResult.error(
|
||||
execution_id=execution_id,
|
||||
workflow_id=workflow_id,
|
||||
error_message=f"Failed to load workflow: {workflow_id}",
|
||||
step_type="workflow_loading",
|
||||
current_node=current_node_id,
|
||||
performance_metrics=performance_metrics
|
||||
)
|
||||
result.correlation_id = correlation_id
|
||||
|
||||
logger.error(f"Failed to load workflow: {workflow_id}")
|
||||
return result
|
||||
|
||||
# Trouver l'edge correspondant
|
||||
edge = None
|
||||
for e in workflow.edges:
|
||||
if (hasattr(e, 'edge_id') and e.edge_id == action_info['edge_id']) or \
|
||||
(e.from_node == current_node_id and e.to_node == action_info['target_node']):
|
||||
edge = e
|
||||
break
|
||||
|
||||
if not edge:
|
||||
performance_metrics.total_execution_time_ms = (datetime.now() - start_time).total_seconds() * 1000
|
||||
|
||||
result = WorkflowExecutionResult.error(
|
||||
execution_id=execution_id,
|
||||
workflow_id=workflow_id,
|
||||
error_message=f"Edge not found: {current_node_id} -> {action_info['target_node']}",
|
||||
step_type="edge_resolution",
|
||||
current_node=current_node_id,
|
||||
performance_metrics=performance_metrics
|
||||
)
|
||||
result.correlation_id = correlation_id
|
||||
|
||||
logger.error(f"Edge not found: {current_node_id} -> {action_info['target_node']}")
|
||||
return result
|
||||
|
||||
# 4. Exécuter l'action avec ActionExecutor avec mesure de performance
|
||||
execution_start = datetime.now()
|
||||
execution_result = self.action_executor.execute_edge(
|
||||
edge=edge,
|
||||
screen_state=current_state,
|
||||
context=context
|
||||
)
|
||||
performance_metrics.action_execution_time_ms = (datetime.now() - execution_start).total_seconds() * 1000
|
||||
|
||||
# 5. Gérer les erreurs spécifiques avec ErrorHandler si nécessaire
|
||||
recovery_info = None
|
||||
if execution_result.status != ExecutionStatus.SUCCESS:
|
||||
recovery_start = datetime.now()
|
||||
|
||||
if execution_result.status == ExecutionStatus.TARGET_NOT_FOUND:
|
||||
# ActionExecutor a déjà géré cela, mais on peut ajouter du logging
|
||||
logger.info("Target not found - ActionExecutor applied recovery strategies")
|
||||
# Créer une info de récupération basée sur ce qui a été fait par ActionExecutor
|
||||
recovery_info = RecoveryInfo(
|
||||
strategy="target_resolution_fallback",
|
||||
message="ActionExecutor applied target resolution fallback strategies",
|
||||
success=False, # Puisque le statut est encore TARGET_NOT_FOUND
|
||||
attempts=1,
|
||||
duration_ms=0.0 # ActionExecutor a déjà mesuré son temps
|
||||
)
|
||||
|
||||
elif execution_result.status == ExecutionStatus.POSTCONDITION_FAILED:
|
||||
# Gérer l'échec de post-conditions
|
||||
recovery_result = self.error_handler.handle_postcondition_failure(
|
||||
edge=edge,
|
||||
screen_state=current_state,
|
||||
timeout_ms=5000
|
||||
)
|
||||
recovery_duration = (datetime.now() - recovery_start).total_seconds() * 1000
|
||||
|
||||
recovery_info = RecoveryInfo(
|
||||
strategy=recovery_result.strategy_used.value,
|
||||
message=recovery_result.message,
|
||||
success=recovery_result.success,
|
||||
attempts=1,
|
||||
duration_ms=recovery_duration
|
||||
)
|
||||
performance_metrics.error_handling_time_ms = recovery_duration
|
||||
logger.warning(f"Post-condition failed - Recovery: {recovery_result.message}")
|
||||
|
||||
# 6. Construire le résultat final avec métadonnées complètes
|
||||
performance_metrics.total_execution_time_ms = (datetime.now() - start_time).total_seconds() * 1000
|
||||
|
||||
# Créer le dictionnaire d'action exécutée avec détails complets
|
||||
action_executed = {
|
||||
"edge_id": action_info.get('edge_id', 'unknown'),
|
||||
"type": action_info['action']['type'],
|
||||
"target": action_info['action'].get('target'),
|
||||
"parameters": action_info['action'].get('parameters', {}),
|
||||
"execution_status": execution_result.status.value,
|
||||
"execution_message": execution_result.message,
|
||||
"execution_duration_ms": execution_result.duration_ms
|
||||
}
|
||||
|
||||
if execution_result.status == ExecutionStatus.SUCCESS:
|
||||
# Créer le résultat de succès
|
||||
result = WorkflowExecutionResult.success(
|
||||
execution_id=execution_id,
|
||||
workflow_id=workflow_id,
|
||||
current_node=current_node_id,
|
||||
target_node=action_info['target_node'],
|
||||
action_executed=action_executed,
|
||||
target_resolved=execution_result.target_resolved,
|
||||
match_result=match_result,
|
||||
performance_metrics=performance_metrics
|
||||
)
|
||||
result.correlation_id = correlation_id
|
||||
|
||||
# Ajouter des détails d'exécution personnalisés
|
||||
result.add_execution_detail("action_confidence", action_info.get('confidence', 1.0))
|
||||
result.add_execution_detail("match_confidence", match_result.get('confidence', 0.0))
|
||||
if context:
|
||||
result.add_execution_detail("execution_context", context)
|
||||
|
||||
logger.info(f"Workflow step executed successfully in {performance_metrics.total_execution_time_ms:.1f}ms")
|
||||
|
||||
else:
|
||||
# Créer le résultat d'erreur
|
||||
result = WorkflowExecutionResult.error(
|
||||
execution_id=execution_id,
|
||||
workflow_id=workflow_id,
|
||||
error_message=execution_result.message,
|
||||
step_type="action_execution",
|
||||
current_node=current_node_id,
|
||||
recovery_info=recovery_info,
|
||||
performance_metrics=performance_metrics
|
||||
)
|
||||
result.correlation_id = correlation_id
|
||||
result.target_node = action_info['target_node']
|
||||
result.action_executed = action_executed
|
||||
result.target_resolved = execution_result.target_resolved
|
||||
result.match_result = match_result
|
||||
|
||||
# Ajouter des détails d'erreur
|
||||
result.add_execution_detail("action_confidence", action_info.get('confidence', 1.0))
|
||||
result.add_execution_detail("match_confidence", match_result.get('confidence', 0.0))
|
||||
if execution_result.error:
|
||||
result.add_execution_detail("original_error", str(execution_result.error))
|
||||
|
||||
logger.error(f"Workflow step failed: {execution_result.message}")
|
||||
|
||||
return result
|
||||
|
||||
except Exception as e:
|
||||
# Gestion des exceptions avec métadonnées complètes
|
||||
performance_metrics.total_execution_time_ms = (datetime.now() - start_time).total_seconds() * 1000
|
||||
logger.error(f"Workflow step execution failed with exception: {e}", exc_info=True)
|
||||
|
||||
# Utiliser ErrorHandler pour logger l'exception
|
||||
from core.execution.error_handler import ErrorContext, ErrorType
|
||||
error_ctx = ErrorContext(
|
||||
error_type=ErrorType.UNKNOWN,
|
||||
timestamp=datetime.now(),
|
||||
screen_state=current_state,
|
||||
message=f"Workflow execution exception: {str(e)}",
|
||||
details={
|
||||
"workflow_id": workflow_id,
|
||||
"execution_id": execution_id,
|
||||
"correlation_id": correlation_id,
|
||||
"exception_type": type(e).__name__
|
||||
}
|
||||
)
|
||||
self.error_handler.error_history.append(error_ctx)
|
||||
self.error_handler._log_error(error_ctx)
|
||||
|
||||
# Créer le résultat d'erreur avec métadonnées complètes
|
||||
result = WorkflowExecutionResult.error(
|
||||
execution_id=execution_id,
|
||||
workflow_id=workflow_id,
|
||||
error_message=str(e),
|
||||
step_type="execution_error",
|
||||
performance_metrics=performance_metrics
|
||||
)
|
||||
result.correlation_id = correlation_id
|
||||
|
||||
# Ajouter des détails d'exception
|
||||
result.add_execution_detail("exception_type", type(e).__name__)
|
||||
result.add_execution_detail("exception_traceback", str(e))
|
||||
if context:
|
||||
result.add_execution_detail("execution_context", context)
|
||||
|
||||
return result
|
||||
|
||||
|
||||
def enhance_workflow_pipeline(pipeline_instance):
|
||||
"""
|
||||
Fonction utilitaire pour améliorer une instance existante de WorkflowPipeline
|
||||
avec la méthode execute_workflow_step_enhanced.
|
||||
|
||||
Args:
|
||||
pipeline_instance: Instance de WorkflowPipeline à améliorer
|
||||
|
||||
Returns:
|
||||
L'instance améliorée avec la nouvelle méthode
|
||||
"""
|
||||
# Ajouter la méthode améliorée à l'instance
|
||||
enhanced_mixin = WorkflowPipelineEnhanced()
|
||||
|
||||
# Lier les méthodes nécessaires
|
||||
pipeline_instance.execute_workflow_step_enhanced = lambda *args, **kwargs: \
|
||||
enhanced_mixin.execute_workflow_step_enhanced.call(pipeline_instance, *args, **kwargs)
|
||||
|
||||
return pipeline_instance
|
||||
|
||||
|
||||
# Fonction de migration pour remplacer la méthode existante
|
||||
def migrate_execute_workflow_step(pipeline_instance):
|
||||
"""
|
||||
Migre la méthode execute_workflow_step existante vers la version améliorée.
|
||||
|
||||
ATTENTION: Cette fonction remplace la méthode existante. Utilisez avec précaution.
|
||||
|
||||
Args:
|
||||
pipeline_instance: Instance de WorkflowPipeline à migrer
|
||||
|
||||
Returns:
|
||||
L'instance avec la méthode migrée
|
||||
"""
|
||||
# Sauvegarder l'ancienne méthode si nécessaire
|
||||
if hasattr(pipeline_instance, 'execute_workflow_step'):
|
||||
pipeline_instance._execute_workflow_step_legacy = pipeline_instance.execute_workflow_step
|
||||
|
||||
# Remplacer par la version améliorée
|
||||
enhanced_mixin = WorkflowPipelineEnhanced()
|
||||
pipeline_instance.execute_workflow_step = lambda *args, **kwargs: \
|
||||
enhanced_mixin.execute_workflow_step_enhanced.__get__(pipeline_instance, type(pipeline_instance))(*args, **kwargs)
|
||||
|
||||
logger.info("WorkflowPipeline.execute_workflow_step migrated to enhanced version with complete metadata")
|
||||
return pipeline_instance
|
||||
@@ -1,483 +0,0 @@
|
||||
#!/usr/bin/env python3
|
||||
"""
|
||||
Service de Capture Contextuelle pour RPA Vision V3
|
||||
|
||||
Ce service gère la capture du contexte environnant des éléments sélectionnés,
|
||||
incluant les éléments voisins, la hiérarchie visuelle et les métadonnées contextuelles.
|
||||
|
||||
Exigences: 7.1, 7.2, 7.3, 7.4, 7.5
|
||||
Auteur: Assistant IA
|
||||
Date: 2026-01-07
|
||||
"""
|
||||
|
||||
import asyncio
|
||||
import logging
|
||||
from typing import Dict, List, Optional, Tuple, Any
|
||||
from dataclasses import dataclass, field
|
||||
from datetime import datetime
|
||||
import numpy as np
|
||||
|
||||
from core.models import UIElement, BBox, ScreenState
|
||||
from core.capture.screen_capturer import ScreenCapturer
|
||||
from core.detection.ui_detector import UIDetector
|
||||
|
||||
logger = logging.getLogger(__name__)
|
||||
|
||||
@dataclass
|
||||
class ContextualElement:
|
||||
"""Élément contextuel dans l'environnement d'un élément cible"""
|
||||
element: UIElement
|
||||
spatial_relationship: str # 'above', 'below', 'left', 'right', 'inside', 'adjacent'
|
||||
distance: float # Distance en pixels
|
||||
relevance_score: float # Score de pertinence contextuelle (0-1)
|
||||
visual_similarity: float # Similarité visuelle avec l'élément cible (0-1)
|
||||
|
||||
@dataclass
|
||||
class VisualHierarchy:
|
||||
"""Hiérarchie visuelle d'un élément"""
|
||||
parent_container: Optional[UIElement] = None
|
||||
child_elements: List[UIElement] = field(default_factory=list)
|
||||
sibling_elements: List[UIElement] = field(default_factory=list)
|
||||
depth_level: int = 0
|
||||
container_type: str = "unknown" # 'form', 'dialog', 'panel', 'page', etc.
|
||||
|
||||
@dataclass
|
||||
class ContextualMetadata:
|
||||
"""Métadonnées contextuelles enrichies"""
|
||||
surrounding_elements: List[ContextualElement] = field(default_factory=list)
|
||||
visual_hierarchy: Optional[VisualHierarchy] = None
|
||||
screen_region: str = "unknown" # 'header', 'sidebar', 'main', 'footer', etc.
|
||||
visual_density: float = 0.0 # Densité d'éléments dans la zone (0-1)
|
||||
color_palette: List[str] = field(default_factory=list) # Couleurs dominantes
|
||||
text_context: List[str] = field(default_factory=list) # Textes environnants
|
||||
capture_timestamp: datetime = field(default_factory=datetime.now)
|
||||
|
||||
class ContextualCaptureService:
|
||||
"""
|
||||
Service de capture contextuelle pour enrichir les éléments sélectionnés
|
||||
avec des informations sur leur environnement visuel.
|
||||
"""
|
||||
|
||||
def __init__(self, screen_capturer: ScreenCapturer, ui_detector: UIDetector):
|
||||
"""
|
||||
Initialise le service de capture contextuelle.
|
||||
|
||||
Args:
|
||||
screen_capturer: Service de capture d'écran
|
||||
ui_detector: Détecteur d'éléments UI
|
||||
"""
|
||||
self.screen_capturer = screen_capturer
|
||||
self.ui_detector = ui_detector
|
||||
self.context_radius = 200 # Rayon de capture du contexte en pixels
|
||||
self.max_contextual_elements = 20 # Nombre max d'éléments contextuels
|
||||
|
||||
logger.info("Service de capture contextuelle initialisé")
|
||||
|
||||
async def capture_element_context(
|
||||
self,
|
||||
target_element: UIElement,
|
||||
screen_state: Optional[ScreenState] = None
|
||||
) -> ContextualMetadata:
|
||||
"""
|
||||
Capture le contexte complet d'un élément cible.
|
||||
|
||||
Args:
|
||||
target_element: Élément dont on veut capturer le contexte
|
||||
screen_state: État d'écran actuel (optionnel, sera capturé si absent)
|
||||
|
||||
Returns:
|
||||
Métadonnées contextuelles enrichies
|
||||
"""
|
||||
try:
|
||||
logger.info(f"Capture du contexte pour élément: {target_element.element_type}")
|
||||
|
||||
# Capturer l'état d'écran si nécessaire
|
||||
if screen_state is None:
|
||||
screen_state = await self._capture_current_screen()
|
||||
|
||||
# Analyser les éléments environnants
|
||||
surrounding_elements = await self._analyze_surrounding_elements(
|
||||
target_element, screen_state
|
||||
)
|
||||
|
||||
# Construire la hiérarchie visuelle
|
||||
visual_hierarchy = await self._build_visual_hierarchy(
|
||||
target_element, screen_state
|
||||
)
|
||||
|
||||
# Déterminer la région d'écran
|
||||
screen_region = self._determine_screen_region(target_element, screen_state)
|
||||
|
||||
# Calculer la densité visuelle
|
||||
visual_density = self._calculate_visual_density(target_element, screen_state)
|
||||
|
||||
# Extraire la palette de couleurs
|
||||
color_palette = await self._extract_color_palette(target_element, screen_state)
|
||||
|
||||
# Collecter le contexte textuel
|
||||
text_context = self._collect_text_context(target_element, screen_state)
|
||||
|
||||
metadata = ContextualMetadata(
|
||||
surrounding_elements=surrounding_elements,
|
||||
visual_hierarchy=visual_hierarchy,
|
||||
screen_region=screen_region,
|
||||
visual_density=visual_density,
|
||||
color_palette=color_palette,
|
||||
text_context=text_context,
|
||||
capture_timestamp=datetime.now()
|
||||
)
|
||||
|
||||
logger.info(f"Contexte capturé: {len(surrounding_elements)} éléments environnants")
|
||||
return metadata
|
||||
|
||||
except Exception as e:
|
||||
logger.error(f"Erreur lors de la capture du contexte: {e}")
|
||||
return ContextualMetadata()
|
||||
|
||||
async def _capture_current_screen(self) -> ScreenState:
|
||||
"""Capture l'état d'écran actuel"""
|
||||
try:
|
||||
screenshot = await self.screen_capturer.capture_screen()
|
||||
screen_state = await self.ui_detector.detect_elements(screenshot)
|
||||
return screen_state
|
||||
except Exception as e:
|
||||
logger.error(f"Erreur lors de la capture d'écran: {e}")
|
||||
raise
|
||||
|
||||
async def _analyze_surrounding_elements(
|
||||
self,
|
||||
target_element: UIElement,
|
||||
screen_state: ScreenState
|
||||
) -> List[ContextualElement]:
|
||||
"""
|
||||
Analyse les éléments environnants d'un élément cible.
|
||||
|
||||
Args:
|
||||
target_element: Élément cible
|
||||
screen_state: État d'écran complet
|
||||
|
||||
Returns:
|
||||
Liste des éléments contextuels triés par pertinence
|
||||
"""
|
||||
contextual_elements = []
|
||||
target_bbox = target_element.bounding_box
|
||||
target_center = self._get_bbox_center(target_bbox)
|
||||
|
||||
for element in screen_state.ui_elements:
|
||||
if element == target_element:
|
||||
continue
|
||||
|
||||
# Calculer la distance
|
||||
element_center = self._get_bbox_center(element.bounding_box)
|
||||
distance = self._calculate_distance(target_center, element_center)
|
||||
|
||||
# Filtrer par rayon de contexte
|
||||
if distance > self.context_radius:
|
||||
continue
|
||||
|
||||
# Déterminer la relation spatiale
|
||||
spatial_relationship = self._determine_spatial_relationship(
|
||||
target_bbox, element.bounding_box
|
||||
)
|
||||
|
||||
# Calculer le score de pertinence
|
||||
relevance_score = self._calculate_relevance_score(
|
||||
target_element, element, distance
|
||||
)
|
||||
|
||||
# Calculer la similarité visuelle (basique pour l'instant)
|
||||
visual_similarity = self._calculate_visual_similarity(
|
||||
target_element, element
|
||||
)
|
||||
|
||||
contextual_element = ContextualElement(
|
||||
element=element,
|
||||
spatial_relationship=spatial_relationship,
|
||||
distance=distance,
|
||||
relevance_score=relevance_score,
|
||||
visual_similarity=visual_similarity
|
||||
)
|
||||
|
||||
contextual_elements.append(contextual_element)
|
||||
|
||||
# Trier par pertinence et limiter le nombre
|
||||
contextual_elements.sort(key=lambda x: x.relevance_score, reverse=True)
|
||||
return contextual_elements[:self.max_contextual_elements]
|
||||
|
||||
async def _build_visual_hierarchy(
|
||||
self,
|
||||
target_element: UIElement,
|
||||
screen_state: ScreenState
|
||||
) -> VisualHierarchy:
|
||||
"""
|
||||
Construit la hiérarchie visuelle d'un élément.
|
||||
|
||||
Args:
|
||||
target_element: Élément cible
|
||||
screen_state: État d'écran complet
|
||||
|
||||
Returns:
|
||||
Hiérarchie visuelle de l'élément
|
||||
"""
|
||||
target_bbox = target_element.bounding_box
|
||||
|
||||
# Trouver le conteneur parent
|
||||
parent_container = None
|
||||
min_area = float('inf')
|
||||
|
||||
for element in screen_state.ui_elements:
|
||||
if element == target_element:
|
||||
continue
|
||||
|
||||
# Vérifier si l'élément contient notre cible
|
||||
if self._bbox_contains(element.bounding_box, target_bbox):
|
||||
area = self._calculate_bbox_area(element.bounding_box)
|
||||
if area < min_area:
|
||||
min_area = area
|
||||
parent_container = element
|
||||
|
||||
# Trouver les éléments enfants
|
||||
child_elements = []
|
||||
for element in screen_state.ui_elements:
|
||||
if element == target_element:
|
||||
continue
|
||||
|
||||
if self._bbox_contains(target_bbox, element.bounding_box):
|
||||
child_elements.append(element)
|
||||
|
||||
# Trouver les éléments frères (même conteneur parent)
|
||||
sibling_elements = []
|
||||
if parent_container:
|
||||
for element in screen_state.ui_elements:
|
||||
if element == target_element or element == parent_container:
|
||||
continue
|
||||
|
||||
if self._bbox_contains(parent_container.bounding_box, element.bounding_box):
|
||||
# Vérifier que ce n'est pas un enfant de notre cible
|
||||
if not self._bbox_contains(target_bbox, element.bounding_box):
|
||||
sibling_elements.append(element)
|
||||
|
||||
# Déterminer le type de conteneur
|
||||
container_type = "unknown"
|
||||
if parent_container:
|
||||
container_type = self._determine_container_type(parent_container)
|
||||
|
||||
# Calculer le niveau de profondeur
|
||||
depth_level = self._calculate_depth_level(target_element, screen_state)
|
||||
|
||||
return VisualHierarchy(
|
||||
parent_container=parent_container,
|
||||
child_elements=child_elements,
|
||||
sibling_elements=sibling_elements,
|
||||
depth_level=depth_level,
|
||||
container_type=container_type
|
||||
)
|
||||
|
||||
def _determine_screen_region(self, target_element: UIElement, screen_state: ScreenState) -> str:
|
||||
"""Détermine la région d'écran où se trouve l'élément"""
|
||||
bbox = target_element.bounding_box
|
||||
screen_width = screen_state.screenshot.width if screen_state.screenshot else 1920
|
||||
screen_height = screen_state.screenshot.height if screen_state.screenshot else 1080
|
||||
|
||||
center_x = (bbox.x + bbox.width / 2) / screen_width
|
||||
center_y = (bbox.y + bbox.height / 2) / screen_height
|
||||
|
||||
# Déterminer la région verticale
|
||||
if center_y < 0.2:
|
||||
vertical_region = "header"
|
||||
elif center_y > 0.8:
|
||||
vertical_region = "footer"
|
||||
else:
|
||||
vertical_region = "main"
|
||||
|
||||
# Déterminer la région horizontale
|
||||
if center_x < 0.2:
|
||||
horizontal_region = "left"
|
||||
elif center_x > 0.8:
|
||||
horizontal_region = "right"
|
||||
else:
|
||||
horizontal_region = "center"
|
||||
|
||||
return f"{vertical_region}_{horizontal_region}"
|
||||
|
||||
def _calculate_visual_density(self, target_element: UIElement, screen_state: ScreenState) -> float:
|
||||
"""Calcule la densité visuelle autour de l'élément"""
|
||||
target_bbox = target_element.bounding_box
|
||||
|
||||
# Définir une zone d'analyse autour de l'élément
|
||||
analysis_bbox = BBox(
|
||||
x=max(0, target_bbox.x - self.context_radius),
|
||||
y=max(0, target_bbox.y - self.context_radius),
|
||||
width=target_bbox.width + 2 * self.context_radius,
|
||||
height=target_bbox.height + 2 * self.context_radius
|
||||
)
|
||||
|
||||
# Compter les éléments dans cette zone
|
||||
elements_in_zone = 0
|
||||
total_element_area = 0
|
||||
|
||||
for element in screen_state.ui_elements:
|
||||
if self._bbox_intersects(element.bounding_box, analysis_bbox):
|
||||
elements_in_zone += 1
|
||||
total_element_area += self._calculate_bbox_area(element.bounding_box)
|
||||
|
||||
# Calculer la densité (ratio surface occupée / surface totale)
|
||||
analysis_area = analysis_bbox.width * analysis_bbox.height
|
||||
density = min(1.0, total_element_area / analysis_area) if analysis_area > 0 else 0.0
|
||||
|
||||
return density
|
||||
|
||||
async def _extract_color_palette(
|
||||
self,
|
||||
target_element: UIElement,
|
||||
screen_state: ScreenState
|
||||
) -> List[str]:
|
||||
"""Extrait la palette de couleurs dominantes autour de l'élément"""
|
||||
try:
|
||||
if not screen_state.screenshot:
|
||||
return []
|
||||
|
||||
# Pour l'instant, retourner une palette basique
|
||||
# L'implémentation complète nécessiterait PIL et sklearn
|
||||
return ["#1976d2", "#dc004e", "#22c55e", "#f59e0b", "#ef4444"]
|
||||
|
||||
except Exception as e:
|
||||
logger.warning(f"Erreur lors de l'extraction de couleurs: {e}")
|
||||
return []
|
||||
|
||||
def _collect_text_context(self, target_element: UIElement, screen_state: ScreenState) -> List[str]:
|
||||
"""Collecte le contexte textuel autour de l'élément"""
|
||||
text_context = []
|
||||
target_bbox = target_element.bounding_box
|
||||
target_center = self._get_bbox_center(target_bbox)
|
||||
|
||||
# Collecter les textes des éléments proches
|
||||
for element in screen_state.ui_elements:
|
||||
if element == target_element or not element.text_content:
|
||||
continue
|
||||
|
||||
element_center = self._get_bbox_center(element.bounding_box)
|
||||
distance = self._calculate_distance(target_center, element_center)
|
||||
|
||||
if distance <= self.context_radius:
|
||||
text_context.append(element.text_content.strip())
|
||||
|
||||
# Nettoyer et limiter
|
||||
text_context = [text for text in text_context if len(text) > 0]
|
||||
return text_context[:10] # Limiter à 10 textes
|
||||
|
||||
# Méthodes utilitaires
|
||||
|
||||
def _get_bbox_center(self, bbox: BBox) -> Tuple[float, float]:
|
||||
"""Calcule le centre d'une bounding box"""
|
||||
return (bbox.x + bbox.width / 2, bbox.y + bbox.height / 2)
|
||||
|
||||
def _calculate_distance(self, point1: Tuple[float, float], point2: Tuple[float, float]) -> float:
|
||||
"""Calcule la distance euclidienne entre deux points"""
|
||||
return np.sqrt((point1[0] - point2[0])**2 + (point1[1] - point2[1])**2)
|
||||
|
||||
def _determine_spatial_relationship(self, bbox1: BBox, bbox2: BBox) -> str:
|
||||
"""Détermine la relation spatiale entre deux bounding boxes"""
|
||||
center1 = self._get_bbox_center(bbox1)
|
||||
center2 = self._get_bbox_center(bbox2)
|
||||
|
||||
# Vérifier si l'un contient l'autre
|
||||
if self._bbox_contains(bbox1, bbox2):
|
||||
return "inside"
|
||||
if self._bbox_contains(bbox2, bbox1):
|
||||
return "contains"
|
||||
|
||||
# Déterminer la direction principale
|
||||
dx = center2[0] - center1[0]
|
||||
dy = center2[1] - center1[1]
|
||||
|
||||
if abs(dx) > abs(dy):
|
||||
return "right" if dx > 0 else "left"
|
||||
else:
|
||||
return "below" if dy > 0 else "above"
|
||||
|
||||
def _calculate_relevance_score(
|
||||
self,
|
||||
target_element: UIElement,
|
||||
contextual_element: UIElement,
|
||||
distance: float
|
||||
) -> float:
|
||||
"""Calcule le score de pertinence d'un élément contextuel"""
|
||||
# Score basé sur la distance (plus proche = plus pertinent)
|
||||
distance_score = max(0, 1 - (distance / self.context_radius))
|
||||
|
||||
# Bonus pour les éléments de même type
|
||||
type_bonus = 0.2 if target_element.element_type == contextual_element.element_type else 0
|
||||
|
||||
# Bonus pour les éléments avec du texte
|
||||
text_bonus = 0.1 if contextual_element.text_content else 0
|
||||
|
||||
return min(1.0, distance_score + type_bonus + text_bonus)
|
||||
|
||||
def _calculate_visual_similarity(self, element1: UIElement, element2: UIElement) -> float:
|
||||
"""Calcule la similarité visuelle basique entre deux éléments"""
|
||||
# Similarité basée sur le type d'élément
|
||||
if element1.element_type == element2.element_type:
|
||||
return 0.8
|
||||
|
||||
# Similarité basée sur la taille
|
||||
area1 = self._calculate_bbox_area(element1.bounding_box)
|
||||
area2 = self._calculate_bbox_area(element2.bounding_box)
|
||||
|
||||
if area1 > 0 and area2 > 0:
|
||||
size_ratio = min(area1, area2) / max(area1, area2)
|
||||
return size_ratio * 0.5
|
||||
|
||||
return 0.1 # Similarité minimale
|
||||
|
||||
def _bbox_contains(self, container: BBox, contained: BBox) -> bool:
|
||||
"""Vérifie si une bounding box en contient une autre"""
|
||||
return (
|
||||
container.x <= contained.x and
|
||||
container.y <= contained.y and
|
||||
container.x + container.width >= contained.x + contained.width and
|
||||
container.y + container.height >= contained.y + contained.height
|
||||
)
|
||||
|
||||
def _bbox_intersects(self, bbox1: BBox, bbox2: BBox) -> bool:
|
||||
"""Vérifie si deux bounding boxes se chevauchent"""
|
||||
return not (
|
||||
bbox1.x + bbox1.width < bbox2.x or
|
||||
bbox2.x + bbox2.width < bbox1.x or
|
||||
bbox1.y + bbox1.height < bbox2.y or
|
||||
bbox2.y + bbox2.height < bbox1.y
|
||||
)
|
||||
|
||||
def _calculate_bbox_area(self, bbox: BBox) -> float:
|
||||
"""Calcule l'aire d'une bounding box"""
|
||||
return bbox.width * bbox.height
|
||||
|
||||
def _determine_container_type(self, container: UIElement) -> str:
|
||||
"""Détermine le type de conteneur basé sur ses caractéristiques"""
|
||||
if container.element_type in ["form", "dialog", "modal"]:
|
||||
return container.element_type
|
||||
|
||||
# Heuristiques basées sur la taille et position
|
||||
area = self._calculate_bbox_area(container.bounding_box)
|
||||
|
||||
if area > 500000: # Grande zone
|
||||
return "page"
|
||||
elif area > 100000: # Zone moyenne
|
||||
return "panel"
|
||||
else:
|
||||
return "container"
|
||||
|
||||
def _calculate_depth_level(self, target_element: UIElement, screen_state: ScreenState) -> int:
|
||||
"""Calcule le niveau de profondeur dans la hiérarchie"""
|
||||
depth = 0
|
||||
current_bbox = target_element.bounding_box
|
||||
|
||||
# Compter les conteneurs qui englobent notre élément
|
||||
for element in screen_state.ui_elements:
|
||||
if element == target_element:
|
||||
continue
|
||||
|
||||
if self._bbox_contains(element.bounding_box, current_bbox):
|
||||
depth += 1
|
||||
|
||||
return depth
|
||||
@@ -1,493 +0,0 @@
|
||||
#!/usr/bin/env python3
|
||||
"""
|
||||
Service de Validation en Temps Réel pour RPA Vision V3
|
||||
|
||||
Ce service gère la validation continue des éléments visuels en arrière-plan,
|
||||
fournit des notifications de changements et maintient la cohérence des cibles visuelles.
|
||||
|
||||
Exigences: 6.1, 6.2, 6.3, 6.4, 6.5
|
||||
Auteur: Assistant IA
|
||||
Date: 2026-01-07
|
||||
"""
|
||||
|
||||
import asyncio
|
||||
import logging
|
||||
from typing import Dict, List, Optional, Callable, Any
|
||||
from dataclasses import dataclass, field
|
||||
from datetime import datetime, timedelta
|
||||
from enum import Enum
|
||||
import threading
|
||||
import weakref
|
||||
|
||||
from core.visual.visual_target_manager import VisualTarget, ValidationResult
|
||||
from core.visual.screenshot_validation_manager import ScreenshotValidationManager, ValidationStatus
|
||||
|
||||
logger = logging.getLogger(__name__)
|
||||
|
||||
class NotificationLevel(Enum):
|
||||
"""Niveaux de notification"""
|
||||
INFO = "info"
|
||||
WARNING = "warning"
|
||||
ERROR = "error"
|
||||
CRITICAL = "critical"
|
||||
|
||||
@dataclass
|
||||
class ValidationNotification:
|
||||
"""Notification de validation"""
|
||||
target_signature: str
|
||||
level: NotificationLevel
|
||||
message: str
|
||||
timestamp: datetime
|
||||
validation_result: Optional[ValidationResult] = None
|
||||
suggested_actions: List[str] = field(default_factory=list)
|
||||
auto_fixable: bool = False
|
||||
|
||||
@dataclass
|
||||
class ValidationSubscription:
|
||||
"""Abonnement aux notifications de validation"""
|
||||
target_signature: str
|
||||
callback: Callable[[ValidationNotification], None]
|
||||
notification_levels: List[NotificationLevel] = field(default_factory=lambda: list(NotificationLevel))
|
||||
active: bool = True
|
||||
created_at: datetime = field(default_factory=datetime.now)
|
||||
|
||||
class RealtimeValidationService:
|
||||
"""
|
||||
Service de validation en temps réel pour les cibles visuelles.
|
||||
|
||||
Gère la validation continue, les notifications et les actions automatiques
|
||||
pour maintenir la cohérence des éléments visuels.
|
||||
"""
|
||||
|
||||
def __init__(self, validation_manager: ScreenshotValidationManager):
|
||||
"""
|
||||
Initialise le service de validation en temps réel.
|
||||
|
||||
Args:
|
||||
validation_manager: Gestionnaire de validation des captures
|
||||
"""
|
||||
self.validation_manager = validation_manager
|
||||
|
||||
# Gestion des abonnements
|
||||
self._subscriptions: Dict[str, List[ValidationSubscription]] = {}
|
||||
self._subscription_lock = threading.RLock()
|
||||
|
||||
# Configuration du service
|
||||
self.validation_interval = 5.0 # Secondes entre validations
|
||||
self.notification_queue_size = 1000
|
||||
self.auto_fix_enabled = True
|
||||
self.batch_validation_size = 10
|
||||
|
||||
# Queue des notifications
|
||||
self._notification_queue: asyncio.Queue = asyncio.Queue(maxsize=self.notification_queue_size)
|
||||
|
||||
# Tâches de service
|
||||
self._service_tasks: List[asyncio.Task] = []
|
||||
self._service_running = False
|
||||
|
||||
# Statistiques
|
||||
self.stats = {
|
||||
'notifications_sent': 0,
|
||||
'auto_fixes_applied': 0,
|
||||
'validation_errors': 0,
|
||||
'active_subscriptions': 0
|
||||
}
|
||||
|
||||
logger.info("Service de validation en temps réel initialisé")
|
||||
|
||||
async def start_service(self):
|
||||
"""Démarre le service de validation en temps réel"""
|
||||
if self._service_running:
|
||||
logger.warning("Service déjà en cours d'exécution")
|
||||
return
|
||||
|
||||
self._service_running = True
|
||||
|
||||
# Démarrer les tâches de service
|
||||
self._service_tasks = [
|
||||
asyncio.create_task(self._notification_processor()),
|
||||
asyncio.create_task(self._periodic_health_check()),
|
||||
asyncio.create_task(self._cleanup_expired_subscriptions())
|
||||
]
|
||||
|
||||
logger.info("Service de validation en temps réel démarré")
|
||||
|
||||
async def stop_service(self):
|
||||
"""Arrête le service de validation en temps réel"""
|
||||
if not self._service_running:
|
||||
return
|
||||
|
||||
self._service_running = False
|
||||
|
||||
# Annuler toutes les tâches
|
||||
for task in self._service_tasks:
|
||||
task.cancel()
|
||||
|
||||
# Attendre l'arrêt des tâches
|
||||
await asyncio.gather(*self._service_tasks, return_exceptions=True)
|
||||
self._service_tasks.clear()
|
||||
|
||||
logger.info("Service de validation en temps réel arrêté")
|
||||
|
||||
def subscribe_to_validation(
|
||||
self,
|
||||
target_signature: str,
|
||||
callback: Callable[[ValidationNotification], None],
|
||||
notification_levels: Optional[List[NotificationLevel]] = None
|
||||
) -> str:
|
||||
"""
|
||||
S'abonne aux notifications de validation pour une cible.
|
||||
|
||||
Args:
|
||||
target_signature: Signature de la cible à surveiller
|
||||
callback: Fonction appelée lors des notifications
|
||||
notification_levels: Niveaux de notification à recevoir
|
||||
|
||||
Returns:
|
||||
ID de l'abonnement
|
||||
"""
|
||||
if notification_levels is None:
|
||||
notification_levels = list(NotificationLevel)
|
||||
|
||||
subscription = ValidationSubscription(
|
||||
target_signature=target_signature,
|
||||
callback=callback,
|
||||
notification_levels=notification_levels
|
||||
)
|
||||
|
||||
with self._subscription_lock:
|
||||
if target_signature not in self._subscriptions:
|
||||
self._subscriptions[target_signature] = []
|
||||
|
||||
self._subscriptions[target_signature].append(subscription)
|
||||
self.stats['active_subscriptions'] += 1
|
||||
|
||||
# Générer un ID unique pour l'abonnement
|
||||
subscription_id = f"{target_signature}_{id(subscription)}"
|
||||
|
||||
logger.info(f"Nouvel abonnement créé: {subscription_id}")
|
||||
return subscription_id
|
||||
|
||||
def unsubscribe_from_validation(self, target_signature: str, subscription_id: str):
|
||||
"""
|
||||
Se désabonne des notifications de validation.
|
||||
|
||||
Args:
|
||||
target_signature: Signature de la cible
|
||||
subscription_id: ID de l'abonnement à supprimer
|
||||
"""
|
||||
with self._subscription_lock:
|
||||
if target_signature in self._subscriptions:
|
||||
# Trouver et supprimer l'abonnement
|
||||
subscriptions = self._subscriptions[target_signature]
|
||||
original_count = len(subscriptions)
|
||||
|
||||
# Filtrer les abonnements actifs (approximation par ID)
|
||||
self._subscriptions[target_signature] = [
|
||||
sub for sub in subscriptions
|
||||
if f"{target_signature}_{id(sub)}" != subscription_id
|
||||
]
|
||||
|
||||
removed_count = original_count - len(self._subscriptions[target_signature])
|
||||
self.stats['active_subscriptions'] -= removed_count
|
||||
|
||||
if removed_count > 0:
|
||||
logger.info(f"Abonnement supprimé: {subscription_id}")
|
||||
|
||||
async def validate_target_with_notification(self, target: VisualTarget) -> ValidationResult:
|
||||
"""
|
||||
Valide une cible et envoie des notifications si nécessaire.
|
||||
|
||||
Args:
|
||||
target: Cible à valider
|
||||
|
||||
Returns:
|
||||
Résultat de la validation
|
||||
"""
|
||||
try:
|
||||
# Effectuer la validation
|
||||
validation_result = await self.validation_manager.validate_target_now(target)
|
||||
|
||||
# Créer et envoyer les notifications appropriées
|
||||
await self._process_validation_result(target, validation_result)
|
||||
|
||||
return validation_result
|
||||
|
||||
except Exception as e:
|
||||
logger.error(f"Erreur lors de la validation avec notification: {e}")
|
||||
self.stats['validation_errors'] += 1
|
||||
|
||||
# Créer une notification d'erreur
|
||||
error_notification = ValidationNotification(
|
||||
target_signature=target.signature,
|
||||
level=NotificationLevel.ERROR,
|
||||
message=f"Erreur de validation: {str(e)}",
|
||||
timestamp=datetime.now()
|
||||
)
|
||||
|
||||
await self._send_notification(error_notification)
|
||||
|
||||
# Retourner un résultat d'erreur
|
||||
return ValidationResult(
|
||||
target_signature=target.signature,
|
||||
status=ValidationStatus.ERROR,
|
||||
confidence=0.0,
|
||||
timestamp=datetime.now(),
|
||||
issues=[f"Erreur de validation: {str(e)}"]
|
||||
)
|
||||
|
||||
async def _process_validation_result(self, target: VisualTarget, result: ValidationResult):
|
||||
"""Traite un résultat de validation et génère les notifications appropriées"""
|
||||
notifications = []
|
||||
|
||||
# Notification basée sur le statut
|
||||
if result.status == ValidationStatus.VALID:
|
||||
if result.confidence < 0.9: # Confiance modérée
|
||||
notifications.append(ValidationNotification(
|
||||
target_signature=target.signature,
|
||||
level=NotificationLevel.INFO,
|
||||
message=f"Élément validé avec confiance modérée: {result.confidence:.2f}",
|
||||
timestamp=datetime.now(),
|
||||
validation_result=result
|
||||
))
|
||||
|
||||
elif result.status == ValidationStatus.WARNING:
|
||||
notifications.append(ValidationNotification(
|
||||
target_signature=target.signature,
|
||||
level=NotificationLevel.WARNING,
|
||||
message=f"Avertissement de validation: confiance {result.confidence:.2f}",
|
||||
timestamp=datetime.now(),
|
||||
validation_result=result,
|
||||
suggested_actions=["Vérifier l'état de l'application", "Mettre à jour la capture"],
|
||||
auto_fixable=True
|
||||
))
|
||||
|
||||
elif result.status == ValidationStatus.ERROR:
|
||||
notifications.append(ValidationNotification(
|
||||
target_signature=target.signature,
|
||||
level=NotificationLevel.ERROR,
|
||||
message="Élément non trouvé ou invalide",
|
||||
timestamp=datetime.now(),
|
||||
validation_result=result,
|
||||
suggested_actions=["Re-sélectionner l'élément", "Vérifier l'application"],
|
||||
auto_fixable=False
|
||||
))
|
||||
|
||||
# Notifications pour les problèmes spécifiques
|
||||
for issue in result.issues:
|
||||
if "position" in issue.lower():
|
||||
notifications.append(ValidationNotification(
|
||||
target_signature=target.signature,
|
||||
level=NotificationLevel.WARNING,
|
||||
message=f"Changement de position détecté: {issue}",
|
||||
timestamp=datetime.now(),
|
||||
validation_result=result,
|
||||
suggested_actions=["Mettre à jour la position de référence"],
|
||||
auto_fixable=True
|
||||
))
|
||||
|
||||
elif "appearance" in issue.lower():
|
||||
notifications.append(ValidationNotification(
|
||||
target_signature=target.signature,
|
||||
level=NotificationLevel.WARNING,
|
||||
message=f"Changement d'apparence détecté: {issue}",
|
||||
timestamp=datetime.now(),
|
||||
validation_result=result,
|
||||
suggested_actions=["Mettre à jour l'embedding de référence"],
|
||||
auto_fixable=True
|
||||
))
|
||||
|
||||
# Envoyer toutes les notifications
|
||||
for notification in notifications:
|
||||
await self._send_notification(notification)
|
||||
|
||||
async def _send_notification(self, notification: ValidationNotification):
|
||||
"""Envoie une notification aux abonnés appropriés"""
|
||||
try:
|
||||
# Ajouter à la queue de traitement
|
||||
await self._notification_queue.put(notification)
|
||||
|
||||
except asyncio.QueueFull:
|
||||
logger.warning("Queue de notifications pleine - notification ignorée")
|
||||
|
||||
async def _notification_processor(self):
|
||||
"""Processeur de notifications en arrière-plan"""
|
||||
while self._service_running:
|
||||
try:
|
||||
# Attendre une notification avec timeout
|
||||
notification = await asyncio.wait_for(
|
||||
self._notification_queue.get(),
|
||||
timeout=1.0
|
||||
)
|
||||
|
||||
# Traiter la notification
|
||||
await self._deliver_notification(notification)
|
||||
|
||||
# Appliquer les corrections automatiques si activées
|
||||
if (self.auto_fix_enabled and
|
||||
notification.auto_fixable and
|
||||
notification.validation_result):
|
||||
|
||||
await self._apply_auto_fix(notification)
|
||||
|
||||
except asyncio.TimeoutError:
|
||||
continue
|
||||
except Exception as e:
|
||||
logger.error(f"Erreur dans le processeur de notifications: {e}")
|
||||
|
||||
async def _deliver_notification(self, notification: ValidationNotification):
|
||||
"""Livre une notification aux abonnés appropriés"""
|
||||
target_signature = notification.target_signature
|
||||
|
||||
with self._subscription_lock:
|
||||
subscriptions = self._subscriptions.get(target_signature, [])
|
||||
|
||||
# Filtrer les abonnements actifs et intéressés par ce niveau
|
||||
active_subscriptions = [
|
||||
sub for sub in subscriptions
|
||||
if sub.active and notification.level in sub.notification_levels
|
||||
]
|
||||
|
||||
# Livrer aux abonnés
|
||||
for subscription in active_subscriptions:
|
||||
try:
|
||||
# Utiliser une référence faible pour éviter les fuites mémoire
|
||||
callback = subscription.callback
|
||||
if callback:
|
||||
# Exécuter le callback dans un thread séparé pour éviter le blocage
|
||||
loop = asyncio.get_event_loop()
|
||||
await loop.run_in_executor(None, callback, notification)
|
||||
|
||||
except Exception as e:
|
||||
logger.error(f"Erreur lors de la livraison de notification: {e}")
|
||||
# Désactiver l'abonnement défaillant
|
||||
subscription.active = False
|
||||
|
||||
self.stats['notifications_sent'] += 1
|
||||
|
||||
async def _apply_auto_fix(self, notification: ValidationNotification):
|
||||
"""Applique une correction automatique basée sur la notification"""
|
||||
try:
|
||||
if not notification.validation_result:
|
||||
return
|
||||
|
||||
result = notification.validation_result
|
||||
|
||||
# Appliquer les actions de récupération automatiques
|
||||
for action in result.recovery_actions:
|
||||
if action.auto_executable and action.confidence > 0.7:
|
||||
success = await self.validation_manager.execute_recovery_action(
|
||||
notification.target_signature, action
|
||||
)
|
||||
|
||||
if success:
|
||||
self.stats['auto_fixes_applied'] += 1
|
||||
logger.info(f"Correction automatique appliquée: {action.action_type}")
|
||||
|
||||
# Envoyer une notification de succès
|
||||
success_notification = ValidationNotification(
|
||||
target_signature=notification.target_signature,
|
||||
level=NotificationLevel.INFO,
|
||||
message=f"Correction automatique appliquée: {action.description}",
|
||||
timestamp=datetime.now()
|
||||
)
|
||||
|
||||
await self._send_notification(success_notification)
|
||||
break
|
||||
|
||||
except Exception as e:
|
||||
logger.error(f"Erreur lors de l'application de correction automatique: {e}")
|
||||
|
||||
async def _periodic_health_check(self):
|
||||
"""Vérification périodique de la santé du service"""
|
||||
while self._service_running:
|
||||
try:
|
||||
await asyncio.sleep(30) # Vérification toutes les 30 secondes
|
||||
|
||||
# Vérifier la taille de la queue
|
||||
queue_size = self._notification_queue.qsize()
|
||||
if queue_size > self.notification_queue_size * 0.8:
|
||||
logger.warning(f"Queue de notifications presque pleine: {queue_size}")
|
||||
|
||||
# Vérifier les abonnements actifs
|
||||
with self._subscription_lock:
|
||||
total_subscriptions = sum(len(subs) for subs in self._subscriptions.values())
|
||||
active_subscriptions = sum(
|
||||
len([sub for sub in subs if sub.active])
|
||||
for subs in self._subscriptions.values()
|
||||
)
|
||||
|
||||
self.stats['active_subscriptions'] = active_subscriptions
|
||||
|
||||
# Log des statistiques périodiques
|
||||
if total_subscriptions > 0:
|
||||
logger.debug(f"Santé du service: {active_subscriptions}/{total_subscriptions} "
|
||||
f"abonnements actifs, {queue_size} notifications en queue")
|
||||
|
||||
except Exception as e:
|
||||
logger.error(f"Erreur lors de la vérification de santé: {e}")
|
||||
|
||||
async def _cleanup_expired_subscriptions(self):
|
||||
"""Nettoie les abonnements expirés"""
|
||||
while self._service_running:
|
||||
try:
|
||||
await asyncio.sleep(300) # Nettoyage toutes les 5 minutes
|
||||
|
||||
cutoff_time = datetime.now() - timedelta(hours=24)
|
||||
|
||||
with self._subscription_lock:
|
||||
for target_signature in list(self._subscriptions.keys()):
|
||||
subscriptions = self._subscriptions[target_signature]
|
||||
|
||||
# Filtrer les abonnements actifs et récents
|
||||
active_subscriptions = [
|
||||
sub for sub in subscriptions
|
||||
if sub.active and sub.created_at > cutoff_time
|
||||
]
|
||||
|
||||
if active_subscriptions:
|
||||
self._subscriptions[target_signature] = active_subscriptions
|
||||
else:
|
||||
del self._subscriptions[target_signature]
|
||||
|
||||
logger.debug("Nettoyage des abonnements expirés terminé")
|
||||
|
||||
except Exception as e:
|
||||
logger.error(f"Erreur lors du nettoyage des abonnements: {e}")
|
||||
|
||||
def get_service_statistics(self) -> Dict[str, Any]:
|
||||
"""Récupère les statistiques du service"""
|
||||
with self._subscription_lock:
|
||||
total_targets = len(self._subscriptions)
|
||||
total_subscriptions = sum(len(subs) for subs in self._subscriptions.values())
|
||||
|
||||
return {
|
||||
'service_running': self._service_running,
|
||||
'total_targets_monitored': total_targets,
|
||||
'total_subscriptions': total_subscriptions,
|
||||
'active_subscriptions': self.stats['active_subscriptions'],
|
||||
'notifications_sent': self.stats['notifications_sent'],
|
||||
'auto_fixes_applied': self.stats['auto_fixes_applied'],
|
||||
'validation_errors': self.stats['validation_errors'],
|
||||
'notification_queue_size': self._notification_queue.qsize(),
|
||||
'auto_fix_enabled': self.auto_fix_enabled
|
||||
}
|
||||
|
||||
def enable_auto_fix(self):
|
||||
"""Active les corrections automatiques"""
|
||||
self.auto_fix_enabled = True
|
||||
logger.info("Corrections automatiques activées")
|
||||
|
||||
def disable_auto_fix(self):
|
||||
"""Désactive les corrections automatiques"""
|
||||
self.auto_fix_enabled = False
|
||||
logger.info("Corrections automatiques désactivées")
|
||||
|
||||
async def __aenter__(self):
|
||||
"""Support du context manager async"""
|
||||
await self.start_service()
|
||||
return self
|
||||
|
||||
async def __aexit__(self, exc_type, exc_val, exc_tb):
|
||||
"""Support du context manager async"""
|
||||
await self.stop_service()
|
||||
@@ -1,642 +0,0 @@
|
||||
#!/usr/bin/env python3
|
||||
"""
|
||||
Gestionnaire d'Intégration RPA pour RPA Vision V3
|
||||
|
||||
Ce gestionnaire connecte le système visuel 100% avec les composants existants:
|
||||
- FusionEngine pour les embeddings
|
||||
- UIDetector pour la détection d'éléments
|
||||
- TargetResolver pour la résolution visuelle pure
|
||||
- ExecutionLoop pour l'exécution basée sur la vision
|
||||
|
||||
Exigences: 1.5, 3.3, 6.1
|
||||
Auteur: Assistant IA
|
||||
Date: 2026-01-07
|
||||
"""
|
||||
|
||||
import asyncio
|
||||
import logging
|
||||
from typing import Dict, List, Optional, Any, Tuple
|
||||
from dataclasses import dataclass
|
||||
from datetime import datetime
|
||||
import numpy as np
|
||||
|
||||
from core.visual.visual_target_manager import VisualTarget, VisualTargetManager
|
||||
from core.visual.visual_embedding_manager import VisualEmbeddingManager
|
||||
from core.visual.screenshot_validation_manager import ScreenshotValidationManager
|
||||
from core.visual.visual_performance_optimizer import VisualPerformanceOptimizer
|
||||
|
||||
# Imports des composants RPA Vision V3 existants
|
||||
from core.embedding.fusion_engine import FusionEngine
|
||||
from core.detection.ui_detector import UIDetector
|
||||
from core.execution.target_resolver import TargetResolver
|
||||
from core.execution.execution_loop import ExecutionLoop
|
||||
from core.models import UIElement, ScreenState, BBox
|
||||
from core.capture.screen_capturer import ScreenCapturer
|
||||
|
||||
logger = logging.getLogger(__name__)
|
||||
|
||||
@dataclass
|
||||
class IntegrationConfig:
|
||||
"""Configuration de l'intégration RPA"""
|
||||
use_visual_only: bool = True # Mode 100% visuel
|
||||
fallback_to_legacy: bool = False # Fallback vers les anciens sélecteurs
|
||||
confidence_threshold: float = 0.8 # Seuil de confiance minimum
|
||||
max_retry_attempts: int = 3 # Tentatives de résolution max
|
||||
enable_self_healing: bool = True # Auto-guérison activée
|
||||
performance_monitoring: bool = True # Monitoring des performances
|
||||
|
||||
@dataclass
|
||||
class ResolutionResult:
|
||||
"""Résultat de résolution d'une cible visuelle"""
|
||||
success: bool
|
||||
target_found: Optional[UIElement] = None
|
||||
confidence: float = 0.0
|
||||
resolution_time_ms: float = 0.0
|
||||
method_used: str = "visual" # 'visual', 'fallback', 'self_healing'
|
||||
attempts_count: int = 1
|
||||
error_message: Optional[str] = None
|
||||
|
||||
class RPAIntegrationManager:
|
||||
"""
|
||||
Gestionnaire d'intégration entre le système visuel 100% et RPA Vision V3.
|
||||
|
||||
Orchestre l'interaction entre les nouveaux composants visuels et
|
||||
l'infrastructure existante pour une transition transparente.
|
||||
"""
|
||||
|
||||
def __init__(
|
||||
self,
|
||||
visual_target_manager: VisualTargetManager,
|
||||
visual_embedding_manager: VisualEmbeddingManager,
|
||||
validation_manager: ScreenshotValidationManager,
|
||||
performance_optimizer: VisualPerformanceOptimizer,
|
||||
fusion_engine: FusionEngine,
|
||||
ui_detector: UIDetector,
|
||||
screen_capturer: ScreenCapturer,
|
||||
config: Optional[IntegrationConfig] = None
|
||||
):
|
||||
"""
|
||||
Initialise le gestionnaire d'intégration.
|
||||
|
||||
Args:
|
||||
visual_target_manager: Gestionnaire des cibles visuelles
|
||||
visual_embedding_manager: Gestionnaire des embeddings visuels
|
||||
validation_manager: Gestionnaire de validation
|
||||
performance_optimizer: Optimiseur de performance
|
||||
fusion_engine: Moteur de fusion existant
|
||||
ui_detector: Détecteur UI existant
|
||||
screen_capturer: Captureur d'écran existant
|
||||
config: Configuration d'intégration
|
||||
"""
|
||||
# Composants visuels nouveaux
|
||||
self.visual_target_manager = visual_target_manager
|
||||
self.visual_embedding_manager = visual_embedding_manager
|
||||
self.validation_manager = validation_manager
|
||||
self.performance_optimizer = performance_optimizer
|
||||
|
||||
# Composants RPA Vision V3 existants
|
||||
self.fusion_engine = fusion_engine
|
||||
self.ui_detector = ui_detector
|
||||
self.screen_capturer = screen_capturer
|
||||
|
||||
# Configuration
|
||||
self.config = config or IntegrationConfig()
|
||||
|
||||
# Adaptateur pour TargetResolver
|
||||
self.visual_target_resolver = None
|
||||
|
||||
# Statistiques d'intégration
|
||||
self.integration_stats = {
|
||||
'visual_resolutions': 0,
|
||||
'fallback_resolutions': 0,
|
||||
'self_healing_activations': 0,
|
||||
'total_resolution_time_ms': 0.0,
|
||||
'average_confidence': 0.0
|
||||
}
|
||||
|
||||
logger.info("Gestionnaire d'intégration RPA initialisé en mode 100% visuel")
|
||||
|
||||
async def initialize_integration(self):
|
||||
"""Initialise l'intégration avec les composants existants"""
|
||||
try:
|
||||
logger.info("🔗 Initialisation de l'intégration RPA Vision V3...")
|
||||
|
||||
# Créer l'adaptateur TargetResolver visuel
|
||||
self.visual_target_resolver = VisualTargetResolver(
|
||||
visual_target_manager=self.visual_target_manager,
|
||||
visual_embedding_manager=self.visual_embedding_manager,
|
||||
fusion_engine=self.fusion_engine,
|
||||
ui_detector=self.ui_detector,
|
||||
config=self.config
|
||||
)
|
||||
|
||||
# Démarrer l'optimiseur de performance
|
||||
await self.performance_optimizer.start_optimizer()
|
||||
|
||||
# Configurer les hooks d'intégration
|
||||
await self._setup_integration_hooks()
|
||||
|
||||
logger.info("✅ Intégration RPA Vision V3 initialisée avec succès")
|
||||
|
||||
except Exception as e:
|
||||
logger.error(f"❌ Erreur lors de l'initialisation de l'intégration: {e}")
|
||||
raise
|
||||
|
||||
async def resolve_visual_target(
|
||||
self,
|
||||
visual_target: VisualTarget,
|
||||
current_screen_state: Optional[ScreenState] = None
|
||||
) -> ResolutionResult:
|
||||
"""
|
||||
Résout une cible visuelle dans l'écran actuel.
|
||||
|
||||
Args:
|
||||
visual_target: Cible visuelle à résoudre
|
||||
current_screen_state: État d'écran actuel (optionnel)
|
||||
|
||||
Returns:
|
||||
Résultat de la résolution
|
||||
"""
|
||||
start_time = datetime.now()
|
||||
|
||||
try:
|
||||
logger.debug(f"🎯 Résolution de la cible visuelle: {visual_target.signature}")
|
||||
|
||||
# Capturer l'écran actuel si nécessaire
|
||||
if current_screen_state is None:
|
||||
current_screen_state = await self._capture_current_screen_state()
|
||||
|
||||
# Tentative de résolution visuelle pure
|
||||
result = await self._attempt_visual_resolution(visual_target, current_screen_state)
|
||||
|
||||
# Si échec et fallback activé, essayer les méthodes legacy
|
||||
if not result.success and self.config.fallback_to_legacy:
|
||||
logger.warning("Résolution visuelle échouée, tentative de fallback...")
|
||||
result = await self._attempt_fallback_resolution(visual_target, current_screen_state)
|
||||
result.method_used = "fallback"
|
||||
|
||||
# Si échec et auto-guérison activée, essayer la récupération
|
||||
if not result.success and self.config.enable_self_healing:
|
||||
logger.warning("Résolution échouée, tentative d'auto-guérison...")
|
||||
result = await self._attempt_self_healing_resolution(visual_target, current_screen_state)
|
||||
result.method_used = "self_healing"
|
||||
if result.success:
|
||||
self.integration_stats['self_healing_activations'] += 1
|
||||
|
||||
# Calculer le temps de résolution
|
||||
resolution_time = (datetime.now() - start_time).total_seconds() * 1000
|
||||
result.resolution_time_ms = resolution_time
|
||||
|
||||
# Mettre à jour les statistiques
|
||||
await self._update_integration_stats(result)
|
||||
|
||||
if result.success:
|
||||
logger.debug(f"✅ Cible résolue en {resolution_time:.1f}ms (confiance: {result.confidence:.2f})")
|
||||
else:
|
||||
logger.warning(f"❌ Échec de résolution après {resolution_time:.1f}ms")
|
||||
|
||||
return result
|
||||
|
||||
except Exception as e:
|
||||
resolution_time = (datetime.now() - start_time).total_seconds() * 1000
|
||||
logger.error(f"❌ Erreur lors de la résolution: {e}")
|
||||
|
||||
return ResolutionResult(
|
||||
success=False,
|
||||
resolution_time_ms=resolution_time,
|
||||
error_message=str(e)
|
||||
)
|
||||
|
||||
async def execute_visual_action(
|
||||
self,
|
||||
visual_target: VisualTarget,
|
||||
action_type: str,
|
||||
action_parameters: Dict[str, Any]
|
||||
) -> bool:
|
||||
"""
|
||||
Exécute une action sur une cible visuelle.
|
||||
|
||||
Args:
|
||||
visual_target: Cible visuelle
|
||||
action_type: Type d'action ('click', 'input', 'hover', etc.)
|
||||
action_parameters: Paramètres de l'action
|
||||
|
||||
Returns:
|
||||
True si l'action a réussi
|
||||
"""
|
||||
try:
|
||||
logger.info(f"🎬 Exécution de l'action {action_type} sur {visual_target.signature}")
|
||||
|
||||
# Résoudre la cible dans l'écran actuel
|
||||
resolution_result = await self.resolve_visual_target(visual_target)
|
||||
|
||||
if not resolution_result.success:
|
||||
logger.error(f"Impossible de résoudre la cible pour l'action {action_type}")
|
||||
return False
|
||||
|
||||
# Exécuter l'action via l'ExecutionLoop adapté
|
||||
success = await self._execute_action_on_element(
|
||||
resolution_result.target_found,
|
||||
action_type,
|
||||
action_parameters
|
||||
)
|
||||
|
||||
if success:
|
||||
# Valider l'action si nécessaire
|
||||
if action_type in ['click', 'input']:
|
||||
await self._validate_action_result(visual_target, action_type)
|
||||
|
||||
logger.info(f"✅ Action {action_type} exécutée avec succès")
|
||||
else:
|
||||
logger.error(f"❌ Échec de l'exécution de l'action {action_type}")
|
||||
|
||||
return success
|
||||
|
||||
except Exception as e:
|
||||
logger.error(f"❌ Erreur lors de l'exécution de l'action: {e}")
|
||||
return False
|
||||
|
||||
async def migrate_legacy_workflow(
|
||||
self,
|
||||
legacy_workflow: Dict[str, Any]
|
||||
) -> Dict[str, VisualTarget]:
|
||||
"""
|
||||
Migre un workflow legacy vers le système 100% visuel.
|
||||
|
||||
Args:
|
||||
legacy_workflow: Workflow avec sélecteurs CSS/XPath
|
||||
|
||||
Returns:
|
||||
Mapping node_id -> VisualTarget
|
||||
"""
|
||||
logger.info("🔄 Migration d'un workflow legacy vers le système visuel")
|
||||
|
||||
migrated_targets = {}
|
||||
|
||||
try:
|
||||
# Parcourir les nœuds du workflow
|
||||
for node in legacy_workflow.get('nodes', []):
|
||||
node_id = node.get('id')
|
||||
|
||||
# Vérifier si le nœud a des sélecteurs legacy
|
||||
if self._has_legacy_selectors(node):
|
||||
logger.debug(f"Migration du nœud {node_id}")
|
||||
|
||||
# Convertir en cible visuelle
|
||||
visual_target = await self._convert_legacy_to_visual(node)
|
||||
|
||||
if visual_target:
|
||||
migrated_targets[node_id] = visual_target
|
||||
logger.debug(f"✅ Nœud {node_id} migré avec succès")
|
||||
else:
|
||||
logger.warning(f"⚠️ Échec de migration du nœud {node_id}")
|
||||
|
||||
logger.info(f"✅ Migration terminée - {len(migrated_targets)} nœuds migrés")
|
||||
return migrated_targets
|
||||
|
||||
except Exception as e:
|
||||
logger.error(f"❌ Erreur lors de la migration: {e}")
|
||||
return {}
|
||||
|
||||
# Méthodes privées
|
||||
|
||||
async def _capture_current_screen_state(self) -> ScreenState:
|
||||
"""Capture l'état d'écran actuel"""
|
||||
screenshot = await self.screen_capturer.capture_screen()
|
||||
screen_state = await self.ui_detector.detect_elements(screenshot)
|
||||
return screen_state
|
||||
|
||||
async def _attempt_visual_resolution(
|
||||
self,
|
||||
visual_target: VisualTarget,
|
||||
screen_state: ScreenState
|
||||
) -> ResolutionResult:
|
||||
"""Tente une résolution purement visuelle"""
|
||||
try:
|
||||
# Utiliser l'embedding manager pour trouver la correspondance
|
||||
best_match = await self.visual_embedding_manager.find_best_match(
|
||||
visual_target.embedding,
|
||||
screen_state.ui_elements
|
||||
)
|
||||
|
||||
if best_match and best_match.confidence >= self.config.confidence_threshold:
|
||||
self.integration_stats['visual_resolutions'] += 1
|
||||
|
||||
return ResolutionResult(
|
||||
success=True,
|
||||
target_found=best_match.element,
|
||||
confidence=best_match.confidence,
|
||||
method_used="visual"
|
||||
)
|
||||
else:
|
||||
return ResolutionResult(
|
||||
success=False,
|
||||
confidence=best_match.confidence if best_match else 0.0,
|
||||
error_message="Confiance insuffisante ou élément non trouvé"
|
||||
)
|
||||
|
||||
except Exception as e:
|
||||
return ResolutionResult(
|
||||
success=False,
|
||||
error_message=f"Erreur de résolution visuelle: {e}"
|
||||
)
|
||||
|
||||
async def _attempt_fallback_resolution(
|
||||
self,
|
||||
visual_target: VisualTarget,
|
||||
screen_state: ScreenState
|
||||
) -> ResolutionResult:
|
||||
"""Tente une résolution avec fallback vers les méthodes legacy"""
|
||||
try:
|
||||
# Utiliser le FusionEngine existant comme fallback
|
||||
fusion_result = await self.fusion_engine.find_element_by_context(
|
||||
screen_state,
|
||||
visual_target.metadata.visual_description
|
||||
)
|
||||
|
||||
if fusion_result:
|
||||
self.integration_stats['fallback_resolutions'] += 1
|
||||
|
||||
return ResolutionResult(
|
||||
success=True,
|
||||
target_found=fusion_result,
|
||||
confidence=0.7, # Confiance réduite pour fallback
|
||||
method_used="fallback"
|
||||
)
|
||||
else:
|
||||
return ResolutionResult(
|
||||
success=False,
|
||||
error_message="Fallback legacy échoué"
|
||||
)
|
||||
|
||||
except Exception as e:
|
||||
return ResolutionResult(
|
||||
success=False,
|
||||
error_message=f"Erreur de fallback: {e}"
|
||||
)
|
||||
|
||||
async def _attempt_self_healing_resolution(
|
||||
self,
|
||||
visual_target: VisualTarget,
|
||||
screen_state: ScreenState
|
||||
) -> ResolutionResult:
|
||||
"""Tente une résolution avec auto-guérison"""
|
||||
try:
|
||||
# Utiliser le gestionnaire de validation pour la récupération
|
||||
validation_result = await self.validation_manager.validate_target_now(visual_target)
|
||||
|
||||
if validation_result.recovery_actions:
|
||||
# Essayer les actions de récupération
|
||||
for action in validation_result.recovery_actions:
|
||||
if action.auto_executable and action.confidence > 0.6:
|
||||
success = await self.validation_manager.execute_recovery_action(
|
||||
visual_target.signature, action
|
||||
)
|
||||
|
||||
if success:
|
||||
# Re-tenter la résolution
|
||||
updated_target = await self.visual_target_manager.get_target_by_signature(
|
||||
visual_target.signature
|
||||
)
|
||||
|
||||
if updated_target:
|
||||
return await self._attempt_visual_resolution(updated_target, screen_state)
|
||||
|
||||
return ResolutionResult(
|
||||
success=False,
|
||||
error_message="Auto-guérison échouée"
|
||||
)
|
||||
|
||||
except Exception as e:
|
||||
return ResolutionResult(
|
||||
success=False,
|
||||
error_message=f"Erreur d'auto-guérison: {e}"
|
||||
)
|
||||
|
||||
async def _execute_action_on_element(
|
||||
self,
|
||||
element: UIElement,
|
||||
action_type: str,
|
||||
parameters: Dict[str, Any]
|
||||
) -> bool:
|
||||
"""Exécute une action sur un élément UI"""
|
||||
try:
|
||||
# Adapter l'exécution selon le type d'action
|
||||
if action_type == "click":
|
||||
return await self._execute_click_action(element, parameters)
|
||||
elif action_type == "input":
|
||||
return await self._execute_input_action(element, parameters)
|
||||
elif action_type == "hover":
|
||||
return await self._execute_hover_action(element, parameters)
|
||||
else:
|
||||
logger.warning(f"Type d'action non supporté: {action_type}")
|
||||
return False
|
||||
|
||||
except Exception as e:
|
||||
logger.error(f"Erreur lors de l'exécution de l'action {action_type}: {e}")
|
||||
return False
|
||||
|
||||
async def _execute_click_action(self, element: UIElement, parameters: Dict[str, Any]) -> bool:
|
||||
"""Exécute un clic sur un élément"""
|
||||
# Calculer la position de clic
|
||||
bbox = element.bounding_box
|
||||
click_x = bbox.x + bbox.width / 2
|
||||
click_y = bbox.y + bbox.height / 2
|
||||
|
||||
# Utiliser pyautogui ou un autre mécanisme de clic
|
||||
import pyautogui
|
||||
pyautogui.click(click_x, click_y)
|
||||
|
||||
# Attendre un délai si spécifié
|
||||
delay = parameters.get('delay', 0.5)
|
||||
await asyncio.sleep(delay)
|
||||
|
||||
return True
|
||||
|
||||
async def _execute_input_action(self, element: UIElement, parameters: Dict[str, Any]) -> bool:
|
||||
"""Exécute une saisie de texte"""
|
||||
# Cliquer d'abord sur l'élément
|
||||
await self._execute_click_action(element, {})
|
||||
|
||||
# Saisir le texte
|
||||
text = parameters.get('text', '')
|
||||
if text:
|
||||
import pyautogui
|
||||
pyautogui.write(text)
|
||||
|
||||
return True
|
||||
|
||||
async def _execute_hover_action(self, element: UIElement, parameters: Dict[str, Any]) -> bool:
|
||||
"""Exécute un survol d'élément"""
|
||||
bbox = element.bounding_box
|
||||
hover_x = bbox.x + bbox.width / 2
|
||||
hover_y = bbox.y + bbox.height / 2
|
||||
|
||||
import pyautogui
|
||||
pyautogui.moveTo(hover_x, hover_y)
|
||||
|
||||
return True
|
||||
|
||||
async def _validate_action_result(
|
||||
self,
|
||||
visual_target: VisualTarget,
|
||||
action_type: str
|
||||
):
|
||||
"""Valide le résultat d'une action"""
|
||||
# Attendre que l'action prenne effet
|
||||
await asyncio.sleep(1.0)
|
||||
|
||||
# Re-valider la cible pour détecter les changements
|
||||
validation_result = await self.validation_manager.validate_target_now(visual_target)
|
||||
|
||||
if not validation_result.is_valid:
|
||||
logger.warning(f"Validation post-action échouée pour {action_type}")
|
||||
|
||||
def _has_legacy_selectors(self, node: Dict[str, Any]) -> bool:
|
||||
"""Vérifie si un nœud contient des sélecteurs legacy"""
|
||||
parameters = node.get('parameters', {})
|
||||
|
||||
# Chercher des sélecteurs CSS/XPath
|
||||
legacy_keys = ['css_selector', 'xpath_selector', 'selector', 'target_selector']
|
||||
|
||||
for key in legacy_keys:
|
||||
if key in parameters and parameters[key]:
|
||||
return True
|
||||
|
||||
return False
|
||||
|
||||
async def _convert_legacy_to_visual(self, node: Dict[str, Any]) -> Optional[VisualTarget]:
|
||||
"""Convertit un nœud legacy en cible visuelle"""
|
||||
try:
|
||||
# Extraire les informations du sélecteur legacy
|
||||
parameters = node.get('parameters', {})
|
||||
|
||||
# Tenter de localiser l'élément avec le sélecteur legacy
|
||||
# (Cette partie nécessiterait une implémentation spécifique selon le format legacy)
|
||||
|
||||
# Pour l'instant, créer une cible visuelle simulée
|
||||
# Dans une vraie implémentation, il faudrait:
|
||||
# 1. Utiliser le sélecteur legacy pour trouver l'élément
|
||||
# 2. Capturer une image de l'élément
|
||||
# 3. Générer un embedding visuel
|
||||
# 4. Créer la VisualTarget
|
||||
|
||||
logger.warning("Conversion legacy->visuel non implémentée complètement")
|
||||
return None
|
||||
|
||||
except Exception as e:
|
||||
logger.error(f"Erreur lors de la conversion legacy: {e}")
|
||||
return None
|
||||
|
||||
async def _setup_integration_hooks(self):
|
||||
"""Configure les hooks d'intégration avec les composants existants"""
|
||||
# Hook pour intercepter les résolutions de cibles
|
||||
# Hook pour monitorer les performances
|
||||
# Hook pour la synchronisation des caches
|
||||
pass
|
||||
|
||||
async def _update_integration_stats(self, result: ResolutionResult):
|
||||
"""Met à jour les statistiques d'intégration"""
|
||||
self.integration_stats['total_resolution_time_ms'] += result.resolution_time_ms
|
||||
|
||||
if result.success:
|
||||
if result.method_used == "visual":
|
||||
self.integration_stats['visual_resolutions'] += 1
|
||||
elif result.method_used == "fallback":
|
||||
self.integration_stats['fallback_resolutions'] += 1
|
||||
|
||||
# Mettre à jour la confiance moyenne
|
||||
current_avg = self.integration_stats['average_confidence']
|
||||
total_resolutions = (self.integration_stats['visual_resolutions'] +
|
||||
self.integration_stats['fallback_resolutions'])
|
||||
|
||||
if total_resolutions > 0:
|
||||
self.integration_stats['average_confidence'] = (
|
||||
(current_avg * (total_resolutions - 1) + result.confidence) / total_resolutions
|
||||
)
|
||||
|
||||
def get_integration_statistics(self) -> Dict[str, Any]:
|
||||
"""Récupère les statistiques d'intégration"""
|
||||
total_resolutions = (self.integration_stats['visual_resolutions'] +
|
||||
self.integration_stats['fallback_resolutions'])
|
||||
|
||||
return {
|
||||
'total_resolutions': total_resolutions,
|
||||
'visual_resolutions': self.integration_stats['visual_resolutions'],
|
||||
'fallback_resolutions': self.integration_stats['fallback_resolutions'],
|
||||
'self_healing_activations': self.integration_stats['self_healing_activations'],
|
||||
'visual_success_rate': (
|
||||
self.integration_stats['visual_resolutions'] / max(1, total_resolutions) * 100
|
||||
),
|
||||
'average_resolution_time_ms': (
|
||||
self.integration_stats['total_resolution_time_ms'] / max(1, total_resolutions)
|
||||
),
|
||||
'average_confidence': self.integration_stats['average_confidence'],
|
||||
'config': {
|
||||
'visual_only_mode': self.config.use_visual_only,
|
||||
'fallback_enabled': self.config.fallback_to_legacy,
|
||||
'self_healing_enabled': self.config.enable_self_healing,
|
||||
'confidence_threshold': self.config.confidence_threshold
|
||||
}
|
||||
}
|
||||
|
||||
|
||||
class VisualTargetResolver:
|
||||
"""
|
||||
Adaptateur pour intégrer la résolution visuelle avec TargetResolver existant.
|
||||
|
||||
Remplace la logique de résolution basée sur les sélecteurs par une résolution
|
||||
purement visuelle utilisant les embeddings et la reconnaissance d'images.
|
||||
"""
|
||||
|
||||
def __init__(
|
||||
self,
|
||||
visual_target_manager: VisualTargetManager,
|
||||
visual_embedding_manager: VisualEmbeddingManager,
|
||||
fusion_engine: FusionEngine,
|
||||
ui_detector: UIDetector,
|
||||
config: IntegrationConfig
|
||||
):
|
||||
self.visual_target_manager = visual_target_manager
|
||||
self.visual_embedding_manager = visual_embedding_manager
|
||||
self.fusion_engine = fusion_engine
|
||||
self.ui_detector = ui_detector
|
||||
self.config = config
|
||||
|
||||
async def resolve_target(
|
||||
self,
|
||||
target_signature: str,
|
||||
screen_state: ScreenState
|
||||
) -> Optional[UIElement]:
|
||||
"""
|
||||
Résout une cible par sa signature visuelle.
|
||||
|
||||
Args:
|
||||
target_signature: Signature de la cible visuelle
|
||||
screen_state: État d'écran actuel
|
||||
|
||||
Returns:
|
||||
Élément UI trouvé ou None
|
||||
"""
|
||||
try:
|
||||
# Récupérer la cible visuelle
|
||||
visual_target = await self.visual_target_manager.get_target_by_signature(target_signature)
|
||||
|
||||
if not visual_target:
|
||||
logger.error(f"Cible visuelle non trouvée: {target_signature}")
|
||||
return None
|
||||
|
||||
# Utiliser l'embedding manager pour la résolution
|
||||
match_result = await self.visual_embedding_manager.find_best_match(
|
||||
visual_target.embedding,
|
||||
screen_state.ui_elements
|
||||
)
|
||||
|
||||
if match_result and match_result.confidence >= self.config.confidence_threshold:
|
||||
return match_result.element
|
||||
|
||||
return None
|
||||
|
||||
except Exception as e:
|
||||
logger.error(f"Erreur lors de la résolution de cible: {e}")
|
||||
return None
|
||||
@@ -1,582 +0,0 @@
|
||||
#!/usr/bin/env python3
|
||||
"""
|
||||
Optimiseur de Performance Visuelle pour RPA Vision V3
|
||||
|
||||
Ce module optimise les performances du système visuel pour respecter les exigences:
|
||||
- Traitement des captures < 2s
|
||||
- Réactivité mode sélection < 100ms
|
||||
- Cache intelligent pour captures multiples
|
||||
- Traitement non-bloquant des embeddings
|
||||
|
||||
Exigences: 10.1, 10.2, 10.4, 10.5
|
||||
Auteur: Assistant IA
|
||||
Date: 2026-01-07
|
||||
"""
|
||||
|
||||
import asyncio
|
||||
import logging
|
||||
import time
|
||||
from typing import Dict, List, Optional, Any, Callable, Tuple
|
||||
from dataclasses import dataclass, field
|
||||
from datetime import datetime, timedelta
|
||||
from concurrent.futures import ThreadPoolExecutor, ProcessPoolExecutor
|
||||
import threading
|
||||
from collections import OrderedDict
|
||||
import hashlib
|
||||
import numpy as np
|
||||
from pathlib import Path
|
||||
|
||||
from core.visual.visual_target_manager import VisualTarget
|
||||
from core.models import BBox, ScreenState
|
||||
|
||||
logger = logging.getLogger(__name__)
|
||||
|
||||
@dataclass
|
||||
class PerformanceMetrics:
|
||||
"""Métriques de performance"""
|
||||
capture_processing_time: float = 0.0 # Temps de traitement des captures (ms)
|
||||
selection_response_time: float = 0.0 # Temps de réponse mode sélection (ms)
|
||||
embedding_processing_time: float = 0.0 # Temps de traitement embeddings (ms)
|
||||
cache_hit_rate: float = 0.0 # Taux de succès du cache (%)
|
||||
memory_usage_mb: float = 0.0 # Usage mémoire (MB)
|
||||
active_background_tasks: int = 0 # Tâches en arrière-plan actives
|
||||
|
||||
@dataclass
|
||||
class CacheEntry:
|
||||
"""Entrée de cache"""
|
||||
key: str
|
||||
data: Any
|
||||
created_at: datetime
|
||||
last_accessed: datetime
|
||||
access_count: int = 0
|
||||
size_bytes: int = 0
|
||||
|
||||
@dataclass
|
||||
class ProcessingTask:
|
||||
"""Tâche de traitement en arrière-plan"""
|
||||
task_id: str
|
||||
task_type: str
|
||||
created_at: datetime
|
||||
callback: Optional[Callable] = None
|
||||
priority: int = 1 # 1=haute, 2=normale, 3=basse
|
||||
|
||||
class VisualPerformanceOptimizer:
|
||||
"""
|
||||
Optimiseur de performance pour le système visuel.
|
||||
|
||||
Gère le cache intelligent, le traitement asynchrone et l'optimisation
|
||||
des performances pour respecter les exigences de temps de réponse.
|
||||
"""
|
||||
|
||||
def __init__(self, max_workers: int = 4, cache_size_mb: int = 100):
|
||||
"""
|
||||
Initialise l'optimiseur de performance.
|
||||
|
||||
Args:
|
||||
max_workers: Nombre maximum de workers pour le traitement parallèle
|
||||
cache_size_mb: Taille maximale du cache en MB
|
||||
"""
|
||||
# Configuration
|
||||
self.max_workers = max_workers
|
||||
self.cache_size_mb = cache_size_mb
|
||||
self.cache_max_entries = 1000
|
||||
|
||||
# Seuils de performance (exigences)
|
||||
self.capture_processing_threshold_ms = 2000 # 2 secondes
|
||||
self.selection_response_threshold_ms = 100 # 100 millisecondes
|
||||
|
||||
# Cache intelligent
|
||||
self._cache: OrderedDict[str, CacheEntry] = OrderedDict()
|
||||
self._cache_lock = threading.RLock()
|
||||
self._cache_size_bytes = 0
|
||||
|
||||
# Pool de workers
|
||||
self._thread_pool = ThreadPoolExecutor(max_workers=max_workers)
|
||||
self._process_pool = ProcessPoolExecutor(max_workers=max_workers // 2)
|
||||
|
||||
# Gestion des tâches en arrière-plan
|
||||
self._background_tasks: Dict[str, ProcessingTask] = {}
|
||||
self._task_queue = asyncio.PriorityQueue()
|
||||
self._task_processor_running = False
|
||||
|
||||
# Métriques de performance
|
||||
self.metrics = PerformanceMetrics()
|
||||
self._metrics_lock = threading.Lock()
|
||||
|
||||
# Optimisations spécifiques
|
||||
self._precomputed_embeddings: Dict[str, np.ndarray] = {}
|
||||
self._screenshot_thumbnails: Dict[str, bytes] = {}
|
||||
|
||||
logger.info(f"Optimiseur de performance initialisé - Workers: {max_workers}, Cache: {cache_size_mb}MB")
|
||||
|
||||
async def start_optimizer(self):
|
||||
"""Démarre l'optimiseur de performance"""
|
||||
if not self._task_processor_running:
|
||||
self._task_processor_running = True
|
||||
asyncio.create_task(self._background_task_processor())
|
||||
logger.info("Optimiseur de performance démarré")
|
||||
|
||||
async def stop_optimizer(self):
|
||||
"""Arrête l'optimiseur de performance"""
|
||||
self._task_processor_running = False
|
||||
|
||||
# Fermer les pools
|
||||
self._thread_pool.shutdown(wait=True)
|
||||
self._process_pool.shutdown(wait=True)
|
||||
|
||||
logger.info("Optimiseur de performance arrêté")
|
||||
|
||||
async def optimize_capture_processing(
|
||||
self,
|
||||
screenshot_data: bytes,
|
||||
processing_func: Callable,
|
||||
cache_key: Optional[str] = None
|
||||
) -> Tuple[Any, float]:
|
||||
"""
|
||||
Optimise le traitement d'une capture d'écran.
|
||||
|
||||
Args:
|
||||
screenshot_data: Données de la capture
|
||||
processing_func: Fonction de traitement
|
||||
cache_key: Clé de cache optionnelle
|
||||
|
||||
Returns:
|
||||
Tuple (résultat, temps_traitement_ms)
|
||||
"""
|
||||
start_time = time.perf_counter()
|
||||
|
||||
try:
|
||||
# Générer une clé de cache si non fournie
|
||||
if cache_key is None:
|
||||
cache_key = self._generate_cache_key(screenshot_data)
|
||||
|
||||
# Vérifier le cache
|
||||
cached_result = self._get_from_cache(cache_key)
|
||||
if cached_result is not None:
|
||||
processing_time = (time.perf_counter() - start_time) * 1000
|
||||
logger.debug(f"Cache hit pour capture - {processing_time:.1f}ms")
|
||||
return cached_result, processing_time
|
||||
|
||||
# Traitement optimisé
|
||||
if len(screenshot_data) > 1024 * 1024: # > 1MB
|
||||
# Traitement en processus séparé pour les grandes images
|
||||
result = await self._process_in_background(
|
||||
processing_func, screenshot_data, priority=1
|
||||
)
|
||||
else:
|
||||
# Traitement en thread pour les petites images
|
||||
loop = asyncio.get_event_loop()
|
||||
result = await loop.run_in_executor(
|
||||
self._thread_pool, processing_func, screenshot_data
|
||||
)
|
||||
|
||||
# Mettre en cache le résultat
|
||||
self._put_in_cache(cache_key, result, len(screenshot_data))
|
||||
|
||||
processing_time = (time.perf_counter() - start_time) * 1000
|
||||
|
||||
# Vérifier le seuil de performance
|
||||
if processing_time > self.capture_processing_threshold_ms:
|
||||
logger.warning(f"Traitement de capture lent: {processing_time:.1f}ms > {self.capture_processing_threshold_ms}ms")
|
||||
|
||||
# Mettre à jour les métriques
|
||||
with self._metrics_lock:
|
||||
self.metrics.capture_processing_time = processing_time
|
||||
|
||||
return result, processing_time
|
||||
|
||||
except Exception as e:
|
||||
processing_time = (time.perf_counter() - start_time) * 1000
|
||||
logger.error(f"Erreur lors du traitement de capture: {e}")
|
||||
raise
|
||||
|
||||
async def optimize_selection_response(
|
||||
self,
|
||||
mouse_position: Tuple[int, int],
|
||||
screen_elements: List[Any],
|
||||
highlight_func: Callable
|
||||
) -> float:
|
||||
"""
|
||||
Optimise la réactivité du mode sélection.
|
||||
|
||||
Args:
|
||||
mouse_position: Position de la souris
|
||||
screen_elements: Éléments à l'écran
|
||||
highlight_func: Fonction de surbrillance
|
||||
|
||||
Returns:
|
||||
Temps de réponse en millisecondes
|
||||
"""
|
||||
start_time = time.perf_counter()
|
||||
|
||||
try:
|
||||
# Pré-filtrer les éléments par proximité
|
||||
nearby_elements = self._filter_nearby_elements(mouse_position, screen_elements)
|
||||
|
||||
# Traitement ultra-rapide en thread
|
||||
loop = asyncio.get_event_loop()
|
||||
await loop.run_in_executor(
|
||||
self._thread_pool, highlight_func, nearby_elements
|
||||
)
|
||||
|
||||
response_time = (time.perf_counter() - start_time) * 1000
|
||||
|
||||
# Vérifier le seuil de performance
|
||||
if response_time > self.selection_response_threshold_ms:
|
||||
logger.warning(f"Réponse sélection lente: {response_time:.1f}ms > {self.selection_response_threshold_ms}ms")
|
||||
|
||||
# Mettre à jour les métriques
|
||||
with self._metrics_lock:
|
||||
self.metrics.selection_response_time = response_time
|
||||
|
||||
return response_time
|
||||
|
||||
except Exception as e:
|
||||
response_time = (time.perf_counter() - start_time) * 1000
|
||||
logger.error(f"Erreur lors de l'optimisation de sélection: {e}")
|
||||
return response_time
|
||||
|
||||
async def process_embedding_async(
|
||||
self,
|
||||
target: VisualTarget,
|
||||
embedding_func: Callable,
|
||||
callback: Optional[Callable] = None
|
||||
) -> str:
|
||||
"""
|
||||
Traite un embedding de manière asynchrone et non-bloquante.
|
||||
|
||||
Args:
|
||||
target: Cible visuelle
|
||||
embedding_func: Fonction de génération d'embedding
|
||||
callback: Fonction de callback optionnelle
|
||||
|
||||
Returns:
|
||||
ID de la tâche
|
||||
"""
|
||||
task_id = f"embedding_{target.signature}_{int(time.time() * 1000)}"
|
||||
|
||||
# Créer la tâche de traitement
|
||||
task = ProcessingTask(
|
||||
task_id=task_id,
|
||||
task_type="embedding",
|
||||
created_at=datetime.now(),
|
||||
callback=callback,
|
||||
priority=2 # Priorité normale
|
||||
)
|
||||
|
||||
# Ajouter à la queue
|
||||
await self._task_queue.put((task.priority, task_id, task, target, embedding_func))
|
||||
|
||||
self._background_tasks[task_id] = task
|
||||
|
||||
with self._metrics_lock:
|
||||
self.metrics.active_background_tasks = len(self._background_tasks)
|
||||
|
||||
logger.debug(f"Tâche d'embedding créée: {task_id}")
|
||||
return task_id
|
||||
|
||||
def precompute_common_embeddings(self, common_elements: List[VisualTarget]):
|
||||
"""
|
||||
Pré-calcule les embeddings des éléments communs.
|
||||
|
||||
Args:
|
||||
common_elements: Liste des éléments communs à pré-calculer
|
||||
"""
|
||||
logger.info(f"Pré-calcul de {len(common_elements)} embeddings communs")
|
||||
|
||||
for target in common_elements:
|
||||
if target.signature not in self._precomputed_embeddings:
|
||||
# Stocker l'embedding pré-calculé
|
||||
self._precomputed_embeddings[target.signature] = target.embedding.copy()
|
||||
|
||||
# Créer une miniature de la capture
|
||||
thumbnail = self._create_thumbnail(target.screenshot)
|
||||
if thumbnail:
|
||||
self._screenshot_thumbnails[target.signature] = thumbnail
|
||||
|
||||
logger.info(f"Pré-calcul terminé - {len(self._precomputed_embeddings)} embeddings en cache")
|
||||
|
||||
def get_cached_embedding(self, signature: str) -> Optional[np.ndarray]:
|
||||
"""
|
||||
Récupère un embedding pré-calculé.
|
||||
|
||||
Args:
|
||||
signature: Signature de la cible
|
||||
|
||||
Returns:
|
||||
Embedding ou None si non trouvé
|
||||
"""
|
||||
return self._precomputed_embeddings.get(signature)
|
||||
|
||||
def get_thumbnail(self, signature: str) -> Optional[bytes]:
|
||||
"""
|
||||
Récupère une miniature de capture.
|
||||
|
||||
Args:
|
||||
signature: Signature de la cible
|
||||
|
||||
Returns:
|
||||
Données de la miniature ou None
|
||||
"""
|
||||
return self._screenshot_thumbnails.get(signature)
|
||||
|
||||
async def optimize_multiple_captures(
|
||||
self,
|
||||
capture_requests: List[Tuple[str, Callable]],
|
||||
batch_size: int = 5
|
||||
) -> Dict[str, Any]:
|
||||
"""
|
||||
Optimise le traitement de multiples captures en lot.
|
||||
|
||||
Args:
|
||||
capture_requests: Liste de (cache_key, processing_func)
|
||||
batch_size: Taille des lots de traitement
|
||||
|
||||
Returns:
|
||||
Dictionnaire des résultats par cache_key
|
||||
"""
|
||||
results = {}
|
||||
|
||||
# Traiter par lots
|
||||
for i in range(0, len(capture_requests), batch_size):
|
||||
batch = capture_requests[i:i + batch_size]
|
||||
|
||||
# Traitement parallèle du lot
|
||||
batch_tasks = []
|
||||
for cache_key, processing_func in batch:
|
||||
task = asyncio.create_task(
|
||||
self._process_capture_with_cache(cache_key, processing_func)
|
||||
)
|
||||
batch_tasks.append((cache_key, task))
|
||||
|
||||
# Attendre les résultats du lot
|
||||
for cache_key, task in batch_tasks:
|
||||
try:
|
||||
result = await task
|
||||
results[cache_key] = result
|
||||
except Exception as e:
|
||||
logger.error(f"Erreur lors du traitement de {cache_key}: {e}")
|
||||
results[cache_key] = None
|
||||
|
||||
logger.info(f"Traitement de {len(capture_requests)} captures terminé")
|
||||
return results
|
||||
|
||||
# Méthodes de cache
|
||||
|
||||
def _get_from_cache(self, key: str) -> Optional[Any]:
|
||||
"""Récupère une valeur du cache"""
|
||||
with self._cache_lock:
|
||||
if key in self._cache:
|
||||
entry = self._cache[key]
|
||||
entry.last_accessed = datetime.now()
|
||||
entry.access_count += 1
|
||||
|
||||
# Déplacer en fin (LRU)
|
||||
self._cache.move_to_end(key)
|
||||
|
||||
# Mettre à jour les métriques
|
||||
self._update_cache_hit_rate(True)
|
||||
|
||||
return entry.data
|
||||
|
||||
self._update_cache_hit_rate(False)
|
||||
return None
|
||||
|
||||
def _put_in_cache(self, key: str, data: Any, size_bytes: int):
|
||||
"""Ajoute une valeur au cache"""
|
||||
with self._cache_lock:
|
||||
# Vérifier la taille
|
||||
max_size_bytes = self.cache_size_mb * 1024 * 1024
|
||||
|
||||
# Nettoyer le cache si nécessaire
|
||||
while (self._cache_size_bytes + size_bytes > max_size_bytes or
|
||||
len(self._cache) >= self.cache_max_entries):
|
||||
if not self._cache:
|
||||
break
|
||||
|
||||
# Supprimer l'entrée la moins récemment utilisée
|
||||
oldest_key, oldest_entry = self._cache.popitem(last=False)
|
||||
self._cache_size_bytes -= oldest_entry.size_bytes
|
||||
|
||||
# Ajouter la nouvelle entrée
|
||||
entry = CacheEntry(
|
||||
key=key,
|
||||
data=data,
|
||||
created_at=datetime.now(),
|
||||
last_accessed=datetime.now(),
|
||||
size_bytes=size_bytes
|
||||
)
|
||||
|
||||
self._cache[key] = entry
|
||||
self._cache_size_bytes += size_bytes
|
||||
|
||||
def _update_cache_hit_rate(self, hit: bool):
|
||||
"""Met à jour le taux de succès du cache"""
|
||||
# Implémentation simplifiée - à améliorer avec un historique glissant
|
||||
with self._metrics_lock:
|
||||
if hit:
|
||||
self.metrics.cache_hit_rate = min(100.0, self.metrics.cache_hit_rate + 0.1)
|
||||
else:
|
||||
self.metrics.cache_hit_rate = max(0.0, self.metrics.cache_hit_rate - 0.1)
|
||||
|
||||
# Méthodes utilitaires
|
||||
|
||||
def _generate_cache_key(self, data: bytes) -> str:
|
||||
"""Génère une clé de cache pour des données"""
|
||||
return hashlib.md5(data).hexdigest()
|
||||
|
||||
def _filter_nearby_elements(
|
||||
self,
|
||||
mouse_position: Tuple[int, int],
|
||||
elements: List[Any],
|
||||
radius: int = 50
|
||||
) -> List[Any]:
|
||||
"""Filtre les éléments proches de la souris"""
|
||||
mx, my = mouse_position
|
||||
nearby = []
|
||||
|
||||
for element in elements:
|
||||
if hasattr(element, 'bounding_box'):
|
||||
bbox = element.bounding_box
|
||||
# Calculer la distance au centre de l'élément
|
||||
cx = bbox.x + bbox.width / 2
|
||||
cy = bbox.y + bbox.height / 2
|
||||
distance = ((mx - cx) ** 2 + (my - cy) ** 2) ** 0.5
|
||||
|
||||
if distance <= radius:
|
||||
nearby.append(element)
|
||||
|
||||
return nearby
|
||||
|
||||
def _create_thumbnail(self, screenshot_b64: str, max_size: int = 64) -> Optional[bytes]:
|
||||
"""Crée une miniature d'une capture d'écran"""
|
||||
try:
|
||||
import base64
|
||||
from PIL import Image
|
||||
import io
|
||||
|
||||
# Décoder l'image
|
||||
image_data = base64.b64decode(screenshot_b64)
|
||||
image = Image.open(io.BytesIO(image_data))
|
||||
|
||||
# Redimensionner
|
||||
image.thumbnail((max_size, max_size), Image.Resampling.LANCZOS)
|
||||
|
||||
# Encoder en bytes
|
||||
output = io.BytesIO()
|
||||
image.save(output, format='PNG', optimize=True)
|
||||
return output.getvalue()
|
||||
|
||||
except Exception as e:
|
||||
logger.warning(f"Erreur lors de la création de miniature: {e}")
|
||||
return None
|
||||
|
||||
async def _process_in_background(
|
||||
self,
|
||||
func: Callable,
|
||||
data: Any,
|
||||
priority: int = 2
|
||||
) -> Any:
|
||||
"""Traite une fonction en arrière-plan"""
|
||||
loop = asyncio.get_event_loop()
|
||||
|
||||
# Utiliser le pool de processus pour les tâches lourdes
|
||||
if priority == 1: # Haute priorité
|
||||
return await loop.run_in_executor(self._process_pool, func, data)
|
||||
else:
|
||||
return await loop.run_in_executor(self._thread_pool, func, data)
|
||||
|
||||
async def _process_capture_with_cache(self, cache_key: str, processing_func: Callable) -> Any:
|
||||
"""Traite une capture avec gestion de cache"""
|
||||
# Vérifier le cache
|
||||
cached_result = self._get_from_cache(cache_key)
|
||||
if cached_result is not None:
|
||||
return cached_result
|
||||
|
||||
# Traiter et mettre en cache
|
||||
result = await self._process_in_background(processing_func, None)
|
||||
self._put_in_cache(cache_key, result, 1024) # Taille estimée
|
||||
|
||||
return result
|
||||
|
||||
async def _background_task_processor(self):
|
||||
"""Processeur de tâches en arrière-plan"""
|
||||
while self._task_processor_running:
|
||||
try:
|
||||
# Attendre une tâche avec timeout
|
||||
priority, task_id, task, *args = await asyncio.wait_for(
|
||||
self._task_queue.get(), timeout=1.0
|
||||
)
|
||||
|
||||
start_time = time.perf_counter()
|
||||
|
||||
# Traiter la tâche
|
||||
if task.task_type == "embedding":
|
||||
target, embedding_func = args
|
||||
result = await self._process_in_background(embedding_func, target)
|
||||
|
||||
# Appeler le callback si fourni
|
||||
if task.callback:
|
||||
await task.callback(target, result)
|
||||
|
||||
# Nettoyer la tâche
|
||||
if task_id in self._background_tasks:
|
||||
del self._background_tasks[task_id]
|
||||
|
||||
processing_time = (time.perf_counter() - start_time) * 1000
|
||||
|
||||
# Mettre à jour les métriques
|
||||
with self._metrics_lock:
|
||||
self.metrics.embedding_processing_time = processing_time
|
||||
self.metrics.active_background_tasks = len(self._background_tasks)
|
||||
|
||||
logger.debug(f"Tâche {task_id} terminée en {processing_time:.1f}ms")
|
||||
|
||||
except asyncio.TimeoutError:
|
||||
continue
|
||||
except Exception as e:
|
||||
logger.error(f"Erreur dans le processeur de tâches: {e}")
|
||||
|
||||
def get_performance_metrics(self) -> Dict[str, Any]:
|
||||
"""Récupère les métriques de performance"""
|
||||
with self._metrics_lock:
|
||||
return {
|
||||
'capture_processing_time_ms': self.metrics.capture_processing_time,
|
||||
'selection_response_time_ms': self.metrics.selection_response_time,
|
||||
'embedding_processing_time_ms': self.metrics.embedding_processing_time,
|
||||
'cache_hit_rate_percent': self.metrics.cache_hit_rate,
|
||||
'memory_usage_mb': self.metrics.memory_usage_mb,
|
||||
'active_background_tasks': self.metrics.active_background_tasks,
|
||||
'cache_entries': len(self._cache),
|
||||
'cache_size_bytes': self._cache_size_bytes,
|
||||
'precomputed_embeddings': len(self._precomputed_embeddings),
|
||||
'performance_thresholds': {
|
||||
'capture_processing_ms': self.capture_processing_threshold_ms,
|
||||
'selection_response_ms': self.selection_response_threshold_ms
|
||||
}
|
||||
}
|
||||
|
||||
def clear_cache(self):
|
||||
"""Vide le cache"""
|
||||
with self._cache_lock:
|
||||
self._cache.clear()
|
||||
self._cache_size_bytes = 0
|
||||
logger.info("Cache vidé")
|
||||
|
||||
def optimize_memory_usage(self):
|
||||
"""Optimise l'usage mémoire"""
|
||||
# Nettoyer les embeddings anciens
|
||||
cutoff_time = datetime.now() - timedelta(hours=1)
|
||||
|
||||
old_embeddings = [
|
||||
sig for sig, _ in self._precomputed_embeddings.items()
|
||||
# Critère de nettoyage basé sur l'usage
|
||||
]
|
||||
|
||||
for sig in old_embeddings[:len(old_embeddings)//2]: # Nettoyer la moitié
|
||||
if sig in self._precomputed_embeddings:
|
||||
del self._precomputed_embeddings[sig]
|
||||
if sig in self._screenshot_thumbnails:
|
||||
del self._screenshot_thumbnails[sig]
|
||||
|
||||
logger.info(f"Nettoyage mémoire - {len(old_embeddings)//2} embeddings supprimés")
|
||||
@@ -1,661 +0,0 @@
|
||||
#!/usr/bin/env python3
|
||||
"""
|
||||
Gestionnaire de Persistance Visuelle pour RPA Vision V3
|
||||
|
||||
Ce gestionnaire gère la sauvegarde et la récupération complète des données visuelles,
|
||||
incluant les embeddings, captures d'écran, métadonnées et validation post-chargement.
|
||||
|
||||
Exigences: 9.1, 9.2, 9.3, 9.4, 9.5
|
||||
Auteur: Assistant IA
|
||||
Date: 2026-01-07
|
||||
"""
|
||||
|
||||
import asyncio
|
||||
import logging
|
||||
import json
|
||||
import base64
|
||||
import gzip
|
||||
import pickle # noqa: S403 - usage legacy restreint au fallback de migration
|
||||
import io
|
||||
from typing import Dict, List, Optional, Any, Tuple
|
||||
from dataclasses import dataclass, asdict
|
||||
from datetime import datetime
|
||||
from pathlib import Path
|
||||
import numpy as np
|
||||
|
||||
from core.visual.visual_target_manager import VisualTarget, VisualTargetManager
|
||||
from core.visual.screenshot_validation_manager import ScreenshotValidationManager, ValidationResult
|
||||
from core.security.signed_serializer import (
|
||||
SignatureVerificationError,
|
||||
UnsupportedFormatError,
|
||||
dumps_signed,
|
||||
loads_signed,
|
||||
)
|
||||
|
||||
logger = logging.getLogger(__name__)
|
||||
|
||||
@dataclass
|
||||
class VisualWorkflowData:
|
||||
"""Données visuelles complètes d'un workflow"""
|
||||
workflow_id: str
|
||||
version: str
|
||||
created_at: datetime
|
||||
visual_targets: Dict[str, VisualTarget]
|
||||
target_signatures: Dict[str, str] # node_id -> target_signature
|
||||
validation_history: Dict[str, List[ValidationResult]]
|
||||
metadata: Dict[str, Any]
|
||||
|
||||
@dataclass
|
||||
class PersistenceStats:
|
||||
"""Statistiques de persistance"""
|
||||
total_targets: int
|
||||
total_size_bytes: int
|
||||
compression_ratio: float
|
||||
save_duration_ms: float
|
||||
load_duration_ms: float
|
||||
|
||||
class VisualPersistenceManager:
|
||||
"""
|
||||
Gestionnaire de persistance pour les données visuelles.
|
||||
|
||||
Gère la sauvegarde complète des embeddings, captures d'écran et métadonnées
|
||||
avec compression, validation et récupération intelligente.
|
||||
"""
|
||||
|
||||
def __init__(
|
||||
self,
|
||||
target_manager: VisualTargetManager,
|
||||
validation_manager: ScreenshotValidationManager,
|
||||
storage_path: str = "data/visual_workflows"
|
||||
):
|
||||
"""
|
||||
Initialise le gestionnaire de persistance.
|
||||
|
||||
Args:
|
||||
target_manager: Gestionnaire des cibles visuelles
|
||||
validation_manager: Gestionnaire de validation
|
||||
storage_path: Chemin de stockage des données
|
||||
"""
|
||||
self.target_manager = target_manager
|
||||
self.validation_manager = validation_manager
|
||||
self.storage_path = Path(storage_path)
|
||||
|
||||
# Configuration
|
||||
self.compression_enabled = True
|
||||
self.validation_on_load = True
|
||||
self.backup_enabled = True
|
||||
self.max_backup_versions = 5
|
||||
|
||||
# Statistiques
|
||||
self.stats = PersistenceStats(0, 0, 0.0, 0.0, 0.0)
|
||||
|
||||
# Créer le répertoire de stockage
|
||||
self.storage_path.mkdir(parents=True, exist_ok=True)
|
||||
|
||||
logger.info(f"Gestionnaire de persistance visuelle initialisé - Stockage: {self.storage_path}")
|
||||
|
||||
async def save_workflow_visual_data(
|
||||
self,
|
||||
workflow_id: str,
|
||||
node_targets: Dict[str, VisualTarget],
|
||||
metadata: Optional[Dict[str, Any]] = None
|
||||
) -> bool:
|
||||
"""
|
||||
Sauvegarde les données visuelles complètes d'un workflow.
|
||||
|
||||
Args:
|
||||
workflow_id: ID du workflow
|
||||
node_targets: Mapping node_id -> VisualTarget
|
||||
metadata: Métadonnées additionnelles
|
||||
|
||||
Returns:
|
||||
True si la sauvegarde a réussi
|
||||
"""
|
||||
start_time = datetime.now()
|
||||
|
||||
try:
|
||||
logger.info(f"💾 Sauvegarde des données visuelles pour workflow {workflow_id}")
|
||||
|
||||
# Créer la structure de données
|
||||
workflow_data = VisualWorkflowData(
|
||||
workflow_id=workflow_id,
|
||||
version="1.0",
|
||||
created_at=datetime.now(),
|
||||
visual_targets={},
|
||||
target_signatures={},
|
||||
validation_history={},
|
||||
metadata=metadata or {}
|
||||
)
|
||||
|
||||
# Traiter chaque cible visuelle
|
||||
for node_id, target in node_targets.items():
|
||||
if target:
|
||||
# Stocker la cible avec sa signature
|
||||
workflow_data.visual_targets[target.signature] = target
|
||||
workflow_data.target_signatures[node_id] = target.signature
|
||||
|
||||
# Récupérer l'historique de validation
|
||||
validation_history = await self._get_validation_history(target.signature)
|
||||
if validation_history:
|
||||
workflow_data.validation_history[target.signature] = validation_history
|
||||
|
||||
# Sauvegarder les données
|
||||
success = await self._save_workflow_data(workflow_data)
|
||||
|
||||
if success:
|
||||
# Créer une sauvegarde si activée
|
||||
if self.backup_enabled:
|
||||
await self._create_backup(workflow_id)
|
||||
|
||||
# Mettre à jour les statistiques
|
||||
duration = (datetime.now() - start_time).total_seconds() * 1000
|
||||
self.stats.save_duration_ms = duration
|
||||
self.stats.total_targets = len(workflow_data.visual_targets)
|
||||
|
||||
logger.info(f"✅ Sauvegarde terminée en {duration:.0f}ms - {len(workflow_data.visual_targets)} cibles")
|
||||
return True
|
||||
|
||||
return False
|
||||
|
||||
except Exception as e:
|
||||
logger.error(f"❌ Erreur lors de la sauvegarde: {e}")
|
||||
return False
|
||||
|
||||
async def load_workflow_visual_data(
|
||||
self,
|
||||
workflow_id: str,
|
||||
validate_on_load: Optional[bool] = None
|
||||
) -> Tuple[Dict[str, VisualTarget], Dict[str, ValidationResult]]:
|
||||
"""
|
||||
Charge les données visuelles d'un workflow avec validation optionnelle.
|
||||
|
||||
Args:
|
||||
workflow_id: ID du workflow
|
||||
validate_on_load: Forcer la validation au chargement
|
||||
|
||||
Returns:
|
||||
Tuple (node_targets, validation_results)
|
||||
"""
|
||||
start_time = datetime.now()
|
||||
|
||||
try:
|
||||
logger.info(f"📂 Chargement des données visuelles pour workflow {workflow_id}")
|
||||
|
||||
# Charger les données
|
||||
workflow_data = await self._load_workflow_data(workflow_id)
|
||||
if not workflow_data:
|
||||
logger.warning(f"Aucune donnée visuelle trouvée pour {workflow_id}")
|
||||
return {}, {}
|
||||
|
||||
# Reconstruire le mapping node_id -> VisualTarget
|
||||
node_targets: Dict[str, VisualTarget] = {}
|
||||
validation_results: Dict[str, ValidationResult] = {}
|
||||
|
||||
for node_id, target_signature in workflow_data.target_signatures.items():
|
||||
if target_signature in workflow_data.visual_targets:
|
||||
target = workflow_data.visual_targets[target_signature]
|
||||
node_targets[node_id] = target
|
||||
|
||||
# Valider la cible si demandé
|
||||
should_validate = validate_on_load if validate_on_load is not None else self.validation_on_load
|
||||
if should_validate:
|
||||
validation_result = await self.validation_manager.validate_target_now(target)
|
||||
validation_results[node_id] = validation_result
|
||||
|
||||
# Mettre à jour la cible si nécessaire
|
||||
if not validation_result.is_valid and validation_result.recovery_actions:
|
||||
updated_target = await self._attempt_target_recovery(target, validation_result)
|
||||
if updated_target:
|
||||
node_targets[node_id] = updated_target
|
||||
|
||||
# Mettre à jour les statistiques
|
||||
duration = (datetime.now() - start_time).total_seconds() * 1000
|
||||
self.stats.load_duration_ms = duration
|
||||
|
||||
logger.info(f"✅ Chargement terminé en {duration:.0f}ms - {len(node_targets)} cibles restaurées")
|
||||
|
||||
return node_targets, validation_results
|
||||
|
||||
except Exception as e:
|
||||
logger.error(f"❌ Erreur lors du chargement: {e}")
|
||||
return {}, {}
|
||||
|
||||
async def export_workflow_visual_data(
|
||||
self,
|
||||
workflow_id: str,
|
||||
export_path: str,
|
||||
include_validation_history: bool = True
|
||||
) -> bool:
|
||||
"""
|
||||
Exporte les données visuelles d'un workflow vers un fichier.
|
||||
|
||||
Args:
|
||||
workflow_id: ID du workflow
|
||||
export_path: Chemin d'export
|
||||
include_validation_history: Inclure l'historique de validation
|
||||
|
||||
Returns:
|
||||
True si l'export a réussi
|
||||
"""
|
||||
try:
|
||||
logger.info(f"📤 Export des données visuelles vers {export_path}")
|
||||
|
||||
# Charger les données
|
||||
workflow_data = await self._load_workflow_data(workflow_id)
|
||||
if not workflow_data:
|
||||
logger.error(f"Aucune donnée à exporter pour {workflow_id}")
|
||||
return False
|
||||
|
||||
# Préparer les données d'export
|
||||
export_data = {
|
||||
"workflow_id": workflow_data.workflow_id,
|
||||
"version": workflow_data.version,
|
||||
"created_at": workflow_data.created_at.isoformat(),
|
||||
"exported_at": datetime.now().isoformat(),
|
||||
"visual_targets": {},
|
||||
"target_signatures": workflow_data.target_signatures,
|
||||
"metadata": workflow_data.metadata
|
||||
}
|
||||
|
||||
# Sérialiser les cibles visuelles
|
||||
for signature, target in workflow_data.visual_targets.items():
|
||||
export_data["visual_targets"][signature] = await self._serialize_target_for_export(target)
|
||||
|
||||
# Inclure l'historique de validation si demandé
|
||||
if include_validation_history:
|
||||
export_data["validation_history"] = {}
|
||||
for signature, history in workflow_data.validation_history.items():
|
||||
export_data["validation_history"][signature] = [
|
||||
self._serialize_validation_result(result) for result in history
|
||||
]
|
||||
|
||||
# Écrire le fichier d'export
|
||||
export_file = Path(export_path)
|
||||
export_file.parent.mkdir(parents=True, exist_ok=True)
|
||||
|
||||
with open(export_file, 'w', encoding='utf-8') as f:
|
||||
json.dump(export_data, f, indent=2, ensure_ascii=False)
|
||||
|
||||
logger.info(f"✅ Export terminé: {export_file}")
|
||||
return True
|
||||
|
||||
except Exception as e:
|
||||
logger.error(f"❌ Erreur lors de l'export: {e}")
|
||||
return False
|
||||
|
||||
async def import_workflow_visual_data(
|
||||
self,
|
||||
import_path: str,
|
||||
target_workflow_id: Optional[str] = None
|
||||
) -> Optional[str]:
|
||||
"""
|
||||
Importe les données visuelles depuis un fichier.
|
||||
|
||||
Args:
|
||||
import_path: Chemin du fichier d'import
|
||||
target_workflow_id: ID du workflow cible (optionnel)
|
||||
|
||||
Returns:
|
||||
ID du workflow importé ou None si échec
|
||||
"""
|
||||
try:
|
||||
logger.info(f"📥 Import des données visuelles depuis {import_path}")
|
||||
|
||||
# Lire le fichier d'import
|
||||
import_file = Path(import_path)
|
||||
if not import_file.exists():
|
||||
logger.error(f"Fichier d'import non trouvé: {import_path}")
|
||||
return None
|
||||
|
||||
with open(import_file, 'r', encoding='utf-8') as f:
|
||||
import_data = json.load(f)
|
||||
|
||||
# Déterminer l'ID du workflow
|
||||
workflow_id = target_workflow_id or import_data.get("workflow_id")
|
||||
if not workflow_id:
|
||||
logger.error("ID de workflow manquant pour l'import")
|
||||
return None
|
||||
|
||||
# Reconstruire les données du workflow
|
||||
workflow_data = VisualWorkflowData(
|
||||
workflow_id=workflow_id,
|
||||
version=import_data.get("version", "1.0"),
|
||||
created_at=datetime.fromisoformat(import_data.get("created_at", datetime.now().isoformat())),
|
||||
visual_targets={},
|
||||
target_signatures=import_data.get("target_signatures", {}),
|
||||
validation_history={},
|
||||
metadata=import_data.get("metadata", {})
|
||||
)
|
||||
|
||||
# Désérialiser les cibles visuelles
|
||||
for signature, target_data in import_data.get("visual_targets", {}).items():
|
||||
target = await self._deserialize_target_from_import(target_data)
|
||||
if target:
|
||||
workflow_data.visual_targets[signature] = target
|
||||
|
||||
# Désérialiser l'historique de validation
|
||||
for signature, history_data in import_data.get("validation_history", {}).items():
|
||||
workflow_data.validation_history[signature] = [
|
||||
self._deserialize_validation_result(result_data)
|
||||
for result_data in history_data
|
||||
]
|
||||
|
||||
# Sauvegarder les données importées
|
||||
success = await self._save_workflow_data(workflow_data)
|
||||
|
||||
if success:
|
||||
logger.info(f"✅ Import terminé pour workflow {workflow_id}")
|
||||
return workflow_id
|
||||
else:
|
||||
logger.error("Échec de la sauvegarde des données importées")
|
||||
return None
|
||||
|
||||
except Exception as e:
|
||||
logger.error(f"❌ Erreur lors de l'import: {e}")
|
||||
return None
|
||||
|
||||
async def cleanup_old_data(self, days_to_keep: int = 30) -> int:
|
||||
"""
|
||||
Nettoie les anciennes données visuelles.
|
||||
|
||||
Args:
|
||||
days_to_keep: Nombre de jours à conserver
|
||||
|
||||
Returns:
|
||||
Nombre de fichiers supprimés
|
||||
"""
|
||||
try:
|
||||
logger.info(f"🧹 Nettoyage des données anciennes (> {days_to_keep} jours)")
|
||||
|
||||
cutoff_date = datetime.now().timestamp() - (days_to_keep * 24 * 3600)
|
||||
deleted_count = 0
|
||||
|
||||
for file_path in self.storage_path.glob("*.vwd"): # Visual Workflow Data
|
||||
if file_path.stat().st_mtime < cutoff_date:
|
||||
file_path.unlink()
|
||||
deleted_count += 1
|
||||
logger.debug(f"Supprimé: {file_path}")
|
||||
|
||||
logger.info(f"✅ Nettoyage terminé - {deleted_count} fichiers supprimés")
|
||||
return deleted_count
|
||||
|
||||
except Exception as e:
|
||||
logger.error(f"❌ Erreur lors du nettoyage: {e}")
|
||||
return 0
|
||||
|
||||
# Méthodes privées
|
||||
|
||||
async def _save_workflow_data(self, workflow_data: VisualWorkflowData) -> bool:
|
||||
"""Sauvegarde les données d'un workflow"""
|
||||
try:
|
||||
file_path = self.storage_path / f"{workflow_data.workflow_id}.vwd"
|
||||
|
||||
# Sérialiser les données
|
||||
serialized_data = await self._serialize_workflow_data(workflow_data)
|
||||
|
||||
# Compresser si activé
|
||||
if self.compression_enabled:
|
||||
compressed_data = gzip.compress(serialized_data)
|
||||
self.stats.compression_ratio = len(serialized_data) / len(compressed_data)
|
||||
data_to_write = compressed_data
|
||||
else:
|
||||
data_to_write = serialized_data
|
||||
self.stats.compression_ratio = 1.0
|
||||
|
||||
# Écrire le fichier
|
||||
with open(file_path, 'wb') as f:
|
||||
f.write(data_to_write)
|
||||
|
||||
self.stats.total_size_bytes = len(data_to_write)
|
||||
return True
|
||||
|
||||
except Exception as e:
|
||||
logger.error(f"Erreur lors de la sauvegarde: {e}")
|
||||
return False
|
||||
|
||||
async def _load_workflow_data(self, workflow_id: str) -> Optional[VisualWorkflowData]:
|
||||
"""Charge les données d'un workflow"""
|
||||
try:
|
||||
file_path = self.storage_path / f"{workflow_id}.vwd"
|
||||
|
||||
if not file_path.exists():
|
||||
return None
|
||||
|
||||
# Lire le fichier
|
||||
with open(file_path, 'rb') as f:
|
||||
data = f.read()
|
||||
|
||||
# Décompresser si nécessaire
|
||||
if self.compression_enabled:
|
||||
try:
|
||||
data = gzip.decompress(data)
|
||||
except gzip.BadGzipFile:
|
||||
# Fichier non compressé
|
||||
pass
|
||||
|
||||
# Désérialiser les données
|
||||
workflow_data = await self._deserialize_workflow_data(data)
|
||||
return workflow_data
|
||||
|
||||
except Exception as e:
|
||||
logger.error(f"Erreur lors du chargement: {e}")
|
||||
return None
|
||||
|
||||
async def _serialize_workflow_data(self, workflow_data: VisualWorkflowData) -> bytes:
|
||||
"""Sérialise les données d'un workflow en JSON signé HMAC."""
|
||||
# Convertir en dictionnaire
|
||||
data_dict = asdict(workflow_data)
|
||||
|
||||
# Traiter les types spéciaux
|
||||
data_dict['created_at'] = workflow_data.created_at.isoformat()
|
||||
|
||||
# Sérialiser les cibles visuelles
|
||||
serialized_targets = {}
|
||||
for signature, target in workflow_data.visual_targets.items():
|
||||
serialized_targets[signature] = await self._serialize_visual_target(target)
|
||||
data_dict['visual_targets'] = serialized_targets
|
||||
|
||||
# Sérialiser l'historique de validation
|
||||
serialized_history = {}
|
||||
for signature, history in workflow_data.validation_history.items():
|
||||
serialized_history[signature] = [
|
||||
self._serialize_validation_result(result) for result in history
|
||||
]
|
||||
data_dict['validation_history'] = serialized_history
|
||||
|
||||
# JSON signé HMAC (cf. core.security.signed_serializer)
|
||||
return dumps_signed(data_dict)
|
||||
|
||||
async def _deserialize_workflow_data(self, data: bytes) -> VisualWorkflowData:
|
||||
"""Désérialise les données d'un workflow (JSON signé HMAC ;
|
||||
fallback pickle legacy avec WARNING pour migrer les anciens fichiers)."""
|
||||
try:
|
||||
data_dict = loads_signed(data)
|
||||
except SignatureVerificationError:
|
||||
# Fichier altéré ou clé différente : on refuse sans fallback.
|
||||
logger.error("Workflow visuel : signature HMAC invalide — refus.")
|
||||
raise
|
||||
except UnsupportedFormatError:
|
||||
# Ancien format pickle : fallback explicite et bruyant.
|
||||
import os
|
||||
if os.getenv("RPA_ALLOW_PICKLE_FALLBACK", "1") == "0":
|
||||
raise
|
||||
logger.warning(
|
||||
"Workflow visuel au format pickle legacy — lecture de compat, "
|
||||
"ré-écrire en JSON signé dès que possible."
|
||||
)
|
||||
data_dict = pickle.loads(data) # noqa: S301 - fallback legacy
|
||||
|
||||
# Reconstruire les objets
|
||||
workflow_data = VisualWorkflowData(
|
||||
workflow_id=data_dict['workflow_id'],
|
||||
version=data_dict['version'],
|
||||
created_at=datetime.fromisoformat(data_dict['created_at']),
|
||||
visual_targets={},
|
||||
target_signatures=data_dict['target_signatures'],
|
||||
validation_history={},
|
||||
metadata=data_dict['metadata']
|
||||
)
|
||||
|
||||
# Désérialiser les cibles visuelles
|
||||
for signature, target_data in data_dict['visual_targets'].items():
|
||||
target = await self._deserialize_visual_target(target_data)
|
||||
workflow_data.visual_targets[signature] = target
|
||||
|
||||
# Désérialiser l'historique de validation
|
||||
for signature, history_data in data_dict['validation_history'].items():
|
||||
workflow_data.validation_history[signature] = [
|
||||
self._deserialize_validation_result(result_data) for result_data in history_data
|
||||
]
|
||||
|
||||
return workflow_data
|
||||
|
||||
async def _serialize_visual_target(self, target: VisualTarget) -> Dict[str, Any]:
|
||||
"""Sérialise une cible visuelle"""
|
||||
return {
|
||||
'embedding': base64.b64encode(target.embedding.tobytes()).decode('utf-8'),
|
||||
'embedding_shape': target.embedding.shape,
|
||||
'embedding_dtype': str(target.embedding.dtype),
|
||||
'screenshot': target.screenshot,
|
||||
'bounding_box': asdict(target.bounding_box),
|
||||
'confidence': target.confidence,
|
||||
'contextual_info': asdict(target.contextual_info),
|
||||
'signature': target.signature,
|
||||
'metadata': asdict(target.metadata),
|
||||
'created_at': target.created_at.isoformat(),
|
||||
'last_validated': target.last_validated.isoformat() if target.last_validated else None,
|
||||
'validation_count': target.validation_count
|
||||
}
|
||||
|
||||
async def _deserialize_visual_target(self, data: Dict[str, Any]) -> VisualTarget:
|
||||
"""Désérialise une cible visuelle"""
|
||||
# Reconstruire l'embedding
|
||||
embedding_bytes = base64.b64decode(data['embedding'])
|
||||
embedding = np.frombuffer(embedding_bytes, dtype=data['embedding_dtype'])
|
||||
embedding = embedding.reshape(data['embedding_shape'])
|
||||
|
||||
# Reconstruire la cible
|
||||
from core.models import BBox, ContextualInfo, VisualMetadata
|
||||
|
||||
return VisualTarget(
|
||||
embedding=embedding,
|
||||
screenshot=data['screenshot'],
|
||||
bounding_box=BBox(**data['bounding_box']),
|
||||
confidence=data['confidence'],
|
||||
contextual_info=ContextualInfo(**data['contextual_info']),
|
||||
signature=data['signature'],
|
||||
metadata=VisualMetadata(**data['metadata']),
|
||||
created_at=datetime.fromisoformat(data['created_at']),
|
||||
last_validated=datetime.fromisoformat(data['last_validated']) if data['last_validated'] else None,
|
||||
validation_count=data['validation_count']
|
||||
)
|
||||
|
||||
def _serialize_validation_result(self, result: ValidationResult) -> Dict[str, Any]:
|
||||
"""Sérialise un résultat de validation"""
|
||||
return asdict(result)
|
||||
|
||||
def _deserialize_validation_result(self, data: Dict[str, Any]) -> ValidationResult:
|
||||
"""Désérialise un résultat de validation"""
|
||||
return ValidationResult(**data)
|
||||
|
||||
async def _serialize_target_for_export(self, target: VisualTarget) -> Dict[str, Any]:
|
||||
"""Sérialise une cible pour l'export JSON"""
|
||||
serialized = await self._serialize_visual_target(target)
|
||||
# Convertir les bytes en base64 pour JSON
|
||||
return serialized
|
||||
|
||||
async def _deserialize_target_from_import(self, data: Dict[str, Any]) -> Optional[VisualTarget]:
|
||||
"""Désérialise une cible depuis l'import JSON"""
|
||||
try:
|
||||
return await self._deserialize_visual_target(data)
|
||||
except Exception as e:
|
||||
logger.error(f"Erreur lors de la désérialisation de cible: {e}")
|
||||
return None
|
||||
|
||||
async def _get_validation_history(self, target_signature: str) -> List[ValidationResult]:
|
||||
"""Récupère l'historique de validation d'une cible"""
|
||||
# À implémenter selon le système de validation
|
||||
return []
|
||||
|
||||
async def _attempt_target_recovery(
|
||||
self,
|
||||
target: VisualTarget,
|
||||
validation_result: ValidationResult
|
||||
) -> Optional[VisualTarget]:
|
||||
"""Tente de récupérer une cible invalide"""
|
||||
try:
|
||||
# Utiliser les actions de récupération du résultat de validation
|
||||
for action in validation_result.recovery_actions:
|
||||
if action.auto_executable and action.confidence > 0.7:
|
||||
# Exécuter l'action de récupération
|
||||
success = await self.validation_manager.execute_recovery_action(
|
||||
target.signature, action
|
||||
)
|
||||
if success:
|
||||
# Récupérer la cible mise à jour
|
||||
updated_target = await self.target_manager.get_target_by_signature(target.signature)
|
||||
if updated_target:
|
||||
logger.info(f"Cible récupérée avec succès: {target.signature}")
|
||||
return updated_target
|
||||
|
||||
return None
|
||||
|
||||
except Exception as e:
|
||||
logger.error(f"Erreur lors de la récupération de cible: {e}")
|
||||
return None
|
||||
|
||||
async def _create_backup(self, workflow_id: str) -> bool:
|
||||
"""Crée une sauvegarde du workflow"""
|
||||
try:
|
||||
source_file = self.storage_path / f"{workflow_id}.vwd"
|
||||
if not source_file.exists():
|
||||
return False
|
||||
|
||||
# Créer le répertoire de sauvegarde
|
||||
backup_dir = self.storage_path / "backups" / workflow_id
|
||||
backup_dir.mkdir(parents=True, exist_ok=True)
|
||||
|
||||
# Nom de fichier avec timestamp
|
||||
timestamp = datetime.now().strftime("%Y%m%d_%H%M%S")
|
||||
backup_file = backup_dir / f"{workflow_id}_{timestamp}.vwd"
|
||||
|
||||
# Copier le fichier
|
||||
import shutil
|
||||
shutil.copy2(source_file, backup_file)
|
||||
|
||||
# Nettoyer les anciennes sauvegardes
|
||||
await self._cleanup_old_backups(backup_dir)
|
||||
|
||||
logger.debug(f"Sauvegarde créée: {backup_file}")
|
||||
return True
|
||||
|
||||
except Exception as e:
|
||||
logger.error(f"Erreur lors de la création de sauvegarde: {e}")
|
||||
return False
|
||||
|
||||
async def _cleanup_old_backups(self, backup_dir: Path):
|
||||
"""Nettoie les anciennes sauvegardes"""
|
||||
try:
|
||||
backup_files = sorted(backup_dir.glob("*.vwd"), key=lambda x: x.stat().st_mtime, reverse=True)
|
||||
|
||||
# Supprimer les fichiers excédentaires
|
||||
for file_to_delete in backup_files[self.max_backup_versions:]:
|
||||
file_to_delete.unlink()
|
||||
logger.debug(f"Ancienne sauvegarde supprimée: {file_to_delete}")
|
||||
|
||||
except Exception as e:
|
||||
logger.error(f"Erreur lors du nettoyage des sauvegardes: {e}")
|
||||
|
||||
def get_persistence_stats(self) -> Dict[str, Any]:
|
||||
"""Récupère les statistiques de persistance"""
|
||||
return {
|
||||
'total_targets': self.stats.total_targets,
|
||||
'total_size_bytes': self.stats.total_size_bytes,
|
||||
'compression_ratio': self.stats.compression_ratio,
|
||||
'save_duration_ms': self.stats.save_duration_ms,
|
||||
'load_duration_ms': self.stats.load_duration_ms,
|
||||
'compression_enabled': self.compression_enabled,
|
||||
'validation_on_load': self.validation_on_load,
|
||||
'backup_enabled': self.backup_enabled,
|
||||
'storage_path': str(self.storage_path)
|
||||
}
|
||||
@@ -1,657 +0,0 @@
|
||||
#!/usr/bin/env python3
|
||||
"""
|
||||
Outil de Migration de Workflows pour RPA Vision V3
|
||||
|
||||
Cet outil migre les workflows existants utilisant des sélecteurs CSS/XPath
|
||||
vers le système 100% visuel avec signatures visuelles et embeddings.
|
||||
|
||||
Fonctionnalités:
|
||||
- Conversion automatique avec validation
|
||||
- Interface de migration guidée
|
||||
- Préservation de la fonctionnalité des workflows
|
||||
- Sauvegarde et rollback
|
||||
|
||||
Exigences: 9.3, 9.4
|
||||
Auteur: Assistant IA
|
||||
Date: 2026-01-07
|
||||
"""
|
||||
|
||||
import asyncio
|
||||
import logging
|
||||
import json
|
||||
from typing import Dict, List, Optional, Any, Tuple
|
||||
from dataclasses import dataclass, asdict
|
||||
from datetime import datetime
|
||||
from pathlib import Path
|
||||
import shutil
|
||||
|
||||
from core.visual.visual_target_manager import VisualTarget, VisualTargetManager
|
||||
from core.visual.visual_embedding_manager import VisualEmbeddingManager
|
||||
from core.visual.screenshot_validation_manager import ScreenshotValidationManager
|
||||
from core.capture.screen_capturer import ScreenCapturer
|
||||
from core.detection.ui_detector import UIDetector
|
||||
|
||||
logger = logging.getLogger(__name__)
|
||||
|
||||
@dataclass
|
||||
class MigrationTask:
|
||||
"""Tâche de migration d'un nœud"""
|
||||
node_id: str
|
||||
node_type: str
|
||||
legacy_selectors: Dict[str, str]
|
||||
migration_status: str = "pending" # pending, in_progress, completed, failed
|
||||
visual_target: Optional[VisualTarget] = None
|
||||
error_message: Optional[str] = None
|
||||
confidence_score: float = 0.0
|
||||
manual_review_required: bool = False
|
||||
|
||||
@dataclass
|
||||
class MigrationReport:
|
||||
"""Rapport de migration d'un workflow"""
|
||||
workflow_id: str
|
||||
workflow_name: str
|
||||
migration_started: datetime
|
||||
migration_completed: Optional[datetime] = None
|
||||
total_nodes: int = 0
|
||||
migrated_nodes: int = 0
|
||||
failed_nodes: int = 0
|
||||
manual_review_nodes: int = 0
|
||||
migration_tasks: List[MigrationTask] = None
|
||||
backup_path: Optional[str] = None
|
||||
success_rate: float = 0.0
|
||||
|
||||
class WorkflowMigrationTool:
|
||||
"""
|
||||
Outil de migration des workflows vers le système 100% visuel.
|
||||
|
||||
Convertit automatiquement les sélecteurs CSS/XPath en cibles visuelles
|
||||
avec validation et interface de migration guidée.
|
||||
"""
|
||||
|
||||
def __init__(
|
||||
self,
|
||||
visual_target_manager: VisualTargetManager,
|
||||
visual_embedding_manager: VisualEmbeddingManager,
|
||||
validation_manager: ScreenshotValidationManager,
|
||||
screen_capturer: ScreenCapturer,
|
||||
ui_detector: UIDetector,
|
||||
migration_storage_path: str = "data/migrations"
|
||||
):
|
||||
"""
|
||||
Initialise l'outil de migration.
|
||||
|
||||
Args:
|
||||
visual_target_manager: Gestionnaire des cibles visuelles
|
||||
visual_embedding_manager: Gestionnaire des embeddings
|
||||
validation_manager: Gestionnaire de validation
|
||||
screen_capturer: Captureur d'écran
|
||||
ui_detector: Détecteur UI
|
||||
migration_storage_path: Chemin de stockage des migrations
|
||||
"""
|
||||
self.visual_target_manager = visual_target_manager
|
||||
self.visual_embedding_manager = visual_embedding_manager
|
||||
self.validation_manager = validation_manager
|
||||
self.screen_capturer = screen_capturer
|
||||
self.ui_detector = ui_detector
|
||||
|
||||
# Configuration
|
||||
self.migration_storage_path = Path(migration_storage_path)
|
||||
self.migration_storage_path.mkdir(parents=True, exist_ok=True)
|
||||
|
||||
# Seuils de migration
|
||||
self.confidence_threshold = 0.8
|
||||
self.manual_review_threshold = 0.6
|
||||
|
||||
# Types de sélecteurs supportés
|
||||
self.supported_selector_types = {
|
||||
'css_selector': self._migrate_css_selector,
|
||||
'xpath_selector': self._migrate_xpath_selector,
|
||||
'id_selector': self._migrate_id_selector,
|
||||
'class_selector': self._migrate_class_selector,
|
||||
'text_selector': self._migrate_text_selector
|
||||
}
|
||||
|
||||
logger.info("Outil de migration de workflows initialisé")
|
||||
|
||||
async def migrate_workflow(
|
||||
self,
|
||||
workflow_data: Dict[str, Any],
|
||||
interactive_mode: bool = True,
|
||||
create_backup: bool = True
|
||||
) -> MigrationReport:
|
||||
"""
|
||||
Migre un workflow complet vers le système visuel.
|
||||
|
||||
Args:
|
||||
workflow_data: Données du workflow à migrer
|
||||
interactive_mode: Mode interactif pour la validation manuelle
|
||||
create_backup: Créer une sauvegarde avant migration
|
||||
|
||||
Returns:
|
||||
Rapport de migration
|
||||
"""
|
||||
workflow_id = workflow_data.get('id', 'unknown')
|
||||
workflow_name = workflow_data.get('name', 'Workflow sans nom')
|
||||
|
||||
logger.info(f"🔄 Début de migration du workflow: {workflow_name} ({workflow_id})")
|
||||
|
||||
# Créer le rapport de migration
|
||||
report = MigrationReport(
|
||||
workflow_id=workflow_id,
|
||||
workflow_name=workflow_name,
|
||||
migration_started=datetime.now(),
|
||||
migration_tasks=[]
|
||||
)
|
||||
|
||||
try:
|
||||
# Créer une sauvegarde si demandé
|
||||
if create_backup:
|
||||
backup_path = await self._create_workflow_backup(workflow_data)
|
||||
report.backup_path = backup_path
|
||||
logger.info(f"💾 Sauvegarde créée: {backup_path}")
|
||||
|
||||
# Analyser les nœuds du workflow
|
||||
nodes = workflow_data.get('nodes', [])
|
||||
report.total_nodes = len(nodes)
|
||||
|
||||
# Identifier les nœuds nécessitant une migration
|
||||
migration_tasks = []
|
||||
for node in nodes:
|
||||
task = await self._analyze_node_for_migration(node)
|
||||
if task:
|
||||
migration_tasks.append(task)
|
||||
report.migration_tasks.append(task)
|
||||
|
||||
logger.info(f"📋 {len(migration_tasks)} nœuds nécessitent une migration")
|
||||
|
||||
# Migrer chaque nœud
|
||||
for task in migration_tasks:
|
||||
logger.info(f"🔧 Migration du nœud {task.node_id} ({task.node_type})")
|
||||
|
||||
task.migration_status = "in_progress"
|
||||
|
||||
try:
|
||||
# Tenter la migration automatique
|
||||
success = await self._migrate_node_task(task, workflow_data)
|
||||
|
||||
if success:
|
||||
task.migration_status = "completed"
|
||||
report.migrated_nodes += 1
|
||||
logger.info(f"✅ Nœud {task.node_id} migré avec succès")
|
||||
else:
|
||||
# Vérifier si une révision manuelle est nécessaire
|
||||
if (task.confidence_score >= self.manual_review_threshold and
|
||||
interactive_mode):
|
||||
|
||||
task.manual_review_required = True
|
||||
task.migration_status = "manual_review"
|
||||
report.manual_review_nodes += 1
|
||||
|
||||
logger.warning(f"⚠️ Nœud {task.node_id} nécessite une révision manuelle")
|
||||
else:
|
||||
task.migration_status = "failed"
|
||||
report.failed_nodes += 1
|
||||
logger.error(f"❌ Échec de migration du nœud {task.node_id}")
|
||||
|
||||
except Exception as e:
|
||||
task.migration_status = "failed"
|
||||
task.error_message = str(e)
|
||||
report.failed_nodes += 1
|
||||
logger.error(f"❌ Erreur lors de la migration du nœud {task.node_id}: {e}")
|
||||
|
||||
# Traiter les révisions manuelles si en mode interactif
|
||||
if interactive_mode and report.manual_review_nodes > 0:
|
||||
await self._handle_manual_reviews(report, workflow_data)
|
||||
|
||||
# Finaliser le rapport
|
||||
report.migration_completed = datetime.now()
|
||||
report.success_rate = (report.migrated_nodes / max(1, report.total_nodes)) * 100
|
||||
|
||||
# Sauvegarder le rapport
|
||||
await self._save_migration_report(report)
|
||||
|
||||
logger.info(f"✅ Migration terminée - Succès: {report.success_rate:.1f}% "
|
||||
f"({report.migrated_nodes}/{report.total_nodes})")
|
||||
|
||||
return report
|
||||
|
||||
except Exception as e:
|
||||
logger.error(f"❌ Erreur critique lors de la migration: {e}")
|
||||
report.migration_completed = datetime.now()
|
||||
report.success_rate = 0.0
|
||||
return report
|
||||
|
||||
async def validate_migrated_workflow(
|
||||
self,
|
||||
workflow_data: Dict[str, Any],
|
||||
migration_report: MigrationReport
|
||||
) -> Dict[str, Any]:
|
||||
"""
|
||||
Valide un workflow migré en testant les cibles visuelles.
|
||||
|
||||
Args:
|
||||
workflow_data: Données du workflow migré
|
||||
migration_report: Rapport de migration
|
||||
|
||||
Returns:
|
||||
Rapport de validation
|
||||
"""
|
||||
logger.info("🔍 Validation du workflow migré")
|
||||
|
||||
validation_report = {
|
||||
'workflow_id': workflow_data.get('id'),
|
||||
'validation_started': datetime.now(),
|
||||
'total_targets': 0,
|
||||
'valid_targets': 0,
|
||||
'invalid_targets': 0,
|
||||
'target_validations': []
|
||||
}
|
||||
|
||||
try:
|
||||
# Capturer l'écran actuel pour la validation
|
||||
current_screen = await self.screen_capturer.capture_screen()
|
||||
screen_state = await self.ui_detector.detect_elements(current_screen)
|
||||
|
||||
# Valider chaque cible migrée
|
||||
for task in migration_report.migration_tasks:
|
||||
if task.migration_status == "completed" and task.visual_target:
|
||||
validation_report['total_targets'] += 1
|
||||
|
||||
# Valider la cible
|
||||
validation_result = await self.validation_manager.validate_target_now(
|
||||
task.visual_target
|
||||
)
|
||||
|
||||
target_validation = {
|
||||
'node_id': task.node_id,
|
||||
'target_signature': task.visual_target.signature,
|
||||
'is_valid': validation_result.is_valid,
|
||||
'confidence': validation_result.confidence,
|
||||
'issues': validation_result.issues
|
||||
}
|
||||
|
||||
validation_report['target_validations'].append(target_validation)
|
||||
|
||||
if validation_result.is_valid:
|
||||
validation_report['valid_targets'] += 1
|
||||
else:
|
||||
validation_report['invalid_targets'] += 1
|
||||
|
||||
validation_report['validation_completed'] = datetime.now()
|
||||
validation_report['success_rate'] = (
|
||||
validation_report['valid_targets'] /
|
||||
max(1, validation_report['total_targets']) * 100
|
||||
)
|
||||
|
||||
logger.info(f"✅ Validation terminée - {validation_report['success_rate']:.1f}% "
|
||||
f"de cibles valides")
|
||||
|
||||
return validation_report
|
||||
|
||||
except Exception as e:
|
||||
logger.error(f"❌ Erreur lors de la validation: {e}")
|
||||
validation_report['error'] = str(e)
|
||||
return validation_report
|
||||
|
||||
async def rollback_migration(
|
||||
self,
|
||||
migration_report: MigrationReport
|
||||
) -> bool:
|
||||
"""
|
||||
Annule une migration en restaurant la sauvegarde.
|
||||
|
||||
Args:
|
||||
migration_report: Rapport de migration à annuler
|
||||
|
||||
Returns:
|
||||
True si le rollback a réussi
|
||||
"""
|
||||
try:
|
||||
if not migration_report.backup_path:
|
||||
logger.error("Aucune sauvegarde disponible pour le rollback")
|
||||
return False
|
||||
|
||||
backup_file = Path(migration_report.backup_path)
|
||||
if not backup_file.exists():
|
||||
logger.error(f"Fichier de sauvegarde non trouvé: {backup_file}")
|
||||
return False
|
||||
|
||||
logger.info(f"🔄 Rollback de la migration {migration_report.workflow_id}")
|
||||
|
||||
# Charger la sauvegarde
|
||||
with open(backup_file, 'r', encoding='utf-8') as f:
|
||||
original_workflow = json.load(f)
|
||||
|
||||
# Supprimer les cibles visuelles créées
|
||||
for task in migration_report.migration_tasks:
|
||||
if task.visual_target:
|
||||
await self.visual_target_manager.remove_target(task.visual_target.signature)
|
||||
|
||||
logger.info("✅ Rollback terminé avec succès")
|
||||
return True
|
||||
|
||||
except Exception as e:
|
||||
logger.error(f"❌ Erreur lors du rollback: {e}")
|
||||
return False
|
||||
|
||||
# Méthodes privées
|
||||
|
||||
async def _analyze_node_for_migration(self, node: Dict[str, Any]) -> Optional[MigrationTask]:
|
||||
"""Analyse un nœud pour déterminer s'il nécessite une migration"""
|
||||
node_id = node.get('id', 'unknown')
|
||||
node_type = node.get('type', 'unknown')
|
||||
parameters = node.get('parameters', {})
|
||||
|
||||
# Chercher des sélecteurs legacy
|
||||
legacy_selectors = {}
|
||||
|
||||
for selector_type in self.supported_selector_types.keys():
|
||||
if selector_type in parameters and parameters[selector_type]:
|
||||
legacy_selectors[selector_type] = parameters[selector_type]
|
||||
|
||||
# Chercher d'autres patterns de sélecteurs
|
||||
legacy_patterns = ['selector', 'target', 'element_selector', 'locator']
|
||||
for pattern in legacy_patterns:
|
||||
if pattern in parameters and parameters[pattern]:
|
||||
# Déterminer le type de sélecteur
|
||||
selector_value = parameters[pattern]
|
||||
if isinstance(selector_value, str):
|
||||
if selector_value.startswith('//') or selector_value.startswith('.//'):
|
||||
legacy_selectors['xpath_selector'] = selector_value
|
||||
elif selector_value.startswith('#') or selector_value.startswith('.'):
|
||||
legacy_selectors['css_selector'] = selector_value
|
||||
else:
|
||||
legacy_selectors['text_selector'] = selector_value
|
||||
|
||||
# Créer une tâche de migration si des sélecteurs legacy sont trouvés
|
||||
if legacy_selectors:
|
||||
return MigrationTask(
|
||||
node_id=node_id,
|
||||
node_type=node_type,
|
||||
legacy_selectors=legacy_selectors
|
||||
)
|
||||
|
||||
return None
|
||||
|
||||
async def _migrate_node_task(
|
||||
self,
|
||||
task: MigrationTask,
|
||||
workflow_data: Dict[str, Any]
|
||||
) -> bool:
|
||||
"""Migre une tâche de nœud spécifique"""
|
||||
try:
|
||||
# Capturer l'écran actuel
|
||||
screenshot = await self.screen_capturer.capture_screen()
|
||||
screen_state = await self.ui_detector.detect_elements(screenshot)
|
||||
|
||||
# Tenter de localiser l'élément avec les sélecteurs legacy
|
||||
target_element = None
|
||||
best_confidence = 0.0
|
||||
|
||||
for selector_type, selector_value in task.legacy_selectors.items():
|
||||
if selector_type in self.supported_selector_types:
|
||||
migration_func = self.supported_selector_types[selector_type]
|
||||
|
||||
element, confidence = await migration_func(
|
||||
selector_value, screen_state, workflow_data
|
||||
)
|
||||
|
||||
if element and confidence > best_confidence:
|
||||
target_element = element
|
||||
best_confidence = confidence
|
||||
|
||||
task.confidence_score = best_confidence
|
||||
|
||||
# Créer une cible visuelle si un élément a été trouvé
|
||||
if target_element and best_confidence >= self.confidence_threshold:
|
||||
visual_target = await self.visual_target_manager.create_target_from_element(
|
||||
target_element, screenshot
|
||||
)
|
||||
|
||||
if visual_target:
|
||||
task.visual_target = visual_target
|
||||
return True
|
||||
|
||||
return False
|
||||
|
||||
except Exception as e:
|
||||
task.error_message = str(e)
|
||||
return False
|
||||
|
||||
async def _migrate_css_selector(
|
||||
self,
|
||||
css_selector: str,
|
||||
screen_state: Any,
|
||||
workflow_data: Dict[str, Any]
|
||||
) -> Tuple[Optional[Any], float]:
|
||||
"""Migre un sélecteur CSS"""
|
||||
try:
|
||||
# Analyser le sélecteur CSS pour extraire des indices
|
||||
confidence = 0.5 # Confiance de base
|
||||
|
||||
# Logique de migration spécifique aux sélecteurs CSS
|
||||
# Cette implémentation est simplifiée
|
||||
|
||||
# Chercher des éléments correspondants par type ou attributs
|
||||
for element in screen_state.ui_elements:
|
||||
# Heuristiques basées sur le sélecteur CSS
|
||||
if '#' in css_selector: # ID selector
|
||||
confidence += 0.2
|
||||
elif '.' in css_selector: # Class selector
|
||||
confidence += 0.1
|
||||
elif css_selector in ['button', 'input', 'a']: # Tag selector
|
||||
if element.element_type.lower() == css_selector:
|
||||
confidence += 0.3
|
||||
return element, confidence
|
||||
|
||||
return None, confidence
|
||||
|
||||
except Exception as e:
|
||||
logger.error(f"Erreur lors de la migration CSS: {e}")
|
||||
return None, 0.0
|
||||
|
||||
async def _migrate_xpath_selector(
|
||||
self,
|
||||
xpath_selector: str,
|
||||
screen_state: Any,
|
||||
workflow_data: Dict[str, Any]
|
||||
) -> Tuple[Optional[Any], float]:
|
||||
"""Migre un sélecteur XPath"""
|
||||
try:
|
||||
confidence = 0.4 # Confiance de base pour XPath
|
||||
|
||||
# Analyser le XPath pour extraire des informations
|
||||
if 'text()' in xpath_selector:
|
||||
# Sélecteur basé sur le texte
|
||||
text_content = self._extract_text_from_xpath(xpath_selector)
|
||||
if text_content:
|
||||
return await self._find_element_by_text(text_content, screen_state)
|
||||
|
||||
if '@id' in xpath_selector:
|
||||
# Sélecteur basé sur l'ID
|
||||
confidence += 0.2
|
||||
|
||||
if 'button' in xpath_selector or 'input' in xpath_selector:
|
||||
# Sélecteur basé sur le type d'élément
|
||||
element_type = self._extract_element_type_from_xpath(xpath_selector)
|
||||
if element_type:
|
||||
for element in screen_state.ui_elements:
|
||||
if element.element_type.lower() == element_type.lower():
|
||||
confidence += 0.3
|
||||
return element, confidence
|
||||
|
||||
return None, confidence
|
||||
|
||||
except Exception as e:
|
||||
logger.error(f"Erreur lors de la migration XPath: {e}")
|
||||
return None, 0.0
|
||||
|
||||
async def _migrate_id_selector(
|
||||
self,
|
||||
id_selector: str,
|
||||
screen_state: Any,
|
||||
workflow_data: Dict[str, Any]
|
||||
) -> Tuple[Optional[Any], float]:
|
||||
"""Migre un sélecteur ID"""
|
||||
# Les sélecteurs ID ont généralement une bonne confiance
|
||||
confidence = 0.8
|
||||
|
||||
# Chercher un élément avec un ID correspondant
|
||||
# (Implémentation simplifiée)
|
||||
return None, confidence
|
||||
|
||||
async def _migrate_class_selector(
|
||||
self,
|
||||
class_selector: str,
|
||||
screen_state: Any,
|
||||
workflow_data: Dict[str, Any]
|
||||
) -> Tuple[Optional[Any], float]:
|
||||
"""Migre un sélecteur de classe"""
|
||||
confidence = 0.6
|
||||
|
||||
# Logique de migration pour les sélecteurs de classe
|
||||
return None, confidence
|
||||
|
||||
async def _migrate_text_selector(
|
||||
self,
|
||||
text_selector: str,
|
||||
screen_state: Any,
|
||||
workflow_data: Dict[str, Any]
|
||||
) -> Tuple[Optional[Any], float]:
|
||||
"""Migre un sélecteur basé sur le texte"""
|
||||
return await self._find_element_by_text(text_selector, screen_state)
|
||||
|
||||
async def _find_element_by_text(
|
||||
self,
|
||||
text: str,
|
||||
screen_state: Any
|
||||
) -> Tuple[Optional[Any], float]:
|
||||
"""Trouve un élément par son contenu textuel"""
|
||||
try:
|
||||
for element in screen_state.ui_elements:
|
||||
if element.text_content and text.lower() in element.text_content.lower():
|
||||
# Calculer la confiance basée sur la correspondance
|
||||
if element.text_content.lower() == text.lower():
|
||||
confidence = 0.9 # Correspondance exacte
|
||||
else:
|
||||
confidence = 0.7 # Correspondance partielle
|
||||
|
||||
return element, confidence
|
||||
|
||||
return None, 0.0
|
||||
|
||||
except Exception as e:
|
||||
logger.error(f"Erreur lors de la recherche par texte: {e}")
|
||||
return None, 0.0
|
||||
|
||||
def _extract_text_from_xpath(self, xpath: str) -> Optional[str]:
|
||||
"""Extrait le texte d'un sélecteur XPath"""
|
||||
try:
|
||||
# Chercher des patterns comme text()='...' ou contains(text(),'...')
|
||||
import re
|
||||
|
||||
# Pattern pour text()='value'
|
||||
match = re.search(r"text\(\)\s*=\s*['\"]([^'\"]+)['\"]", xpath)
|
||||
if match:
|
||||
return match.group(1)
|
||||
|
||||
# Pattern pour contains(text(),'value')
|
||||
match = re.search(r"contains\s*\(\s*text\(\)\s*,\s*['\"]([^'\"]+)['\"]", xpath)
|
||||
if match:
|
||||
return match.group(1)
|
||||
|
||||
return None
|
||||
|
||||
except Exception:
|
||||
return None
|
||||
|
||||
def _extract_element_type_from_xpath(self, xpath: str) -> Optional[str]:
|
||||
"""Extrait le type d'élément d'un sélecteur XPath"""
|
||||
try:
|
||||
# Chercher des patterns comme //button ou //input
|
||||
import re
|
||||
|
||||
match = re.search(r"//(\w+)", xpath)
|
||||
if match:
|
||||
return match.group(1)
|
||||
|
||||
return None
|
||||
|
||||
except Exception:
|
||||
return None
|
||||
|
||||
async def _create_workflow_backup(self, workflow_data: Dict[str, Any]) -> str:
|
||||
"""Crée une sauvegarde du workflow"""
|
||||
workflow_id = workflow_data.get('id', 'unknown')
|
||||
timestamp = datetime.now().strftime("%Y%m%d_%H%M%S")
|
||||
|
||||
backup_filename = f"workflow_{workflow_id}_{timestamp}_backup.json"
|
||||
backup_path = self.migration_storage_path / "backups" / backup_filename
|
||||
|
||||
backup_path.parent.mkdir(parents=True, exist_ok=True)
|
||||
|
||||
with open(backup_path, 'w', encoding='utf-8') as f:
|
||||
json.dump(workflow_data, f, indent=2, ensure_ascii=False, default=str)
|
||||
|
||||
return str(backup_path)
|
||||
|
||||
async def _save_migration_report(self, report: MigrationReport):
|
||||
"""Sauvegarde le rapport de migration"""
|
||||
report_filename = f"migration_report_{report.workflow_id}_{datetime.now().strftime('%Y%m%d_%H%M%S')}.json"
|
||||
report_path = self.migration_storage_path / "reports" / report_filename
|
||||
|
||||
report_path.parent.mkdir(parents=True, exist_ok=True)
|
||||
|
||||
# Convertir le rapport en dictionnaire
|
||||
report_dict = asdict(report)
|
||||
|
||||
# Sérialiser les dates
|
||||
report_dict['migration_started'] = report.migration_started.isoformat()
|
||||
if report.migration_completed:
|
||||
report_dict['migration_completed'] = report.migration_completed.isoformat()
|
||||
|
||||
with open(report_path, 'w', encoding='utf-8') as f:
|
||||
json.dump(report_dict, f, indent=2, ensure_ascii=False, default=str)
|
||||
|
||||
logger.info(f"Rapport de migration sauvegardé: {report_path}")
|
||||
|
||||
async def _handle_manual_reviews(
|
||||
self,
|
||||
report: MigrationReport,
|
||||
workflow_data: Dict[str, Any]
|
||||
):
|
||||
"""Gère les révisions manuelles en mode interactif"""
|
||||
logger.info(f"🔍 {report.manual_review_nodes} nœuds nécessitent une révision manuelle")
|
||||
|
||||
for task in report.migration_tasks:
|
||||
if task.manual_review_required:
|
||||
logger.info(f"📝 Révision manuelle requise pour le nœud {task.node_id}")
|
||||
|
||||
# Dans une vraie implémentation, ceci ouvrirait une interface
|
||||
# pour permettre à l'utilisateur de valider ou corriger la migration
|
||||
|
||||
# Pour l'instant, simuler une validation automatique
|
||||
if task.confidence_score >= 0.7:
|
||||
task.migration_status = "completed"
|
||||
task.manual_review_required = False
|
||||
report.migrated_nodes += 1
|
||||
report.manual_review_nodes -= 1
|
||||
logger.info(f"✅ Révision automatique acceptée pour {task.node_id}")
|
||||
|
||||
def get_migration_statistics(self) -> Dict[str, Any]:
|
||||
"""Récupère les statistiques de migration"""
|
||||
# Compter les fichiers de rapport
|
||||
reports_dir = self.migration_storage_path / "reports"
|
||||
backups_dir = self.migration_storage_path / "backups"
|
||||
|
||||
total_reports = len(list(reports_dir.glob("*.json"))) if reports_dir.exists() else 0
|
||||
total_backups = len(list(backups_dir.glob("*.json"))) if backups_dir.exists() else 0
|
||||
|
||||
return {
|
||||
'total_migrations': total_reports,
|
||||
'total_backups': total_backups,
|
||||
'migration_storage_path': str(self.migration_storage_path),
|
||||
'supported_selector_types': list(self.supported_selector_types.keys()),
|
||||
'confidence_threshold': self.confidence_threshold,
|
||||
'manual_review_threshold': self.manual_review_threshold
|
||||
}
|
||||
Reference in New Issue
Block a user