resolve review comments

This commit is contained in:
Gokul Mohanarangan
2023-08-10 14:33:46 +05:30
parent bb983194f8
commit af954e2818
3 changed files with 60 additions and 76 deletions

View File

@@ -1,89 +1,67 @@
import json
import os
import re import re
from dataclasses import dataclass
from pathlib import Path from pathlib import Path
from typing import List, Union from typing import Any, List
from jiwer import wer from jiwer import wer
from Levenshtein import distance from Levenshtein import distance
from pydantic import BaseModel, Field, field_validator
from tqdm.auto import tqdm from tqdm.auto import tqdm
from whisper.normalizers import EnglishTextNormalizer from whisper.normalizers import EnglishTextNormalizer
@dataclass class EvaluationResult(BaseModel):
class EvaluationResult:
""" """
Result object of the model evaluation Result object of the model evaluation
""" """
accuracy: float = Field(default=0.0)
accuracy = float total_test_samples: int = Field(default=0)
total_test_samples = int
def __init__(self, accuracy, total_test_samples):
self.accuracy = accuracy
self.total_test_samples = total_test_samples
def __repr__(self):
return (
"EvaluationResult("
+ json.dumps(
{
"accuracy": self.accuracy,
"total_test_samples": self.total_test_samples,
}
)
+ ")"
)
@dataclass class EvaluationTestSample(BaseModel):
class EvaluationTestSample:
""" """
Represents one test sample Represents one test sample
""" """
reference_text = str reference_text: str
predicted_text = str predicted_text: str
def __init__(self, reference_text, predicted_text): def update(self, reference_text:str, predicted_text:str) -> None:
self.reference_text = reference_text
self.predicted_text = predicted_text
def update(self, reference_text, predicted_text):
self.reference_text = reference_text self.reference_text = reference_text
self.predicted_text = predicted_text self.predicted_text = predicted_text
class TestDatasetLoader: class TestDatasetLoader(BaseModel):
""" """
Test samples loader Test samples loader
""" """
parent_dir = None test_dir: Path = Field(default=Path(__file__).parent)
total_samples = 0 total_samples: int = Field(default=0)
def __init__(self, parent_dir: Union[Path | str]): @field_validator("test_dir")
if isinstance(parent_dir, str): def validate_file_path(cls, path):
self.parent_dir = Path(parent_dir) """
else: Check the file path
self.parent_dir = parent_dir """
if not path.exists():
raise ValueError("Path does not exist")
return path
def _load_test_data(self) -> tuple[str, str]: def _load_test_data(self) -> tuple[Path, Path]:
""" """
Loader function to validate inout files and generate samples Loader function to validate inout files and generate samples
""" """
PREDICTED_TEST_SAMPLES_DIR = self.parent_dir / "predicted_texts" PREDICTED_TEST_SAMPLES_DIR = self.test_dir / "predicted_texts"
REFERENCE_TEST_SAMPLES_DIR = self.parent_dir / "reference_texts" REFERENCE_TEST_SAMPLES_DIR = self.test_dir / "reference_texts"
for filename in os.listdir(PREDICTED_TEST_SAMPLES_DIR.as_posix()): for filename in PREDICTED_TEST_SAMPLES_DIR.iterdir():
match = re.search(r"(\d+)\.txt$", filename) match = re.search(r"(\d+)\.txt$", filename.as_posix())
if match: if match:
sample_id = match.group(1) sample_id = match.group(1)
pred_file_path = (PREDICTED_TEST_SAMPLES_DIR / filename).as_posix() pred_file_path = PREDICTED_TEST_SAMPLES_DIR / filename
ref_file_name = "ref_sample_" + str(sample_id) + ".txt" ref_file_name = "ref_sample_" + str(sample_id) + ".txt"
ref_file_path = (REFERENCE_TEST_SAMPLES_DIR / ref_file_name).as_posix() ref_file_path = REFERENCE_TEST_SAMPLES_DIR / ref_file_name
if os.path.exists(ref_file_path): if ref_file_path.exists():
self.total_samples += 1 self.total_samples += 1
yield ref_file_path, pred_file_path yield ref_file_path, pred_file_path
@@ -96,7 +74,18 @@ class TestDatasetLoader:
pred_text = file.read() pred_text = file.read()
with open(ref_file_path, "r", encoding="utf-8") as file: with open(ref_file_path, "r", encoding="utf-8") as file:
ref_text = file.read() ref_text = file.read()
yield EvaluationTestSample(ref_text, pred_text) yield EvaluationTestSample(reference_text=ref_text, predicted_text=pred_text)
class EvaluationConfig(BaseModel):
"""
Model for evaluation parameters
"""
insertion_penalty: int = Field(default=1)
substitution_penalty: int = Field(default=1)
deletion_penalty: int = Field(default=1)
normalizer: Any = Field(default=EnglishTextNormalizer())
test_directory: str = Field(default=str(Path(__file__).parent))
class ModelEvaluator: class ModelEvaluator:
@@ -111,38 +100,29 @@ class ModelEvaluator:
WEIGHTED_WER_JIWER = 0.0 WEIGHTED_WER_JIWER = 0.0
WER_JIWER = [] WER_JIWER = []
normalizer = None evaluation_result = EvaluationResult()
accuracy = None
test_dataset_loader = None test_dataset_loader = None
test_directory = None evaluation_config = None
evaluation_config = {}
def __init__(self, **kwargs): def __init__(self, **kwargs):
self.evaluation_config = {k: v for k, v in kwargs.items() if v is not None} self.evaluation_config = EvaluationConfig(**kwargs)
if "normalizer" not in self.evaluation_config: self.test_dataset_loader = TestDatasetLoader(test_dir=self.evaluation_config.test_directory)
self.normalizer = EnglishTextNormalizer()
self.evaluation_config["normalizer"] = str(type(self.normalizer))
if "parent_dir" not in self.evaluation_config:
self.test_directory = Path(__file__).parent
self.test_dataset_loader = TestDatasetLoader(self.test_directory)
self.evaluation_config["test_directory"] = str(self.test_directory)
def __repr__(self): def __repr__(self):
return "ModelEvaluator(" + json.dumps(self.describe(), indent=4) + ")" return f"ModelEvaluator({self.evaluation_config})"
def describe(self) -> dict: def describe(self) -> dict:
""" """
Returns the parameters defining the evaluator Returns the parameters defining the evaluator
""" """
return self.evaluation_config return self.evaluation_config.model_dump()
def _normalize(self, sample: EvaluationTestSample) -> None: def _normalize(self, sample: EvaluationTestSample) -> None:
""" """
Normalize both reference and predicted text Normalize both reference and predicted text
""" """
sample.update( sample.update(
self.normalizer(sample.reference_text), self.evaluation_config.normalizer(sample.reference_text),
self.normalizer(sample.predicted_text), self.evaluation_config.normalizer(sample.predicted_text),
) )
def _calculate_wer(self, sample: EvaluationTestSample) -> float: def _calculate_wer(self, sample: EvaluationTestSample) -> float:
@@ -154,9 +134,9 @@ class ModelEvaluator:
s1=sample.reference_text, s1=sample.reference_text,
s2=sample.predicted_text, s2=sample.predicted_text,
weights=( weights=(
self.evaluation_config["insertion_penalty"], self.evaluation_config.insertion_penalty,
self.evaluation_config["deletion_penalty"], self.evaluation_config.deletion_penalty,
self.evaluation_config["substitution_penalty"], self.evaluation_config.substitution_penalty,
), ),
) )
wer = levenshtein_distance / len(sample.reference_text) wer = levenshtein_distance / len(sample.reference_text)
@@ -166,7 +146,7 @@ class ModelEvaluator:
""" """
Compute WER Compute WER
""" """
for sample in tqdm(self.test_dataset_loader, desc="Evaluating", ncols=100): for sample in tqdm(self.test_dataset_loader, desc="Evaluating"):
self._normalize(sample) self._normalize(sample)
wer_item_l = { wer_item_l = {
"wer": self._calculate_wer(sample), "wer": self._calculate_wer(sample),
@@ -199,15 +179,18 @@ class ModelEvaluator:
weighted_wer_jiwer = self._calculate_weighted_wer(self.WER_JIWER) weighted_wer_jiwer = self._calculate_weighted_wer(self.WER_JIWER)
final_weighted_wer = (weighted_wer_levenshtein + weighted_wer_jiwer) / 2 final_weighted_wer = (weighted_wer_levenshtein + weighted_wer_jiwer) / 2
self.accuracy = (1 - final_weighted_wer) * 100 self.evaluation_result.accuracy = (1 - final_weighted_wer) * 100
def evaluate(self, recalculate: bool = False) -> EvaluationResult: def evaluate(self, recalculate: bool = False) -> EvaluationResult:
""" """
Triggers the model evaluation Triggers the model evaluation
""" """
if not self.accuracy or recalculate: if not self.evaluation_result.accuracy or recalculate:
self._calculate_model_accuracy() self._calculate_model_accuracy()
return EvaluationResult(self.accuracy, self.test_dataset_loader.total_samples) return EvaluationResult(
accuracy=self.evaluation_result.accuracy,
total_test_samples=self.test_dataset_loader.total_samples
)
eval_config = {"insertion_penalty": 1, "deletion_penalty": 2, "substitution_penalty": 1} eval_config = {"insertion_penalty": 1, "deletion_penalty": 2, "substitution_penalty": 1}

2
server/poetry.lock generated
View File

@@ -2996,4 +2996,4 @@ multidict = ">=4.0"
[metadata] [metadata]
lock-version = "2.0" lock-version = "2.0"
python-versions = "^3.11" python-versions = "^3.11"
content-hash = "a5cd48fcfc629c2cd2f4fcc4263c57f867d84acf60824eaf952e365578374d1d" content-hash = "c9924049dacf7310590416f096f5b20f6ed905d8a50edf5e8afcf2c28b70799f"

View File

@@ -51,6 +51,7 @@ aioboto3 = "^11.2.0"
jiwer = "^3.0.2" jiwer = "^3.0.2"
levenshtein = "^0.21.1" levenshtein = "^0.21.1"
tqdm = "^4.66.0" tqdm = "^4.66.0"
pydantic = "^2.1.1"
[build-system] [build-system]
requires = ["poetry-core"] requires = ["poetry-core"]