| import torch |
| from torch.utils.data import Dataset |
| import numpy as np |
| from tqdm import tqdm |
|
|
|
|
| class BeatTrackingDataset(Dataset): |
| def __init__( |
| self, |
| hf_dataset, |
| target_type="beats", |
| sample_rate=16000, |
| hop_length=160, |
| context_frames=50, |
| ): |
| """ |
| Args: |
| hf_dataset: HuggingFace dataset object |
| target_type (str): "beats" or "downbeats". Determines which labels are treated as positive. |
| context_frames (int): Number of frames before and after the center frame. |
| Total frames = 2 * context_frames + 1. |
| Default 50 means 101 frames (~1s). |
| """ |
| self.sr = sample_rate |
| self.hop_length = hop_length |
| self.target_type = target_type |
|
|
| self.context_frames = context_frames |
| |
| |
| |
| |
| self.context_samples = (self.context_frames * 2 + 1) * hop_length + 1488 |
|
|
| |
| self.audio_cache = [] |
| self.indices = [] |
| self._prepare_indices(hf_dataset) |
|
|
| def _prepare_indices(self, hf_dataset): |
| """ |
| Prepares balanced indices and caches audio. |
| Uses the same "Fuzzier" training examples strategy as the baseline. |
| """ |
| print(f"Preparing dataset indices for target: {self.target_type}...") |
|
|
| for i, item in tqdm( |
| enumerate(hf_dataset), total=len(hf_dataset), desc="Building indices" |
| ): |
| |
| audio = item["audio"]["array"] |
| if hasattr(audio, "numpy"): |
| audio = audio.numpy() |
| self.audio_cache.append(audio) |
|
|
| |
| audio_len = len(audio) |
| n_frames = int(audio_len / self.hop_length) |
|
|
| |
| if self.target_type == "downbeats": |
| gt_times = item["downbeats"] |
| else: |
| gt_times = item["beats"] |
|
|
| |
| if hasattr(gt_times, "tolist"): |
| gt_times = gt_times.tolist() |
|
|
| gt_frames = set([int(t * self.sr / self.hop_length) for t in gt_times]) |
|
|
| |
| pos_frames = set() |
| for bf in gt_frames: |
| if 0 <= bf < n_frames: |
| self.indices.append((i, bf, 1.0)) |
| pos_frames.add(bf) |
|
|
| |
| if 0 <= bf - 1 < n_frames: |
| self.indices.append((i, bf - 1, 0.25)) |
| pos_frames.add(bf - 1) |
| if 0 <= bf + 1 < n_frames: |
| self.indices.append((i, bf + 1, 0.25)) |
| pos_frames.add(bf + 1) |
|
|
| |
| |
| num_pos = len(pos_frames) |
| num_neg = num_pos * 2 |
|
|
| count = 0 |
| attempts = 0 |
| while count < num_neg and attempts < num_neg * 5: |
| f = np.random.randint(0, n_frames) |
| if f not in pos_frames: |
| self.indices.append((i, f, 0.0)) |
| count += 1 |
| attempts += 1 |
|
|
| print( |
| f"Dataset ready. {len(self.indices)} samples, {len(self.audio_cache)} tracks cached." |
| ) |
|
|
| def __len__(self): |
| return len(self.indices) |
|
|
| def __getitem__(self, idx): |
| track_idx, frame_idx, label = self.indices[idx] |
|
|
| |
| audio = self.audio_cache[track_idx] |
| audio_len = len(audio) |
|
|
| |
| center_sample = frame_idx * self.hop_length |
| half_context = self.context_samples // 2 |
|
|
| |
| start = center_sample - half_context |
| end = center_sample + half_context |
|
|
| |
| pad_left = max(0, -start) |
| pad_right = max(0, end - audio_len) |
|
|
| valid_start = max(0, start) |
| valid_end = min(audio_len, end) |
|
|
| |
| chunk = audio[valid_start:valid_end] |
|
|
| if pad_left > 0 or pad_right > 0: |
| chunk = np.pad(chunk, (pad_left, pad_right), mode="constant") |
|
|
| waveform = torch.tensor(chunk, dtype=torch.float32) |
| return waveform, torch.tensor([label], dtype=torch.float32) |
|
|