/
medqa-csc34f77a
import json
from collections import defaultdict
from main import Section
class DataViewer:
def __init__(self, dataset_path):
self.lookup = None
self.dataset = None
self.dataset_path = dataset_path
self.open_dataset(dataset_path)
def open_dataset(self, dataset_path):
with open(dataset_path, "r") as f:
data = json.load(f)
# Create a lookup dictionary to map sections and case IDs to conversation turns
lookup = dict(
zip(
[section.value for section in Section],
[defaultdict(dict) for i in range(5)],
)
)
prev_section = Section.qa.value
prev_case = 1
max_turn = 0
for i, element in enumerate(data):
section = element["section"]
case_id = int(element["case_id"])
conversation_turn_id = int(element["conversation_turn_id"])
lookup[section][case_id][conversation_turn_id] = i
if prev_section != section or prev_case != case_id:
lookup[prev_section][prev_case]["max_turn"] = max_turn
prev_section = section
prev_case = case_id
max_turn = 0
max_turn += 1
lookup[prev_section][prev_case]["max_turn"] = max_turn
self.lookup = lookup
self.dataset = data
return data, lookup
def get_data(self, section, case, turn=1):
index = self.lookup[section][case][turn]
data = self.dataset[index]
return data
def view_data(self, section, case, turn=1):
data = self.get_data(section, case, turn)
return json.dumps(data, indent=2)
if __name__ == "__main__":
data_viewer = DataViewer("data/evaluation.json")
print(data_viewer.view_data("diagnosis", 1))