update readme; demo
This commit is contained in:
234
utils.py
234
utils.py
@@ -18,7 +18,7 @@ import numpy as np
|
||||
# %matplotlib inline
|
||||
from matplotlib import pyplot as plt
|
||||
import easyocr
|
||||
reader = easyocr.Reader(['en']) # this needs to run only once to load the model into memory # 'ch_sim',
|
||||
reader = easyocr.Reader(['en'])
|
||||
import time
|
||||
import base64
|
||||
|
||||
@@ -33,44 +33,19 @@ import supervision as sv
|
||||
import torchvision.transforms as T
|
||||
|
||||
|
||||
def get_caption_model_processor(model_name="Salesforce/blip2-opt-2.7b", device=None):
|
||||
def get_caption_model_processor(model_name_or_path="Salesforce/blip2-opt-2.7b", device=None):
|
||||
if not device:
|
||||
device = "cuda" if torch.cuda.is_available() else "cpu"
|
||||
if model_name == "Salesforce/blip2-opt-2.7b":
|
||||
from transformers import Blip2Processor, Blip2ForConditionalGeneration
|
||||
processor = Blip2Processor.from_pretrained("Salesforce/blip2-opt-2.7b")
|
||||
from transformers import Blip2Processor, Blip2ForConditionalGeneration
|
||||
processor = Blip2Processor.from_pretrained("Salesforce/blip2-opt-2.7b")
|
||||
if device == 'cpu':
|
||||
model = Blip2ForConditionalGeneration.from_pretrained(
|
||||
"Salesforce/blip2-opt-2.7b", device_map=None, torch_dtype=torch.float16
|
||||
# '/home/yadonglu/sandbox/data/orca/blipv2_ui_merge', device_map=None, torch_dtype=torch.float16
|
||||
)
|
||||
elif model_name == "blip2-opt-2.7b-ui":
|
||||
from transformers import Blip2Processor, Blip2ForConditionalGeneration
|
||||
processor = Blip2Processor.from_pretrained("Salesforce/blip2-opt-2.7b")
|
||||
if device == 'cpu':
|
||||
model = Blip2ForConditionalGeneration.from_pretrained(
|
||||
'/home/yadonglu/sandbox/data/orca/blipv2_ui_merge', device_map=None, torch_dtype=torch.float32
|
||||
)
|
||||
else:
|
||||
model = Blip2ForConditionalGeneration.from_pretrained(
|
||||
'/home/yadonglu/sandbox/data/orca/blipv2_ui_merge', device_map=None, torch_dtype=torch.float16
|
||||
)
|
||||
elif model_name == "florence":
|
||||
from transformers import AutoProcessor, AutoModelForCausalLM
|
||||
processor = AutoProcessor.from_pretrained("microsoft/Florence-2-base", trust_remote_code=True)
|
||||
if device == 'cpu':
|
||||
model = AutoModelForCausalLM.from_pretrained("/home/yadonglu/sandbox/data/orca/florence-2-base-ft-fft_ep1_rai", torch_dtype=torch.float32, trust_remote_code=True)#.to(device)
|
||||
else:
|
||||
model = AutoModelForCausalLM.from_pretrained("/home/yadonglu/sandbox/data/orca/florence-2-base-ft-fft_ep1_rai_win_ep5_fixed", torch_dtype=torch.float16, trust_remote_code=True).to(device)
|
||||
elif model_name == 'phi3v_ui':
|
||||
from transformers import AutoModelForCausalLM, AutoProcessor
|
||||
model_id = "microsoft/Phi-3-vision-128k-instruct"
|
||||
model = AutoModelForCausalLM.from_pretrained('/home/yadonglu/sandbox/data/orca/phi3v_ui', device_map=device, trust_remote_code=True, torch_dtype="auto")
|
||||
processor = AutoProcessor.from_pretrained(model_id, trust_remote_code=True)
|
||||
elif model_name == 'phi3v':
|
||||
from transformers import AutoModelForCausalLM, AutoProcessor
|
||||
model_id = "microsoft/Phi-3-vision-128k-instruct"
|
||||
model = AutoModelForCausalLM.from_pretrained(model_id, device_map=device, trust_remote_code=True, torch_dtype="auto")
|
||||
processor = AutoProcessor.from_pretrained(model_id, trust_remote_code=True)
|
||||
model_name_or_path, device_map=None, torch_dtype=torch.float32
|
||||
)
|
||||
else:
|
||||
model = Blip2ForConditionalGeneration.from_pretrained(
|
||||
model_name_or_path, device_map=None, torch_dtype=torch.float16
|
||||
)
|
||||
return {'model': model.to(device), 'processor': processor}
|
||||
|
||||
|
||||
@@ -94,14 +69,12 @@ def get_parsed_content_icon(filtered_boxes, ocr_bbox, image_source, caption_mode
|
||||
cropped_image = image_source[ymin:ymax, xmin:xmax, :]
|
||||
croped_pil_image.append(to_pil(cropped_image))
|
||||
|
||||
# import pdb; pdb.set_trace()
|
||||
model, processor = caption_model_processor['model'], caption_model_processor['processor']
|
||||
if not prompt:
|
||||
if 'florence' in model.config.name_or_path:
|
||||
prompt = "<CAPTION>"
|
||||
else:
|
||||
prompt = "The image shows"
|
||||
# prompt = "NO gender!NO gender!NO gender! The image shows a icon:"
|
||||
|
||||
batch_size = 10 # Number of samples per batch
|
||||
generated_texts = []
|
||||
@@ -387,117 +360,15 @@ def get_xywh_yolo(input):
|
||||
return x, y, w, h
|
||||
|
||||
|
||||
def run_api(body, max_tokens=1024):
|
||||
'''
|
||||
API call, check https://platform.openai.com/docs/guides/vision for the latest api usage.
|
||||
'''
|
||||
max_num_trial = 3
|
||||
num_trial = 0
|
||||
while num_trial < max_num_trial:
|
||||
try:
|
||||
response = client.chat.completions.create(
|
||||
model=deployment,
|
||||
messages=body,
|
||||
temperature=0.01,
|
||||
max_tokens=max_tokens,
|
||||
)
|
||||
return response.choices[0].message.content
|
||||
except:
|
||||
print('retry call gptv', num_trial)
|
||||
num_trial += 1
|
||||
time.sleep(10)
|
||||
return ''
|
||||
|
||||
def call_gpt4v_new(message_text, image_path=None, max_tokens=2048):
|
||||
if image_path:
|
||||
try:
|
||||
with open(image_path, "rb") as img_file:
|
||||
encoded_image = base64.b64encode(img_file.read()).decode('ascii')
|
||||
except:
|
||||
encoded_image = image_path
|
||||
|
||||
if image_path:
|
||||
content = [{"type": "image_url","image_url": {"url": f"data:image/jpeg;base64,{encoded_image}"}}, {"type": "text","text": message_text},]
|
||||
else:
|
||||
content = [{"type": "text","text": message_text},]
|
||||
|
||||
max_num_trial = 3
|
||||
num_trial = 0
|
||||
call_api_success = True
|
||||
|
||||
while num_trial < max_num_trial:
|
||||
try:
|
||||
response = client.chat.completions.create(
|
||||
model=deployment,
|
||||
messages=[
|
||||
{
|
||||
"role": "system",
|
||||
"content": [
|
||||
{
|
||||
"type": "text",
|
||||
"text": "You are an AI assistant that is good at making plans and analyzing screens, and helping people find information."
|
||||
},
|
||||
]
|
||||
},
|
||||
{
|
||||
"role": "user",
|
||||
"content": content
|
||||
}
|
||||
],
|
||||
temperature=0.01,
|
||||
max_tokens=max_tokens,
|
||||
)
|
||||
ans_1st_pass = response.choices[0].message.content
|
||||
break
|
||||
except:
|
||||
print('retry call gptv', num_trial)
|
||||
num_trial += 1
|
||||
ans_1st_pass = ''
|
||||
time.sleep(10)
|
||||
if num_trial == max_num_trial:
|
||||
call_api_success = False
|
||||
return ans_1st_pass, call_api_success
|
||||
|
||||
|
||||
def check_ocr_box(image_path, display_img = True, output_bb_format='xywh', goal_filtering=None, easyocr_args=None):
|
||||
if easyocr_args is None:
|
||||
easyocr_args = {}
|
||||
result = reader.readtext(image_path, **easyocr_args)
|
||||
is_goal_filtered = False
|
||||
if goal_filtering:
|
||||
ocr_filter_fs = "Example 1:\n Based on task and ocr results, ```In summary, the task related bboxes are: [([[3060, 111], [3135, 111], [3135, 141], [3060, 141]], 'Share', 0.949013667261589), ([[3068, 197], [3135, 197], [3135, 227], [3068, 227]], 'Link _', 0.3567054243152049), ([[3006, 321], [3178, 321], [3178, 354], [3006, 354]], 'Manage Access', 0.8800734456437066)] ``` \n Example 2:\n Based on task and ocr results, ```In summary, the task related bboxes are: [([[3060, 111], [3135, 111], [3135, 141], [3060, 141]], 'Search Google or type a URL', 0.949013667261589)] ```"
|
||||
# message_text = f"Based on the ocr results which contains text+bounding box in a dictionary, please filter it so that it only contains the task related bboxes. The task is: {goal_filtering}, the ocr results are: {str(result)}. Your final answer should be in the exact same format as the ocr results, please do not include any other redundant information, please do not include any analysis."
|
||||
message_text = f"Based on the task and ocr results which contains text+bounding box in a dictionary, please filter it so that it only contains the task related bboxes. Requirement: 1. first give a brief analysis. 2. provide an answer in the format: ```In summary, the task related bboxes are: ..```, you must put it inside ``` ```. Do not include any info after ```.\n {ocr_filter_fs}\n The task is: {goal_filtering}, the ocr results are: {str(result)}."
|
||||
|
||||
prompt = [{"role":"system", "content": "You are an AI assistant that helps people find the correct way to operate computer or smartphone."}, {"role":"user","content": message_text},]
|
||||
print('[Perform OCR filtering by goal] ongoing ...')
|
||||
# pred, _, _ = call_gpt4(prompt)
|
||||
pred, _, = call_gpt4v(message_text)
|
||||
# import pdb; pdb.set_trace()
|
||||
try:
|
||||
# match = re.search(r"```(.*?)```", pred, re.DOTALL)
|
||||
# result = match.group(1).strip()
|
||||
# pred = result.split('In summary, the task related bboxes are:')[-1].strip()
|
||||
pred = pred.split('In summary, the task related bboxes are:')[-1].strip().strip('```')
|
||||
result = ast.literal_eval(pred)
|
||||
print('[Perform OCR filtering by goal] success!!! Filtered buttons: ', pred)
|
||||
is_goal_filtered = True
|
||||
except:
|
||||
print('[Perform OCR filtering by goal] failed or unused!!!')
|
||||
pass
|
||||
# added_prompt = [{"role":"assistant","content":pred},
|
||||
# {"role":"user","content": "given the previous answers, please provide the final answer in the exact same format as the ocr results, please do not include any other redundant information, please do not include any analysis."}]
|
||||
# prompt.extend(added_prompt)
|
||||
# pred, _, _ = call_gpt4(prompt)
|
||||
# print('goal filtering pred 2nd:', pred)
|
||||
# result = ast.literal_eval(pred)
|
||||
# print('goal filtering pred:', result[-5:])
|
||||
coord = [item[0] for item in result]
|
||||
text = [item[1] for item in result]
|
||||
# confidence = [item[2] for item in result]
|
||||
# if confidence_filtering:
|
||||
# coord = [coord[i] for i in range(len(coord)) if confidence[i] > confidence_filtering]
|
||||
# text = [text[i] for i in range(len(text)) if confidence[i] > confidence_filtering]
|
||||
# read the image using cv2
|
||||
if display_img:
|
||||
opencv_img = cv2.imread(image_path)
|
||||
@@ -520,87 +391,4 @@ def check_ocr_box(image_path, display_img = True, output_bb_format='xywh', goal_
|
||||
return (text, bb), is_goal_filtered
|
||||
|
||||
|
||||
def get_pred_gptv(message_text, yolo_labled_img, label_coordinates, summarize_history=True, verbose=True, history=None, id_key='Click ID'):
|
||||
""" This func first
|
||||
1. call gptv(yolo_labled_img, text bbox+task) -> ans_1st_cal
|
||||
2. call gpt4(ans_1st_cal, label_coordinates) -> final ans
|
||||
"""
|
||||
|
||||
# Configuration
|
||||
encoded_image = yolo_labled_img
|
||||
|
||||
# Payload for the request
|
||||
if not history:
|
||||
messages = [
|
||||
{"role": "system", "content": [{"type": "text","text": "You are an AI assistant that is great at interpreting screenshot and predict action."},]},
|
||||
{"role": "user","content": [{"type": "text","text": message_text}, {"type": "image_url","image_url": {"url": f"data:image/jpeg;base64,{encoded_image}"}},]}
|
||||
]
|
||||
else:
|
||||
messages = [
|
||||
{"role": "system", "content": [{"type": "text","text": "You are an AI assistant that is great at interpreting screenshot and predict action."},]},
|
||||
history,
|
||||
{"role": "user","content": [{"type": "image_url","image_url": {"url": f"data:image/jpeg;base64,{encoded_image}"}},{"type": "text","text": message_text},]}
|
||||
]
|
||||
|
||||
payload = {
|
||||
"messages": messages,
|
||||
"temperature": 0.01, # 0.01
|
||||
"top_p": 0.95,
|
||||
"max_tokens": 800
|
||||
}
|
||||
|
||||
max_num_trial = 3
|
||||
num_trial = 0
|
||||
call_api_success = True
|
||||
while num_trial < max_num_trial:
|
||||
try:
|
||||
# response = requests.post(GPT4V_ENDPOINT, headers=headers, json=payload)
|
||||
# response.raise_for_status() # Will raise an HTTPError if the HTTP request returned an unsuccessful status code
|
||||
# ans_1st_pass = response.json()['choices'][0]['message']['content']
|
||||
response = client.chat.completions.create(
|
||||
model=deployment,
|
||||
messages=messages,
|
||||
temperature=0.01,
|
||||
max_tokens=512,
|
||||
)
|
||||
ans_1st_pass = response.choices[0].message.content
|
||||
break
|
||||
except requests.RequestException as e:
|
||||
print('retry call gptv', num_trial)
|
||||
num_trial += 1
|
||||
ans_1st_pass = ''
|
||||
time.sleep(30)
|
||||
# raise SystemExit(f"Failed to make the request. Error: {e}")
|
||||
if num_trial == max_num_trial:
|
||||
call_api_success = False
|
||||
if verbose:
|
||||
print('Answer by GPTV: ', ans_1st_pass)
|
||||
# extract by simple parsing
|
||||
try:
|
||||
match = re.search(r"```(.*?)```", ans_1st_pass, re.DOTALL)
|
||||
if match:
|
||||
result = match.group(1).strip()
|
||||
pred = result.split('In summary, the next action I will perform is:')[-1].strip().replace('\\', '')
|
||||
pred = ast.literal_eval(pred)
|
||||
else:
|
||||
pred = ans_1st_pass.split('In summary, the next action I will perform is:')[-1].strip().replace('\\', '')
|
||||
pred = ast.literal_eval(pred)
|
||||
|
||||
if id_key in pred:
|
||||
icon_id = pred[id_key]
|
||||
bbox = label_coordinates[str(icon_id)]
|
||||
pred['click_point'] = [bbox[0] + bbox[2]/2, bbox[1] + bbox[3]/2]
|
||||
except:
|
||||
# import pdb; pdb.set_trace()
|
||||
print('gptv action regex extract fail!!!')
|
||||
print('ans_1st_pass:', ans_1st_pass)
|
||||
pred = {'action_type': 'CLICK', 'click_point': [0, 0], 'value': 'None', 'is_completed': False}
|
||||
|
||||
step_pred_summary = None
|
||||
if summarize_history:
|
||||
step_pred_summary, _ = call_gpt4v_new('Summarize what action you decide to perform in the current step, in one sentence, and do not include any icon box number: ' + ans_1st_pass, max_tokens=128)
|
||||
print('step_pred_summary', step_pred_summary)
|
||||
return pred, [call_api_success, ans_1st_pass, None, step_pred_summary]
|
||||
# return pred, [call_api_success, message_2nd, completion_2nd.choices[0].message.content, step_pred_summary]
|
||||
|
||||
|
||||
|
||||
Reference in New Issue
Block a user