diff --git a/core/grounding/server.py b/core/grounding/server.py index 532f7b5e4..dc685621f 100644 --- a/core/grounding/server.py +++ b/core/grounding/server.py @@ -32,9 +32,9 @@ import uvicorn # --------------------------------------------------------------------------- PORT = int(os.environ.get("GROUNDING_PORT", 8200)) -MODEL_ID = "ByteDance-Seed/UI-TARS-1.5-7B" +MODEL_ID = os.environ.get("GROUNDING_MODEL", "InfiX-ai/InfiGUI-G1-3B") MIN_PIXELS = 100 * 28 * 28 -MAX_PIXELS = 16384 * 28 * 28 +MAX_PIXELS = 5600 * 28 * 28 # InfiGUI recommande 5600*28*28 # --------------------------------------------------------------------------- # Smart resize — identique a /tmp/test_uitars.py @@ -57,23 +57,11 @@ def _smart_resize(height: int, width: int, factor: int = 28, # --------------------------------------------------------------------------- -# Prompt officiel UI-TARS — identique a /tmp/test_uitars.py +# Prompts — InfiGUI-G1-3B (format officiel de la doc HuggingFace) # --------------------------------------------------------------------------- -_GROUNDING_PROMPT = """You are a GUI agent. You are given a task and your action history, with screenshots. You need to perform the next action to complete the task. - -## Output Format - -Thought: ... -Action: ... - - -## Action Space -click(start_box='(x1, y1)') - - -## User Instruction -{instruction}""" +_SYSTEM_PROMPT = """You FIRST think about the reasoning process as an internal monologue and then provide the final answer. +The reasoning process MUST BE enclosed within tags.""" # --------------------------------------------------------------------------- @@ -121,7 +109,7 @@ def _evict_ollama_models(): def _load_model(): - """Charge UI-TARS-1.5-7B en 4-bit NF4 — code identique a /tmp/test_uitars.py.""" + """Charge le modele de grounding en 4-bit NF4.""" global _model, _processor, _model_loaded if _model_loaded: @@ -161,6 +149,7 @@ def _load_model(): MODEL_ID, min_pixels=MIN_PIXELS, max_pixels=MAX_PIXELS, + padding_side="left", ) _model_loaded = True @@ -292,7 +281,7 @@ def ground(req: GroundRequest): from PIL import Image from qwen_vl_utils import process_vision_info - # Construire l'instruction + # Construire la description de la cible parts = [] if req.target_text: parts.append(req.target_text) @@ -301,7 +290,7 @@ def ground(req: GroundRequest): if not parts: raise HTTPException(status_code=400, detail="target_text ou target_description requis") - instruction = f"Click on the {' — '.join(parts)}" + target_label = ' — '.join(parts) # Obtenir l'image (fournie en b64 ou capture ecran) if req.image_b64: @@ -319,30 +308,23 @@ def ground(req: GroundRequest): W, H = screen_pil.size rH, rW = _smart_resize(H, W, min_pixels=MIN_PIXELS, max_pixels=MAX_PIXELS) - # Sauver temporairement l'image pour qwen_vl_utils - import tempfile - tmp_path = os.path.join(tempfile.gettempdir(), f"grounding_screen_{os.getpid()}.png") - screen_pil.save(tmp_path) - try: - system_prompt = _GROUNDING_PROMPT.format(instruction=instruction) + import json as _json + + # Prompt officiel InfiGUI-G1-3B (doc HuggingFace) + user_text = ( + f'The screen\'s resolution is {rW}x{rH}.\n' + f'Locate the UI element(s) for "{target_label}", ' + f'output the coordinates using JSON format: ' + f'[{{"point_2d": [x, y]}}, ...]' + ) messages = [ - { - "role": "user", - "content": [ - { - "type": "image", - "image": f"file://{tmp_path}", - "min_pixels": MIN_PIXELS, - "max_pixels": MAX_PIXELS, - }, - { - "type": "text", - "text": system_prompt, - }, - ], - } + {"role": "system", "content": _SYSTEM_PROMPT}, + {"role": "user", "content": [ + {"type": "image", "image": screen_pil}, + {"type": "text", "text": user_text}, + ]}, ] text = _processor.apply_chat_template( @@ -360,7 +342,7 @@ def ground(req: GroundRequest): # Inference t0 = time.time() with torch.no_grad(): - gen = _model.generate(**inputs, max_new_tokens=64) + gen = _model.generate(**inputs, max_new_tokens=512) infer_ms = (time.time() - t0) * 1000 # Decoder @@ -369,46 +351,56 @@ def ground(req: GroundRequest): trimmed, skip_special_tokens=True, clean_up_tokenization_spaces=False )[0].strip() - print(f"[grounding-server] '{instruction}' -> raw='{raw[:150]}' ({infer_ms:.0f}ms)") + print(f"[grounding-server] '{target_label}' -> raw='{raw[:150]}' ({infer_ms:.0f}ms)") - # Détecter les réponses négatives (le modèle dit qu'il ne voit pas l'élément) - _raw_lower = raw.lower() - _negative_markers = ["don't see", "do not see", "cannot find", "can't find", - "not visible", "not found", "doesn't appear", "does not appear", - "i don't", "unable to find", "unable to locate"] - for _neg in _negative_markers: - if _neg in _raw_lower: - print(f"[grounding-server] NÉGATIF détecté: '{_neg}' → élément non trouvé") - return GroundResponse(x=None, y=None, method="ui_tars", confidence=0.0, - time_ms=round(infer_ms, 1), raw_output=raw[:300]) + # Parser le JSON InfiGUI : split sur , extraire point_2d + px, py = None, None + json_part = raw.split("")[-1] if "" in raw else raw + json_part = json_part.replace("```json", "").replace("```", "").strip() - # Parser les coordonnees - parsed = _parse_coordinates(raw, W, H, rW, rH) - if parsed is None: - raise HTTPException( - status_code=422, - detail=f"Coordonnees non parsees dans la reponse: {raw[:200]}" - ) + try: + data = _json.loads(json_part) + if isinstance(data, list) and len(data) > 0: + pt = data[0].get("point_2d", []) + if len(pt) >= 2: + # Coordonnées en pixels resizés → convertir en pixels originaux + px = int(pt[0] * W / rW) + py = int(pt[1] * H / rH) + except _json.JSONDecodeError: + # Fallback regex + m = re.search(r'"point_2d"\s*:\s*\[(\d+),\s*(\d+)\]', raw) + if m: + px = int(int(m.group(1)) * W / rW) + py = int(int(m.group(2)) * H / rH) - px, py, method_detail, confidence = parsed + if px is None: + # Détection réponses négatives + _raw_lower = raw.lower() + for _neg in ["don't see", "cannot find", "not visible", "not found", + "unable to find", "unable to locate", "does not appear"]: + if _neg in _raw_lower: + print(f"[grounding-server] NÉGATIF: '{_neg}'") + return GroundResponse(x=None, y=None, method="infigui", + confidence=0.0, time_ms=round(infer_ms, 1), + raw_output=raw[:300]) - print(f"[grounding-server] Resultat: ({px}, {py}) conf={confidence:.2f} " - f"[{method_detail}] ({infer_ms:.0f}ms)") + print(f"[grounding-server] Coordonnées non parsées: {json_part[:100]}") + return GroundResponse(x=None, y=None, method="infigui", + confidence=0.0, time_ms=round(infer_ms, 1), + raw_output=raw[:300]) + + confidence = 0.90 + print(f"[grounding-server] Résultat: ({px}, {py}) conf={confidence:.2f} ({infer_ms:.0f}ms)") return GroundResponse( - x=px, - y=py, - method="ui_tars", - confidence=confidence, - time_ms=round(infer_ms, 1), + x=px, y=py, method="infigui", + confidence=confidence, time_ms=round(infer_ms, 1), raw_output=raw[:300], ) - finally: - try: - os.unlink(tmp_path) - except OSError: - pass + except Exception as e: + print(f"[grounding-server] ERREUR: {e}") + raise HTTPException(status_code=500, detail=str(e)) # ---------------------------------------------------------------------------