Skip to content
6 changes: 3 additions & 3 deletions benchmarks/models/hf_bart.sh
Original file line number Diff line number Diff line change
Expand Up @@ -20,9 +20,9 @@ source utils.sh
grep "facebook/bart-large-cnn cnn_dm.1k/raw val " perf | awk '{print $9}' | awk -F'|' '{if($1!="NA"){c+=1;s+=$1}}END{print s/c}' | bash range.sh 34.8 35
# Speed on V100 16GB 250W
grep -E "transformers_v3.0.2 facebook/bart-large-cnn cnn_dm.1k/raw val 32 " perf | awk '{s+=$13}END{print s/NR}' | bash range.sh 3.2 3.4
grep -E "transformers_v3.0.2\+fastseq_v.* facebook/bart-large-cnn cnn_dm.1k/raw val 32 " perf | awk '{s+=$13}END{print s/NR}' | bash range.sh 5.2 100
grep -E "transformers_v3.0.2\+fastseq_v.* facebook/bart-large-cnn cnn_dm.1k/raw val 64 " perf | awk '{s+=$13}END{print s/NR}' | bash range.sh 6.2 100
grep -E "transformers_v3.0.2\+fastseq_v.* facebook/bart-large-cnn cnn_dm.1k/raw val 128 " perf | awk '{s+=$13}END{print s/NR}' | bash range.sh 6.4 100
grep -E "transformers_v3.0.2\+fastseq_v.* facebook/bart-large-cnn cnn_dm.1k/raw val 32 " perf | awk '{s+=$13}END{print s/NR}' | bash range.sh 6.8 100
grep -E "transformers_v3.0.2\+fastseq_v.* facebook/bart-large-cnn cnn_dm.1k/raw val 64 " perf | awk '{s+=$13}END{print s/NR}' | bash range.sh 8.7 100
grep -E "transformers_v3.0.2\+fastseq_v.* facebook/bart-large-cnn cnn_dm.1k/raw val 128 " perf | awk '{s+=$13}END{print s/NR}' | bash range.sh 9.1 100

## Accuracy
#grep "facebook/bart-large-cnn cnn_dm/raw val " perf | awk '{print $9}' | awk -F'|' '{if($1!="NA"){c+=1;s+=$1}}END{print s/c}' | bash range.sh 44.78 44.82
Expand Down
4 changes: 2 additions & 2 deletions benchmarks/models/hf_distibart.sh
Original file line number Diff line number Diff line change
Expand Up @@ -20,9 +20,9 @@ source utils.sh
grep "sshleifer/distilbart-cnn-12-6 cnn_dm.1k/raw val " perf | awk '{print $9}' | awk -F'|' '{if($1!="NA"){c+=1;s+=$1}}END{print s/c}' | bash range.sh 35.1 35.3
# Speed on V100 16GB 250W
grep -E "transformers_v3.0.2 sshleifer/distilbart-cnn-12-6 cnn_dm.1k/raw val 64 " perf | awk '{s+=$13}END{print s/NR}' | bash range.sh 3.9 4.2
grep -E "transformers_v3.0.2\+fastseq_v.* sshleifer/distilbart-cnn-12-6 cnn_dm.1k/raw val 64 " perf | awk '{s+=$13}END{print s/NR}' | bash range.sh 6.5 100
grep -E "transformers_v3.0.2\+fastseq_v.* sshleifer/distilbart-cnn-12-6 cnn_dm.1k/raw val 64 " perf | awk '{s+=$13}END{print s/NR}' | bash range.sh 9.5 100
# todo: bigger bs doesn't increase speed
grep -E "transformers_v3.0.2\+fastseq_v.* sshleifer/distilbart-cnn-12-6 cnn_dm.1k/raw val 128 " perf | awk '{s+=$13}END{print s/NR}' | bash range.sh 6.5 100
grep -E "transformers_v3.0.2\+fastseq_v.* sshleifer/distilbart-cnn-12-6 cnn_dm.1k/raw val 128 " perf | awk '{s+=$13}END{print s/NR}' | bash range.sh 9.5 100

