feat(grounding): InfiGUI-G1-3B remplace UI-TARS 7B — 3.5x moins de VRAM
Serveur de grounding (server.py) : - InfiGUI-G1-3B au lieu de UI-TARS-1.5-7B - VRAM : 2.25 GB au lieu de 8.4 GB (6.6 GB libres) - Prompt officiel InfiGUI (system <think> + user point_2d JSON) - max_new_tokens=512, parsing JSON point_2d - 4/4 éléments trouvés : Demo 5px, Chrome 98px, Corbeille 15px, Search 66px - Fallback UI-TARS via env GROUNDING_MODEL=ByteDance-Seed/UI-TARS-1.5-7B EasyOCR : retour sur GPU (assez de VRAM maintenant) → 192ms au lieu de 2.5s Co-Authored-By: Claude Opus 4.6 (1M context) <noreply@anthropic.com>
This commit is contained in:
@@ -32,9 +32,9 @@ import uvicorn
|
|||||||
# ---------------------------------------------------------------------------
|
# ---------------------------------------------------------------------------
|
||||||
|
|
||||||
PORT = int(os.environ.get("GROUNDING_PORT", 8200))
|
PORT = int(os.environ.get("GROUNDING_PORT", 8200))
|
||||||
MODEL_ID = "ByteDance-Seed/UI-TARS-1.5-7B"
|
MODEL_ID = os.environ.get("GROUNDING_MODEL", "InfiX-ai/InfiGUI-G1-3B")
|
||||||
MIN_PIXELS = 100 * 28 * 28
|
MIN_PIXELS = 100 * 28 * 28
|
||||||
MAX_PIXELS = 16384 * 28 * 28
|
MAX_PIXELS = 5600 * 28 * 28 # InfiGUI recommande 5600*28*28
|
||||||
|
|
||||||
# ---------------------------------------------------------------------------
|
# ---------------------------------------------------------------------------
|
||||||
# Smart resize — identique a /tmp/test_uitars.py
|
# Smart resize — identique a /tmp/test_uitars.py
|
||||||
@@ -57,23 +57,11 @@ def _smart_resize(height: int, width: int, factor: int = 28,
|
|||||||
|
|
||||||
|
|
||||||
# ---------------------------------------------------------------------------
|
# ---------------------------------------------------------------------------
|
||||||
# Prompt officiel UI-TARS — identique a /tmp/test_uitars.py
|
# Prompts — InfiGUI-G1-3B (format officiel de la doc HuggingFace)
|
||||||
# ---------------------------------------------------------------------------
|
# ---------------------------------------------------------------------------
|
||||||
|
|
||||||
_GROUNDING_PROMPT = """You are a GUI agent. You are given a task and your action history, with screenshots. You need to perform the next action to complete the task.
|
_SYSTEM_PROMPT = """You FIRST think about the reasoning process as an internal monologue and then provide the final answer.
|
||||||
|
The reasoning process MUST BE enclosed within <think> </think> tags."""
|
||||||
## Output Format
|
|
||||||
|
|
||||||
Thought: ...
|
|
||||||
Action: ...
|
|
||||||
|
|
||||||
|
|
||||||
## Action Space
|
|
||||||
click(start_box='(x1, y1)')
|
|
||||||
|
|
||||||
|
|
||||||
## User Instruction
|
|
||||||
{instruction}"""
|
|
||||||
|
|
||||||
|
|
||||||
# ---------------------------------------------------------------------------
|
# ---------------------------------------------------------------------------
|
||||||
@@ -121,7 +109,7 @@ def _evict_ollama_models():
|
|||||||
|
|
||||||
|
|
||||||
def _load_model():
|
def _load_model():
|
||||||
"""Charge UI-TARS-1.5-7B en 4-bit NF4 — code identique a /tmp/test_uitars.py."""
|
"""Charge le modele de grounding en 4-bit NF4."""
|
||||||
global _model, _processor, _model_loaded
|
global _model, _processor, _model_loaded
|
||||||
|
|
||||||
if _model_loaded:
|
if _model_loaded:
|
||||||
@@ -161,6 +149,7 @@ def _load_model():
|
|||||||
MODEL_ID,
|
MODEL_ID,
|
||||||
min_pixels=MIN_PIXELS,
|
min_pixels=MIN_PIXELS,
|
||||||
max_pixels=MAX_PIXELS,
|
max_pixels=MAX_PIXELS,
|
||||||
|
padding_side="left",
|
||||||
)
|
)
|
||||||
|
|
||||||
_model_loaded = True
|
_model_loaded = True
|
||||||
@@ -292,7 +281,7 @@ def ground(req: GroundRequest):
|
|||||||
from PIL import Image
|
from PIL import Image
|
||||||
from qwen_vl_utils import process_vision_info
|
from qwen_vl_utils import process_vision_info
|
||||||
|
|
||||||
# Construire l'instruction
|
# Construire la description de la cible
|
||||||
parts = []
|
parts = []
|
||||||
if req.target_text:
|
if req.target_text:
|
||||||
parts.append(req.target_text)
|
parts.append(req.target_text)
|
||||||
@@ -301,7 +290,7 @@ def ground(req: GroundRequest):
|
|||||||
if not parts:
|
if not parts:
|
||||||
raise HTTPException(status_code=400, detail="target_text ou target_description requis")
|
raise HTTPException(status_code=400, detail="target_text ou target_description requis")
|
||||||
|
|
||||||
instruction = f"Click on the {' — '.join(parts)}"
|
target_label = ' — '.join(parts)
|
||||||
|
|
||||||
# Obtenir l'image (fournie en b64 ou capture ecran)
|
# Obtenir l'image (fournie en b64 ou capture ecran)
|
||||||
if req.image_b64:
|
if req.image_b64:
|
||||||
@@ -319,30 +308,23 @@ def ground(req: GroundRequest):
|
|||||||
W, H = screen_pil.size
|
W, H = screen_pil.size
|
||||||
rH, rW = _smart_resize(H, W, min_pixels=MIN_PIXELS, max_pixels=MAX_PIXELS)
|
rH, rW = _smart_resize(H, W, min_pixels=MIN_PIXELS, max_pixels=MAX_PIXELS)
|
||||||
|
|
||||||
# Sauver temporairement l'image pour qwen_vl_utils
|
|
||||||
import tempfile
|
|
||||||
tmp_path = os.path.join(tempfile.gettempdir(), f"grounding_screen_{os.getpid()}.png")
|
|
||||||
screen_pil.save(tmp_path)
|
|
||||||
|
|
||||||
try:
|
try:
|
||||||
system_prompt = _GROUNDING_PROMPT.format(instruction=instruction)
|
import json as _json
|
||||||
|
|
||||||
|
# Prompt officiel InfiGUI-G1-3B (doc HuggingFace)
|
||||||
|
user_text = (
|
||||||
|
f'The screen\'s resolution is {rW}x{rH}.\n'
|
||||||
|
f'Locate the UI element(s) for "{target_label}", '
|
||||||
|
f'output the coordinates using JSON format: '
|
||||||
|
f'[{{"point_2d": [x, y]}}, ...]'
|
||||||
|
)
|
||||||
|
|
||||||
messages = [
|
messages = [
|
||||||
{
|
{"role": "system", "content": _SYSTEM_PROMPT},
|
||||||
"role": "user",
|
{"role": "user", "content": [
|
||||||
"content": [
|
{"type": "image", "image": screen_pil},
|
||||||
{
|
{"type": "text", "text": user_text},
|
||||||
"type": "image",
|
]},
|
||||||
"image": f"file://{tmp_path}",
|
|
||||||
"min_pixels": MIN_PIXELS,
|
|
||||||
"max_pixels": MAX_PIXELS,
|
|
||||||
},
|
|
||||||
{
|
|
||||||
"type": "text",
|
|
||||||
"text": system_prompt,
|
|
||||||
},
|
|
||||||
],
|
|
||||||
}
|
|
||||||
]
|
]
|
||||||
|
|
||||||
text = _processor.apply_chat_template(
|
text = _processor.apply_chat_template(
|
||||||
@@ -360,7 +342,7 @@ def ground(req: GroundRequest):
|
|||||||
# Inference
|
# Inference
|
||||||
t0 = time.time()
|
t0 = time.time()
|
||||||
with torch.no_grad():
|
with torch.no_grad():
|
||||||
gen = _model.generate(**inputs, max_new_tokens=64)
|
gen = _model.generate(**inputs, max_new_tokens=512)
|
||||||
infer_ms = (time.time() - t0) * 1000
|
infer_ms = (time.time() - t0) * 1000
|
||||||
|
|
||||||
# Decoder
|
# Decoder
|
||||||
@@ -369,46 +351,56 @@ def ground(req: GroundRequest):
|
|||||||
trimmed, skip_special_tokens=True, clean_up_tokenization_spaces=False
|
trimmed, skip_special_tokens=True, clean_up_tokenization_spaces=False
|
||||||
)[0].strip()
|
)[0].strip()
|
||||||
|
|
||||||
print(f"[grounding-server] '{instruction}' -> raw='{raw[:150]}' ({infer_ms:.0f}ms)")
|
print(f"[grounding-server] '{target_label}' -> raw='{raw[:150]}' ({infer_ms:.0f}ms)")
|
||||||
|
|
||||||
# Détecter les réponses négatives (le modèle dit qu'il ne voit pas l'élément)
|
# Parser le JSON InfiGUI : split sur </think>, extraire point_2d
|
||||||
|
px, py = None, None
|
||||||
|
json_part = raw.split("</think>")[-1] if "</think>" in raw else raw
|
||||||
|
json_part = json_part.replace("```json", "").replace("```", "").strip()
|
||||||
|
|
||||||
|
try:
|
||||||
|
data = _json.loads(json_part)
|
||||||
|
if isinstance(data, list) and len(data) > 0:
|
||||||
|
pt = data[0].get("point_2d", [])
|
||||||
|
if len(pt) >= 2:
|
||||||
|
# Coordonnées en pixels resizés → convertir en pixels originaux
|
||||||
|
px = int(pt[0] * W / rW)
|
||||||
|
py = int(pt[1] * H / rH)
|
||||||
|
except _json.JSONDecodeError:
|
||||||
|
# Fallback regex
|
||||||
|
m = re.search(r'"point_2d"\s*:\s*\[(\d+),\s*(\d+)\]', raw)
|
||||||
|
if m:
|
||||||
|
px = int(int(m.group(1)) * W / rW)
|
||||||
|
py = int(int(m.group(2)) * H / rH)
|
||||||
|
|
||||||
|
if px is None:
|
||||||
|
# Détection réponses négatives
|
||||||
_raw_lower = raw.lower()
|
_raw_lower = raw.lower()
|
||||||
_negative_markers = ["don't see", "do not see", "cannot find", "can't find",
|
for _neg in ["don't see", "cannot find", "not visible", "not found",
|
||||||
"not visible", "not found", "doesn't appear", "does not appear",
|
"unable to find", "unable to locate", "does not appear"]:
|
||||||
"i don't", "unable to find", "unable to locate"]
|
|
||||||
for _neg in _negative_markers:
|
|
||||||
if _neg in _raw_lower:
|
if _neg in _raw_lower:
|
||||||
print(f"[grounding-server] NÉGATIF détecté: '{_neg}' → élément non trouvé")
|
print(f"[grounding-server] NÉGATIF: '{_neg}'")
|
||||||
return GroundResponse(x=None, y=None, method="ui_tars", confidence=0.0,
|
return GroundResponse(x=None, y=None, method="infigui",
|
||||||
time_ms=round(infer_ms, 1), raw_output=raw[:300])
|
confidence=0.0, time_ms=round(infer_ms, 1),
|
||||||
|
raw_output=raw[:300])
|
||||||
|
|
||||||
# Parser les coordonnees
|
print(f"[grounding-server] Coordonnées non parsées: {json_part[:100]}")
|
||||||
parsed = _parse_coordinates(raw, W, H, rW, rH)
|
return GroundResponse(x=None, y=None, method="infigui",
|
||||||
if parsed is None:
|
confidence=0.0, time_ms=round(infer_ms, 1),
|
||||||
raise HTTPException(
|
raw_output=raw[:300])
|
||||||
status_code=422,
|
|
||||||
detail=f"Coordonnees non parsees dans la reponse: {raw[:200]}"
|
|
||||||
)
|
|
||||||
|
|
||||||
px, py, method_detail, confidence = parsed
|
confidence = 0.90
|
||||||
|
print(f"[grounding-server] Résultat: ({px}, {py}) conf={confidence:.2f} ({infer_ms:.0f}ms)")
|
||||||
print(f"[grounding-server] Resultat: ({px}, {py}) conf={confidence:.2f} "
|
|
||||||
f"[{method_detail}] ({infer_ms:.0f}ms)")
|
|
||||||
|
|
||||||
return GroundResponse(
|
return GroundResponse(
|
||||||
x=px,
|
x=px, y=py, method="infigui",
|
||||||
y=py,
|
confidence=confidence, time_ms=round(infer_ms, 1),
|
||||||
method="ui_tars",
|
|
||||||
confidence=confidence,
|
|
||||||
time_ms=round(infer_ms, 1),
|
|
||||||
raw_output=raw[:300],
|
raw_output=raw[:300],
|
||||||
)
|
)
|
||||||
|
|
||||||
finally:
|
except Exception as e:
|
||||||
try:
|
print(f"[grounding-server] ERREUR: {e}")
|
||||||
os.unlink(tmp_path)
|
raise HTTPException(status_code=500, detail=str(e))
|
||||||
except OSError:
|
|
||||||
pass
|
|
||||||
|
|
||||||
|
|
||||||
# ---------------------------------------------------------------------------
|
# ---------------------------------------------------------------------------
|
||||||
|
|||||||
Reference in New Issue
Block a user