Skip to content
6 changes: 3 additions & 3 deletions benchmarks/models/hf_bart.sh
Original file line number Diff line number Diff line change
Expand Up @@ -20,9 +20,9 @@ source utils.sh
grep "facebook/bart-large-cnn cnn_dm.1k/raw val " perf | awk '{print $9}' | awk -F'|' '{if($1!="NA"){c+=1;s+=$1}}END{print s/c}' | bash range.sh 34.8 35
# Speed on V100 16GB 250W
grep -E "transformers_v3.0.2 facebook/bart-large-cnn cnn_dm.1k/raw val 32 " perf | awk '{s+=$13}END{print s/NR}' | bash range.sh 3.2 3.4
grep -E "transformers_v3.0.2\+fastseq_v.* facebook/bart-large-cnn cnn_dm.1k/raw val 32 " perf | awk '{s+=$13}END{print s/NR}' | bash range.sh 5.2 100
grep -E "transformers_v3.0.2\+fastseq_v.* facebook/bart-large-cnn cnn_dm.1k/raw val 64 " perf | awk '{s+=$13}END{print s/NR}' | bash range.sh 6.2 100
grep -E "transformers_v3.0.2\+fastseq_v.* facebook/bart-large-cnn cnn_dm.1k/raw val 128 " perf | awk '{s+=$13}END{print s/NR}' | bash range.sh 6.4 100
grep -E "transformers_v3.0.2\+fastseq_v.* facebook/bart-large-cnn cnn_dm.1k/raw val 32 " perf | awk '{s+=$13}END{print s/NR}' | bash range.sh 6.8 100
grep -E "transformers_v3.0.2\+fastseq_v.* facebook/bart-large-cnn cnn_dm.1k/raw val 64 " perf | awk '{s+=$13}END{print s/NR}' | bash range.sh 8.7 100
grep -E "transformers_v3.0.2\+fastseq_v.* facebook/bart-large-cnn cnn_dm.1k/raw val 128 " perf | awk '{s+=$13}END{print s/NR}' | bash range.sh 9.1 100

## Accuracy
#grep "facebook/bart-large-cnn cnn_dm/raw val " perf | awk '{print $9}' | awk -F'|' '{if($1!="NA"){c+=1;s+=$1}}END{print s/c}' | bash range.sh 44.78 44.82
Expand Down
4 changes: 2 additions & 2 deletions benchmarks/models/hf_distibart.sh
Original file line number Diff line number Diff line change
Expand Up @@ -20,9 +20,9 @@ source utils.sh
grep "sshleifer/distilbart-cnn-12-6 cnn_dm.1k/raw val " perf | awk '{print $9}' | awk -F'|' '{if($1!="NA"){c+=1;s+=$1}}END{print s/c}' | bash range.sh 35.1 35.3
# Speed on V100 16GB 250W
grep -E "transformers_v3.0.2 sshleifer/distilbart-cnn-12-6 cnn_dm.1k/raw val 64 " perf | awk '{s+=$13}END{print s/NR}' | bash range.sh 3.9 4.2
grep -E "transformers_v3.0.2\+fastseq_v.* sshleifer/distilbart-cnn-12-6 cnn_dm.1k/raw val 64 " perf | awk '{s+=$13}END{print s/NR}' | bash range.sh 6.5 100
grep -E "transformers_v3.0.2\+fastseq_v.* sshleifer/distilbart-cnn-12-6 cnn_dm.1k/raw val 64 " perf | awk '{s+=$13}END{print s/NR}' | bash range.sh 9.5 100
# todo: bigger bs doesn't increase speed
grep -E "transformers_v3.0.2\+fastseq_v.* sshleifer/distilbart-cnn-12-6 cnn_dm.1k/raw val 128 " perf | awk '{s+=$13}END{print s/NR}' | bash range.sh 6.5 100
grep -E "transformers_v3.0.2\+fastseq_v.* sshleifer/distilbart-cnn-12-6 cnn_dm.1k/raw val 128 " perf | awk '{s+=$13}END{print s/NR}' | bash range.sh 9.5 100

