mirror of
https://github.com/Monadical-SAS/reflector.git
synced 2025-12-20 20:29:06 +00:00
204 lines
6.6 KiB
Python
204 lines
6.6 KiB
Python
import re
|
|
from pathlib import Path
|
|
from typing import Any, List
|
|
|
|
from jiwer import wer
|
|
from Levenshtein import distance
|
|
from pydantic import BaseModel, Field, field_validator
|
|
from tqdm.auto import tqdm
|
|
from whisper.normalizers import EnglishTextNormalizer
|
|
|
|
|
|
class EvaluationResult(BaseModel):
|
|
"""
|
|
Result object of the model evaluation
|
|
"""
|
|
accuracy: float = Field(default=0.0)
|
|
total_test_samples: int = Field(default=0)
|
|
|
|
|
|
class EvaluationTestSample(BaseModel):
|
|
"""
|
|
Represents one test sample
|
|
"""
|
|
|
|
reference_text: str
|
|
predicted_text: str
|
|
|
|
def update(self, reference_text:str, predicted_text:str) -> None:
|
|
self.reference_text = reference_text
|
|
self.predicted_text = predicted_text
|
|
|
|
|
|
class TestDatasetLoader(BaseModel):
|
|
"""
|
|
Test samples loader
|
|
"""
|
|
|
|
test_dir: Path = Field(default=Path(__file__).parent)
|
|
total_samples: int = Field(default=0)
|
|
|
|
@field_validator("test_dir")
|
|
def validate_file_path(cls, path):
|
|
"""
|
|
Check the file path
|
|
"""
|
|
if not path.exists():
|
|
raise ValueError("Path does not exist")
|
|
return path
|
|
|
|
def _load_test_data(self) -> tuple[Path, Path]:
|
|
"""
|
|
Loader function to validate inout files and generate samples
|
|
"""
|
|
PREDICTED_TEST_SAMPLES_DIR = self.test_dir / "predicted_texts"
|
|
REFERENCE_TEST_SAMPLES_DIR = self.test_dir / "reference_texts"
|
|
|
|
for filename in PREDICTED_TEST_SAMPLES_DIR.iterdir():
|
|
match = re.search(r"(\d+)\.txt$", filename.as_posix())
|
|
if match:
|
|
sample_id = match.group(1)
|
|
pred_file_path = PREDICTED_TEST_SAMPLES_DIR / filename
|
|
ref_file_name = "ref_sample_" + str(sample_id) + ".txt"
|
|
ref_file_path = REFERENCE_TEST_SAMPLES_DIR / ref_file_name
|
|
if ref_file_path.exists():
|
|
self.total_samples += 1
|
|
yield ref_file_path, pred_file_path
|
|
|
|
def __iter__(self) -> EvaluationTestSample:
|
|
"""
|
|
Iter method for the test loader
|
|
"""
|
|
for pred_file_path, ref_file_path in self._load_test_data():
|
|
with open(pred_file_path, "r", encoding="utf-8") as file:
|
|
pred_text = file.read()
|
|
with open(ref_file_path, "r", encoding="utf-8") as file:
|
|
ref_text = file.read()
|
|
yield EvaluationTestSample(reference_text=ref_text, predicted_text=pred_text)
|
|
|
|
|
|
class EvaluationConfig(BaseModel):
|
|
"""
|
|
Model for evaluation parameters
|
|
"""
|
|
insertion_penalty: int = Field(default=1)
|
|
substitution_penalty: int = Field(default=1)
|
|
deletion_penalty: int = Field(default=1)
|
|
normalizer: Any = Field(default=EnglishTextNormalizer())
|
|
test_directory: str = Field(default=str(Path(__file__).parent))
|
|
|
|
|
|
class ModelEvaluator:
|
|
"""
|
|
Class that comprises all model evaluation related processes and methods
|
|
"""
|
|
|
|
# The 2 popular methods of WER differ slightly. More dimensions of accuracy
|
|
# will be added. For now, the average of these 2 will serve as the metric.
|
|
WEIGHTED_WER_LEVENSHTEIN = 0.0
|
|
WER_LEVENSHTEIN = []
|
|
WEIGHTED_WER_JIWER = 0.0
|
|
WER_JIWER = []
|
|
|
|
evaluation_result = EvaluationResult()
|
|
test_dataset_loader = None
|
|
evaluation_config = None
|
|
|
|
def __init__(self, **kwargs):
|
|
self.evaluation_config = EvaluationConfig(**kwargs)
|
|
self.test_dataset_loader = TestDatasetLoader(test_dir=self.evaluation_config.test_directory)
|
|
|
|
def __repr__(self):
|
|
return f"ModelEvaluator({self.evaluation_config})"
|
|
|
|
def describe(self) -> dict:
|
|
"""
|
|
Returns the parameters defining the evaluator
|
|
"""
|
|
return self.evaluation_config.model_dump()
|
|
def _normalize(self, sample: EvaluationTestSample) -> None:
|
|
"""
|
|
Normalize both reference and predicted text
|
|
"""
|
|
sample.update(
|
|
self.evaluation_config.normalizer(sample.reference_text),
|
|
self.evaluation_config.normalizer(sample.predicted_text),
|
|
)
|
|
|
|
def _calculate_wer(self, sample: EvaluationTestSample) -> float:
|
|
"""
|
|
Based on weights for (insert, delete, substitute), calculate
|
|
the Word Error Rate
|
|
"""
|
|
levenshtein_distance = distance(
|
|
s1=sample.reference_text,
|
|
s2=sample.predicted_text,
|
|
weights=(
|
|
self.evaluation_config.insertion_penalty,
|
|
self.evaluation_config.deletion_penalty,
|
|
self.evaluation_config.substitution_penalty,
|
|
),
|
|
)
|
|
wer = levenshtein_distance / len(sample.reference_text)
|
|
return wer
|
|
|
|
def _calculate_wers(self) -> None:
|
|
"""
|
|
Compute WER
|
|
"""
|
|
for sample in tqdm(self.test_dataset_loader, desc="Evaluating"):
|
|
self._normalize(sample)
|
|
wer_item_l = {
|
|
"wer": self._calculate_wer(sample),
|
|
"no_of_words": len(sample.reference_text),
|
|
}
|
|
wer_item_j = {
|
|
"wer": wer(sample.reference_text, sample.predicted_text),
|
|
"no_of_words": len(sample.reference_text),
|
|
}
|
|
self.WER_LEVENSHTEIN.append(wer_item_l)
|
|
self.WER_JIWER.append(wer_item_j)
|
|
|
|
def _calculate_weighted_wer(self, wers: List[float]) -> float:
|
|
"""
|
|
Calculate the weighted WER from WER
|
|
"""
|
|
total_wer = 0.0
|
|
total_words = 0.0
|
|
for item in wers:
|
|
total_wer += item["no_of_words"] * item["wer"]
|
|
total_words += item["no_of_words"]
|
|
return total_wer / total_words
|
|
|
|
def _calculate_model_accuracy(self) -> None:
|
|
"""
|
|
Compute model accuracy
|
|
"""
|
|
self._calculate_wers()
|
|
weighted_wer_levenshtein = self._calculate_weighted_wer(self.WER_LEVENSHTEIN)
|
|
weighted_wer_jiwer = self._calculate_weighted_wer(self.WER_JIWER)
|
|
|
|
final_weighted_wer = (weighted_wer_levenshtein + weighted_wer_jiwer) / 2
|
|
self.evaluation_result.accuracy = (1 - final_weighted_wer) * 100
|
|
|
|
def evaluate(self, recalculate: bool = False) -> EvaluationResult:
|
|
"""
|
|
Triggers the model evaluation
|
|
"""
|
|
if not self.evaluation_result.accuracy or recalculate:
|
|
self._calculate_model_accuracy()
|
|
return EvaluationResult(
|
|
accuracy=self.evaluation_result.accuracy,
|
|
total_test_samples=self.test_dataset_loader.total_samples
|
|
)
|
|
|
|
|
|
eval_config = {"insertion_penalty": 1, "deletion_penalty": 2, "substitution_penalty": 1}
|
|
|
|
evaluator = ModelEvaluator(**eval_config)
|
|
evaluation = evaluator.evaluate()
|
|
|
|
print(evaluator)
|
|
print(evaluation)
|
|
print("Model accuracy : {:.2f} %".format(evaluation.accuracy))
|