mirrored 18 minutes ago
0
BlackSoi1initial commit 0c7e038
import argparse
import json
import re
import sys


def extract_sql_from_response(response_string):
    """
    Extract all SQL code blocks wrapped with ```sql and ``` from the response string.
    Returns a list of SQL statements.
    """
    sql_pattern = re.compile(
        r"```sql\s*(.*?)```",
        re.IGNORECASE | re.DOTALL,
    )
    # Find all matches
    sql_statements = sql_pattern.findall(response_string)
    # Strip whitespace from each statement
    sql_statements = [stmt.strip() for stmt in sql_statements]

    return sql_statements


def process_file(input_file, output_file):
    """
    Process a JSONL file to extract SQL statements from responses.
    """
    with open(input_file, "r", encoding="utf-8") as infile, open(
        output_file, "w", encoding="utf-8"
    ) as outfile:
        for line_number, line in enumerate(infile, 1):
            # Parse the line as JSON
            try:
                data = json.loads(line.strip())
                response = data.get("response", "")

                # Extract SQL statements
                sql_list = extract_sql_from_response(response)
                print(
                    f"Extracted {len(sql_list)} SQL statements from line {line_number}"
                )
                # Add the list to the data
                data["pred_sqls"] = sql_list

                # Write the updated data
                outfile.write(json.dumps(data, ensure_ascii=False) + "\n")
            except json.JSONDecodeError:
                print(
                    f"Skipping invalid JSON line {line_number}: {line.strip()}",
                    file=sys.stderr,
                )


def main():
    parser = argparse.ArgumentParser(
        description="Extract SQL statements from LLM responses."
    )
    parser.add_argument(
        "--input_path", type=str, required=True, help="Path to the input JSONL file."
    )
    parser.add_argument(
        "--output_path", type=str, required=True, help="Path to the output JSONL file."
    )
    args = parser.parse_args()

    process_file(args.input_path, args.output_path)


if __name__ == "__main__":
    main()