resolve review comments

This commit is contained in:
Gokul Mohanarangan
2023-08-10 14:33:46 +05:30
parent bb983194f8
commit af954e2818
3 changed files with 60 additions and 76 deletions

View File

@@ -1,89 +1,67 @@
import json
import os
import re
from dataclasses import dataclass
from pathlib import Path
from typing import List, Union
from typing import Any, List
from jiwer import wer
from Levenshtein import distance
from pydantic import BaseModel, Field, field_validator
from tqdm.auto import tqdm
from whisper.normalizers import EnglishTextNormalizer
@dataclass
class EvaluationResult:
class EvaluationResult(BaseModel):
"""
Result object of the model evaluation
"""
accuracy = float
total_test_samples = int
def __init__(self, accuracy, total_test_samples):
self.accuracy = accuracy
self.total_test_samples = total_test_samples
def __repr__(self):
return (
"EvaluationResult("
+ json.dumps(
{
"accuracy": self.accuracy,
"total_test_samples": self.total_test_samples,
}
)
+ ")"
)
accuracy: float = Field(default=0.0)
total_test_samples: int = Field(default=0)
@dataclass
class EvaluationTestSample:
class EvaluationTestSample(BaseModel):
"""
Represents one test sample
"""
reference_text = str
predicted_text = str
reference_text: str
predicted_text: str
def __init__(self, reference_text, predicted_text):
self.reference_text = reference_text
self.predicted_text = predicted_text
def update(self, reference_text, predicted_text):
def update(self, reference_text:str, predicted_text:str) -> None:
self.reference_text = reference_text
self.predicted_text = predicted_text
class TestDatasetLoader:
class TestDatasetLoader(BaseModel):
"""
Test samples loader
"""
parent_dir = None
total_samples = 0
test_dir: Path = Field(default=Path(__file__).parent)
total_samples: int = Field(default=0)
def __init__(self, parent_dir: Union[Path | str]):
if isinstance(parent_dir, str):
self.parent_dir = Path(parent_dir)
else:
self.parent_dir = parent_dir
@field_validator("test_dir")
def validate_file_path(cls, path):
"""
Check the file path
"""
if not path.exists():
raise ValueError("Path does not exist")
return path
def _load_test_data(self) -> tuple[str, str]:
def _load_test_data(self) -> tuple[Path, Path]:
"""
Loader function to validate inout files and generate samples
"""
PREDICTED_TEST_SAMPLES_DIR = self.parent_dir / "predicted_texts"
REFERENCE_TEST_SAMPLES_DIR = self.parent_dir / "reference_texts"
PREDICTED_TEST_SAMPLES_DIR = self.test_dir / "predicted_texts"
REFERENCE_TEST_SAMPLES_DIR = self.test_dir / "reference_texts"
for filename in os.listdir(PREDICTED_TEST_SAMPLES_DIR.as_posix()):
match = re.search(r"(\d+)\.txt$", filename)
for filename in PREDICTED_TEST_SAMPLES_DIR.iterdir():
match = re.search(r"(\d+)\.txt$", filename.as_posix())
if match:
sample_id = match.group(1)
pred_file_path = (PREDICTED_TEST_SAMPLES_DIR / filename).as_posix()
pred_file_path = PREDICTED_TEST_SAMPLES_DIR / filename
ref_file_name = "ref_sample_" + str(sample_id) + ".txt"
ref_file_path = (REFERENCE_TEST_SAMPLES_DIR / ref_file_name).as_posix()
if os.path.exists(ref_file_path):
ref_file_path = REFERENCE_TEST_SAMPLES_DIR / ref_file_name
if ref_file_path.exists():
self.total_samples += 1
yield ref_file_path, pred_file_path
@@ -96,7 +74,18 @@ class TestDatasetLoader:
pred_text = file.read()
with open(ref_file_path, "r", encoding="utf-8") as file:
ref_text = file.read()
yield EvaluationTestSample(ref_text, pred_text)
yield EvaluationTestSample(reference_text=ref_text, predicted_text=pred_text)
class EvaluationConfig(BaseModel):
"""
Model for evaluation parameters
"""
insertion_penalty: int = Field(default=1)
substitution_penalty: int = Field(default=1)
deletion_penalty: int = Field(default=1)
normalizer: Any = Field(default=EnglishTextNormalizer())
test_directory: str = Field(default=str(Path(__file__).parent))
class ModelEvaluator:
@@ -111,38 +100,29 @@ class ModelEvaluator:
WEIGHTED_WER_JIWER = 0.0
WER_JIWER = []
normalizer = None
accuracy = None
evaluation_result = EvaluationResult()
test_dataset_loader = None
test_directory = None
evaluation_config = {}
evaluation_config = None
def __init__(self, **kwargs):
self.evaluation_config = {k: v for k, v in kwargs.items() if v is not None}
if "normalizer" not in self.evaluation_config:
self.normalizer = EnglishTextNormalizer()
self.evaluation_config["normalizer"] = str(type(self.normalizer))
if "parent_dir" not in self.evaluation_config:
self.test_directory = Path(__file__).parent
self.test_dataset_loader = TestDatasetLoader(self.test_directory)
self.evaluation_config["test_directory"] = str(self.test_directory)
self.evaluation_config = EvaluationConfig(**kwargs)
self.test_dataset_loader = TestDatasetLoader(test_dir=self.evaluation_config.test_directory)
def __repr__(self):
return "ModelEvaluator(" + json.dumps(self.describe(), indent=4) + ")"
return f"ModelEvaluator({self.evaluation_config})"
def describe(self) -> dict:
"""
Returns the parameters defining the evaluator
"""
return self.evaluation_config
return self.evaluation_config.model_dump()
def _normalize(self, sample: EvaluationTestSample) -> None:
"""
Normalize both reference and predicted text
"""
sample.update(
self.normalizer(sample.reference_text),
self.normalizer(sample.predicted_text),
self.evaluation_config.normalizer(sample.reference_text),
self.evaluation_config.normalizer(sample.predicted_text),
)
def _calculate_wer(self, sample: EvaluationTestSample) -> float:
@@ -154,9 +134,9 @@ class ModelEvaluator:
s1=sample.reference_text,
s2=sample.predicted_text,
weights=(
self.evaluation_config["insertion_penalty"],
self.evaluation_config["deletion_penalty"],
self.evaluation_config["substitution_penalty"],
self.evaluation_config.insertion_penalty,
self.evaluation_config.deletion_penalty,
self.evaluation_config.substitution_penalty,
),
)
wer = levenshtein_distance / len(sample.reference_text)
@@ -166,7 +146,7 @@ class ModelEvaluator:
"""
Compute WER
"""
for sample in tqdm(self.test_dataset_loader, desc="Evaluating", ncols=100):
for sample in tqdm(self.test_dataset_loader, desc="Evaluating"):
self._normalize(sample)
wer_item_l = {
"wer": self._calculate_wer(sample),
@@ -199,15 +179,18 @@ class ModelEvaluator:
weighted_wer_jiwer = self._calculate_weighted_wer(self.WER_JIWER)
final_weighted_wer = (weighted_wer_levenshtein + weighted_wer_jiwer) / 2
self.accuracy = (1 - final_weighted_wer) * 100
self.evaluation_result.accuracy = (1 - final_weighted_wer) * 100
def evaluate(self, recalculate: bool = False) -> EvaluationResult:
"""
Triggers the model evaluation
"""
if not self.accuracy or recalculate:
if not self.evaluation_result.accuracy or recalculate:
self._calculate_model_accuracy()
return EvaluationResult(self.accuracy, self.test_dataset_loader.total_samples)
return EvaluationResult(
accuracy=self.evaluation_result.accuracy,
total_test_samples=self.test_dataset_loader.total_samples
)
eval_config = {"insertion_penalty": 1, "deletion_penalty": 2, "substitution_penalty": 1}

2
server/poetry.lock generated
View File

@@ -2996,4 +2996,4 @@ multidict = ">=4.0"
[metadata]
lock-version = "2.0"
python-versions = "^3.11"
content-hash = "a5cd48fcfc629c2cd2f4fcc4263c57f867d84acf60824eaf952e365578374d1d"
content-hash = "c9924049dacf7310590416f096f5b20f6ed905d8a50edf5e8afcf2c28b70799f"

View File

@@ -51,6 +51,7 @@ aioboto3 = "^11.2.0"
jiwer = "^3.0.2"
levenshtein = "^0.21.1"
tqdm = "^4.66.0"
pydantic = "^2.1.1"
[build-system]
requires = ["poetry-core"]