mirror of
https://github.com/Monadical-SAS/reflector.git
synced 2025-12-20 20:29:06 +00:00
resolve review comments
This commit is contained in:
@@ -1,89 +1,67 @@
|
|||||||
import json
|
|
||||||
import os
|
|
||||||
import re
|
import re
|
||||||
from dataclasses import dataclass
|
|
||||||
from pathlib import Path
|
from pathlib import Path
|
||||||
from typing import List, Union
|
from typing import Any, List
|
||||||
|
|
||||||
from jiwer import wer
|
from jiwer import wer
|
||||||
from Levenshtein import distance
|
from Levenshtein import distance
|
||||||
|
from pydantic import BaseModel, Field, field_validator
|
||||||
from tqdm.auto import tqdm
|
from tqdm.auto import tqdm
|
||||||
from whisper.normalizers import EnglishTextNormalizer
|
from whisper.normalizers import EnglishTextNormalizer
|
||||||
|
|
||||||
|
|
||||||
@dataclass
|
class EvaluationResult(BaseModel):
|
||||||
class EvaluationResult:
|
|
||||||
"""
|
"""
|
||||||
Result object of the model evaluation
|
Result object of the model evaluation
|
||||||
"""
|
"""
|
||||||
|
accuracy: float = Field(default=0.0)
|
||||||
accuracy = float
|
total_test_samples: int = Field(default=0)
|
||||||
total_test_samples = int
|
|
||||||
|
|
||||||
def __init__(self, accuracy, total_test_samples):
|
|
||||||
self.accuracy = accuracy
|
|
||||||
self.total_test_samples = total_test_samples
|
|
||||||
|
|
||||||
def __repr__(self):
|
|
||||||
return (
|
|
||||||
"EvaluationResult("
|
|
||||||
+ json.dumps(
|
|
||||||
{
|
|
||||||
"accuracy": self.accuracy,
|
|
||||||
"total_test_samples": self.total_test_samples,
|
|
||||||
}
|
|
||||||
)
|
|
||||||
+ ")"
|
|
||||||
)
|
|
||||||
|
|
||||||
|
|
||||||
@dataclass
|
class EvaluationTestSample(BaseModel):
|
||||||
class EvaluationTestSample:
|
|
||||||
"""
|
"""
|
||||||
Represents one test sample
|
Represents one test sample
|
||||||
"""
|
"""
|
||||||
|
|
||||||
reference_text = str
|
reference_text: str
|
||||||
predicted_text = str
|
predicted_text: str
|
||||||
|
|
||||||
def __init__(self, reference_text, predicted_text):
|
def update(self, reference_text:str, predicted_text:str) -> None:
|
||||||
self.reference_text = reference_text
|
|
||||||
self.predicted_text = predicted_text
|
|
||||||
|
|
||||||
def update(self, reference_text, predicted_text):
|
|
||||||
self.reference_text = reference_text
|
self.reference_text = reference_text
|
||||||
self.predicted_text = predicted_text
|
self.predicted_text = predicted_text
|
||||||
|
|
||||||
|
|
||||||
class TestDatasetLoader:
|
class TestDatasetLoader(BaseModel):
|
||||||
"""
|
"""
|
||||||
Test samples loader
|
Test samples loader
|
||||||
"""
|
"""
|
||||||
|
|
||||||
parent_dir = None
|
test_dir: Path = Field(default=Path(__file__).parent)
|
||||||
total_samples = 0
|
total_samples: int = Field(default=0)
|
||||||
|
|
||||||
def __init__(self, parent_dir: Union[Path | str]):
|
@field_validator("test_dir")
|
||||||
if isinstance(parent_dir, str):
|
def validate_file_path(cls, path):
|
||||||
self.parent_dir = Path(parent_dir)
|
"""
|
||||||
else:
|
Check the file path
|
||||||
self.parent_dir = parent_dir
|
"""
|
||||||
|
if not path.exists():
|
||||||
|
raise ValueError("Path does not exist")
|
||||||
|
return path
|
||||||
|
|
||||||
def _load_test_data(self) -> tuple[str, str]:
|
def _load_test_data(self) -> tuple[Path, Path]:
|
||||||
"""
|
"""
|
||||||
Loader function to validate inout files and generate samples
|
Loader function to validate inout files and generate samples
|
||||||
"""
|
"""
|
||||||
PREDICTED_TEST_SAMPLES_DIR = self.parent_dir / "predicted_texts"
|
PREDICTED_TEST_SAMPLES_DIR = self.test_dir / "predicted_texts"
|
||||||
REFERENCE_TEST_SAMPLES_DIR = self.parent_dir / "reference_texts"
|
REFERENCE_TEST_SAMPLES_DIR = self.test_dir / "reference_texts"
|
||||||
|
|
||||||
for filename in os.listdir(PREDICTED_TEST_SAMPLES_DIR.as_posix()):
|
for filename in PREDICTED_TEST_SAMPLES_DIR.iterdir():
|
||||||
match = re.search(r"(\d+)\.txt$", filename)
|
match = re.search(r"(\d+)\.txt$", filename.as_posix())
|
||||||
if match:
|
if match:
|
||||||
sample_id = match.group(1)
|
sample_id = match.group(1)
|
||||||
pred_file_path = (PREDICTED_TEST_SAMPLES_DIR / filename).as_posix()
|
pred_file_path = PREDICTED_TEST_SAMPLES_DIR / filename
|
||||||
ref_file_name = "ref_sample_" + str(sample_id) + ".txt"
|
ref_file_name = "ref_sample_" + str(sample_id) + ".txt"
|
||||||
ref_file_path = (REFERENCE_TEST_SAMPLES_DIR / ref_file_name).as_posix()
|
ref_file_path = REFERENCE_TEST_SAMPLES_DIR / ref_file_name
|
||||||
if os.path.exists(ref_file_path):
|
if ref_file_path.exists():
|
||||||
self.total_samples += 1
|
self.total_samples += 1
|
||||||
yield ref_file_path, pred_file_path
|
yield ref_file_path, pred_file_path
|
||||||
|
|
||||||
@@ -96,7 +74,18 @@ class TestDatasetLoader:
|
|||||||
pred_text = file.read()
|
pred_text = file.read()
|
||||||
with open(ref_file_path, "r", encoding="utf-8") as file:
|
with open(ref_file_path, "r", encoding="utf-8") as file:
|
||||||
ref_text = file.read()
|
ref_text = file.read()
|
||||||
yield EvaluationTestSample(ref_text, pred_text)
|
yield EvaluationTestSample(reference_text=ref_text, predicted_text=pred_text)
|
||||||
|
|
||||||
|
|
||||||
|
class EvaluationConfig(BaseModel):
|
||||||
|
"""
|
||||||
|
Model for evaluation parameters
|
||||||
|
"""
|
||||||
|
insertion_penalty: int = Field(default=1)
|
||||||
|
substitution_penalty: int = Field(default=1)
|
||||||
|
deletion_penalty: int = Field(default=1)
|
||||||
|
normalizer: Any = Field(default=EnglishTextNormalizer())
|
||||||
|
test_directory: str = Field(default=str(Path(__file__).parent))
|
||||||
|
|
||||||
|
|
||||||
class ModelEvaluator:
|
class ModelEvaluator:
|
||||||
@@ -111,38 +100,29 @@ class ModelEvaluator:
|
|||||||
WEIGHTED_WER_JIWER = 0.0
|
WEIGHTED_WER_JIWER = 0.0
|
||||||
WER_JIWER = []
|
WER_JIWER = []
|
||||||
|
|
||||||
normalizer = None
|
evaluation_result = EvaluationResult()
|
||||||
accuracy = None
|
|
||||||
test_dataset_loader = None
|
test_dataset_loader = None
|
||||||
test_directory = None
|
evaluation_config = None
|
||||||
evaluation_config = {}
|
|
||||||
|
|
||||||
def __init__(self, **kwargs):
|
def __init__(self, **kwargs):
|
||||||
self.evaluation_config = {k: v for k, v in kwargs.items() if v is not None}
|
self.evaluation_config = EvaluationConfig(**kwargs)
|
||||||
if "normalizer" not in self.evaluation_config:
|
self.test_dataset_loader = TestDatasetLoader(test_dir=self.evaluation_config.test_directory)
|
||||||
self.normalizer = EnglishTextNormalizer()
|
|
||||||
self.evaluation_config["normalizer"] = str(type(self.normalizer))
|
|
||||||
if "parent_dir" not in self.evaluation_config:
|
|
||||||
self.test_directory = Path(__file__).parent
|
|
||||||
self.test_dataset_loader = TestDatasetLoader(self.test_directory)
|
|
||||||
self.evaluation_config["test_directory"] = str(self.test_directory)
|
|
||||||
|
|
||||||
def __repr__(self):
|
def __repr__(self):
|
||||||
return "ModelEvaluator(" + json.dumps(self.describe(), indent=4) + ")"
|
return f"ModelEvaluator({self.evaluation_config})"
|
||||||
|
|
||||||
def describe(self) -> dict:
|
def describe(self) -> dict:
|
||||||
"""
|
"""
|
||||||
Returns the parameters defining the evaluator
|
Returns the parameters defining the evaluator
|
||||||
"""
|
"""
|
||||||
return self.evaluation_config
|
return self.evaluation_config.model_dump()
|
||||||
|
|
||||||
def _normalize(self, sample: EvaluationTestSample) -> None:
|
def _normalize(self, sample: EvaluationTestSample) -> None:
|
||||||
"""
|
"""
|
||||||
Normalize both reference and predicted text
|
Normalize both reference and predicted text
|
||||||
"""
|
"""
|
||||||
sample.update(
|
sample.update(
|
||||||
self.normalizer(sample.reference_text),
|
self.evaluation_config.normalizer(sample.reference_text),
|
||||||
self.normalizer(sample.predicted_text),
|
self.evaluation_config.normalizer(sample.predicted_text),
|
||||||
)
|
)
|
||||||
|
|
||||||
def _calculate_wer(self, sample: EvaluationTestSample) -> float:
|
def _calculate_wer(self, sample: EvaluationTestSample) -> float:
|
||||||
@@ -154,9 +134,9 @@ class ModelEvaluator:
|
|||||||
s1=sample.reference_text,
|
s1=sample.reference_text,
|
||||||
s2=sample.predicted_text,
|
s2=sample.predicted_text,
|
||||||
weights=(
|
weights=(
|
||||||
self.evaluation_config["insertion_penalty"],
|
self.evaluation_config.insertion_penalty,
|
||||||
self.evaluation_config["deletion_penalty"],
|
self.evaluation_config.deletion_penalty,
|
||||||
self.evaluation_config["substitution_penalty"],
|
self.evaluation_config.substitution_penalty,
|
||||||
),
|
),
|
||||||
)
|
)
|
||||||
wer = levenshtein_distance / len(sample.reference_text)
|
wer = levenshtein_distance / len(sample.reference_text)
|
||||||
@@ -166,7 +146,7 @@ class ModelEvaluator:
|
|||||||
"""
|
"""
|
||||||
Compute WER
|
Compute WER
|
||||||
"""
|
"""
|
||||||
for sample in tqdm(self.test_dataset_loader, desc="Evaluating", ncols=100):
|
for sample in tqdm(self.test_dataset_loader, desc="Evaluating"):
|
||||||
self._normalize(sample)
|
self._normalize(sample)
|
||||||
wer_item_l = {
|
wer_item_l = {
|
||||||
"wer": self._calculate_wer(sample),
|
"wer": self._calculate_wer(sample),
|
||||||
@@ -199,15 +179,18 @@ class ModelEvaluator:
|
|||||||
weighted_wer_jiwer = self._calculate_weighted_wer(self.WER_JIWER)
|
weighted_wer_jiwer = self._calculate_weighted_wer(self.WER_JIWER)
|
||||||
|
|
||||||
final_weighted_wer = (weighted_wer_levenshtein + weighted_wer_jiwer) / 2
|
final_weighted_wer = (weighted_wer_levenshtein + weighted_wer_jiwer) / 2
|
||||||
self.accuracy = (1 - final_weighted_wer) * 100
|
self.evaluation_result.accuracy = (1 - final_weighted_wer) * 100
|
||||||
|
|
||||||
def evaluate(self, recalculate: bool = False) -> EvaluationResult:
|
def evaluate(self, recalculate: bool = False) -> EvaluationResult:
|
||||||
"""
|
"""
|
||||||
Triggers the model evaluation
|
Triggers the model evaluation
|
||||||
"""
|
"""
|
||||||
if not self.accuracy or recalculate:
|
if not self.evaluation_result.accuracy or recalculate:
|
||||||
self._calculate_model_accuracy()
|
self._calculate_model_accuracy()
|
||||||
return EvaluationResult(self.accuracy, self.test_dataset_loader.total_samples)
|
return EvaluationResult(
|
||||||
|
accuracy=self.evaluation_result.accuracy,
|
||||||
|
total_test_samples=self.test_dataset_loader.total_samples
|
||||||
|
)
|
||||||
|
|
||||||
|
|
||||||
eval_config = {"insertion_penalty": 1, "deletion_penalty": 2, "substitution_penalty": 1}
|
eval_config = {"insertion_penalty": 1, "deletion_penalty": 2, "substitution_penalty": 1}
|
||||||
|
|||||||
2
server/poetry.lock
generated
2
server/poetry.lock
generated
@@ -2996,4 +2996,4 @@ multidict = ">=4.0"
|
|||||||
[metadata]
|
[metadata]
|
||||||
lock-version = "2.0"
|
lock-version = "2.0"
|
||||||
python-versions = "^3.11"
|
python-versions = "^3.11"
|
||||||
content-hash = "a5cd48fcfc629c2cd2f4fcc4263c57f867d84acf60824eaf952e365578374d1d"
|
content-hash = "c9924049dacf7310590416f096f5b20f6ed905d8a50edf5e8afcf2c28b70799f"
|
||||||
|
|||||||
@@ -51,6 +51,7 @@ aioboto3 = "^11.2.0"
|
|||||||
jiwer = "^3.0.2"
|
jiwer = "^3.0.2"
|
||||||
levenshtein = "^0.21.1"
|
levenshtein = "^0.21.1"
|
||||||
tqdm = "^4.66.0"
|
tqdm = "^4.66.0"
|
||||||
|
pydantic = "^2.1.1"
|
||||||
|
|
||||||
[build-system]
|
[build-system]
|
||||||
requires = ["poetry-core"]
|
requires = ["poetry-core"]
|
||||||
|
|||||||
Reference in New Issue
Block a user