# evaluation.py import argparse import sys import os import io import multiprocessing import threading import queue import json from concurrent.futures import ThreadPoolExecutor, as_completed from datetime import datetime from tqdm import tqdm as tqdm_progress # Local imports from logger import ( configure_logger, NullLogger, ) from utils import load_jsonl, split_field, save_report_and_status from db_utils import ( perform_query_on_postgresql_databases, close_postgresql_connection, execute_queries, close_all_postgresql_pools, get_connection_for_phase, reset_and_restore_database, create_ephemeral_db_copies, drop_ephemeral_dbs, ) from test_utils import ( check_sql_function_usage, remove_distinct, preprocess_results, ex_base, performance_compare_by_qep, ) from datasets import load_dataset # Global counters number_of_execution_errors = 0 number_of_timeouts = 0 number_of_assertion_errors = 0 total_passed_instances = 0 number_error_unexpected_pass = 0 question_test_case_results = [] def run_test_case( test_code, result, logger, idx, return_dict, conn, pred_sqls, sol_sqls, db_name ): """ In a separate Process, runs the test_code with the given environment and captures pass/fail status. """ global_env = { "perform_query_on_postgresql_databases": perform_query_on_postgresql_databases, "execute_queries": execute_queries, "ex_base": ex_base, "performance_compare_by_qep": performance_compare_by_qep, "check_sql_function_usage": check_sql_function_usage, "remove_distinct": remove_distinct, "preprocess_results": preprocess_results, "pred_query_result": result, } local_env = { "conn": conn, "pred_sqls": pred_sqls, "sol_sqls": sol_sqls, "db_name": db_name, } logger.info(f"Passing result is {result}") test_case_code = "from datetime import date\n" + test_code test_case_code += ( "\n__test_case_result__ = test_case(pred_sqls, sol_sqls, db_name, conn)" ) logger.info(f"Test case content:\n{test_case_code}") logger.info(f"Executing test case {idx}") old_stdout = sys.stdout mystdout = io.StringIO() sys.stdout = mystdout try: exec(test_case_code, global_env, local_env) logger.info(f"Test case {idx} passed.") return_dict[idx] = "passed" except AssertionError as e: logger.error(f"Test case {idx} failed due to assertion error: {e}") return_dict[idx] = "failed" except Exception as e: logger.error(f"Test case {idx} failed due to error: {e}") return_dict[idx] = "failed" finally: sys.stdout = old_stdout captured_output = mystdout.getvalue() if captured_output.strip(): logger.info(f"Captured output from test_code:\n{captured_output}") def execute_test_cases( test_cases, sql_result, logger, conn, error_sql, sol_sql, db_name ): """ Spawns each test case in a separate Process. Returns (passed_count, failed_tests). """ manager = multiprocessing.Manager() return_dict = manager.dict() processes = [] for i, test_case in enumerate(test_cases, start=1): logger.info(f"Starting test case {i}/{len(test_cases)}") p = multiprocessing.Process( target=run_test_case, args=( test_case, sql_result, logger, i, return_dict, conn, error_sql, sol_sql, db_name, ), ) p.start() p.join(timeout=60) if p.is_alive(): logger.error(f"Test case {i} execution timed out.") p.terminate() p.join() return_dict[i] = "timeout" processes.append(p) passed_count = 0 failed_tests = [] for idx in range(1, len(test_cases) + 1): status = return_dict.get(idx, "failed") if status == "passed": passed_count += 1 else: failed_tests.append(f"test_{idx}") return passed_count, failed_tests def run_preprocessing(preprocess_sql, db_name, logger, conn): """ Execute any pre-processing SQL statements. """ if preprocess_sql: execute_queries( preprocess_sql, db_name, conn, logger, section_title="Preprocess SQL" ) def run_evaluation_phase( pred_sqls, sol_sqls, error_sqls, db_name, test_cases, logger, conn, efficiency ): """ 1. Execute 'pred_sql' 2. If no error, run test cases. Returns tuple of flags + (passed_count, failed_tests). """ sol_sql_result, exec_error_flag, timeout_flag = execute_queries( pred_sqls, db_name, conn, logger, section_title="LLM Generated SQL" ) instance_execution_error = exec_error_flag instance_timeout_error = timeout_flag instance_assertion_error = False passed_count = 0 failed_tests = [] if not instance_execution_error and not instance_timeout_error and test_cases: if not efficiency: passed_count, failed_tests = execute_test_cases( test_cases, sol_sql_result, logger, conn, pred_sqls, # pred_sqls param for run_test_case sol_sqls, # sol_sqls param for run_test_case db_name, ) else: passed_count, failed_tests = execute_test_cases( test_cases, sol_sql_result, logger, conn, error_sqls, # pass error_sql as "pred_sqls" slow pred_sqls, # pass pred_sql as "sol_sqls" fast db_name, ) if failed_tests: instance_assertion_error = True return ( instance_execution_error, instance_timeout_error, instance_assertion_error, passed_count, failed_tests, ) def process_one_instance(data_item, ephemeral_db_queues, args, global_stats_lock): """ Orchestrate the entire logic for a single instance: - Acquire ephemeral DB - Evaluation Phase - Cleanup - Update global counters """ global number_of_execution_errors, number_of_timeouts global number_of_assertion_errors global total_passed_instances, number_error_unexpected_pass instance_id = data_item["instance_id"] log_filename = os.path.splitext(args.jsonl_file)[0] + f"_instance_{instance_id}.log" if args.logging == "true": logger = configure_logger(log_filename) else: logger = NullLogger() required_fields = [ "selected_database", "preprocess_sql", "error_sql", "sol_sql", "pred_sqls", ] missing_fields = [field for field in required_fields if field not in data_item] if missing_fields: logger.error(f"Missing required fields: {', '.join(missing_fields)}") with global_stats_lock: number_of_execution_errors += 1 return { "instance_id": instance_id, "status": "failed", "error_message": f"Missing fields: {', '.join(missing_fields)}", "total_test_cases": len(data_item.get("test_cases", [])), "passed_test_cases": 0, "failed_test_cases": [], "evaluation_phase_execution_error": False, "evaluation_phase_timeout_error": False, "evaluation_phase_assertion_error": False, } efficiency = data_item.get("efficiency", False) db_name = data_item["selected_database"] preprocess_sql = split_field(data_item, "preprocess_sql") error_sqls = split_field(data_item, "error_sql") pred_sqls = split_field(data_item, "pred_sqls") sol_sqls = split_field(data_item, "sol_sql") clean_up_sql = split_field(data_item, "clean_up_sql") test_cases = data_item.get("test_cases", []) evaluation_phase_execution_error = False evaluation_phase_timeout_error = False evaluation_phase_assertion_error = False total_test_cases = len(test_cases) passed_test_cases_count = 0 failed_test_cases = [] error_message_text = "" # Acquire ephemeral db try: ephemeral_db = ephemeral_db_queues[db_name].get(timeout=60) except queue.Empty: logger.error(f"No available ephemeral databases for base_db: {db_name}") with global_stats_lock: print("run here") number_of_execution_errors += 1 return { "instance_id": instance_id, "status": "failed", "error_message": "No available ephemeral databases.", "total_test_cases": total_test_cases, "passed_test_cases": 0, "failed_test_cases": [], "evaluation_phase_execution_error": True, "evaluation_phase_timeout_error": False, "evaluation_phase_assertion_error": False, } logger.info(f"Instance {instance_id} is using ephemeral db: {ephemeral_db}") try: # ---------- Evaluation Phase ---------- logger.info("=== Starting Evaluation Phase ===") evaluation_conn = get_connection_for_phase(ephemeral_db, logger) run_preprocessing(preprocess_sql, ephemeral_db, logger, evaluation_conn) ( evaluation_phase_execution_error, evaluation_phase_timeout_error, evaluation_phase_assertion_error, passed_count, failed_tests, ) = run_evaluation_phase( pred_sqls, sol_sqls, error_sqls, ephemeral_db, test_cases, logger, evaluation_conn, efficiency, ) close_postgresql_connection(ephemeral_db, evaluation_conn) passed_test_cases_count += passed_count failed_test_cases.extend(failed_tests) # Cleanup SQL if clean_up_sql: logger.info("Executing Clean Up SQL after solution phase.") new_temp_conn = get_connection_for_phase(ephemeral_db, logger) execute_queries( clean_up_sql, ephemeral_db, new_temp_conn, logger, section_title="Clean Up SQL", ) close_postgresql_connection(ephemeral_db, new_temp_conn) reset_and_restore_database(ephemeral_db, "123123", logger) logger.info("=== Evaluation Phase Completed ===") except Exception as e: print(f"RUN HERE instance {instance_id} with ERROR {e}") logger.error(f"Error during execution for question {instance_id}: {e}") error_message_text += str(e) finally: # Return the ephemeral database to the queue ephemeral_db_queues[db_name].put(ephemeral_db) logger.info( f"Instance {instance_id} finished. Returned ephemeral db: {ephemeral_db}" ) # ---------- Update Global Stats ---------- with global_stats_lock: if evaluation_phase_execution_error: number_of_execution_errors += 1 if evaluation_phase_timeout_error: number_of_timeouts += 1 if evaluation_phase_assertion_error: number_of_assertion_errors += 1 if ( not evaluation_phase_execution_error and not evaluation_phase_timeout_error and not evaluation_phase_assertion_error ): total_passed_instances += 1 # ---------- Determine status ---------- ret_status = "success" if ( evaluation_phase_execution_error or evaluation_phase_timeout_error or evaluation_phase_assertion_error ): ret_status = "failed" return { "instance_id": instance_id, "status": ret_status, "error_message": error_message_text if error_message_text else None, "total_test_cases": total_test_cases, "passed_test_cases": passed_test_cases_count, "failed_test_cases": failed_test_cases, "evaluation_phase_execution_error": evaluation_phase_execution_error, "evaluation_phase_timeout_error": evaluation_phase_timeout_error, "evaluation_phase_assertion_error": evaluation_phase_assertion_error, } def main(): global number_of_execution_errors, number_of_timeouts global number_of_assertion_errors global total_passed_instances, number_error_unexpected_pass global question_test_case_results parser = argparse.ArgumentParser( description="Execute SQL solution and test cases (PostgreSQL)." ) parser.add_argument( "--jsonl_file", required=True, help="Path to the JSONL file containing the dataset instances.", ) parser.add_argument( "--limit", type=int, default=None, help="Limit the number of instances to process.", ) parser.add_argument( "--num_threads", type=int, default=2, help="Number of parallel threads to use." ) parser.add_argument( "--logging", type=str, default="false", help="Enable or disable per-instance logging ('true' or 'false').", ) args = parser.parse_args() data_list = load_jsonl(args.jsonl_file) # or to load the data from the Hugging Face dataset # dataset = load_dataset("birdsql/bird-critic-1.0-flash-exp") # data_list = dataset["flash"] if not data_list: print("No data found in the JSONL file.") sys.exit(1) if args.limit is not None: data_list = data_list[: args.limit] # Collect base DB names all_db_names = set() for d in data_list: if "selected_database" in d: all_db_names.add(d["selected_database"]) # summary logger base_output_folder = os.path.splitext(args.jsonl_file)[0] ephemeral_db_log_filename = f"{base_output_folder}_multi_thread.log" ephemeral_db_logger = configure_logger(ephemeral_db_log_filename) ephemeral_db_logger.info( f"=== Starting Multi-Thread Evaluation with {args.num_threads} threads ===" ) # Create ephemeral DB copies ephemeral_db_pool_dict = create_ephemeral_db_copies( base_db_names=all_db_names, num_copies=args.num_threads, pg_password="123123", logger=ephemeral_db_logger, ) # Initialize queues ephemeral_db_queues = {} for base_db, ephemeral_list in ephemeral_db_pool_dict.items(): q = queue.Queue() for ep_db in ephemeral_list: q.put(ep_db) ephemeral_db_queues[base_db] = q global_stats_lock = threading.Lock() results = [] total_instances = len(data_list) with ThreadPoolExecutor(max_workers=args.num_threads) as executor, tqdm_progress( total=total_instances, desc="Evaluating Questions" ) as pbar: future_to_data = {} for data_item in data_list: future = executor.submit( process_one_instance, data_item, ephemeral_db_queues, args, global_stats_lock, ) future_to_data[future] = data_item for fut in as_completed(future_to_data): res = fut.result() results.append(res) pbar.update(1) question_test_case_results = results[:] # Summarize results total_errors = ( number_of_execution_errors + number_of_timeouts + number_of_assertion_errors ) overall_accuracy = ( ((total_instances - total_errors) / total_instances * 100) if total_instances > 0 else 0.0 ) timestamp = datetime.now().isoformat(sep=" ", timespec="microseconds") report_file_path = f"{base_output_folder}_report.txt" # Generate the report + update data_list save_report_and_status( report_file_path, question_test_case_results, data_list, number_of_execution_errors, number_of_timeouts, number_of_assertion_errors, overall_accuracy, timestamp, ephemeral_db_logger, ) print("Overall report generated:", report_file_path) # If logging enabled, output JSONL with status if args.logging == "true": output_jsonl_file = f"{base_output_folder}_output_with_status.jsonl" with open(output_jsonl_file, "w") as f: for i, data in enumerate(data_list): data["status"] = question_test_case_results[i]["status"] data["error_message"] = question_test_case_results[i]["error_message"] f.write(json.dumps(data) + "\n") # Close all pools, drop ephemeral DBs try: close_all_postgresql_pools() except Exception as e: print(f"Failed to close all PostgreSQL pools: {e}") drop_ephemeral_dbs(ephemeral_db_pool_dict, "123123", ephemeral_db_logger) ephemeral_db_logger.info("All ephemeral databases have been dropped.") if __name__ == "__main__": main()