/
BIRD-critiq0c7e038
import json
import os
import argparse
from tqdm import tqdm
from prompt import assistant_prompt
from datasets import load_dataset
# Utility functions
def load_jsonl(file_path):
with open(file_path, "r") as file:
return [json.loads(line) for line in file]
def create_directory(path):
if not os.path.exists(path):
os.makedirs(path)
def write_prompts(prompts, data_list, prompt_path):
create_directory(os.path.dirname(prompt_path))
with open(prompt_path, "w") as f:
for i, instance in enumerate(data_list):
instance["prompt"] = prompts[i]
f.write(json.dumps(instance, ensure_ascii=False) + "\n")
def generate_prompts(data_list, prompt_type):
prompt_list = []
final_data_list = []
# Use tqdm to show progress while generating prompts
for data in tqdm(data_list, desc="Generating prompts"):
if prompt_type == "assistant":
prompt_list.append(assistant_prompt(data))
final_data_list.append(data)
else:
raise ValueError(f"Invalid prompt type: {prompt_type}")
return prompt_list, final_data_list
if __name__ == "__main__":
parser = argparse.ArgumentParser(
description="Generate prompts for the SO-Evaluation task."
)
parser.add_argument("--data_path", type=str, help="Path to the data file.")
parser.add_argument(
"--prompt_path", type=str, help="Path to save the generated prompts."
)
parser.add_argument(
"--prompt_type",
type=str,
help="Type of prompt to generate.",
)
args = parser.parse_args()
# Load the data from the JSONL file
data_list = load_jsonl(args.data_path)
# or to load the data from the Hugging Face dataset
# dataset = load_dataset("birdsql/bird-critic-1.0-flash-exp")
# data_list = dataset["train"]
# final_data_list = filter_instances(data_list)
final_data_list = data_list
prompt_list, final_data_list = generate_prompts(final_data_list, args.prompt_type)
# prompt_list = prompt_list[:3]
# final_data_list = final_data_list[:3]
write_prompts(prompt_list, final_data_list, args.prompt_path)
print(f"Generated {len(prompt_list)} prompts.")
print("Prompts generated successfully.")