mirrored 13 minutes ago
0
kirkfix: fix for benchflow 0.1.12 4a3e095
from llm_utils.api import Openai_api_handler
import argparse
from utils.mydataset import RareDataset
from utils.evaluation import diagnosis_evaluate
import os
from prompt import RarePrompt
import json
import numpy as np
import re
from benchflow import BenchClient
from typing import Dict, Any

np.random.seed(42)

class RareBenchClient(BenchClient):
    def __init__(self, intelligence_url: str, max_retry: int = 1):
        super().__init__(intelligence_url, max_retry)

    def prepare_input(self, raw_input_data: Dict[str, Any]) -> Dict[str, Any]:
        return raw_input_data

    def parse_response(self, raw_response: str) -> Dict[str, Any]:
        result = {
                    'system_prompt': "",
                    'question': "",
                    'model': "user_model",
                    'seed': 42,
                    'usage': {
                        'input_tokens': 0,
                        'output_tokens': 0,
                    },
                    'answer': raw_response,
                }
        return result

def diagnosis_metric_calculate(folder, judge_model="chatgpt"):
    handler = Openai_api_handler(judge_model)
    
    CNT = 0
    metric = {}
    recall_top_k = []

    Pediatrics = range(0, 15)
    Neurology = range(30, 45)
    Cardiology = range(15, 30)
    Nephrology = range(45, 60)
    Hematology = range(60, 75)

    for file in os.listdir(folder):
        file_path = os.path.join(folder, file)
        res = json.load(open(file_path, "r", encoding="utf-8-sig"))
        predict_rank = res["predict_rank"]
        if res['predict_diagnosis'] is None:
            print(file_path, "predict_diagnosis is None")
        
        if predict_rank is None:
            predict_rank = diagnosis_evaluate(res["predict_diagnosis"], res["golden_diagnosis"], handler)
            res["predict_rank"] = predict_rank
            json.dump(res, open(file_path, "w", encoding="utf-8-sig"), indent=4, ensure_ascii=False)
        
        if predict_rank not in ["", "1", "2", "3", "4", "5", "6", "7", "8", "9", "10", "No"]:
            print(file_path)
            CNT += 1

        if "" in predict_rank or "No" in predict_rank:
            recall_top_k.append(11)
        else:
            pattern = r'\b(?:10|[1-9])\b'
            found = re.findall(pattern, predict_rank)
            if not found or found[0] not in ["1", "2", "3", "4", "5", "6", "7", "8", "9", "10"]:
                res["predict_rank"] = None
                raise Exception("predict_rank error")
            predict_rank = found[0]
            recall_top_k.append(int(predict_rank))
    
    metric['recall_top_1'] = len([i for i in recall_top_k if i <= 1]) / len(recall_top_k)
    metric['recall_top_3'] = len([i for i in recall_top_k if i <= 3]) / len(recall_top_k)
    metric['recall_top_10'] = len([i for i in recall_top_k if i <= 10]) / len(recall_top_k)
    metric['medain_rank'] = np.median(recall_top_k)
    
    print("predict_rank error: ", CNT)
    print("evaluate tokens: ", handler.gpt4_tokens, handler.chatgpt_tokens, handler.chatgpt_instruct_tokens)
    
    # Package the results into a dictionary
    result = {
        "folder": folder,
        "metric": metric,
        "predict_rank_error": CNT,
    }
    
    # Save the results to result.json in the folder
    result_file = os.path.join("./results", "result.json")
    with open(result_file, "w", encoding="utf-8") as f:
        json.dump(result, f, indent=4, ensure_ascii=False)
        
def generate_random_few_shot_id(exclude_id, total_num, k_shot=3):
    few_shot_id = []
    while len(few_shot_id) < k_shot:
        id = np.random.randint(0, total_num)
        if id not in few_shot_id and id not in exclude_id:
            few_shot_id.append(id)
    return few_shot_id

def generate_dynamic_few_shot_id(methods, exclude_id, dataset, k_shot=3):
    few_shot_id = []

    patient = dataset.load_hpo_code_data()
    if methods == "dynamic":
        phe2embedding = json.load(open("mapping/phe2embedding.json", "r", encoding="utf-8-sig"))
    elif methods == "medprompt":
        phe2embedding = json.load(open("mapping/medprompt_emb.json", "r", encoding="utf-8-sig"))
    ic_dict = json.load(open("mapping/ic_dict.json", "r", encoding="utf-8-sig"))
    if methods == "medprompt":
        ic_dict = {k: 1 for k, _ in ic_dict.items()}
   
    exclude_patient = patient[exclude_id]
    exclude_patient_embedding = np.array([np.array(phe2embedding[phe]) for phe in exclude_patient[0] if phe in phe2embedding])
    exclude_patient_ic = np.array([ic_dict[phe] for phe in exclude_patient[0] if phe in phe2embedding])
    exclude_patient_embedding = np.sum(exclude_patient_embedding * exclude_patient_ic.reshape(-1, 1), axis=0) / np.sum(exclude_patient_ic)
    candidata_embedding_list = []
    for i, p in enumerate(patient):
        phe_embedding = np.array([np.array(phe2embedding[phe]) for phe in p[0] if phe in phe2embedding])
        ic_coefficient_list = np.array([ic_dict[phe] for phe in p[0] if phe in phe2embedding])
        phe_embedding = np.sum(phe_embedding * ic_coefficient_list.reshape(-1, 1), axis=0) / np.sum(ic_coefficient_list)
        candidata_embedding_list.append(phe_embedding)
    candidata_embedding_list = np.array(candidata_embedding_list)
    cosine_sim = np.dot(candidata_embedding_list, exclude_patient_embedding) 
    cosine_sim = np.argsort(cosine_sim)[::-1]
    for i in cosine_sim:
        if i not in few_shot_id and i != exclude_id:
            few_shot_id.append(i)
        if len(few_shot_id) == k_shot:
            break
    
    return few_shot_id