## Accuracy
#grep "sshleifer/distilbart-cnn-12-6 cnn_dm/raw val " perf | awk '{print $9}' | awk -F'|' '{if($1!="NA"){c+=1;s+=$1}}END{print s/c}' | bash range.sh 45 45.1
Expand Down
2 changes: 1 addition & 1 deletion benchmarks/models/hf_mbart.sh
Original file line number Diff line number Diff line change
Expand Up @@ -15,4 +15,4 @@ source utils.sh
grep "facebook/mbart-large-en-ro wmt_en_ro/raw val " perf | awk '{if($8!="NA"){c+=1;s+=$8}}END{print s/c}' | bash range.sh 27.79 27.95
# Speed on V100 16GB 250W
grep -E "transformers_v3.0.2 facebook/mbart-large-en-ro wmt_en_ro/raw val 64 " perf | awk '{s+=$13}END{print s/NR}' | bash range.sh 5.8 6.2
grep -E "transformers_v3.0.2\+fastseq_v.* facebook/mbart-large-en-ro wmt_en_ro/raw val 64 " perf | awk '{s+=$13}END{print s/NR}' | bash range.sh 6.0 100
Comment thread
feihugis marked this conversation as resolved.
grep -E "transformers_v3.0.2\+fastseq_v.* facebook/mbart-large-en-ro wmt_en_ro/raw val 64 " perf | awk '{s+=$13}END{print s/NR}' | bash range.sh 7.2 100
115 changes: 102 additions & 13 deletions fastseq_cli/transformers_generate.py
Original file line number Diff line number Diff line change
Expand Up @@ -2,21 +2,91 @@
import argparse
import json
from pathlib import Path

import torch
from multiprocessing import Process, Queue, cpu_count
from tqdm import tqdm

from fastseq_cli.transformers_utils import use_task_specific_params, trim_batch, calculate_rouge, calculate_bleu_score
import torch
from transformers import AutoModelForSeq2SeqLM, AutoTokenizer
from fastseq_cli.transformers_utils import use_task_specific_params, trim_batch, calculate_rouge, calculate_bleu_score

DEFAULT_DEVICE = "cuda" if torch.cuda.is_available() else "cpu"

GENERATE_FINISHED = 'done'
POSTPROCESS_FINISHED = None


def chunks(lst, n):
"""Yield successive n-sized chunks from lst."""
for i in range(0, len(lst), n):
yield lst[i:i + n]

class IOProcess (Process) :
""" Write detokenized output to file in order."""
def __init__(self, msg_queue, fout):
Copy link
Copy Markdown
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

missing docs

super(IOProcess, self).__init__()
self.msg_queue = msg_queue
self.fout = fout
self.waiting_for=0
self.dec_buf = {}

def process_dec(self, dec) :
for hypothesis in dec:
self.fout.write(hypothesis + "\n")
self.fout.flush()

def process_buffer(self):
while self.waiting_for in self.dec_buf :
self.process_dec(self.dec_buf[self.waiting_for])
del self.dec_buf[self.waiting_for]
self.waiting_for+=1

def run(self) :
while True :
ind, dec = self.msg_queue.get()
if dec == GENERATE_FINISHED:
break
elif ind != self.waiting_for:
self.dec_buf[ind] = dec
else :
self.process_dec(dec)
self.waiting_for+=1
self.process_buffer()
self.process_buffer()
assert not self.dec_buf, "IO Buffer not empty"
self.msg_queue.close()
self.msg_queue.join_thread()

class PostProcess(Process) :
""" Parallel detokenization """
def __init__(self, tokenizer, data_queue, msg_queue,
Copy link
Copy Markdown
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

missing docs.

skip_special_tokens, clean_up_tokenization_spaces) :
super(PostProcess, self).__init__()
self.data_queue = data_queue
self.msg_queue = msg_queue
self.tokenizer = tokenizer
self.clean_up_tokenization_spaces = clean_up_tokenization_spaces
self.skip_special_tokens = skip_special_tokens

