/
BIRD-critiq73e468e
import argparse
import os
import json
import time
import itertools
from tqdm import tqdm
from concurrent.futures import ThreadPoolExecutor, as_completed
import threading
from openai import OpenAI
import anthropic
import google.generativeai as genai
from google.generativeai.types import HarmCategory, HarmBlockThreshold, GenerationConfig
from config import model_config
def load_jsonl(file_path):
data = []
with open(file_path, "r", encoding="utf-8") as file:
for line in file:
data.append(json.loads(line))
return data
def new_directory(path):
if path and not os.path.exists(path):
os.makedirs(path)
GEMINI_API_KEYS = model_config["gemini"]
# Create an infinite key cycle
gemini_key_cycle = itertools.cycle(GEMINI_API_KEYS)
def write_response(results, data_list, output_path):
"""
By default, each result is a single response.
"""
formatted_data = []
for i, data in enumerate(data_list):
data["responses"] = results[i]
data.pop("prompt", None)
formatted_data.append(data)
if output_path:
directory_path = os.path.dirname(output_path)
new_directory(directory_path)
with open(output_path, "w") as f:
for instance in formatted_data:
f.write(json.dumps(instance, ensure_ascii=False) + "\n")
def api_request(messages, engine, client, backend, **kwargs):
"""
Calls the underlying LLM endpoint depending on the 'backend'.
"""
while True:
try:
if backend == "openai":
completion = client.chat.completions.create(
model=engine,
messages=messages,
temperature=kwargs.get("temperature", 0),
max_tokens=kwargs.get("max_tokens", 512),
top_p=kwargs.get("top_p", 1),
frequency_penalty=kwargs.get("frequency_penalty", 0),
presence_penalty=kwargs.get("presence_penalty", 0),
stop=kwargs.get("stop", None),
)
return completion.choices[0].message.content
elif backend == "anthropic":
message = client.messages.create(
model=engine,
messages=messages,
temperature=kwargs.get("temperature", 0),
max_tokens=kwargs.get("max_tokens", 512),
top_p=kwargs.get("top_p", 1),
stop_sequences=kwargs.get("stop", None),
)
return message.content[0].text
elif backend == "genai":
response = client.generate_content(
messages[0]["content"],
generation_config=GenerationConfig(
temperature=kwargs.get("temperature", 0),
top_p=kwargs.get("top_p", 1),
max_output_tokens=kwargs.get("max_tokens", 512),
presence_penalty=kwargs.get("presence_penalty", 0),
frequency_penalty=kwargs.get("frequency_penalty", 0),
stop_sequences=kwargs.get("stop", None),
),
)
try:
return response.text
except ValueError as ve:
return f"Model refused to generate a response {ve}"
except Exception:
return ""
except Exception as e:
print(e)
time.sleep(1)
# Rotate API keys and retry if using the genai backend
if backend == "genai":
genai.configure(api_key=next(gemini_key_cycle))
time.sleep(10)
def call_api_model(
messages,
model_name,
temperature=0,
max_tokens=512,
top_p=1,
frequency_penalty=0,
presence_penalty=0,
timeout=10,
stop=None,
):
"""
Sets up the correct backend client + model engine, then calls 'api_request'.
"""
if "gpt" in model_name:
engine = model_name
client = OpenAI(
base_url=model_config[model_name]["base_url"],
api_key=model_config[model_name]["api_key"],
)
backend = "openai"
elif "claude" in model_name:
engine = model_name
client = anthropic.Anthropic(
api_key=model_config[model_name],
)
backend = "anthropic"
elif "gemini" in model_name:
engine = model_name
client = genai.GenerativeModel(engine)
genai.configure(api_key=GEMINI_API_KEYS[1])
backend = "genai"
else:
print(f"Unsupported model name: {model_name}")
raise ValueError(f"Unsupported model name: {model_name}")
kwargs = {
"temperature": temperature,
"max_tokens": max_tokens,
"top_p": top_p,
"frequency_penalty": frequency_penalty,
"presence_penalty": presence_penalty,
"stop": stop,
}
return api_request(messages, engine, client, backend, **kwargs)
def worker_function(task, data_list, output_path, lock):
"""
Processes a single prompt.
"""
prompt, idx, model_name = task
messages = [{"role": "user", "content": prompt}]
response = call_api_model(messages, model_name)
print(response)
# Write to the file in real-time (append mode)
with lock:
with open(output_path, "a", encoding="utf-8") as f:
row = data_list[idx]
row["response"] = response
# Use the _index field to record the original index
row["_index"] = idx
row.pop("prompt", None)
f.write(json.dumps(row, ensure_ascii=False) + "\n")
return idx, response
def final_sort_jsonl_by_index(file_path):
"""
Reads an existing JSONL file, sorts it by the '_index' field,
then overwrites the file. After sorting, you can remove the '_index' field.
"""
all_data = []
with open(file_path, "r", encoding="utf-8") as fin:
for line in fin:
if not line.strip():
continue
row = json.loads(line)
all_data.append(row)
# Sort by '_index'
all_data.sort(key=lambda x: x["_index"])
# Overwrite the file, removing the '_index' field
with open(file_path, "w", encoding="utf-8") as fout:
for row in all_data:
row.pop("_index", None)
fout.write(json.dumps(row, ensure_ascii=False) + "\n")
def collect_response_from_api(
prompt_list,
model_name,
data_list,
output_path,
num_threads=8,
start_index=0,
):
"""
In multi-threading, write to a file in real-time, then sort the final output.
"""
# Only process tasks from 'start_index' onward
tasks = [
(prompt_list[i], i, model_name) for i in range(start_index, len(prompt_list))
]
# Ensure the output directory exists
new_directory(os.path.dirname(output_path))
# If starting from scratch, use 'w' to clear the file; otherwise use 'a' to append
file_mode = "a" if start_index > 0 else "w"
if file_mode == "w":
# Clear the file first
open(output_path, "w", encoding="utf-8").close()
# Lock for protecting the write operation
lock = threading.Lock()
with ThreadPoolExecutor(max_workers=num_threads) as executor:
futures = []
for t in tasks:
futures.append(
executor.submit(worker_function, t, data_list, output_path, lock)
)
# Wait until all threads are done
for _ in tqdm(as_completed(futures), total=len(futures)):
pass
# After all threads finish, perform a final sort of the output file
final_sort_jsonl_by_index(output_path)
if __name__ == "__main__":
args_parser = argparse.ArgumentParser()
args_parser.add_argument("--prompt_path", type=str)
args_parser.add_argument("--output_path", type=str)
args_parser.add_argument("--model_name", type=str, default="claude")
args_parser.add_argument("--start_index", type=int, default=0)
args = args_parser.parse_args()
data_list = load_jsonl(args.prompt_path)
prompts = [data["prompt"] for data in data_list]
print(prompts[0])
collect_response_from_api(
prompts,
args.model_name,
data_list,
args.output_path,
start_index=args.start_index,
)