## Accuracy
#grep "sshleifer/distilbart-cnn-12-6 cnn_dm/raw val " perf | awk '{print $9}' | awk -F'|' '{if($1!="NA"){c+=1;s+=$1}}END{print s/c}' | bash range.sh 45 45.1
Expand Down
2 changes: 1 addition & 1 deletion benchmarks/models/hf_mbart.sh
Original file line number Diff line number Diff line change
Expand Up @@ -15,4 +15,4 @@ source utils.sh
grep "facebook/mbart-large-en-ro wmt_en_ro/raw val " perf | awk '{if($8!="NA"){c+=1;s+=$8}}END{print s/c}' | bash range.sh 27.79 27.95
# Speed on V100 16GB 250W
grep -E "transformers_v3.0.2 facebook/mbart-large-en-ro wmt_en_ro/raw val 64 " perf | awk '{s+=$13}END{print s/NR}' | bash range.sh 5.8 6.2
grep -E "transformers_v3.0.2\+fastseq_v.* facebook/mbart-large-en-ro wmt_en_ro/raw val 64 " perf | awk '{s+=$13}END{print s/NR}' | bash range.sh 6.0 100
Comment thread
feihugis marked this conversation as resolved.
grep -E "transformers_v3.0.2\+fastseq_v.* facebook/mbart-large-en-ro wmt_en_ro/raw val 64 " perf | awk '{s+=$13}END{print s/NR}' | bash range.sh 7.2 100
165 changes: 141 additions & 24 deletions fastseq_cli/transformers_generate.py
Original file line number Diff line number Diff line change
Expand Up @@ -2,21 +2,105 @@
import argparse
import json
from pathlib import Path

import torch
from multiprocessing import Process, Queue
from tqdm import tqdm

from fastseq_cli.transformers_utils import use_task_specific_params, trim_batch, calculate_rouge, calculate_bleu_score
import torch
from transformers import AutoModelForSeq2SeqLM, AutoTokenizer
from fastseq_cli.transformers_utils import use_task_specific_params, trim_batch, calculate_rouge, calculate_bleu_score

DEFAULT_DEVICE = "cuda" if torch.cuda.is_available() else "cpu"

GENERATE_FINISHED = 'done'
POSTPROCESS_FINISHED = None

class Dataset(torch.utils.data.Dataset):
Comment thread
feihugis marked this conversation as resolved.
Outdated
"""Characterizes a dataset for PyTorch"""
def __init__(self, examples, tokenizer, model_name, prefix):
self.examples = examples
self.tokenizer= tokenizer
self.model_name = model_name
self.prefix = prefix

def __len__(self):
return len(self.examples)

def __getitem__(self, index):
if "t5" in self.model_name:
batch = [self.prefix + text for text in batch]
batch = self.examples[index]
batch = self.tokenizer(batch,
return_tensors="pt",
truncation=True,
padding="max_length")
Copy link
Copy Markdown
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Add these parameters to the constructor instead of hard coding.

return batch['input_ids'], batch['attention_mask']

class IOProcess (Process):
""" Write detokenized output to file in order."""
def __init__(self, msg_queue, fout):
Copy link
Copy Markdown
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

missing docs

super(IOProcess, self).__init__()
self.msg_queue = msg_queue
self.fout = fout
self.waiting_for=0
self.dec_buf = {}

def process_dec(self, dec):
for hypothesis in dec:
self.fout.write(hypothesis + "\n")
self.fout.flush()

def process_buffer(self):
while self.waiting_for in self.dec_buf:
self.process_dec(self.dec_buf[self.waiting_for])
del self.dec_buf[self.waiting_for]
self.waiting_for+=1

def chunks(lst, n):
"""Yield successive n-sized chunks from lst."""
for i in range(0, len(lst), n):
yield lst[i:i + n]
def run(self):
while True:
ind, dec = self.msg_queue.get()
if dec == GENERATE_FINISHED:
break
elif ind != self.waiting_for:
self.dec_buf[ind] = dec
else:
self.process_dec(dec)
self.waiting_for+=1
self.process_buffer()
self.process_buffer()
assert not self.dec_buf, "IO Buffer not empty"
self.msg_queue.close()
self.msg_queue.join_thread()

class PostProcess(Process):
""" Parallel detokenization """
def __init__(self, tokenizer, data_queue, msg_queue,
Copy link
Copy Markdown
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

missing docs.

skip_special_tokens, clean_up_tokenization_spaces):
super(PostProcess, self).__init__()
self.data_queue = data_queue
self.msg_queue = msg_queue
self.tokenizer = tokenizer
self.clean_up_tokenization_spaces = clean_up_tokenization_spaces
self.skip_special_tokens = skip_special_tokens

