feat(vwb): Intégration UI-DETR-1 + Toggle mode Basique/Intelligent/Debug
- Toggle 3 modes dans le header: Basique (coords fixes), Intelligent (vision IA), Debug (overlay) - Service UI-DETR-1 pour détection d'éléments UI (510MB model, ~800ms/image) - API endpoints: /api/ui-detection/detect, /preload, /status, /find-element - Overlay des bboxes détectées en mode Debug (miniature + plein écran) - Clic sur élément détecté pour le sélectionner comme ancre - Document de vision produit: docs/VISION_RPA_INTELLIGENT.md - Configuration CORS étendue pour ports locaux Co-Authored-By: Claude Opus 4.5 <noreply@anthropic.com>
This commit is contained in:
237
visual_workflow_builder/backend/api/ui_detection.py
Normal file
237
visual_workflow_builder/backend/api/ui_detection.py
Normal file
@@ -0,0 +1,237 @@
|
||||
"""
|
||||
API Blueprint pour la détection UI avec UI-DETR-1
|
||||
"""
|
||||
|
||||
from flask import Blueprint, request, jsonify
|
||||
from flask_cors import cross_origin
|
||||
import base64
|
||||
import io
|
||||
from PIL import Image
|
||||
|
||||
ui_detection_bp = Blueprint('ui_detection', __name__, url_prefix='/api/ui-detection')
|
||||
|
||||
# Import lazy du service (le modèle est lourd)
|
||||
_service = None
|
||||
|
||||
|
||||
def get_service():
|
||||
"""Lazy loading du service de détection"""
|
||||
global _service
|
||||
if _service is None:
|
||||
from services.ui_detection_service import (
|
||||
detect_from_base64,
|
||||
detect_from_file,
|
||||
annotated_image_to_base64,
|
||||
preload_model
|
||||
)
|
||||
_service = {
|
||||
'detect_from_base64': detect_from_base64,
|
||||
'detect_from_file': detect_from_file,
|
||||
'annotated_image_to_base64': annotated_image_to_base64,
|
||||
'preload_model': preload_model
|
||||
}
|
||||
return _service
|
||||
|
||||
|
||||
@ui_detection_bp.route('/detect', methods=['POST'])
|
||||
@cross_origin()
|
||||
def detect_ui_elements():
|
||||
"""
|
||||
Détecte les éléments UI dans une image
|
||||
|
||||
Request body (JSON):
|
||||
- image_base64: Image encodée en base64 (requis)
|
||||
- threshold: Seuil de confiance (optionnel, défaut: 0.35)
|
||||
- annotate: Retourner l'image annotée (optionnel, défaut: false)
|
||||
- show_confidence: Afficher les scores sur l'image annotée (optionnel, défaut: false)
|
||||
|
||||
Response:
|
||||
- success: bool
|
||||
- result: {
|
||||
elements: [...],
|
||||
count: int,
|
||||
processing_time_ms: float,
|
||||
image_size: {width, height},
|
||||
model: str,
|
||||
annotated_image_base64?: str (si annotate=true)
|
||||
}
|
||||
"""
|
||||
try:
|
||||
data = request.get_json()
|
||||
|
||||
if not data or 'image_base64' not in data:
|
||||
return jsonify({
|
||||
'success': False,
|
||||
'error': 'image_base64 est requis'
|
||||
}), 400
|
||||
|
||||
image_base64 = data['image_base64']
|
||||
threshold = data.get('threshold', 0.35)
|
||||
annotate = data.get('annotate', False)
|
||||
show_confidence = data.get('show_confidence', False)
|
||||
|
||||
# Valider le threshold
|
||||
threshold = max(0.1, min(1.0, float(threshold)))
|
||||
|
||||
service = get_service()
|
||||
|
||||
# Détecter les éléments
|
||||
result = service['detect_from_base64'](image_base64, threshold)
|
||||
response_data = result.to_dict()
|
||||
|
||||
# Générer l'image annotée si demandé
|
||||
if annotate:
|
||||
# Décoder l'image originale
|
||||
if ',' in image_base64:
|
||||
image_base64_clean = image_base64.split(',')[1]
|
||||
else:
|
||||
image_base64_clean = image_base64
|
||||
|
||||
image_bytes = base64.b64decode(image_base64_clean)
|
||||
image = Image.open(io.BytesIO(image_bytes))
|
||||
|
||||
# Créer l'image annotée
|
||||
annotated_b64 = service['annotated_image_to_base64'](
|
||||
image, result,
|
||||
show_ids=True,
|
||||
show_confidence=show_confidence
|
||||
)
|
||||
response_data['annotated_image_base64'] = f"data:image/png;base64,{annotated_b64}"
|
||||
|
||||
return jsonify({
|
||||
'success': True,
|
||||
'result': response_data
|
||||
})
|
||||
|
||||
except Exception as e:
|
||||
import traceback
|
||||
traceback.print_exc()
|
||||
return jsonify({
|
||||
'success': False,
|
||||
'error': str(e)
|
||||
}), 500
|
||||
|
||||
|
||||
@ui_detection_bp.route('/preload', methods=['POST'])
|
||||
@cross_origin()
|
||||
def preload_model():
|
||||
"""
|
||||
Précharge le modèle UI-DETR-1 en mémoire
|
||||
|
||||
Utile pour éviter la latence du premier appel
|
||||
"""
|
||||
try:
|
||||
service = get_service()
|
||||
service['preload_model']()
|
||||
|
||||
return jsonify({
|
||||
'success': True,
|
||||
'message': 'Modèle en cours de chargement'
|
||||
})
|
||||
|
||||
except Exception as e:
|
||||
return jsonify({
|
||||
'success': False,
|
||||
'error': str(e)
|
||||
}), 500
|
||||
|
||||
|
||||
@ui_detection_bp.route('/status', methods=['GET'])
|
||||
@cross_origin()
|
||||
def get_status():
|
||||
"""
|
||||
Retourne le statut du service de détection
|
||||
"""
|
||||
try:
|
||||
from services.ui_detection_service import _model, MODEL_PATH
|
||||
import os
|
||||
|
||||
model_exists = os.path.exists(MODEL_PATH)
|
||||
model_loaded = _model is not None
|
||||
|
||||
return jsonify({
|
||||
'success': True,
|
||||
'status': {
|
||||
'model_path': MODEL_PATH,
|
||||
'model_exists': model_exists,
|
||||
'model_loaded': model_loaded,
|
||||
'model_name': 'UI-DETR-1',
|
||||
'default_threshold': 0.35
|
||||
}
|
||||
})
|
||||
|
||||
except Exception as e:
|
||||
return jsonify({
|
||||
'success': False,
|
||||
'error': str(e)
|
||||
}), 500
|
||||
|
||||
|
||||
@ui_detection_bp.route('/find-element', methods=['POST'])
|
||||
@cross_origin()
|
||||
def find_element():
|
||||
"""
|
||||
Trouve un élément spécifique dans l'image en utilisant une ancre de référence
|
||||
|
||||
Request body (JSON):
|
||||
- image_base64: Screenshot actuel
|
||||
- anchor_base64: Image de l'ancre à trouver
|
||||
- threshold: Seuil de confiance (optionnel)
|
||||
|
||||
Response:
|
||||
- success: bool
|
||||
- result: {
|
||||
found: bool,
|
||||
element: {...} ou null,
|
||||
all_elements: [...],
|
||||
match_score: float
|
||||
}
|
||||
|
||||
Note: Cette fonction utilise la détection + comparaison d'embedding CLIP
|
||||
"""
|
||||
try:
|
||||
data = request.get_json()
|
||||
|
||||
if not data or 'image_base64' not in data:
|
||||
return jsonify({
|
||||
'success': False,
|
||||
'error': 'image_base64 est requis'
|
||||
}), 400
|
||||
|
||||
image_base64 = data['image_base64']
|
||||
anchor_base64 = data.get('anchor_base64')
|
||||
threshold = data.get('threshold', 0.35)
|
||||
|
||||
service = get_service()
|
||||
|
||||
# Détecter tous les éléments
|
||||
result = service['detect_from_base64'](image_base64, threshold)
|
||||
|
||||
response = {
|
||||
'found': False,
|
||||
'element': None,
|
||||
'all_elements': [e.to_dict() for e in result.elements],
|
||||
'count': len(result.elements),
|
||||
'match_score': 0.0
|
||||
}
|
||||
|
||||
# Si une ancre est fournie, essayer de la matcher
|
||||
if anchor_base64 and len(result.elements) > 0:
|
||||
# TODO: Intégrer CLIP pour le matching d'ancre
|
||||
# Pour l'instant, retourner le premier élément comme placeholder
|
||||
response['found'] = True
|
||||
response['element'] = result.elements[0].to_dict()
|
||||
response['match_score'] = 0.5 # Placeholder
|
||||
|
||||
return jsonify({
|
||||
'success': True,
|
||||
'result': response
|
||||
})
|
||||
|
||||
except Exception as e:
|
||||
import traceback
|
||||
traceback.print_exc()
|
||||
return jsonify({
|
||||
'success': False,
|
||||
'error': str(e)
|
||||
}), 500
|
||||
@@ -39,10 +39,10 @@ socketio = SocketIO(
|
||||
engineio_logger=True
|
||||
)
|
||||
|
||||
# Enable CORS
|
||||
# Enable CORS - autoriser tous les ports locaux en développement
|
||||
CORS(app, resources={
|
||||
r"/api/*": {
|
||||
"origins": os.getenv('CORS_ORIGINS', 'http://localhost:3000').split(','),
|
||||
"origins": os.getenv('CORS_ORIGINS', 'http://localhost:3000,http://localhost:3001,http://localhost:3002,http://localhost:3003,http://localhost:3004,http://localhost:5173').split(','),
|
||||
"methods": ["GET", "POST", "PUT", "DELETE", "OPTIONS"],
|
||||
"allow_headers": ["Content-Type", "Authorization"]
|
||||
}
|
||||
@@ -150,6 +150,14 @@ try:
|
||||
except ImportError as e:
|
||||
print(f"⚠️ Blueprint anchor_images désactivé: {e}")
|
||||
|
||||
# API UI Detection - UI-DETR-1
|
||||
try:
|
||||
from api.ui_detection import ui_detection_bp
|
||||
app.register_blueprint(ui_detection_bp)
|
||||
print("✅ Blueprint ui_detection (UI-DETR-1) enregistré - /api/ui-detection/*")
|
||||
except ImportError as e:
|
||||
print(f"⚠️ Blueprint ui_detection désactivé: {e}")
|
||||
|
||||
# ============================================================
|
||||
# API V3 - Thin Client Architecture (Source de Vérité Unique)
|
||||
# ============================================================
|
||||
|
||||
298
visual_workflow_builder/backend/services/ui_detection_service.py
Normal file
298
visual_workflow_builder/backend/services/ui_detection_service.py
Normal file
@@ -0,0 +1,298 @@
|
||||
"""
|
||||
Service de détection UI utilisant UI-DETR-1
|
||||
Détecte les éléments d'interface utilisateur dans un screenshot
|
||||
"""
|
||||
|
||||
import os
|
||||
import time
|
||||
import base64
|
||||
import io
|
||||
from typing import List, Dict, Any, Optional
|
||||
from dataclasses import dataclass
|
||||
import numpy as np
|
||||
from PIL import Image
|
||||
|
||||
# Configuration du modèle
|
||||
MODEL_PATH = "/home/dom/ai/rpa_vision_v3/models/ui-detr-1/model.pth"
|
||||
CONFIDENCE_THRESHOLD = 0.35
|
||||
RESOLUTION = 1600
|
||||
|
||||
# Instance globale du modèle (lazy loading)
|
||||
_model = None
|
||||
_model_loading = False
|
||||
|
||||
|
||||
@dataclass
|
||||
class UIElement:
|
||||
"""Élément UI détecté"""
|
||||
id: int
|
||||
bbox: Dict[str, int] # x1, y1, x2, y2
|
||||
center: Dict[str, int] # x, y
|
||||
confidence: float
|
||||
area: int
|
||||
|
||||
def to_dict(self) -> Dict[str, Any]:
|
||||
return {
|
||||
"id": self.id,
|
||||
"bbox": self.bbox,
|
||||
"center": self.center,
|
||||
"confidence": round(self.confidence, 3),
|
||||
"area": self.area
|
||||
}
|
||||
|
||||
|
||||
@dataclass
|
||||
class DetectionResult:
|
||||
"""Résultat de détection"""
|
||||
elements: List[UIElement]
|
||||
processing_time_ms: float
|
||||
image_size: Dict[str, int]
|
||||
model_name: str = "UI-DETR-1"
|
||||
|
||||
def to_dict(self) -> Dict[str, Any]:
|
||||
return {
|
||||
"elements": [e.to_dict() for e in self.elements],
|
||||
"count": len(self.elements),
|
||||
"processing_time_ms": round(self.processing_time_ms, 1),
|
||||
"image_size": self.image_size,
|
||||
"model": self.model_name
|
||||
}
|
||||
|
||||
|
||||
def load_model():
|
||||
"""Charge le modèle UI-DETR-1 (lazy loading)"""
|
||||
global _model, _model_loading
|
||||
|
||||
if _model is not None:
|
||||
return _model
|
||||
|
||||
if _model_loading:
|
||||
# Attendre que le chargement soit terminé
|
||||
while _model_loading and _model is None:
|
||||
time.sleep(0.1)
|
||||
return _model
|
||||
|
||||
_model_loading = True
|
||||
|
||||
try:
|
||||
print(f"[UI-DETR-1] Chargement du modèle depuis {MODEL_PATH}...")
|
||||
start = time.time()
|
||||
|
||||
from rfdetr.detr import RFDETRMedium
|
||||
|
||||
if not os.path.exists(MODEL_PATH):
|
||||
raise FileNotFoundError(f"Modèle non trouvé: {MODEL_PATH}")
|
||||
|
||||
_model = RFDETRMedium(pretrain_weights=MODEL_PATH, resolution=RESOLUTION)
|
||||
|
||||
elapsed = time.time() - start
|
||||
print(f"[UI-DETR-1] Modèle chargé en {elapsed:.1f}s")
|
||||
|
||||
return _model
|
||||
|
||||
except Exception as e:
|
||||
print(f"[UI-DETR-1] Erreur chargement modèle: {e}")
|
||||
_model_loading = False
|
||||
raise
|
||||
finally:
|
||||
_model_loading = False
|
||||
|
||||
|
||||
def detect_ui_elements(
|
||||
image: Image.Image,
|
||||
threshold: float = CONFIDENCE_THRESHOLD
|
||||
) -> DetectionResult:
|
||||
"""
|
||||
Détecte les éléments UI dans une image
|
||||
|
||||
Args:
|
||||
image: Image PIL
|
||||
threshold: Seuil de confiance (0-1)
|
||||
|
||||
Returns:
|
||||
DetectionResult avec la liste des éléments détectés
|
||||
"""
|
||||
start_time = time.time()
|
||||
|
||||
# Charger le modèle
|
||||
model = load_model()
|
||||
|
||||
# Convertir en numpy array RGB
|
||||
image_np = np.array(image.convert('RGB'))
|
||||
|
||||
# Exécuter la détection
|
||||
detections = model.predict(image_np, threshold=threshold)
|
||||
|
||||
# Parser les résultats
|
||||
elements = []
|
||||
boxes = detections.xyxy # [x1, y1, x2, y2]
|
||||
scores = detections.confidence
|
||||
|
||||
for i, (box, score) in enumerate(zip(boxes, scores)):
|
||||
x1, y1, x2, y2 = map(int, box)
|
||||
|
||||
element = UIElement(
|
||||
id=i,
|
||||
bbox={"x1": x1, "y1": y1, "x2": x2, "y2": y2},
|
||||
center={"x": (x1 + x2) // 2, "y": (y1 + y2) // 2},
|
||||
confidence=float(score),
|
||||
area=(x2 - x1) * (y2 - y1)
|
||||
)
|
||||
elements.append(element)
|
||||
|
||||
# Trier par position (haut-gauche vers bas-droite)
|
||||
elements.sort(key=lambda e: (e.bbox["y1"], e.bbox["x1"]))
|
||||
|
||||
# Réassigner les IDs après tri
|
||||
for i, elem in enumerate(elements):
|
||||
elem.id = i
|
||||
|
||||
processing_time = (time.time() - start_time) * 1000
|
||||
|
||||
return DetectionResult(
|
||||
elements=elements,
|
||||
processing_time_ms=processing_time,
|
||||
image_size={"width": image.width, "height": image.height}
|
||||
)
|
||||
|
||||
|
||||
def detect_from_base64(
|
||||
image_base64: str,
|
||||
threshold: float = CONFIDENCE_THRESHOLD
|
||||
) -> DetectionResult:
|
||||
"""
|
||||
Détecte les éléments UI depuis une image base64
|
||||
|
||||
Args:
|
||||
image_base64: Image encodée en base64 (avec ou sans préfixe data:image/...)
|
||||
threshold: Seuil de confiance
|
||||
|
||||
Returns:
|
||||
DetectionResult
|
||||
"""
|
||||
# Retirer le préfixe data:image/... si présent
|
||||
if ',' in image_base64:
|
||||
image_base64 = image_base64.split(',')[1]
|
||||
|
||||
# Décoder
|
||||
image_bytes = base64.b64decode(image_base64)
|
||||
image = Image.open(io.BytesIO(image_bytes))
|
||||
|
||||
return detect_ui_elements(image, threshold)
|
||||
|
||||
|
||||
def detect_from_file(
|
||||
file_path: str,
|
||||
threshold: float = CONFIDENCE_THRESHOLD
|
||||
) -> DetectionResult:
|
||||
"""
|
||||
Détecte les éléments UI depuis un fichier image
|
||||
|
||||
Args:
|
||||
file_path: Chemin vers l'image
|
||||
threshold: Seuil de confiance
|
||||
|
||||
Returns:
|
||||
DetectionResult
|
||||
"""
|
||||
image = Image.open(file_path)
|
||||
return detect_ui_elements(image, threshold)
|
||||
|
||||
|
||||
def create_annotated_image(
|
||||
image: Image.Image,
|
||||
detection_result: DetectionResult,
|
||||
show_ids: bool = True,
|
||||
show_confidence: bool = False
|
||||
) -> Image.Image:
|
||||
"""
|
||||
Crée une image annotée avec les bboxes et IDs
|
||||
|
||||
Args:
|
||||
image: Image originale
|
||||
detection_result: Résultat de détection
|
||||
show_ids: Afficher les numéros d'ID
|
||||
show_confidence: Afficher les scores de confiance
|
||||
|
||||
Returns:
|
||||
Image annotée
|
||||
"""
|
||||
from PIL import ImageDraw, ImageFont
|
||||
|
||||
# Copier l'image
|
||||
annotated = image.copy()
|
||||
draw = ImageDraw.Draw(annotated)
|
||||
|
||||
# Essayer de charger une police, sinon utiliser la police par défaut
|
||||
try:
|
||||
font = ImageFont.truetype("/usr/share/fonts/truetype/dejavu/DejaVuSans-Bold.ttf", 14)
|
||||
small_font = ImageFont.truetype("/usr/share/fonts/truetype/dejavu/DejaVuSans.ttf", 10)
|
||||
except:
|
||||
font = ImageFont.load_default()
|
||||
small_font = font
|
||||
|
||||
# Couleurs pour les bboxes
|
||||
bbox_color = (233, 69, 96) # Rouge/rose
|
||||
text_bg_color = (233, 69, 96)
|
||||
text_color = (255, 255, 255)
|
||||
|
||||
for elem in detection_result.elements:
|
||||
bbox = elem.bbox
|
||||
x1, y1, x2, y2 = bbox["x1"], bbox["y1"], bbox["x2"], bbox["y2"]
|
||||
|
||||
# Dessiner la bbox
|
||||
draw.rectangle([x1, y1, x2, y2], outline=bbox_color, width=2)
|
||||
|
||||
if show_ids:
|
||||
# Texte à afficher
|
||||
label = str(elem.id)
|
||||
if show_confidence:
|
||||
label += f" ({elem.confidence:.0%})"
|
||||
|
||||
# Mesurer le texte
|
||||
text_bbox = draw.textbbox((0, 0), label, font=font)
|
||||
text_width = text_bbox[2] - text_bbox[0]
|
||||
text_height = text_bbox[3] - text_bbox[1]
|
||||
|
||||
# Position du label (en haut à gauche de la bbox)
|
||||
label_x = x1
|
||||
label_y = y1 - text_height - 4
|
||||
if label_y < 0:
|
||||
label_y = y1 + 2
|
||||
|
||||
# Fond du label
|
||||
draw.rectangle(
|
||||
[label_x - 2, label_y - 2, label_x + text_width + 4, label_y + text_height + 2],
|
||||
fill=text_bg_color
|
||||
)
|
||||
|
||||
# Texte du label
|
||||
draw.text((label_x, label_y), label, fill=text_color, font=font)
|
||||
|
||||
return annotated
|
||||
|
||||
|
||||
def annotated_image_to_base64(
|
||||
image: Image.Image,
|
||||
detection_result: DetectionResult,
|
||||
show_ids: bool = True,
|
||||
show_confidence: bool = False
|
||||
) -> str:
|
||||
"""
|
||||
Crée une image annotée et la retourne en base64
|
||||
"""
|
||||
annotated = create_annotated_image(image, detection_result, show_ids, show_confidence)
|
||||
|
||||
buffer = io.BytesIO()
|
||||
annotated.save(buffer, format='PNG')
|
||||
buffer.seek(0)
|
||||
|
||||
return base64.b64encode(buffer.read()).decode('utf-8')
|
||||
|
||||
|
||||
# Préchargement optionnel
|
||||
def preload_model():
|
||||
"""Précharge le modèle en arrière-plan"""
|
||||
import threading
|
||||
thread = threading.Thread(target=load_model, daemon=True)
|
||||
thread.start()
|
||||
Reference in New Issue
Block a user