# db_utils.py import os import subprocess import psycopg2 from psycopg2 import OperationalError from psycopg2.pool import SimpleConnectionPool from logger import ( log_section_header, log_section_footer, PrintLogger ) _postgresql_pools = {} DEFAULT_DB_CONFIG = { "minconn": 1, "maxconn": 5, "user": "root", "password": "123123", "host": "bird_critic_postgresql", "port": 5432, } def _get_or_init_pool(db_name): """ Returns a connection pool for the given database name, creating one if it does not exist. """ if db_name not in _postgresql_pools: config = DEFAULT_DB_CONFIG.copy() config.update({"dbname": db_name}) _postgresql_pools[db_name] = SimpleConnectionPool( config["minconn"], config["maxconn"], dbname=config["dbname"], user=config["user"], password=config["password"], host=config["host"], port=config["port"], ) return _postgresql_pools[db_name] def perform_query_on_postgresql_databases(query, db_name, conn=None): """ Executes the given query on the specified database, returns (result, conn). Automatically commits if the query is recognized as a write operation. """ MAX_ROWS = 10000 pool = _get_or_init_pool(db_name) need_to_put_back = False if conn is None: conn = pool.getconn() need_to_put_back = True cursor = conn.cursor() cursor.execute("SET statement_timeout = '60s';") # 60s query timeout try: cursor.execute(query) lower_q = query.strip().lower() conn.commit() if lower_q.startswith("select") or lower_q.startswith("with"): # Fetch up to MAX_ROWS + 1 to see if there's an overflow rows = cursor.fetchmany(MAX_ROWS + 1) if len(rows) > MAX_ROWS: rows = rows[:MAX_ROWS] result = rows else: try: result = cursor.fetchall() except psycopg2.ProgrammingError: result = None return (result, conn) except Exception as e: conn.rollback() raise e finally: cursor.close() if need_to_put_back: # If you only need a single query, you could return it right away: # But usually, we keep the same conn for subsequent queries, so do nothing. # If you truly do not want to reuse it, uncomment below: # pool.putconn(conn) pass def close_postgresql_connection(db_name, conn): """ Release a connection back to the pool when you are done with it. """ if db_name in _postgresql_pools: pool = _postgresql_pools[db_name] pool.putconn(conn) def close_all_postgresql_pools(): """ Closes all connections in all pools (e.g., at application shutdown). """ for pool in _postgresql_pools.values(): pool.closeall() _postgresql_pools.clear() def close_postgresql_pool(db_name): """ Close the pool for a specific db_name and remove its reference. """ if db_name in _postgresql_pools: pool = _postgresql_pools.pop(db_name) pool.closeall() def get_connection_for_phase(db_name, logger): """ Acquire a new connection (borrowed from the connection pool) for a specific phase. """ logger.info(f"Acquiring dedicated connection for phase on db: {db_name}") result, conn = perform_query_on_postgresql_databases("SELECT 1", db_name, conn=None) return conn def reset_and_restore_database(db_name, pg_password, logger): """ Resets the database by dropping it and re-creating it from its template. 1) close pool 2) terminate connections 3) dropdb 4) createdb --template ... """ pg_host = "bird_critic_postgresql" pg_port = 5432 pg_user = "root" env_vars = os.environ.copy() env_vars["PGPASSWORD"] = pg_password base_db_name = db_name.split('_process_')[0] template_db_name = f"{base_db_name}_template" logger.info(f"Resetting database {db_name} using template {template_db_name}") # 1) Close the pool logger.info(f"Closing connection pool for database {db_name} before resetting.") close_postgresql_pool(db_name) # 2) Terminate existing connections terminate_command = [ "psql", "-h", pg_host, "-p", str(pg_port), "-U", pg_user, "-d", "postgres", "-c", f""" SELECT pg_terminate_backend(pid) FROM pg_stat_activity WHERE datname = '{db_name}' AND pid <> pg_backend_pid(); """ ] subprocess.run( terminate_command, check=True, env=env_vars, timeout=60, stdout=subprocess.DEVNULL, stderr=subprocess.DEVNULL ) logger.info(f"All connections to database {db_name} have been terminated.") # 3) dropdb drop_command = [ "dropdb", "--if-exists", "-h", pg_host, "-p", str(pg_port), "-U", pg_user, db_name, ] subprocess.run(drop_command, check=True, env=env_vars, timeout=60, stdout=subprocess.DEVNULL, stderr=subprocess.DEVNULL) logger.info(f"Database {db_name} dropped if it existed.") # 4) createdb --template=xxx_template create_command = [ "createdb", "-h", pg_host, "-p", str(pg_port), "-U", pg_user, db_name, "--template", template_db_name ] subprocess.run(create_command, check=True, env=env_vars, timeout=60, stdout=subprocess.DEVNULL, stderr=subprocess.DEVNULL) logger.info(f"Database {db_name} created from template {template_db_name} successfully.") def create_ephemeral_db_copies(base_db_names, num_copies, pg_password, logger): """ For each base database in base_db_names, create `num_copies` ephemeral DB copies from base_db_template. Return a dict: {base_db: [ephemeral1, ephemeral2, ...], ...} """ pg_host = "bird_critic_postgresql" pg_port = 5432 pg_user = "root" env_vars = os.environ.copy() env_vars["PGPASSWORD"] = pg_password ephemeral_db_pool = {} for base_db in base_db_names: base_template = f"{base_db}_template" ephemeral_db_pool[base_db] = [] for i in range(1, num_copies+1): ephemeral_name = f"{base_db}_process_{i}" # If it already exists, drop it first drop_cmd = [ "dropdb", "--if-exists", "-h", pg_host, "-p", str(pg_port), "-U", pg_user, ephemeral_name ] subprocess.run(drop_cmd, check=False, env=env_vars, stdout=subprocess.DEVNULL, stderr=subprocess.DEVNULL) # createdb create_cmd = [ "createdb", "-h", pg_host, "-p", str(pg_port), "-U", pg_user, ephemeral_name, "--template", base_template ] logger.info(f"Creating ephemeral db {ephemeral_name} from {base_template}...") subprocess.run(create_cmd, check=True, env=env_vars, stdout=subprocess.DEVNULL, stderr=subprocess.DEVNULL) ephemeral_db_pool[base_db].append(ephemeral_name) logger.info(f"For base_db={base_db}, ephemeral db list = {ephemeral_db_pool[base_db]}") return ephemeral_db_pool def drop_ephemeral_dbs(ephemeral_db_pool_dict, pg_password, logger): """ Delete all ephemeral databases created during the script execution. """ pg_host = "bird_critic_postgresql" pg_port = 5432 pg_user = "root" env_vars = os.environ.copy() env_vars["PGPASSWORD"] = pg_password logger.info("=== Cleaning up ephemeral databases ===") for base_db, ephemeral_list in ephemeral_db_pool_dict.items(): for ephemeral_db in ephemeral_list: logger.info(f"Dropping ephemeral db: {ephemeral_db}") drop_cmd = [ "dropdb", "--if-exists", "-h", pg_host, "-p", str(pg_port), "-U", pg_user, ephemeral_db ] try: subprocess.run(drop_cmd, check=True, env=env_vars, stdout=subprocess.DEVNULL, stderr=subprocess.DEVNULL) except subprocess.CalledProcessError as e: logger.error(f"Failed to drop ephemeral db {ephemeral_db}: {e}") def execute_queries(queries, db_name, conn, logger=None, section_title=""): """ Execute a list of queries using the SAME connection (conn). Returns (query_result, execution_error_flag, timeout_flag). Once the first error occurs, we break out and return immediately. """ if logger is None: logger = PrintLogger() log_section_header(section_title, logger) query_result = None execution_error = False timeout_error = False for i, query in enumerate(queries): try: logger.info(f"Executing query {i+1}/{len(queries)}: {query}") query_result, conn = perform_query_on_postgresql_databases(query, db_name, conn=conn) logger.info(f"Query result: {query_result}") except psycopg2.errors.QueryCanceled as e: # Timeout error logger.error(f"Timeout error executing query {i+1}: {e}") timeout_error = True break except OperationalError as e: # Operational errors (e.g., server not available, etc.) logger.error(f"OperationalError executing query {i+1}: {e}") execution_error = True break except psycopg2.Error as e: # Other psycopg2 errors (e.g., syntax errors, constraint violations) logger.error(f"psycopg2 Error executing query {i+1}: {e}") execution_error = True break except Exception as e: # Any other generic error logger.error(f"Generic error executing query {i+1}: {e}") execution_error = True break finally: logger.info(f"[{section_title}] DB: {db_name}, conn info: {conn}") # If an error is flagged, don't continue subsequent queries if execution_error or timeout_error: break log_section_footer(logger) return query_result, execution_error, timeout_error