import json import re from pathlib import Path from typing import Any, TypedDict from browser_env import Action, ActionParsingError, Trajectory from browser_env.env_config import URL_MAPPINGS from browser_env.utils import StateInfo from llms import lm_config from llms.tokenizers import Tokenizer from llms.utils import APIInput class Instruction(TypedDict): """Instruction for constructing prompt""" intro: str examples: list[tuple[str, str]] template: str meta_data: dict[str, Any] class PromptConstructor(object): def __init__( self, instruction_path: str | Path, lm_config: lm_config.LMConfig, tokenizer: Tokenizer, ): self.instruction_path = Path(instruction_path) self.obs_modality = "text" self.lm_config = lm_config instruction = json.load(open(self.instruction_path)) instruction["examples"] = [tuple(e) for e in instruction["examples"]] self.instruction: Instruction = instruction self.tokenizer = tokenizer def get_lm_api_input( self, intro: str, examples: list[tuple[str, str]], current: str ) -> APIInput: """Return the require format for an API""" message: list[dict[str, str]] | str if "openai" in self.lm_config.provider: if self.lm_config.mode == "chat": message = [{"role": "system", "content": intro}] for (x, y) in examples: message.append( { "role": "system", "name": "example_user", "content": x, } ) message.append( { "role": "system", "name": "example_assistant", "content": y, } ) message.append({"role": "user", "content": current}) return message elif self.lm_config.mode == "completion": message = f"{intro}\n\n" message += "Here are a few examples:\n" for example in examples: message += f"Observation\n:{example[0]}\n\n" message += f"Action: {example[1]}\n\n" message += "Now make prediction given the observation\n\n" message += f"Observation\n:{current}\n\n" message += "Action:" return message else: raise ValueError( f"OpenAI models do not support mode {self.lm_config.mode}" ) elif "huggingface" in self.lm_config.provider: # https://huggingface.co/blog/llama2#how-to-prompt-llama-2 # https://github.com/facebookresearch/llama/blob/main/llama/generation.py#L320 if "Llama-2" in self.lm_config.model: if self.lm_config.mode == "chat": B_INST, E_INST = "[INST]", "[/INST]" B_SYS, E_SYS = "<>\n", "\n<>\n\n" BOS, EOS = "", "" # adding the system message to be the starting of the first example examples = [ ( B_SYS + intro + E_SYS + examples[0][0], examples[0][1], ) ] + examples[1:] message = "".join( [ f"{BOS}{B_INST} {x.strip()} {E_INST} {y.strip()} {EOS}" for (x, y) in examples ] ) # add the current observation message += f"{BOS}{B_INST} {current.strip()} {E_INST} {self.instruction['meta_data'].get('force_prefix', '')}" return message else: raise ValueError("Only chat mode is supported for Llama-2") else: raise ValueError( f"Huggingface models do not support model_tag {self.lm_config.gen_config['model_tag']}" ) else: raise NotImplementedError( f"Provider {self.lm_config.provider} not implemented" ) def construct( self, trajectory: Trajectory, intent: str, meta_data: dict[str, Any] = {}, ) -> APIInput: raise NotImplementedError def map_url_to_real(self, url: str) -> str: """Map the urls to their real world counterparts""" for i, j in URL_MAPPINGS.items(): if i in url: url = url.replace(i, j) return url def map_url_to_local(self, url: str) -> str: """Map the urls to their local counterparts""" for i, j in URL_MAPPINGS.items(): if j in url: url = url.replace(j, i) # https if j.replace("http", "https") in url: url = url.replace(j.replace("http", "https"), i) return url def _extract_action(self, response: str) -> str: raise NotImplementedError def extract_action(self, response: str) -> str: response = self._extract_action(response) response = self.map_url_to_local(response) return response class DirectPromptConstructor(PromptConstructor): """The agent will direct predict the action""" def __init__( self, instruction_path: str | Path, lm_config: lm_config.LMConfig, tokenizer: Tokenizer, ): super().__init__(instruction_path, lm_config, tokenizer) def construct( self, trajectory: Trajectory, intent: str, meta_data: dict[str, Any] = {}, ) -> APIInput: """Construct prompt given the trajectory""" intro = self.instruction["intro"] examples = self.instruction["examples"] template = self.instruction["template"] keywords = self.instruction["meta_data"]["keywords"] state_info: StateInfo = trajectory[-1] # type: ignore[assignment] obs = state_info["observation"][self.obs_modality] max_obs_length = self.lm_config.gen_config["max_obs_length"] if max_obs_length: obs = self.tokenizer.decode(self.tokenizer.encode(obs)[:max_obs_length]) # type: ignore[arg-type] page = state_info["info"]["page"] url = page.url previous_action_str = meta_data["action_history"][-1] # input x current = template.format( objective=intent, url=self.map_url_to_real(url), observation=obs, previous_action=previous_action_str, ) # make sure all keywords are replaced assert all([f"{{k}}" not in current for k in keywords]) prompt = self.get_lm_api_input(intro, examples, current) return prompt def _extract_action(self, response: str) -> str: action_splitter = self.instruction["meta_data"]["action_splitter"] pattern = rf"{action_splitter}((.|\n)*?){action_splitter}" match = re.search(pattern, response) if match: return match.group(1).strip() else: raise ActionParsingError( f"Cannot parse action from response {response}" ) class CoTPromptConstructor(PromptConstructor): """The agent will perform step-by-step reasoning before the answer""" def __init__( self, instruction_path: str | Path, lm_config: lm_config.LMConfig, tokenizer: Tokenizer, ): super().__init__(instruction_path, lm_config, tokenizer) self.answer_phrase = self.instruction["meta_data"]["answer_phrase"] def construct( self, trajectory: Trajectory, intent: str, meta_data: dict[str, Any] = {}, ) -> APIInput: intro = self.instruction["intro"] examples = self.instruction["examples"] template = self.instruction["template"] keywords = self.instruction["meta_data"]["keywords"] state_info: StateInfo = trajectory[-1] # type: ignore[assignment] obs = state_info["observation"][self.obs_modality] max_obs_length = self.lm_config.gen_config["max_obs_length"] if max_obs_length: obs = self.tokenizer.decode(self.tokenizer.encode(obs)[:max_obs_length]) # type: ignore[arg-type] page = state_info["info"]["page"] url = page.url previous_action_str = meta_data["action_history"][-1] current = template.format( objective=intent, url=self.map_url_to_real(url), observation=obs, previous_action=previous_action_str, ) assert all([f"{{k}}" not in current for k in keywords]) prompt = self.get_lm_api_input(intro, examples, current) return prompt def _extract_action(self, response: str) -> str: # find the first occurence of action action_splitter = self.instruction["meta_data"]["action_splitter"] pattern = rf"{action_splitter}((.|\n)*?){action_splitter}" match = re.search(pattern, response) if match: return match.group(1).strip() else: raise ActionParsingError( f'Cannot find the answer phrase "{self.answer_phrase}" in "{response}"' )