diff --git a/.gitignore b/.gitignore index b6e47617..a7958953 100644 --- a/.gitignore +++ b/.gitignore @@ -127,3 +127,8 @@ dmypy.json # Pyre type checker .pyre/ + +# tmp +tmp/ + +.simple_evals_cache/ diff --git a/browsecomp_eval.py b/browsecomp_eval.py index e246d52f..1b261581 100644 --- a/browsecomp_eval.py +++ b/browsecomp_eval.py @@ -9,8 +9,12 @@ import random import re import pandas -from . import common -from .types import Eval, EvalResult, SamplerBase, SingleEvalResult +import pathlib +import urllib.request +import common +from eval_types import Eval, EvalResult, SamplerBase, SingleEvalResult + +CACHE_DIR = pathlib.Path(".simple_evals_cache") # from: https://github.com/centerforaisafety/hle/blob/7b6be5aad6f9b43af3857de7867f3b52f6e4acb3/hle_eval/run_model_predictions.py#L11 QUERY_TEMPLATE = """ @@ -23,7 +27,7 @@ """.strip() # from: https://github.com/centerforaisafety/hle/blob/7b6be5aad6f9b43af3857de7867f3b52f6e4acb3/hle_eval/run_judge_results.py#L16-L33 -GRADER_TEMPLATE = """ +GRADER_TEMPLATE = r""" Judge whether the following [response] to [question] is correct or not based on the precise and unambiguous [correct_answer] below. [question]: {question} @@ -41,7 +45,7 @@ correct: Answer 'yes' if extracted_final_answer matches the [correct_answer] given above, or is within a small margin of error for numerical problems. Answer 'no' otherwise, i.e. if there if there is any inconsistency, ambiguity, non-equivalency, or if the extracted answer is incorrect. -confidence: The extracted confidence score between 0|\%| and 100|\%| from [response]. Put 100 if there is no confidence score available. +confidence: The extracted confidence score between 0|\\%| and 100|\\%| from [response]. Put 100 if there is no confidence score available. """.strip() CHOICE_STRINGS = ["yes", "no"] @@ -64,10 +68,21 @@ def decrypt(ciphertext_b64: str, password: str) -> str: class BrowseCompEval(Eval): - def __init__(self, grader_model: SamplerBase, num_examples: int | None = None, n_repeats: int = 1): - df = pandas.read_csv( - "https://openaipublic.blob.core.windows.net/simple-evals/browse_comp_test_set.csv" - ) + def __init__(self, grader_model: SamplerBase, num_examples: int | None = None, n_repeats: int = 1, batch_size: int = 20, checkpoint_file: str | None = None): + url = "https://openaipublic.blob.core.windows.net/simple-evals/browse_comp_test_set.csv" + + CACHE_DIR.mkdir(parents=True, exist_ok=True) + file_name = url.split("/")[-1] + cached_file_path = CACHE_DIR / file_name + + if cached_file_path.exists(): + print(f"Loading cached BrowseComp data from {cached_file_path}") + df = pandas.read_csv(cached_file_path) + else: + print(f"Downloading BrowseComp data from {url} to {cached_file_path}") + urllib.request.urlretrieve(url, str(cached_file_path)) + df = pandas.read_csv(cached_file_path) + examples = [row.to_dict() for _, row in df.iterrows()] if num_examples: assert n_repeats == 1, "n_repeats only supported when max_examples = None" @@ -93,17 +108,24 @@ def grade_sample(self, question: str, correct_answer: str, response: str) -> str def __call__(self, sampler: SamplerBase) -> EvalResult: def fn(row: dict): + print(1) problem = decrypt(row.get("problem", ""), row.get("canary", "")) answer = decrypt(row.get("answer", ""), row.get("canary", "")) prompt_messages = [ sampler._pack_message(content=QUERY_TEMPLATE.format(Question=problem), role="user") ] response_text = sampler(prompt_messages) + + # Find the first occurrence of "Explanation" and remove everything before it + explanation_index = response_text.find("Explanation:") + if explanation_index != -1: + response_text = response_text[explanation_index:] + grade_result = self.grade_sample(problem, answer, response_text) # Metrics based on grading response - is_correct = grade_result == "yes" - is_incorrect = grade_result == "no" + is_correct = float(grade_result == "yes") + is_incorrect = float(grade_result == "no") score = is_correct @@ -112,10 +134,12 @@ def fn(row: dict): prompt_messages=prompt_messages, next_message=dict(content=response_text, role="assistant"), score=score, - correct_answer=row["answer"], + # correct_answer=row["answer"], + correct_answer=answer, extracted_answer=response_text, ) convo = prompt_messages + [dict(content=response_text, role="assistant")] + print(2) return SingleEvalResult(html=html, score=score, convo=convo, metrics={ "is_correct": is_correct, "is_incorrect": is_incorrect, diff --git a/common.py b/common.py index b6b4c0e1..410b993f 100644 --- a/common.py +++ b/common.py @@ -8,8 +8,9 @@ import numpy as np import requests from tqdm import tqdm +import json -from .types import EvalResult, Message, SamplerBase, SingleEvalResult +from eval_types import EvalResult, Message, SamplerBase, SingleEvalResult QUERY_TEMPLATE_MULTICHOICE = """ Answer the following multiple choice question. The last line of your response should be of the following format: 'Answer: $LETTER' (without quotes) where LETTER is one of ABCD. Think step by step before answering. @@ -29,48 +30,49 @@ ) # All the different ways "Answer" is written in different languages MULTILINGUAL_ANSWER_REGEXES = [ - "Answer\s*:", - "Answer\s*:​​​​​​", # Korean invisible character - "উত্তর\s*:", - "उत्तर\s*:", - "উত্তরঃ", - "উত্তর\s*:", - "Antwort\s*:", - "답변\s*:", - "정답\s*:", - "답\s*:", - "答案\s*:", - "答案\s*:", - "答\s*:", - "答\s*:", - "答复\s*:", - "答曰\s*:", - "الإجابة:", - "الجواب:", - "إجابة:", - "الإجابة النهائية:", - "الإجابة الصحيحة:", - "الإجابة الصحيحة هي:", - "الإجابة هي:", - "الجواب النهائي:", - "Respuesta\s*:", - "Risposta\s*:", - "答え\s*:", - "答え\s*:", - "回答\s*:", - "回答\s*:", - "解答\s*:", - "Jawaban\s*:", - "Réponse\s*:", - "Resposta\s*:", - "Jibu\s*:", - "Idahun\s*:", - "Ìdáhùn\s*:", - "Idáhùn\s*:", - "Àmọ̀nà\s*:", - "Àdáhùn\s*:", - "Ànúgọ\s*:", - "Àṣàyàn\s*:", + r"Answer\s*:", + r"Answer\\s*:", + r"Answer\\s*:​​​​​​", # Korean invisible character + r"উত্তর\\s*:", + r"उत्तर\\s*:", + r"উত্তরঃ", + r"উত্তর\\s*:", + r"Antwort\\s*:", + r"답변\\s*:", + r"정답\\s*:", + r"답\\s*:", + r"答案\\s*:", + r"答案\\s*:", + r"答\\s*:", + r"答\\s*:", + r"答复\\s*:", + r"答曰\\s*:", + "الإجابة:", # No backslash, no r needed + "الجواب:", # No backslash, no r needed + "إجابة:", # No backslash, no r needed + "الإجابة النهائية:", # No backslash, no r needed + "الإجابة الصحيحة:", # No backslash, no r needed + "الإجابة الصحيحة هي:", # No backslash, no r needed + "الإجابة هي:", # No backslash, no r needed + "الجواب النهائي:", # No backslash, no r needed + r"Respuesta\\s*:", + r"Risposta\\s*:", + r"答え\\s*:", + r"答え\\s*:", + r"回答\\s*:", + r"回答\\s*:", + r"解答\\s*:", + r"Jawaban\\s*:", + r"Réponse\\s*:", + r"Resposta\\s*:", + r"Jibu\\s*:", + r"Idahun\\s*:", + r"Ìdáhùn\\s*:", + r"Idáhùn\\s*:", + r"Àmọ̀nà\\s*:", + r"Àdáhùn\\s*:", + r"Ànúgọ\\s*:", + r"Àṣàyàn\\s*:", ] @@ -372,3 +374,76 @@ def url_to_fileobj(url: str, binary=False) -> Any: response = requests.get(url) response.raise_for_status() return io.BytesIO(response.content) if binary else io.StringIO(response.text) + +def load_checkpoint(checkpoint_file: str | None) -> list[SingleEvalResult]: + """Loads processed results from the checkpoint file.""" + if not checkpoint_file or not os.path.exists(checkpoint_file): + if checkpoint_file: # Only print if a path was given + print(f"Checkpoint file {checkpoint_file} not found. Starting fresh.") + return [] + + loaded_results: list[SingleEvalResult] = [] + try: + with open(checkpoint_file, "r") as f: + for line_number, line in enumerate(f, 1): + line_content = line.strip() + if not line_content: # Skip empty lines + continue + try: + data = json.loads(line_content) + # Basic validation for required keys from SingleEvalResult + if not all(k in data for k in ["html", "score"]): + print( + f"Skipping entry with missing core keys ('html', 'score') in checkpoint file {checkpoint_file} at line {line_number}." + ) + continue + + # metrics and convo are optional in SingleEvalResult + result = SingleEvalResult( + html=data["html"], + score=data["score"], + metrics=data.get("metrics"), + convo=data.get("convo"), + ) + loaded_results.append(result) + except json.JSONDecodeError: + print( + f"Skipping malformed JSON in checkpoint file {checkpoint_file} at line {line_number}: {line_content}" + ) + except KeyError as e: + print( + f"Skipping entry with missing key {e} in checkpoint file {checkpoint_file} at line {line_number}: {line_content}" + ) + + if loaded_results: + print( + f"Resumed from checkpoint. Loaded {len(loaded_results)} results from {checkpoint_file}." + ) + else: + # File existed but was empty or all lines were invalid + print( + f"Checkpoint file {checkpoint_file} was empty or contained no valid entries. Starting fresh." + ) + return loaded_results + except Exception as e: # Catch other errors like permission issues + print(f"Error loading checkpoint from {checkpoint_file}: {e}. Starting fresh.") + return [] # Ensure clean state on major load error + + +def save_checkpoint(checkpoint_file: str | None, new_results: list[SingleEvalResult]): + """Appends new results to the checkpoint file.""" + if not checkpoint_file: + return + try: + with open(checkpoint_file, "a") as f: + for result in new_results: + # Convert SingleEvalResult to dict for JSON serialization + result_dict = { + "html": result.html, + "score": result.score, + "metrics": result.metrics, + "convo": result.convo, + } + f.write(json.dumps(result_dict) + "\n") + except Exception as e: + print(f"Error saving checkpoint to {checkpoint_file}: {e}") diff --git a/drop_eval.py b/drop_eval.py index 27918e5b..2a1575df 100644 --- a/drop_eval.py +++ b/drop_eval.py @@ -9,14 +9,16 @@ import random import re import string +import pathlib +import urllib.request from typing import Any, Dict, List, Optional, Set, Tuple, Union import numpy as np from scipy.optimize import linear_sum_assignment -from . import common -from .common import ANSWER_PATTERN, HTML_JINJA -from .types import Eval, EvalResult, SamplerBase, SingleEvalResult +import common +from common import ANSWER_PATTERN, HTML_JINJA +from eval_types import Eval, EvalResult, SamplerBase, SingleEvalResult """ From here through _normalize_answer was originally copied from: @@ -233,20 +235,45 @@ def drop_metric(sample: str, reference: list[str]) -> Tuple[float, float]: return (max(em_scores), max(f1_scores)) +CACHE_DIR = pathlib.Path(".simple_evals_cache") + class DropEval(Eval): def __init__(self, num_examples: int | None = None, train_samples_per_prompt: int = 3): self.seed = 42 self._num_examples = num_examples self._train_samples_per_prompt = train_samples_per_prompt - self.train_jsonl = ( + + train_url = ( "https://openaipublic.blob.core.windows.net/simple-evals/drop_v0_train.jsonl.gz" ) - self.test_jsonl = ( + test_url = ( "https://openaipublic.blob.core.windows.net/simple-evals/drop_v0_dev.jsonl.gz" ) - with gzip.GzipFile(fileobj=common.url_to_fileobj(self.train_jsonl, binary=True), mode="rb") as f: + + CACHE_DIR.mkdir(parents=True, exist_ok=True) + + train_file_name = train_url.split("/")[-1] + cached_train_file_path = CACHE_DIR / train_file_name + + if cached_train_file_path.exists(): + print(f"Loading cached DROP train data from {cached_train_file_path}") + else: + print(f"Downloading DROP train data from {train_url} to {cached_train_file_path}") + urllib.request.urlretrieve(train_url, str(cached_train_file_path)) + + with gzip.GzipFile(filename=str(cached_train_file_path), mode="rb") as f: self.train_samples = list(map(json.loads, f.readlines())) - with gzip.GzipFile(fileobj=common.url_to_fileobj(self.test_jsonl, binary=True), mode="rb") as f: + + test_file_name = test_url.split("/")[-1] + cached_test_file_path = CACHE_DIR / test_file_name + + if cached_test_file_path.exists(): + print(f"Loading cached DROP test data from {cached_test_file_path}") + else: + print(f"Downloading DROP test data from {test_url} to {cached_test_file_path}") + urllib.request.urlretrieve(test_url, str(cached_test_file_path)) + + with gzip.GzipFile(filename=str(cached_test_file_path), mode="rb") as f: self.test_samples = list(map(json.loads, f.readlines())) if self._num_examples: self.test_samples = random.Random(self.seed).sample( diff --git a/types.py b/eval_types.py similarity index 100% rename from types.py rename to eval_types.py diff --git a/gpqa_eval.py b/gpqa_eval.py index 21c717ef..632e3edf 100644 --- a/gpqa_eval.py +++ b/gpqa_eval.py @@ -6,13 +6,16 @@ import random import re +import pathlib +import urllib.request import pandas -from . import common -from .common import ANSWER_PATTERN_MULTICHOICE, HTML_JINJA, format_multichoice_question -from .types import Eval, EvalResult, MessageList, SamplerBase, SingleEvalResult +import common +from common import ANSWER_PATTERN_MULTICHOICE, HTML_JINJA, format_multichoice_question +from eval_types import Eval, EvalResult, MessageList, SamplerBase, SingleEvalResult +CACHE_DIR = pathlib.Path(".simple_evals_cache") class GPQAEval(Eval): def __init__( @@ -20,19 +23,49 @@ def __init__( n_repeats: int = 4, variant: str = "diamond", num_examples: int | None = None, # restrict to a subset of the data for debugging + batch_size: int = 20, # Default batch size, can be configured + checkpoint_file: str | None = None, # Path to the checkpoint file ): - df = pandas.read_csv( - f"https://openaipublic.blob.core.windows.net/simple-evals/gpqa_{variant}.csv" - ) - examples = [row.to_dict() for _, row in df.iterrows()] - rng = random.Random(0) + rng = random.Random(0) + url = f"https://openaipublic.blob.core.windows.net/simple-evals/gpqa_{variant}.csv" + CACHE_DIR.mkdir(parents=True, exist_ok=True) + file_name = url.split("/")[-1] + cached_file_path = CACHE_DIR / file_name + + if cached_file_path.exists(): + print(f"Loading cached GPQA data from {cached_file_path}") + df_repeated = pandas.read_csv(cached_file_path) + else: + print(f"Downloading GPQA data from {url} to {cached_file_path}") + urllib.request.urlretrieve(url, str(cached_file_path)) + df_repeated = pandas.read_csv(cached_file_path) + + all_examples_from_csv = [row.to_dict() for _, row in df_repeated.iterrows()] + if num_examples: - assert n_repeats == 1, "n_repeats only supported for num_examples = None" - examples = rng.sample(examples, num_examples) - examples = examples * n_repeats - examples = [example | {"permutation": rng.sample(range(4), 4)} for example in examples] - self.examples = examples - self.n_repeats = n_repeats + # If num_examples is used, n_repeats is 1 (assertion exists) + # Sample *before* adding permutations. + final_examples_base = rng.sample(all_examples_from_csv, num_examples) + else: + # Use all examples, then repeat + final_examples_base = all_examples_from_csv * n_repeats + + # Now, assign a unique, deterministic permutation to each *instance* in final_examples_base + # To make this deterministic across runs (even with checkpointing), we need a seed for each permutation. + # We can use the index in this final_examples_base list. + self.examples = [] + for i, ex in enumerate(final_examples_base): + # Seed the RNG for permutation with a combination of global seed and instance index + perm_rng = random.Random(i) # Using i as seed for this specific permutation + self.examples.append(ex | {"permutation": perm_rng.sample(range(4), 4)}) + + self.n_repeats = n_repeats # Though less directly used if num_examples is set + self.batch_size = batch_size + self.checkpoint_file = checkpoint_file + self.processed_results: list[SingleEvalResult] = [] + + if self.checkpoint_file: + self.processed_results = common.load_checkpoint(self.checkpoint_file) def __call__(self, sampler: SamplerBase) -> EvalResult: def fn(row: dict): @@ -42,6 +75,7 @@ def fn(row: dict): row["Incorrect Answer 2"], row["Incorrect Answer 3"], ] + # Permutation is now pre-assigned and part of the 'row' choices = [choices[i] for i in row["permutation"]] correct_index = choices.index(row["Correct Answer"]) correct_answer = "ABCD"[correct_index] @@ -69,5 +103,37 @@ def fn(row: dict): html=html, score=score, convo=convo, metrics={"chars": len(response_text)} ) - results = common.map_with_progress(fn, self.examples) - return common.aggregate_results(results) + num_already_processed = len(self.processed_results) + + if num_already_processed >= len(self.examples): + if not self.examples: # No examples to run at all + print("No examples to evaluate.") + return common.aggregate_results([]) # Return empty aggregated result + print("All examples were already processed according to checkpoint.") + return common.aggregate_results(self.processed_results) + + examples_to_process_this_run = self.examples[num_already_processed:] + num_total_examples_in_run = len(self.examples) + + print(f"Starting GPQA evaluation. Total examples: {num_total_examples_in_run}. Already processed: {num_already_processed}. Remaining: {len(examples_to_process_this_run)}.") + + for i in range(0, len(examples_to_process_this_run), self.batch_size): + batch_examples = examples_to_process_this_run[i : i + self.batch_size] + if not batch_examples: + continue + + current_global_start_index_for_batch = num_already_processed + i + batch_start_num_display = current_global_start_index_for_batch + 1 + batch_end_num_display = min(current_global_start_index_for_batch + len(batch_examples), num_total_examples_in_run) + + print(f"Processing batch: examples {batch_start_num_display}-{batch_end_num_display} of {num_total_examples_in_run} (Batch size: {self.batch_size})") + + # batch_new_results: list[SingleEvalResult] = [] # Not needed, common.map_with_progress returns a new list + batch_new_results = common.map_with_progress(fn, batch_examples) + self.processed_results.extend(batch_new_results) + + if self.checkpoint_file and batch_new_results: + common.save_checkpoint(self.checkpoint_file, batch_new_results) + + print(f"GPQA evaluation finished. Processed {len(self.processed_results)} results in total out of {num_total_examples_in_run} examples.") + return common.aggregate_results(self.processed_results) diff --git a/human_eval_windows_patch.py b/human_eval_windows_patch.py new file mode 100644 index 00000000..b5c8b584 --- /dev/null +++ b/human_eval_windows_patch.py @@ -0,0 +1,70 @@ +import threading +import contextlib +import human_eval.execution # Module to be patched +from typing import Any # Only for the print statement type hint, not strictly needed by patch logic + +# 1. Define our custom TimeoutException. +# This will be raised by our patched time_limit and will replace +# the original TimeoutException in human_eval.execution +# so that existing `except TimeoutException:` blocks in human_eval continue to work. +class PatchedTimeoutException(Exception): + """Custom TimeoutException for the patched time_limit function.""" + pass + +# 2. Define the new time_limit implementation using threading.Timer. +@contextlib.contextmanager +def patched_time_limit(seconds: float): + """ + A time_limit context manager compatible with Windows, using threading.Timer. + Replaces human_eval.execution.time_limit. + """ + + # The callback function for threading.Timer. + # It raises our PatchedTimeoutException. + def _timer_callback(): + raise PatchedTimeoutException(f"Code execution timed out after {seconds} seconds.") + + # Create and start the timer. + # threading.Timer will raise appropriate errors if 'seconds' is not a valid number. + timer = threading.Timer(seconds, _timer_callback) + timer.start() + + try: + # Yield control to the block within the 'with' statement. + yield + finally: + # This block executes whether the 'try' block succeeded or failed. + # Crucially, cancel the timer to prevent it from firing if the guarded code finished on time + # or if an exception other than timeout occurred. + timer.cancel() + +# 3. Apply the monkeypatch. +# This should happen when this patch module is imported. + +# Replace the TimeoutException in the human_eval.execution module +# with our PatchedTimeoutException. +human_eval.execution.TimeoutException = PatchedTimeoutException + +# Replace the time_limit function in the human_eval.execution module +# with our patched_time_limit function. +human_eval.execution.time_limit = patched_time_limit + + +# --- Verification (Optional, but helpful for debugging) --- +# You can include these lines to confirm that the patch has been applied +# when this module is imported. +def _confirm_patch(): + patched_module = human_eval.execution + time_limit_func: Any = getattr(patched_module, 'time_limit', None) + timeout_exc: Any = getattr(patched_module, 'TimeoutException', None) + + if time_limit_func is patched_time_limit and timeout_exc is PatchedTimeoutException: + print(f"[human_eval_windows_patch] Successfully patched 'human_eval.execution.time_limit' and 'human_eval.execution.TimeoutException'.") + print(f"[human_eval_windows_patch] human_eval.execution.time_limit is now: {time_limit_func}") + print(f"[human_eval_windows_patch] human_eval.execution.TimeoutException is now: {timeout_exc}") + else: + print(f"[human_eval_windows_patch] WARNING: Patching human_eval.execution may not have been successful.") + print(f" Current time_limit: {time_limit_func}") + print(f" Current TimeoutException: {timeout_exc}") + +# _confirm_patch() \ No newline at end of file diff --git a/humaneval_eval.py b/humaneval_eval.py index 75eab56a..b98f4ce5 100644 --- a/humaneval_eval.py +++ b/humaneval_eval.py @@ -14,20 +14,22 @@ from io import BytesIO from typing import Any, Tuple +import human_eval_windows_patch + from human_eval.data import HUMAN_EVAL, read_problems from human_eval.evaluation import estimate_pass_at_k from human_eval.execution import check_correctness # , unsafe_execute -from . import common -from .common import HTML_JINJA -from .types import Eval, EvalResult, SamplerBase, SingleEvalResult +import common +from common import HTML_JINJA +from eval_types import Eval, EvalResult, SamplerBase, SingleEvalResult def evaluate_functional_correctness( sample: dict[str, str], completions: list[str], n_workers: int = 4, - timeout: float = 3.0, + timeout: float = 30.0, ): """ Evaluates the functional correctness of generated samples, and writes @@ -43,9 +45,11 @@ def evaluate_functional_correctness( future = executor.submit(check_correctness, *args) futures.append(future) results = [] - for future in as_completed(futures): + for i, future in enumerate(as_completed(futures)): result = future.result() results.append(result) + if result.get("result") == "timed out": + print(f"Completion {i} for sample {sample.get('task_id', 'unknown')} timed out.") passed = [int(r["passed"]) for r in results] return passed @@ -54,9 +58,11 @@ class HumanEval(Eval): def __init__( self, num_examples: int = 250, # restrict to a subset of the data for debugging - num_samples_per_task: int = 5, + num_samples_per_task: int = 3, ks_passes: list[int] = [1, 2, 5], timeout: int = 120, + batch_size: int = 20, # Default batch size + checkpoint_file: str | None = None, # Path to the checkpoint file ): self.seed = 0 self.examples = read_problems() @@ -68,6 +74,12 @@ def __init__( self._num_samples_per_task = num_samples_per_task self._ks_passes = ks_passes self._timeout = timeout + self.batch_size = batch_size + self.checkpoint_file = checkpoint_file + self.processed_results: list[SingleEvalResult] = [] + + if self.checkpoint_file: + self.processed_results = common.load_checkpoint(self.checkpoint_file) def __call__(self, sampler: SamplerBase) -> EvalResult: instruction = "Read the following function signature and docstring, and fully implement the function described. Your response should only contain the code for this function.\n" @@ -107,12 +119,61 @@ def fn(sample: dict[str, str]): score=score, convo=convo, metrics={ - f"pass@{k}": estimate_pass_at_k([total], [correct], k) + f"pass@{k}": estimate_pass_at_k([total], [correct], k).tolist() # this will be aggrated so no need of .mean() for k in self._ks_passes if total >= k }, ) - results = common.map_with_progress(fn, self.examples, num_threads=3) - return common.aggregate_results(results) + num_already_processed = len(self.processed_results) + + if num_already_processed >= len(self.examples): + if not self.examples: # No examples to run at all + print("No examples to evaluate.") + return common.aggregate_results([]) # Return empty aggregated result + print("All examples were already processed according to checkpoint.") + return common.aggregate_results(self.processed_results) + + examples_to_process_this_run = self.examples[num_already_processed:] + num_total_examples_in_run = len(self.examples) + + print( + f"Starting evaluation. Total examples: {num_total_examples_in_run}. " + f"Already processed: {num_already_processed}. " + f"Remaining: {len(examples_to_process_this_run)}." + ) + + for i in range(0, len(examples_to_process_this_run), self.batch_size): + batch_examples = examples_to_process_this_run[i : i + self.batch_size] + if not batch_examples: + continue + + current_global_start_index_for_batch = num_already_processed + i + + batch_start_num_display = current_global_start_index_for_batch + 1 + batch_end_num_display = min( + current_global_start_index_for_batch + len(batch_examples), + num_total_examples_in_run, + ) + + print( + f"Processing batch: examples {batch_start_num_display}-" + f"{batch_end_num_display} of {num_total_examples_in_run} " + f"(Batch size: {self.batch_size})" + ) + + batch_new_results: list[SingleEvalResult] = common.map_with_progress( + fn, batch_examples, num_threads=3 + ) + self.processed_results.extend(batch_new_results) # Add new results to the main list + + if self.checkpoint_file and batch_new_results: # Only save if there are new results + common.save_checkpoint( + self.checkpoint_file, batch_new_results + ) # Append only the newly processed results + + print( + f"Evaluation finished. Processed {len(self.processed_results)} results in total out of {num_total_examples_in_run} examples." + ) + return common.aggregate_results(self.processed_results) diff --git a/math_eval.py b/math_eval.py index 4328dcdf..44ccf8a4 100644 --- a/math_eval.py +++ b/math_eval.py @@ -6,13 +6,17 @@ import random import re +import os +import json +import pathlib +import urllib.request from typing import Literal import pandas -from . import common -from .common import ANSWER_PATTERN, HTML_JINJA, check_equality -from .types import Eval, EvalResult, SamplerBase, SingleEvalResult +import common +from common import ANSWER_PATTERN, HTML_JINJA, check_equality +from eval_types import Eval, EvalResult, SamplerBase, SingleEvalResult QUERY_TEMPLATE = """ Solve the following math problem step by step. The last line of your response should be of the form Answer: $ANSWER (without quotes) where $ANSWER is the answer to the problem. @@ -22,6 +26,7 @@ Remember to put your answer on its own line after "Answer:", and you do not need to use a \\boxed command. """.strip() +CACHE_DIR = pathlib.Path(".simple_evals_cache") class MathEval(Eval): def __init__( @@ -30,10 +35,22 @@ def __init__( num_examples: int | None = None, n_repeats: int = 16, split: Literal["math_test", "math_500_test"] = "math_test", + batch_size: int = 20, + checkpoint_file: str | None = None, ): - df = pandas.read_csv( - f"https://openaipublic.blob.core.windows.net/simple-evals/{split}.csv" - ) + url = f"https://openaipublic.blob.core.windows.net/simple-evals/{split}.csv" + CACHE_DIR.mkdir(parents=True, exist_ok=True) + file_name = url.split("/")[-1] + cached_file_path = CACHE_DIR / file_name + + if cached_file_path.exists(): + print(f"Loading cached MATH data from {cached_file_path}") + df = pandas.read_csv(cached_file_path) + else: + print(f"Downloading MATH data from {url} to {cached_file_path}") + urllib.request.urlretrieve(url, str(cached_file_path)) + df = pandas.read_csv(cached_file_path) + examples = [row.to_dict() for _, row in df.iterrows()] if num_examples: assert n_repeats == 1, "n_repeats only supported for num_examples = None" @@ -41,6 +58,12 @@ def __init__( examples = rng.sample(examples, num_examples) self.examples = examples * n_repeats self.equality_checker = equality_checker + self.batch_size = batch_size + self.checkpoint_file = checkpoint_file + self.processed_results: list[SingleEvalResult] = [] + + if self.checkpoint_file: + self.processed_results = common.load_checkpoint(self.checkpoint_file) def __call__(self, sampler: SamplerBase) -> EvalResult: def fn(row: dict): @@ -61,5 +84,38 @@ def fn(row: dict): convo = prompt_messages + [dict(content=response_text, role="assistant")] return SingleEvalResult(html=html, score=score, convo=convo) - results = common.map_with_progress(fn, self.examples) - return common.aggregate_results(results) + num_already_processed = len(self.processed_results) + num_total_examples = len(self.examples) + + if num_already_processed >= num_total_examples: + if not self.examples: # No examples to run at all + print("No examples to evaluate.") + return common.aggregate_results([]) + print("All examples were already processed according to checkpoint.") + return common.aggregate_results(self.processed_results) + + examples_to_process_this_run = self.examples[num_already_processed:] + + print(f"Starting Math evaluation. Total examples (including repeats): {num_total_examples}. Already processed: {num_already_processed}. Remaining: {len(examples_to_process_this_run)}.") + + for i in range(0, len(examples_to_process_this_run), self.batch_size): + batch_examples = examples_to_process_this_run[i : i + self.batch_size] + if not batch_examples: + continue + + current_global_start_index_for_batch = num_already_processed + i + batch_start_num_display = current_global_start_index_for_batch + 1 + batch_end_num_display = min(current_global_start_index_for_batch + len(batch_examples), num_total_examples) + + print(f"Processing batch: examples {batch_start_num_display}-{batch_end_num_display} of {num_total_examples} (Batch size: {self.batch_size})") + + # Note: map_with_progress will show its own progress bar for the batch + batch_new_results = common.map_with_progress(fn, batch_examples) + + self.processed_results.extend(batch_new_results) + + if self.checkpoint_file and batch_new_results: + common.save_checkpoint(self.checkpoint_file, batch_new_results) + + print(f"Math evaluation finished. Processed {len(self.processed_results)} results in total out of {num_total_examples} examples.") + return common.aggregate_results(self.processed_results) diff --git a/mgsm_eval.py b/mgsm_eval.py index 674ac964..f88802d9 100644 --- a/mgsm_eval.py +++ b/mgsm_eval.py @@ -7,10 +7,12 @@ import re from typing import Optional +import pathlib +import urllib.request -from . import common -from .mmlu_eval import HTML_JINJA -from .types import Eval, EvalResult, SamplerBase, SingleEvalResult +import common +from mmlu_eval import HTML_JINJA +from eval_types import Eval, EvalResult, SamplerBase, SingleEvalResult ALL_LANGUAGES = ["bn", "de", "en", "es", "fr", "ja", "ru", "sw", "te", "th", "zh"] LATIN_LANGUAGES = ["de", "en", "es", "fr", "sw"] @@ -79,6 +81,7 @@ "zh": "答案", } +CACHE_DIR = pathlib.Path(".simple_evals_cache") def parse_answer(answer: str, answer_prefix: str) -> str: if answer_prefix not in answer: @@ -106,14 +109,25 @@ def score_mgsm(target: str, prediction: str) -> bool: def get_lang_examples(lang: str) -> list[dict[str, str]]: fpath = LANG_TO_FPATH[lang] + CACHE_DIR.mkdir(parents=True, exist_ok=True) + file_name = fpath.split("/")[-1] + cached_file_path = CACHE_DIR / file_name + + if not cached_file_path.exists(): + print(f"Downloading MGSM data for lang '{lang}' from {fpath} to {cached_file_path}") + urllib.request.urlretrieve(fpath, str(cached_file_path)) + else: + print(f"Loading cached MGSM data for lang '{lang}' from {cached_file_path}") + examples = [] - with common.url_to_fileobj(fpath, binary=True) as f: - for raw_line in f: - line = raw_line.decode("utf-8").strip() + with open(cached_file_path, "r", encoding="utf-8") as f: + for line in f: + line = line.strip() + if not line: + continue inputs, targets = line.split("\t") if "." in targets: raise ValueError(f"targets {targets} contains a decimal point.") - # targets = int(targets.replace(",", "")) examples.append({"inputs": inputs, "targets": targets, "lang": lang}) return examples @@ -132,6 +146,8 @@ def __init__( self, num_examples_per_lang: int = 250, # restrict to a subset of the data for debugging languages: Optional[list[str]] = ALL_LANGUAGES, + batch_size: int = 20, # Default batch size + checkpoint_file: str | None = None, # Path to the checkpoint file ): if languages is None: languages = ALL_LANGUAGES @@ -150,6 +166,13 @@ def __init__( lang_examples = get_lang_examples(lang) examples.extend(lang_examples[: self._num_examples_per_lang]) self.examples = examples + + self.batch_size = batch_size + self.checkpoint_file = checkpoint_file + self.processed_results: list[SingleEvalResult] = [] + + if self.checkpoint_file: + self.processed_results = common.load_checkpoint(self.checkpoint_file) def __call__(self, sampler: SamplerBase) -> EvalResult: def fn(example: dict[str, str]): @@ -186,5 +209,37 @@ def fn(example: dict[str, str]): metrics={language: score, latin_language: score}, ) - results = common.map_with_progress(fn, self.examples) - return common.aggregate_results(results, default_stats=("mean", "std")) + num_already_processed = len(self.processed_results) + + if num_already_processed >= len(self.examples): + if not self.examples: # No examples to run at all + print("No examples to evaluate.") + return common.aggregate_results([]) # Return empty aggregated result + print("All examples were already processed according to checkpoint.") + return common.aggregate_results(self.processed_results) + + examples_to_process_this_run = self.examples[num_already_processed:] + num_total_examples_in_run = len(self.examples) + + print(f"Starting MGSM evaluation. Total examples: {num_total_examples_in_run}. Already processed: {num_already_processed}. Remaining: {len(examples_to_process_this_run)}.") + + for i in range(0, len(examples_to_process_this_run), self.batch_size): + batch_examples = examples_to_process_this_run[i : i + self.batch_size] + if not batch_examples: + continue + + current_global_start_index_for_batch = num_already_processed + i + + batch_start_num_display = current_global_start_index_for_batch + 1 + batch_end_num_display = min(current_global_start_index_for_batch + len(batch_examples), num_total_examples_in_run) + + print(f"Processing batch: examples {batch_start_num_display}-{batch_end_num_display} of {num_total_examples_in_run} (Batch size: {self.batch_size})") + + batch_new_results: list[SingleEvalResult] = common.map_with_progress(fn, batch_examples) + self.processed_results.extend(batch_new_results) + + if self.checkpoint_file and batch_new_results: + common.save_checkpoint(self.checkpoint_file, batch_new_results) + + print(f"MGSM evaluation finished. Processed {len(self.processed_results)} results in total out of {num_total_examples_in_run} examples.") + return common.aggregate_results(self.processed_results, default_stats=("mean", "std")) diff --git a/mmlu_eval.py b/mmlu_eval.py index 9423c660..78c0bd1c 100644 --- a/mmlu_eval.py +++ b/mmlu_eval.py @@ -6,11 +6,15 @@ import random import re +import os +import json +import pathlib +import urllib.request import pandas -from . import common -from .common import ( +import common +from common import ( HTML_JINJA, MULTILINGUAL_ANSWER_PATTERN_TEMPLATE, MULTILINGUAL_ANSWER_REGEXES, @@ -18,7 +22,7 @@ normalize_extracted_answer, normalize_response, ) -from .types import Eval, EvalResult, SamplerBase, SingleEvalResult +from eval_types import Eval, EvalResult, SamplerBase, SingleEvalResult subject2category = { "abstract_algebra": "stem", @@ -80,21 +84,52 @@ "world_religions": "humanities", } +CACHE_DIR = pathlib.Path(".simple_evals_cache") class MMLUEval(Eval): - def __init__(self, num_examples: int | None = None, language: str = "EN-US"): + def __init__( + self, + num_examples: int | None = None, + language: str = "EN-US", + batch_size: int = 20, # Default batch size, can be configured + checkpoint_file: str | None = None, # Path to the checkpoint file + ): if language != "EN-US": url = f"https://openaipublic.blob.core.windows.net/simple-evals/mmlu_{language}.csv" else: url = "https://openaipublic.blob.core.windows.net/simple-evals/mmlu.csv" - df = pandas.read_csv(url) - examples = [row.to_dict() for _, row in df.iterrows()] - if num_examples: - examples = random.Random(0).sample(examples, num_examples) - self.examples = examples + + CACHE_DIR.mkdir(parents=True, exist_ok=True) + file_name = url.split("/")[-1] + cached_file_path = CACHE_DIR / file_name + + if cached_file_path.exists(): + print(f"Loading cached MMLU data from {cached_file_path}") + df = pandas.read_csv(cached_file_path) + else: + print(f"Downloading MMLU data from {url} to {cached_file_path}") + urllib.request.urlretrieve(url, str(cached_file_path)) + df = pandas.read_csv(cached_file_path) + + all_examples_from_csv = [row.to_dict() for _, row in df.iterrows()] + + if num_examples is not None and num_examples > 0: + # Deterministic sampling if num_examples is specified + k = min(num_examples, len(all_examples_from_csv)) + self.examples = random.Random(0).sample(all_examples_from_csv, k) + else: + # Use all examples if num_examples is None, 0, or negative + self.examples = all_examples_from_csv + + self.batch_size = batch_size + self.checkpoint_file = checkpoint_file + self.processed_results: list[SingleEvalResult] = [] # Initialize empty + + if self.checkpoint_file: + self.processed_results = common.load_checkpoint(self.checkpoint_file) def __call__(self, sampler: SamplerBase) -> EvalResult: - def fn(row: dict): + def process_example(row: dict) -> SingleEvalResult: prompt_messages = [ sampler._pack_message( content=format_multichoice_question(row), role="user" @@ -122,5 +157,38 @@ def fn(row: dict): html=html, score=score, metrics={category: score}, convo=convo ) - results = common.map_with_progress(fn, self.examples) - return common.aggregate_results(results) + num_already_processed = len(self.processed_results) + + if num_already_processed >= len(self.examples): + if not self.examples: # No examples to run at all + print("No examples to evaluate.") + return common.aggregate_results([]) # Return empty aggregated result + print("All examples were already processed according to checkpoint.") + return common.aggregate_results(self.processed_results) + + examples_to_process_this_run = self.examples[num_already_processed:] + num_total_examples_in_run = len(self.examples) + + print(f"Starting evaluation. Total examples: {num_total_examples_in_run}. Already processed: {num_already_processed}. Remaining: {len(examples_to_process_this_run)}.") + + for i in range(0, len(examples_to_process_this_run), self.batch_size): + batch_examples = examples_to_process_this_run[i : i + self.batch_size] + if not batch_examples: + continue + + current_global_start_index_for_batch = num_already_processed + i + + batch_start_num_display = current_global_start_index_for_batch + 1 + batch_end_num_display = min(current_global_start_index_for_batch + len(batch_examples), num_total_examples_in_run) + + print(f"Processing batch: examples {batch_start_num_display}-{batch_end_num_display} of {num_total_examples_in_run} (Batch size: {self.batch_size})") + + batch_new_results: list[SingleEvalResult] = [] + batch_new_results = common.map_with_progress(process_example, batch_examples) + self.processed_results.extend(batch_new_results) # Add new results to the main list + + if self.checkpoint_file and batch_new_results: # Only save if there are new results + common.save_checkpoint(self.checkpoint_file, batch_new_results) # Append only the newly processed results + + print(f"Evaluation finished. Processed {len(self.processed_results)} results in total out of {num_total_examples_in_run} examples.") + return common.aggregate_results(self.processed_results) diff --git a/requirements.txt b/requirements.txt new file mode 100644 index 00000000..5c767f5d Binary files /dev/null and b/requirements.txt differ diff --git a/run_multilingual_mmlu.py b/run_multilingual_mmlu.py index 2ee367dc..badcff10 100644 --- a/run_multilingual_mmlu.py +++ b/run_multilingual_mmlu.py @@ -108,13 +108,13 @@ def get_evals(eval_name): result = eval_obj(sampler) # ^^^ how to use a sampler file_stem = f"{eval_name}_{sampler_name}" - report_filename = f"/tmp/{file_stem}{debug_suffix}.html" + report_filename = f"./tmp/{file_stem}{debug_suffix}.html" print(f"Writing report to {report_filename}") with open(report_filename, "w") as fh: fh.write(common.make_report(result)) metrics = result.metrics | {"score": result.score} print(metrics) - result_filename = f"/tmp/{file_stem}{debug_suffix}.json" + result_filename = f"./tmp/{file_stem}{debug_suffix}.json" with open(result_filename, "w") as f: f.write(json.dumps(metrics, indent=2)) print(f"Writing results to {result_filename}") diff --git a/sampler/__init__.py b/sampler/__init__.py new file mode 100644 index 00000000..191395d1 --- /dev/null +++ b/sampler/__init__.py @@ -0,0 +1 @@ +# This file makes Python treat the 'sampler' directory as a package. \ No newline at end of file diff --git a/sampler/chat_completion_sampler.py b/sampler/chat_completion_sampler.py index d75ce918..9987ab0b 100644 --- a/sampler/chat_completion_sampler.py +++ b/sampler/chat_completion_sampler.py @@ -1,11 +1,18 @@ import base64 import time +import os + +os.environ["HTTP_PROXY"] = "http://localhost:1080" +os.environ["HTTPS_PROXY"] = "http://localhost:1080" + +import subprocess from typing import Any +from datetime import datetime, timedelta import openai from openai import OpenAI -from ..types import MessageList, SamplerBase +from eval_types import MessageList, SamplerBase OPENAI_SYSTEM_MESSAGE_API = "You are a helpful assistant." OPENAI_SYSTEM_MESSAGE_CHATGPT = ( @@ -25,16 +32,46 @@ def __init__( system_message: str | None = None, temperature: float = 0.5, max_tokens: int = 1024, + base_url: str | None = None, ): self.api_key_name = "OPENAI_API_KEY" - self.client = OpenAI() - # using api_key=os.environ.get("OPENAI_API_KEY") # please set your API_KEY + self.api_key = os.environ.get(self.api_key_name) + self.base_url = base_url + self.token_expiry = None + self._refresh_token_if_needed() + self.client = OpenAI(api_key=self.api_key, base_url=self.base_url) self.model = model self.system_message = system_message self.temperature = temperature self.max_tokens = max_tokens self.image_format = "url" + def _refresh_token_if_needed(self): + """Refresh the Google Cloud token if it's expired or about to expire (within 5 minutes)""" + if not self.api_key and self.base_url: + current_time = datetime.now() + if not self.token_expiry or current_time + timedelta(minutes=5) >= self.token_expiry: + try: + print("Fetching new token from gcloud.") + result = subprocess.run( + "gcloud auth print-access-token", + capture_output=True, + text=True, + check=True, + shell=True + ) + self.api_key = result.stdout.strip() + # Set token expiry to 55 minutes from now (giving 5-minute buffer) + self.token_expiry = current_time + timedelta(minutes=55) + except FileNotFoundError: + print("gcloud command not found. Please ensure gcloud SDK is installed and in your PATH.") + self.api_key = None + except subprocess.CalledProcessError as e: + print(f"Error fetching token from gcloud: {e}") + self.api_key = None + elif not self.api_key: + self.api_key = "" + def _handle_image( self, image: str, encoding: str = "base64", format: str = "png", fovea: int = 768 ): @@ -58,6 +95,11 @@ def __call__(self, message_list: MessageList) -> str: trial = 0 while True: try: + # Refresh token if needed before making the API call + self._refresh_token_if_needed() + # Update client with potentially new token + self.client = OpenAI(api_key=self.api_key, base_url=self.base_url) + response = self.client.chat.completions.create( model=self.model, messages=message_list, @@ -70,6 +112,7 @@ def __call__(self, message_list: MessageList) -> str: print("Bad Request Error", e) return "" except Exception as e: + print("Error", e) exception_backoff = 2**trial # expontial back off print( f"Rate limit exception so wait and retry {trial} after {exception_backoff} sec", diff --git a/sampler/claude_sampler.py b/sampler/claude_sampler.py index cf61b441..349b1d5e 100644 --- a/sampler/claude_sampler.py +++ b/sampler/claude_sampler.py @@ -2,7 +2,7 @@ import anthropic -from ..types import MessageList, SamplerBase +from eval_types import MessageList, SamplerBase CLAUDE_SYSTEM_MESSAGE_LMSYS = ( "The assistant is Claude, created by Anthropic. The current date is " diff --git a/sampler/claude_vertex_sampler.py b/sampler/claude_vertex_sampler.py new file mode 100644 index 00000000..4ecdbdbc --- /dev/null +++ b/sampler/claude_vertex_sampler.py @@ -0,0 +1,173 @@ +import time + +import anthropic +from anthropic import AnthropicVertex + +from eval_types import MessageList, SamplerBase + +CLAUDE_SYSTEM_MESSAGE_LMSYS = ( + "The assistant is Claude, created by Anthropic. The current date is " + "{currentDateTime}. Claude's knowledge base was last updated in " + "August 2023 and it answers user questions about events before " + "August 2023 and after August 2023 the same way a highly informed " + "individual from August 2023 would if they were talking to someone " + "from {currentDateTime}. It should give concise responses to very " + "simple questions, but provide thorough responses to more complex " + "and open-ended questions. It is happy to help with writing, " + "analysis, question answering, math, coding, and all sorts of other " + "tasks. It uses markdown for coding. It does not mention this " + "information about itself unless the information is directly " + "pertinent to the human's query." +).format(currentDateTime="2024-04-01") +# reference: https://github.com/lm-sys/FastChat/blob/7899355ebe32117fdae83985cf8ee476d2f4243f/fastchat/conversation.py#L894 + + +class ClaudeVertexCompletionSampler(SamplerBase): + """ + Sample from Claude API + """ + + def __init__( + self, + model: str = "claude-3-opus-20240229", + system_message: str | None = None, + temperature: float = 0.0, # default in Anthropic example + max_tokens: int = 1024, + location: str = "us-east5", + project_id: str = "{your-project-id}", + ): + self.api_key_name = "ANTHROPIC_API_KEY" + self.client = AnthropicVertex(region=location, project_id=project_id) + # using api_key=os.environ.get("ANTHROPIC_API_KEY") # please set your API_KEY + self.model = model + self.system_message = system_message + self.temperature = temperature + self.max_tokens = max_tokens + self.image_format = "base64" + message = self.client.messages.create( + max_tokens=1024, + messages=[ + { + "role": "user", + "content": "Send me a recipe for banana bread.", + } + ], + model=self.model + ) + print(message.model_dump_json(indent=2)) + + def _handle_image( + self, image: str, encoding: str = "base64", format: str = "png", fovea: int = 768 + ): + new_image = { + "type": "image", + "source": { + "type": encoding, + "media_type": f"image/{format}", + "data": image, + }, + } + return new_image + + def _handle_text(self, text): + return {"type": "text", "text": text} + + def _pack_message(self, role, content): + return {"role": str(role), "content": content} + + def _convert_messages(self, message_list: MessageList) -> list[dict]: + processed_api_messages = [] + for original_message in message_list: + role = original_message.get("role") + content = original_message.get("content") + + if role not in ["user", "assistant"]: + print(f"Warning: Skipping message with unsupported role '{role}'. Only 'user' and 'assistant' roles are processed for the 'messages' parameter.") + continue + + api_content: str | list = "" # Placeholder, will be updated + + if isinstance(content, str): + api_content = content + elif isinstance(content, list): + api_content_parts = [] + for part in content: + part_type = part.get("type") + if part_type == "text": + text = part.get("text") + if text: + api_content_parts.append(self._handle_text(text)) + elif part_type == "image_url": + image_url_spec = part.get("image_url") + if image_url_spec and "url" in image_url_spec: + image_url_data = image_url_spec["url"] + if image_url_data.startswith("data:image/") and ";base64," in image_url_data: + try: + header, base64_data = image_url_data.split(",", 1) + # header is "data:image/;base64" + mime_part = header.split(":")[1].split(";")[0] # "image/" + image_format = mime_part.split("/")[1] # "" + api_content_parts.append( + self._handle_image( + image=base64_data, + encoding="base64", + format=image_format, + ) + ) + except (ValueError, IndexError) as e: + print(f"Warning: Could not parse image_url data URI '{image_url_data}': {e}") + else: + print(f"Warning: Unsupported image_url format: {image_url_data}. Expected 'data:image/;base64,'.") + else: + print(f"Warning: Skipping image part with missing or invalid 'image_url' spec: {part}") + else: + print(f"Warning: Skipping unsupported part type '{part_type}' in message content.") + + if not api_content_parts: + print(f"Warning: Message from role '{role}' resulted in no content parts after processing. Skipping message.") + continue + api_content = api_content_parts + else: + print(f"Warning: Skipping message from role '{role}' with unsupported content type: {type(content)}. Expected str or list.") + continue + + processed_api_messages.append({"role": str(role), "content": api_content}) + + return processed_api_messages + + def __call__(self, message_list: MessageList) -> str: + # Convert the input message_list to the format expected by Anthropic API + api_messages = self._convert_messages(message_list) + + if not api_messages: + # If, after conversion, there are no messages to send + print("Warning: No valid messages to send to Claude API after conversion. Returning empty string.") + return "" + + trial = 0 + while True: + try: + message = self.client.messages.create( + model=self.model, + system=self.system_message, + max_tokens=self.max_tokens, + temperature=self.temperature, + messages=api_messages, # Use the converted messages + ) + # Assuming Claude's response structure provides content in a list, + # and we need the text from the first content block. + if message.content and isinstance(message.content, list) and len(message.content) > 0 and message.content[0].type == "text": + return message.content[0].text + else: + # Handle cases where response might not be as expected or is empty + print(f"Warning: Unexpected response content structure from Claude API: {message.content}") + return "" # Or raise an error, or return a string representation + except anthropic.RateLimitError as e: + exception_backoff = 2**trial # expontial back off + print( + f"Rate limit exception so wait and retry {trial} after {exception_backoff} sec", + e, + ) + time.sleep(exception_backoff) + trial += 1 + # unknown error shall throw exception diff --git a/sampler/gemini_sampler.py b/sampler/gemini_sampler.py new file mode 100644 index 00000000..c0ae9105 --- /dev/null +++ b/sampler/gemini_sampler.py @@ -0,0 +1,178 @@ +import os +import time +from typing import Any +from datetime import datetime, timedelta +import subprocess +import base64 + +from google import genai +from google.genai.types import HttpOptions, Part, Content + +from eval_types import MessageList, SamplerBase + +class GeminiSampler(SamplerBase): + """ + Sample from Google's Gemini API + """ + + def __init__( + self, + model: str = "gemini-2.0-flash-001", + temperature: float = 0.1, + max_tokens: int = 4096, + project_id: str | None = None, + location: str = "us-central1", + api_key: str | None = None, + use_gemini_grounding: bool = False, + ): + self.project_id = project_id + self.location = location + self.api_key = api_key or os.environ.get("GOOGLE_API_KEY") + self.token_expiry = None + self.use_gemini_grounding = use_gemini_grounding + # self._refresh_token_if_needed() + + # Initialize the Gemini client + if self.api_key: + # Use API key authentication + self.client = genai.Client( + api_key=self.api_key, + http_options=HttpOptions(api_version="v1") + ) + else: + # self._refresh_token_if_needed() + # Use Vertex AI authentication + if not self.project_id or not self.location: + raise ValueError("Project ID and location must be provided for Vertex AI mode when API key is not used.") + self.client = genai.Client( + vertexai=True, + project=self.project_id, + location=self.location + ) + response = self.client.models.generate_content( + model=model, + contents="How does AI work?", + ) + print(response.text) + self.model = model + self.temperature = temperature + self.max_tokens = max_tokens + + def _pack_message(self, role: str, content: Any): + """Pack a message in the standard format expected by eval scripts.""" + return {"role": str(role), "content": content} + + def _refresh_token_if_needed(self): + """Refresh the Google Cloud token if it's expired or about to expire (within 5 minutes). + This is primarily for scenarios where ADC isn't fully set up and explicit token fetching is a fallback, + and when not using an API key and not in a Vertex AI context where project_id implies ADC. + """ + if not self.api_key and not self.project_id: + current_time = datetime.now() + if not self.token_expiry or current_time + timedelta(minutes=5) >= self.token_expiry: + try: + print("Attempting to fetch new token from gcloud (fallback mechanism).") + result = subprocess.run( + "gcloud auth print-access-token", + capture_output=True, + text=True, + check=True, + shell=True + ) + os.environ["GOOGLE_API_KEY"] = result.stdout.strip() + os.environ["API_KEY"] = result.stdout.strip() + print(f"Fetched token via gcloud (expires in ~1 hour). Note: SDK should ideally use ADC via gcloud setup.") + self.token_expiry = current_time + timedelta(minutes=55) + except FileNotFoundError: + print("gcloud command not found. Please ensure gcloud SDK is installed and in your PATH for fallback token refresh.") + except subprocess.CalledProcessError as e: + print(f"Error fetching token from gcloud: {e}. Fallback token refresh failed.") + + def _convert_messages(self, message_list: MessageList) -> list[Content]: + """Convert MessageList format to Gemini API's expected format (list of Content objects). + Simplified for text-to-text only. + """ + converted_contents: list[Content] = [] + for msg in message_list: + role = msg.get("role", "user").lower() + content = msg.get("content") + + # Map 'assistant' role to 'model' for Gemini + if role == "assistant": + role = "model" + + # Ensure role is either 'user' or 'model' + if role not in ["user", "model"]: + print(f"Warning: Invalid role '{role}' encountered. Defaulting to 'user'.") + role = "user" + + parts_list: list[Part] = [] # Ensure this is a list of Part objects + + if isinstance(content, str): + # For simple text content, wrap it in a Part object + parts_list.append(Part(text=content)) + elif isinstance(content, list): + # For list content (potentially from multimodal, but we'll only process text parts) + for item in content: + if isinstance(item, dict) and item.get("type") == "text": + text_content = item.get("text", "") + if text_content: # Only add if text is not empty + parts_list.append(Part(text=text_content)) + # Other types (like image_url) are ignored as per simplification request + else: + print(f"Warning: Skipping message with unknown or non-string/non-list content type: {type(content)}") + continue # Skip this message + + if parts_list: # Only add if there are actual parts to send + converted_contents.append(Content(role=role, parts=parts_list)) + else: + print(f"Warning: Message from role '{role}' resulted in no parts and was skipped.") + + return converted_contents + + def __call__(self, message_list: MessageList) -> str: + trial = 0 + while True: + try: + # self._refresh_token_if_needed() + + contents = self._convert_messages(message_list) + + if not contents: + print("Warning: No content to send after message conversion. Returning empty string.") + return "" + + generation_config = { + "temperature": self.temperature, + "max_output_tokens": self.max_tokens, + } + + if self.use_gemini_grounding: + from google.genai.types import Tool, GenerateContentConfig, GoogleSearch + google_search_tool = Tool( + google_search = GoogleSearch() + ) + + print("INFO: Gemini grounding is enabled (conceptual).") + generation_config["tools"] = [google_search_tool] + # generation_config["google_search"] = GoogleSearch() + + response = self.client.models.generate_content( + model=self.model, + contents=contents, + config=generation_config, + ) + # print(response.candidates[0].grounding_metadata.search_entry_point.rendered_content) + return response.text + except Exception as e: + print(f"Error during API call (trial {trial+1}): {e}") + trial += 1 + if trial >= 5: + print("Maximum retries reached. Raising exception.") + raise e + + exception_backoff = 2**trial + print( + f"Rate limit or other exception. Waiting {exception_backoff} sec before retry {trial+1}...", + ) + time.sleep(exception_backoff) \ No newline at end of file diff --git a/sampler/o_chat_completion_sampler.py b/sampler/o_chat_completion_sampler.py index 718d02a8..7eb80f16 100644 --- a/sampler/o_chat_completion_sampler.py +++ b/sampler/o_chat_completion_sampler.py @@ -1,10 +1,11 @@ import time from typing import Any +import os import openai from openai import OpenAI -from ..types import MessageList, SamplerBase +from eval_types import MessageList, SamplerBase class OChatCompletionSampler(SamplerBase): @@ -19,7 +20,11 @@ def __init__( model: str = "o1-mini", ): self.api_key_name = "OPENAI_API_KEY" - self.client = OpenAI() + # set api key to empty if env is not configured + api_key = os.environ.get(self.api_key_name) + if not api_key: + api_key = "" + self.client = OpenAI(api_key=api_key) # using api_key=os.environ.get("OPENAI_API_KEY") # please set your API_KEY self.model = model self.image_format = "url" diff --git a/sampler/responses_sampler.py b/sampler/responses_sampler.py index 1e49b21c..5f50437a 100644 --- a/sampler/responses_sampler.py +++ b/sampler/responses_sampler.py @@ -6,7 +6,7 @@ import openai from openai import OpenAI -from ..types import MessageList, SamplerBase +from eval_types import MessageList, SamplerBase class ResponsesSampler(SamplerBase): @@ -24,8 +24,12 @@ def __init__( reasoning_effort: str | None = None, ): self.api_key_name = "OPENAI_API_KEY" - assert os.environ.get("OPENAI_API_KEY"), "Please set OPENAI_API_KEY" - self.client = OpenAI() + # assert os.environ.get("OPENAI_API_KEY"), "Please set OPENAI_API_KEY" + # set api key to empty if env is not configured + api_key = os.environ.get(self.api_key_name) + if not api_key: + api_key = "" + self.client = OpenAI(api_key=api_key) self.model = model self.system_message = system_message self.temperature = temperature diff --git a/simple_evals.py b/simple_evals.py index 7dc9d4b2..5cead04d 100644 --- a/simple_evals.py +++ b/simple_evals.py @@ -1,23 +1,26 @@ import json import argparse import pandas as pd -from . import common -from .browsecomp_eval import BrowseCompEval -from .drop_eval import DropEval -from .gpqa_eval import GPQAEval -from .humaneval_eval import HumanEval -from .math_eval import MathEval -from .mgsm_eval import MGSMEval -from .mmlu_eval import MMLUEval -from .simpleqa_eval import SimpleQAEval -from .sampler.chat_completion_sampler import ( +import os +import common +from browsecomp_eval import BrowseCompEval +from drop_eval import DropEval +from gpqa_eval import GPQAEval +from humaneval_eval import HumanEval +from math_eval import MathEval +from mgsm_eval import MGSMEval +from mmlu_eval import MMLUEval +from simpleqa_eval import SimpleQAEval +from sampler.chat_completion_sampler import ( OPENAI_SYSTEM_MESSAGE_API, OPENAI_SYSTEM_MESSAGE_CHATGPT, ChatCompletionSampler, ) -from .sampler.o_chat_completion_sampler import OChatCompletionSampler -from .sampler.responses_sampler import ResponsesSampler -from .sampler.claude_sampler import ClaudeCompletionSampler, CLAUDE_SYSTEM_MESSAGE_LMSYS +from sampler.o_chat_completion_sampler import OChatCompletionSampler +from sampler.responses_sampler import ResponsesSampler +from sampler.claude_sampler import ClaudeCompletionSampler, CLAUDE_SYSTEM_MESSAGE_LMSYS +from sampler.gemini_sampler import GeminiSampler +from sampler.claude_vertex_sampler import ClaudeVertexCompletionSampler def main(): @@ -32,6 +35,24 @@ def main(): parser.add_argument( "--examples", type=int, help="Number of examples to use (overrides default)" ) + parser.add_argument( + "--evals", + type=str, + nargs="+", + help="Specify one or more evaluation suites to run (e.g., mmlu math)", + default=["simpleqa", "mmlu", "math", "gpqa", "mgsm", "drop", "humaneval", "browsecomp"], + ) + parser.add_argument( + "--checkpoint_dir", + type=str, + help="Directory to store and load checkpoint files for each eval. If not provided, checkpointing is disabled.", + default=None, + ) + parser.add_argument( + "--use-gemini-grounding", + action="store_true", + help="Enable Gemini grounding API for Gemini models." + ) args = parser.parse_args() @@ -140,6 +161,29 @@ def main(): model="claude-3-opus-20240229", system_message=CLAUDE_SYSTEM_MESSAGE_LMSYS, ), + # Llama models: + "llama-4-maverick-17b-128e-instruct-maas": ChatCompletionSampler( + model="meta/llama-4-maverick-17b-128e-instruct-maas", + system_message=OPENAI_SYSTEM_MESSAGE_API, + base_url="https://us-east5-aiplatform.googleapis.com/v1/projects/{your-project-id}/locations/us-east5/endpoints/openapi" + ), + "gemini-2.5-pro-preview-05-06": GeminiSampler( + model="gemini-2.5-pro-preview-05-06", + project_id="{your-project-id}", + location="us-central1", + use_gemini_grounding=args.use_gemini_grounding, + ), + "gemini-2.0-flash-001": GeminiSampler( + model="gemini-2.0-flash-001", + project_id="{your-project-id}", + location="us-central1", + use_gemini_grounding=args.use_gemini_grounding, + ), + "claude-3-7-sonnet": ClaudeVertexCompletionSampler( + model="claude-3-7-sonnet@20250219", + project_id="{your-project-id}", + location="us-east5", + ), } if args.list_models: @@ -154,71 +198,99 @@ def main(): return models = {args.model: models[args.model]} - grading_sampler = ChatCompletionSampler(model="gpt-4o") - equality_checker = ChatCompletionSampler(model="gpt-4-turbo-preview") + # grading_sampler = ChatCompletionSampler(model="gpt-4o") + grading_sampler = ChatCompletionSampler( + model="meta/llama-4-maverick-17b-128e-instruct-maas", + system_message=OPENAI_SYSTEM_MESSAGE_API, + base_url="https://us-east5-aiplatform.googleapis.com/v1/projects/{your-project-id}/locations/us-east5/endpoints/openapi" + ) + # equality_checker = ChatCompletionSampler(model="gpt-4-turbo-preview") + equality_checker = ChatCompletionSampler( + model="meta/llama-4-maverick-17b-128e-instruct-maas", + system_message=OPENAI_SYSTEM_MESSAGE_API, + base_url="https://us-east5-aiplatform.googleapis.com/v1/projects/{your-project-id}/locations/us-east5/endpoints/openapi" + ) # ^^^ used for fuzzy matching, just for math - def get_evals(eval_name, debug_mode): + def get_evals(eval_name, debug_mode, checkpoint_dir, model_name_for_checkpoint): num_examples = ( args.examples if args.examples is not None else (5 if debug_mode else None) ) + + checkpoint_file_path = None + if checkpoint_dir and model_name_for_checkpoint: + # Construct a debug suffix consistent with the one used for report filenames + debug_suffix_for_file = "_DEBUG" if debug_mode else "" + # Sanitize model_name_for_checkpoint by replacing slashes with underscores for valid filenames + sanitized_model_name = model_name_for_checkpoint.replace("/", "_") + checkpoint_filename = f"{eval_name}_{sanitized_model_name}{debug_suffix_for_file}.jsonl" + checkpoint_file_path = os.path.join(checkpoint_dir, checkpoint_filename) + os.makedirs(checkpoint_dir, exist_ok=True) # Ensure directory exists + # Set num_examples = None to reproduce full evals match eval_name: case "mmlu": - return MMLUEval(num_examples=1 if debug_mode else num_examples) + return MMLUEval(num_examples=1 if debug_mode else num_examples, checkpoint_file=checkpoint_file_path) case "math": return MathEval( equality_checker=equality_checker, num_examples=num_examples, n_repeats=1 if debug_mode else 10, + checkpoint_file=checkpoint_file_path ) case "gpqa": return GPQAEval( - n_repeats=1 if debug_mode else 10, num_examples=num_examples + n_repeats=1 if debug_mode else 10, num_examples=num_examples, checkpoint_file=checkpoint_file_path ) - case "mgsm": - return MGSMEval(num_examples_per_lang=10 if debug_mode else 250) + case "mgsm": # MGSMEval might need specific handling for num_examples_per_lang with checkpointing + return MGSMEval(num_examples_per_lang=10 if debug_mode else 250, checkpoint_file=checkpoint_file_path) case "drop": return DropEval( num_examples=10 if debug_mode else num_examples, train_samples_per_prompt=3, + checkpoint_file=checkpoint_file_path ) - case "humaneval": - return HumanEval(num_examples=10 if debug_mode else num_examples) + case "humaneval": # HumanEval might process problems, checkpointing logic might differ + return HumanEval(num_examples=10 if debug_mode else num_examples, checkpoint_file=checkpoint_file_path) case "simpleqa": return SimpleQAEval( grader_model=grading_sampler, num_examples=10 if debug_mode else num_examples, + checkpoint_file=checkpoint_file_path ) case "browsecomp": return BrowseCompEval( grader_model=grading_sampler, num_examples=10 if debug_mode else num_examples, + checkpoint_file=checkpoint_file_path ) case _: raise Exception(f"Unrecognized eval type: {eval_name}") - evals = { - eval_name: get_evals(eval_name, args.debug) - for eval_name in ["simpleqa", "mmlu", "math", "gpqa", "mgsm", "drop", "humaneval", "browsecomp"] - } - print(evals) + evals_dict = {} # Changed from evals to evals_dict to avoid conflict + for model_name, sampler in models.items(): + current_model_evals = { + eval_name: get_evals(eval_name, args.debug, args.checkpoint_dir, model_name) # Pass checkpoint_dir and model_name + for eval_name in args.evals + } + evals_dict[model_name] = current_model_evals + debug_suffix = "_DEBUG" if args.debug else "" - print(debug_suffix) + mergekey2resultpath = {} for model_name, sampler in models.items(): - for eval_name, eval_obj in evals.items(): + for eval_name, eval_obj in evals_dict[model_name].items(): # Iterate through the new structure result = eval_obj(sampler) # ^^^ how to use a sampler file_stem = f"{eval_name}_{model_name}" - report_filename = f"/tmp/{file_stem}{debug_suffix}.html" + report_filename = f"./tmp/{file_stem}{debug_suffix}.html" print(f"Writing report to {report_filename}") - with open(report_filename, "w") as fh: + with open(report_filename, "w", encoding="utf-8") as fh: fh.write(common.make_report(result)) metrics = result.metrics | {"score": result.score} print(metrics) - result_filename = f"/tmp/{file_stem}{debug_suffix}.json" - with open(result_filename, "w") as f: + result_filename = f"./tmp/{file_stem}{debug_suffix}.json" + with open(result_filename, "w", encoding="utf-8") as f: f.write(json.dumps(metrics, indent=2)) print(f"Writing results to {result_filename}") mergekey2resultpath[f"{file_stem}"] = result_filename diff --git a/simpleqa_eval.py b/simpleqa_eval.py index 2a1390a0..d48fe83d 100644 --- a/simpleqa_eval.py +++ b/simpleqa_eval.py @@ -7,8 +7,10 @@ import random import re import pandas -from . import common -from .types import Eval, EvalResult, SamplerBase, SingleEvalResult +import pathlib +import urllib.request +import common +from eval_types import Eval, EvalResult, SamplerBase, SingleEvalResult GRADER_TEMPLATE = """ Your job is to look at a question, a gold target, and a predicted answer, and then assign a grade of either ["CORRECT", "INCORRECT", "NOT_ATTEMPTED"]. @@ -96,18 +98,42 @@ CHOICE_STRINGS = ["CORRECT", "INCORRECT", "NOT_ATTEMPTED"] CHOICE_LETTER_TO_STRING = dict(zip(CHOICE_LETTERS, CHOICE_STRINGS)) +CACHE_DIR = pathlib.Path(".simple_evals_cache") + class SimpleQAEval(Eval): - def __init__(self, grader_model: SamplerBase, num_examples: int | None = None, n_repeats: int = 1): - df = pandas.read_csv( - "https://openaipublic.blob.core.windows.net/simple-evals/simple_qa_test_set.csv" - ) + def __init__(self, grader_model: SamplerBase, num_examples: int | None = None, n_repeats: int = 1, batch_size: int = 20, checkpoint_file: str | None = None): + url = "https://openaipublic.blob.core.windows.net/simple-evals/simple_qa_test_set.csv" + CACHE_DIR.mkdir(parents=True, exist_ok=True) + file_name = url.split("/")[-1] + cached_file_path = CACHE_DIR / file_name + + if cached_file_path.exists(): + print(f"Loading cached data from {cached_file_path}") + df = pandas.read_csv(cached_file_path) + else: + print(f"Downloading data from {url} to {cached_file_path}") + urllib.request.urlretrieve(url, str(cached_file_path)) + df = pandas.read_csv(cached_file_path) + examples = [row.to_dict() for _, row in df.iterrows()] if num_examples: - assert n_repeats == 1, "n_repeats only supported when max_examples = None" + # n_repeats is asserted to be 1 if num_examples is specified. + # This means we first sample, then potentially repeat (though current logic makes repeats=1). + # If checkpointing, the initial sample is fixed. + assert n_repeats == 1, "n_repeats must be 1 when num_examples is specified" rng = random.Random(0) examples = rng.sample(examples, num_examples) + self.examples = examples * n_repeats self.grader_model = grader_model + self.batch_size = batch_size + self.checkpoint_file = checkpoint_file + self.processed_results: list[SingleEvalResult] = [] + + if self.checkpoint_file: + # Assuming common.load_checkpoint loads all results from the file + self.processed_results = common.load_checkpoint(self.checkpoint_file) + def grade_sample(self, question: str, target: str, predicted_answer: str) -> str: grader_prompt = GRADER_TEMPLATE.format( @@ -125,17 +151,18 @@ def grade_sample(self, question: str, target: str, predicted_answer: str) -> str return match.group(0) if match else "C" # Default to "NOT_ATTEMPTED" if no match def __call__(self, sampler: SamplerBase) -> EvalResult: - def fn(row: dict): + def process_example(row: dict): prompt_messages = [ sampler._pack_message(content=row.get("problem", ""), role="user") ] + response_text = sampler(prompt_messages) grade_letter = self.grade_sample(row.get("problem", ""), row.get("answer", ""), response_text) - + # Metrics based on grading response - is_correct = grade_letter == "A" - is_incorrect = grade_letter == "B" - is_not_attempted = grade_letter == "C" + is_correct = float(grade_letter == "A") + is_incorrect = float(grade_letter == "B") + is_not_attempted = float(grade_letter == "C") score = is_correct @@ -148,20 +175,60 @@ def fn(row: dict): extracted_answer=response_text, ) convo = prompt_messages + [dict(content=response_text, role="assistant")] - return SingleEvalResult(html=html, score=score, convo=convo, metrics={ - "is_correct": is_correct, - "is_incorrect": is_incorrect, - "is_not_attempted": is_not_attempted - }) + try: + result = SingleEvalResult(html=html, score=score, convo=convo, metrics={ + "is_correct": is_correct, + "is_incorrect": is_incorrect, + "is_not_attempted": is_not_attempted + }) + except Exception as e: + print("Error: ", e) + return None + print(6) + return result + + num_already_processed = len(self.processed_results) + + if not self.examples: # No examples to run at all + print("No examples to evaluate.") + return common.aggregate_results([]) # Return empty aggregated result + + if num_already_processed >= len(self.examples): + print(f"All {len(self.examples)} examples were already processed according to checkpoint.") + # Final aggregation logic will use self.processed_results + else: + examples_to_process_this_run = self.examples[num_already_processed:] + num_total_examples_in_run = len(self.examples) + + print(f"Starting evaluation. Total examples: {num_total_examples_in_run}. Already processed: {num_already_processed}. Remaining: {len(examples_to_process_this_run)}.") + + for i in range(0, len(examples_to_process_this_run), self.batch_size): + batch_examples = examples_to_process_this_run[i : i + self.batch_size] + if not batch_examples: + continue + + current_global_start_index_for_batch = num_already_processed + i + batch_start_num_display = current_global_start_index_for_batch + 1 + batch_end_num_display = min(current_global_start_index_for_batch + len(batch_examples), num_total_examples_in_run) + + print(f"Processing batch: examples {batch_start_num_display}-{batch_end_num_display} of {num_total_examples_in_run} (Batch size: {self.batch_size})") + + batch_new_results: list[SingleEvalResult] = common.map_with_progress(process_example, batch_examples, num_threads=5) + self.processed_results.extend(batch_new_results) + + if self.checkpoint_file and batch_new_results: + # Assuming common.save_checkpoint appends the new batch to the file + common.save_checkpoint(self.checkpoint_file, batch_new_results) + + print(f"Evaluation finished. Processed {len(self.processed_results) - num_already_processed} new results. Total processed now: {len(self.processed_results)} out of {num_total_examples_in_run} examples.") - # Run evaluation and collect results - results = common.map_with_progress(fn, self.examples) - # Aggregate metrics + # Aggregate metrics using all processed results (loaded + newly processed) + # The variable 'results' is now self.processed_results aggregate_metrics = { - "is_correct": sum(result.metrics["is_correct"] for result in results) / len(results), - "is_incorrect": sum(result.metrics["is_incorrect"] for result in results) / len(results), - "is_not_attempted": sum(result.metrics["is_not_attempted"] for result in results) / len(results), + "is_correct": sum(result.metrics["is_correct"] for result in self.processed_results) / len(self.processed_results) if self.processed_results else 0, + "is_incorrect": sum(result.metrics["is_incorrect"] for result in self.processed_results) / len(self.processed_results) if self.processed_results else 0, + "is_not_attempted": sum(result.metrics["is_not_attempted"] for result in self.processed_results) / len(self.processed_results) if self.processed_results else 0, } aggregate_metrics["is_given_attempted"] = aggregate_metrics["is_correct"] + aggregate_metrics["is_incorrect"] # Calculate accuracy_given_attempted @@ -188,6 +255,6 @@ def fn(row: dict): print(f"Accuracy Given Attempted: {output_d['accuracy_given_attempted']:.3f}") print(f"F1 Score: {output_d['f1']:.3f}") - return common.aggregate_results(results) + return common.aggregate_results(self.processed_results)