mirror of
https://github.com/Monadical-SAS/reflector.git
synced 2025-12-20 20:29:06 +00:00
resolve review comments
This commit is contained in:
@@ -1,89 +1,67 @@
|
||||
import json
|
||||
import os
|
||||
import re
|
||||
from dataclasses import dataclass
|
||||
from pathlib import Path
|
||||
from typing import List, Union
|
||||
from typing import Any, List
|
||||
|
||||
from jiwer import wer
|
||||
from Levenshtein import distance
|
||||
from pydantic import BaseModel, Field, field_validator
|
||||
from tqdm.auto import tqdm
|
||||
from whisper.normalizers import EnglishTextNormalizer
|
||||
|
||||
|
||||
@dataclass
|
||||
class EvaluationResult:
|
||||
class EvaluationResult(BaseModel):
|
||||
"""
|
||||
Result object of the model evaluation
|
||||
"""
|
||||
|
||||
accuracy = float
|
||||
total_test_samples = int
|
||||
|
||||
def __init__(self, accuracy, total_test_samples):
|
||||
self.accuracy = accuracy
|
||||
self.total_test_samples = total_test_samples
|
||||
|
||||
def __repr__(self):
|
||||
return (
|
||||
"EvaluationResult("
|
||||
+ json.dumps(
|
||||
{
|
||||
"accuracy": self.accuracy,
|
||||
"total_test_samples": self.total_test_samples,
|
||||
}
|
||||
)
|
||||
+ ")"
|
||||
)
|
||||
accuracy: float = Field(default=0.0)
|
||||
total_test_samples: int = Field(default=0)
|
||||
|
||||
|
||||
@dataclass
|
||||
class EvaluationTestSample:
|
||||
class EvaluationTestSample(BaseModel):
|
||||
"""
|
||||
Represents one test sample
|
||||
"""
|
||||
|
||||
reference_text = str
|
||||
predicted_text = str
|
||||
reference_text: str
|
||||
predicted_text: str
|
||||
|
||||
def __init__(self, reference_text, predicted_text):
|
||||
self.reference_text = reference_text
|
||||
self.predicted_text = predicted_text
|
||||
|
||||
def update(self, reference_text, predicted_text):
|
||||
def update(self, reference_text:str, predicted_text:str) -> None:
|
||||
self.reference_text = reference_text
|
||||
self.predicted_text = predicted_text
|
||||
|
||||
|
||||
class TestDatasetLoader:
|
||||
class TestDatasetLoader(BaseModel):
|
||||
"""
|
||||
Test samples loader
|
||||
"""
|
||||
|
||||
parent_dir = None
|
||||
total_samples = 0
|
||||
test_dir: Path = Field(default=Path(__file__).parent)
|
||||
total_samples: int = Field(default=0)
|
||||
|
||||
def __init__(self, parent_dir: Union[Path | str]):
|
||||
if isinstance(parent_dir, str):
|
||||
self.parent_dir = Path(parent_dir)
|
||||
else:
|
||||
self.parent_dir = parent_dir
|
||||
@field_validator("test_dir")
|
||||
def validate_file_path(cls, path):
|
||||
"""
|
||||
Check the file path
|
||||
"""
|
||||
if not path.exists():
|
||||
raise ValueError("Path does not exist")
|
||||
return path
|
||||
|
||||
def _load_test_data(self) -> tuple[str, str]:
|
||||
def _load_test_data(self) -> tuple[Path, Path]:
|
||||
"""
|
||||
Loader function to validate inout files and generate samples
|
||||
"""
|
||||
PREDICTED_TEST_SAMPLES_DIR = self.parent_dir / "predicted_texts"
|
||||
REFERENCE_TEST_SAMPLES_DIR = self.parent_dir / "reference_texts"
|
||||
PREDICTED_TEST_SAMPLES_DIR = self.test_dir / "predicted_texts"
|
||||
REFERENCE_TEST_SAMPLES_DIR = self.test_dir / "reference_texts"
|
||||
|
||||
for filename in os.listdir(PREDICTED_TEST_SAMPLES_DIR.as_posix()):
|
||||
match = re.search(r"(\d+)\.txt$", filename)
|
||||
for filename in PREDICTED_TEST_SAMPLES_DIR.iterdir():
|
||||
match = re.search(r"(\d+)\.txt$", filename.as_posix())
|
||||
if match:
|
||||
sample_id = match.group(1)
|
||||
pred_file_path = (PREDICTED_TEST_SAMPLES_DIR / filename).as_posix()
|
||||
pred_file_path = PREDICTED_TEST_SAMPLES_DIR / filename
|
||||
ref_file_name = "ref_sample_" + str(sample_id) + ".txt"
|
||||
ref_file_path = (REFERENCE_TEST_SAMPLES_DIR / ref_file_name).as_posix()
|
||||
if os.path.exists(ref_file_path):
|
||||
ref_file_path = REFERENCE_TEST_SAMPLES_DIR / ref_file_name
|
||||
if ref_file_path.exists():
|
||||
self.total_samples += 1
|
||||
yield ref_file_path, pred_file_path
|
||||
|
||||
@@ -96,7 +74,18 @@ class TestDatasetLoader:
|
||||
pred_text = file.read()
|
||||
with open(ref_file_path, "r", encoding="utf-8") as file:
|
||||
ref_text = file.read()
|
||||
yield EvaluationTestSample(ref_text, pred_text)
|
||||
yield EvaluationTestSample(reference_text=ref_text, predicted_text=pred_text)
|
||||
|
||||
|
||||
class EvaluationConfig(BaseModel):
|
||||
"""
|
||||
Model for evaluation parameters
|
||||
"""
|
||||
insertion_penalty: int = Field(default=1)
|
||||
substitution_penalty: int = Field(default=1)
|
||||
deletion_penalty: int = Field(default=1)
|
||||
normalizer: Any = Field(default=EnglishTextNormalizer())
|
||||
test_directory: str = Field(default=str(Path(__file__).parent))
|
||||
|
||||
|
||||
class ModelEvaluator:
|
||||
@@ -111,38 +100,29 @@ class ModelEvaluator:
|
||||
WEIGHTED_WER_JIWER = 0.0
|
||||
WER_JIWER = []
|
||||
|
||||
normalizer = None
|
||||
accuracy = None
|
||||
evaluation_result = EvaluationResult()
|
||||
test_dataset_loader = None
|
||||
test_directory = None
|
||||
evaluation_config = {}
|
||||
evaluation_config = None
|
||||
|
||||
def __init__(self, **kwargs):
|
||||
self.evaluation_config = {k: v for k, v in kwargs.items() if v is not None}
|
||||
if "normalizer" not in self.evaluation_config:
|
||||
self.normalizer = EnglishTextNormalizer()
|
||||
self.evaluation_config["normalizer"] = str(type(self.normalizer))
|
||||
if "parent_dir" not in self.evaluation_config:
|
||||
self.test_directory = Path(__file__).parent
|
||||
self.test_dataset_loader = TestDatasetLoader(self.test_directory)
|
||||
self.evaluation_config["test_directory"] = str(self.test_directory)
|
||||
self.evaluation_config = EvaluationConfig(**kwargs)
|
||||
self.test_dataset_loader = TestDatasetLoader(test_dir=self.evaluation_config.test_directory)
|
||||
|
||||
def __repr__(self):
|
||||
return "ModelEvaluator(" + json.dumps(self.describe(), indent=4) + ")"
|
||||
return f"ModelEvaluator({self.evaluation_config})"
|
||||
|
||||
def describe(self) -> dict:
|
||||
"""
|
||||
Returns the parameters defining the evaluator
|
||||
"""
|
||||
return self.evaluation_config
|
||||
|
||||
return self.evaluation_config.model_dump()
|
||||
def _normalize(self, sample: EvaluationTestSample) -> None:
|
||||
"""
|
||||
Normalize both reference and predicted text
|
||||
"""
|
||||
sample.update(
|
||||
self.normalizer(sample.reference_text),
|
||||
self.normalizer(sample.predicted_text),
|
||||
self.evaluation_config.normalizer(sample.reference_text),
|
||||
self.evaluation_config.normalizer(sample.predicted_text),
|
||||
)
|
||||
|
||||
def _calculate_wer(self, sample: EvaluationTestSample) -> float:
|
||||
@@ -154,9 +134,9 @@ class ModelEvaluator:
|
||||
s1=sample.reference_text,
|
||||
s2=sample.predicted_text,
|
||||
weights=(
|
||||
self.evaluation_config["insertion_penalty"],
|
||||
self.evaluation_config["deletion_penalty"],
|
||||
self.evaluation_config["substitution_penalty"],
|
||||
self.evaluation_config.insertion_penalty,
|
||||
self.evaluation_config.deletion_penalty,
|
||||
self.evaluation_config.substitution_penalty,
|
||||
),
|
||||
)
|
||||
wer = levenshtein_distance / len(sample.reference_text)
|
||||
@@ -166,7 +146,7 @@ class ModelEvaluator:
|
||||
"""
|
||||
Compute WER
|
||||
"""
|
||||
for sample in tqdm(self.test_dataset_loader, desc="Evaluating", ncols=100):
|
||||
for sample in tqdm(self.test_dataset_loader, desc="Evaluating"):
|
||||
self._normalize(sample)
|
||||
wer_item_l = {
|
||||
"wer": self._calculate_wer(sample),
|
||||
@@ -199,15 +179,18 @@ class ModelEvaluator:
|
||||
weighted_wer_jiwer = self._calculate_weighted_wer(self.WER_JIWER)
|
||||
|
||||
final_weighted_wer = (weighted_wer_levenshtein + weighted_wer_jiwer) / 2
|
||||
self.accuracy = (1 - final_weighted_wer) * 100
|
||||
self.evaluation_result.accuracy = (1 - final_weighted_wer) * 100
|
||||
|
||||
def evaluate(self, recalculate: bool = False) -> EvaluationResult:
|
||||
"""
|
||||
Triggers the model evaluation
|
||||
"""
|
||||
if not self.accuracy or recalculate:
|
||||
if not self.evaluation_result.accuracy or recalculate:
|
||||
self._calculate_model_accuracy()
|
||||
return EvaluationResult(self.accuracy, self.test_dataset_loader.total_samples)
|
||||
return EvaluationResult(
|
||||
accuracy=self.evaluation_result.accuracy,
|
||||
total_test_samples=self.test_dataset_loader.total_samples
|
||||
)
|
||||
|
||||
|
||||
eval_config = {"insertion_penalty": 1, "deletion_penalty": 2, "substitution_penalty": 1}
|
||||
|
||||
2
server/poetry.lock
generated
2
server/poetry.lock
generated
@@ -2996,4 +2996,4 @@ multidict = ">=4.0"
|
||||
[metadata]
|
||||
lock-version = "2.0"
|
||||
python-versions = "^3.11"
|
||||
content-hash = "a5cd48fcfc629c2cd2f4fcc4263c57f867d84acf60824eaf952e365578374d1d"
|
||||
content-hash = "c9924049dacf7310590416f096f5b20f6ed905d8a50edf5e8afcf2c28b70799f"
|
||||
|
||||
@@ -51,6 +51,7 @@ aioboto3 = "^11.2.0"
|
||||
jiwer = "^3.0.2"
|
||||
levenshtein = "^0.21.1"
|
||||
tqdm = "^4.66.0"
|
||||
pydantic = "^2.1.1"
|
||||
|
||||
[build-system]
|
||||
requires = ["poetry-core"]
|
||||
|
||||
Reference in New Issue
Block a user