mirror of
https://github.com/Monadical-SAS/reflector.git
synced 2025-12-20 20:29:06 +00:00
flake8 checks
This commit is contained in:
20
server.py
20
server.py
@@ -53,11 +53,11 @@ def get_title_and_summary(llm_input_text, last_timestamp):
|
||||
|
||||
prompt = f"""
|
||||
### Human:
|
||||
Create a JSON object as response. The JSON object must have 2 fields:
|
||||
i) title and ii) summary. For the title field,generate a short title
|
||||
for the given text. For the summary field, summarize the given text
|
||||
Create a JSON object as response. The JSON object must have 2 fields:
|
||||
i) title and ii) summary. For the title field,generate a short title
|
||||
for the given text. For the summary field, summarize the given text
|
||||
in three sentences.
|
||||
|
||||
|
||||
{llm_input_text}
|
||||
|
||||
### Assistant:
|
||||
@@ -144,12 +144,12 @@ def get_transcription(frames):
|
||||
result_text = ""
|
||||
|
||||
try:
|
||||
segments, _ = model.transcribe(audiofilename,
|
||||
language="en",
|
||||
beam_size=5,
|
||||
vad_filter=True,
|
||||
vad_parameters=dict(min_silence_duration_ms=500)
|
||||
)
|
||||
segments, _ = \
|
||||
model.transcribe(audiofilename,
|
||||
language="en",
|
||||
beam_size=5,
|
||||
vad_filter=True,
|
||||
vad_parameters=dict(min_silence_duration_ms=500))
|
||||
os.remove(audiofilename)
|
||||
segments = list(segments)
|
||||
result_text = ""
|
||||
|
||||
@@ -16,8 +16,8 @@ from av import AudioFifo
|
||||
from sortedcontainers import SortedDict
|
||||
from whisper_jax import FlaxWhisperPipline
|
||||
|
||||
from utils.log_utils import logger
|
||||
from utils.run_utils import config, Mutex
|
||||
from ..utils.log_utils import logger
|
||||
from ..utils.run_utils import config, Mutex
|
||||
|
||||
WHISPER_MODEL_SIZE = config['DEFAULT']["WHISPER_REAL_TIME_MODEL_SIZE"]
|
||||
pcs = set()
|
||||
@@ -21,8 +21,8 @@ from whisper_jax import FlaxWhisperPipline
|
||||
from ...utils.file_utils import download_files, upload_files
|
||||
from ...utils.log_utils import logger
|
||||
from ...utils.run_utils import config
|
||||
from ...utils.text_utilities import post_process_transcription, summarize
|
||||
from ...utils.viz_utilities import create_talk_diff_scatter_viz, create_wordcloud
|
||||
from ...utils.text_utils import post_process_transcription, summarize
|
||||
from ...utils.viz_utils import create_talk_diff_scatter_viz, create_wordcloud
|
||||
|
||||
nltk.download('punkt', quiet=True)
|
||||
nltk.download('stopwords', quiet=True)
|
||||
|
||||
@@ -13,8 +13,8 @@ from whisper_jax import FlaxWhisperPipline
|
||||
from ...utils.file_utils import upload_files
|
||||
from ...utils.log_utils import logger
|
||||
from ...utils.run_utils import config
|
||||
from ...utils.text_utilities import post_process_transcription, summarize
|
||||
from ...utils.viz_utilities import create_talk_diff_scatter_viz, create_wordcloud
|
||||
from ...utils.text_utils import post_process_transcription, summarize
|
||||
from ...utils.viz_utils import create_talk_diff_scatter_viz, create_wordcloud
|
||||
|
||||
WHISPER_MODEL_SIZE = config['DEFAULT']["WHISPER_MODEL_SIZE"]
|
||||
|
||||
|
||||
@@ -154,7 +154,7 @@ def chunk_text(text,
|
||||
|
||||
def summarize(transcript_text, timestamp,
|
||||
real_time=False,
|
||||
summarize_using_chunks=config["DEFAULT"]["SUMMARIZE_USING_CHUNKS"]):
|
||||
chunk_summarize=config["DEFAULT"]["SUMMARIZE_USING_CHUNKS"]):
|
||||
device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
|
||||
summary_model = config["DEFAULT"]["SUMMARY_MODEL"]
|
||||
if not summary_model:
|
||||
@@ -166,27 +166,35 @@ def summarize(transcript_text, timestamp,
|
||||
model = BartForConditionalGeneration.from_pretrained(summary_model)
|
||||
model = model.to(device)
|
||||
|
||||
output_filename = "summary_" + timestamp.strftime("%m-%d-%Y_%H:%M:%S") + ".txt"
|
||||
output_file = "summary_" + timestamp.strftime("%m-%d-%Y_%H:%M:%S") + ".txt"
|
||||
if real_time:
|
||||
output_filename = "real_time_" + output_filename
|
||||
output_file = "real_time_" + output_file
|
||||
|
||||
if summarize_using_chunks != "YES":
|
||||
if chunk_summarize != "YES":
|
||||
max_length = int(config["DEFAULT"]["INPUT_ENCODING_MAX_LENGTH"])
|
||||
inputs = tokenizer. \
|
||||
batch_encode_plus([transcript_text], truncation=True,
|
||||
padding='longest',
|
||||
max_length=int(config["DEFAULT"]["INPUT_ENCODING_MAX_LENGTH"]),
|
||||
max_length=max_length,
|
||||
return_tensors='pt')
|
||||
inputs = inputs.to(device)
|
||||
|
||||
with torch.no_grad():
|
||||
num_beans = int(config["DEFAULT"]["BEAM_SIZE"])
|
||||
max_length = int(config["DEFAULT"]["MAX_LENGTH"])
|
||||
summaries = model.generate(inputs['input_ids'],
|
||||
num_beams=int(config["DEFAULT"]["BEAM_SIZE"]), length_penalty=2.0,
|
||||
max_length=int(config["DEFAULT"]["MAX_LENGTH"]), early_stopping=True)
|
||||
num_beams=num_beans,
|
||||
length_penalty=2.0,
|
||||
max_length=max_length,
|
||||
early_stopping=True)
|
||||
|
||||
decoded_summaries = [tokenizer.decode(summary, skip_special_tokens=True, clean_up_tokenization_spaces=False)
|
||||
for summary in summaries]
|
||||
decoded_summaries = \
|
||||
[tokenizer.decode(summary,
|
||||
skip_special_tokens=True,
|
||||
clean_up_tokenization_spaces=False)
|
||||
for summary in summaries]
|
||||
summary = " ".join(decoded_summaries)
|
||||
with open("./artefacts/" + output_filename, 'w') as f:
|
||||
with open("./artefacts/" + output_file, 'w') as f:
|
||||
f.write(summary.strip() + "\n")
|
||||
else:
|
||||
logger.info("Breaking transcript into smaller chunks")
|
||||
@@ -195,8 +203,8 @@ def summarize(transcript_text, timestamp,
|
||||
logger.info(f"Transcript broken into {len(chunks)} "
|
||||
f"chunks of at most 500 words")
|
||||
|
||||
logger.info(f"Writing summary text to: {output_filename}")
|
||||
with open(output_filename, 'w') as f:
|
||||
logger.info(f"Writing summary text to: {output_file}")
|
||||
with open(output_file, 'w') as f:
|
||||
summaries = summarize_chunks(chunks, tokenizer, model)
|
||||
for summary in summaries:
|
||||
f.write(summary.strip() + " ")
|
||||
@@ -45,14 +45,14 @@ def create_wordcloud(timestamp, real_time=False):
|
||||
plt.axis("off")
|
||||
plt.tight_layout(pad=0)
|
||||
|
||||
wordcloud_name = "wordcloud"
|
||||
wordcloud = "wordcloud"
|
||||
if real_time:
|
||||
wordcloud_name = "real_time_" + wordcloud_name + "_" + \
|
||||
wordcloud = "real_time_" + wordcloud + "_" + \
|
||||
timestamp.strftime("%m-%d-%Y_%H:%M:%S") + ".png"
|
||||
else:
|
||||
wordcloud_name += "_" + timestamp.strftime("%m-%d-%Y_%H:%M:%S") + ".png"
|
||||
wordcloud += "_" + timestamp.strftime("%m-%d-%Y_%H:%M:%S") + ".png"
|
||||
|
||||
plt.savefig("./artefacts/" + wordcloud_name)
|
||||
plt.savefig("./artefacts/" + wordcloud)
|
||||
|
||||
|
||||
def create_talk_diff_scatter_viz(timestamp, real_time=False):
|
||||
Reference in New Issue
Block a user