/
webarenae989873
import json
import re
from collections import defaultdict
from typing import Any, TypedDict, Union
import numpy as np
import numpy.typing as npt
from gymnasium import spaces
from playwright.sync_api import CDPSession, Page, ViewportSize
from browser_env.constants import (
ASCII_CHARSET,
FREQ_UNICODE_CHARSET,
IGNORED_ACTREE_PROPERTIES,
UTTERANCE_MAX_LENGTH,
)
from .utils import (
AccessibilityTree,
AccessibilityTreeNode,
BrowserConfig,
BrowserInfo,
DOMNode,
DOMTree,
Observation,
png_bytes_to_numpy,
)
IN_VIEWPORT_RATIO_THRESHOLD = 0.6
class ObservationProcessor:
def process(self, page: Page, client: CDPSession) -> Observation:
raise NotImplementedError
class ObservationMetadata(TypedDict):
obs_nodes_info: dict[str, Any]
def create_empty_metadata() -> ObservationMetadata:
return {
"obs_nodes_info": {},
}
class TextObervationProcessor(ObservationProcessor):
def __init__(
self,
observation_type: str,
current_viewport_only: bool,
viewport_size: ViewportSize,
):
self.observation_type = observation_type
self.current_viewport_only = current_viewport_only
self.viewport_size = viewport_size
self.observation_tag = "text"
self.meta_data = (
create_empty_metadata()
) # use the store meta data of this observation type
def fetch_browser_info(
self,
page: Page,
client: CDPSession,
) -> BrowserInfo:
# extract domtree
tree = client.send(
"DOMSnapshot.captureSnapshot",
{
"computedStyles": [],
"includeDOMRects": True,
"includePaintOrder": True,
},
)
# calibrate the bounds, in some cases, the bounds are scaled somehow
bounds = tree["documents"][0]["layout"]["bounds"]
b = bounds[0]
n = b[2] / self.viewport_size["width"]
bounds = [[x / n for x in bound] for bound in bounds]
tree["documents"][0]["layout"]["bounds"] = bounds
# extract browser info
win_top_bound = page.evaluate("window.pageYOffset")
win_left_bound = page.evaluate("window.pageXOffset")
win_width = page.evaluate("window.screen.width")
win_height = page.evaluate("window.screen.height")
win_right_bound = win_left_bound + win_width
win_lower_bound = win_top_bound + win_height
device_pixel_ratio = page.evaluate("window.devicePixelRatio")
assert device_pixel_ratio == 1.0, "devicePixelRatio is not 1.0"
config: BrowserConfig = {
"win_top_bound": win_top_bound,
"win_left_bound": win_left_bound,
"win_width": win_width,
"win_height": win_height,
"win_right_bound": win_right_bound,
"win_lower_bound": win_lower_bound,
"device_pixel_ratio": device_pixel_ratio,
}
# assert len(tree['documents']) == 1, "More than one document in the DOM tree"
info: BrowserInfo = {"DOMTree": tree, "config": config}
return info
@staticmethod
def get_bounding_client_rect(
client: CDPSession, backend_node_id: str
) -> dict[str, Any]:
try:
remote_object = client.send(
"DOM.resolveNode", {"backendNodeId": int(backend_node_id)}
)
remote_object_id = remote_object["object"]["objectId"]
response = client.send(
"Runtime.callFunctionOn",
{
"objectId": remote_object_id,
"functionDeclaration": """
function() {
if (this.nodeType == 3) {
var range = document.createRange();
range.selectNode(this);
var rect = range.getBoundingClientRect().toJSON();
range.detach();
return rect;
} else {
return this.getBoundingClientRect().toJSON();
}
}
""",
"returnByValue": True,
},
)
return response
except Exception as e:
return {"result": {"subtype": "error"}}
@staticmethod
def get_element_in_viewport_ratio(
elem_left_bound: float,
elem_top_bound: float,
width: float,
height: float,
config: BrowserConfig,
) -> float:
elem_right_bound = elem_left_bound + width
elem_lower_bound = elem_top_bound + height
win_left_bound = 0
win_right_bound = config["win_width"]
win_top_bound = 0
win_lower_bound = config["win_height"]
# Compute the overlap in x and y axes
overlap_width = max(
0,
min(elem_right_bound, win_right_bound)
- max(elem_left_bound, win_left_bound),
)
overlap_height = max(
0,
min(elem_lower_bound, win_lower_bound)
- max(elem_top_bound, win_top_bound),
)
# Compute the overlap area
ratio = overlap_width * overlap_height / width * height
return ratio
def fetch_page_html(
self,
info: BrowserInfo,
page: Page,
client: CDPSession,
current_viewport_only: bool,
) -> DOMTree:
# adopted from [natbot](https://github.com/nat/natbot)
tree = info["DOMTree"]
strings = tree["strings"]
document = tree["documents"][0]
nodes = document["nodes"]
# make a dom tree that is easier to navigate
dom_tree: DOMTree = []
graph = defaultdict(list)
for node_idx in range(len(nodes["nodeName"])):
cur_node: DOMNode = {
"nodeId": "",
"nodeType": "",
"nodeName": "",
"nodeValue": "",
"attributes": "",
"backendNodeId": "",
"parentId": "",
"childIds": [],
"cursor": 0,
"union_bound": None,
}
node_type_idx = nodes["nodeType"][node_idx]
node_type = "generic"
if node_type_idx >= 0 and node_type_idx < len(strings):
node_type = strings[node_type_idx]
node_name = strings[nodes["nodeName"][node_idx]]
node_value_idx = nodes["nodeValue"][node_idx]
node_value = ""
if node_value_idx >= 0 and node_value_idx < len(strings):
node_value = " ".join(strings[node_value_idx].split())
node_attributes = [
strings[i] for i in nodes["attributes"][node_idx]
]
node_attributes_str = ""
for i in range(0, len(node_attributes), 2):
a = node_attributes[i]
b = node_attributes[i + 1]
b = " ".join(b.split())
node_attributes_str += f'{a}="{b}" '
node_attributes_str = node_attributes_str.strip()
cur_node["nodeId"] = str(node_idx)
cur_node["nodeType"] = node_type
cur_node["nodeName"] = node_name
cur_node["nodeValue"] = node_value
cur_node["attributes"] = node_attributes_str
cur_node["backendNodeId"] = str(nodes["backendNodeId"][node_idx])
cur_node["parentId"] = str(nodes["parentIndex"][node_idx])
if cur_node["parentId"] != "-1":
graph[cur_node["parentId"]].append(str(cur_node["nodeId"]))
# get the bound
if cur_node["parentId"] == "-1":
cur_node["union_bound"] = [0.0, 0.0, 10.0, 10.0]
else:
response = self.get_bounding_client_rect(
client, cur_node["backendNodeId"]
)
if response.get("result", {}).get("subtype", "") == "error":
cur_node["union_bound"] = None
else:
x = response["result"]["value"]["x"]
y = response["result"]["value"]["y"]
width = response["result"]["value"]["width"]
height = response["result"]["value"]["height"]
cur_node["union_bound"] = [x, y, width, height]
dom_tree.append(cur_node)
# add parent children index to the node
for parent_id, child_ids in graph.items():
dom_tree[int(parent_id)]["childIds"] = child_ids
# remove the nodes that are not in the current viewport
if current_viewport_only:
def remove_node_in_graph(node: DOMNode) -> None:
# update the node information in the accessibility tree
node_id = node["nodeId"]
parent_id = node["parentId"]
child_ids = node["childIds"]
# update the children of the parent node
assert dom_tree[int(parent_id)]["parentId"] != "[REMOVED]"
# remove the nodeid from parent
index = dom_tree[int(parent_id)]["childIds"].index(node_id)
dom_tree[int(parent_id)]["childIds"].pop(index)
# Insert children_nodeids in the same location
for child_id in child_ids:
dom_tree[int(parent_id)]["childIds"].insert(
index, child_id
)
index += 1
# update children node's parent
for child_id in child_ids:
dom_tree[int(child_id)]["parentId"] = parent_id
# mark as removed
dom_tree[int(node_id)]["parentId"] = "[REMOVED]"
config = info["config"]
for cursor, node in enumerate(dom_tree):
if not node["union_bound"]:
remove_node_in_graph(node)
continue
[x, y, width, height] = node["union_bound"]
# invisible node
if width == 0.0 or height == 0.0:
remove_node_in_graph(node)
continue
in_viewport_ratio = self.get_element_in_viewport_ratio(
elem_left_bound=float(x),
elem_top_bound=float(y),
width=float(width),
height=float(height),
config=config,
)
if in_viewport_ratio < IN_VIEWPORT_RATIO_THRESHOLD:
remove_node_in_graph(node)
dom_tree = [
node
for node in dom_tree
if node.get("parentId", "-1") != "[REMOVED]"
]
return dom_tree
@staticmethod
def parse_html(dom_tree: DOMTree) -> tuple[str, dict[str, Any]]:
"""Parse the html tree into a string text"""
obs_nodes_info = {}
nodeid_to_cursor = {
node["nodeId"]: idx for idx, node in enumerate(dom_tree)
}
def dfs(node_cursor: int, depth: int) -> str:
tree_str = ""
node = dom_tree[node_cursor]
indent = "\t" * depth
valid_node = True
try:
node_str = f"[{node_cursor}] <{node['nodeName']}"
if node["attributes"]:
node_str += f" {node['attributes']}"
node_str += f"> {node['nodeValue']}"
valid_node = bool(node["attributes"] or node["nodeValue"])
if valid_node:
obs_nodes_info[str(node_cursor)] = {
"backend_id": node["backendNodeId"],
"union_bound": node["union_bound"],
"text": node_str,
}
tree_str += f"{indent}{node_str}\n"
except Exception as e:
valid_node = False
for child_ids in node["childIds"]:
child_cursor = nodeid_to_cursor[child_ids]
child_depth = depth + 1 if valid_node else depth
child_str = dfs(child_cursor, child_depth)
tree_str += child_str
return tree_str
html = dfs(0, 0)
return html, obs_nodes_info
def fetch_page_accessibility_tree(
self,
info: BrowserInfo,
client: CDPSession,
current_viewport_only: bool,
) -> AccessibilityTree:
accessibility_tree: AccessibilityTree = client.send(
"Accessibility.getFullAXTree", {}
)["nodes"]
# a few nodes are repeated in the accessibility tree
seen_ids = set()
_accessibility_tree = []
for node in accessibility_tree:
if node["nodeId"] not in seen_ids:
_accessibility_tree.append(node)
seen_ids.add(node["nodeId"])
accessibility_tree = _accessibility_tree
nodeid_to_cursor = {}
for cursor, node in enumerate(accessibility_tree):
nodeid_to_cursor[node["nodeId"]] = cursor
# usually because the node is not visible etc
if "backendDOMNodeId" not in node:
node["union_bound"] = None
continue
backend_node_id = str(node["backendDOMNodeId"])
if node["role"]["value"] == "RootWebArea":
# always inside the viewport
node["union_bound"] = [0.0, 0.0, 10.0, 10.0]
else:
response = self.get_bounding_client_rect(
client, backend_node_id
)
if response.get("result", {}).get("subtype", "") == "error":
node["union_bound"] = None
else:
x = response["result"]["value"]["x"]
y = response["result"]["value"]["y"]
width = response["result"]["value"]["width"]
height = response["result"]["value"]["height"]
node["union_bound"] = [x, y, width, height]
# filter nodes that are not in the current viewport
if current_viewport_only:
def remove_node_in_graph(node: AccessibilityTreeNode) -> None:
# update the node information in the accessibility tree
nodeid = node["nodeId"]
node_cursor = nodeid_to_cursor[nodeid]
parent_nodeid = node["parentId"]
children_nodeids = node["childIds"]
parent_cursor = nodeid_to_cursor[parent_nodeid]
# update the children of the parent node
assert (
accessibility_tree[parent_cursor].get("parentId", "Root")
is not None
)
# remove the nodeid from parent's childIds
index = accessibility_tree[parent_cursor]["childIds"].index(
nodeid
)
accessibility_tree[parent_cursor]["childIds"].pop(index)
# Insert children_nodeids in the same location
for child_nodeid in children_nodeids:
accessibility_tree[parent_cursor]["childIds"].insert(
index, child_nodeid
)
index += 1
# update children node's parent
for child_nodeid in children_nodeids:
child_cursor = nodeid_to_cursor[child_nodeid]
accessibility_tree[child_cursor][
"parentId"
] = parent_nodeid
# mark as removed
accessibility_tree[node_cursor]["parentId"] = "[REMOVED]"
config = info["config"]
for node in accessibility_tree:
if not node["union_bound"]:
remove_node_in_graph(node)
continue
[x, y, width, height] = node["union_bound"]
# invisible node
if width == 0 or height == 0:
remove_node_in_graph(node)
continue
in_viewport_ratio = self.get_element_in_viewport_ratio(
elem_left_bound=float(x),
elem_top_bound=float(y),
width=float(width),
height=float(height),
config=config,
)
if in_viewport_ratio < IN_VIEWPORT_RATIO_THRESHOLD:
remove_node_in_graph(node)
accessibility_tree = [
node
for node in accessibility_tree
if node.get("parentId", "Root") != "[REMOVED]"
]
return accessibility_tree
@staticmethod
def parse_accessibility_tree(
accessibility_tree: AccessibilityTree,
) -> tuple[str, dict[str, Any]]:
"""Parse the accessibility tree into a string text"""
node_id_to_idx = {}
for idx, node in enumerate(accessibility_tree):
node_id_to_idx[node["nodeId"]] = idx
obs_nodes_info = {}
def dfs(idx: int, obs_node_id: str, depth: int) -> str:
tree_str = ""
node = accessibility_tree[idx]
indent = "\t" * depth
valid_node = True
try:
role = node["role"]["value"]
name = node["name"]["value"]
node_str = f"[{obs_node_id}] {role} {repr(name)}"
properties = []
for property in node.get("properties", []):
try:
if property["name"] in IGNORED_ACTREE_PROPERTIES:
continue
properties.append(
f'{property["name"]}: {property["value"]["value"]}'
)
except KeyError:
pass
if properties:
node_str += " " + " ".join(properties)
# check valid
if not node_str.strip():
valid_node = False
# empty generic node
if not name.strip():
if not properties:
if role in [
"generic",
"img",
"list",
"strong",
"paragraph",
"banner",
"navigation",
"Section",
"LabelText",
"Legend",
"listitem",
]:
valid_node = False
elif role in ["listitem"]:
valid_node = False
if valid_node:
tree_str += f"{indent}{node_str}"
obs_nodes_info[obs_node_id] = {
"backend_id": node["backendDOMNodeId"],
"union_bound": node["union_bound"],
"text": node_str,
}
except Exception as e:
valid_node = False
for _, child_node_id in enumerate(node["childIds"]):
if child_node_id not in node_id_to_idx:
continue
# mark this to save some tokens
child_depth = depth + 1 if valid_node else depth
child_str = dfs(
node_id_to_idx[child_node_id], child_node_id, child_depth
)
if child_str.strip():
if tree_str.strip():
tree_str += "\n"
tree_str += child_str
return tree_str
tree_str = dfs(0, accessibility_tree[0]["nodeId"], 0)
return tree_str, obs_nodes_info
@staticmethod
def clean_accesibility_tree(tree_str: str) -> str:
"""further clean accesibility tree"""
clean_lines: list[str] = []
for line in tree_str.split("\n"):
# remove statictext if the content already appears in the previous line
if "statictext" in line.lower():
prev_lines = clean_lines[-3:]
pattern = r"\[\d+\] StaticText (.+)"
match = re.search(pattern, line, re.DOTALL)
if match:
static_text = match.group(1)[1:-1] # remove the quotes
if static_text and all(
static_text not in prev_line
for prev_line in prev_lines
):
clean_lines.append(line)
else:
clean_lines.append(line)
return "\n".join(clean_lines)
def process(self, page: Page, client: CDPSession) -> str:
# get the tab info
open_tabs = page.context.pages
try:
tab_titles = [tab.title() for tab in open_tabs]
current_tab_idx = open_tabs.index(page)
for idx in range(len(open_tabs)):
if idx == current_tab_idx:
tab_titles[
idx
] = f"Tab {idx} (current): {open_tabs[idx].title()}"
else:
tab_titles[idx] = f"Tab {idx}: {open_tabs[idx].title()}"
tab_title_str = " | ".join(tab_titles)
except Exception:
tab_title_str = " | ".join(
["Tab {idx}" for idx in range(len(open_tabs))]
)
try:
browser_info = self.fetch_browser_info(page, client)
except Exception:
page.wait_for_load_state("load", timeout=500)
browser_info = self.fetch_browser_info(page, client)
if self.observation_type == "html":
dom_tree = self.fetch_page_html(
browser_info,
page,
client,
current_viewport_only=self.current_viewport_only,
)
content, obs_nodes_info = self.parse_html(dom_tree)
self.obs_nodes_info = obs_nodes_info
self.meta_data["obs_nodes_info"] = obs_nodes_info
elif self.observation_type == "accessibility_tree":
accessibility_tree = self.fetch_page_accessibility_tree(
browser_info,
client,
current_viewport_only=self.current_viewport_only,
)
content, obs_nodes_info = self.parse_accessibility_tree(
accessibility_tree
)
content = self.clean_accesibility_tree(content)
self.obs_nodes_info = obs_nodes_info
self.meta_data["obs_nodes_info"] = obs_nodes_info
else:
raise ValueError(
f"Invalid observatrion type: {self.observation_type}"
)
self.browser_config = browser_info["config"]
content = f"{tab_title_str}\n\n{content}"
return content
def get_element_center(self, element_id: str) -> tuple[float, float]:
node_info = self.obs_nodes_info[element_id]
node_bound = node_info["union_bound"]
x, y, width, height = node_bound
center_x = x + width / 2
center_y = y + height / 2
return (
center_x / self.viewport_size["width"],
center_y / self.viewport_size["height"],
)
class ImageObservationProcessor(ObservationProcessor):
def __init__(self, observation_type: str):
self.observation_type = observation_type
self.observation_tag = "image"
self.meta_data = create_empty_metadata()
def process(self, page: Page, client: CDPSession) -> npt.NDArray[np.uint8]:
try:
screenshot = png_bytes_to_numpy(page.screenshot())
except:
page.wait_for_event("load")
screenshot = png_bytes_to_numpy(page.screenshot())
return screenshot
class ObservationHandler:
"""Main entry point to access all observation processor"""
def __init__(
self,
main_observation_type: str,
text_observation_type: str,
image_observation_type: str,
current_viewport_only: bool,
viewport_size: ViewportSize,
) -> None:
self.main_observation_type = main_observation_type
self.text_processor = TextObervationProcessor(
text_observation_type, current_viewport_only, viewport_size
)
self.image_processor = ImageObservationProcessor(
image_observation_type
)
self.viewport_size = viewport_size
def get_observation_space(self) -> spaces.Dict:
text_space = spaces.Text(
min_length=0,
max_length=UTTERANCE_MAX_LENGTH,
charset=ASCII_CHARSET + FREQ_UNICODE_CHARSET,
)
image_space = spaces.Box(
# Each position stores the RGB values. Note the swapped axes (height first).
np.zeros(
(self.viewport_size["height"], self.viewport_size["width"], 3),
dtype=np.uint8,
),
np.ones(
(self.viewport_size["height"], self.viewport_size["width"], 3),
dtype=np.uint8,
)
* 255.0,
dtype=np.uint8,
)
return spaces.Dict({"text": text_space, "image": image_space})
def get_observation(
self, page: Page, client: CDPSession
) -> dict[str, Observation]:
text_obs = self.text_processor.process(page, client)
image_obs = self.image_processor.process(page, client)
return {"text": text_obs, "image": image_obs}
def get_observation_metadata(self) -> dict[str, ObservationMetadata]:
return {
"text": self.text_processor.meta_data,
"image": self.image_processor.meta_data,
}
@property
def action_processor(self) -> ObservationProcessor:
"""Return the main processor that is associated with the action space"""
if self.main_observation_type == "text":
return self.text_processor
elif self.main_observation_type == "image":
return self.image_processor
else:
raise ValueError("Invalid main observation type")