mirrored 2 minutes ago
0
BlackSoi1update evaluation script and readme 1c6ba4c
# db_utils.py
import os
import subprocess
import psycopg2
from psycopg2 import OperationalError
from psycopg2.pool import SimpleConnectionPool
from logger import (
    log_section_header,
    log_section_footer,
    PrintLogger
)


_postgresql_pools = {}

DEFAULT_DB_CONFIG = {
    "minconn": 1,
    "maxconn": 5,
    "user": "root",
    "password": "123123",
    "host": "bird_critic_postgresql",
    "port": 5432,
}

def _get_or_init_pool(db_name):
    """
    Returns a connection pool for the given database name, creating one if it does not exist.
    """
    if db_name not in _postgresql_pools:
        config = DEFAULT_DB_CONFIG.copy()
        config.update({"dbname": db_name})
        _postgresql_pools[db_name] = SimpleConnectionPool(
            config["minconn"],
            config["maxconn"],
            dbname=config["dbname"],
            user=config["user"],
            password=config["password"],
            host=config["host"],
            port=config["port"],
        )
    return _postgresql_pools[db_name]

def perform_query_on_postgresql_databases(query, db_name, conn=None):
    """
    Executes the given query on the specified database, returns (result, conn).
    Automatically commits if the query is recognized as a write operation.
    """
    MAX_ROWS = 10000
    pool = _get_or_init_pool(db_name)
    need_to_put_back = False

    if conn is None:
        conn = pool.getconn()
        need_to_put_back = True

    cursor = conn.cursor()
    cursor.execute("SET statement_timeout = '60s';")  # 60s query timeout

    try:
        cursor.execute(query)
        lower_q = query.strip().lower()
        conn.commit()

        if lower_q.startswith("select") or lower_q.startswith("with"):
            # Fetch up to MAX_ROWS + 1 to see if there's an overflow
            rows = cursor.fetchmany(MAX_ROWS + 1)
            if len(rows) > MAX_ROWS:
                rows = rows[:MAX_ROWS]
            result = rows
        else:
            try:
                result = cursor.fetchall()
            except psycopg2.ProgrammingError:
                result = None

        return (result, conn)

    except Exception as e:
        conn.rollback()
        raise e
    finally:
        cursor.close()
        if need_to_put_back:
            # If you only need a single query, you could return it right away:
            # But usually, we keep the same conn for subsequent queries, so do nothing.
            # If you truly do not want to reuse it, uncomment below:
            # pool.putconn(conn)
            pass

def close_postgresql_connection(db_name, conn):
    """
    Release a connection back to the pool when you are done with it.
    """
    if db_name in _postgresql_pools:
        pool = _postgresql_pools[db_name]
        pool.putconn(conn)

def close_all_postgresql_pools():
    """
    Closes all connections in all pools (e.g., at application shutdown).
    """
    for pool in _postgresql_pools.values():
        pool.closeall()
    _postgresql_pools.clear()

def close_postgresql_pool(db_name):
    """
    Close the pool for a specific db_name and remove its reference.
    """
    if db_name in _postgresql_pools:
        pool = _postgresql_pools.pop(db_name)
        pool.closeall()

def get_connection_for_phase(db_name, logger):
    """
    Acquire a new connection (borrowed from the connection pool) for a specific phase.
    """
    logger.info(f"Acquiring dedicated connection for phase on db: {db_name}")
    result, conn = perform_query_on_postgresql_databases("SELECT 1", db_name, conn=None)
    return conn

def reset_and_restore_database(db_name, pg_password, logger):
    """
    Resets the database by dropping it and re-creating it from its template.
    1) close pool
    2) terminate connections
    3) dropdb
    4) createdb --template ...
    """
    pg_host = "bird_critic_postgresql"
    pg_port = 5432
    pg_user = "root"

    env_vars = os.environ.copy()
    env_vars["PGPASSWORD"] = pg_password
    base_db_name = db_name.split('_process_')[0]
    template_db_name = f"{base_db_name}_template"

    logger.info(f"Resetting database {db_name} using template {template_db_name}")

    # 1) Close the pool
    logger.info(f"Closing connection pool for database {db_name} before resetting.")
    close_postgresql_pool(db_name)

    # 2) Terminate existing connections
    terminate_command = [
        "psql",
        "-h", pg_host,
        "-p", str(pg_port),
        "-U", pg_user,
        "-d", "postgres",
        "-c",
        f"""
        SELECT pg_terminate_backend(pid)
        FROM pg_stat_activity
        WHERE datname = '{db_name}' AND pid <> pg_backend_pid();
        """
    ]
    subprocess.run(
        terminate_command,
        check=True,
        env=env_vars,
        timeout=60,
        stdout=subprocess.DEVNULL,
        stderr=subprocess.DEVNULL
    )
    logger.info(f"All connections to database {db_name} have been terminated.")

    # 3) dropdb
    drop_command = [
        "dropdb",
        "--if-exists",
        "-h", pg_host,
        "-p", str(pg_port),
        "-U", pg_user,
        db_name,
    ]
    subprocess.run(drop_command, check=True, env=env_vars, timeout=60,
                   stdout=subprocess.DEVNULL, stderr=subprocess.DEVNULL)
    logger.info(f"Database {db_name} dropped if it existed.")

    # 4) createdb --template=xxx_template
    create_command = [
        "createdb",
        "-h", pg_host,
        "-p", str(pg_port),
        "-U", pg_user,
        db_name,
        "--template", template_db_name
    ]
    subprocess.run(create_command, check=True, env=env_vars, timeout=60,
                   stdout=subprocess.DEVNULL, stderr=subprocess.DEVNULL)
    logger.info(f"Database {db_name} created from template {template_db_name} successfully.")