def run_task(task_type, dataset:RareDataset, intelligence_url, results_folder, few_shot, cot, judge_model, eval=False):
    few_shot_dict = {}
    rare_prompt = RarePrompt()
    client = RareBenchClient(intelligence_url)
    if task_type == "diagnosis":
        patient_info_type = dataset.dataset_type
        os.makedirs(results_folder, exist_ok=True)
        print("Begin diagnosis.....")
        print("total patient: ", len(dataset.patient))
        ERR_CNT = 0
        questions = []
        for i, patient in enumerate(dataset.patient):
            # if handler is None:
            #     print("handler is None")
            #     break
            result_file = os.path.join(results_folder, f"patient_{i}.json")
            if os.path.exists(result_file):
                continue
            patient_info = patient[0]
            golden_diagnosis = patient[1]
            few_shot_info = []
            if few_shot == "random":
                few_shot_id = generate_random_few_shot_id([i], len(dataset.patient))
                
                few_shot_dict[i] = few_shot_id
                for id in few_shot_id:
                    few_shot_info.append((dataset.patient[id][0], dataset.patient[id][1]))
            elif few_shot == "dynamic" or few_shot == "medprompt":
                few_shot_id = generate_dynamic_few_shot_id(few_shot, i, dataset)
                
                few_shot_dict[str(i)] = [str(idx) for idx in few_shot_id]
                for id in few_shot_id:
                    few_shot_info.append((dataset.patient[id][0], dataset.patient[id][1]))

            system_prompt, prompt = rare_prompt.diagnosis_prompt(patient_info_type, patient_info, cot, few_shot_info)
            
            questions.append(system_prompt + prompt)
            # if few_shot == "auto-cot":
            #     autocot_example = json.load(open("mapping/autocot_example.json", "r", encoding="utf-8-sig"))
            #     system_prompt = "Here a some examples: " + autocot_example[handler.model_name] + system_prompt
            #     prompt = prompt + "Let us think step by step.\n"
            

            # predict_diagnosis = handler.get_completion(system_prompt, prompt)
            input = {"system_prompt": system_prompt, "prompt": prompt}
            predict_diagnosis = client.get_response(input)
            if predict_diagnosis is None:
                print(f"patient {i} predict diagnosis is None")
                ERR_CNT += 1
                continue
            
            predict_rank = None
            res = {
                "patient_info": patient_info,
                "golden_diagnosis": golden_diagnosis,
                "predict_diagnosis": predict_diagnosis['raw_response'],
                "predict_rank": predict_rank
            }
            json.dump(res, open(result_file, "w", encoding="utf-8-sig"), indent=4, ensure_ascii=False)
            print(f"patient {i} finished")
            # if type(handler) == Openai_api_handler:
            #     print("total tokens: ", handler.gpt4_tokens, handler.chatgpt_tokens, handler.chatgpt_instruct_tokens)
            
        if eval:
            diagnosis_metric_calculate(results_folder, judge_model=judge_model)
        print("diagnosis ERR_CNT: ", ERR_CNT)
    elif task_type == "mdt":
        pass


def main():
    parser = argparse.ArgumentParser()
    parser.add_argument('--intelligence_url', type=str, default="http://localhost:8000")
    parser.add_argument('--task_type', type=str, default="diagnosis", choices=["diagnosis", "mdt"])
    parser.add_argument('--dataset_name', type=str, default="LIRICAL", choices=["RAMEDIS", "MME", "HMS", "LIRICAL", "PUMCH_ADM"])
    parser.add_argument('--dataset_path', default=None)
    parser.add_argument('--dataset_type', type=str, default="PHENOTYPE", choices=["EHR", "PHENOTYPE", "MDT"])
    parser.add_argument('--results_folder', default='./results')
    parser.add_argument('--judge_model', type=str, default="chatgpt", choices=["gpt4", "chatgpt"])
    parser.add_argument('--few_shot', type=str, default="none", choices=["none", "random", "dynamic", "medprompt"])
    parser.add_argument('--cot', type=str, default="none", choices=["none", "zero-shot"])
    parser.add_argument('--eval', action='store_true')

    args = parser.parse_args()

    dataset = RareDataset(args.dataset_name, args.dataset_path, args.dataset_type)
    
    if args.few_shot == "none":
        few_shot = ""
    elif args.few_shot == "random":
        few_shot = "_few_shot"
    elif args.few_shot == "dynamic":
        few_shot = "_dynamic_few_shot"
    elif args.few_shot == "medprompt":
        few_shot = "_medprompt"
    elif args.few_shot == "auto-cot":
        few_shot = "_auto-cot"
    if args.cot == "none":
        cot = ""
    elif args.cot == "zero-shot":
        cot = "_cot"
    results_folder = os.path.join(args.results_folder, args.dataset_type, args.dataset_name, args.task_type+few_shot+cot)
    run_task(args.task_type, dataset, args.intelligence_url, results_folder, args.few_shot, args.cot, args.judge_model, args.eval)

if __name__ == "__main__":
    main()