| from typing import Dict, List, Any |
| from peft import AutoPeftModelForCausalLM |
| import transformers |
| import os |
| import tempfile |
| from PIL import Image, ImageDraw |
| from io import BytesIO |
| import base64, json |
|
|
| COORDINATE_PROMPT = 'In this UI screenshot, what is the position of the element corresponding to the command \"{command}\" (with point)?' |
|
|
| PARTITION_PROMPT = 'In this UI screenshot, what is the partition of the element corresponding to the command \"{command}\" (with quadrant number)?' |
|
|
| class EndpointHandler(): |
| def __init__(self, path=""): |
| self.model = transformers.AutoModelForCausalLM.from_pretrained( |
| path, |
| device_map="cuda", |
| trust_remote_code=True, |
| fp16=True).eval() |
| self.tokenizer = transformers.AutoTokenizer.from_pretrained( |
| path, |
| cache_dir=None, |
| model_max_length=2048, |
| padding_side="right", |
| use_fast=False, |
| trust_remote_code=True, |
| ) |
| self.tokenizer.pad_token_id = self.tokenizer.eod_id |
| return |
| |
| def __call__(self, data: Dict[str, Any]) -> List[Dict[str, Any]]: |
| """ |
| data args: |
| image (:obj: `str`) |
| task (:obj: `str`) |
| k (:obj: `str`) |
| context (:obj: 'str') |
| kwargs |
| Return: |
| A :obj:`list` | `dict`: will be serialized and returned |
| """ |
| |
| with tempfile.TemporaryDirectory() as temp_dir: |
| image = os.path.join(temp_dir, "image.png") |
| img = Image.open(BytesIO(base64.b64decode(data["inputs"]["image"]))) |
| img.save(image) |
| command = data["inputs"]["task"] |
| K = int(data["inputs"]["k"]) |
| keep_context = bool(data["inputs"]["context"]) |
|
|
| print(image) |
| print(command) |
| print(K) |
| print(keep_context) |
|
|
| images = [image] |
| partitions = [] |
|
|
| for k in range(K): |
| query = self.tokenizer.from_list_format(([{ 'image': context_image } for context_image in images] if keep_context else [{'image': image}]) + |
| [{'text': PARTITION_PROMPT.format(command=command)}]) |
| response, _ = self.model.chat(self.tokenizer, query=query, history=None) |
|
|
| partition = int(response.split(" ")[-1]) |
| partitions.append(partition) |
|
|
| |
| with Image.open(image) as img: |
| width, height = img.size |
| if partition == 1: |
| img = img.crop((width // 2, 0, width, height // 2)) |
| elif partition == 2: |
| img = img.crop((0, 0, width // 2, height // 2)) |
| elif partition == 3: |
| img = img.crop((0, height // 2, width // 2, height)) |
| elif partition == 4: |
| img = img.crop((width // 2, height // 2, width, height)) |
| |
| new_path = os.path.join(temp_dir, f"partition{k}.png") |
| img.save(new_path) |
| image = new_path |
| images.append(image) |
| |
| query = self.tokenizer.from_list_format(([{ 'image': context_image } for context_image in images] if keep_context else [{'image': image}]) + |
| [{'text': COORDINATE_PROMPT.format(command=command)}]) |
| response, _ = self.model.chat(self.tokenizer, query=query, history=None) |
| print("Coordinate Response:", response) |
|
|
| x = float(response.split(",")[0].split("(")[1]) |
| y = float(response.split(",")[1].split(")")[0]) |
|
|
| for partition in partitions[::-1]: |
| if partition == 1: |
| x = x/2 + 0.5 |
| y = y/2 |
| elif partition == 2: |
| x = x/2 |
| y = y/2 |
| elif partition == 3: |
| x = x/2 |
| y = y/2 + 0.5 |
| elif partition == 4: |
| x = x/2 + 0.5 |
| y = y/2 + 0.5 |
| print("rescaled point:", x, y) |
|
|
| response = {} |
| response['x'] = x |
| response['y'] = y |
| return response |
| |