/
BIRD-critiq0c7e038
import argparse
import json
import re
import sys
def extract_sql_from_response(response_string):
"""
Extract all SQL code blocks wrapped with ```sql and ``` from the response string.
Returns a list of SQL statements.
"""
sql_pattern = re.compile(
r"```sql\s*(.*?)```",
re.IGNORECASE | re.DOTALL,
)
# Find all matches
sql_statements = sql_pattern.findall(response_string)
# Strip whitespace from each statement
sql_statements = [stmt.strip() for stmt in sql_statements]
return sql_statements
def process_file(input_file, output_file):
"""
Process a JSONL file to extract SQL statements from responses.
"""
with open(input_file, "r", encoding="utf-8") as infile, open(
output_file, "w", encoding="utf-8"
) as outfile:
for line_number, line in enumerate(infile, 1):
# Parse the line as JSON
try:
data = json.loads(line.strip())
response = data.get("response", "")
# Extract SQL statements
sql_list = extract_sql_from_response(response)
print(
f"Extracted {len(sql_list)} SQL statements from line {line_number}"
)
# Add the list to the data
data["pred_sqls"] = sql_list
# Write the updated data
outfile.write(json.dumps(data, ensure_ascii=False) + "\n")
except json.JSONDecodeError:
print(
f"Skipping invalid JSON line {line_number}: {line.strip()}",
file=sys.stderr,
)
def main():
parser = argparse.ArgumentParser(
description="Extract SQL statements from LLM responses."
)
parser.add_argument(
"--input_path", type=str, required=True, help="Path to the input JSONL file."
)
parser.add_argument(
"--output_path", type=str, required=True, help="Path to the output JSONL file."
)
args = parser.parse_args()
process_file(args.input_path, args.output_path)
if __name__ == "__main__":
main()