Skip to content
95 changes: 76 additions & 19 deletions fastseq_cli/transformers_generate.py
Original file line number Diff line number Diff line change
@@ -1,22 +1,71 @@
"""From Huggingface Transformers."""
import argparse
import json
import time
from pathlib import Path

import torch
import time
from tqdm import tqdm

from fastseq_cli.transformers_utils import use_task_specific_params, trim_batch, calculate_rouge, calculate_bleu_score
from multiprocessing import Process, Queue, JoinableQueue, cpu_count
from transformers import AutoModelForSeq2SeqLM, AutoTokenizer
from fastseq_cli.transformers_utils import use_task_specific_params, trim_batch, calculate_rouge, calculate_bleu_score

DEFAULT_DEVICE = "cuda" if torch.cuda.is_available() else "cpu"

GENERATE_FINISHED = 'done'
POSTPROCESS_FINISHED = None


def chunks(lst, n):
"""Yield successive n-sized chunks from lst."""
for i in range(0, len(lst), n):
yield lst[i:i + n]

class IOProcess (Process) :
Comment thread
NickNickGo marked this conversation as resolved.
Outdated
def __init__ (self, msg_queue, fout):
Comment thread
NickNickGo marked this conversation as resolved.
Outdated
Copy link
Copy Markdown
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Suggested change
def __init__ (self, msg_queue, fout):
def __init__(self, msg_queue, fout):

Copy link
Copy Markdown
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Remove the similar spaces in other places.

super(IOProcess, self).__init__()
self.msg_queue = msg_queue
self.fout = fout
def run (self) :
while (True) :
dec = self.msg_queue.get()
if dec == GENERATE_FINISHED :
Copy link
Copy Markdown
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Suggested change
if dec == GENERATE_FINISHED :
if dec == GENERATE_FINISHED:

break
else :
for hypothesis in dec:
self.fout.write(hypothesis + "\n")
self.fout.flush()
self.msg_queue.close()
self.msg_queue.join_thread()

class PostProcess (Process) :
def __init__ (self, tokenizer, data_queue, msg_queue) :
super(PostProcess, self).__init__()
self.data_queue = data_queue
self.msg_queue = msg_queue
self.tokenizer = tokenizer

def run (self) :
while True :
summaries = self.data_queue.get()
if summaries == GENERATE_FINISHED :
self.data_queue.put(POSTPROCESS_FINISHED)
break
elif summaries == POSTPROCESS_FINISHED :
self.data_queue.put(POSTPROCESS_FINISHED)
break
else :
dec = self.tokenizer.batch_decode(summaries,
skip_special_tokens=True,
clean_up_tokenization_spaces=False)
Comment thread
NickNickGo marked this conversation as resolved.
Outdated
self.msg_queue.put(dec)

self.data_queue.close()
self.data_queue.join_thread()
self.msg_queue.close()
self.msg_queue.join_thread()
self.msg_queue.join()


def generate_summaries_or_translations(
examples: list,
Expand All @@ -28,7 +77,6 @@ def generate_summaries_or_translations(
task="summarization",
decoder_start_token_id=None,
fastseq_opt=True,
no_repeat_ngram_size=None,
Comment thread
NickNickGo marked this conversation as resolved.
**gen_kwargs,
) -> None:
"""Run generation"""
Expand All @@ -41,36 +89,48 @@ def generate_summaries_or_translations(
model = model.half()
if decoder_start_token_id is None:
decoder_start_token_id = gen_kwargs.pop("decoder_start_token_id", None)

tokenizer = AutoTokenizer.from_pretrained(model_name)

Comment thread
feihugis marked this conversation as resolved.
Outdated
# update config with summarization specific params
use_task_specific_params(model, task)

data_queue = Queue()
msg_queue = Queue()
p_list = []
threads = cpu_count()
Copy link
Copy Markdown
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

It may be better to allow users to specify CPU numbers.

Copy link
Copy Markdown
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

shouldn't make a big difference right, although I can create an argument .,

Copy link
Copy Markdown
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

There should be some differences. It will waste the CPU resources and it also brings overhead to create and manage these processes and sync data across these processes.

Copy link
Copy Markdown
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

There is a parameter define when support parallel for fairseq. GPU machine has 32/64 or more CPU. Do you get better speed when have threads > 1?

Copy link
Copy Markdown
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

@feihugis I added support for this.
@yuyan2do , I haven't yet analyzed effect of changing num of threads on speed, let me do that .

Copy link
Copy Markdown
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I didn't notice significant changes in overall time when number of threads are changed.


for i in range (threads) :
Comment thread
NickNickGo marked this conversation as resolved.
Outdated
p = PostProcess(tokenizer, data_queue, msg_queue)
p_list.append(p)
p.start()

io_process = IOProcess( msg_queue, fout)
Comment thread
NickNickGo marked this conversation as resolved.
io_process.start()

for batch in tqdm(list(chunks(examples, batch_size))):
if "t5" in model_name:
batch = [model.config.prefix + text for text in batch]
torch.cuda.nvtx.range_push("tokenization_step")
Comment thread
NickNickGo marked this conversation as resolved.
Outdated
batch = tokenizer(batch,
Comment thread
NickNickGo marked this conversation as resolved.
Outdated
return_tensors="pt",
truncation=True,
padding="max_length").to(device)
input_ids, attention_mask = trim_batch(
**batch, pad_token_id=tokenizer.pad_token_id)
torch.cuda.nvtx.range_pop()
Comment thread
NickNickGo marked this conversation as resolved.
Outdated
summaries = model.generate(
input_ids=input_ids,
attention_mask=attention_mask,
decoder_start_token_id=decoder_start_token_id,
no_repeat_ngram_size=no_repeat_ngram_size,
**gen_kwargs,
)
dec = tokenizer.batch_decode(summaries,
skip_special_tokens=True,
clean_up_tokenization_spaces=False)
for hypothesis in dec:
fout.write(hypothesis + "\n")
fout.flush()


summaries_cpu = summaries.cpu()
data_queue.put(summaries_cpu)
data_queue.put(GENERATE_FINISHED)
for p in p_list :
p.join()
msg_queue.put(GENERATE_FINISHED)
io_process.join()
def run_generate():
"""Entrance is here."""
parser = argparse.ArgumentParser()
Expand Down Expand Up @@ -116,8 +176,6 @@ def run_generate():
help="How many observations. Defaults to all.")
parser.add_argument("--fp16", action="store_true")
parser.add_argument("--without_fastseq_opt", action="store_true")
parser.add_argument("--no_repeat_ngram_size", type=int, default=None,
Comment thread
NickNickGo marked this conversation as resolved.
required=False, help="size of no repeat ngram")
args = parser.parse_args()
examples = [
" " + x.rstrip() if "t5" in args.model_name else x.rstrip()
Expand All @@ -136,7 +194,6 @@ def run_generate():
task=args.task,
decoder_start_token_id=args.decoder_start_token_id,
fastseq_opt=not args.without_fastseq_opt,
no_repeat_ngram_size=args.no_repeat_ngram_size,
)
if args.reference_path is None:
return
Expand Down