Skip to content
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension


Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
5 changes: 5 additions & 0 deletions .gitignore
Original file line number Diff line number Diff line change
Expand Up @@ -127,3 +127,8 @@ dmypy.json

# Pyre type checker
.pyre/

# tmp
tmp/

.simple_evals_cache/
46 changes: 35 additions & 11 deletions browsecomp_eval.py
Original file line number Diff line number Diff line change
Expand Up @@ -9,8 +9,12 @@
import random
import re
import pandas
from . import common
from .types import Eval, EvalResult, SamplerBase, SingleEvalResult
import pathlib
import urllib.request
import common
from eval_types import Eval, EvalResult, SamplerBase, SingleEvalResult

CACHE_DIR = pathlib.Path(".simple_evals_cache")

# from: https://github.com/centerforaisafety/hle/blob/7b6be5aad6f9b43af3857de7867f3b52f6e4acb3/hle_eval/run_model_predictions.py#L11
QUERY_TEMPLATE = """
Expand All @@ -23,7 +27,7 @@
""".strip()

# from: https://github.com/centerforaisafety/hle/blob/7b6be5aad6f9b43af3857de7867f3b52f6e4acb3/hle_eval/run_judge_results.py#L16-L33
GRADER_TEMPLATE = """
GRADER_TEMPLATE = r"""
Judge whether the following [response] to [question] is correct or not based on the precise and unambiguous [correct_answer] below.

[question]: {question}
Expand All @@ -41,7 +45,7 @@
correct: Answer 'yes' if extracted_final_answer matches the [correct_answer] given above, or is within a small margin of error for numerical problems. Answer 'no' otherwise, i.e. if there if there is any inconsistency, ambiguity, non-equivalency, or if the extracted answer is incorrect.


