- Skip crops < 40px (deviner type par forme, confidence 0.3)
- Retry 1 fois si réponse VLM vide
- Extraction JSON robuste : cherche {…} dans le texte, fixe single quotes
- Élimine ~70% des appels VLM inutiles sur les petits éléments
Co-Authored-By: Claude Opus 4.6 (1M context) <noreply@anthropic.com>
524 lines
19 KiB
Python
524 lines
19 KiB
Python
"""
|
|
OllamaClient - Client pour Vision-Language Models via Ollama
|
|
|
|
Interface pour communiquer avec des VLM (Qwen, LLaVA, etc.) via Ollama.
|
|
"""
|
|
|
|
import logging
|
|
from typing import Dict, List, Optional, Any
|
|
import requests
|
|
import json
|
|
import base64
|
|
from pathlib import Path
|
|
from PIL import Image
|
|
import io
|
|
|
|
logger = logging.getLogger(__name__)
|
|
|
|
|
|
class OllamaClient:
|
|
"""
|
|
Client Ollama pour VLM
|
|
|
|
Permet d'envoyer des images et prompts à un VLM via l'API Ollama.
|
|
"""
|
|
|
|
def __init__(self,
|
|
endpoint: str = "http://localhost:11434",
|
|
model: str = "qwen3-vl:8b",
|
|
timeout: int = 60):
|
|
"""
|
|
Initialiser le client Ollama
|
|
|
|
Args:
|
|
endpoint: URL de l'API Ollama
|
|
model: Nom du modèle VLM à utiliser
|
|
timeout: Timeout en secondes
|
|
"""
|
|
self.endpoint = endpoint.rstrip('/')
|
|
self.model = model
|
|
self.timeout = timeout
|
|
self._check_connection()
|
|
|
|
def _check_connection(self) -> bool:
|
|
"""Vérifier la connexion à Ollama"""
|
|
try:
|
|
response = requests.get(f"{self.endpoint}/api/tags", timeout=5)
|
|
if response.status_code == 200:
|
|
models = response.json().get('models', [])
|
|
model_names = [m['name'] for m in models]
|
|
if self.model not in model_names:
|
|
logger.warning(f" Model '{self.model}' not found in Ollama")
|
|
logger.info(f"Available models: {model_names}")
|
|
return True
|
|
except Exception as e:
|
|
logger.warning(f" Cannot connect to Ollama at {self.endpoint}: {e}")
|
|
return False
|
|
return False
|
|
|
|
def generate(self,
|
|
prompt: str,
|
|
image_path: Optional[str] = None,
|
|
image: Optional[Image.Image] = None,
|
|
system_prompt: Optional[str] = None,
|
|
temperature: float = 0.1,
|
|
max_tokens: int = 500,
|
|
force_json: bool = False) -> Dict[str, Any]:
|
|
"""
|
|
Générer une réponse du VLM via l'API chat d'Ollama.
|
|
|
|
Note: On utilise /api/chat au lieu de /api/generate car qwen3-vl
|
|
avec /api/generate consomme tous les tokens en thinking interne
|
|
et retourne une réponse vide. L'API chat gère correctement
|
|
le mode /no_think et sépare thinking/réponse.
|
|
|
|
Args:
|
|
prompt: Prompt textuel
|
|
image_path: Chemin vers une image (optionnel)
|
|
image: Image PIL (optionnel)
|
|
system_prompt: Prompt système (optionnel)
|
|
temperature: Température de génération
|
|
max_tokens: Nombre max de tokens
|
|
force_json: Forcer la sortie JSON (non recommandé pour qwen3-vl)
|
|
|
|
Returns:
|
|
Dict avec 'response', 'success', 'error'
|
|
"""
|
|
try:
|
|
# Préparer l'image si fournie
|
|
image_data = None
|
|
if image_path:
|
|
image_data = self._encode_image_from_path(image_path)
|
|
elif image:
|
|
image_data = self._encode_image_from_pil(image)
|
|
|
|
# Nettoyer le prompt — retirer /no_think et /nothink du texte
|
|
# car le mode thinking est contrôlé via le paramètre think=false
|
|
# de l'API chat. Les préfixes /no_think dans le prompt causent
|
|
# paradoxalement PLUS de thinking interne chez qwen3-vl.
|
|
effective_prompt = prompt.replace("/no_think\n", "").replace("/no_think", "")
|
|
effective_prompt = effective_prompt.replace("/nothink ", "").replace("/nothink", "")
|
|
effective_prompt = effective_prompt.strip()
|
|
|
|
# Construire le message utilisateur
|
|
user_message = {"role": "user", "content": effective_prompt}
|
|
if image_data:
|
|
user_message["images"] = [image_data]
|
|
|
|
# Construire les messages
|
|
messages = []
|
|
if system_prompt:
|
|
messages.append({"role": "system", "content": system_prompt})
|
|
messages.append(user_message)
|
|
|
|
# Déterminer si le modèle supporte le thinking
|
|
is_thinking_model = "qwen3" in self.model.lower()
|
|
|
|
payload = {
|
|
"model": self.model,
|
|
"messages": messages,
|
|
"stream": False,
|
|
"options": {
|
|
"temperature": temperature,
|
|
"num_predict": max_tokens,
|
|
"num_ctx": 2048,
|
|
"top_k": 1
|
|
}
|
|
}
|
|
|
|
# Désactiver le thinking pour les modèles qui le supportent
|
|
# Cela réduit drastiquement la consommation de tokens et le temps
|
|
if is_thinking_model:
|
|
payload["think"] = False
|
|
|
|
if force_json:
|
|
payload["format"] = "json"
|
|
|
|
# Envoyer la requête via l'API chat
|
|
response = requests.post(
|
|
f"{self.endpoint}/api/chat",
|
|
json=payload,
|
|
timeout=self.timeout
|
|
)
|
|
|
|
if response.status_code == 200:
|
|
result = response.json()
|
|
content = result.get("message", {}).get("content", "")
|
|
return {
|
|
"response": content,
|
|
"success": True,
|
|
"error": None
|
|
}
|
|
else:
|
|
return {
|
|
"response": "",
|
|
"success": False,
|
|
"error": f"HTTP {response.status_code}: {response.text}"
|
|
}
|
|
|
|
except Exception as e:
|
|
return {
|
|
"response": "",
|
|
"success": False,
|
|
"error": str(e)
|
|
}
|
|
|
|
def detect_ui_elements(self, image_path: str) -> Dict[str, Any]:
|
|
"""
|
|
Détecter les éléments UI dans une image
|
|
|
|
Args:
|
|
image_path: Chemin vers le screenshot
|
|
|
|
Returns:
|
|
Dict avec liste d'éléments détectés
|
|
"""
|
|
prompt = """Analyze this screenshot and list all interactive UI elements you can see.
|
|
For each element, provide:
|
|
- Type (button, text_input, checkbox, radio, dropdown, tab, link, icon, table_row, menu_item)
|
|
- Position (approximate x, y coordinates)
|
|
- Label or text content
|
|
- Semantic role (primary_action, cancel, submit, form_input, search_field, navigation, settings, close)
|
|
|
|
Format your response as JSON."""
|
|
|
|
result = self.generate(prompt, image_path=image_path, temperature=0.1)
|
|
|
|
if result["success"]:
|
|
try:
|
|
# Parser la réponse JSON
|
|
elements = json.loads(result["response"])
|
|
return {"elements": elements, "success": True}
|
|
except json.JSONDecodeError:
|
|
# Si pas JSON valide, retourner texte brut
|
|
return {"elements": [], "success": False, "raw_response": result["response"]}
|
|
|
|
return {"elements": [], "success": False, "error": result["error"]}
|
|
|
|
def classify_element_type(self,
|
|
element_image: Image.Image,
|
|
context: Optional[str] = None) -> Dict[str, Any]:
|
|
"""
|
|
Classifier le type d'un élément UI
|
|
|
|
Args:
|
|
element_image: Image de l'élément
|
|
context: Contexte additionnel
|
|
|
|
Returns:
|
|
Dict avec 'type' et 'confidence'
|
|
"""
|
|
types_list = "button, text_input, checkbox, radio, dropdown, tab, link, icon, table_row, menu_item"
|
|
|
|
prompt = f"""What type of UI element is this?
|
|
Choose ONLY ONE from: {types_list}
|
|
|
|
Respond with just the type name, nothing else."""
|
|
|
|
if context:
|
|
prompt += f"\n\nContext: {context}"
|
|
|
|
result = self.generate(prompt, image=element_image, temperature=0.1)
|
|
|
|
if result["success"]:
|
|
element_type = result["response"].strip().lower()
|
|
# Valider que c'est un type connu
|
|
valid_types = types_list.split(", ")
|
|
if element_type in valid_types:
|
|
return {"type": element_type, "confidence": 0.9, "success": True}
|
|
else:
|
|
# Essayer de trouver le type le plus proche
|
|
for vtype in valid_types:
|
|
if vtype in element_type:
|
|
return {"type": vtype, "confidence": 0.7, "success": True}
|
|
|
|
return {"type": "unknown", "confidence": 0.0, "success": False}
|
|
|
|
def classify_element_role(self,
|
|
element_image: Image.Image,
|
|
element_type: str,
|
|
context: Optional[str] = None) -> Dict[str, Any]:
|
|
"""
|
|
Classifier le rôle sémantique d'un élément
|
|
|
|
Args:
|
|
element_image: Image de l'élément
|
|
element_type: Type de l'élément
|
|
context: Contexte additionnel
|
|
|
|
Returns:
|
|
Dict avec 'role' et 'confidence'
|
|
"""
|
|
roles_list = "primary_action, cancel, submit, form_input, search_field, navigation, settings, close, delete, edit, save"
|
|
|
|
prompt = f"""This is a {element_type}. What is its semantic role or purpose?
|
|
Choose ONLY ONE from: {roles_list}
|
|
|
|
Respond with just the role name, nothing else."""
|
|
|
|
if context:
|
|
prompt += f"\n\nContext: {context}"
|
|
|
|
result = self.generate(prompt, image=element_image, temperature=0.1)
|
|
|
|
if result["success"]:
|
|
role = result["response"].strip().lower()
|
|
# Valider que c'est un rôle connu
|
|
valid_roles = roles_list.split(", ")
|
|
if role in valid_roles:
|
|
return {"role": role, "confidence": 0.9, "success": True}
|
|
else:
|
|
# Essayer de trouver le rôle le plus proche
|
|
for vrole in valid_roles:
|
|
if vrole in role:
|
|
return {"role": vrole, "confidence": 0.7, "success": True}
|
|
|
|
return {"role": "unknown", "confidence": 0.0, "success": False}
|
|
|
|
def extract_text(self, image: Image.Image) -> Dict[str, Any]:
|
|
"""
|
|
Extraire le texte d'une image
|
|
|
|
Args:
|
|
image: Image PIL
|
|
|
|
Returns:
|
|
Dict avec 'text' extrait
|
|
"""
|
|
prompt = "Extract all visible text from this image. Return only the text, nothing else."
|
|
|
|
result = self.generate(prompt, image=image, temperature=0.1)
|
|
|
|
if result["success"]:
|
|
return {"text": result["response"].strip(), "success": True}
|
|
|
|
return {"text": "", "success": False, "error": result["error"]}
|
|
|
|
# Taille minimum pour une classification fiable par le VLM
|
|
_MIN_CLASSIFY_SIZE = 40
|
|
|
|
def classify_element_complete(self, element_image: Image.Image) -> Dict[str, Any]:
|
|
"""
|
|
Classifier complètement un élément UI en UN SEUL appel VLM.
|
|
|
|
Optimisations :
|
|
- Skip les crops < 40px (le VLM ne peut rien en tirer)
|
|
- Retry 1 fois si réponse vide
|
|
- Extraction JSON robuste (cherche {…} même dans du texte)
|
|
"""
|
|
# Skip les images trop petites — deviner par la forme
|
|
if (element_image.width < self._MIN_CLASSIFY_SIZE
|
|
or element_image.height < self._MIN_CLASSIFY_SIZE):
|
|
ratio = element_image.width / max(element_image.height, 1)
|
|
if ratio > 3:
|
|
guessed = "link"
|
|
elif element_image.width < 24 and element_image.height < 24:
|
|
guessed = "icon"
|
|
else:
|
|
guessed = "button"
|
|
return {
|
|
"type": guessed, "role": "unknown", "text": "",
|
|
"confidence": 0.3, "success": True,
|
|
}
|
|
|
|
prompt = """Classify this UI element. Reply with ONLY a JSON object.
|
|
Types: button, text_input, checkbox, radio, dropdown, tab, link, icon, table_row, menu_item
|
|
Roles: primary_action, cancel, submit, form_input, search_field, navigation, settings, close, delete, edit, save
|
|
Example: {"type": "button", "role": "submit", "text": "OK"}
|
|
Your answer:"""
|
|
|
|
# Retry une fois si réponse vide
|
|
for attempt in range(2):
|
|
result = self.generate(
|
|
prompt,
|
|
image=element_image,
|
|
temperature=0.1,
|
|
max_tokens=200,
|
|
force_json=False
|
|
)
|
|
|
|
if not result["success"]:
|
|
continue
|
|
|
|
response_text = result["response"].strip()
|
|
if not response_text:
|
|
if attempt == 0:
|
|
continue
|
|
break
|
|
|
|
# Extraction JSON robuste
|
|
parsed = self._extract_json_from_response(response_text)
|
|
if parsed is not None:
|
|
return self._validate_classification(parsed)
|
|
|
|
return {
|
|
"type": "unknown", "role": "unknown", "text": "",
|
|
"confidence": 0.0, "success": False,
|
|
"error": "VLM returned empty or unparseable response"
|
|
}
|
|
|
|
def _extract_json_from_response(self, text: str) -> Optional[Dict]:
|
|
"""Extrait un objet JSON d'une réponse VLM, même si entouré de texte."""
|
|
import re as _re
|
|
|
|
# Nettoyer le markdown
|
|
if "```" in text:
|
|
lines = text.split("\n")
|
|
text = "\n".join([l for l in lines if not l.startswith("```")])
|
|
text = text.strip()
|
|
|
|
# Essai 1 : parse direct
|
|
try:
|
|
return json.loads(text)
|
|
except json.JSONDecodeError:
|
|
pass
|
|
|
|
# Essai 2 : trouver {…} dans le texte
|
|
match = _re.search(r'\{[^{}]+\}', text)
|
|
if match:
|
|
try:
|
|
return json.loads(match.group())
|
|
except json.JSONDecodeError:
|
|
pass
|
|
|
|
# Essai 3 : fixer les single quotes
|
|
fixed = text.replace("'", '"')
|
|
match = _re.search(r'\{[^{}]+\}', fixed)
|
|
if match:
|
|
try:
|
|
return json.loads(match.group())
|
|
except json.JSONDecodeError:
|
|
pass
|
|
|
|
logger.debug(f"Cannot extract JSON from VLM response: {text[:100]}")
|
|
return None
|
|
|
|
def _validate_classification(self, data: Dict) -> Dict[str, Any]:
|
|
"""Valide et normalise un résultat de classification."""
|
|
valid_types = ["button", "text_input", "checkbox", "radio", "dropdown",
|
|
"tab", "link", "icon", "table_row", "menu_item"]
|
|
valid_roles = ["primary_action", "cancel", "submit", "form_input",
|
|
"search_field", "navigation", "settings", "close",
|
|
"delete", "edit", "save"]
|
|
|
|
elem_type = str(data.get("type", "unknown")).lower().strip()
|
|
elem_role = str(data.get("role", "unknown")).lower().strip()
|
|
elem_text = str(data.get("text", ""))
|
|
|
|
if elem_type not in valid_types:
|
|
elem_type = "unknown"
|
|
if elem_role not in valid_roles:
|
|
elem_role = "unknown"
|
|
|
|
return {
|
|
"type": elem_type, "role": elem_role, "text": elem_text,
|
|
"confidence": 0.85, "success": True
|
|
}
|
|
|
|
def _encode_image_from_path(self, image_path: str) -> str:
|
|
"""Encoder une image depuis un fichier en base64"""
|
|
with open(image_path, 'rb') as f:
|
|
return base64.b64encode(f.read()).decode('utf-8')
|
|
|
|
def _encode_image_from_pil(self, image: Image.Image) -> str:
|
|
"""Encoder une image PIL en base64 avec prétraitement optimisé"""
|
|
# 1. Convertir en RGB si nécessaire (évite erreurs PNG transparent)
|
|
if image.mode != 'RGB':
|
|
image = image.convert('RGB')
|
|
|
|
# 1b. Minimum 32x32 (requis par qwen3-vl, sinon Ollama panic)
|
|
min_size = 32
|
|
if image.width < min_size or image.height < min_size:
|
|
new_w = max(image.width, min_size)
|
|
new_h = max(image.height, min_size)
|
|
image = image.resize((new_w, new_h), Image.NEAREST)
|
|
|
|
# 2. Redimensionnement intelligent : max 1280px sur le côté long
|
|
max_size = 1280
|
|
if max(image.size) > max_size:
|
|
ratio = max_size / max(image.size)
|
|
new_size = (int(image.size[0] * ratio), int(image.size[1] * ratio))
|
|
image = image.resize(new_size, Image.Resampling.LANCZOS)
|
|
|
|
# 3. Sauvegarder en JPEG qualité 90 (plus léger, meilleur pour VLM)
|
|
buffer = io.BytesIO()
|
|
image.save(buffer, format='JPEG', quality=90)
|
|
return base64.b64encode(buffer.getvalue()).decode('utf-8')
|
|
|
|
def list_models(self) -> List[str]:
|
|
"""Lister les modèles disponibles dans Ollama"""
|
|
try:
|
|
response = requests.get(f"{self.endpoint}/api/tags", timeout=5)
|
|
if response.status_code == 200:
|
|
models = response.json().get('models', [])
|
|
return [m['name'] for m in models]
|
|
except Exception as e:
|
|
logger.error(f"Error listing models: {e}")
|
|
return []
|
|
|
|
def pull_model(self, model_name: str) -> bool:
|
|
"""
|
|
Télécharger un modèle dans Ollama
|
|
|
|
Args:
|
|
model_name: Nom du modèle à télécharger
|
|
|
|
Returns:
|
|
True si succès
|
|
"""
|
|
try:
|
|
logger.info(f"Pulling model {model_name}...")
|
|
response = requests.post(
|
|
f"{self.endpoint}/api/pull",
|
|
json={"name": model_name},
|
|
stream=True,
|
|
timeout=600
|
|
)
|
|
|
|
if response.status_code == 200:
|
|
for line in response.iter_lines():
|
|
if line:
|
|
data = json.loads(line)
|
|
if 'status' in data:
|
|
logger.info(f" {data['status']}")
|
|
return True
|
|
except Exception as e:
|
|
logger.error(f"Error pulling model: {e}")
|
|
return False
|
|
|
|
|
|
# ============================================================================
|
|
# Fonctions utilitaires
|
|
# ============================================================================
|
|
|
|
def create_ollama_client(model: str = "qwen3-vl:8b",
|
|
endpoint: str = "http://localhost:11434") -> OllamaClient:
|
|
"""
|
|
Créer un client Ollama
|
|
|
|
Args:
|
|
model: Nom du modèle VLM
|
|
endpoint: URL de l'API Ollama
|
|
|
|
Returns:
|
|
OllamaClient configuré
|
|
"""
|
|
return OllamaClient(endpoint=endpoint, model=model)
|
|
|
|
|
|
def check_ollama_available(endpoint: str = "http://localhost:11434") -> bool:
|
|
"""
|
|
Vérifier si Ollama est disponible
|
|
|
|
Args:
|
|
endpoint: URL de l'API Ollama
|
|
|
|
Returns:
|
|
True si disponible
|
|
"""
|
|
try:
|
|
response = requests.get(f"{endpoint}/api/tags", timeout=5)
|
|
return response.status_code == 200
|
|
except (requests.RequestException, ConnectionError, TimeoutError):
|
|
return False
|