mirror of
https://github.com/Monadical-SAS/reflector.git
synced 2025-12-21 12:49:06 +00:00
code style updates
This commit is contained in:
@@ -121,9 +121,9 @@ def summarize_chunks(chunks, tokenizer, model):
|
||||
with torch.no_grad():
|
||||
summary_ids = \
|
||||
model.generate(input_ids,
|
||||
num_beams=int(config["DEFAULT"]["BEAM_SIZE"]),
|
||||
num_beams=int(config["SUMMARIZER"]["BEAM_SIZE"]),
|
||||
length_penalty=2.0,
|
||||
max_length=int(config["DEFAULT"]["MAX_LENGTH"]),
|
||||
max_length=int(config["SUMMARIZER"]["MAX_LENGTH"]),
|
||||
early_stopping=True)
|
||||
summary = tokenizer.decode(summary_ids[0],
|
||||
skip_special_tokens=True)
|
||||
@@ -132,7 +132,7 @@ def summarize_chunks(chunks, tokenizer, model):
|
||||
|
||||
|
||||
def chunk_text(text,
|
||||
max_chunk_length=int(config["DEFAULT"]["MAX_CHUNK_LENGTH"])):
|
||||
max_chunk_length=int(config["SUMMARIZER"]["MAX_CHUNK_LENGTH"])):
|
||||
"""
|
||||
Split text into smaller chunks.
|
||||
:param text: Text to be chunked
|
||||
@@ -154,9 +154,9 @@ def chunk_text(text,
|
||||
|
||||
def summarize(transcript_text, timestamp,
|
||||
real_time=False,
|
||||
chunk_summarize=config["DEFAULT"]["SUMMARIZE_USING_CHUNKS"]):
|
||||
chunk_summarize=config["SUMMARIZER"]["SUMMARIZE_USING_CHUNKS"]):
|
||||
device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
|
||||
summary_model = config["DEFAULT"]["SUMMARY_MODEL"]
|
||||
summary_model = config["SUMMARIZER"]["SUMMARY_MODEL"]
|
||||
if not summary_model:
|
||||
summary_model = "facebook/bart-large-cnn"
|
||||
|
||||
@@ -171,7 +171,7 @@ def summarize(transcript_text, timestamp,
|
||||
output_file = "real_time_" + output_file
|
||||
|
||||
if chunk_summarize != "YES":
|
||||
max_length = int(config["DEFAULT"]["INPUT_ENCODING_MAX_LENGTH"])
|
||||
max_length = int(config["SUMMARIZER"]["INPUT_ENCODING_MAX_LENGTH"])
|
||||
inputs = tokenizer. \
|
||||
batch_encode_plus([transcript_text], truncation=True,
|
||||
padding='longest',
|
||||
@@ -180,8 +180,8 @@ def summarize(transcript_text, timestamp,
|
||||
inputs = inputs.to(device)
|
||||
|
||||
with torch.no_grad():
|
||||
num_beans = int(config["DEFAULT"]["BEAM_SIZE"])
|
||||
max_length = int(config["DEFAULT"]["MAX_LENGTH"])
|
||||
num_beans = int(config["SUMMARIZER"]["BEAM_SIZE"])
|
||||
max_length = int(config["SUMMARIZER"]["MAX_LENGTH"])
|
||||
summaries = model.generate(inputs['input_ids'],
|
||||
num_beams=num_beans,
|
||||
length_penalty=2.0,
|
||||
|
||||
Reference in New Issue
Block a user