def run(self):
while True:
ind, summaries = self.data_queue.get()
if summaries == GENERATE_FINISHED:
self.data_queue.put((-1, POSTPROCESS_FINISHED))
break
elif summaries == POSTPROCESS_FINISHED:
self.data_queue.put((-1, POSTPROCESS_FINISHED))
break
else:
dec = self.tokenizer.batch_decode(summaries,
skip_special_tokens = self.skip_special_tokens,
clean_up_tokenization_spaces =
self.clean_up_tokenization_spaces)
self.msg_queue.put((ind, dec))

self.data_queue.close()
self.data_queue.join_thread()
self.msg_queue.close()
self.msg_queue.join_thread()

def generate_summaries_or_translations(
examples: list,
Expand All @@ -29,6 +113,10 @@ def generate_summaries_or_translations(
decoder_start_token_id=None,
fastseq_opt=True,
no_repeat_ngram_size=None,
Comment thread
NickNickGo marked this conversation as resolved.
skip_special_tokens=True,
clean_up_tokenization_spaces=False,
pre_process_threads=2,
post_process_threads=2,
Comment thread
feihugis marked this conversation as resolved.
Outdated
**gen_kwargs,
) -> None:
"""Run generation"""
Expand All @@ -46,30 +134,42 @@ def generate_summaries_or_translations(

# update config with summarization specific params
use_task_specific_params(model, task)
data_queue = Queue()
msg_queue = Queue()
p_list = []

for batch in tqdm(list(chunks(examples, batch_size))):
if "t5" in model_name:
batch = [model.config.prefix + text for text in batch]
batch = tokenizer(batch,
return_tensors="pt",
truncation=True,
padding="max_length").to(device)
for i in range(post_process_threads):
p = PostProcess(tokenizer, data_queue, msg_queue,
skip_special_tokens, clean_up_tokenization_spaces)
p_list.append(p)
p.start()

io_process = IOProcess( msg_queue, fout)
Comment thread
NickNickGo marked this conversation as resolved.
io_process.start()
dataset = Dataset(examples, tokenizer, model_name, model.config.prefix)
training_generator = torch.utils.data.DataLoader(dataset,
batch_size=batch_size, num_workers = pre_process_threads)
for ind, batch in tqdm(enumerate(training_generator)):
input_ids, attention_mask = batch
input_ids = input_ids.view(batch_size, -1).to(device)
attention_mask = attention_mask.view(batch_size, -1).to(device)
input_ids, attention_mask = trim_batch(
**batch, pad_token_id=tokenizer.pad_token_id)
input_ids, tokenizer.pad_token_id, attention_mask)
summaries = model.generate(
input_ids=input_ids,
attention_mask=attention_mask,
decoder_start_token_id=decoder_start_token_id,
no_repeat_ngram_size=no_repeat_ngram_size,
**gen_kwargs,
)
dec = tokenizer.batch_decode(summaries,
skip_special_tokens=True,
clean_up_tokenization_spaces=False)
for hypothesis in dec:
fout.write(hypothesis + "\n")
fout.flush()

summaries_cpu = summaries.cpu()
data_queue.put((ind, summaries_cpu))
data_queue.put((-1, GENERATE_FINISHED))
for p in p_list:
p.join()
msg_queue.put((-1, GENERATE_FINISHED))
io_process.join()
fout.close()

def run_generate():
"""Entrance is here."""
Expand Down Expand Up @@ -118,6 +218,19 @@ def run_generate():
parser.add_argument("--without_fastseq_opt", action="store_true")
parser.add_argument("--no_repeat_ngram_size", type=int, default=None,
Comment thread
NickNickGo marked this conversation as resolved.
required=False, help="size of no repeat ngram")
parser.add_argument("--include_special_tokens", action="store_true")
parser.add_argument("--clean_up_tokenization_spaces", action="store_true")
parser.add_argument("--pre_process_threads",
type=int,
default=2,
required=False,
help="pre-processing worker threads")
parser.add_argument("--post_process_threads",
type=int,
default=2,
required=False,
help="post-processing worker threads")
Comment thread
feihugis marked this conversation as resolved.
Outdated

args = parser.parse_args()
examples = [
" " + x.rstrip() if "t5" in args.model_name else x.rstrip()
Expand All @@ -137,7 +250,11 @@ def run_generate():
decoder_start_token_id=args.decoder_start_token_id,
fastseq_opt=not args.without_fastseq_opt,
no_repeat_ngram_size=args.no_repeat_ngram_size,
)
skip_special_tokens=not args.include_special_tokens,
clean_up_tokenization_spaces=args.clean_up_tokenization_spaces,
pre_process_threads=args.pre_process_threads,
post_process_threads=args.post_process_threads,
)
if args.reference_path is None:
return
# Compute scores
Expand Down