mirrored 17 minutes ago
0
BlackSoi1initial commit 0c7e038
import json
import os
import argparse
from tqdm import tqdm
from prompt import assistant_prompt
from datasets import load_dataset


# Utility functions
def load_jsonl(file_path):
    with open(file_path, "r") as file:
        return [json.loads(line) for line in file]


def create_directory(path):
    if not os.path.exists(path):
        os.makedirs(path)


def write_prompts(prompts, data_list, prompt_path):
    create_directory(os.path.dirname(prompt_path))
    with open(prompt_path, "w") as f:
        for i, instance in enumerate(data_list):
            instance["prompt"] = prompts[i]
            f.write(json.dumps(instance, ensure_ascii=False) + "\n")


def generate_prompts(data_list, prompt_type):
    prompt_list = []
    final_data_list = []

    # Use tqdm to show progress while generating prompts
    for data in tqdm(data_list, desc="Generating prompts"):
        if prompt_type == "assistant":
            prompt_list.append(assistant_prompt(data))
            final_data_list.append(data)
        else:
            raise ValueError(f"Invalid prompt type: {prompt_type}")
    return prompt_list, final_data_list


if __name__ == "__main__":
    parser = argparse.ArgumentParser(
        description="Generate prompts for the SO-Evaluation task."
    )
    parser.add_argument("--data_path", type=str, help="Path to the data file.")
    parser.add_argument(
        "--prompt_path", type=str, help="Path to save the generated prompts."
    )
    parser.add_argument(
        "--prompt_type",
        type=str,
        help="Type of prompt to generate.",
    )
    args = parser.parse_args()

    # Load the data from the JSONL file
    data_list = load_jsonl(args.data_path)

    # or to load the data from the Hugging Face dataset
    # dataset = load_dataset("birdsql/bird-critic-1.0-flash-exp")
    # data_list = dataset["train"]

    # final_data_list = filter_instances(data_list)
    final_data_list = data_list
    prompt_list, final_data_list = generate_prompts(final_data_list, args.prompt_type)

    # prompt_list = prompt_list[:3]
    # final_data_list = final_data_list[:3]
    write_prompts(prompt_list, final_data_list, args.prompt_path)
    print(f"Generated {len(prompt_list)} prompts.")
    print("Prompts generated successfully.")