From 664407ae9b3f4a5a9d19d1a2803ae3aa4ca30e97 Mon Sep 17 00:00:00 2001 From: yadonglu Date: Wed, 9 Oct 2024 22:31:38 +0000 Subject: [PATCH] update readme; demo --- README.md | 3 +- __pycache__/utils.cpython-312.pyc | Bin 31466 -> 22216 bytes demo.ipynb | 472 ++++++------------ util/__pycache__/__init__.cpython-312.pyc | Bin 0 -> 147 bytes .../__pycache__/box_annotator.cpython-312.pyc | Bin 0 -> 9812 bytes utils.py | 234 +-------- 6 files changed, 174 insertions(+), 535 deletions(-) create mode 100644 util/__pycache__/__init__.cpython-312.pyc create mode 100644 util/__pycache__/box_annotator.cpython-312.pyc diff --git a/README.md b/README.md index e6f0d30..0465b9f 100644 --- a/README.md +++ b/README.md @@ -18,7 +18,8 @@ ## Install ```python conda create -n "omni" python==3.12 -pip install -r requirements.txt +conda activate omni +pip install -r requirement.txt ``` ## Examples: diff --git a/__pycache__/utils.cpython-312.pyc b/__pycache__/utils.cpython-312.pyc index 395a981858ea8353df8bbce91894b7cc9e55f17b..3b245085ae73a89c3c55caebd66ae9f96493d1c6 100644 GIT binary patch delta 4274 zcmai1eQ;FO6@T~b+wa|E^RZbViP?PT5wd|Kgb++f6c7>;BA}oZm(9MHWMQ+r-uD7o z+%#!R6-Gtj;^#1G9j6XeIxrbKWu(;>akL$kOdLtypopE;+8He*YK#4&J?Ab7*um+} zzTf>i=iKwox#!&T{{A(2`xwdiqucEe;CEiM_II`&%$P@dPE;-;g=LC;FS)x+aqJcD zE?3GGCyXi-7mO;Ea+V;Sq=KCyuR1H0QdntB61j>U7VmeOL6!-gnI^&(nAS^*Ns-vU zOr>7)%LMlkgb^u46me&jAkb1)Ea!tI-O`)i&^~Gset?i9Q5_L84gkDn@Sk^0x zhC@+zMZ#)uFqANyI?xOR6+M|yVaYrgSJbFsPtbT+)wDR39n52wEgn#oX0es_{5AD; zhP5vm59v!94N;HF79*Q81!JK>H5jMCL`WYnEa^RHni0A45WE1I1%REj_ly=prwnHf z0E`KLa%P>+_dnPCbo06Vx{3U{5Aquh=KZ>SpXGwoG3B}GoF_2h37oC%nDlgxt(&sC zAM^frbjnk6&eJmCX_>MucwRr{eZBBx;U$xh)jA`HS(g3gOEw`Z@I4fzNV3m-$s*+C zKe_eb)~EU=z2(Q(O*mW6O27}lFk$ts(osSLQAm-LFl+h|F@=UVbyE?jWMW8nZCJkn z?B^Vc>NFHh>XEQU@zrFL;YgbY&BCYBM!A?ter~<1i;L#oz*_R~&+hp#dX+bhCV@0K} zn6Q1X)W>GLMW#=(a+oi#$TsP2nh{*)+VSOiZN!9~lrP8Tl|cqwh)|7iRf`vknuyFk zDry{U#@PUXOlTc85`Q@(Bc!7;k$4PFxwexog*C$&84UHSK`pXJ)zUWDiR`#WqsUBX zJ#P$?5*hq_KJ-7kKlhinUEH&GrEH>&Fot+h2f;iceTb^G3H32tYFrEI@ttZc7*S-A zu4FIy%PZPoNzSJ(XlOgaEeN+FtU_3gP{gMFWn?Y$6|W#ZwyC(ah1*;+>X?znXw&yrS}mSjl~g!IO3@_j4NzA(l+N@nH=UBm@WhT&^WY z*!+b7$cbwgev|#(0|C|^W|N6lPmX^(SJK}GKhSfcHrH_+rk{pKyUhMk>lG&?W?$UopUf$p7To4V`>p%PzJ^bH z`TJbweg5aYlfKGn!DP>4TNXE%usyK2IS=T*DiB;*kKVlh=BF|yot0;$O2{W4FCHc& z!nV{M%{l=a7rkb`|_|f;V#^qIT>VwPYjlPF6rEolvG;E<*EUt%i)iCXe zBxG;;gdao6Zz7yQcnjfugtG_}00tQ_$Zq<39KVC`F2W?PSP}y?q^a}+oZ(x!!4<)) z8^x)?zFD3{9%Q!WC7FJtIX?@N<7=DGi=Jy?6u^~Q*3i;ghAJ5rEez>aON(xU8QIE1 z?P}&^PqZvuih0(MFyF}TbnbcsXRP5+0#g(H3`X>C2>f!yNTu5kwdOgmdudkOnwRT` z6~mgKsuBt7^fHWQgz?VSf0NO_Kp*ZMixwY(yB;kW>q}osR!&E`C~lUofRwI9z!N1%-g0GM)rntT^q)XwW9Ql)BJuHXdydr? zN8VQfAbno95qo5M&_>h4Gw>cFlR;jH4gcyeg%8K z05Hr68j0x?BbBZZgenM{bP2$3)&n@LAl!nm8o`Nho_)FQw(K85S1vbfk->gmNl`Dm ztwS#1X=Cp28yue2s5Acq-M%*JABbh1Z zj!gWx@VZj&R7NVZPnxA70%)*aDp-_qq+qd61kBh2(tt1%+dff|6!XhhH6y~q>m`9b z6)YNcToVynkrEXP|IC?>O3fJu)Vi^9V5m$mlyk%yD+FFE=bd*0Uc71f@7jx{c}I~s zWgm5d%9cK;*25x!XT-Yu|McK0I7gR-n@NK3}U)Ua&i{-G5(YYM@&=VQun zwPc7?Wrv5W+o!mYxoStFzd|9!2$cv50#DmKrDMuB%ppocN%|o!oj#Pz?c|lib@vCV z$lZ@ghhp-eXnfj+9GF3PvisI{Gh7WrQMHwJ!z%m*Z99u46Pa$Za@sGEteK1ol0Utf QeVM%A_$nixRB`720?JZ0ssI20 delta 12757 zcmcJ0d2k!onP+$720;*D10*FLYKkW)-j^s^l1+;eb&%9W$&zIS2GLCtBnZH-n->Bs z=&I~2)F{1WqFfW z&g|~@y(UP?D>?aN7c#$o*Y~~i``+(;@A>9mv;X)%S?7=P^X&}$Uj4<#V~@u!J4@J! zZ`3ui<#nR%3+&U?qWufZ(=}qfm(x{xESPU_>a`R?v4U5julP%vB45?BITx+Wh`DOoR<7)9M8+k} zIu?!+DH4=?oR;4=qePF9Xi$>nDA7C^9g2!l_&}8G3(UxY@ZeB<%XBD6qH;7frnDSb z<85kh8V$(Oh>z9W3Yn1=Kar-RisTOhcl`0+*{EC_*5oveDNSQzieHMg`AHz;2`rfD2#%@)fI zzt5rm*yuC6v3#NWiLun*)~#8fZvmyFowky;SK~meO|w&Wev zk}@RNJDty}bA&|M25EvIsLz;=<@q374S#YGr+&*+p2}L>X>aR0me%*J!&%Sv6;FG{ z)1K}aUiKVavSnROE3Q2m*Pg7S?t1^b9&XlxcA;;3~u`ird2Cr_oRy(-mx`2 zls(UV>C_94XFY8zp1zEyFYDNQv-8&Z(UtR)ne&q%$opi9y~l7~`wRO)l(+qn#q?4a z_OF_u!b|aYY?bd@`?H?P6;E5n)0S=@N(ZF0r)}9Y293_uTEF7y$#{CwyW$@)%pq== z`!#+}b9eR6D#J{2sR)<$^xR_@ZR|Bh#k2l9FkalrNz_YAT<#fivoW`(P~1Voh5cC@zus1U*OE2ss$Ow*W?Y@=u4mKJGL$};;s#OT)7-OM z+SLh+?f|!BTJPEQ`kNx{*U1Yuls5N{tpa=gfVAspZq}y7iS%p&L@c|K59Nn(@os(T z%;hsLk1cy^t{=+e?M)l^-ZR5eV1kS~Y_3W%EWp$lEcCJ6$8EsxQhM_SY zcCY#+$Dfw#LP#J2CfSIh2x=jzK(bvulAku~HLI1*O14(r=RC#MtAFAww@*QdQ}8E$ zQGLVNVtCinsNOAm&BS<{R-Nj#f=(0YM4?B0zhIkTCBNr}+4Tf#V4S;d9CGtkZ!3y# zL(-09t%~Qi^suj}f3>Cicps*A0)dY8;KxW?t~mpWB1IIii2mt-JVhRZJS{IY9T=DV zawsmzdIx1>$2_7SGfQ@i4!*vq#mRZ+(odP=(UHt#tH`W8MOi6G%=X83E#} zCOBl|G?J%~oKe3~%BRj_@N-D~NS;L!Kr)IXh-72_>5N&YVXK5*%j7KPiby0R-9R*U zUUTS!Lf3;##xd2TNb`!U8I*ZNGtmsqG98%rhgG30472*hvZ}m4hJw$-pZv5@ebd>T zn^rK1P3qNh38zxcuQ+3Px2RT~uJ{83Q(UM1s-g+XwX2QW&~$dFg79DHOxr52aCLN& zQ7817rV~$xmemXZ5M+dcpe^c;t6JFC)!$Y%vkmHw>R%N;xI$o77S)cLc4*~TO&|Lm z^;%7wZMacYYhv>kL4wap{tWB*bG56^SMsMA{5lZX3Iui-_$@9%C9`!k>|5%Ux-UQ_ zmG%BI+7PqGHMf#WP*O8Oi#o_(A^V@HSL!REy6@Dtu|HM+NBxi&Cy*`z%@Bwz=<{Ep z9`$)s{|8hJq>IMTMjLU33D43-3Qxdq&2%7;nU(zT zN5X*^c^yi96aM6X1oGQOCY#5n^GdJ8uhsvJt^d}xU1{U4be{Twuczo;zUqQyiCyZv zQ(U^_xLqt<^}Y+U)Ys5yU`p!M`G(#i##8-K8RKxjxckEHmz~Rbb!lTAX}EE-VV-3_ zQ2QERaeo_X(QIJmBJx<2Op~Y7PZ}GGs8O>65)l1#d66GOv7f1Xn*Pw$03kxvmagk> z-uP+LYex2)>QZ}o`8$}mW~v=eM(6t?k*E?-q6F0XC+b`6XUtTiK!G~yQ%AAp!WktL zZibQb2XbY}3T7CC(b$4!T}zcTLp&7od39aF;URyIj5_I-->#U24V7 zz17$N%`682*UE&-p$p7HZS$I4RTp=5rO*m%Ml@Kyd|k7Dh$)s}AcnS$tfI`jNXn6* z#_K9@G;n?s1pbRr=LI7-MKn{jNbU~8` z%@m>YT|@A2Nd6W`iu@f0e}#mOGj?9zT;IpEY9w@ezmK6GAXy)K^1m=jRpW0k^zVTH zZYQCLLd^j|D3Jr;bn8azO{S30&1#UbK`;;FzfiOL1`7WGg1#Ef8k!#W#{$X(DO20} zecP$&T2~j%C6d6L9|szaQIWCGIP6l1{wgya>a*%!_6t6$|5qe*ls7bwt_7-O zr_^8e*RvPZ;{A`C|2vlaC-s^AwN5m)&AZ=_#$gcTZhvHH})SIW)H4U$SW+}tY0ya zFJf&MfN1OjS#Pv#f)uJGdd*g>d46GTg49A)bLvkIe+hPp*pYPKcAT>+Ab_9NWsw?4 z(75@S#x2AiA!~ktG(f3@`uyMwl~t!!485}T3bqI35vm58+M)hzh&MGt&hzTWLvv=d zsXjQl)$z7h#NeeK1>Ze$935cK*>}$q# z4}&pboV2Sgg(cjiL%r+e)q7UH;Uu$Qt%6mmxO04>DPhQYMZi~L5Y6#?aC{il`ePL- z(UNoKu*_K|Vwg9aW#-uR_OTZ)TmL{i8CZH_P=Q)YEd?!aPO&X9-S3;miPrc@3D67*-2T9x(Ju2oA?4IV?#r zp9w*RxIkJkCm$OWnVI%0Bou%%_jGMFMocexH(hYyi(kiDek)6(3x#&%eHoATXwD5 zn3C>$jHzV%(!o^+<0)I=t1^7m)xHd0yTUhT_~sS9JHvOU9o?&T#^Lw*QRtn*G9*NUiv%?^WM(xSfc zM3wqhGe0JBx%HSdiVXhCNr0ux5*PRy!3~)*ewq0S7h*;86;8B(Vpy+m;~Y5a%4erWZmQ@rPdQkZKncgAku6cbv5QaHZ49lt9u^ZP7A37Id%`*f{<$iq zjgi?o4%Lo@e;5BmS3T6-1e-LTWiEbgBSk6A)gl%qESrYJBJxz5EutrBSITpxCM)!i znpfmViQe3pKA_{d5n(TW>*8r&@vz2`QDVYnZPU*lGbF4TfZQ)eFJLcCY?+Kn2z`fyfGme(B@j`- zE(8=IBn#uws3<^iIxrQAj0<2{A~F<)v$3kI=Y+L8vjiEh2G#8p# zv=x4NCWu~DYEU!>jE27ryreN0IT?Wgn&q{hz70AlCs@SJS${;DBZDmT{|d;IF92C& z-m|z@Iis!aw%dDY@@n<6yKa?X^G~qXdTtA~D?)2VXuV$Yt@3Y_-`KV+^sNX7Gs3}} z#mmC+OZmSxV8Od~##)#ytGh0xA3vRbN=%=Y(tF0zUE>*kA}uehGR!{iC_1&K*kkB| z;!`KM3zj=&+pq1;@IC3m-X+T&k2h!Sic8Z*&z*IfSNd?9x4jXZ5pBDGbUn7pFE2A~-lf5;!wpu{>b^Aej-@hNDx|F?*}UTP7GFBgmp1x- zeU8Qc>%;ZQs*$zTWy@+`82V)O0Bd~Etmdq0wzPa{F!+FN-L+v`TVVda;MiBkT{ntC8Bp{4xD7B&}fWtfCF<9M~lMf(f-L>lj4LQ&%g! zUTco)88Grmdom9!nqB>>$Kw(mxhyds4If;xc+_aI6)#)zow+MdUv0J?JMa{ z7AErjm#7 z*5pQJqgNZe31`YRmf{j-eWQYoY%B`*M)%G#m)MKN`X&{GjV#C|9dkfGE95H~J`iYt zm9&38fRIdR>$M9{!geAG(TJdcPmr2*I4P!twIgz)@a(hC9*PL^%=C1CEHqLKNlOw5 zp`=L9loD(cMCujJ)t@`p(c0bGD73Y;oo^J*b+mPKeLB*q=jg>C0_EyXP5^&hqtM#C zt8-UtTSuz9duMxhTUXC6U!$N4?EyKvc5V{7Inv&)=d^Z1F^r(l!BAvM@KdQBUEMod zyE@xDJKDP1TRYdKwstoP9qnsP-3h#FkyJ+;jdgX>3R`J8h@jXZ*bT=ep^t72RC-TO zYwOOA&d#px&W@d}-QB+PIBI~abK~6psBwPOFsJ^1jk6RW!3kj?8Xbq}$0Gqat%ZPa z>crrt(eLg~R==V;bZ!{P_=I&kU5xl8@t zOTM`oiZ==i0IjCs3RM6H6Ltg!0PBD}{DW5ZM+NZPL3%JeLn~i^@K{s`jA7TYj9rKq z*o6_=&lIR)NbVKl#k$6!^JYyWbxpg^Hj@J|EAhZ*f;vQKvJl&Z&#D+IbQ$!4kGzylx=`Pi0{!WNIklSHB1I>Ea3oMH)w-q6a`L~y~0Sm<9D@i zEC9F<3?30?A~Rt6K+uLa3LIm21pXlKzk%w>Z5X1iPIT1xcy!+dM2Tot{8C^^T}||~ zYTp2WaAu8#LkhjW0H?RCXbyMPCQ_V!y3kKHe9_ z+BFwA!EsRG${W=KZoTti8LY+pV=_62+-I?<5q*TzGpiY)j%ocm1P=n59o+)KS#$+y z+~_EE<>>AgHlVSivKI=e&pW+qE~I?wT<~K>ps^$bRs010p2Z6gaLYm4 zQvRxu*_to`#!Q+Ho8gl%9l_5L^Ev!HW1hs%Ir9{Lrp!-TJ_2lNIbeaInUOf5`tOXl0w{Iv72 zTh`q_=gVIXFY`MtSeJ}TfmNg3=7zJqr{tyH%f0D}gUkG(o7*$|&p`mJZN^`tkJ1r*1jUWVcnN9c8OU47VMw08aQ! zcLAyL>eo;B3_ta3QQti=1Xnj()w+{Ur+%|Kz;ZS>^6EbR zu%0RG{P;h+JE-gK&8>Z&16|Bp?T!Jf@ol4HAm8|RdkO~%jo&XV94Iz^zbk*hWC*VB z{xDG(rl9j306}=>Pg)WtxlDvNJ$N?1!llMp(FB;%EOKx#hXfsEA||Crb~`;!LkU2e z*2pGCoexgjaEM1-x|t&<3r^z5X@CR0RfP9Ga9Vd>F(z$dLDCNBF;8?S3ml2Ogq@!C zPjB|F4gU`B8pJG3rme6=tzCUT=Y z%wjpV1`^AfC+jxK=R#tISd^5~lzS>1&>ESD^d3?!WYFJz`R$xN6Ng|l#W z2E1#=W+wRrr!?i-kl=H*(>zbYqYo{7w_%F7Ee;1G37Jxm@tP6eb)Ar&yBkjvICGtGL$0 zF_Y_lqFf5Tcp>e{f=$v=o>RV^na!dSiZ(f9AmFN;(pi9-_7O@8Orx3hN zSM>^f4e>t$1OTp>v6)D44awNi#Q?Ylw-)Mct*4J7z{kJl;s=NTPz!)~DmuRdk(p0u z+KbqRdMkj9N&w`Olh9eeDYFq=!f;yvw-r$WNDbY?lpnk~fiOy;+!Sn3H&cA%k%%Le z^CO{;7ad7(*&-3xn-qiuYp$aSKtbqJTD#T~;6)bo7465}Y#nLnZmlOVZ6A_;Btk%parP;dKWXGZKOaFb{P-eSm(t;RA73pAG{@7leUh zBc}m%QsR~9EYGnLX=EjfJ!U)ly;GSg6?Q;z4b7leo#-ieP7Tl@VH-0kjBF823(8=CkVxyK}@rR zU{v8f4P1913Y$R=yb`2OpVa5dic`A3F(0qp=-|>@7`)bSz+#sHkASyv)ya6P7l+`L zl!L$HEx&#MAX7j4Gy4Ebu!0}2sN%C~bMRPU>CLY6$W%JKke-UC zPkcUoIFYGOg1Ce5s1H6*aX}QD0&IhIjL|yAxmAnN<6KHvv2*3w;?k_!n=RUM*JdiP zFAiYh`dtq7Zf@M@xNM~#S(%cG|Laa?dt%X8twy>pk(U3IIU zTk-lbULQrBS$@k)yD#rf3&)oEpJrd`zXNY9 zM=tEm7HxgG`*L_`Aj?;O{p7Xczj!9Yw=eZ)tLm&L%sd87J`=iaQlS^1_f({m)feQ>G&QhCO)4Q69YdD`fG&r))GTN6xT%M?y!gp1=` zw!|?qU^r}mlYwx=kg=Afty^AhqeEDkuIjvH?aJmAqzkuS9lTzh?m3ipA5P~TP8$#Z z=A!}d&$&ug3hFZj^&c{y<^NrG|GqeUVfd3(8+1VeyG*}kFH!$dYHTh6podB@Y#5@2j&rhlaCt?p(ojA9HB?d z;b91k49BheaMKVD_4|4?E-Din8UT7Xg4Y}Hq(Z%Wvbc)OV#KW9`w}zAg9j%IGC7T< z?_ho3S6fHCmFKXS1w1VaaIDkJC+P8&9K%x74pyjRBRfoF9um|)8rjkJGn8asf4(G> z@1ZQjuk;P3{zihnmUtc0(1OCV4nhJ8" + "" ] }, - "execution_count": 9, + "execution_count": 11, "metadata": {}, "output_type": "execute_result" }, @@ -640,238 +722,6 @@ "plt.imshow(image)\n" ] }, - { - "cell_type": "markdown", - "metadata": {}, - "source": [ - "# wrapped Omniparser" - ] - }, - { - "cell_type": "code", - "execution_count": 6, - "metadata": {}, - "outputs": [ - { - "name": "stdout", - "output_type": "stream", - "text": [ - "Parsing image: examples/pc_1.png\n", - "\n", - "image 1/1 /home/yadonglu/sandbox/screenparsing_collab/screenparsing/omniparser/examples/pc_1.png: 800x1280 210 icons, 55.6ms\n", - "Speed: 7.7ms preprocess, 55.6ms inference, 1.5ms postprocess per image at shape (1, 3, 800, 1280)\n", - "boxes cpu\n", - "Time taken for Omniparser on cpu: 2.029506206512451\n" - ] - } - ], - "source": [ - "from utils import get_som_labeled_img, check_ocr_box, get_caption_model_processor, get_dino_model, get_yolo_model\n", - "import torch\n", - "from ultralytics import YOLO\n", - "from PIL import Image\n", - "from typing import Dict, Tuple, List\n", - "import io\n", - "import base64\n", - "\n", - "\n", - "config = {\n", - " 'som_model_path': 'finetuned_icon_detect.pt',\n", - " 'device': 'cpu',\n", - " 'caption_model_path': 'Salesforce/blip2-opt-2.7b',\n", - " 'draw_bbox_config': {\n", - " 'text_scale': 0.8,\n", - " 'text_thickness': 2,\n", - " 'text_padding': 3,\n", - " 'thickness': 3,\n", - " },\n", - " 'BOX_TRESHOLD': 0.05\n", - "}\n", - "\n", - "\n", - "class Omniparser(object):\n", - " def __init__(self, config: Dict):\n", - " self.config = config\n", - " \n", - " self.som_model = get_yolo_model(model_path=config['som_model_path'])\n", - " # self.caption_model_processor = get_caption_model_processor(config['caption_model_path'], device=cofig['device'])\n", - " # self.caption_model_processor['model'].to(torch.float32)\n", - "\n", - " def parse(self, image_path: str):\n", - " print('Parsing image:', image_path)\n", - " ocr_bbox_rslt, is_goal_filtered = check_ocr_box(image_path, display_img = False, output_bb_format='xyxy', goal_filtering=None, easyocr_args={'paragraph': False, 'text_threshold':0.9})\n", - " text, ocr_bbox = ocr_bbox_rslt\n", - "\n", - " draw_bbox_config = self.config['draw_bbox_config']\n", - " BOX_TRESHOLD = self.config['BOX_TRESHOLD']\n", - " dino_labled_img, label_coordinates, parsed_content_list = get_som_labeled_img(image_path, self.som_model, BOX_TRESHOLD = BOX_TRESHOLD, output_coord_in_ratio=False, ocr_bbox=ocr_bbox,draw_bbox_config=draw_bbox_config, caption_model_processor=None, ocr_text=text,use_local_semantics=False)\n", - " \n", - " image = Image.open(io.BytesIO(base64.b64decode(dino_labled_img)))\n", - " # formating output\n", - " return_list = [{'from': 'omniparser', 'shape': {'x':coord[0], 'y':coord[1], 'width':coord[2], 'height':coord[3]},\n", - " 'text': parsed_content_list[i].split(': ')[1], 'type':'text'} for i, (k, coord) in enumerate(label_coordinates.items()) if i < len(parsed_content_list)]\n", - " return_list.extend(\n", - " [{'from': 'omniparser', 'shape': {'x':coord[0], 'y':coord[1], 'width':coord[2], 'height':coord[3]},\n", - " 'text': 'None', 'type':'icon'} for i, (k, coord) in enumerate(label_coordinates.items()) if i >= len(parsed_content_list)]\n", - " )\n", - "\n", - " return [image, return_list]\n", - " \n", - "parser = Omniparser(config)\n", - "image_path = 'imgs/pc_1.png'\n", - "\n", - "# time the parser\n", - "import time\n", - "s = time.time()\n", - "image, parsed_content_list = parser.parse(image_path)\n", - "device = config['device']\n", - "print(f'Time taken for Omniparser on {device}:', time.time() - s)\n" - ] - }, - { - "cell_type": "code", - "execution_count": 3, - "metadata": {}, - "outputs": [ - { - "name": "stdout", - "output_type": "stream", - "text": [ - "\n" - ] - }, - { - "name": "stdout", - "output_type": "stream", - "text": [ - "0: 800x1280 210 icons, 49.4ms\n", - "Speed: 5.7ms preprocess, 49.4ms inference, 1.1ms postprocess per image at shape (1, 3, 800, 1280)\n", - "boxes cpu\n", - "Time taken for Omniparser finetuned YOLO module on cpu: 0.2898883819580078\n" - ] - } - ], - "source": [ - "from utils import get_som_labeled_img, check_ocr_box, get_caption_model_processor, get_dino_model, get_yolo_model, predict_yolo\n", - "import torch\n", - "from ultralytics import YOLO\n", - "from PIL import Image\n", - "from typing import Dict, Tuple, List\n", - "import io\n", - "import base64\n", - "\n", - "\n", - "config = {\n", - " 'som_model_path': 'finetuned_icon_detect.pt',\n", - " 'device': 'cpu',\n", - " 'caption_model_path': 'Salesforce/blip2-opt-2.7b',\n", - " 'draw_bbox_config': {\n", - " 'text_scale': 0.8,\n", - " 'text_thickness': 2,\n", - " 'text_padding': 3,\n", - " 'thickness': 3,\n", - " },\n", - " 'BOX_TRESHOLD': 0.05\n", - "}\n", - "\n", - "class OmniparserYOLO(object):\n", - " def __init__(self, config: Dict):\n", - " self.config = config\n", - " self.som_model = get_yolo_model(model_path=config['som_model_path'])\n", - "\n", - " def parse(self, image):\n", - " draw_bbox_config = self.config['draw_bbox_config']\n", - " BOX_TRESHOLD = self.config['BOX_TRESHOLD']\n", - " xyxy, logits, phrases = predict_yolo(model=self.som_model, image_path=image, box_threshold=BOX_TRESHOLD)\n", - " # print('xyxy:', xyxy)\n", - " xyxy = xyxy.tolist()\n", - " # formating output\n", - " return_list = [{'from': 'omniparserYOLO', 'shape': {'x':coord[0], 'y':coord[1], 'width':coord[2]-coord[0], 'height':coord[3] - coord[1]},\n", - " 'text': 'None', 'type':'icon'} for i, coord in enumerate(xyxy)]\n", - " \n", - " return [None, return_list]\n", - " \n", - "parser = OmniparserYOLO(config)\n", - "image_path = 'imgs/pc_1.png'\n", - "image = Image.open(image_path)\n", - "\n", - "# time the parser\n", - "import time\n", - "s = time.time()\n", - "_, parsed_content_list = parser.parse(image)\n", - "device = config['device']\n", - "print(f'Time taken for Omniparser finetuned YOLO module on {device}:', time.time() - s)" - ] - }, - { - "cell_type": "markdown", - "metadata": {}, - "source": [ - "# florence caption model" - ] - }, - { - "cell_type": "code", - "execution_count": 9, - "metadata": {}, - "outputs": [ - { - "name": "stderr", - "output_type": "stream", - "text": [ - "/home/yadonglu/anaconda3/envs/pilot/lib/python3.9/site-packages/transformers/utils/generic.py:342: UserWarning: torch.utils._pytree._register_pytree_node is deprecated. Please use torch.utils._pytree.register_pytree_node instead.\n", - " _torch_pytree._register_pytree_node(\n", - "/home/yadonglu/anaconda3/envs/pilot/lib/python3.9/site-packages/huggingface_hub/file_download.py:1132: FutureWarning: `resume_download` is deprecated and will be removed in version 1.0.0. Downloads always resume when possible. If you want to force a new download, use `force_download=True`.\n", - " warnings.warn(\n" - ] - } - ], - "source": [ - "from transformers import AutoProcessor, AutoModelForCausalLM \n", - "import torch\n", - "device = 'cpu'\n", - "torch_dtype = torch.float16 if device == 'cuda' else torch.float32\n", - "model = AutoModelForCausalLM.from_pretrained(\"/home/yadonglu/sandbox/data/orca/florence-2-base-ft-fft_rai_win_ep5/epoch_5\", torch_dtype=torch_dtype, trust_remote_code=True).to(device)\n", - "processor = AutoProcessor.from_pretrained(\"microsoft/Florence-2-base\", trust_remote_code=True)\n" - ] - }, - { - "cell_type": "code", - "execution_count": 11, - "metadata": {}, - "outputs": [ - { - "data": { - "text/plain": [ - "['settings or configuration options.']" - ] - }, - "execution_count": 11, - "metadata": {}, - "output_type": "execute_result" - } - ], - "source": [ - "from PIL import Image\n", - "prompt = \"\"\n", - "image_path = 'imgs/settings.png'\n", - "image = [Image.open(image_path).convert('RGB')]\n", - "inputs = processor(images=image, text=[prompt]*len(image), return_tensors=\"pt\").to(device=device)\n", - "generated_ids = model.generate(input_ids=inputs[\"input_ids\"],pixel_values=inputs[\"pixel_values\"],max_new_tokens=1024,num_beams=3, do_sample=False)\n", - "generated_text = processor.batch_decode(generated_ids, skip_special_tokens=True)\n", - "generated_text = [gen.strip() for gen in generated_text]\n", - "generated_text" - ] - }, - { - "cell_type": "code", - "execution_count": 1, - "metadata": {}, - "outputs": [], - "source": [ - "import cv2" - ] - }, { "cell_type": "code", "execution_count": null, @@ -896,7 +746,7 @@ "name": "python", "nbconvert_exporter": "python", "pygments_lexer": "ipython3", - "version": "3.9.18" + "version": "3.12.0" } }, "nbformat": 4, diff --git a/util/__pycache__/__init__.cpython-312.pyc b/util/__pycache__/__init__.cpython-312.pyc new file mode 100644 index 0000000000000000000000000000000000000000..cf380e082148fe03e30ecbbb776bec440be7dd9a GIT binary patch literal 147 zcmX@j%ge<81goF^Oasx6K?FMZ%mNgd&QQsq$>_I|p@<2{`wUX^%Sb;XKQ~psGBG7T zFFmJJzc?{3B`Lo`-#<4mGa#|3IJHQ>v?Mb}KR!M)FS8^*Uaz3?7l%!5eoARhs$CH) U& literal 0 HcmV?d00001 diff --git a/util/__pycache__/box_annotator.cpython-312.pyc b/util/__pycache__/box_annotator.cpython-312.pyc new file mode 100644 index 0000000000000000000000000000000000000000..678898d44e0d60389bb742a5dd645dee6d0c41db GIT binary patch literal 9812 zcmb_CZEzGQy_uF4-@dY( z-EVgG0|_U-%b@M|UcdMH_3Q5Urn`r~_xaouJh#67uaw$FQNP26`B=-9hqF+*LGct% zXQ(VaMrWBZCTkh9WUXUX8u3iVHfAHGear!+CF9Jx#$2S&nsH}6W1g&c%u7=Yb%f$= zS1I0pi!rE=`FO`wYRu0&0S0&v@48Bl1$j5LLi0@2dmo2It-5tMEi1b7)Jwmfl@bFE5YZP39%UFlt z!yD5`dHH5ot7$$EZPC34b94Qom{T|mN?;~woB8j?SyNJfl z5sk#yTE(1??qf$Mh4MHhNBCeBz|=AltI#7ffunJO(lu zrFm&GtqApWlAcBL{o&g9fUeIL4e;?LAmt#0>oOWjTQG0;Yx2vPSF ze07|+mT%q1&F0d4A~!2Y8E!i1G}a;zSHtxmrU2ZaR9dAHqyvGQ-6Er2Y>> zW!w_#5TD;a*E5;R3O(~2pM!HW(<5^tKc1WGIhhsHr#ML#q@Ecio#_Ewf-`TA`1HK) zOeE4`T1g}>hifmExQT9qRpjRZT%k%$RNr$8)_2{-j!r22#jU%b2$;q8-UaIy&bH## z-B?D7+xIM3|Iyj{5av=Nx=l_BqM*C0SB`GaaN|No*4^^O^mKwWbT_%45_0;opxY%u znUTa~MUFto3qPW}uR!uQCSn@K{vu z`;=&f^GbA;57hA7@fb7ER0DWkr`Qd8sHtu1u^(A(l}|cvMO%0itSukRhk(Y8^`qqz zAQ5Wnsr5aV0hMo1T~FwRD=^=vI`d7ci-+4~Lkb|N%P>w2A zY@+52cbmLb%R1ZePwlSum^IV3@-4hwZ87w78KlHJpyfnuVGCuzy8vUW#pEMugm+&Y z01IMx51c)3zO`m&adVAo3DzD^8`MVF)srl!6(qZX4b|*(PnDio`^a~zeDlz78@)SZ zh4<2FD(dI{4Ghn1-100t0G^#(=S~@Z9eO4yiOZUf9k#1C?c>fJQtOu{YFszp89;8Hr1K`##dPO-u((yZa3YoOw6mI>E%qSsm> zUq=C{)Vhz~xt#GrVY|%m;v7aB=VV1XkG^TY$mYs%2l#|CxO6Gjn6YlfWKNO)Q{BsU`qJj=-jJMlVp zsxL>jYdn|Bz($5-Ru&W%*t#$(%?KA_EDox-3xgCm{3KFl6HS$`0O#4-}8&15T z^-F1GLyZ$nqf4At42KSPd3cV?LKIgG7!Y~k!iDL1WinT97wPPDPEuHTW?GPD(=xD! zyO(FH#BeT=6DWK?d#)1($~mz!#>V6E^K}9;INgsOb@+<5kbpcJui9T142{+=Rp%EO z<7?fau4oU|+OFBEDTc(v>1btOM7;^{c;Oov*PE-cDq@cU7 zb{bL9OQ1StoxYmE+Yi-+|0~6!HH(vPU^REuvM-OJ=H?x0xv$f~f2kx&xD{7I* zgbIjS?*BL7C431H5Oqp;jY~TL=$4!`A>kWc>O@d+<#i{9m!tF3LNuaVlCyn!)8NUG z(Ztciubn=6_-x|z(1}yShu_d`5+_awx+Nosx^-@TZeF)@vI5g!$0F>?R5Ww$li6&grJOqcus@kToYJeOhKXxJq!a9 zc>QI$jNxl1rgO4x9~>GUK0H8(NfYC`V|s?r(H+A>BZm|H{n0iFy+#S0IthI@sT~3O zRMHjzx)rsc(`lWV>(!Zg_|5e}*(dGAUJEd7Y=cPAh~F-yd1X@fka%rUNKZ^Ex}7xU z%H_P?K-%M6^5O)EHWNU69S#;-hl9n};pY2v7aoCfXqlPGrbQxO^sHsv85;i+9tqtW z-Dw`0$HF8Njed#z4`54Is7IqT6>Pj4=qv;vOu83p{+YbA@25Yy6Y6|O4bd+DLVvNn z^JZc>vHEJEz2|QGvxW9&KauaWzj$qA;ou^(cyNhX>)v~}dw-#Ozt%mp=)N1=UI=bq z>AxL(>T{)Jp?u+zo!Zi=b@W~D`jlQc^vF&%bu95WrKhOP5*83&)l`B|p{F za<_5k?Z%zONZS&Bx1*;3f03R=`{V7kot;{6TXARS>ef~MZu~$2{<;qob{_a_S7GM~ z=z;<3y1ENHhmBc<;I>CWN}`W2C5k%A3^U(iah7>t6o5R(2-r8lNLVP$EEaRjWh_+Y zhgiH~IcR}mz%qixDafKioGa1iEQO%BXkW5xY@fDwP&+!Ny(#{SH&BswrxtxtYkH{| zHY=vY@3pqy>|X9B5-$a*NXJd@viE)e!ceiLeZ{r9^+Ru=CAK>A@%(3@zfcQ3hYBr+ zJ`+AqYj2$WO{Oq7R%jW6p*wbeZ28lZ%bPxJUpRITRy}_0IMFevxozont!YP99c#^O zMI-kjn{T!)w|(2__W4UcqR~i8qYM=cFLKdpsjZ;(6o4FX;=$)6#~(c z4KQ%Rnl_ajSOc?fExEAf#+nCfUMjM=~v~K(Z32Lj&bTG9`PU0$DCd`=f-RBzbKTLqgWO zj3ZGk;xJZ%xWhI=-lWDW5|bQK7lJndK$6}9x2cSavE|#5I2e+^Ntj6cGx*7P-Cv=K z!N$e@>-MF_rB|1muLo8RtEYn-8j(aIQfGay*q? zRrUZm66D*=cjTG6F=T%#ZS1o$#-boi@&FxU-?ULy##rmd5SaiSQdY*;s%!ByWCR(M zn%P^|XHzY;eO1X4mxIU?y%hpcsPXnI@7S5GrgdC|TyjoxbM%_YmMUIZd0TwdFN$lhS=x!qknn2_5 z)@oU|&5%!Lc8PzaKn(tgdqU|+O<=?L8mG%UWfva`_Gx!U}3^PRx+n)CU!K--c~ z2yEA!+YS3D=azsaio!PpO6;K%Bm8n)Uh52`qy>esj9ZwTvRV|PowepasCwU#f z8D}u>Q!N!N5K?I})2~|d0fvH%mP7TbKGhHT8|N)NK=9;)YEX8PTvxd-WheSFa%?nN z(O4Ts!#X$zIH;_Hs~e+$C*&Xx>s5NIbG|PJko~LKca@Guwc+kzC9uJ=1hBWV*E+1h z?;o;ubzilTudZ3`f08b2D2eYadCWT2Xd-IiI^hyvlaoU7B3u5LgLxYa4o_18E}2Y~@7wrPxj-H_c{* zlmdAOQkO7z!A1@~l}eQ^qh)lmqbCPW?qky_mV+!A=8DSOn1(N&G1fTVB&!CYK7M8t zVr)7t#Mv4m6@Eu2AuTnT%kcYPx5np8$O7$&KTSr|eJR0v06#80{8s>Yp5SXp`91}o zwtP2%;;S1K36{bG+l)r@_^~W9sPcV5*N_nS(ZV7oHZ_403IAMj; zhj$-5JcD<%OuzdAZ^_$fN~AIC$y+Z%zZ{0qHsjq(M)_2x+~b71z-Hue$sNz&Y!%LJ z)IKse@ickv92hljSbWV!?NSdE5_&>XFM>V<=z@}qG3qiR+hd40j$jzU2?Qj@g_zv@ zk|@0bSlv}kec&BdzHiyGJLE^Rt4_JRdqXpfzm&;A@HTE;ot9@8hTzuKnx4|S_GvGTXs5=t#-t{Q zcLF(0nz<90DYkE2bblEPms~&vd{}SC3VqwNRonU8r`?~u`d7Pe?>wR%Kd)WjZXf3~ z;ZljB-lh*@v@^&Y!SLv9`Y7`aMWvvCKw{xoF%Z$55ySr^o(yphL&oRKLD+$G1VKN7 zg9r{G0F{?B2`F0$-4p33f*}MO=^zIkatJ_0&Su5%{`cWT1i!|C$H027P~Ujnw$@TB z-BOCEwqi2E9jGFOg$=?NDD;l){PORx6yVeQCS?Oaw1WHo0NJpM8l zE3F0Jt|}`lZ3SQK%HVe_`W>wP5&Sq!el}vIvf(GORU@FX)c*PZUo=ma%KGIiLIT$R za8}klp~Yvlw%uA!zjo-1_Jg$MNo&?LvF<268f}2w)I8qVBp5ZqQuJw!h!pcbx=ozP zPR~ncvEEvf+m2Um3&^TU0*Pz&=DGp!TVT;UA(4N35JgVaBP$ z+|=)#FB!4;%M$(%FTCSrJPaiZP1C=ndhSp?zoXp$L^XU(_5PN6_FMZK3>|(zA^c|6 aO4HALW4F<}ANoD?j)$9vXnKk!wEqX{V|wQR literal 0 HcmV?d00001 diff --git a/utils.py b/utils.py index 40fb079..2fc180d 100755 --- a/utils.py +++ b/utils.py @@ -18,7 +18,7 @@ import numpy as np # %matplotlib inline from matplotlib import pyplot as plt import easyocr -reader = easyocr.Reader(['en']) # this needs to run only once to load the model into memory # 'ch_sim', +reader = easyocr.Reader(['en']) import time import base64 @@ -33,44 +33,19 @@ import supervision as sv import torchvision.transforms as T -def get_caption_model_processor(model_name="Salesforce/blip2-opt-2.7b", device=None): +def get_caption_model_processor(model_name_or_path="Salesforce/blip2-opt-2.7b", device=None): if not device: device = "cuda" if torch.cuda.is_available() else "cpu" - if model_name == "Salesforce/blip2-opt-2.7b": - from transformers import Blip2Processor, Blip2ForConditionalGeneration - processor = Blip2Processor.from_pretrained("Salesforce/blip2-opt-2.7b") + from transformers import Blip2Processor, Blip2ForConditionalGeneration + processor = Blip2Processor.from_pretrained("Salesforce/blip2-opt-2.7b") + if device == 'cpu': model = Blip2ForConditionalGeneration.from_pretrained( - "Salesforce/blip2-opt-2.7b", device_map=None, torch_dtype=torch.float16 - # '/home/yadonglu/sandbox/data/orca/blipv2_ui_merge', device_map=None, torch_dtype=torch.float16 - ) - elif model_name == "blip2-opt-2.7b-ui": - from transformers import Blip2Processor, Blip2ForConditionalGeneration - processor = Blip2Processor.from_pretrained("Salesforce/blip2-opt-2.7b") - if device == 'cpu': - model = Blip2ForConditionalGeneration.from_pretrained( - '/home/yadonglu/sandbox/data/orca/blipv2_ui_merge', device_map=None, torch_dtype=torch.float32 - ) - else: - model = Blip2ForConditionalGeneration.from_pretrained( - '/home/yadonglu/sandbox/data/orca/blipv2_ui_merge', device_map=None, torch_dtype=torch.float16 - ) - elif model_name == "florence": - from transformers import AutoProcessor, AutoModelForCausalLM - processor = AutoProcessor.from_pretrained("microsoft/Florence-2-base", trust_remote_code=True) - if device == 'cpu': - model = AutoModelForCausalLM.from_pretrained("/home/yadonglu/sandbox/data/orca/florence-2-base-ft-fft_ep1_rai", torch_dtype=torch.float32, trust_remote_code=True)#.to(device) - else: - model = AutoModelForCausalLM.from_pretrained("/home/yadonglu/sandbox/data/orca/florence-2-base-ft-fft_ep1_rai_win_ep5_fixed", torch_dtype=torch.float16, trust_remote_code=True).to(device) - elif model_name == 'phi3v_ui': - from transformers import AutoModelForCausalLM, AutoProcessor - model_id = "microsoft/Phi-3-vision-128k-instruct" - model = AutoModelForCausalLM.from_pretrained('/home/yadonglu/sandbox/data/orca/phi3v_ui', device_map=device, trust_remote_code=True, torch_dtype="auto") - processor = AutoProcessor.from_pretrained(model_id, trust_remote_code=True) - elif model_name == 'phi3v': - from transformers import AutoModelForCausalLM, AutoProcessor - model_id = "microsoft/Phi-3-vision-128k-instruct" - model = AutoModelForCausalLM.from_pretrained(model_id, device_map=device, trust_remote_code=True, torch_dtype="auto") - processor = AutoProcessor.from_pretrained(model_id, trust_remote_code=True) + model_name_or_path, device_map=None, torch_dtype=torch.float32 + ) + else: + model = Blip2ForConditionalGeneration.from_pretrained( + model_name_or_path, device_map=None, torch_dtype=torch.float16 + ) return {'model': model.to(device), 'processor': processor} @@ -94,14 +69,12 @@ def get_parsed_content_icon(filtered_boxes, ocr_bbox, image_source, caption_mode cropped_image = image_source[ymin:ymax, xmin:xmax, :] croped_pil_image.append(to_pil(cropped_image)) - # import pdb; pdb.set_trace() model, processor = caption_model_processor['model'], caption_model_processor['processor'] if not prompt: if 'florence' in model.config.name_or_path: prompt = "" else: prompt = "The image shows" - # prompt = "NO gender!NO gender!NO gender! The image shows a icon:" batch_size = 10 # Number of samples per batch generated_texts = [] @@ -387,117 +360,15 @@ def get_xywh_yolo(input): return x, y, w, h -def run_api(body, max_tokens=1024): - ''' - API call, check https://platform.openai.com/docs/guides/vision for the latest api usage. - ''' - max_num_trial = 3 - num_trial = 0 - while num_trial < max_num_trial: - try: - response = client.chat.completions.create( - model=deployment, - messages=body, - temperature=0.01, - max_tokens=max_tokens, - ) - return response.choices[0].message.content - except: - print('retry call gptv', num_trial) - num_trial += 1 - time.sleep(10) - return '' - -def call_gpt4v_new(message_text, image_path=None, max_tokens=2048): - if image_path: - try: - with open(image_path, "rb") as img_file: - encoded_image = base64.b64encode(img_file.read()).decode('ascii') - except: - encoded_image = image_path - - if image_path: - content = [{"type": "image_url","image_url": {"url": f"data:image/jpeg;base64,{encoded_image}"}}, {"type": "text","text": message_text},] - else: - content = [{"type": "text","text": message_text},] - - max_num_trial = 3 - num_trial = 0 - call_api_success = True - - while num_trial < max_num_trial: - try: - response = client.chat.completions.create( - model=deployment, - messages=[ - { - "role": "system", - "content": [ - { - "type": "text", - "text": "You are an AI assistant that is good at making plans and analyzing screens, and helping people find information." - }, - ] - }, - { - "role": "user", - "content": content - } - ], - temperature=0.01, - max_tokens=max_tokens, - ) - ans_1st_pass = response.choices[0].message.content - break - except: - print('retry call gptv', num_trial) - num_trial += 1 - ans_1st_pass = '' - time.sleep(10) - if num_trial == max_num_trial: - call_api_success = False - return ans_1st_pass, call_api_success - def check_ocr_box(image_path, display_img = True, output_bb_format='xywh', goal_filtering=None, easyocr_args=None): if easyocr_args is None: easyocr_args = {} result = reader.readtext(image_path, **easyocr_args) is_goal_filtered = False - if goal_filtering: - ocr_filter_fs = "Example 1:\n Based on task and ocr results, ```In summary, the task related bboxes are: [([[3060, 111], [3135, 111], [3135, 141], [3060, 141]], 'Share', 0.949013667261589), ([[3068, 197], [3135, 197], [3135, 227], [3068, 227]], 'Link _', 0.3567054243152049), ([[3006, 321], [3178, 321], [3178, 354], [3006, 354]], 'Manage Access', 0.8800734456437066)] ``` \n Example 2:\n Based on task and ocr results, ```In summary, the task related bboxes are: [([[3060, 111], [3135, 111], [3135, 141], [3060, 141]], 'Search Google or type a URL', 0.949013667261589)] ```" - # message_text = f"Based on the ocr results which contains text+bounding box in a dictionary, please filter it so that it only contains the task related bboxes. The task is: {goal_filtering}, the ocr results are: {str(result)}. Your final answer should be in the exact same format as the ocr results, please do not include any other redundant information, please do not include any analysis." - message_text = f"Based on the task and ocr results which contains text+bounding box in a dictionary, please filter it so that it only contains the task related bboxes. Requirement: 1. first give a brief analysis. 2. provide an answer in the format: ```In summary, the task related bboxes are: ..```, you must put it inside ``` ```. Do not include any info after ```.\n {ocr_filter_fs}\n The task is: {goal_filtering}, the ocr results are: {str(result)}." - - prompt = [{"role":"system", "content": "You are an AI assistant that helps people find the correct way to operate computer or smartphone."}, {"role":"user","content": message_text},] - print('[Perform OCR filtering by goal] ongoing ...') - # pred, _, _ = call_gpt4(prompt) - pred, _, = call_gpt4v(message_text) - # import pdb; pdb.set_trace() - try: - # match = re.search(r"```(.*?)```", pred, re.DOTALL) - # result = match.group(1).strip() - # pred = result.split('In summary, the task related bboxes are:')[-1].strip() - pred = pred.split('In summary, the task related bboxes are:')[-1].strip().strip('```') - result = ast.literal_eval(pred) - print('[Perform OCR filtering by goal] success!!! Filtered buttons: ', pred) - is_goal_filtered = True - except: - print('[Perform OCR filtering by goal] failed or unused!!!') - pass - # added_prompt = [{"role":"assistant","content":pred}, - # {"role":"user","content": "given the previous answers, please provide the final answer in the exact same format as the ocr results, please do not include any other redundant information, please do not include any analysis."}] - # prompt.extend(added_prompt) - # pred, _, _ = call_gpt4(prompt) - # print('goal filtering pred 2nd:', pred) - # result = ast.literal_eval(pred) # print('goal filtering pred:', result[-5:]) coord = [item[0] for item in result] text = [item[1] for item in result] - # confidence = [item[2] for item in result] - # if confidence_filtering: - # coord = [coord[i] for i in range(len(coord)) if confidence[i] > confidence_filtering] - # text = [text[i] for i in range(len(text)) if confidence[i] > confidence_filtering] # read the image using cv2 if display_img: opencv_img = cv2.imread(image_path) @@ -520,87 +391,4 @@ def check_ocr_box(image_path, display_img = True, output_bb_format='xywh', goal_ return (text, bb), is_goal_filtered -def get_pred_gptv(message_text, yolo_labled_img, label_coordinates, summarize_history=True, verbose=True, history=None, id_key='Click ID'): - """ This func first - 1. call gptv(yolo_labled_img, text bbox+task) -> ans_1st_cal - 2. call gpt4(ans_1st_cal, label_coordinates) -> final ans - """ - - # Configuration - encoded_image = yolo_labled_img - - # Payload for the request - if not history: - messages = [ - {"role": "system", "content": [{"type": "text","text": "You are an AI assistant that is great at interpreting screenshot and predict action."},]}, - {"role": "user","content": [{"type": "text","text": message_text}, {"type": "image_url","image_url": {"url": f"data:image/jpeg;base64,{encoded_image}"}},]} - ] - else: - messages = [ - {"role": "system", "content": [{"type": "text","text": "You are an AI assistant that is great at interpreting screenshot and predict action."},]}, - history, - {"role": "user","content": [{"type": "image_url","image_url": {"url": f"data:image/jpeg;base64,{encoded_image}"}},{"type": "text","text": message_text},]} - ] - - payload = { - "messages": messages, - "temperature": 0.01, # 0.01 - "top_p": 0.95, - "max_tokens": 800 - } - - max_num_trial = 3 - num_trial = 0 - call_api_success = True - while num_trial < max_num_trial: - try: - # response = requests.post(GPT4V_ENDPOINT, headers=headers, json=payload) - # response.raise_for_status() # Will raise an HTTPError if the HTTP request returned an unsuccessful status code - # ans_1st_pass = response.json()['choices'][0]['message']['content'] - response = client.chat.completions.create( - model=deployment, - messages=messages, - temperature=0.01, - max_tokens=512, - ) - ans_1st_pass = response.choices[0].message.content - break - except requests.RequestException as e: - print('retry call gptv', num_trial) - num_trial += 1 - ans_1st_pass = '' - time.sleep(30) - # raise SystemExit(f"Failed to make the request. Error: {e}") - if num_trial == max_num_trial: - call_api_success = False - if verbose: - print('Answer by GPTV: ', ans_1st_pass) - # extract by simple parsing - try: - match = re.search(r"```(.*?)```", ans_1st_pass, re.DOTALL) - if match: - result = match.group(1).strip() - pred = result.split('In summary, the next action I will perform is:')[-1].strip().replace('\\', '') - pred = ast.literal_eval(pred) - else: - pred = ans_1st_pass.split('In summary, the next action I will perform is:')[-1].strip().replace('\\', '') - pred = ast.literal_eval(pred) - - if id_key in pred: - icon_id = pred[id_key] - bbox = label_coordinates[str(icon_id)] - pred['click_point'] = [bbox[0] + bbox[2]/2, bbox[1] + bbox[3]/2] - except: - # import pdb; pdb.set_trace() - print('gptv action regex extract fail!!!') - print('ans_1st_pass:', ans_1st_pass) - pred = {'action_type': 'CLICK', 'click_point': [0, 0], 'value': 'None', 'is_completed': False} - - step_pred_summary = None - if summarize_history: - step_pred_summary, _ = call_gpt4v_new('Summarize what action you decide to perform in the current step, in one sentence, and do not include any icon box number: ' + ans_1st_pass, max_tokens=128) - print('step_pred_summary', step_pred_summary) - return pred, [call_api_success, ans_1st_pass, None, step_pred_summary] - # return pred, [call_api_success, message_2nd, completion_2nd.choices[0].message.content, step_pred_summary] -