def create_ephemeral_db_copies(base_db_names, num_copies, pg_password, logger):
    """
    For each base database in base_db_names, create `num_copies` ephemeral DB copies 
    from base_db_template. Return a dict: {base_db: [ephemeral1, ephemeral2, ...], ...}
    """
    pg_host = "bird_critic_postgresql"
    pg_port = 5432
    pg_user = "root"
    env_vars = os.environ.copy()
    env_vars["PGPASSWORD"] = pg_password

    ephemeral_db_pool = {}

    for base_db in base_db_names:
        base_template = f"{base_db}_template"
        ephemeral_db_pool[base_db] = []

        for i in range(1, num_copies+1):
            ephemeral_name = f"{base_db}_process_{i}"
            # If it already exists, drop it first
            drop_cmd = [
                "dropdb", "--if-exists",
                "-h", pg_host,
                "-p", str(pg_port),
                "-U", pg_user,
                ephemeral_name
            ]
            subprocess.run(drop_cmd, check=False, env=env_vars,
                           stdout=subprocess.DEVNULL, stderr=subprocess.DEVNULL)

            # createdb
            create_cmd = [
                "createdb",
                "-h", pg_host,
                "-p", str(pg_port),
                "-U", pg_user,
                ephemeral_name,
                "--template", base_template
            ]
            logger.info(f"Creating ephemeral db {ephemeral_name} from {base_template}...")
            subprocess.run(create_cmd, check=True, env=env_vars,
                           stdout=subprocess.DEVNULL, stderr=subprocess.DEVNULL)

            ephemeral_db_pool[base_db].append(ephemeral_name)

        logger.info(f"For base_db={base_db}, ephemeral db list = {ephemeral_db_pool[base_db]}")

    return ephemeral_db_pool

def drop_ephemeral_dbs(ephemeral_db_pool_dict, pg_password, logger):
    """
    Delete all ephemeral databases created during the script execution.
    """
    pg_host = "bird_critic_postgresql"
    pg_port = 5432
    pg_user = "root"
    env_vars = os.environ.copy()
    env_vars["PGPASSWORD"] = pg_password

    logger.info("=== Cleaning up ephemeral databases ===")
    for base_db, ephemeral_list in ephemeral_db_pool_dict.items():
        for ephemeral_db in ephemeral_list:
            logger.info(f"Dropping ephemeral db: {ephemeral_db}")
            drop_cmd = [
                "dropdb",
                "--if-exists",
                "-h", pg_host,
                "-p", str(pg_port),
                "-U", pg_user,
                ephemeral_db
            ]
            try:
                subprocess.run(drop_cmd, check=True, env=env_vars,
                               stdout=subprocess.DEVNULL, stderr=subprocess.DEVNULL)
            except subprocess.CalledProcessError as e:
                logger.error(f"Failed to drop ephemeral db {ephemeral_db}: {e}")


def execute_queries(queries, db_name, conn, logger=None, section_title=""):
    """
    Execute a list of queries using the SAME connection (conn).
    Returns (query_result, execution_error_flag, timeout_flag).
    Once the first error occurs, we break out and return immediately.
    """
    if logger is None:
        logger = PrintLogger()

    log_section_header(section_title, logger)
    query_result = None
    execution_error = False
    timeout_error = False

    for i, query in enumerate(queries):
        try:
            logger.info(f"Executing query {i+1}/{len(queries)}: {query}")
            query_result, conn = perform_query_on_postgresql_databases(query, db_name, conn=conn)
            logger.info(f"Query result: {query_result}")

        except psycopg2.errors.QueryCanceled as e:
            # Timeout error
            logger.error(f"Timeout error executing query {i+1}: {e}")
            timeout_error = True
            break

        except OperationalError as e:
            # Operational errors (e.g., server not available, etc.)
            logger.error(f"OperationalError executing query {i+1}: {e}")
            execution_error = True
            break

        except psycopg2.Error as e:
            # Other psycopg2 errors (e.g., syntax errors, constraint violations)
            logger.error(f"psycopg2 Error executing query {i+1}: {e}")
            execution_error = True
            break

        except Exception as e:
            # Any other generic error
            logger.error(f"Generic error executing query {i+1}: {e}")
            execution_error = True
            break

        finally:
            logger.info(f"[{section_title}] DB: {db_name}, conn info: {conn}")

        # If an error is flagged, don't continue subsequent queries
        if execution_error or timeout_error:
            break

    log_section_footer(logger)
    return query_result, execution_error, timeout_error