/
MMLU-PROc8dcff7
import os
import json
import re
import random
from tqdm import tqdm
from typing import Dict, Any
import time
from datasets import load_dataset
import argparse
from benchflow import BenchClient
API_KEY = ""
def load_mmlu_pro():
dataset = load_dataset("TIGER-Lab/MMLU-Pro")
test_df, val_df = dataset["test"], dataset["validation"]
test_df = preprocess(test_df)
val_df = preprocess(val_df)
return test_df, val_df
def preprocess(test_df):
res_df = []
for each in test_df:
options = []
for opt in each["options"]:
if opt == "N/A":
continue
options.append(opt)
each["options"] = options
res_df.append(each)
res = {}
for each in res_df:
if each["category"] not in res:
res[each["category"]] = []
res[each["category"]].append(each)
return res
def format_example(question, options, cot_content=""):
if cot_content == "":
cot_content = "Let's think step by step."
if cot_content.startswith("A: "):
cot_content = cot_content[3:]
example = "Question: {}\nOptions: ".format(question)
choice_map = "ABCDEFGHIJ"
for i, opt in enumerate(options):
example += "{}. {}\n".format(choice_map[i], opt)
if cot_content == "":
example += "Answer: "
else:
example += "Answer: " + cot_content + "\n\n"
return example
def extract_answer(text):
pattern = r"answer is \(?([A-J])\)?"
match = re.search(pattern, text)
if match:
return match.group(1)
else:
print("1st answer extract failed\n" + text)
return extract_again(text)
def extract_again(text):
match = re.search(r'.*[aA]nswer:\s*([A-J])', text)
if match:
return match.group(1)
else:
return extract_final(text)
def extract_final(text):
pattern = r"\b[A-J]\b(?!.*\b[A-J]\b)"
match = re.search(pattern, text, re.DOTALL)
if match:
return match.group(0)
else:
return None
def update_result(output_res_path):
category_record = {}
res = []
success = False
while not success:
try:
if os.path.exists(output_res_path):
with open(output_res_path, "r") as fi:
res = json.load(fi)
for each in res:
category = each["category"]
if category not in category_record:
category_record[category] = {"corr": 0.0, "wrong": 0.0}
if not each["pred"]:
random.seed(12345)
x = random.randint(0, len(each["options"]) - 1)
if x == each["answer_index"]:
category_record[category]["corr"] += 1
else:
category_record[category]["wrong"] += 1
elif each["pred"] == each["answer"]:
category_record[category]["corr"] += 1
else:
category_record[category]["wrong"] += 1
success = True
except Exception as e:
print("Error", e, "sleep 2 seconds")
time.sleep(2)
return res, category_record
def merge_result(res, curr):
merged = False
for i, single in enumerate(res):
if single["question_id"] == curr["question_id"] and single["question"] == curr["question"]:
res[i] = curr
merged = True
if not merged:
res.append(curr)
return res
def evaluate(subjects, intelligence_url):
test_df, dev_df = load_mmlu_pro()
if not subjects:
subjects = list(test_df.keys())
print("assigned subjects", subjects)
bench_client = MMLUClient(intelligence_url)
for subject in subjects:
test_data = test_df[subject]
output_res_path = os.path.join(args.output_dir, subject + "_result.json")
output_summary_path = os.path.join(args.output_dir, subject + "_summary.json")
res, category_record = update_result(output_res_path)
for each in tqdm(test_data):
label = each["answer"]
category = subject
env = {
"each": each,
"input_text": dev_df
}
action = bench_client.get_response(env)
pred = action["action"]
response = action["response"]
if response is not None:
res, category_record = update_result(output_res_path)
if category not in category_record:
category_record[category] = {"corr": 0.0, "wrong": 0.0}
each["pred"] = pred
each["model_outputs"] = response
merge_result(res, each)
if pred is not None:
if pred == label:
category_record[category]["corr"] += 1
else:
category_record[category]["wrong"] += 1
else:
category_record[category]["wrong"] += 1
save_res(res, output_res_path)
save_summary(category_record, output_summary_path)
res, category_record = update_result(output_res_path)
save_res(res, output_res_path)
save_summary(category_record, output_summary_path)
def save_res(res, output_res_path):
temp = []
exist_q_id = []
for each in res:
if each["question_id"] not in exist_q_id:
exist_q_id.append(each["question_id"])
temp.append(each)
else:
continue
res = temp
with open(output_res_path, "w") as fo:
fo.write(json.dumps(res))
def save_summary(category_record, output_summary_path):
total_corr = 0.0
total_wrong = 0.0
for k, v in category_record.items():
if k == "total":
continue
cat_acc = v["corr"] / (v["corr"] + v["wrong"])
category_record[k]["acc"] = cat_acc
total_corr += v["corr"]
total_wrong += v["wrong"]
acc = total_corr / (total_corr + total_wrong)
category_record["total"] = {"corr": total_corr, "wrong": total_wrong, "acc": acc}
with open(output_summary_path, "w") as fo:
fo.write(json.dumps(category_record))
class MMLUClient(BenchClient):
def __init__(self, intelligence_url):
super().__init__(intelligence_url, 3)
def prepare_input(self, env: Dict[str, Any]) -> Dict[str, Any]:
single_question = env["each"]
cot_examples_dict = env["input_text"]
category = single_question["category"]
cot_examples = cot_examples_dict[category]
question = single_question["question"]
options = single_question["options"]
prompt = "The following are multiple choice questions (with answers) about {}. Think step by" \
" step and then output the answer in the format of \"The answer is (X)\" at the end.\n\n" \
.format(category)
for each in cot_examples:
prompt += format_example(each["question"], each["options"], each["cot_content"])
input_text = format_example(question, options)
return {"prompt": prompt, "input_text": input_text, "entry": single_question, "cot_examples_dict": cot_examples_dict}
def parse_response(self, raw_response: str) -> Dict[str, Any]:
pred = extract_answer(raw_response)
return {"action": pred, "response": raw_response}
if __name__ == "__main__":
parser = argparse.ArgumentParser()
parser.add_argument("--output_dir", "-o", type=str, default="eval_results/")
parser.add_argument("--intelligence_url", "-b", type=str)
parser.add_argument("--model_name", "-m", type=str, default="gpt-4",
choices=["gpt-4", "gpt-4o", "o1-preview",
"deepseek-chat", "deepseek-coder",
"gemini-1.5-flash-latest",
"gemini-1.5-pro-latest",
"claude-3-opus-20240229",
"gemini-1.5-flash-8b",
"claude-3-sonnet-20240229",
"gemini-002-pro",
"gemini-002-flash"])
parser.add_argument("--assigned_subjects", "-a", type=str, default="all")
assigned_subjects = []
args = parser.parse_args()
if args.assigned_subjects == "all":
assigned_subjects = []
else:
assigned_subjects = args.assigned_subjects.split(",")
os.makedirs(args.output_dir, exist_ok=True)
evaluate(assigned_subjects, args.intelligence_url)