Files
reflector/server/evaluate/evaluate_transcription.py
Gokul Mohanarangan 992134a38b minor update
2023-08-10 14:37:39 +05:30

205 lines
6.6 KiB
Python

import re
from pathlib import Path
from typing import Any, List
from jiwer import wer
from Levenshtein import distance
from pydantic import BaseModel, Field, field_validator
from tqdm.auto import tqdm
from whisper.normalizers import EnglishTextNormalizer
class EvaluationResult(BaseModel):
"""
Result object of the model evaluation
"""
accuracy: float = Field(default=0.0)
total_test_samples: int = Field(default=0)
class EvaluationTestSample(BaseModel):
"""
Represents one test sample
"""
reference_text: str
predicted_text: str
def update(self, reference_text:str, predicted_text:str) -> None:
self.reference_text = reference_text
self.predicted_text = predicted_text
class TestDatasetLoader(BaseModel):
"""
Test samples loader
"""
test_dir: Path = Field(default=Path(__file__).parent)
total_samples: int = Field(default=0)
@field_validator("test_dir")
def validate_file_path(cls, path):
"""
Check the file path
"""
if not path.exists():
raise ValueError("Path does not exist")
return path
def _load_test_data(self) -> tuple[Path, Path]:
"""
Loader function to validate input files and generate samples
"""
PREDICTED_TEST_SAMPLES_DIR = self.test_dir / "predicted_texts"
REFERENCE_TEST_SAMPLES_DIR = self.test_dir / "reference_texts"
for filename in PREDICTED_TEST_SAMPLES_DIR.iterdir():
match = re.search(r"(\d+)\.txt$", filename.as_posix())
if match:
sample_id = match.group(1)
pred_file_path = PREDICTED_TEST_SAMPLES_DIR / filename
ref_file_name = "ref_sample_" + str(sample_id) + ".txt"
ref_file_path = REFERENCE_TEST_SAMPLES_DIR / ref_file_name
if ref_file_path.exists():
self.total_samples += 1
yield ref_file_path, pred_file_path
def __iter__(self) -> EvaluationTestSample:
"""
Iter method for the test loader
"""
for pred_file_path, ref_file_path in self._load_test_data():
with open(pred_file_path, "r", encoding="utf-8") as file:
pred_text = file.read()
with open(ref_file_path, "r", encoding="utf-8") as file:
ref_text = file.read()
yield EvaluationTestSample(reference_text=ref_text, predicted_text=pred_text)
class EvaluationConfig(BaseModel):
"""
Model for evaluation parameters
"""
insertion_penalty: int = Field(default=1)
substitution_penalty: int = Field(default=1)
deletion_penalty: int = Field(default=1)
normalizer: Any = Field(default=EnglishTextNormalizer())
test_directory: str = Field(default=str(Path(__file__).parent))
class ModelEvaluator:
"""
Class that comprises all model evaluation related processes and methods
"""
# The 2 popular methods of WER differ slightly. More dimensions of accuracy
# will be added. For now, the average of these 2 will serve as the metric.
WEIGHTED_WER_LEVENSHTEIN = 0.0
WER_LEVENSHTEIN = []
WEIGHTED_WER_JIWER = 0.0
WER_JIWER = []
evaluation_result = EvaluationResult()
test_dataset_loader = None
evaluation_config = None
def __init__(self, **kwargs):
self.evaluation_config = EvaluationConfig(**kwargs)
self.test_dataset_loader = TestDatasetLoader(test_dir=self.evaluation_config.test_directory)
def __repr__(self):
return f"ModelEvaluator({self.evaluation_config})"
def describe(self) -> dict:
"""
Returns the parameters defining the evaluator
"""
return self.evaluation_config.model_dump()
def _normalize(self, sample: EvaluationTestSample) -> None:
"""
Normalize both reference and predicted text
"""
sample.update(
self.evaluation_config.normalizer(sample.reference_text),
self.evaluation_config.normalizer(sample.predicted_text),
)
def _calculate_wer(self, sample: EvaluationTestSample) -> float:
"""
Based on weights for (insert, delete, substitute), calculate
the Word Error Rate
"""
levenshtein_distance = distance(
s1=sample.reference_text,
s2=sample.predicted_text,
weights=(
self.evaluation_config.insertion_penalty,
self.evaluation_config.deletion_penalty,
self.evaluation_config.substitution_penalty,
),
)
wer = levenshtein_distance / len(sample.reference_text)
return wer
def _calculate_wers(self) -> None:
"""
Compute WER
"""
for sample in tqdm(self.test_dataset_loader, desc="Evaluating"):
self._normalize(sample)
wer_item_l = {
"wer": self._calculate_wer(sample),
"no_of_words": len(sample.reference_text),
}
wer_item_j = {
"wer": wer(sample.reference_text, sample.predicted_text),
"no_of_words": len(sample.reference_text),
}
self.WER_LEVENSHTEIN.append(wer_item_l)
self.WER_JIWER.append(wer_item_j)
def _calculate_weighted_wer(self, wers: List[float]) -> float:
"""
Calculate the weighted WER from WER
"""
total_wer = 0.0
total_words = 0.0
for item in wers:
total_wer += item["no_of_words"] * item["wer"]
total_words += item["no_of_words"]
return total_wer / total_words
def _calculate_model_accuracy(self) -> None:
"""
Compute model accuracy
"""
self._calculate_wers()
weighted_wer_levenshtein = self._calculate_weighted_wer(self.WER_LEVENSHTEIN)
weighted_wer_jiwer = self._calculate_weighted_wer(self.WER_JIWER)
final_weighted_wer = (weighted_wer_levenshtein + weighted_wer_jiwer) / 2
self.evaluation_result.accuracy = (1 - final_weighted_wer) * 100
def evaluate(self, recalculate: bool = False) -> EvaluationResult:
"""
Triggers the model evaluation
"""
if not self.evaluation_result.accuracy or recalculate:
self._calculate_model_accuracy()
return EvaluationResult(
accuracy=self.evaluation_result.accuracy,
total_test_samples=self.test_dataset_loader.total_samples
)
eval_config = {"insertion_penalty": 1, "deletion_penalty": 2, "substitution_penalty": 1}
evaluator = ModelEvaluator(**eval_config)
evaluation = evaluator.evaluate()
print(evaluator)
print(evaluation)
print("Model accuracy : {:.2f} %".format(evaluation.accuracy))