/
rare4a3e095
from llm_utils.api import Openai_api_handler
import argparse
from utils.mydataset import RareDataset
from utils.evaluation import diagnosis_evaluate
import os
from prompt import RarePrompt
import json
import numpy as np
import re
from benchflow import BenchClient
from typing import Dict, Any
np.random.seed(42)
class RareBenchClient(BenchClient):
def __init__(self, intelligence_url: str, max_retry: int = 1):
super().__init__(intelligence_url, max_retry)
def prepare_input(self, raw_input_data: Dict[str, Any]) -> Dict[str, Any]:
return raw_input_data
def parse_response(self, raw_response: str) -> Dict[str, Any]:
result = {
'system_prompt': "",
'question': "",
'model': "user_model",
'seed': 42,
'usage': {
'input_tokens': 0,
'output_tokens': 0,
},
'answer': raw_response,
}
return result
def diagnosis_metric_calculate(folder, judge_model="chatgpt"):
handler = Openai_api_handler(judge_model)
CNT = 0
metric = {}
recall_top_k = []
Pediatrics = range(0, 15)
Neurology = range(30, 45)
Cardiology = range(15, 30)
Nephrology = range(45, 60)
Hematology = range(60, 75)
for file in os.listdir(folder):
file_path = os.path.join(folder, file)
res = json.load(open(file_path, "r", encoding="utf-8-sig"))
predict_rank = res["predict_rank"]
if res['predict_diagnosis'] is None:
print(file_path, "predict_diagnosis is None")
if predict_rank is None:
predict_rank = diagnosis_evaluate(res["predict_diagnosis"], res["golden_diagnosis"], handler)
res["predict_rank"] = predict_rank
json.dump(res, open(file_path, "w", encoding="utf-8-sig"), indent=4, ensure_ascii=False)
if predict_rank not in ["否", "1", "2", "3", "4", "5", "6", "7", "8", "9", "10", "No"]:
print(file_path)
CNT += 1
if "否" in predict_rank or "No" in predict_rank:
recall_top_k.append(11)
else:
pattern = r'\b(?:10|[1-9])\b'
found = re.findall(pattern, predict_rank)
if not found or found[0] not in ["1", "2", "3", "4", "5", "6", "7", "8", "9", "10"]:
res["predict_rank"] = None
raise Exception("predict_rank error")
predict_rank = found[0]
recall_top_k.append(int(predict_rank))
metric['recall_top_1'] = len([i for i in recall_top_k if i <= 1]) / len(recall_top_k)
metric['recall_top_3'] = len([i for i in recall_top_k if i <= 3]) / len(recall_top_k)
metric['recall_top_10'] = len([i for i in recall_top_k if i <= 10]) / len(recall_top_k)
metric['medain_rank'] = np.median(recall_top_k)
print("predict_rank error: ", CNT)
print("evaluate tokens: ", handler.gpt4_tokens, handler.chatgpt_tokens, handler.chatgpt_instruct_tokens)
# Package the results into a dictionary
result = {
"folder": folder,
"metric": metric,
"predict_rank_error": CNT,
}
# Save the results to result.json in the folder
result_file = os.path.join("./results", "result.json")
with open(result_file, "w", encoding="utf-8") as f:
json.dump(result, f, indent=4, ensure_ascii=False)
def generate_random_few_shot_id(exclude_id, total_num, k_shot=3):
few_shot_id = []
while len(few_shot_id) < k_shot:
id = np.random.randint(0, total_num)
if id not in few_shot_id and id not in exclude_id:
few_shot_id.append(id)
return few_shot_id
def generate_dynamic_few_shot_id(methods, exclude_id, dataset, k_shot=3):
few_shot_id = []
patient = dataset.load_hpo_code_data()
if methods == "dynamic":
phe2embedding = json.load(open("mapping/phe2embedding.json", "r", encoding="utf-8-sig"))
elif methods == "medprompt":
phe2embedding = json.load(open("mapping/medprompt_emb.json", "r", encoding="utf-8-sig"))
ic_dict = json.load(open("mapping/ic_dict.json", "r", encoding="utf-8-sig"))
if methods == "medprompt":
ic_dict = {k: 1 for k, _ in ic_dict.items()}
exclude_patient = patient[exclude_id]
exclude_patient_embedding = np.array([np.array(phe2embedding[phe]) for phe in exclude_patient[0] if phe in phe2embedding])
exclude_patient_ic = np.array([ic_dict[phe] for phe in exclude_patient[0] if phe in phe2embedding])
exclude_patient_embedding = np.sum(exclude_patient_embedding * exclude_patient_ic.reshape(-1, 1), axis=0) / np.sum(exclude_patient_ic)
candidata_embedding_list = []
for i, p in enumerate(patient):
phe_embedding = np.array([np.array(phe2embedding[phe]) for phe in p[0] if phe in phe2embedding])
ic_coefficient_list = np.array([ic_dict[phe] for phe in p[0] if phe in phe2embedding])
phe_embedding = np.sum(phe_embedding * ic_coefficient_list.reshape(-1, 1), axis=0) / np.sum(ic_coefficient_list)
candidata_embedding_list.append(phe_embedding)
candidata_embedding_list = np.array(candidata_embedding_list)
cosine_sim = np.dot(candidata_embedding_list, exclude_patient_embedding)
cosine_sim = np.argsort(cosine_sim)[::-1]
for i in cosine_sim:
if i not in few_shot_id and i != exclude_id:
few_shot_id.append(i)
if len(few_shot_id) == k_shot:
break
return few_shot_id
def run_task(task_type, dataset:RareDataset, intelligence_url, results_folder, few_shot, cot, judge_model, eval=False):
few_shot_dict = {}
rare_prompt = RarePrompt()
client = RareBenchClient(intelligence_url)
if task_type == "diagnosis":
patient_info_type = dataset.dataset_type
os.makedirs(results_folder, exist_ok=True)
print("Begin diagnosis.....")
print("total patient: ", len(dataset.patient))
ERR_CNT = 0
questions = []
for i, patient in enumerate(dataset.patient):
# if handler is None:
# print("handler is None")
# break
result_file = os.path.join(results_folder, f"patient_{i}.json")
if os.path.exists(result_file):
continue
patient_info = patient[0]
golden_diagnosis = patient[1]
few_shot_info = []
if few_shot == "random":
few_shot_id = generate_random_few_shot_id([i], len(dataset.patient))
few_shot_dict[i] = few_shot_id
for id in few_shot_id:
few_shot_info.append((dataset.patient[id][0], dataset.patient[id][1]))
elif few_shot == "dynamic" or few_shot == "medprompt":
few_shot_id = generate_dynamic_few_shot_id(few_shot, i, dataset)
few_shot_dict[str(i)] = [str(idx) for idx in few_shot_id]
for id in few_shot_id:
few_shot_info.append((dataset.patient[id][0], dataset.patient[id][1]))
system_prompt, prompt = rare_prompt.diagnosis_prompt(patient_info_type, patient_info, cot, few_shot_info)
questions.append(system_prompt + prompt)
# if few_shot == "auto-cot":
# autocot_example = json.load(open("mapping/autocot_example.json", "r", encoding="utf-8-sig"))
# system_prompt = "Here a some examples: " + autocot_example[handler.model_name] + system_prompt
# prompt = prompt + "Let us think step by step.\n"
# predict_diagnosis = handler.get_completion(system_prompt, prompt)
input = {"system_prompt": system_prompt, "prompt": prompt}
predict_diagnosis = client.get_response(input)
if predict_diagnosis is None:
print(f"patient {i} predict diagnosis is None")
ERR_CNT += 1
continue
predict_rank = None
res = {
"patient_info": patient_info,
"golden_diagnosis": golden_diagnosis,
"predict_diagnosis": predict_diagnosis['raw_response'],
"predict_rank": predict_rank
}
json.dump(res, open(result_file, "w", encoding="utf-8-sig"), indent=4, ensure_ascii=False)
print(f"patient {i} finished")
# if type(handler) == Openai_api_handler:
# print("total tokens: ", handler.gpt4_tokens, handler.chatgpt_tokens, handler.chatgpt_instruct_tokens)
if eval:
diagnosis_metric_calculate(results_folder, judge_model=judge_model)
print("diagnosis ERR_CNT: ", ERR_CNT)
elif task_type == "mdt":
pass
def main():
parser = argparse.ArgumentParser()
parser.add_argument('--intelligence_url', type=str, default="http://localhost:8000")
parser.add_argument('--task_type', type=str, default="diagnosis", choices=["diagnosis", "mdt"])
parser.add_argument('--dataset_name', type=str, default="LIRICAL", choices=["RAMEDIS", "MME", "HMS", "LIRICAL", "PUMCH_ADM"])
parser.add_argument('--dataset_path', default=None)
parser.add_argument('--dataset_type', type=str, default="PHENOTYPE", choices=["EHR", "PHENOTYPE", "MDT"])
parser.add_argument('--results_folder', default='./results')
parser.add_argument('--judge_model', type=str, default="chatgpt", choices=["gpt4", "chatgpt"])
parser.add_argument('--few_shot', type=str, default="none", choices=["none", "random", "dynamic", "medprompt"])
parser.add_argument('--cot', type=str, default="none", choices=["none", "zero-shot"])
parser.add_argument('--eval', action='store_true')
args = parser.parse_args()
dataset = RareDataset(args.dataset_name, args.dataset_path, args.dataset_type)
if args.few_shot == "none":
few_shot = ""
elif args.few_shot == "random":
few_shot = "_few_shot"
elif args.few_shot == "dynamic":
few_shot = "_dynamic_few_shot"
elif args.few_shot == "medprompt":
few_shot = "_medprompt"
elif args.few_shot == "auto-cot":
few_shot = "_auto-cot"
if args.cot == "none":
cot = ""
elif args.cot == "zero-shot":
cot = "_cot"
results_folder = os.path.join(args.results_folder, args.dataset_type, args.dataset_name, args.task_type+few_shot+cot)
run_task(args.task_type, dataset, args.intelligence_url, results_folder, args.few_shot, args.cot, args.judge_model, args.eval)
if __name__ == "__main__":
main()