"""Serveur grounding minimaliste — Flask single-thread, même contexte CUDA.""" import base64, io, json, math, os, re, time, gc import torch from flask import Flask, request, jsonify from PIL import Image app = Flask(__name__) MODEL_ID = os.environ.get("GROUNDING_MODEL", "InfiX-ai/InfiGUI-G1-3B") MIN_PIXELS = 100 * 28 * 28 MAX_PIXELS = 5600 * 28 * 28 _model = None _processor = None def _smart_resize(h, w, factor=28): h_bar = max(factor, round(h/factor)*factor) w_bar = max(factor, round(w/factor)*factor) if h_bar*w_bar > MAX_PIXELS: beta = math.sqrt((h*w)/MAX_PIXELS) h_bar = math.floor(h/beta/factor)*factor w_bar = math.floor(w/beta/factor)*factor elif h_bar*w_bar < MIN_PIXELS: beta = math.sqrt(MIN_PIXELS/(h*w)) h_bar = math.ceil(h*beta/factor)*factor w_bar = math.ceil(w*beta/factor)*factor return h_bar, w_bar def load_model(): global _model, _processor if _model is not None: return from transformers import Qwen2_5_VLForConditionalGeneration, AutoProcessor, BitsAndBytesConfig torch.cuda.empty_cache(); gc.collect() print(f"[grounding] Chargement {MODEL_ID}...") bnb = BitsAndBytesConfig(load_in_4bit=True, bnb_4bit_quant_type="nf4", bnb_4bit_compute_dtype=torch.bfloat16, bnb_4bit_use_double_quant=True) _model = Qwen2_5_VLForConditionalGeneration.from_pretrained( MODEL_ID, quantization_config=bnb, device_map="auto") _model.eval() _processor = AutoProcessor.from_pretrained(MODEL_ID, min_pixels=MIN_PIXELS, max_pixels=MAX_PIXELS, padding_side="left") print(f"[grounding] Prêt — VRAM: {torch.cuda.memory_allocated()/1e9:.2f}GB") @app.route('/health') def health(): return jsonify({"status": "ok", "model": MODEL_ID, "model_loaded": _model is not None, "cuda_available": torch.cuda.is_available(), "vram_allocated_gb": round(torch.cuda.memory_allocated()/1e9, 2)}) @app.route('/ground', methods=['POST']) def ground(): if _model is None: return jsonify({"error": "Modèle pas chargé"}), 503 from qwen_vl_utils import process_vision_info data = request.json target = data.get('target_text', '') desc = data.get('target_description', '') label = f"{target} — {desc}" if desc else target if not label.strip(): return jsonify({"error": "target_text requis"}), 400 # Image if data.get('image_b64'): raw = data['image_b64'].split(',')[1] if ',' in data['image_b64'] else data['image_b64'] img = Image.open(io.BytesIO(base64.b64decode(raw))).convert('RGB') else: import mss with mss.mss() as sct: grab = sct.grab(sct.monitors[0]) img = Image.frombytes('RGB', grab.size, grab.bgra, 'raw', 'BGRX') W, H = img.size rH, rW = _smart_resize(H, W) user_text = f'The screen\'s resolution is {rW}x{rH}.\nLocate the UI element(s) for "{label}", output the coordinates using JSON format: [{{"point_2d": [x, y]}}, ...]' system = "You FIRST think about the reasoning process as an internal monologue and then provide the final answer.\nThe reasoning process MUST BE enclosed within tags." messages = [{"role": "system", "content": system}, {"role": "user", "content": [{"type": "image", "image": img}, {"type": "text", "text": user_text}]}] text = _processor.apply_chat_template(messages, tokenize=False, add_generation_prompt=True) image_inputs, video_inputs = process_vision_info(messages) inputs = _processor(text=[text], images=image_inputs, videos=video_inputs, padding=True, return_tensors="pt").to(_model.device) t0 = time.time() with torch.no_grad(): gen = _model.generate(**inputs, max_new_tokens=512) infer_ms = (time.time()-t0)*1000 trimmed = [o[len(i):] for i,o in zip(inputs.input_ids, gen)] raw = _processor.batch_decode(trimmed, skip_special_tokens=True, clean_up_tokenization_spaces=False)[0].strip() print(f"[grounding] '{label[:40]}' → {raw[:100]} ({infer_ms:.0f}ms)") # Parser JSON point_2d json_part = raw.split("")[-1] if "" in raw else raw json_part = json_part.replace("```json","").replace("```","").strip() px, py = None, None try: parsed = json.loads(json_part) if isinstance(parsed, list) and len(parsed) > 0: pt = parsed[0].get("point_2d", []) if len(pt) >= 2: px, py = int(pt[0]*W/rW), int(pt[1]*H/rH) except json.JSONDecodeError: m = re.search(r'"point_2d"\s*:\s*\[(\d+),\s*(\d+)\]', raw) if m: px, py = int(int(m.group(1))*W/rW), int(int(m.group(2))*H/rH) return jsonify({"x": px, "y": py, "method": "infigui", "confidence": 0.90 if px else 0.0, "time_ms": round(infer_ms, 1), "raw_output": raw[:300]}) if __name__ == '__main__': load_model() app.run(host='0.0.0.0', port=8200, threaded=False)