diff --git a/server/evaluate/evaluate_transcription.py b/server/evaluate/evaluate_transcription.py index 26c1ce7a..c402f34e 100644 --- a/server/evaluate/evaluate_transcription.py +++ b/server/evaluate/evaluate_transcription.py @@ -1,89 +1,67 @@ -import json -import os import re -from dataclasses import dataclass from pathlib import Path -from typing import List, Union +from typing import Any, List from jiwer import wer from Levenshtein import distance +from pydantic import BaseModel, Field, field_validator from tqdm.auto import tqdm from whisper.normalizers import EnglishTextNormalizer -@dataclass -class EvaluationResult: +class EvaluationResult(BaseModel): """ Result object of the model evaluation """ - - accuracy = float - total_test_samples = int - - def __init__(self, accuracy, total_test_samples): - self.accuracy = accuracy - self.total_test_samples = total_test_samples - - def __repr__(self): - return ( - "EvaluationResult(" - + json.dumps( - { - "accuracy": self.accuracy, - "total_test_samples": self.total_test_samples, - } - ) - + ")" - ) + accuracy: float = Field(default=0.0) + total_test_samples: int = Field(default=0) -@dataclass -class EvaluationTestSample: +class EvaluationTestSample(BaseModel): """ Represents one test sample """ - reference_text = str - predicted_text = str + reference_text: str + predicted_text: str - def __init__(self, reference_text, predicted_text): - self.reference_text = reference_text - self.predicted_text = predicted_text - - def update(self, reference_text, predicted_text): + def update(self, reference_text:str, predicted_text:str) -> None: self.reference_text = reference_text self.predicted_text = predicted_text -class TestDatasetLoader: +class TestDatasetLoader(BaseModel): """ Test samples loader """ - parent_dir = None - total_samples = 0 + test_dir: Path = Field(default=Path(__file__).parent) + total_samples: int = Field(default=0) - def __init__(self, parent_dir: Union[Path | str]): - if isinstance(parent_dir, str): - self.parent_dir = Path(parent_dir) - else: - self.parent_dir = parent_dir + @field_validator("test_dir") + def validate_file_path(cls, path): + """ + Check the file path + """ + if not path.exists(): + raise ValueError("Path does not exist") + return path - def _load_test_data(self) -> tuple[str, str]: + def _load_test_data(self) -> tuple[Path, Path]: """ Loader function to validate inout files and generate samples """ - PREDICTED_TEST_SAMPLES_DIR = self.parent_dir / "predicted_texts" - REFERENCE_TEST_SAMPLES_DIR = self.parent_dir / "reference_texts" + PREDICTED_TEST_SAMPLES_DIR = self.test_dir / "predicted_texts" + REFERENCE_TEST_SAMPLES_DIR = self.test_dir / "reference_texts" - for filename in os.listdir(PREDICTED_TEST_SAMPLES_DIR.as_posix()): - match = re.search(r"(\d+)\.txt$", filename) + for filename in PREDICTED_TEST_SAMPLES_DIR.iterdir(): + match = re.search(r"(\d+)\.txt$", filename.as_posix()) if match: sample_id = match.group(1) - pred_file_path = (PREDICTED_TEST_SAMPLES_DIR / filename).as_posix() + pred_file_path = PREDICTED_TEST_SAMPLES_DIR / filename ref_file_name = "ref_sample_" + str(sample_id) + ".txt" - ref_file_path = (REFERENCE_TEST_SAMPLES_DIR / ref_file_name).as_posix() - if os.path.exists(ref_file_path): + ref_file_path = REFERENCE_TEST_SAMPLES_DIR / ref_file_name + if ref_file_path.exists(): self.total_samples += 1 yield ref_file_path, pred_file_path @@ -96,7 +74,18 @@ class TestDatasetLoader: pred_text = file.read() with open(ref_file_path, "r", encoding="utf-8") as file: ref_text = file.read() - yield EvaluationTestSample(ref_text, pred_text) + yield EvaluationTestSample(reference_text=ref_text, predicted_text=pred_text) + + +class EvaluationConfig(BaseModel): + """ + Model for evaluation parameters + """ + insertion_penalty: int = Field(default=1) + substitution_penalty: int = Field(default=1) + deletion_penalty: int = Field(default=1) + normalizer: Any = Field(default=EnglishTextNormalizer()) + test_directory: str = Field(default=str(Path(__file__).parent)) class ModelEvaluator: @@ -111,38 +100,29 @@ class ModelEvaluator: WEIGHTED_WER_JIWER = 0.0 WER_JIWER = [] - normalizer = None - accuracy = None + evaluation_result = EvaluationResult() test_dataset_loader = None - test_directory = None - evaluation_config = {} + evaluation_config = None def __init__(self, **kwargs): - self.evaluation_config = {k: v for k, v in kwargs.items() if v is not None} - if "normalizer" not in self.evaluation_config: - self.normalizer = EnglishTextNormalizer() - self.evaluation_config["normalizer"] = str(type(self.normalizer)) - if "parent_dir" not in self.evaluation_config: - self.test_directory = Path(__file__).parent - self.test_dataset_loader = TestDatasetLoader(self.test_directory) - self.evaluation_config["test_directory"] = str(self.test_directory) + self.evaluation_config = EvaluationConfig(**kwargs) + self.test_dataset_loader = TestDatasetLoader(test_dir=self.evaluation_config.test_directory) def __repr__(self): - return "ModelEvaluator(" + json.dumps(self.describe(), indent=4) + ")" + return f"ModelEvaluator({self.evaluation_config})" def describe(self) -> dict: """ Returns the parameters defining the evaluator """ - return self.evaluation_config - + return self.evaluation_config.model_dump() def _normalize(self, sample: EvaluationTestSample) -> None: """ Normalize both reference and predicted text """ sample.update( - self.normalizer(sample.reference_text), - self.normalizer(sample.predicted_text), + self.evaluation_config.normalizer(sample.reference_text), + self.evaluation_config.normalizer(sample.predicted_text), ) def _calculate_wer(self, sample: EvaluationTestSample) -> float: @@ -154,9 +134,9 @@ class ModelEvaluator: s1=sample.reference_text, s2=sample.predicted_text, weights=( - self.evaluation_config["insertion_penalty"], - self.evaluation_config["deletion_penalty"], - self.evaluation_config["substitution_penalty"], + self.evaluation_config.insertion_penalty, + self.evaluation_config.deletion_penalty, + self.evaluation_config.substitution_penalty, ), ) wer = levenshtein_distance / len(sample.reference_text) @@ -166,7 +146,7 @@ class ModelEvaluator: """ Compute WER """ - for sample in tqdm(self.test_dataset_loader, desc="Evaluating", ncols=100): + for sample in tqdm(self.test_dataset_loader, desc="Evaluating"): self._normalize(sample) wer_item_l = { "wer": self._calculate_wer(sample), @@ -199,15 +179,18 @@ class ModelEvaluator: weighted_wer_jiwer = self._calculate_weighted_wer(self.WER_JIWER) final_weighted_wer = (weighted_wer_levenshtein + weighted_wer_jiwer) / 2 - self.accuracy = (1 - final_weighted_wer) * 100 + self.evaluation_result.accuracy = (1 - final_weighted_wer) * 100 def evaluate(self, recalculate: bool = False) -> EvaluationResult: """ Triggers the model evaluation """ - if not self.accuracy or recalculate: + if not self.evaluation_result.accuracy or recalculate: self._calculate_model_accuracy() - return EvaluationResult(self.accuracy, self.test_dataset_loader.total_samples) + return EvaluationResult( + accuracy=self.evaluation_result.accuracy, + total_test_samples=self.test_dataset_loader.total_samples + ) eval_config = {"insertion_penalty": 1, "deletion_penalty": 2, "substitution_penalty": 1} diff --git a/server/poetry.lock b/server/poetry.lock index 4c309f6c..dc5cae28 100644 --- a/server/poetry.lock +++ b/server/poetry.lock @@ -2996,4 +2996,4 @@ multidict = ">=4.0" [metadata] lock-version = "2.0" python-versions = "^3.11" -content-hash = "a5cd48fcfc629c2cd2f4fcc4263c57f867d84acf60824eaf952e365578374d1d" +content-hash = "c9924049dacf7310590416f096f5b20f6ed905d8a50edf5e8afcf2c28b70799f" diff --git a/server/pyproject.toml b/server/pyproject.toml index 3da319bb..cdd510a0 100644 --- a/server/pyproject.toml +++ b/server/pyproject.toml @@ -51,6 +51,7 @@ aioboto3 = "^11.2.0" jiwer = "^3.0.2" levenshtein = "^0.21.1" tqdm = "^4.66.0" +pydantic = "^2.1.1" [build-system] requires = ["poetry-core"]