def run(self) :
while True :
ind, summaries = self.data_queue.get()
if summaries == GENERATE_FINISHED:
self.data_queue.put((-1, POSTPROCESS_FINISHED))
break
elif summaries == POSTPROCESS_FINISHED :
self.data_queue.put((-1, POSTPROCESS_FINISHED))
break
else :
dec = self.tokenizer.batch_decode(summaries,
skip_special_tokens = self.skip_special_tokens,
clean_up_tokenization_spaces =
self.clean_up_tokenization_spaces)
self.msg_queue.put((ind, dec))

self.data_queue.close()
self.data_queue.join_thread()
self.msg_queue.close()
self.msg_queue.join_thread()


def generate_summaries_or_translations(
examples: list,
Expand All @@ -29,6 +99,8 @@ def generate_summaries_or_translations(
decoder_start_token_id=None,
fastseq_opt=True,
no_repeat_ngram_size=None,
Comment thread
NickNickGo marked this conversation as resolved.
skip_special_tokens=True,
clean_up_tokenization_spaces=False,
**gen_kwargs,
) -> None:
"""Run generation"""
Expand All @@ -46,8 +118,21 @@ def generate_summaries_or_translations(

# update config with summarization specific params
use_task_specific_params(model, task)
data_queue = Queue()
msg_queue = Queue()
p_list = []
threads = cpu_count()
Copy link
Copy Markdown
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

It may be better to allow users to specify CPU numbers.

Copy link
Copy Markdown
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

shouldn't make a big difference right, although I can create an argument .,

Copy link
Copy Markdown
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

There should be some differences. It will waste the CPU resources and it also brings overhead to create and manage these processes and sync data across these processes.

Copy link
Copy Markdown
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

There is a parameter define when support parallel for fairseq. GPU machine has 32/64 or more CPU. Do you get better speed when have threads > 1?

Copy link
Copy Markdown
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

@feihugis I added support for this.
@yuyan2do , I haven't yet analyzed effect of changing num of threads on speed, let me do that .

Copy link
Copy Markdown
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I didn't notice significant changes in overall time when number of threads are changed.


for i in range(threads) :
p = PostProcess(tokenizer, data_queue, msg_queue,
skip_special_tokens, clean_up_tokenization_spaces)
p_list.append(p)
p.start()

io_process = IOProcess( msg_queue, fout)
Comment thread
NickNickGo marked this conversation as resolved.
io_process.start()

for batch in tqdm(list(chunks(examples, batch_size))):
for ind, batch in tqdm(enumerate(list(chunks(examples, batch_size)))):
if "t5" in model_name:
batch = [model.config.prefix + text for text in batch]
batch = tokenizer(batch,
Comment thread
NickNickGo marked this conversation as resolved.
Outdated
Expand All @@ -63,13 +148,14 @@ def generate_summaries_or_translations(
no_repeat_ngram_size=no_repeat_ngram_size,
**gen_kwargs,
)
dec = tokenizer.batch_decode(summaries,
skip_special_tokens=True,
clean_up_tokenization_spaces=False)
for hypothesis in dec:
fout.write(hypothesis + "\n")
fout.flush()

summaries_cpu = summaries.cpu()
data_queue.put((ind, summaries_cpu))
data_queue.put((-1, GENERATE_FINISHED))
for p in p_list :
p.join()
msg_queue.put((-1, GENERATE_FINISHED))
io_process.join()
fout.close()

def run_generate():
"""Entrance is here."""
Expand Down Expand Up @@ -118,6 +204,7 @@ def run_generate():
parser.add_argument("--without_fastseq_opt", action="store_true")
parser.add_argument("--no_repeat_ngram_size", type=int, default=None,
Comment thread
NickNickGo marked this conversation as resolved.
required=False, help="size of no repeat ngram")

args = parser.parse_args()
examples = [
" " + x.rstrip() if "t5" in args.model_name else x.rstrip()
Expand All @@ -137,7 +224,9 @@ def run_generate():
decoder_start_token_id=args.decoder_start_token_id,
fastseq_opt=not args.without_fastseq_opt,
no_repeat_ngram_size=args.no_repeat_ngram_size,
)
skip_special_tokens=True,
clean_up_tokenization_spaces=False,
Comment thread
feihugis marked this conversation as resolved.
Outdated
)
if args.reference_path is None:
return
# Compute scores
Expand Down