confidence: The extracted confidence score between 0|\%| and 100|\%| from [response]. Put 100 if there is no confidence score available.
confidence: The extracted confidence score between 0|\\%| and 100|\\%| from [response]. Put 100 if there is no confidence score available.
""".strip()

CHOICE_STRINGS = ["yes", "no"]
Expand All @@ -64,10 +68,21 @@ def decrypt(ciphertext_b64: str, password: str) -> str:


class BrowseCompEval(Eval):
def __init__(self, grader_model: SamplerBase, num_examples: int | None = None, n_repeats: int = 1):
df = pandas.read_csv(
"https://openaipublic.blob.core.windows.net/simple-evals/browse_comp_test_set.csv"
)
def __init__(self, grader_model: SamplerBase, num_examples: int | None = None, n_repeats: int = 1, batch_size: int = 20, checkpoint_file: str | None = None):
url = "https://openaipublic.blob.core.windows.net/simple-evals/browse_comp_test_set.csv"

CACHE_DIR.mkdir(parents=True, exist_ok=True)
file_name = url.split("/")[-1]
cached_file_path = CACHE_DIR / file_name

if cached_file_path.exists():
print(f"Loading cached BrowseComp data from {cached_file_path}")
df = pandas.read_csv(cached_file_path)
else:
print(f"Downloading BrowseComp data from {url} to {cached_file_path}")
urllib.request.urlretrieve(url, str(cached_file_path))
df = pandas.read_csv(cached_file_path)

examples = [row.to_dict() for _, row in df.iterrows()]
if num_examples:
assert n_repeats == 1, "n_repeats only supported when max_examples = None"
Expand All @@ -93,17 +108,24 @@ def grade_sample(self, question: str, correct_answer: str, response: str) -> str

def __call__(self, sampler: SamplerBase) -> EvalResult:
def fn(row: dict):
print(1)
problem = decrypt(row.get("problem", ""), row.get("canary", ""))
answer = decrypt(row.get("answer", ""), row.get("canary", ""))
prompt_messages = [
sampler._pack_message(content=QUERY_TEMPLATE.format(Question=problem), role="user")
]
response_text = sampler(prompt_messages)

# Find the first occurrence of "Explanation" and remove everything before it
explanation_index = response_text.find("Explanation:")
if explanation_index != -1:
response_text = response_text[explanation_index:]

grade_result = self.grade_sample(problem, answer, response_text)

# Metrics based on grading response
is_correct = grade_result == "yes"
is_incorrect = grade_result == "no"
is_correct = float(grade_result == "yes")
is_incorrect = float(grade_result == "no")

score = is_correct

Expand All @@ -112,10 +134,12 @@ def fn(row: dict):
prompt_messages=prompt_messages,
next_message=dict(content=response_text, role="assistant"),
score=score,
correct_answer=row["answer"],
# correct_answer=row["answer"],
correct_answer=answer,
extracted_answer=response_text,
)
convo = prompt_messages + [dict(content=response_text, role="assistant")]
print(2)
return SingleEvalResult(html=html, score=score, convo=convo, metrics={
"is_correct": is_correct,
"is_incorrect": is_incorrect,
Expand Down
161 changes: 118 additions & 43 deletions common.py
Original file line number Diff line number Diff line change
Expand Up @@ -8,8 +8,9 @@
import numpy as np
import requests
from tqdm import tqdm
import json

from .types import EvalResult, Message, SamplerBase, SingleEvalResult
from eval_types import EvalResult, Message, SamplerBase, SingleEvalResult

QUERY_TEMPLATE_MULTICHOICE = """
Answer the following multiple choice question. The last line of your response should be of the following format: 'Answer: $LETTER' (without quotes) where LETTER is one of ABCD. Think step by step before answering.
Expand All @@ -29,48 +30,49 @@
)
# All the different ways "Answer" is written in different languages
MULTILINGUAL_ANSWER_REGEXES = [
"Answer\s*:",
"Answer\s*:​​​​​​", # Korean invisible character
"উত্তর\s*:",
"उत्तर\s*:",
"উত্তরঃ",
"উত্তর\s*:",
"Antwort\s*:",
"답변\s*:",
"정답\s*:",
"답\s*:",
"答案\s*:",
"答案\s*:",
"答\s*:",
"答\s*:",
"答复\s*:",
"答曰\s*:",
"الإجابة:",
"الجواب:",
"إجابة:",
"الإجابة النهائية:",
"الإجابة الصحيحة:",
"الإجابة الصحيحة هي:",
"الإجابة هي:",
"الجواب النهائي:",
"Respuesta\s*:",
"Risposta\s*:",
"答え\s*:",
"答え\s*:",
"回答\s*:",
"回答\s*:",
"解答\s*:",
"Jawaban\s*:",
"Réponse\s*:",
"Resposta\s*:",
"Jibu\s*:",
"Idahun\s*:",
"Ìdáhùn\s*:",
"Idáhùn\s*:",
"Àmọ̀nà\s*:",
"Àdáhùn\s*:",
"Ànúgọ\s*:",
"Àṣàyàn\s*:",
r"Answer\s*:",
r"Answer\\s*:",
r"Answer\\s*:​​​​​​", # Korean invisible character
r"উত্তর\\s*:",
r"उत्तर\\s*:",
r"উত্তরঃ",
r"উত্তর\\s*:",
r"Antwort\\s*:",
r"답변\\s*:",
r"정답\\s*:",
r"답\\s*:",
r"答案\\s*:",
r"答案\\s*:",
r"答\\s*:",
r"答\\s*:",
r"答复\\s*:",
r"答曰\\s*:",
"الإجابة:", # No backslash, no r needed
"الجواب:", # No backslash, no r needed
"إجابة:", # No backslash, no r needed
"الإجابة النهائية:", # No backslash, no r needed
"الإجابة الصحيحة:", # No backslash, no r needed
"الإجابة الصحيحة هي:", # No backslash, no r needed
"الإجابة هي:", # No backslash, no r needed
"الجواب النهائي:", # No backslash, no r needed
r"Respuesta\\s*:",
r"Risposta\\s*:",
r"答え\\s*:",
r"答え\\s*:",
r"回答\\s*:",
r"回答\\s*:",
r"解答\\s*:",
r"Jawaban\\s*:",
r"Réponse\\s*:",
r"Resposta\\s*:",
r"Jibu\\s*:",
r"Idahun\\s*:",
r"Ìdáhùn\\s*:",
r"Idáhùn\\s*:",
r"Àmọ̀nà\\s*:",
r"Àdáhùn\\s*:",
r"Ànúgọ\\s*:",
r"Àṣàyàn\\s*:",
]


Expand Down Expand Up @@ -372,3 +374,76 @@ def url_to_fileobj(url: str, binary=False) -> Any:
response = requests.get(url)
response.raise_for_status()
return io.BytesIO(response.content) if binary else io.StringIO(response.text)

def load_checkpoint(checkpoint_file: str | None) -> list[SingleEvalResult]:
"""Loads processed results from the checkpoint file."""
if not checkpoint_file or not os.path.exists(checkpoint_file):
if checkpoint_file: # Only print if a path was given
print(f"Checkpoint file {checkpoint_file} not found. Starting fresh.")
return []

loaded_results: list[SingleEvalResult] = []
try:
with open(checkpoint_file, "r") as f:
for line_number, line in enumerate(f, 1):
line_content = line.strip()
if not line_content: # Skip empty lines
continue
try:
data = json.loads(line_content)
# Basic validation for required keys from SingleEvalResult
if not all(k in data for k in ["html", "score"]):
print(
f"Skipping entry with missing core keys ('html', 'score') in checkpoint file {checkpoint_file} at line {line_number}."
)
continue

