| import argparse |
| from typing import Dict, Tuple |
|
|
| from visualize import generate_html_report |
| from utils import read_srt_text, preprocess_chinese_text |
| import jiwer |
|
|
| |
| from utils import read_srt_text |
|
|
| def parse_arguments(): |
| parser = argparse.ArgumentParser( |
| description="Calculate Character Error Rate (CER) for Chinese SRT files", |
| formatter_class=argparse.ArgumentDefaultsHelpFormatter, |
| ) |
|
|
| parser.add_argument( |
| "-r", |
| "--reference", |
| required=True, |
| help="Path to the reference (golden) SRT file", |
| ) |
|
|
| parser.add_argument( |
| "-i", "--input", required=True, help="Path to the input (hypothesis) SRT file" |
| ) |
|
|
| parser.add_argument("-o", "--output", help="Path to save the results (optional)") |
|
|
| parser.add_argument("--html", help="Path to save the HTML visualization (optional)") |
|
|
| return parser.parse_args() |
|
|
| def calculate_cer_both_versions( |
| reference_path: str, hypothesis_path: str |
| ) -> Tuple[Dict, Dict]: |
| """ |
| Calculate CER and related metrics between reference and hypothesis SRT files, |
| both with and without punctuation. |
| |
| Args: |
| reference_path (str): Path to the reference SRT file |
| hypothesis_path (str): Path to the hypothesis SRT file |
| |
| Returns: |
| Tuple[Dict, Dict]: Two dictionaries containing metrics (with and without punctuation) |
| """ |
| |
| reference_text = read_srt_text(reference_path) |
| hypothesis_text = read_srt_text(hypothesis_path) |
|
|
| |
| reference_chars_no_punct = preprocess_chinese_text( |
| reference_text, include_punctuation=False |
| ) |
| hypothesis_chars_no_punct = preprocess_chinese_text( |
| hypothesis_text, include_punctuation=False |
| ) |
|
|
| metrics_no_punct = jiwer.compute_measures( |
| reference_chars_no_punct, hypothesis_chars_no_punct |
| ) |
|
|
| |
| reference_chars_with_punct = preprocess_chinese_text( |
| reference_text, include_punctuation=True |
| ) |
| hypothesis_chars_with_punct = preprocess_chinese_text( |
| hypothesis_text, include_punctuation=True |
| ) |
|
|
| metrics_with_punct = jiwer.compute_measures( |
| reference_chars_with_punct, hypothesis_chars_with_punct |
| ) |
|
|
| |
| metrics_no_punct["total_ref_chars"] = len(reference_chars_no_punct.replace(" ", "")) |
| metrics_no_punct["total_hyp_chars"] = len( |
| hypothesis_chars_no_punct.replace(" ", "") |
| ) |
| metrics_with_punct["total_ref_chars"] = len( |
| reference_chars_with_punct.replace(" ", "") |
| ) |
| metrics_with_punct["total_hyp_chars"] = len( |
| hypothesis_chars_with_punct.replace(" ", "") |
| ) |
|
|
| return metrics_no_punct, metrics_with_punct |
|
|
|
|
| def format_metrics(metrics: dict, version: str) -> str: |
| """ |
| Format metrics into a string. |
| |
| Args: |
| metrics (dict): Dictionary of metric values |
| version (str): String indicating which version of metrics these are |
| |
| Returns: |
| str: Formatted metrics string |
| """ |
| output = [] |
| output.append(f"\n=== {version} ===") |
| output.append(f"Character Error Rate (CER): {metrics['wer']:.3f}") |
| output.append(f"Total Reference Characters: {metrics['total_ref_chars']}") |
| output.append(f"Total Hypothesis Characters: {metrics['total_hyp_chars']}") |
|
|
| output.append("\nDetailed Statistics:") |
| output.append(f"Correct Characters: {metrics['hits']}") |
| output.append(f"Substitutions: {metrics['substitutions']}") |
| output.append(f"Deletions: {metrics['deletions']}") |
| output.append(f"Insertions: {metrics['insertions']}") |
|
|
| |
| total_errors = ( |
| metrics["substitutions"] + metrics["deletions"] + metrics["insertions"] |
| ) |
| total_chars = metrics["total_ref_chars"] |
|
|
| output.append(f"\nError Analysis:") |
| output.append(f"Total Errors: {total_errors}") |
| output.append(f"Substitution Rate: {metrics['substitutions']/total_chars:.3f}") |
| output.append(f"Deletion Rate: {metrics['deletions']/total_chars:.3f}") |
| output.append(f"Insertion Rate: {metrics['insertions']/total_chars:.3f}") |
|
|
| return "\n".join(output) |
|
|
|
|
|
|
| if __name__ == "__main__": |
| args = parse_arguments() |
|
|
| try: |
| |
| reference_text = read_srt_text(args.reference) |
| hypothesis_text = read_srt_text(args.input) |
|
|
| |
| metrics_no_punct, metrics_with_punct = calculate_cer_both_versions( |
| args.reference, args.input |
| ) |
|
|
| |
| if args.html: |
| html_content = generate_html_report( |
| reference_text, hypothesis_text, metrics_no_punct, metrics_with_punct |
| ) |
| with open(args.html, "w", encoding="utf-8") as f: |
| f.write(html_content) |
| print(f"\nHTML visualization has been saved to: {args.html}") |
|
|
| |
| output_text = [] |
| output_text.append( |
| format_metrics(metrics_no_punct, "Metrics Without Punctuation") |
| ) |
| output_text.append( |
| format_metrics(metrics_with_punct, "Metrics With Punctuation") |
| ) |
| output_text.append("\n=== Comparison ===") |
| output_text.append(f"CER without punctuation: {metrics_no_punct['wer']:.3f}") |
| output_text.append(f"CER with punctuation: {metrics_with_punct['wer']:.3f}") |
| output_text.append( |
| f"Difference: {abs(metrics_with_punct['wer'] - metrics_no_punct['wer']):.3f}" |
| ) |
|
|
| final_output = "\n".join(output_text) |
| print(final_output) |
|
|
| if args.output: |
| with open(args.output, "w", encoding="utf-8") as f: |
| f.write(final_output) |
| print(f"\nResults have been saved to: {args.output}") |
|
|
| except FileNotFoundError as e: |
| print(f"Error: Could not find one of the input files - {str(e)}") |
| except Exception as e: |
| print(f"Error occurred: {str(e)}") |
|
|