import os import pdb import sys import json import numpy as np import argparse import sqlite3 import multiprocessing as mp from func_timeout import func_timeout, FunctionTimedOut import time import math def result_callback(result): exec_result.append(result) def clean_abnormal(input): input = np.asarray(input) processed_list = [] mean = np.mean(input,axis=0) std = np.std(input,axis=0) for x in input: if x < mean + 3 * std and x > mean - 3 * std: processed_list.append(x) return processed_list def execute_sql(sql, db_path): # Connect to the database conn = sqlite3.connect(db_path) # Create a cursor object cursor = conn.cursor() start_time = time.time() cursor.execute(sql) exec_time = time.time() - start_time return exec_time def iterated_execute_sql(predicted_sql,ground_truth,db_path,iterate_num): conn = sqlite3.connect(db_path) diff_list = [] cursor = conn.cursor() cursor.execute(predicted_sql) predicted_res = cursor.fetchall() cursor.execute(ground_truth) ground_truth_res = cursor.fetchall() time_ratio = 0 if set(predicted_res) == set(ground_truth_res): for i in range(iterate_num): predicted_time = execute_sql(predicted_sql, db_path) ground_truth_time = execute_sql(ground_truth, db_path) diff_list.append(ground_truth_time / predicted_time) processed_diff_list = clean_abnormal(diff_list) time_ratio = sum(processed_diff_list) / len(processed_diff_list) return time_ratio def execute_model(predicted_sql,ground_truth, db_place, idx, iterate_num, meta_time_out): try: # you can personalize the total timeout number # larger timeout leads to more stable ves # while it needs more your patience.... time_ratio = func_timeout(meta_time_out * iterate_num, iterated_execute_sql, args=(predicted_sql, ground_truth, db_place, iterate_num)) # print([idx, math.sqrt(time_ratio)]) except KeyboardInterrupt: sys.exit(0) except FunctionTimedOut: result = [(f'timeout',)] time_ratio = 0 except Exception as e: result = [(f'error',)] # possibly len(query) > 512 or not executable time_ratio = 0 result = {'sql_idx': idx, 'time_ratio': time_ratio} return result def package_sqls(sql_path, db_root_path, mode='gpt', data_mode='dev', cot=False): clean_sqls = [] db_path_list = [] if mode == 'gpt': if cot: sql_data = json.load(open(sql_path + 'predict_' + data_mode + '_cot.json', 'r')) else: sql_data = json.load(open(sql_path + 'predict_' + data_mode + '.json', 'r')) for idx, sql_str in sql_data.items(): if type(sql_str) == str: sql, db_name = sql_str.split('\t----- bird -----\t') else: sql, db_name = " ", "financial" clean_sqls.append(sql) db_path_list.append(db_root_path + db_name + '/' + db_name + '.sqlite') elif mode == 'gt': sqls = open(sql_path + data_mode + '_gold.sql') sql_txt = sqls.readlines() for idx, sql_str in enumerate(sql_txt): sql, db_name = sql_str.strip().split('\t') clean_sqls.append(sql) db_path_list.append(db_root_path + db_name + '/' + db_name + '.sqlite') return clean_sqls, db_path_list def run_sqls_parallel(sqls, db_places, num_cpus=1, iterate_num=100, meta_time_out=30.0): pool = mp.Pool(processes=num_cpus) for i,sql_pair in enumerate(sqls): predicted_sql, ground_truth = sql_pair pool.apply_async(execute_model, args=(predicted_sql, ground_truth, db_places[i], i, iterate_num, meta_time_out), callback=result_callback) pool.close() pool.join() def sort_results(list_of_dicts): return sorted(list_of_dicts, key=lambda x: x['sql_idx']) def compute_ves(exec_results): num_queries = len(exec_results) total_ratio = 0 count = 0 for i, result in enumerate(exec_results): if result['time_ratio'] != 0: count += 1 total_ratio += math.sqrt(result['time_ratio']) * 100 ves = (total_ratio/num_queries) return ves def load_json(dir): with open(dir, 'r') as j: contents = json.loads(j.read()) return contents def compute_ves_by_diff(exec_results,diff_json_path): num_queries = len(exec_results) contents = load_json(diff_json_path) simple_results, moderate_results, challenging_results = [], [], [] for i,content in enumerate(contents): if content['difficulty'] == 'simple': simple_results.append(exec_results[i]) if content['difficulty'] == 'moderate': moderate_results.append(exec_results[i]) if content['difficulty'] == 'challenging': challenging_results.append(exec_results[i]) simple_ves = compute_ves(simple_results) moderate_ves = compute_ves(moderate_results) challenging_ves = compute_ves(challenging_results) all_ves = compute_ves(exec_results) count_lists = [len(simple_results), len(moderate_results), len(challenging_results), num_queries] return simple_ves, moderate_ves, challenging_ves, all_ves, count_lists def print_data(score_lists,count_lists): levels = ['simple', 'moderate', 'challenging', 'total'] print("{:20} {:20} {:20} {:20} {:20}".format("", *levels)) print("{:20} {:<20} {:<20} {:<20} {:<20}".format('count', *count_lists)) print('========================================= VES ========================================') print("{:20} {:<20.2f} {:<20.2f} {:<20.2f} {:<20.2f}".format('ves', *score_lists)) results = { "accuracy": { "simple": score_lists[0], "moderate": score_lists[1], "challenging": score_lists[2], "total": score_lists[3] }, "counts": { "simple": count_lists[0], "moderate": count_lists[1], "challenging": count_lists[2], "total": count_lists[3] } } output_file = "../result/evaluation_result_ves.json" os.makedirs(os.path.dirname(output_file), exist_ok=True) with open(output_file, "w") as f: json.dump(results, f, indent=4) if __name__ == '__main__': args_parser = argparse.ArgumentParser() args_parser.add_argument('--predicted_sql_path', type=str, required=True, default='') args_parser.add_argument('--ground_truth_path', type=str, required=True, default='') args_parser.add_argument('--data_mode', type=str, default='dev') args_parser.add_argument('--db_root_path', type=str, required=True, default='') args_parser.add_argument('--num_cpus', type=int, default=4) args_parser.add_argument('--meta_time_out', type=float, default=30.0) args_parser.add_argument('--mode_gt', type=str, default='gt') args_parser.add_argument('--mode_predict', type=str, default='gpt') args_parser.add_argument('--diff_json_path',type=str,default='') args_parser.add_argument('--cot', action='store_true') args = args_parser.parse_args() exec_result = [] pred_queries, db_paths = package_sqls(args.predicted_sql_path, args.db_root_path, mode=args.mode_predict, data_mode=args.data_mode, cot=args.cot) # generate gt sqls: gt_queries, db_paths_gt = package_sqls(args.ground_truth_path, args.db_root_path, mode='gt', data_mode=args.data_mode) query_pairs = list(zip(pred_queries, gt_queries)) run_sqls_parallel(query_pairs, db_places=db_paths, num_cpus=args.num_cpus, meta_time_out=args.meta_time_out) exec_result = sort_results(exec_result) print('start calculate') simple_ves, moderate_ves, challenging_ves, ves, count_lists = \ compute_ves_by_diff(exec_result, args.diff_json_path) score_lists = [simple_ves, moderate_ves, challenging_ves, ves] print_data(score_lists, count_lists) print('===========================================================================================') print("Finished evaluation")