# metrics and convo are optional in SingleEvalResult
result = SingleEvalResult(
html=data["html"],
score=data["score"],
metrics=data.get("metrics"),
convo=data.get("convo"),
)
loaded_results.append(result)
except json.JSONDecodeError:
print(
f"Skipping malformed JSON in checkpoint file {checkpoint_file} at line {line_number}: {line_content}"
)
except KeyError as e:
print(
f"Skipping entry with missing key {e} in checkpoint file {checkpoint_file} at line {line_number}: {line_content}"
)

if loaded_results:
print(
f"Resumed from checkpoint. Loaded {len(loaded_results)} results from {checkpoint_file}."
)
else:
# File existed but was empty or all lines were invalid
print(
f"Checkpoint file {checkpoint_file} was empty or contained no valid entries. Starting fresh."
)
return loaded_results
except Exception as e: # Catch other errors like permission issues
print(f"Error loading checkpoint from {checkpoint_file}: {e}. Starting fresh.")
return [] # Ensure clean state on major load error


def save_checkpoint(checkpoint_file: str | None, new_results: list[SingleEvalResult]):
"""Appends new results to the checkpoint file."""
if not checkpoint_file:
return
try:
with open(checkpoint_file, "a") as f:
for result in new_results:
# Convert SingleEvalResult to dict for JSON serialization
result_dict = {
"html": result.html,
"score": result.score,
"metrics": result.metrics,
"convo": result.convo,
}
f.write(json.dumps(result_dict) + "\n")
except Exception as e:
print(f"Error saving checkpoint to {checkpoint_file}: {e}")
41 changes: 34 additions & 7 deletions drop_eval.py
Original file line number Diff line number Diff line change
Expand Up @@ -9,14 +9,16 @@
import random
import re
import string
import pathlib
import urllib.request
from typing import Any, Dict, List, Optional, Set, Tuple, Union

import numpy as np
from scipy.optimize import linear_sum_assignment

from . import common
from .common import ANSWER_PATTERN, HTML_JINJA
from .types import Eval, EvalResult, SamplerBase, SingleEvalResult
import common
from common import ANSWER_PATTERN, HTML_JINJA
from eval_types import Eval, EvalResult, SamplerBase, SingleEvalResult

"""
From here through _normalize_answer was originally copied from:
Expand Down Expand Up @@ -233,20 +235,45 @@ def drop_metric(sample: str, reference: list[str]) -> Tuple[float, float]:
return (max(em_scores), max(f1_scores))


CACHE_DIR = pathlib.Path(".simple_evals_cache")

class DropEval(Eval):
def __init__(self, num_examples: int | None = None, train_samples_per_prompt: int = 3):
self.seed = 42
self._num_examples = num_examples
self._train_samples_per_prompt = train_samples_per_prompt
self.train_jsonl = (

train_url = (
"https://openaipublic.blob.core.windows.net/simple-evals/drop_v0_train.jsonl.gz"
)
self.test_jsonl = (
test_url = (
"https://openaipublic.blob.core.windows.net/simple-evals/drop_v0_dev.jsonl.gz"
)
with gzip.GzipFile(fileobj=common.url_to_fileobj(self.train_jsonl, binary=True), mode="rb") as f:

CACHE_DIR.mkdir(parents=True, exist_ok=True)

train_file_name = train_url.split("/")[-1]
cached_train_file_path = CACHE_DIR / train_file_name

if cached_train_file_path.exists():
print(f"Loading cached DROP train data from {cached_train_file_path}")
else:
print(f"Downloading DROP train data from {train_url} to {cached_train_file_path}")
urllib.request.urlretrieve(train_url, str(cached_train_file_path))

with gzip.GzipFile(filename=str(cached_train_file_path), mode="rb") as f:
self.train_samples = list(map(json.loads, f.readlines()))
with gzip.GzipFile(fileobj=common.url_to_fileobj(self.test_jsonl, binary=True), mode="rb") as f:

test_file_name = test_url.split("/")[-1]
cached_test_file_path = CACHE_DIR / test_file_name

if cached_test_file_path.exists():
print(f"Loading cached DROP test data from {cached_test_file_path}")
else:
print(f"Downloading DROP test data from {test_url} to {cached_test_file_path}")
urllib.request.urlretrieve(test_url, str(cached_test_file_path))

with gzip.GzipFile(filename=str(cached_test_file_path), mode="rb") as f:
self.test_samples = list(map(json.loads, f.readlines()))
if self._num_examples:
self.test_samples = random.Random(self.seed).sample(
Expand Down
File renamed without changes.
Loading