diff --git a/demo/gradio/app.py b/demo/gradio/app.py
index 1dfb9e4..b54a19a 100644
--- a/demo/gradio/app.py
+++ b/demo/gradio/app.py
@@ -14,6 +14,7 @@ from functools import partial
from pathlib import Path
from typing import cast, Dict
from PIL import Image
+import socket
import gradio as gr
from anthropic import APIResponse
@@ -294,8 +295,10 @@ with gr.Blocks(theme=gr.themes.Default()) as demo:
with gr.Column(scale=1):
chatbot = gr.Chatbot(label="Chatbot History", autoscroll=True, height=580)
with gr.Column(scale=3):
+ # Get the fully qualified domain name of the machine
+ machine_fqdn = socket.getfqdn()
iframe = gr.HTML(
- '',
+ f'',
container=False,
elem_classes="no-padding"
)
@@ -361,4 +364,44 @@ with gr.Blocks(theme=gr.themes.Default()) as demo:
submit_button.click(process_input, [chat_input, state], chatbot)
-demo.launch(share=False, allowed_paths=["./"], server_port=7888)
+from fastapi import FastAPI
+import uvicorn
+from multiprocessing import Process
+
+app = FastAPI()
+
+# Mount the Gradio app under the "/gradio" path
+app = gr.mount_gradio_app(app, demo, path="/gradio")
+
+# Optional: Add a root endpoint that redirects to the Gradio interface
+@app.get("/")
+async def root():
+ return {"message": "Welcome to OmniParser Demo API",
+ "gradio_interface": "/gradio"}
+
+# Create a second FastAPI app for VNC
+vnc_app = FastAPI()
+
+@vnc_app.get("/")
+async def vnc_root():
+ return {"message": "VNC Server"}
+
+def run_app(app, host, port):
+ uvicorn.run(app, host=host, port=port)
+
+# To run this with uvicorn:
+if __name__ == "__main__":
+ # Start the main app on port 7889
+ p1 = Process(target=run_app, args=(app, "0.0.0.0", 7889))
+ # Start the VNC app on port 8006
+ p2 = Process(target=run_app, args=(vnc_app, "0.0.0.0", 8006))
+
+ p1.start()
+ p2.start()
+
+ try:
+ p1.join()
+ p2.join()
+ except KeyboardInterrupt:
+ p1.terminate()
+ p2.terminate()
\ No newline at end of file
diff --git a/demo/gradio/computer_use_demo/loop.py b/demo/gradio/computer_use_demo/loop.py
index a0bc58e..22754f8 100644
--- a/demo/gradio/computer_use_demo/loop.py
+++ b/demo/gradio/computer_use_demo/loop.py
@@ -113,7 +113,9 @@ def sampling_loop_sync(
)
elif model == "omniparser + gpt-4o" or model == "omniparser + phi35v":
- omniparser = OmniParser(url="http://localhost:8000/send_text/",
+ # omniparser = OmniParser(url="http://localhost:8000/send_text/",
+ # selected_screen=selected_screen,)
+ omniparser = OmniParser(url=None,
selected_screen=selected_screen,)
actor = VLMAgent(
diff --git a/demo/gradio/computer_use_demo/omniparser_agent/omniparser.py b/demo/gradio/computer_use_demo/omniparser_agent/omniparser.py
new file mode 100644
index 0000000..81ca124
--- /dev/null
+++ b/demo/gradio/computer_use_demo/omniparser_agent/omniparser.py
@@ -0,0 +1,77 @@
+# uvicorn remote_request:app --host 0.0.0.0 --port 8000 --reload
+
+import sys
+import os
+sys.path.append(os.path.dirname(os.path.dirname(os.path.abspath(__file__))))
+
+from computer_use_demo.omniparser_agent.utils import get_som_labeled_img, get_caption_model_processor, get_yolo_model, check_ocr_box
+import torch
+from PIL import Image
+from typing import Dict, Tuple, List
+import base64
+import io
+
+
+# config = {
+# 'som_model_path': '../weights/icon_detect_v1_5/model_v1_5.pt',
+# 'device': 'cpu',
+# 'caption_model_name': 'florence2',
+# 'caption_model_path': '../weights/icon_caption_florence',
+# 'BOX_TRESHOLD': 0.05
+# }
+
+
+class Omniparser(object):
+ def __init__(self, config: Dict):
+ self.config = config
+ device = 'cuda' if torch.cuda.is_available() else 'cpu'
+
+ self.som_model = get_yolo_model(model_path=config['som_model_path'])
+ self.caption_model_processor = get_caption_model_processor(model_name=config['caption_model_name'], model_name_or_path=config['caption_model_path'], device=device)
+ print('Omniparser initialized!!!')
+
+ def parse(self, image_base64: str):
+ image_path = './demo_image.jpg'
+ with open(image_path, "wb") as fh:
+ fh.write(base64.b64decode(image_base64))
+ print('Parsing image:', image_path)
+
+ image = Image.open(image_path)
+ print('image size:', image.size)
+
+ box_overlay_ratio = max(image.size) / 3200
+ draw_bbox_config = {
+ 'text_scale': 0.8 * box_overlay_ratio,
+ 'text_thickness': max(int(2 * box_overlay_ratio), 1),
+ 'text_padding': max(int(3 * box_overlay_ratio), 1),
+ 'thickness': max(int(3 * box_overlay_ratio), 1),
+ }
+ BOX_TRESHOLD = self.config['BOX_TRESHOLD']
+
+ ocr_bbox_rslt, is_goal_filtered = check_ocr_box(image_path, display_img = False, output_bb_format='xyxy', goal_filtering=None, easyocr_args={'paragraph': False, 'text_threshold':0.8}, use_paddleocr=True)
+ text, ocr_bbox = ocr_bbox_rslt
+ dino_labled_img, label_coordinates, parsed_content_list = get_som_labeled_img(image_path, self.som_model, BOX_TRESHOLD = BOX_TRESHOLD, output_coord_in_ratio=True, ocr_bbox=ocr_bbox,draw_bbox_config=draw_bbox_config, caption_model_processor=self.caption_model_processor, ocr_text=text,use_local_semantics=True, iou_threshold=0.7, scale_img=False, batch_size=128)
+
+ return dino_labled_img, parsed_content_list
+
+
+# from fastapi import FastAPI
+# from pydantic import BaseModel
+
+# app = FastAPI()
+
+# class Item(BaseModel):
+# base64_image: str
+# prompt: str
+
+# Omniparser = Omniparser(config)
+
+# @app.post("/send_text/")
+# async def send_text(item: Item):
+# print('start parsing...')
+# import time
+# start = time.time()
+# dino_labled_img, parsed_content_list = Omniparser.parse(item.base64_image)
+# latency = time.time() - start
+# print('time:', latency)
+# return {"som_image_base64": dino_labled_img, "parsed_content_list": parsed_content_list, 'latency': latency}
\ No newline at end of file
diff --git a/demo/gradio/computer_use_demo/omniparser_agent/utils.py b/demo/gradio/computer_use_demo/omniparser_agent/utils.py
new file mode 100755
index 0000000..87c6f83
--- /dev/null
+++ b/demo/gradio/computer_use_demo/omniparser_agent/utils.py
@@ -0,0 +1,935 @@
+# from ultralytics import YOLO
+import os
+import io
+import base64
+import time
+from PIL import Image, ImageDraw, ImageFont
+import json
+import requests
+# utility function
+import os
+from openai import AzureOpenAI
+
+import json
+import sys
+import os
+import cv2
+import numpy as np
+# %matplotlib inline
+from matplotlib import pyplot as plt
+import easyocr
+from paddleocr import PaddleOCR
+reader = easyocr.Reader(['en', 'ch_sim'], gpu=True)
+paddle_ocr = PaddleOCR(
+ lang='en', # other lang also available
+ use_angle_cls=False,
+ use_gpu=False, # using cuda will conflict with pytorch in the same process
+ show_log=False,
+ max_batch_size=1024,
+ use_dilation=True, # improves accuracy
+ det_db_score_mode='slow', # improves accuracy
+ rec_batch_num=1024)
+import time
+import base64
+
+import os
+import ast
+import torch
+from typing import Tuple, List
+from torchvision.ops import box_convert
+import re
+from torchvision.transforms import ToPILImage
+import supervision as sv
+import torchvision.transforms as T
+
+
+def get_caption_model_processor(model_name, model_name_or_path="Salesforce/blip2-opt-2.7b", device=None):
+ if not device:
+ device = "cuda" if torch.cuda.is_available() else "cpu"
+ if model_name == "blip2":
+ from transformers import Blip2Processor, Blip2ForConditionalGeneration
+ processor = Blip2Processor.from_pretrained("Salesforce/blip2-opt-2.7b")
+ if device == 'cpu':
+ model = Blip2ForConditionalGeneration.from_pretrained(
+ model_name_or_path, device_map=None, torch_dtype=torch.float32
+ )
+ else:
+ model = Blip2ForConditionalGeneration.from_pretrained(
+ model_name_or_path, device_map=None, torch_dtype=torch.float16
+ ).to(device)
+ elif model_name == "florence2":
+ from transformers import AutoProcessor, AutoModelForCausalLM
+ processor = AutoProcessor.from_pretrained("microsoft/Florence-2-base", trust_remote_code=True)
+ if device == 'cpu':
+ model = AutoModelForCausalLM.from_pretrained(model_name_or_path, torch_dtype=torch.float32, trust_remote_code=True)
+ else:
+ model = AutoModelForCausalLM.from_pretrained(model_name_or_path, torch_dtype=torch.float16, trust_remote_code=True).to(device)
+ return {'model': model.to(device), 'processor': processor}
+
+
+def get_yolo_model(model_path):
+ from ultralytics import YOLO
+ # Load the model.
+ model = YOLO(model_path)
+ return model
+
+
+@torch.inference_mode()
+def get_parsed_content_icon(filtered_boxes, starting_idx, image_source, caption_model_processor, prompt=None, batch_size=None):
+ # Number of samples per batch, --> 256 roughly takes 23 GB of GPU memory for florence model
+
+ to_pil = ToPILImage()
+ if starting_idx:
+ non_ocr_boxes = filtered_boxes[starting_idx:]
+ else:
+ non_ocr_boxes = filtered_boxes
+ croped_pil_image = []
+ t0 = time.time()
+ for i, coord in enumerate(non_ocr_boxes):
+ try:
+ xmin, xmax = int(coord[0]*image_source.shape[1]), int(coord[2]*image_source.shape[1])
+ ymin, ymax = int(coord[1]*image_source.shape[0]), int(coord[3]*image_source.shape[0])
+ cropped_image = image_source[ymin:ymax, xmin:xmax, :]
+ # resize the image to 224x224 to avoid long overhead in clipimageprocessor # TODO
+ cropped_image = cv2.resize(cropped_image, (64, 64))
+ croped_pil_image.append(to_pil(cropped_image))
+ except:
+ continue
+ # print('time to prepare bbox:', time.time()-t0)
+
+ model, processor = caption_model_processor['model'], caption_model_processor['processor']
+ if not prompt:
+ if 'florence' in model.config.name_or_path:
+ prompt = "
"
+ else:
+ prompt = "The image shows"
+
+ generated_texts = []
+ device = model.device
+ # batch_size = 64
+ for i in range(0, len(croped_pil_image), batch_size):
+ start = time.time()
+ batch = croped_pil_image[i:i+batch_size]
+ t1 = time.time()
+ if model.device.type == 'cuda':
+ inputs = processor(images=batch, text=[prompt]*len(batch), return_tensors="pt", do_resize=False).to(device=device, dtype=torch.float16)
+ else:
+ inputs = processor(images=batch, text=[prompt]*len(batch), return_tensors="pt").to(device=device)
+ t2 = time.time()
+ # print('time to process image + tokenize text inputs:', t2-t1)
+ if 'florence' in model.config.name_or_path:
+ generated_ids = model.generate(input_ids=inputs["input_ids"],pixel_values=inputs["pixel_values"],max_new_tokens=20,num_beams=1, do_sample=False)
+ else:
+ generated_ids = model.generate(**inputs, max_length=100, num_beams=5, no_repeat_ngram_size=2, early_stopping=True, num_return_sequences=1) # temperature=0.01, do_sample=True,
+ t3 = time.time()
+ # print('time to generate:', t3-t2)
+ generated_text = processor.batch_decode(generated_ids, skip_special_tokens=True)
+ generated_text = [gen.strip() for gen in generated_text]
+ generated_texts.extend(generated_text)
+
+ return generated_texts
+
+
+
+def get_parsed_content_icon_phi3v(filtered_boxes, ocr_bbox, image_source, caption_model_processor):
+ to_pil = ToPILImage()
+ if ocr_bbox:
+ non_ocr_boxes = filtered_boxes[len(ocr_bbox):]
+ else:
+ non_ocr_boxes = filtered_boxes
+ croped_pil_image = []
+ for i, coord in enumerate(non_ocr_boxes):
+ xmin, xmax = int(coord[0]*image_source.shape[1]), int(coord[2]*image_source.shape[1])
+ ymin, ymax = int(coord[1]*image_source.shape[0]), int(coord[3]*image_source.shape[0])
+ cropped_image = image_source[ymin:ymax, xmin:xmax, :]
+ croped_pil_image.append(to_pil(cropped_image))
+
+ model, processor = caption_model_processor['model'], caption_model_processor['processor']
+ device = model.device
+ messages = [{"role": "user", "content": "<|image_1|>\ndescribe the icon in one sentence"}]
+ prompt = processor.tokenizer.apply_chat_template(messages, tokenize=False, add_generation_prompt=True)
+
+ batch_size = 5 # Number of samples per batch
+ generated_texts = []
+
+ for i in range(0, len(croped_pil_image), batch_size):
+ images = croped_pil_image[i:i+batch_size]
+ image_inputs = [processor.image_processor(x, return_tensors="pt") for x in images]
+ inputs ={'input_ids': [], 'attention_mask': [], 'pixel_values': [], 'image_sizes': []}
+ texts = [prompt] * len(images)
+ for i, txt in enumerate(texts):
+ input = processor._convert_images_texts_to_inputs(image_inputs[i], txt, return_tensors="pt")
+ inputs['input_ids'].append(input['input_ids'])
+ inputs['attention_mask'].append(input['attention_mask'])
+ inputs['pixel_values'].append(input['pixel_values'])
+ inputs['image_sizes'].append(input['image_sizes'])
+ max_len = max([x.shape[1] for x in inputs['input_ids']])
+ for i, v in enumerate(inputs['input_ids']):
+ inputs['input_ids'][i] = torch.cat([processor.tokenizer.pad_token_id * torch.ones(1, max_len - v.shape[1], dtype=torch.long), v], dim=1)
+ inputs['attention_mask'][i] = torch.cat([torch.zeros(1, max_len - v.shape[1], dtype=torch.long), inputs['attention_mask'][i]], dim=1)
+ inputs_cat = {k: torch.concatenate(v).to(device) for k, v in inputs.items()}
+
+ generation_args = {
+ "max_new_tokens": 25,
+ "temperature": 0.01,
+ "do_sample": False,
+ }
+ generate_ids = model.generate(**inputs_cat, eos_token_id=processor.tokenizer.eos_token_id, **generation_args)
+ # # remove input tokens
+ generate_ids = generate_ids[:, inputs_cat['input_ids'].shape[1]:]
+ response = processor.batch_decode(generate_ids, skip_special_tokens=True, clean_up_tokenization_spaces=False)
+ response = [res.strip('\n').strip() for res in response]
+ generated_texts.extend(response)
+
+ return generated_texts
+
+def remove_overlap(boxes, iou_threshold, ocr_bbox=None):
+ assert ocr_bbox is None or isinstance(ocr_bbox, List)
+
+ def box_area(box):
+ return (box[2] - box[0]) * (box[3] - box[1])
+
+ def intersection_area(box1, box2):
+ x1 = max(box1[0], box2[0])
+ y1 = max(box1[1], box2[1])
+ x2 = min(box1[2], box2[2])
+ y2 = min(box1[3], box2[3])
+ return max(0, x2 - x1) * max(0, y2 - y1)
+
+ def IoU(box1, box2):
+ intersection = intersection_area(box1, box2)
+ union = box_area(box1) + box_area(box2) - intersection + 1e-6
+ if box_area(box1) > 0 and box_area(box2) > 0:
+ ratio1 = intersection / box_area(box1)
+ ratio2 = intersection / box_area(box2)
+ else:
+ ratio1, ratio2 = 0, 0
+ return max(intersection / union, ratio1, ratio2)
+
+ def is_inside(box1, box2):
+ # return box1[0] >= box2[0] and box1[1] >= box2[1] and box1[2] <= box2[2] and box1[3] <= box2[3]
+ intersection = intersection_area(box1, box2)
+ ratio1 = intersection / box_area(box1)
+ return ratio1 > 0.95
+
+ boxes = boxes.tolist()
+ filtered_boxes = []
+ if ocr_bbox:
+ filtered_boxes.extend(ocr_bbox)
+ # print('ocr_bbox!!!', ocr_bbox)
+ for i, box1 in enumerate(boxes):
+ # if not any(IoU(box1, box2) > iou_threshold and box_area(box1) > box_area(box2) for j, box2 in enumerate(boxes) if i != j):
+ is_valid_box = True
+ for j, box2 in enumerate(boxes):
+ # keep the smaller box
+ if i != j and IoU(box1, box2) > iou_threshold and box_area(box1) > box_area(box2):
+ is_valid_box = False
+ break
+ if is_valid_box:
+ # add the following 2 lines to include ocr bbox
+ if ocr_bbox:
+ # only add the box if it does not overlap with any ocr bbox
+ if not any(IoU(box1, box3) > iou_threshold and not is_inside(box1, box3) for k, box3 in enumerate(ocr_bbox)):
+ filtered_boxes.append(box1)
+ else:
+ filtered_boxes.append(box1)
+ return torch.tensor(filtered_boxes)
+
+
+def remove_overlap_new(boxes, iou_threshold, ocr_bbox=None):
+ '''
+ ocr_bbox format: [{'type': 'text', 'bbox':[x,y], 'interactivity':False, 'content':str }, ...]
+ boxes format: [{'type': 'icon', 'bbox':[x,y], 'interactivity':True, 'content':None }, ...]
+
+ '''
+ assert ocr_bbox is None or isinstance(ocr_bbox, List)
+
+ def box_area(box):
+ return (box[2] - box[0]) * (box[3] - box[1])
+
+ def intersection_area(box1, box2):
+ x1 = max(box1[0], box2[0])
+ y1 = max(box1[1], box2[1])
+ x2 = min(box1[2], box2[2])
+ y2 = min(box1[3], box2[3])
+ return max(0, x2 - x1) * max(0, y2 - y1)
+
+ def IoU(box1, box2):
+ intersection = intersection_area(box1, box2)
+ union = box_area(box1) + box_area(box2) - intersection + 1e-6
+ if box_area(box1) > 0 and box_area(box2) > 0:
+ ratio1 = intersection / box_area(box1)
+ ratio2 = intersection / box_area(box2)
+ else:
+ ratio1, ratio2 = 0, 0
+ return max(intersection / union, ratio1, ratio2)
+
+ def is_inside(box1, box2):
+ # return box1[0] >= box2[0] and box1[1] >= box2[1] and box1[2] <= box2[2] and box1[3] <= box2[3]
+ intersection = intersection_area(box1, box2)
+ ratio1 = intersection / box_area(box1)
+ return ratio1 > 0.80
+
+ # boxes = boxes.tolist()
+ filtered_boxes = []
+ if ocr_bbox:
+ filtered_boxes.extend(ocr_bbox)
+ # print('ocr_bbox!!!', ocr_bbox)
+ for i, box1_elem in enumerate(boxes):
+ box1 = box1_elem['bbox']
+ is_valid_box = True
+ for j, box2_elem in enumerate(boxes):
+ # keep the smaller box
+ box2 = box2_elem['bbox']
+ if i != j and IoU(box1, box2) > iou_threshold and box_area(box1) > box_area(box2):
+ is_valid_box = False
+ break
+ if is_valid_box:
+ # add the following 2 lines to include ocr bbox
+ if ocr_bbox:
+ # keep yolo boxes + prioritize ocr label
+ box_added = False
+ ocr_labels = ''
+ for box3_elem in ocr_bbox:
+ if not box_added:
+ box3 = box3_elem['bbox']
+ if is_inside(box3, box1): # ocr inside icon
+ # box_added = True
+ # delete the box3_elem from ocr_bbox
+ try:
+ # filtered_boxes.append({'type': 'text', 'bbox': box1_elem['bbox'], 'interactivity': True, 'content': box3_elem['content'], 'source':'box_yolo_content_ocr'})
+ # gather all ocr labels
+ ocr_labels += box3_elem['content'] + ' '
+ filtered_boxes.remove(box3_elem)
+ # print('remove ocr bbox:', box3_elem)
+ except:
+ continue
+ # break
+ elif is_inside(box1, box3): # icon inside ocr, don't added this icon box, no need to check other ocr bbox bc no overlap between ocr bbox, icon can only be in one ocr box
+ box_added = True
+ # try:
+ # filtered_boxes.append({'type': 'icon', 'bbox': box1_elem['bbox'], 'interactivity': True, 'content': None})
+ # filtered_boxes.remove(box3_elem)
+ # except:
+ # continue
+ break
+ else:
+ continue
+ if not box_added:
+ if ocr_labels:
+ filtered_boxes.append({'type': 'icon', 'bbox': box1_elem['bbox'], 'interactivity': True, 'content': ocr_labels, 'source':'box_yolo_content_ocr'})
+ else:
+ filtered_boxes.append({'type': 'icon', 'bbox': box1_elem['bbox'], 'interactivity': True, 'content': None, 'source':'box_yolo_content_yolo'})
+
+ else:
+ filtered_boxes.append(box1)
+ return filtered_boxes # torch.tensor(filtered_boxes)
+
+
+def load_image(image_path: str) -> Tuple[np.array, torch.Tensor]:
+ transform = T.Compose(
+ [
+ T.RandomResize([800], max_size=1333),
+ T.ToTensor(),
+ T.Normalize([0.485, 0.456, 0.406], [0.229, 0.224, 0.225]),
+ ]
+ )
+ image_source = Image.open(image_path).convert("RGB")
+ image = np.asarray(image_source)
+ image_transformed, _ = transform(image_source, None)
+ return image, image_transformed
+
+
+def annotate(image_source: np.ndarray, boxes: torch.Tensor, logits: torch.Tensor, phrases: List[str], text_scale: float,
+ text_padding=5, text_thickness=2, thickness=3) -> np.ndarray:
+ """
+ This function annotates an image with bounding boxes and labels.
+
+ Parameters:
+ image_source (np.ndarray): The source image to be annotated.
+ boxes (torch.Tensor): A tensor containing bounding box coordinates. in cxcywh format, pixel scale
+ logits (torch.Tensor): A tensor containing confidence scores for each bounding box.
+ phrases (List[str]): A list of labels for each bounding box.
+ text_scale (float): The scale of the text to be displayed. 0.8 for mobile/web, 0.3 for desktop # 0.4 for mind2web
+
+ Returns:
+ np.ndarray: The annotated image.
+ """
+ h, w, _ = image_source.shape
+ boxes = boxes * torch.Tensor([w, h, w, h])
+ xyxy = box_convert(boxes=boxes, in_fmt="cxcywh", out_fmt="xyxy").numpy()
+ xywh = box_convert(boxes=boxes, in_fmt="cxcywh", out_fmt="xywh").numpy()
+ detections = sv.Detections(xyxy=xyxy)
+
+ labels = [f"{phrase}" for phrase in range(boxes.shape[0])]
+
+ # from util.box_annotator import BoxAnnotator
+ box_annotator = BoxAnnotator(text_scale=text_scale, text_padding=text_padding,text_thickness=text_thickness,thickness=thickness) # 0.8 for mobile/web, 0.3 for desktop # 0.4 for mind2web
+ annotated_frame = image_source.copy()
+ annotated_frame = box_annotator.annotate(scene=annotated_frame, detections=detections, labels=labels, image_size=(w,h))
+
+ label_coordinates = {f"{phrase}": v for phrase, v in zip(phrases, xywh)}
+ return annotated_frame, label_coordinates
+
+
+def predict(model, image, caption, box_threshold, text_threshold):
+ """ Use huggingface model to replace the original model
+ """
+ model, processor = model['model'], model['processor']
+ device = model.device
+
+ inputs = processor(images=image, text=caption, return_tensors="pt").to(device)
+ with torch.no_grad():
+ outputs = model(**inputs)
+
+ results = processor.post_process_grounded_object_detection(
+ outputs,
+ inputs.input_ids,
+ box_threshold=box_threshold, # 0.4,
+ text_threshold=text_threshold, # 0.3,
+ target_sizes=[image.size[::-1]]
+ )[0]
+ boxes, logits, phrases = results["boxes"], results["scores"], results["labels"]
+ return boxes, logits, phrases
+
+
+def predict_yolo(model, image_path, box_threshold, imgsz, scale_img, iou_threshold=0.7):
+ """ Use huggingface model to replace the original model
+ """
+ # model = model['model']
+ if scale_img:
+ result = model.predict(
+ source=image_path,
+ conf=box_threshold,
+ imgsz=imgsz,
+ iou=iou_threshold, # default 0.7
+ )
+ else:
+ result = model.predict(
+ source=image_path,
+ conf=box_threshold,
+ iou=iou_threshold, # default 0.7
+ )
+ boxes = result[0].boxes.xyxy#.tolist() # in pixel space
+ conf = result[0].boxes.conf
+ phrases = [str(i) for i in range(len(boxes))]
+
+ return boxes, conf, phrases
+
+
+def int_box_area(box, w, h):
+ x1, y1, x2, y2 = box
+ int_box = [int(x1*w), int(y1*h), int(x2*w), int(y2*h)]
+ area = (int_box[2] - int_box[0]) * (int_box[3] - int_box[1])
+ return area
+
+
+def get_som_labeled_img(img_path, model=None, BOX_TRESHOLD = 0.01, output_coord_in_ratio=False, ocr_bbox=None, text_scale=0.4, text_padding=5, draw_bbox_config=None, caption_model_processor=None, ocr_text=[], use_local_semantics=True, iou_threshold=0.9,prompt=None, scale_img=False, imgsz=None, batch_size=64):
+ """ ocr_bbox: list of xyxy format bbox
+ """
+ image_source = Image.open(img_path).convert("RGB")
+ w, h = image_source.size
+ if not imgsz:
+ imgsz = (h, w)
+ # print('image size:', w, h)
+ xyxy, logits, phrases = predict_yolo(model=model, image_path=img_path, box_threshold=BOX_TRESHOLD, imgsz=imgsz, scale_img=scale_img, iou_threshold=0.1)
+ xyxy = xyxy / torch.Tensor([w, h, w, h]).to(xyxy.device)
+ image_source = np.asarray(image_source)
+ phrases = [str(i) for i in range(len(phrases))]
+
+
+ # annotate the image with labels
+ if ocr_bbox:
+ ocr_bbox = torch.tensor(ocr_bbox) / torch.Tensor([w, h, w, h])
+ ocr_bbox=ocr_bbox.tolist()
+ else:
+ print('no ocr bbox!!!')
+ ocr_bbox = None
+
+ ocr_bbox_elem = [{'type': 'text', 'bbox':box, 'interactivity':False, 'content':txt, 'source': 'box_ocr_content_ocr'} for box, txt in zip(ocr_bbox, ocr_text) if int_box_area(box, w, h) > 0]
+ xyxy_elem = [{'type': 'icon', 'bbox':box, 'interactivity':True, 'content':None} for box in xyxy.tolist() if int_box_area(box, w, h) > 0]
+ filtered_boxes = remove_overlap_new(boxes=xyxy_elem, iou_threshold=iou_threshold, ocr_bbox=ocr_bbox_elem)
+
+
+ # sort the filtered_boxes so that the one with 'content': None is at the end, and get the index of the first 'content': None
+ filtered_boxes_elem = sorted(filtered_boxes, key=lambda x: x['content'] is None)
+ # get the index of the first 'content': None
+ starting_idx = next((i for i, box in enumerate(filtered_boxes_elem) if box['content'] is None), -1)
+ filtered_boxes = torch.tensor([box['bbox'] for box in filtered_boxes_elem])
+ print('len(filtered_boxes):', len(filtered_boxes), starting_idx)
+
+ # get parsed icon local semantics
+ time1 = time.time()
+ if use_local_semantics:
+ caption_model = caption_model_processor['model']
+ if 'phi3_v' in caption_model.config.model_type:
+ parsed_content_icon = get_parsed_content_icon_phi3v(filtered_boxes, ocr_bbox, image_source, caption_model_processor)
+ else:
+ parsed_content_icon = get_parsed_content_icon(filtered_boxes, starting_idx, image_source, caption_model_processor, prompt=prompt,batch_size=batch_size)
+ ocr_text = [f"Text Box ID {i}: {txt}" for i, txt in enumerate(ocr_text)]
+ icon_start = len(ocr_text)
+ parsed_content_icon_ls = []
+ # fill the filtered_boxes_elem None content with parsed_content_icon in order
+ for i, box in enumerate(filtered_boxes_elem):
+ if box['content'] is None:
+ box['content'] = parsed_content_icon.pop(0)
+ for i, txt in enumerate(parsed_content_icon):
+ parsed_content_icon_ls.append(f"Icon Box ID {str(i+icon_start)}: {txt}")
+ parsed_content_merged = ocr_text + parsed_content_icon_ls
+ else:
+ ocr_text = [f"Text Box ID {i}: {txt}" for i, txt in enumerate(ocr_text)]
+ parsed_content_merged = ocr_text
+ print('time to get parsed content:', time.time()-time1)
+
+ filtered_boxes = box_convert(boxes=filtered_boxes, in_fmt="xyxy", out_fmt="cxcywh")
+
+ phrases = [i for i in range(len(filtered_boxes))]
+
+ # draw boxes
+ if draw_bbox_config:
+ annotated_frame, label_coordinates = annotate(image_source=image_source, boxes=filtered_boxes, logits=logits, phrases=phrases, **draw_bbox_config)
+ else:
+ annotated_frame, label_coordinates = annotate(image_source=image_source, boxes=filtered_boxes, logits=logits, phrases=phrases, text_scale=text_scale, text_padding=text_padding)
+
+ pil_img = Image.fromarray(annotated_frame)
+ buffered = io.BytesIO()
+ pil_img.save(buffered, format="PNG")
+ encoded_image = base64.b64encode(buffered.getvalue()).decode('ascii')
+ if output_coord_in_ratio:
+ # h, w, _ = image_source.shape
+ label_coordinates = {k: [v[0]/w, v[1]/h, v[2]/w, v[3]/h] for k, v in label_coordinates.items()}
+ assert w == annotated_frame.shape[1] and h == annotated_frame.shape[0]
+
+ return encoded_image, label_coordinates, filtered_boxes_elem
+
+
+def get_xywh(input):
+ x, y, w, h = input[0][0], input[0][1], input[2][0] - input[0][0], input[2][1] - input[0][1]
+ x, y, w, h = int(x), int(y), int(w), int(h)
+ return x, y, w, h
+
+def get_xyxy(input):
+ x, y, xp, yp = input[0][0], input[0][1], input[2][0], input[2][1]
+ x, y, xp, yp = int(x), int(y), int(xp), int(yp)
+ return x, y, xp, yp
+
+def get_xywh_yolo(input):
+ x, y, w, h = input[0], input[1], input[2] - input[0], input[3] - input[1]
+ x, y, w, h = int(x), int(y), int(w), int(h)
+ return x, y, w, h
+
+
+
+def check_ocr_box(image_path, display_img = True, output_bb_format='xywh', goal_filtering=None, easyocr_args=None, use_paddleocr=False):
+ if use_paddleocr:
+ if easyocr_args is None:
+ text_threshold = 0.5
+ else:
+ text_threshold = easyocr_args['text_threshold']
+ result = paddle_ocr.ocr(image_path, cls=False)[0]
+ # conf = [item[1] for item in result]
+ coord = [item[0] for item in result if item[1][1] > text_threshold]
+ text = [item[1][0] for item in result if item[1][1] > text_threshold]
+ else: # EasyOCR
+ if easyocr_args is None:
+ easyocr_args = {}
+ result = reader.readtext(image_path, **easyocr_args)
+ # print('goal filtering pred:', result[-5:])
+ coord = [item[0] for item in result]
+ text = [item[1] for item in result]
+ # read the image using cv2
+ if display_img:
+ opencv_img = cv2.imread(image_path)
+ opencv_img = cv2.cvtColor(opencv_img, cv2.COLOR_RGB2BGR)
+ bb = []
+ for item in coord:
+ x, y, a, b = get_xywh(item)
+ # print(x, y, a, b)
+ bb.append((x, y, a, b))
+ cv2.rectangle(opencv_img, (x, y), (x+a, y+b), (0, 255, 0), 2)
+
+ # Display the image
+ plt.imshow(opencv_img)
+ else:
+ if output_bb_format == 'xywh':
+ bb = [get_xywh(item) for item in coord]
+ elif output_bb_format == 'xyxy':
+ bb = [get_xyxy(item) for item in coord]
+ # print('bounding box!!!', bb)
+ return (text, bb), goal_filtering
+
+
+
+from typing import List, Optional, Union, Tuple
+
+import cv2
+import numpy as np
+
+from supervision.detection.core import Detections
+from supervision.draw.color import Color, ColorPalette
+
+
+class BoxAnnotator:
+ """
+ A class for drawing bounding boxes on an image using detections provided.
+
+ Attributes:
+ color (Union[Color, ColorPalette]): The color to draw the bounding box,
+ can be a single color or a color palette
+ thickness (int): The thickness of the bounding box lines, default is 2
+ text_color (Color): The color of the text on the bounding box, default is white
+ text_scale (float): The scale of the text on the bounding box, default is 0.5
+ text_thickness (int): The thickness of the text on the bounding box,
+ default is 1
+ text_padding (int): The padding around the text on the bounding box,
+ default is 5
+
+ """
+
+ def __init__(
+ self,
+ color: Union[Color, ColorPalette] = ColorPalette.DEFAULT,
+ thickness: int = 3, # 1 for seeclick 2 for mind2web and 3 for demo
+ text_color: Color = Color.BLACK,
+ text_scale: float = 0.5, # 0.8 for mobile/web, 0.3 for desktop # 0.4 for mind2web
+ text_thickness: int = 2, #1, # 2 for demo
+ text_padding: int = 10,
+ avoid_overlap: bool = True,
+ ):
+ self.color: Union[Color, ColorPalette] = color
+ self.thickness: int = thickness
+ self.text_color: Color = text_color
+ self.text_scale: float = text_scale
+ self.text_thickness: int = text_thickness
+ self.text_padding: int = text_padding
+ self.avoid_overlap: bool = avoid_overlap
+
+ def annotate(
+ self,
+ scene: np.ndarray,
+ detections: Detections,
+ labels: Optional[List[str]] = None,
+ skip_label: bool = False,
+ image_size: Optional[Tuple[int, int]] = None,
+ ) -> np.ndarray:
+ """
+ Draws bounding boxes on the frame using the detections provided.
+
+ Args:
+ scene (np.ndarray): The image on which the bounding boxes will be drawn
+ detections (Detections): The detections for which the
+ bounding boxes will be drawn
+ labels (Optional[List[str]]): An optional list of labels
+ corresponding to each detection. If `labels` are not provided,
+ corresponding `class_id` will be used as label.
+ skip_label (bool): Is set to `True`, skips bounding box label annotation.
+ Returns:
+ np.ndarray: The image with the bounding boxes drawn on it
+
+ Example:
+ ```python
+ import supervision as sv
+
+ classes = ['person', ...]
+ image = ...
+ detections = sv.Detections(...)
+
+ box_annotator = sv.BoxAnnotator()
+ labels = [
+ f"{classes[class_id]} {confidence:0.2f}"
+ for _, _, confidence, class_id, _ in detections
+ ]
+ annotated_frame = box_annotator.annotate(
+ scene=image.copy(),
+ detections=detections,
+ labels=labels
+ )
+ ```
+ """
+ font = cv2.FONT_HERSHEY_SIMPLEX
+ for i in range(len(detections)):
+ x1, y1, x2, y2 = detections.xyxy[i].astype(int)
+ class_id = (
+ detections.class_id[i] if detections.class_id is not None else None
+ )
+ idx = class_id if class_id is not None else i
+ color = (
+ self.color.by_idx(idx)
+ if isinstance(self.color, ColorPalette)
+ else self.color
+ )
+ cv2.rectangle(
+ img=scene,
+ pt1=(x1, y1),
+ pt2=(x2, y2),
+ color=color.as_bgr(),
+ thickness=self.thickness,
+ )
+ if skip_label:
+ continue
+
+ text = (
+ f"{class_id}"
+ if (labels is None or len(detections) != len(labels))
+ else labels[i]
+ )
+
+ text_width, text_height = cv2.getTextSize(
+ text=text,
+ fontFace=font,
+ fontScale=self.text_scale,
+ thickness=self.text_thickness,
+ )[0]
+
+ if not self.avoid_overlap:
+ text_x = x1 + self.text_padding
+ text_y = y1 - self.text_padding
+
+ text_background_x1 = x1
+ text_background_y1 = y1 - 2 * self.text_padding - text_height
+
+ text_background_x2 = x1 + 2 * self.text_padding + text_width
+ text_background_y2 = y1
+ # text_x = x1 - self.text_padding - text_width
+ # text_y = y1 + self.text_padding + text_height
+ # text_background_x1 = x1 - 2 * self.text_padding - text_width
+ # text_background_y1 = y1
+ # text_background_x2 = x1
+ # text_background_y2 = y1 + 2 * self.text_padding + text_height
+ else:
+ text_x, text_y, text_background_x1, text_background_y1, text_background_x2, text_background_y2 = get_optimal_label_pos(self.text_padding, text_width, text_height, x1, y1, x2, y2, detections, image_size)
+
+ cv2.rectangle(
+ img=scene,
+ pt1=(text_background_x1, text_background_y1),
+ pt2=(text_background_x2, text_background_y2),
+ color=color.as_bgr(),
+ thickness=cv2.FILLED,
+ )
+ box_color = color.as_rgb()
+ luminance = 0.299 * box_color[0] + 0.587 * box_color[1] + 0.114 * box_color[2]
+ text_color = (0,0,0) if luminance > 160 else (255,255,255)
+ cv2.putText(
+ img=scene,
+ text=text,
+ org=(text_x, text_y),
+ fontFace=font,
+ fontScale=self.text_scale,
+ # color=self.text_color.as_rgb(),
+ color=text_color,
+ thickness=self.text_thickness,
+ lineType=cv2.LINE_AA,
+ )
+ return scene
+
+
+def box_area(box):
+ return (box[2] - box[0]) * (box[3] - box[1])
+
+def intersection_area(box1, box2):
+ x1 = max(box1[0], box2[0])
+ y1 = max(box1[1], box2[1])
+ x2 = min(box1[2], box2[2])
+ y2 = min(box1[3], box2[3])
+ return max(0, x2 - x1) * max(0, y2 - y1)
+
+def IoU(box1, box2, return_max=True):
+ intersection = intersection_area(box1, box2)
+ union = box_area(box1) + box_area(box2) - intersection
+ if box_area(box1) > 0 and box_area(box2) > 0:
+ ratio1 = intersection / box_area(box1)
+ ratio2 = intersection / box_area(box2)
+ else:
+ ratio1, ratio2 = 0, 0
+ if return_max:
+ return max(intersection / union, ratio1, ratio2)
+ else:
+ return intersection / union
+
+
+def get_optimal_label_pos(text_padding, text_width, text_height, x1, y1, x2, y2, detections, image_size):
+ """ check overlap of text and background detection box, and get_optimal_label_pos,
+ pos: str, position of the text, must be one of 'top left', 'top right', 'outer left', 'outer right' TODO: if all are overlapping, return the last one, i.e. outer right
+ Threshold: default to 0.3
+ """
+
+ def get_is_overlap(detections, text_background_x1, text_background_y1, text_background_x2, text_background_y2, image_size):
+ is_overlap = False
+ for i in range(len(detections)):
+ detection = detections.xyxy[i].astype(int)
+ if IoU([text_background_x1, text_background_y1, text_background_x2, text_background_y2], detection) > 0.3:
+ is_overlap = True
+ break
+ # check if the text is out of the image
+ if text_background_x1 < 0 or text_background_x2 > image_size[0] or text_background_y1 < 0 or text_background_y2 > image_size[1]:
+ is_overlap = True
+ return is_overlap
+
+ # if pos == 'top left':
+ text_x = x1 + text_padding
+ text_y = y1 - text_padding
+
+ text_background_x1 = x1
+ text_background_y1 = y1 - 2 * text_padding - text_height
+
+ text_background_x2 = x1 + 2 * text_padding + text_width
+ text_background_y2 = y1
+ is_overlap = get_is_overlap(detections, text_background_x1, text_background_y1, text_background_x2, text_background_y2, image_size)
+ if not is_overlap:
+ return text_x, text_y, text_background_x1, text_background_y1, text_background_x2, text_background_y2
+
+ # elif pos == 'outer left':
+ text_x = x1 - text_padding - text_width
+ text_y = y1 + text_padding + text_height
+
+ text_background_x1 = x1 - 2 * text_padding - text_width
+ text_background_y1 = y1
+
+ text_background_x2 = x1
+ text_background_y2 = y1 + 2 * text_padding + text_height
+ is_overlap = get_is_overlap(detections, text_background_x1, text_background_y1, text_background_x2, text_background_y2, image_size)
+ if not is_overlap:
+ return text_x, text_y, text_background_x1, text_background_y1, text_background_x2, text_background_y2
+
+
+ # elif pos == 'outer right':
+ text_x = x2 + text_padding
+ text_y = y1 + text_padding + text_height
+
+ text_background_x1 = x2
+ text_background_y1 = y1
+
+ text_background_x2 = x2 + 2 * text_padding + text_width
+ text_background_y2 = y1 + 2 * text_padding + text_height
+
+ is_overlap = get_is_overlap(detections, text_background_x1, text_background_y1, text_background_x2, text_background_y2, image_size)
+ if not is_overlap:
+ return text_x, text_y, text_background_x1, text_background_y1, text_background_x2, text_background_y2
+
+ # elif pos == 'top right':
+ text_x = x2 - text_padding - text_width
+ text_y = y1 - text_padding
+
+ text_background_x1 = x2 - 2 * text_padding - text_width
+ text_background_y1 = y1 - 2 * text_padding - text_height
+
+ text_background_x2 = x2
+ text_background_y2 = y1
+
+ is_overlap = get_is_overlap(detections, text_background_x1, text_background_y1, text_background_x2, text_background_y2, image_size)
+ if not is_overlap:
+ return text_x, text_y, text_background_x1, text_background_y1, text_background_x2, text_background_y2
+
+ return text_x, text_y, text_background_x1, text_background_y1, text_background_x2, text_background_y2
+
+
+
+import re
+def extract_dict_from_text(text):
+ # Define the regex pattern for a dictionary-like structure
+ pattern = r"\{\s*'(?P.*?)':\s*'(?P.*?)',\s*'(?P.*?)':\s*'(?P.*?)'\s*\}"
+
+ # Search for the dictionary in the text
+ match = re.search(pattern, text)
+
+ if match:
+ # Extract matched groups into a dictionary
+ return {
+ match.group('key1'): match.group('value1'),
+ match.group('key2'): match.group('value2'),
+ }
+ else:
+ raise ValueError("No valid dictionary structure found in the text.")
+
+
+def get_phi3v_model_dict():
+ from PIL import Image
+ import requests
+ from transformers import AutoModelForCausalLM
+ from transformers import AutoProcessor
+
+ model_id = "microsoft/Phi-3.5-vision-instruct"
+ model = AutoModelForCausalLM.from_pretrained(model_id, device_map="cuda", trust_remote_code=True, torch_dtype="auto", _attn_implementation='flash_attention_2')
+ processor = AutoProcessor.from_pretrained(model_id, trust_remote_code=True)
+ print('phi3v model loaded!!!')
+ return {'model': model, 'processor': processor}
+
+
+def call_phi3v(messages, image_base64, model_dict):
+ model, processor = model_dict['model'], model_dict['processor']
+ device = model.device
+ prompt = processor.tokenizer.apply_chat_template(messages, tokenize=False, add_generation_prompt=True)
+ if isinstance(image_base64, tuple):
+ image_base64, dino_labled_img = image_base64
+ image = Image.open(io.BytesIO(base64.b64decode(image_base64)))
+ dino_labled_img = Image.open(io.BytesIO(base64.b64decode(dino_labled_img)))
+ inputs = processor(prompt, [image, dino_labled_img], return_tensors="pt").to(device)
+ else:
+ image = Image.open(io.BytesIO(base64.b64decode(image_base64)))
+ inputs = processor(prompt, [image], return_tensors="pt").to(device)
+
+ generation_args = {
+ "max_new_tokens": 512,
+ "temperature": 0.01,
+ "do_sample": False,
+ }
+
+ generate_ids = model.generate(**inputs, eos_token_id=processor.tokenizer.eos_token_id, **generation_args)
+ # remove input tokens
+ generate_ids = generate_ids[:, inputs['input_ids'].shape[1]:]
+ ans = processor.batch_decode(generate_ids, skip_special_tokens=True, clean_up_tokenization_spaces=False)[0]
+ return ans
+
+
+def get_pred_phi3v(message_text, image_base64, label_coordinates, id_key='Click ID', model_dict=None):
+ # messages = [
+ # {"role": "system", "content": '''You are an expert at completing instructions on GUI screens.
+ # You will be presented with two images. The first is the original screenshot. The second is the same screenshot with some numeric tags. You will also be provided with some descriptions of the bbox, and your task is to choose the numeric bbox idx you want to click in order to complete the user instruction.'''},
+ # ]
+ messages = [
+ {"role": "system", "content": '''You are an expert at completing instructions on GUI screens. You will also be provided with some descriptions of the bbox, and your task is to choose the numeric bbox idx you want to click in order to complete the user instruction.'''},
+ ]
+ messages = []
+ if isinstance(image_base64, tuple):
+ messages.append({"role": "user", "content": '<|image_1|>\n' + '<|image_2|>\n' + message_text})
+ else:
+ messages.append({"role": "user", "content": '<|image_1|>\n' + message_text})
+
+ response_text = call_phi3v(messages, image_base64, model_dict)
+ print(response_text)
+
+ try:
+ response_text = ast.literal_eval(response_text)
+
+ icon_id = response_text['Click BBox ID']
+ bbox = label_coordinates[str(icon_id)]
+ click_point = [bbox[0] + bbox[2]/2, bbox[1] + bbox[3]/2]
+ except:
+ print('error parsing, use regex to parse!!!')
+ import pdb; pdb.set_trace()
+ response_text = extract_dict_from_text(response_text)
+ icon_id = response_text['Click BBox ID']
+ bbox = label_coordinates[str(icon_id)]
+ click_point = [bbox[0] + bbox[2]/2, bbox[1] + bbox[3]/2]
+ return icon_id, bbox, click_point, response_text
+
+ # try:
+ # match = re.search(r"```(.*?)```", ans, re.DOTALL)
+ # if match:
+ # result = match.group(1).strip()
+ # pred = result.split('In summary, the next action I will perform is:')[-1].strip().replace('\\', '')
+ # pred = ast.literal_eval(pred)
+ # else:
+ # pred = ans.split('In summary, the next action I will perform is:')[-1].strip().replace('\\', '')
+ # pred = ast.literal_eval(pred)
+
+ # if pred[id_key]:
+ # icon_id = pred[id_key]
+ # bbox = label_coordinates[str(icon_id)]
+ # pred['click_point'] = [bbox[0] + bbox[2]/2, bbox[1] + bbox[3]/2]
+ # except:
+ # print('phi3v action regex extract fail!!!')
+ # pred = {'action_type': 'CLICK', 'click_point': [0, 0], 'value': 'None', 'is_completed': False}
+
+ # step_pred_summary = None
+ # return pred, [True, ans, None, step_pred_summary]
\ No newline at end of file
diff --git a/demo/gradio/computer_use_demo/omniparser_agent/vlm_agent.py b/demo/gradio/computer_use_demo/omniparser_agent/vlm_agent.py
index 26542e8..66c740a 100644
--- a/demo/gradio/computer_use_demo/omniparser_agent/vlm_agent.py
+++ b/demo/gradio/computer_use_demo/omniparser_agent/vlm_agent.py
@@ -21,7 +21,8 @@ from computer_use_demo.gui_agent.llm_utils.oai import run_oai_interleaved, encod
from computer_use_demo.gui_agent.llm_utils.qwen import run_qwen
from computer_use_demo.gui_agent.llm_utils.llm_utils import extract_data
from computer_use_demo.colorful_text import colorful_text_vlm
-
+import time
+# start = time.time()
SYSTEM_PROMPT = f"""
* You are utilizing a Windows system with internet access.
@@ -36,14 +37,34 @@ class OmniParser:
selected_screen: int = 0) -> None:
self.url = url
self.selected_screen = selected_screen
+ if not self.url:
+ config = {
+ 'som_model_path': '../weights/icon_detect_v1_5/model_v1_5.pt',
+ 'device': 'cpu',
+ 'caption_model_name': 'florence2',
+ 'caption_model_path': '../weights/icon_caption_florence',
+ 'BOX_TRESHOLD': 0.05
+ }
+ from computer_use_demo.omniparser_agent.omniparser import Omniparser as Omniparser_class
+ self.omniparser = Omniparser_class(config=config)
def __call__(self,):
screenshot, screenshot_path = get_screenshot(selected_screen=self.selected_screen)
screenshot_path = str(screenshot_path)
image_base64 = encode_image(screenshot_path)
+ if self.url:
+ response = requests.post(self.url, json={"base64_image": image_base64, 'prompt': 'omniparser process'})
+ response_json = response.json()
+ else:
+ start = time.time()
+ dino_labled_img, parsed_content_list = self.omniparser.parse(image_base64)
+ latency = time.time() - start
+ response_json = {
+ 'som_image_base64': dino_labled_img,
+ 'parsed_content_list': parsed_content_list,
+ 'latency': latency
+ }
- response = requests.post(self.url, json={"base64_image": image_base64, 'prompt': 'omniparser process'})
- response_json = response.json()
som_image_data = base64.b64decode(response_json['som_image_base64'])
screenshot_path_uuid = Path(screenshot_path).stem.replace("screenshot_", "")
som_screenshot_path = f"{OUTPUT_DIR}/screenshot_som_{screenshot_path_uuid}.png"
@@ -64,9 +85,11 @@ class OmniParser:
for idx, element in enumerate(response_json["parsed_content_list"]):
element['idx'] = idx
if element['type'] == 'text':
- screen_info += f'''
\n'''
+ # screen_info += f'''
\n'''
+ screen_info += f'ID: {idx}, Text: {element["content"]}\n'
elif element['type'] == 'icon':
- screen_info += f'''
\n'''
+ # screen_info += f'''
\n'''
+ screen_info += f'ID: {idx}, Icon: {element["content"]}\n'
response_json['screen_info'] = screen_info
return response_json
@@ -106,12 +129,16 @@ class VLMAgent:
self.system = system_prompt_suffix
- def __call__(self, messages: list, parsed_screen: list[str, list]):
+ def __call__(self, messages: list, parsed_screen: list[str, list, dict]):
# Show results of Omniparser
image_base64 = parsed_screen['original_screenshot_base64']
+ latency_omniparser = parsed_screen['latency']
self.output_callback(f'Screenshot for {colorful_text_vlm}:\n
',
sender="bot")
self.output_callback(f'Set of Marks Screenshot for {colorful_text_vlm}:\n
', sender="bot")
+ screen_info = str(parsed_screen['screen_info'])
+ self.output_callback(f'Screen Info for {colorful_text_vlm}:\n{screen_info}', sender="bot")
+
screenshot_uuid = parsed_screen['screenshot_uuid']
screen_width, screen_height = parsed_screen['width'], parsed_screen['height']
@@ -125,7 +152,7 @@ class VLMAgent:
planner_messages = _keep_latest_images(planner_messages)
# if self.only_n_most_recent_images:
# _maybe_filter_to_n_most_recent_images(planner_messages, self.only_n_most_recent_images)
- print(f"filtered_messages: {planner_messages}\n\n", "full messages:", messages)
+ # print(f"filtered_messages: {planner_messages}\n\n", "full messages:", messages)
if isinstance(planner_messages[-1], dict):
if not isinstance(planner_messages[-1]["content"], list):
@@ -134,7 +161,7 @@ class VLMAgent:
planner_messages[-1]["content"].append(f"{OUTPUT_DIR}/screenshot_som_{screenshot_uuid}.png")
print(f"Sending messages to VLMPlanner : {planner_messages}")
-
+ start = time.time()
if "gpt" in self.model:
vlm_response, token_usage = run_oai_interleaved(
messages=planner_messages,
@@ -164,6 +191,8 @@ class VLMAgent:
pass # TODO
else:
raise ValueError(f"Model {self.model} not supported")
+ latency_vlm = time.time() - start
+ self.output_callback(f"VLMPlanner latency: {latency_vlm}, Omniparser latency: {latency_omniparser}", sender="bot")
print(f"VLMPlanner response: {vlm_response}")
@@ -189,7 +218,7 @@ class VLMAgent:
vlm_plan_str += f'{value}'
else:
vlm_plan_str += f'\n{key}: {value}'
- self.output_callback(f"{colorful_text_vlm}:\n{vlm_plan_str}", sender="bot")
+ # self.output_callback(f"{colorful_text_vlm}:\n{vlm_plan_str}", sender="bot")
# construct the response so that anthropicExcutor can execute the tool
response_content = [BetaTextBlock(text=vlm_plan_str, type='text')]
@@ -231,10 +260,12 @@ Here is the list of all detected bounding boxes by IDs on the screen and their d
Your available "Next Action" only include:
- type: type a string of text.
- left_click: Describe the ui element to be clicked.
-- enter: Press an enter key.
+- double_click: Describe the ui element to be double clicked.
+- right_click: Describe the ui element to be right clicked.
- escape: Press an ESCAPE key.
- hover: Describe the ui element to be hovered.
-- scroll: Scroll the screen, you must specify up or down.
+- scroll_up: Scroll the screen up.
+- scroll_down: Scroll the screen down.
- press: Describe the ui element to be pressed.
Based on the visual information from the screenshot image and the detected bounding boxes, please determine the next action, the Box ID you should operate on, and the value (if the action is 'type') in order to complete the task.
diff --git a/demo/gradio/computer_use_demo/tools/computer.py b/demo/gradio/computer_use_demo/tools/computer.py
index 912c096..7086a62 100644
--- a/demo/gradio/computer_use_demo/tools/computer.py
+++ b/demo/gradio/computer_use_demo/tools/computer.py
@@ -208,6 +208,12 @@ class ComputerTool(BaseAnthropicTool):
time.sleep(1)
self.send_to_vm("pyautogui.mouseUp()")
return ToolResult(output=f"Performed {action}")
+ if action in ("scroll_up", "scroll_down"):
+ if action == "scroll_up":
+ self.send_to_vm("pyautogui.scroll(100)")
+ elif action == "scroll_down":
+ self.send_to_vm("pyautogui.scroll(-100)")
+ return ToolResult(output=f"Performed {action}")
raise ToolError(f"Invalid action: {action}")
diff --git a/demo/gradio/demo_image.jpg b/demo/gradio/demo_image.jpg
new file mode 100644
index 0000000..06a901c
Binary files /dev/null and b/demo/gradio/demo_image.jpg differ
diff --git a/demo/gradio/demo_image_som.jpg b/demo/gradio/demo_image_som.jpg
new file mode 100644
index 0000000..2d580f3
Binary files /dev/null and b/demo/gradio/demo_image_som.jpg differ
diff --git a/demo/gradio/fast_api_demo.py b/demo/gradio/fast_api_demo.py
new file mode 100644
index 0000000..d50bbd4
--- /dev/null
+++ b/demo/gradio/fast_api_demo.py
@@ -0,0 +1,12 @@
+from fastapi import FastAPI
+import gradio as gr
+
+app = FastAPI()
+
+def greet(name):
+ return f"Hello, {name}!"
+
+gradio_interface = gr.Interface(fn=greet, inputs="text", outputs="text")
+gradio_app = gr.routes.App.create_app(gradio_interface)
+
+app.mount("/gradio", gradio_app)