import os import re import sys import json import dotenv import logging import argparse from enum import Enum from pathlib import Path from datetime import datetime from typing import Callable, Union, Tuple from langchain.prompts import PromptTemplate from langchain_openai import ChatOpenAI from langchain_core.output_parsers import StrOutputParser from benchflow import BenchClient from typing import Dict, Any import utils class MedQACSClient(BenchClient): def __init__(self, intelligence_url: str): print(f"Initializing MedQACSClient with intelligence_url: {intelligence_url}") super().__init__(intelligence_url) def prepare_input(self, raw_step_inputs: Dict[str, Any]) -> Dict[str, Any]: return {"user_prompt": raw_step_inputs["user_prompt"], "prompt_template": raw_step_inputs["prompt_template"], "filling_data": raw_step_inputs["input_data"]} def parse_response(self, raw_step_outputs: str) -> Dict[str, Any]: return {"output": raw_step_outputs} class Section(Enum): qa = "qa" physical_exam = "physical_exam" closure = "closure" diagnosis = "diagnosis" other = "other" def load_data(dataset_path: str, is_examiner: bool, section: str, case: str): dataset_path = Path(dataset_path).resolve() if dataset_path.is_dir(): # Locate the appropriate dataset file using default naming if not is_examiner: dataset_path = dataset_path / "med-student.json" else: dataset_path = dataset_path / "med-exam.json" if is_examiner: logging.info(f"Loading examiner dataset from {dataset_path}") else: logging.info(f"Loading medical student dataset from {dataset_path}") with dataset_path.open("r") as f: dataset = json.load(f) # filter dataset by section dataset = [data for data in dataset if data["section"] == section] # filter dataset by case dataset = [data for data in dataset if data["case_id"] == int(case)] return dataset def save_result(path: str, dataset: dict, is_examiner: bool, section: str): output_path = Path(path).resolve() if output_path.suffix != "": output_path.parent.mkdir(parents=True, exist_ok=True) output_file = output_path else: output_path.mkdir(parents=True, exist_ok=True) file_prefix = "med-exam" if is_examiner else "med-student" output_file = output_path / f"{file_prefix}.json" dataset = [data for data in dataset if data["section"] == section] logging.debug(f"Saving dataset to {output_file}") with output_file.open("w", encoding="UTF-8") as f: json.dump(dataset, f, indent=2) return str(output_file) def parse_range(val: str) -> Union[Tuple[int, int], None]: """ Parse a string input into a range of numbers or a single number. Args: val (str): The input string to parse. Returns: Union[Tuple[int, int], None]: A tuple representing the range (start, end), or (number, number) for a single number, or None if the input is invalid. Examples: parse_range("1-44") -> (1, 44) parse_range("5") -> (5, 5) parse_range("") -> None parse_range("invalid") -> None """ # Remove leading and trailing whitespace val = str(val).strip() if not val: logging.info("Empty range!") return None # Match a range like "1-44" range_match = re.fullmatch(r"(\d+)\s*-\s*(\d+)", val) if range_match: start, end = map(int, range_match.groups()) return start, end # Match a single number like "1" single_num_match = re.fullmatch(r"(\d+)", val) if single_num_match: num = int(single_num_match.group(1)) return num, num logging.info("Invalid range!") return None def call_api( prompt_template: str, input_data: dict[str, str], pre_processing_func: Callable = lambda x: x, post_processing_func: Callable = lambda x: x["output"], intelligence_url: str = None, **kwargs, ) -> str: """ Run a language model with the given prompt template and input data. Args: model: The language model to use. prompt_template (str): The template for generating the prompt. input_data (dict[str, str]): The input data to be used in the prompt. pre_processing (Callable, optional): Function to pre-process the input data. Defaults to identity function. post_processing (Callable, optional): Function to post-process the model output. Defaults to identity function. **kwargs: Additional keyword arguments. Returns: str: The processed result from the model. Process: 1. Create a PromptTemplate from the given template. 2. Set up a StrOutputParser for parsing the model's output. 3. Create an evaluation chain: prompt -> model -> parser. 4. Pre-process the input data. 5. Run the evaluation chain with the pre-processed input. 6. Post-process the result. 7. Return the final processed result. """ # Create prompt template and parser prompt = PromptTemplate.from_template(prompt_template) # Preprocess input data pre_processed_input = pre_processing_func(input_data) filled_prompt = prompt.format(**pre_processed_input) client = MedQACSClient(intelligence_url) env = {"user_prompt": filled_prompt, "prompt_template": prompt_template, "input_data": input_data} raw_result = client.get_response(env)["output"] logging.debug(raw_result) # Post-process the result result = post_processing_func( { "prompt": prompt_template, "input": input_data, "output": raw_result, } ) return result def llm_as_medical_student( section: str, case: str, conversation_turn: str = "all", # only used for QA, other sections only has 1 conversation turn med_student_dataset_path: str = "data/med-student.json", output_path: str = "output/", # the path for the output file or a path to a folder that store the output file model=None, # one of the langchain model class, will override model_parameters model_parameters: dict = None, # only used if model is not None prompt_template: dict[ int, str ] = None, # Custom prompt templates for each case number input_data: dict[ int, dict[str, str] ] = None, # Custom input data for each case number pre_processing: Callable = None, post_processing: Callable = None, intelligence_url: str = None, **kwargs, ): """ Simulates a medical student using a language model to generate responses for different sections of a medical examination. Args: section (str): The section of the medical examination (e.g., 'qa', 'physical_exam', 'closure', 'diagnosis'). case (str): The case number or range to process. conversation_turn (str): The specific turn in the conversation or 'all' for all turns. med_student_dataset_path (str): Path to the dataset file or directory. output_path (str): Path to save the output file or directory. model: The language model to use (default is None, which will use a default model). model_parameters (dict): Parameters for the language model. prompt_template (dict): Custom prompt templates for each case. Example: { 1: "Prompt for case 1", 2: "Prompt for case 2" } input_data (dict): Custom input data for each case. Example: { 1: { "input_var1": "value1", "input_var2": "value2" }, 2: { "input_var1": "value3", "input_var2": "value4" } } pre_processing (Callable): Function to preprocess input data. post_processing (Callable): Function to post-process model output. **kwargs: Additional keyword arguments. Returns: None: Results are saved to the specified output path. """ logging.info(f"Running llm as medical student on {section}:") # Load the dataset dataset = load_data(med_student_dataset_path, is_examiner=False, section=section, case=case) # Parse case range start_case, end_case = (1, 44) if str(case) == "all" else parse_range(case) # Check if parse_range returned None if start_case is None or end_case is None: logging.error("Invalid case range provided. Exiting.") sys.exit(1) # Determine whether to use dataset prompt template or custom prompt use_dataset_prompt_template = prompt_template is None # Determine whether to use dataset input data or custom input data use_dataset_input_data = input_data is None # Set default pre-processing and post-processing functions if not provided if pre_processing is None: pre_processing = lambda x: x # Define post-processing functions for different sections post_processing_func = { Section.qa.value: utils.medical_student_qa_post_processing, Section.closure.value: utils.output_only_post_processing, Section.physical_exam.value: utils.medical_student_physical_exam_post_processing, Section.diagnosis.value: utils.medical_student_diagnosis_post_processing, } if post_processing is None: post_processing = post_processing_func[section] for data in dataset: if ( data["section"] == section and start_case <= int(data["case_id"]) <= end_case ): # Handle specific conversation turns for QA section if section == Section.qa.value and str(conversation_turn) != "all": start_conversation_turn, end_conversation_turn = parse_range( conversation_turn ) if not ( start_conversation_turn <= int(data["conversation_turn_id"]) <= end_conversation_turn ): continue logging.info( f'Running {data["section"]} case {data["case_id"]}, turn {data["conversation_turn_id"]}' ) if use_dataset_prompt_template: prompt = data["prompt"]["template"] else: prompt = prompt_template[int(data["case_id"])] if use_dataset_input_data: input_data_dict = data["input"] else: input_data_dict = input_data[int(data["case_id"])] # Run the model result = call_api( prompt_template=prompt, input_data=input_data_dict, pre_processing_func=pre_processing, post_processing_func=post_processing, intelligence_url=intelligence_url, **kwargs, ) # parse result logging.debug(result) # save result data["output"]["benchflow"] = result # save updated dataset output_file_path = save_result(output_path, dataset, is_examiner=False, section=section) logging.info(f"Finished. Output saved to: {output_file_path}") return output_file_path def run_model( model, prompt_template: str, input_data: dict[str, str], pre_processing_func: Callable = lambda x: x, post_processing_func: Callable = lambda x: x["output"], **kwargs, ) -> str: """ Run a language model with the given prompt template and input data. Args: model: The language model to use. prompt_template (str): The template for generating the prompt. input_data (dict[str, str]): The input data to be used in the prompt. pre_processing (Callable, optional): Function to pre-process the input data. Defaults to identity function. post_processing (Callable, optional): Function to post-process the model output. Defaults to identity function. **kwargs: Additional keyword arguments. Returns: str: The processed result from the model. Process: 1. Create a PromptTemplate from the given template. 2. Set up a StrOutputParser for parsing the model's output. 3. Create an evaluation chain: prompt -> model -> parser. 4. Pre-process the input data. 5. Run the evaluation chain with the pre-processed input. 6. Post-process the result. 7. Return the final processed result. """ # Create prompt template and parser prompt = PromptTemplate.from_template(prompt_template) parser = StrOutputParser() # Set up the evaluation chain eval_chain = prompt | model | parser # Preprocess input data pre_processed_input = pre_processing_func(input_data) # Run the model and get the result if use_langfuse: config = {"callbacks": [langfuse_handler]} else: config = {} raw_result = eval_chain.invoke(pre_processed_input, config=config) logging.debug(raw_result) # Post-process the result result = post_processing_func( { "prompt": prompt_template, "input": input_data, "output": raw_result, } ) return result def llm_as_examiner( section: str, # one of the section in Section enum case: str, # case number from 1 to 44 or "all" conversation_turn: str = "all", # only used for QA, other sections only has 1 conversation turn med_student_dataset_path: str = None, med_exam_dataset_path: str = "data/med-exam.json", output_path: str = "output/", model=None, # one of the langchain model class model_parameters=None, # parameters used to initialize the model via langchainChatOpenAI class, only used if model is not None input_student_model_name: str = None, # the model name of the medical student's output that will be used as the examiner's input prompt_template: dict[ int, str ] = None, # Custom prompt templates for each case number input_data: dict[ int, dict[str, str] ] = None, # Custom input data for each case number pre_processing: Callable = None, # function to preprocess input data post_processing: Callable = None, # function to post-process model output **kwargs, ): """ Simulates a medical examiner using a language model to evaluate responses from medical students. Args: section (str): The section of the medical examination (e.g., 'qa', 'physical_exam', 'closure', 'diagnosis'). case (str): The case number or range to process. conversation_turn (str): The specific turn in the conversation or 'all' for all turns. med_student_dataset_path (str): Path to the dataset file or directory. med_exam_dataset_path (str): Path to the dataset file or directory. output_path (str): Path to save the output file or directory. model: The language model to use (default is None, which will use a default model). model_parameters (dict): Parameters for the language model. input_student_model_name (str): The model name of the medical student's output that will be used as the examiner's input. prompt_template (dict): Custom prompt templates for each case. Example: { 1: "Prompt for case 1", 2: "Prompt for case 2" } input_data (dict): Custom input data for each case. Example: { 1: { "input_var1": "value1", "input_var2": "value2" }, 2: { "input_var1": "value3", "input_var2": "value4" } } pre_processing (Callable): Function to preprocess input data. post_processing (Callable): Function to post-process model output. **kwargs: Additional keyword arguments. Returns: None: Results are saved to the specified output path. """ from collections import defaultdict result_summary = defaultdict(int) logging.info(f"Running llm as examiner on {section}:") # Check if input_student_model_name is provided if input_student_model_name is None: logging.error( "Missing student model name specified which input data used to evaluate!" ) sys.exit(1) # Check if med_exam_dataset_path is provided if med_exam_dataset_path is None: logging.error("Need specify med-exam dataset!") sys.exit(1) # Load medical student dataset if provided if med_student_dataset_path: logging.info("Use med-student and med-exam dataset") student_dataset = load_data(med_student_dataset_path, is_examiner=False, section=section, case=case) else: logging.info("Use med-exam dataset only") student_dataset = None # Load medical examiner dataset dataset = load_data(med_exam_dataset_path, is_examiner=True, section=section, case=case) start_case, end_case = (1, 44) if str(case) == "all" else parse_range(case) # Check if parse_range returned None if start_case is None or end_case is None: logging.error("Invalid case range provided. Exiting.") sys.exit(1) # Determine whether to use dataset prompt template or custom prompt use_dataset_prompt_template = prompt_template is None # Determine whether to use dataset input data or custom input data use_dataset_input_data = input_data is None if model is None: if model_parameters is None: logging.info( "Using default model parameters: model_name=gpt-4-1106-preview, temperature=0" ) model = ChatOpenAI(model_name="gpt-4-1106-preview", temperature=0) else: model = ChatOpenAI(**model_parameters) # Set default pre-processing and post-processing functions if not provided if pre_processing is None: pre_processing = lambda x: x # Define post-processing functions for different sections post_processing_func = { Section.qa.value: utils.output_only_post_processing, Section.physical_exam.value: utils.output_only_post_processing, Section.closure.value: utils.output_only_post_processing, Section.diagnosis.value: utils.examiner_diagnosis_post_processing, } if post_processing is None: post_processing = post_processing_func[section] # get examiner model name examiner_model_name = model.model_name for index, data in enumerate(dataset): if ( data["section"] == section and start_case <= int(data["case_id"]) <= end_case ): # Handle specific conversation turns for QA section if section == Section.qa.value and str(conversation_turn) != "all": start_conversation_turn, end_conversation_turn = parse_range( conversation_turn ) if not ( start_conversation_turn <= int(data["conversation_turn_id"]) <= end_conversation_turn ): continue logging.info( f'Running {data["section"]} case {data["case_id"]}, turn {data["conversation_turn_id"]}' ) if use_dataset_prompt_template: prompt = data["prompt"]["template"] else: prompt = prompt_template[int(data["case_id"])] if use_dataset_input_data: # input_data_dict = data["input"] if section == Section.qa.value: input_dict_name = "question" else: input_dict_name = "pred" # choose one of the student model's output as examiner input if input_student_model_name not in data["input"][input_dict_name]: # find input data from med-student dataset if student_dataset is None: logging.error( f"Cannot find input data for model {input_student_model_name}!!!" ) logging.error( "Consider including the med-student dataset in the examiner task." ) sys.exit(1) student_data = student_dataset[index] if ( student_data["section"] != section or student_data["case_id"] != data["case_id"] or student_data["conversation_turn_id"] != data["conversation_turn_id"] ): logging.error("Error, student dataset info don't match!!!") # TODO: student dataset full search sys.exit(1) # add med-student output to med-exam dataset as examiner input if input_student_model_name in student_data["output"]: data["input"][input_dict_name][input_student_model_name] = ( student_data["output"][input_student_model_name] ) else: logging.error( f"Cannot find input data for model {input_student_model_name}!!!" ) logging.error( "Consider including the med-student dataset in the examiner task." ) sys.exit(1) # setting examiner model input data input_data_dict = {} for key, value in data["input"].items(): if key == input_dict_name: input_data_dict[key] = value[input_student_model_name] else: input_data_dict[key] = value else: input_data_dict = input_data[int(data["case_id"])] result = run_model( model=model, prompt_template=prompt, input_data=input_data_dict, pre_processing_func=pre_processing, post_processing_func=post_processing, **kwargs, ) match = re.search(r'```json\s*(\{.*?\})\s*```', result, re.DOTALL) if match: result = match.group(1) parsed_dict = json.loads(result) logging.debug(parsed_dict) if section == "qa": result_summary["total_score"] += 1 result_summary["score_obtained"] += int(parsed_dict['score']) elif section == "physical_exam" or section == "closure": result_summary["score_obtained"] += int(parsed_dict['overall score']) elif section == "diagnosis": match = re.match(r'(\d+)\s*/\s*(\d+)', parsed_dict['total score']) if match: score_obtained = int(match.group(1)) score_total = int(match.group(2)) result_summary["total_score"] += score_total result_summary["quality_score"] += parsed_dict['quality score'] result_summary["score_obtained"] += score_obtained # save result if input_student_model_name not in data["output"]: data["output"][input_student_model_name] = {} logging.debug(f"Creating new entry for {input_student_model_name}") data["output"][input_student_model_name][examiner_model_name] = result # save updated dataset output_file_path = save_result(output_path, dataset, is_examiner=True, section=section) summary_path = Path("result/result_summary.json") summary_path.parent.mkdir(parents=True, exist_ok=True) with open(summary_path, "w") as f: json.dump(result_summary, f) logging.info(f"Finished. Metrics saved to: {summary_path}") return output_file_path def main(args): """ Main function to run the medical examination simulation. This function handles the execution of tasks for both the medical student and examiner roles, based on the provided command-line arguments. Args: args (argparse.Namespace): Parsed command-line arguments containing task specifications and other parameters. Returns: None Raises: None, but exits the function early if required arguments are missing. """ if args.task not in ["student", "examiner", "all"]: logging.error("Invalid task! Please choose 'student', 'examiner', or 'all'.") sys.exit(1) if args.task in ["student", "all"]: if not args.student_model: logging.error( "Missing student model name. Please specify which model to use for generating responses." ) sys.exit(1) if not args.med_student_dataset: logging.error( "Missing medical student dataset. Please specify which dataset to use for generation." ) sys.exit(1) if args.task in ["examiner", "all"]: if not args.student_model: logging.error( "Missing student model name. Please specify which input data to use for evaluation." ) sys.exit(1) if not args.examiner_model: logging.error( "Missing examiner model name. Please specify which model to use for evaluation." ) sys.exit(1) if not args.med_exam_dataset: logging.error( "Missing medical examination dataset. Please specify which dataset to use for evaluation." ) sys.exit(1) logging.info(f"Starting task: {args.task}") logging.info(f"Section: {args.section}, Case: {args.case}, Turn: {args.turn}") if args.task == "student": logging.info(f"Model: {args.student_model}") else: logging.info(f"Model: {args.student_model}, {args.examiner_model}") try: if args.task in ["student", "all"]: med_student_output_file_path = llm_as_medical_student( section=args.section, case=args.case, conversation_turn=args.turn, med_student_dataset_path=args.med_student_dataset, output_path=args.output, model_parameters={"model_name": args.student_model, "temperature": 0.9}, intelligence_url=args.intelligence_url, # prompt_template="", # Uncomment and provide a template if needed ) if args.task in ["examiner", "all"]: if args.task == "all": # Use the path of the dataset generated by llm_as_medical_student med_student_dataset_path = med_student_output_file_path else: if args.med_student_dataset == "data/med-student.json": med_student_dataset_path = None else: med_student_dataset_path = args.med_student_dataset llm_as_examiner( section=args.section, case=args.case, conversation_turn=args.turn, med_student_dataset_path=med_student_dataset_path, med_exam_dataset_path=args.med_exam_dataset, output_path=args.output, model_parameters={"model_name": args.examiner_model, "temperature": 0}, input_student_model_name=args.student_model, # prompt_template="", # Uncomment and provide a template if needed ) except Exception as e: logging.error(f"An error occurred: {str(e)}") logging.error("Use verbose mode to see detailed error information.") if args.verbose: logging.exception("Detailed error information:") sys.exit(1) logging.info("Task completed successfully.") def setup_logging(verbose=False): level = logging.DEBUG if verbose else logging.INFO log_format = "%(asctime)s - %(levelname)s - %(filename)s:%(lineno)d - %(message)s" # Set up logging to file logging.basicConfig( level=level, format=log_format, filename="MedQA-CS.log", filemode="a" ) # Set up logging to console console = logging.StreamHandler() console.setLevel(level) formatter = logging.Formatter(log_format) console.setFormatter(formatter) logging.getLogger("").addHandler(console) def parse_args(): """ Parse command-line arguments for the medical examination simulation. Returns: argparse.Namespace: An object containing the parsed arguments. """ parser = argparse.ArgumentParser( description="Medical Examination Simulation CLI", formatter_class=argparse.ArgumentDefaultsHelpFormatter, ) # Task selection parser.add_argument( "-t", "--task", type=str, required=True, default="all", choices=["student", "examiner", "all"], help="Task to run: student (generate responses), examiner (evaluate responses), or all (both)", ) parser.add_argument( "-s", "--section", type=str, required=True, choices=["qa", "physical_exam", "closure", "diagnosis"], help="Section of the medical examination (qa, physical_exam, closure, diagnosis)", ) parser.add_argument( "-c", "--case", type=str, required=True, help="Case number or range (e.g., '1-44' for cases 1 through 44)", ) parser.add_argument( "--turn", type=str, default="all", help="Specific conversation turn or 'all' for entire conversation", ) parser.add_argument( "-sd", "--med_student_dataset", type=str, default="data/med-student.json", help="Path to the medical student dataset for generation task", ) parser.add_argument( "-ed", "--med_exam_dataset", type=str, default="data/med-exam.json", help="Path to the medical examination dataset for examiner task", ) parser.add_argument( "-o", "--output", type=str, default="output/", help="Path to output file or directory. If a directory is specified, output files will be saved with default names.", ) parser.add_argument( "-sm", "--student_model", type=str, help="Name of the model to use for generating student responses", ) parser.add_argument( "-em", "--examiner_model", type=str, default="gpt-4-1106-preview", help="Name of the model to use for evaluating responses", ) parser.add_argument( "-v", "--verbose", action="store_true", help="Enable verbose output" ) parser.add_argument( "-i", "--intelligence_url", type=str, help="URL of the intelligence service" ) args = parser.parse_args() return args def setup_langfuse(): if ( os.environ.get("LANGFUSE_PUBLIC_KEY") and os.environ.get("LANGFUSE_SECRET_KEY") and os.environ.get("LANGFUSE_HOST") ): from langfuse.callback import CallbackHandler logging.info("Using LangFuse") return True, CallbackHandler() return False, None if __name__ == "__main__": args = parse_args() dotenv.load_dotenv() setup_logging(args.verbose) logging.info(args) use_langfuse, langfuse_handler = setup_langfuse() main(args)