| """ |
| Task 1: Count - Generate counting questions |
| |
| This task joins multiple audio sources and asks questions about counting |
| the number of unique sound sources in the audio. |
| """ |
|
|
| import csv |
| import random |
| from pathlib import Path |
| from typing import Dict, List |
|
|
| import sys |
| sys.path.append(str(Path(__file__).parent.parent)) |
|
|
| from utils import ( |
| AudioProcessor, ESC50Dataset, QuestionGenerator, LLMQuestionGenerator, |
| setup_logger, set_random_seed, generate_sample_durations_for_task, |
| generate_single_clip_duration, build_count_task_audio, |
| get_max_clip_num_to_be_joined |
| ) |
|
|
|
|
| class CountTaskGenerator: |
| """Generator for counting task dataset.""" |
| |
| def __init__(self, config: Dict, logger): |
| """ |
| Initialize count task generator. |
| |
| Args: |
| config: Configuration dictionary |
| logger: Logger instance |
| """ |
| self.config = config |
| self.logger = logger |
| self.task_config = config['tasks']['count'] |
| |
| |
| self.dataset = ESC50Dataset( |
| config['esc50']['metadata_path'], |
| config['esc50']['audio_path'], |
| config |
| ) |
| self.audio_processor = AudioProcessor( |
| crossfade_duration=config['audio']['crossfade_duration'], |
| silence_duration=config['audio']['silence_duration'], |
| with_silence=config['audio']['with_silence'], |
| normalize=config['audio']['normalize'], |
| normalize_target_dBFS=config['audio']['normalize_target_dBFS'], |
| synthetic_silence_path=config['synthetic_silence']['path'] |
| ) |
| self.question_generator = QuestionGenerator( |
| num_options=config['mcq']['num_options'], |
| option_labels=config['mcq']['option_labels'], |
| distractor_strategy=config['mcq']['distractor_strategy'] |
| ) |
| |
| |
| self.llm_enabled = config.get('llm', {}).get('enabled', False) |
| self.llm_generator = LLMQuestionGenerator( |
| enabled=self.llm_enabled, |
| template_questions=self.task_config |
| ) |
| if self.llm_enabled: |
| logger.info("LLM question generation enabled (local Llama 3.1 8B)") |
| else: |
| logger.info("Using template-based question generation") |
| |
| |
| self.min_clip_duration = config['audio']['min_clip_duration'] |
| self.max_clip_duration = config['audio']['max_clip_duration'] |
| self.source_clip_duration = config['audio'].get('source_clip_duration', 5.0) |
| self.min_silence_ms = config['audio'].get('min_silence_duration', 100) |
| self.max_extra_silence_per_gap_ms = config['audio'].get('max_extra_silence_per_gap', 500) |
| |
| self.crossfade_within_source_ms = config['audio'].get('crossfade_within_source', 50) |
| self.task_duration_hours = self.task_config['task_duration_size'] |
| |
| |
| |
| |
| self.ordering_mode = self.task_config.get('ordering_mode', 'random') |
| logger.info(f"Count task ordering mode: {self.ordering_mode}") |
| |
| |
| self.output_base = Path(config['output']['base_path']) / 'count' |
| self.output_base.mkdir(parents=True, exist_ok=True) |
| self.audio_output = self.output_base / 'audios' |
| self.audio_output.mkdir(parents=True, exist_ok=True) |
| |
| def create_sampling_list(self, parent_list: List, n_sampling: int) -> List: |
| """ |
| Sample elements from parent list with replacement. |
| |
| Args: |
| parent_list: List to sample from |
| n_sampling: Number of samples |
| |
| Returns: |
| List of sampled elements |
| """ |
| return [random.choice(parent_list) for _ in range(n_sampling)] |
| |
| def generate_sample(self, sample_id: int, target_unique_count: int = None, target_duration_seconds: float = None) -> Dict: |
| """ |
| Generate a single count task sample. |
| |
| Pipeline for COUNT task: |
| 1. Use pre-generated target duration (or generate if not provided) |
| 2. Calculate max clips that can fit |
| 3. Pick N unique classes (N <= max_clips, since each source needs at least 1 clip) |
| 4. For each class, sample one audio clip |
| 5. Calculate repetitions to fill target duration |
| 6. Based on ordering_mode: |
| - "random": Shuffle clips (A B A C B A C) - tests recognition |
| - "consecutive": Group same-class (AAA BBB CCC) - easier |
| 7. Insert silences between clips |
| 8. Distribute remainder as random extra silences |
| |
| Args: |
| sample_id: Sample ID number |
| target_unique_count: Target number of unique sounds (for balanced distribution) |
| target_duration_seconds: Pre-generated target duration (from generate_sample_durations_for_task) |
| |
| Returns: |
| Dictionary with sample metadata |
| """ |
| |
| if target_duration_seconds is not None: |
| clip_duration_seconds = target_duration_seconds |
| else: |
| clip_duration_seconds = generate_single_clip_duration( |
| self.min_clip_duration, |
| self.max_clip_duration |
| ) |
| |
| |
| max_clips, remainder_seconds = get_max_clip_num_to_be_joined( |
| clip_duration_seconds, |
| self.source_clip_duration, |
| self.min_silence_ms |
| ) |
| |
| |
| max_clips = max(1, max_clips) |
| |
| max_clips_per_sample = self.task_config.get('max_clips_per_sample', 10) |
| |
| |
| |
| max_unique_for_sample = min(max_clips, max_clips_per_sample, len(self.dataset.CATEGORIES)) |
| |
| if max_unique_for_sample < 1: |
| raise ValueError( |
| f"Sample {sample_id}: Cannot generate sample - max_unique_for_sample={max_unique_for_sample}. " |
| f"max_clips={max_clips}, max_clips_per_sample={max_clips_per_sample}, " |
| f"available_categories={len(self.dataset.CATEGORIES)}, duration={clip_duration_seconds:.1f}s. " |
| f"Increase min_clip_duration or reduce max_clips_per_sample." |
| ) |
| |
| |
| if target_unique_count is not None: |
| |
| |
| n_unique_audios = min(target_unique_count, max_unique_for_sample) |
| |
| if n_unique_audios != target_unique_count: |
| self.logger.debug( |
| f"Sample {sample_id}: Clamped target from {target_unique_count} to {n_unique_audios} " |
| f"(duration={clip_duration_seconds:.1f}s can only fit {max_clips} clips)" |
| ) |
| else: |
| |
| n_unique_audios = random.randint(1, max_unique_for_sample) |
| |
| self.logger.debug( |
| f"Sample {sample_id}: target={clip_duration_seconds:.1f}s, max_clips={max_clips}, " |
| f"n_unique_audios={n_unique_audios}" |
| ) |
| |
| |
| selected_categories = self.dataset.get_least_used_categories(n_unique_audios) |
| |
| |
| for cat in selected_categories: |
| self.dataset.category_usage_counts[cat] += 1 |
| |
| |
| source_files = [] |
| source_paths = [] |
| source_categories = [] |
| |
| for category in selected_categories: |
| filename, filepath = self.dataset.sample_file_from_category(category) |
| source_files.append(filename) |
| source_paths.append(filepath) |
| source_categories.append(category) |
| |
| |
| source_audios = [] |
| for file_path in source_paths: |
| audio = self.audio_processor.load_audio(file_path) |
| source_audios.append(audio) |
| |
| |
| final_audio, clip_sequence, build_metadata = build_count_task_audio( |
| source_audios, |
| source_categories, |
| clip_duration_seconds, |
| ordering_mode=self.ordering_mode, |
| source_clip_duration_seconds=self.source_clip_duration, |
| min_silence_ms=self.min_silence_ms, |
| max_extra_silence_per_gap_ms=self.max_extra_silence_per_gap_ms, |
| crossfade_within_source_ms=self.crossfade_within_source_ms |
| ) |
| |
| |
| output_audio_path = self.audio_output / f"{sample_id}.wav" |
| final_audio.export(str(output_audio_path), format="wav") |
| |
| |
| if self.llm_enabled and self.llm_generator: |
| llm_questions = self.llm_generator.generate_count_questions( |
| correct_count=n_unique_audios, |
| categories_present=list(set(clip_sequence)) |
| ) |
| mcq_question_text = llm_questions.get('mcq_question') |
| open_text_question_text = llm_questions.get('open_text_question') |
| else: |
| mcq_question_text = random.choice(self.task_config['mcq_questions']) |
| open_text_question_text = random.choice(self.task_config['open_text_questions']) |
| |
| |
| mcq_data = self.question_generator.generate_count_mcq( |
| mcq_question_text, |
| n_unique_audios, |
| self.dataset.CATEGORIES |
| ) |
| |
| |
| open_text_data = self.question_generator.generate_count_open_text( |
| open_text_question_text, |
| n_unique_audios |
| ) |
| |
| |
| metadata = { |
| 'id': sample_id, |
| 'audio_path': str(output_audio_path.relative_to(self.output_base.parent)), |
| 'n_unique_sounds': n_unique_audios, |
| 'total_clips': build_metadata['total_clips'], |
| 'repetitions_per_source': build_metadata['repetitions_per_source'], |
| 'ordering_mode': self.ordering_mode, |
| 'source_files': source_files, |
| 'source_categories': source_categories, |
| 'clip_sequence': clip_sequence, |
| 'unique_categories': sorted(list(set(source_categories))), |
| 'target_duration_seconds': clip_duration_seconds, |
| 'actual_duration_seconds': len(final_audio) / 1000.0, |
| 'mcq_question': mcq_data['question'], |
| 'mcq_options': mcq_data['options'], |
| 'mcq_correct_answer': mcq_data['correct_answer'], |
| 'open_text_question': open_text_data['question'], |
| 'open_text_answer': open_text_data['correct_answer'], |
| 'llm_generated': self.llm_enabled |
| } |
| |
| self.logger.info( |
| f"Generated count sample {sample_id}: {n_unique_audios} unique sounds, " |
| f"{build_metadata['total_clips']} clips, {len(final_audio)/1000:.1f}s" |
| ) |
| |
| return metadata |
| |
| def generate_dataset(self) -> tuple: |
| """ |
| Generate the complete count task dataset. |
| |
| Returns: |
| Tuple of (mcq_csv_path, open_text_csv_path) |
| """ |
| |
| sample_durations = generate_sample_durations_for_task( |
| self.task_duration_hours, |
| self.min_clip_duration, |
| self.max_clip_duration |
| ) |
| num_samples = len(sample_durations) |
| self.logger.info(f"Generating {num_samples} count task samples (target: {self.task_duration_hours}h, actual: {sum(sample_durations)/3600:.2f}h)...") |
| |
| |
| max_clips_per_sample = self.task_config.get('max_clips_per_sample', 10) |
| sample_max_clips = [] |
| for duration in sample_durations: |
| max_clips, _ = get_max_clip_num_to_be_joined( |
| duration, |
| self.source_clip_duration, |
| self.min_silence_ms |
| ) |
| |
| max_for_sample = min(max_clips, max_clips_per_sample, len(self.dataset.CATEGORIES)) |
| sample_max_clips.append(max_for_sample) |
| |
| |
| |
| possible_answers = list(range(1, max_clips_per_sample + 1)) |
| samples_per_answer = num_samples // len(possible_answers) |
| remainder = num_samples % len(possible_answers) |
| |
| |
| sample_info = [(i, sample_durations[i], sample_max_clips[i]) for i in range(num_samples)] |
| |
| |
| sample_info.sort(key=lambda x: x[2], reverse=True) |
| |
| |
| balanced_assignments = [None] * num_samples |
| assignment_pool = [] |
| |
| for answer in possible_answers: |
| count = samples_per_answer + (1 if remainder > 0 else 0) |
| assignment_pool.extend([answer] * count) |
| remainder = max(0, remainder - 1) |
| |
| |
| assignment_pool.sort(reverse=True) |
| |
| for idx, (sample_idx, duration, capacity) in enumerate(sample_info): |
| |
| target = min(assignment_pool[idx], capacity) |
| balanced_assignments[sample_idx] = target |
| |
| |
| from collections import Counter |
| distribution = Counter(balanced_assignments) |
| self.logger.info(f"Balanced answer distribution (after capacity-aware assignment): {dict(sorted(distribution.items()))}") |
| |
| all_metadata = [] |
| |
| for i in range(num_samples): |
| metadata = self.generate_sample( |
| i, |
| target_unique_count=balanced_assignments[i], |
| target_duration_seconds=sample_durations[i] |
| ) |
| all_metadata.append(metadata) |
| |
| |
| mcq_csv_path = self.output_base / 'count_mcq.csv' |
| self._save_mcq_csv(all_metadata, mcq_csv_path) |
| |
| |
| open_text_csv_path = self.output_base / 'count_open_text.csv' |
| self._save_open_text_csv(all_metadata, open_text_csv_path) |
| |
| |
| metadata_csv_path = self.output_base / 'count_metadata.csv' |
| self._save_metadata_csv(all_metadata, metadata_csv_path) |
| |
| self.logger.info(f"Count task dataset generation complete!") |
| self.logger.info(f" - MCQ CSV: {mcq_csv_path}") |
| self.logger.info(f" - Open-text CSV: {open_text_csv_path}") |
| self.logger.info(f" - Metadata CSV: {metadata_csv_path}") |
| self.logger.info(f" - Audio files: {self.audio_output}") |
| |
| return mcq_csv_path, open_text_csv_path |
| |
| def _save_mcq_csv(self, metadata_list: List[Dict], output_path: Path): |
| """Save MCQ format CSV.""" |
| with open(output_path, 'w', newline='') as f: |
| writer = csv.writer(f) |
| |
| writer.writerow([ |
| 'question', 'id', 'audio_path', |
| 'optionA', 'optionB', 'optionC', 'optionD', |
| 'correct', 'source_wavs', 'source_categories' |
| ]) |
| |
| |
| for meta in metadata_list: |
| writer.writerow([ |
| meta['mcq_question'], |
| meta['id'], |
| meta['audio_path'], |
| meta['mcq_options']['A'], |
| meta['mcq_options']['B'], |
| meta['mcq_options']['C'], |
| meta['mcq_options']['D'], |
| meta['mcq_correct_answer'], |
| str(meta['source_files']), |
| str(meta['unique_categories']) |
| ]) |
| |
| def _save_open_text_csv(self, metadata_list: List[Dict], output_path: Path): |
| """Save open-text format CSV.""" |
| with open(output_path, 'w', newline='') as f: |
| writer = csv.writer(f) |
| |
| writer.writerow([ |
| 'question', 'id', 'audio_path', 'answer', |
| 'source_wavs', 'source_categories' |
| ]) |
| |
| |
| for meta in metadata_list: |
| writer.writerow([ |
| meta['open_text_question'], |
| meta['id'], |
| meta['audio_path'], |
| meta['open_text_answer'], |
| str(meta['source_files']), |
| str(meta['unique_categories']) |
| ]) |
| |
| def _save_metadata_csv(self, metadata_list: List[Dict], output_path: Path): |
| """Save detailed metadata CSV.""" |
| with open(output_path, 'w', newline='') as f: |
| writer = csv.writer(f) |
| |
| writer.writerow([ |
| 'id', 'audio_path', 'total_clips', 'n_unique_sounds', |
| 'source_files', 'source_categories', 'unique_categories', |
| 'ordering_mode', 'target_duration_s', 'actual_duration_s', 'llm_generated' |
| ]) |
| |
| |
| for meta in metadata_list: |
| writer.writerow([ |
| meta['id'], |
| meta['audio_path'], |
| meta['total_clips'], |
| meta['n_unique_sounds'], |
| str(meta['source_files']), |
| str(meta['source_categories']), |
| str(meta['unique_categories']), |
| meta.get('ordering_mode', 'random'), |
| meta.get('target_duration_seconds', 0), |
| meta.get('actual_duration_seconds', 0), |
| meta.get('llm_generated', False) |
| ]) |
|
|
|
|
| def main(config_path: str = None): |
| """Main entry point for count task generation.""" |
| import yaml |
| |
| |
| if config_path is None: |
| config_path = Path(__file__).parent.parent / 'config.yaml' |
| |
| with open(config_path, 'r') as f: |
| config = yaml.safe_load(f) |
| |
| |
| set_random_seed(config['random_seed']) |
| |
| |
| logger = setup_logger( |
| 'count_task', |
| log_file=str(Path(config['output']['base_path']) / config['logging']['log_file']), |
| level=config['logging']['level'], |
| console_output=config['logging']['console_output'] |
| ) |
| |
| |
| generator = CountTaskGenerator(config, logger) |
| generator.generate_dataset() |
|
|
|
|
| if __name__ == '__main__': |
| main() |
|
|