Files
reflector/server/tests/evaluate/evaluate_transcription.py
2023-08-10 11:40:55 +05:30

221 lines
7.0 KiB
Python

import json
import os
import re
from dataclasses import dataclass
from pathlib import Path
from typing import List, Union
from jiwer import wer
from Levenshtein import distance
from tqdm.auto import tqdm
from whisper.normalizers import EnglishTextNormalizer
@dataclass
class EvaluationResult:
"""
Result object of the model evaluation
"""
accuracy = float
total_test_samples = int
def __init__(self, accuracy, total_test_samples):
self.accuracy = accuracy
self.total_test_samples = total_test_samples
def __repr__(self):
return (
"EvaluationResult("
+ json.dumps(
{
"accuracy": self.accuracy,
"total_test_samples": self.total_test_samples,
}
)
+ ")"
)
@dataclass
class EvaluationTestSample:
"""
Represents one test sample
"""
reference_text = str
predicted_text = str
def __init__(self, reference_text, predicted_text):
self.reference_text = reference_text
self.predicted_text = predicted_text
def update(self, reference_text, predicted_text):
self.reference_text = reference_text
self.predicted_text = predicted_text
class TestDatasetLoader:
"""
Test samples loader
"""
parent_dir = None
total_samples = 0
def __init__(self, parent_dir: Union[Path | str]):
if isinstance(parent_dir, str):
self.parent_dir = Path(parent_dir)
else:
self.parent_dir = parent_dir
def _load_test_data(self) -> tuple[str, str]:
"""
Loader function to validate inout files and generate samples
"""
PREDICTED_TEST_SAMPLES_DIR = self.parent_dir / "predicted_texts"
REFERENCE_TEST_SAMPLES_DIR = self.parent_dir / "reference_texts"
for filename in os.listdir(PREDICTED_TEST_SAMPLES_DIR.as_posix()):
match = re.search(r"(\d+)\.txt$", filename)
if match:
sample_id = match.group(1)
pred_file_path = (PREDICTED_TEST_SAMPLES_DIR / filename).as_posix()
ref_file_name = "ref_sample_" + str(sample_id) + ".txt"
ref_file_path = (REFERENCE_TEST_SAMPLES_DIR / ref_file_name).as_posix()
if os.path.exists(ref_file_path):
self.total_samples += 1
yield ref_file_path, pred_file_path
def __iter__(self) -> EvaluationTestSample:
"""
Iter method for the test loader
"""
for pred_file_path, ref_file_path in self._load_test_data():
with open(pred_file_path, "r", encoding="utf-8") as file:
pred_text = file.read()
with open(ref_file_path, "r", encoding="utf-8") as file:
ref_text = file.read()
yield EvaluationTestSample(ref_text, pred_text)
class ModelEvaluator:
"""
Class that comprises all model evaluation related processes and methods
"""
# The 2 popular methods of WER differ slightly. More dimensions of accuracy
# will be added. For now, the average of these 2 will serve as the metric.
WEIGHTED_WER_LEVENSHTEIN = 0.0
WER_LEVENSHTEIN = []
WEIGHTED_WER_JIWER = 0.0
WER_JIWER = []
normalizer = None
accuracy = None
test_dataset_loader = None
test_directory = None
evaluation_config = {}
def __init__(self, **kwargs):
self.evaluation_config = {k: v for k, v in kwargs.items() if v is not None}
if "normalizer" not in self.evaluation_config:
self.normalizer = EnglishTextNormalizer()
self.evaluation_config["normalizer"] = str(type(self.normalizer))
if "parent_dir" not in self.evaluation_config:
self.test_directory = Path(__file__).parent
self.test_dataset_loader = TestDatasetLoader(self.test_directory)
self.evaluation_config["test_directory"] = str(self.test_directory)
def __repr__(self):
return "ModelEvaluator(" + json.dumps(self.describe(), indent=4) + ")"
def describe(self) -> dict:
"""
Returns the parameters defining the evaluator
"""
return self.evaluation_config
def _normalize(self, sample: EvaluationTestSample) -> None:
"""
Normalize both reference and predicted text
"""
sample.update(
self.normalizer(sample.reference_text),
self.normalizer(sample.predicted_text),
)
def _calculate_wer(self, sample: EvaluationTestSample) -> float:
"""
Based on weights for (insert, delete, substitute), calculate
the Word Error Rate
"""
levenshtein_distance = distance(
s1=sample.reference_text,
s2=sample.predicted_text,
weights=(
self.evaluation_config["insertion_penalty"],
self.evaluation_config["deletion_penalty"],
self.evaluation_config["substitution_penalty"],
),
)
wer = levenshtein_distance / len(sample.reference_text)
return wer
def _calculate_wers(self) -> None:
"""
Compute WER
"""
for sample in tqdm(self.test_dataset_loader, desc="Evaluating", ncols=100):
self._normalize(sample)
wer_item_l = {
"wer": self._calculate_wer(sample),
"no_of_words": len(sample.reference_text),
}
wer_item_j = {
"wer": wer(sample.reference_text, sample.predicted_text),
"no_of_words": len(sample.reference_text),
}
self.WER_LEVENSHTEIN.append(wer_item_l)
self.WER_JIWER.append(wer_item_j)
def _calculate_weighted_wer(self, wers: List[float]) -> float:
"""
Calculate the weighted WER from WER
"""
total_wer = 0.0
total_words = 0.0
for item in wers:
total_wer += item["no_of_words"] * item["wer"]
total_words += item["no_of_words"]
return total_wer / total_words
def _calculate_model_accuracy(self) -> None:
"""
Compute model accuracy
"""
self._calculate_wers()
weighted_wer_levenshtein = self._calculate_weighted_wer(self.WER_LEVENSHTEIN)
weighted_wer_jiwer = self._calculate_weighted_wer(self.WER_JIWER)
final_weighted_wer = (weighted_wer_levenshtein + weighted_wer_jiwer) / 2
self.accuracy = (1 - final_weighted_wer) * 100
def evaluate(self, recalculate: bool = False) -> EvaluationResult:
"""
Triggers the model evaluation
"""
if not self.accuracy or recalculate:
self._calculate_model_accuracy()
return EvaluationResult(self.accuracy, self.test_dataset_loader.total_samples)
eval_config = {"insertion_penalty": 1, "deletion_penalty": 2, "substitution_penalty": 1}
evaluator = ModelEvaluator(**eval_config)
evaluation = evaluator.evaluate()
print(evaluator)
print(evaluation)
print("Model accuracy : {:.2f} %".format(evaluation.accuracy))