| | |
| | |
| | |
| | |
| | |
| | |
| | |
| | |
| | |
| | |
| | |
| | |
| | |
| | |
| |
|
| | import dataclasses |
| | import math |
| | import os |
| | from typing import Any, Callable, List, Mapping, MutableMapping, Optional, Sequence, Tuple, Union |
| |
|
| | import numpy as np |
| | import torch |
| | import torch.nn.functional as F |
| |
|
| | from ....utils import is_note_seq_available |
| | from .pipeline_spectrogram_diffusion import TARGET_FEATURE_LENGTH |
| |
|
| |
|
| | if is_note_seq_available(): |
| | import note_seq |
| | else: |
| | raise ImportError("Please install note-seq via `pip install note-seq`") |
| |
|
| |
|
| | INPUT_FEATURE_LENGTH = 2048 |
| |
|
| | SAMPLE_RATE = 16000 |
| | HOP_SIZE = 320 |
| | FRAME_RATE = int(SAMPLE_RATE // HOP_SIZE) |
| |
|
| | DEFAULT_STEPS_PER_SECOND = 100 |
| | DEFAULT_MAX_SHIFT_SECONDS = 10 |
| | DEFAULT_NUM_VELOCITY_BINS = 1 |
| |
|
| | SLAKH_CLASS_PROGRAMS = { |
| | "Acoustic Piano": 0, |
| | "Electric Piano": 4, |
| | "Chromatic Percussion": 8, |
| | "Organ": 16, |
| | "Acoustic Guitar": 24, |
| | "Clean Electric Guitar": 26, |
| | "Distorted Electric Guitar": 29, |
| | "Acoustic Bass": 32, |
| | "Electric Bass": 33, |
| | "Violin": 40, |
| | "Viola": 41, |
| | "Cello": 42, |
| | "Contrabass": 43, |
| | "Orchestral Harp": 46, |
| | "Timpani": 47, |
| | "String Ensemble": 48, |
| | "Synth Strings": 50, |
| | "Choir and Voice": 52, |
| | "Orchestral Hit": 55, |
| | "Trumpet": 56, |
| | "Trombone": 57, |
| | "Tuba": 58, |
| | "French Horn": 60, |
| | "Brass Section": 61, |
| | "Soprano/Alto Sax": 64, |
| | "Tenor Sax": 66, |
| | "Baritone Sax": 67, |
| | "Oboe": 68, |
| | "English Horn": 69, |
| | "Bassoon": 70, |
| | "Clarinet": 71, |
| | "Pipe": 73, |
| | "Synth Lead": 80, |
| | "Synth Pad": 88, |
| | } |
| |
|
| |
|
| | @dataclasses.dataclass |
| | class NoteRepresentationConfig: |
| | """Configuration note representations.""" |
| |
|
| | onsets_only: bool |
| | include_ties: bool |
| |
|
| |
|
| | @dataclasses.dataclass |
| | class NoteEventData: |
| | pitch: int |
| | velocity: Optional[int] = None |
| | program: Optional[int] = None |
| | is_drum: Optional[bool] = None |
| | instrument: Optional[int] = None |
| |
|
| |
|
| | @dataclasses.dataclass |
| | class NoteEncodingState: |
| | """Encoding state for note transcription, keeping track of active pitches.""" |
| |
|
| | |
| | active_pitches: MutableMapping[Tuple[int, int], int] = dataclasses.field(default_factory=dict) |
| |
|
| |
|
| | @dataclasses.dataclass |
| | class EventRange: |
| | type: str |
| | min_value: int |
| | max_value: int |
| |
|
| |
|
| | @dataclasses.dataclass |
| | class Event: |
| | type: str |
| | value: int |
| |
|
| |
|
| | class Tokenizer: |
| | def __init__(self, regular_ids: int): |
| | |
| | self._num_special_tokens = 3 |
| | self._num_regular_tokens = regular_ids |
| |
|
| | def encode(self, token_ids): |
| | encoded = [] |
| | for token_id in token_ids: |
| | if not 0 <= token_id < self._num_regular_tokens: |
| | raise ValueError( |
| | f"token_id {token_id} does not fall within valid range of [0, {self._num_regular_tokens})" |
| | ) |
| | encoded.append(token_id + self._num_special_tokens) |
| |
|
| | |
| | encoded.append(1) |
| |
|
| | |
| | encoded = encoded + [0] * (INPUT_FEATURE_LENGTH - len(encoded)) |
| |
|
| | return encoded |
| |
|
| |
|
| | class Codec: |
| | """Encode and decode events. |
| | |
| | Useful for declaring what certain ranges of a vocabulary should be used for. This is intended to be used from |
| | Python before encoding or after decoding with GenericTokenVocabulary. This class is more lightweight and does not |
| | include things like EOS or UNK token handling. |
| | |
| | To ensure that 'shift' events are always the first block of the vocab and start at 0, that event type is required |
| | and specified separately. |
| | """ |
| |
|
| | def __init__(self, max_shift_steps: int, steps_per_second: float, event_ranges: List[EventRange]): |
| | """Define Codec. |
| | |
| | Args: |
| | max_shift_steps: Maximum number of shift steps that can be encoded. |
| | steps_per_second: Shift steps will be interpreted as having a duration of |
| | 1 / steps_per_second. |
| | event_ranges: Other supported event types and their ranges. |
| | """ |
| | self.steps_per_second = steps_per_second |
| | self._shift_range = EventRange(type="shift", min_value=0, max_value=max_shift_steps) |
| | self._event_ranges = [self._shift_range] + event_ranges |
| | |
| | assert len(self._event_ranges) == len({er.type for er in self._event_ranges}) |
| |
|
| | @property |
| | def num_classes(self) -> int: |
| | return sum(er.max_value - er.min_value + 1 for er in self._event_ranges) |
| |
|
| | |
| | |
| |
|
| | def is_shift_event_index(self, index: int) -> bool: |
| | return (self._shift_range.min_value <= index) and (index <= self._shift_range.max_value) |
| |
|
| | @property |
| | def max_shift_steps(self) -> int: |
| | return self._shift_range.max_value |
| |
|
| | def encode_event(self, event: Event) -> int: |
| | """Encode an event to an index.""" |
| | offset = 0 |
| | for er in self._event_ranges: |
| | if event.type == er.type: |
| | if not er.min_value <= event.value <= er.max_value: |
| | raise ValueError( |
| | f"Event value {event.value} is not within valid range " |
| | f"[{er.min_value}, {er.max_value}] for type {event.type}" |
| | ) |
| | return offset + event.value - er.min_value |
| | offset += er.max_value - er.min_value + 1 |
| |
|
| | raise ValueError(f"Unknown event type: {event.type}") |
| |
|
| | def event_type_range(self, event_type: str) -> Tuple[int, int]: |
| | """Return [min_id, max_id] for an event type.""" |
| | offset = 0 |
| | for er in self._event_ranges: |
| | if event_type == er.type: |
| | return offset, offset + (er.max_value - er.min_value) |
| | offset += er.max_value - er.min_value + 1 |
| |
|
| | raise ValueError(f"Unknown event type: {event_type}") |
| |
|
| | def decode_event_index(self, index: int) -> Event: |
| | """Decode an event index to an Event.""" |
| | offset = 0 |
| | for er in self._event_ranges: |
| | if offset <= index <= offset + er.max_value - er.min_value: |
| | return Event(type=er.type, value=er.min_value + index - offset) |
| | offset += er.max_value - er.min_value + 1 |
| |
|
| | raise ValueError(f"Unknown event index: {index}") |
| |
|
| |
|
| | @dataclasses.dataclass |
| | class ProgramGranularity: |
| | |
| | tokens_map_fn: Callable[[Sequence[int], Codec], Sequence[int]] |
| | program_map_fn: Callable[[int], int] |
| |
|
| |
|
| | def drop_programs(tokens, codec: Codec): |
| | """Drops program change events from a token sequence.""" |
| | min_program_id, max_program_id = codec.event_type_range("program") |
| | return tokens[(tokens < min_program_id) | (tokens > max_program_id)] |
| |
|
| |
|
| | def programs_to_midi_classes(tokens, codec): |
| | """Modifies program events to be the first program in the MIDI class.""" |
| | min_program_id, max_program_id = codec.event_type_range("program") |
| | is_program = (tokens >= min_program_id) & (tokens <= max_program_id) |
| | return np.where(is_program, min_program_id + 8 * ((tokens - min_program_id) // 8), tokens) |
| |
|
| |
|
| | PROGRAM_GRANULARITIES = { |
| | |
| | |
| | "flat": ProgramGranularity(tokens_map_fn=drop_programs, program_map_fn=lambda program: 0), |
| | |
| | "midi_class": ProgramGranularity( |
| | tokens_map_fn=programs_to_midi_classes, program_map_fn=lambda program: 8 * (program // 8) |
| | ), |
| | |
| | "full": ProgramGranularity(tokens_map_fn=lambda tokens, codec: tokens, program_map_fn=lambda program: program), |
| | } |
| |
|
| |
|
| | def frame(signal, frame_length, frame_step, pad_end=False, pad_value=0, axis=-1): |
| | """ |
| | equivalent of tf.signal.frame |
| | """ |
| | signal_length = signal.shape[axis] |
| | if pad_end: |
| | frames_overlap = frame_length - frame_step |
| | rest_samples = np.abs(signal_length - frames_overlap) % np.abs(frame_length - frames_overlap) |
| | pad_size = int(frame_length - rest_samples) |
| |
|
| | if pad_size != 0: |
| | pad_axis = [0] * signal.ndim |
| | pad_axis[axis] = pad_size |
| | signal = F.pad(signal, pad_axis, "constant", pad_value) |
| | frames = signal.unfold(axis, frame_length, frame_step) |
| | return frames |
| |
|
| |
|
| | def program_to_slakh_program(program): |
| | |
| | for slakh_program in sorted(SLAKH_CLASS_PROGRAMS.values(), reverse=True): |
| | if program >= slakh_program: |
| | return slakh_program |
| |
|
| |
|
| | def audio_to_frames( |
| | samples, |
| | hop_size: int, |
| | frame_rate: int, |
| | ) -> Tuple[Sequence[Sequence[int]], torch.Tensor]: |
| | """Convert audio samples to non-overlapping frames and frame times.""" |
| | frame_size = hop_size |
| | samples = np.pad(samples, [0, frame_size - len(samples) % frame_size], mode="constant") |
| |
|
| | |
| | frames = frame( |
| | torch.Tensor(samples).unsqueeze(0), |
| | frame_length=frame_size, |
| | frame_step=frame_size, |
| | pad_end=False, |
| | ) |
| |
|
| | num_frames = len(samples) // frame_size |
| |
|
| | times = np.arange(num_frames) / frame_rate |
| | return frames, times |
| |
|
| |
|
| | def note_sequence_to_onsets_and_offsets_and_programs( |
| | ns: note_seq.NoteSequence, |
| | ) -> Tuple[Sequence[float], Sequence[NoteEventData]]: |
| | """Extract onset & offset times and pitches & programs from a NoteSequence. |
| | |
| | The onset & offset times will not necessarily be in sorted order. |
| | |
| | Args: |
| | ns: NoteSequence from which to extract onsets and offsets. |
| | |
| | Returns: |
| | times: A list of note onset and offset times. values: A list of NoteEventData objects where velocity is zero for |
| | note |
| | offsets. |
| | """ |
| | |
| | |
| | notes = sorted(ns.notes, key=lambda note: (note.is_drum, note.program, note.pitch)) |
| | times = [note.end_time for note in notes if not note.is_drum] + [note.start_time for note in notes] |
| | values = [ |
| | NoteEventData(pitch=note.pitch, velocity=0, program=note.program, is_drum=False) |
| | for note in notes |
| | if not note.is_drum |
| | ] + [ |
| | NoteEventData(pitch=note.pitch, velocity=note.velocity, program=note.program, is_drum=note.is_drum) |
| | for note in notes |
| | ] |
| | return times, values |
| |
|
| |
|
| | def num_velocity_bins_from_codec(codec: Codec): |
| | """Get number of velocity bins from event codec.""" |
| | lo, hi = codec.event_type_range("velocity") |
| | return hi - lo |
| |
|
| |
|
| | |
| | def segment(a, n): |
| | return [a[i : i + n] for i in range(0, len(a), n)] |
| |
|
| |
|
| | def velocity_to_bin(velocity, num_velocity_bins): |
| | if velocity == 0: |
| | return 0 |
| | else: |
| | return math.ceil(num_velocity_bins * velocity / note_seq.MAX_MIDI_VELOCITY) |
| |
|
| |
|
| | def note_event_data_to_events( |
| | state: Optional[NoteEncodingState], |
| | value: NoteEventData, |
| | codec: Codec, |
| | ) -> Sequence[Event]: |
| | """Convert note event data to a sequence of events.""" |
| | if value.velocity is None: |
| | |
| | return [Event("pitch", value.pitch)] |
| | else: |
| | num_velocity_bins = num_velocity_bins_from_codec(codec) |
| | velocity_bin = velocity_to_bin(value.velocity, num_velocity_bins) |
| | if value.program is None: |
| | |
| | if state is not None: |
| | state.active_pitches[(value.pitch, 0)] = velocity_bin |
| | return [Event("velocity", velocity_bin), Event("pitch", value.pitch)] |
| | else: |
| | if value.is_drum: |
| | |
| | return [Event("velocity", velocity_bin), Event("drum", value.pitch)] |
| | else: |
| | |
| | if state is not None: |
| | state.active_pitches[(value.pitch, value.program)] = velocity_bin |
| | return [ |
| | Event("program", value.program), |
| | Event("velocity", velocity_bin), |
| | Event("pitch", value.pitch), |
| | ] |
| |
|
| |
|
| | def note_encoding_state_to_events(state: NoteEncodingState) -> Sequence[Event]: |
| | """Output program and pitch events for active notes plus a final tie event.""" |
| | events = [] |
| | for pitch, program in sorted(state.active_pitches.keys(), key=lambda k: k[::-1]): |
| | if state.active_pitches[(pitch, program)]: |
| | events += [Event("program", program), Event("pitch", pitch)] |
| | events.append(Event("tie", 0)) |
| | return events |
| |
|
| |
|
| | def encode_and_index_events( |
| | state, event_times, event_values, codec, frame_times, encode_event_fn, encoding_state_to_events_fn=None |
| | ): |
| | """Encode a sequence of timed events and index to audio frame times. |
| | |
| | Encodes time shifts as repeated single step shifts for later run length encoding. |
| | |
| | Optionally, also encodes a sequence of "state events", keeping track of the current encoding state at each audio |
| | frame. This can be used e.g. to prepend events representing the current state to a targets segment. |
| | |
| | Args: |
| | state: Initial event encoding state. |
| | event_times: Sequence of event times. |
| | event_values: Sequence of event values. |
| | encode_event_fn: Function that transforms event value into a sequence of one |
| | or more Event objects. |
| | codec: An Codec object that maps Event objects to indices. |
| | frame_times: Time for every audio frame. |
| | encoding_state_to_events_fn: Function that transforms encoding state into a |
| | sequence of one or more Event objects. |
| | |
| | Returns: |
| | events: Encoded events and shifts. event_start_indices: Corresponding start event index for every audio frame. |
| | Note: one event can correspond to multiple audio indices due to sampling rate differences. This makes |
| | splitting sequences tricky because the same event can appear at the end of one sequence and the beginning of |
| | another. |
| | event_end_indices: Corresponding end event index for every audio frame. Used |
| | to ensure when slicing that one chunk ends where the next begins. Should always be true that |
| | event_end_indices[i] = event_start_indices[i + 1]. |
| | state_events: Encoded "state" events representing the encoding state before |
| | each event. |
| | state_event_indices: Corresponding state event index for every audio frame. |
| | """ |
| | indices = np.argsort(event_times, kind="stable") |
| | event_steps = [round(event_times[i] * codec.steps_per_second) for i in indices] |
| | event_values = [event_values[i] for i in indices] |
| |
|
| | events = [] |
| | state_events = [] |
| | event_start_indices = [] |
| | state_event_indices = [] |
| |
|
| | cur_step = 0 |
| | cur_event_idx = 0 |
| | cur_state_event_idx = 0 |
| |
|
| | def fill_event_start_indices_to_cur_step(): |
| | while ( |
| | len(event_start_indices) < len(frame_times) |
| | and frame_times[len(event_start_indices)] < cur_step / codec.steps_per_second |
| | ): |
| | event_start_indices.append(cur_event_idx) |
| | state_event_indices.append(cur_state_event_idx) |
| |
|
| | for event_step, event_value in zip(event_steps, event_values): |
| | while event_step > cur_step: |
| | events.append(codec.encode_event(Event(type="shift", value=1))) |
| | cur_step += 1 |
| | fill_event_start_indices_to_cur_step() |
| | cur_event_idx = len(events) |
| | cur_state_event_idx = len(state_events) |
| | if encoding_state_to_events_fn: |
| | |
| | |
| | for e in encoding_state_to_events_fn(state): |
| | state_events.append(codec.encode_event(e)) |
| |
|
| | for e in encode_event_fn(state, event_value, codec): |
| | events.append(codec.encode_event(e)) |
| |
|
| | |
| | |
| | |
| | |
| | while cur_step / codec.steps_per_second <= frame_times[-1]: |
| | events.append(codec.encode_event(Event(type="shift", value=1))) |
| | cur_step += 1 |
| | fill_event_start_indices_to_cur_step() |
| | cur_event_idx = len(events) |
| |
|
| | |
| | |
| | |
| | event_end_indices = event_start_indices[1:] + [len(events)] |
| |
|
| | events = np.array(events).astype(np.int32) |
| | state_events = np.array(state_events).astype(np.int32) |
| | event_start_indices = segment(np.array(event_start_indices).astype(np.int32), TARGET_FEATURE_LENGTH) |
| | event_end_indices = segment(np.array(event_end_indices).astype(np.int32), TARGET_FEATURE_LENGTH) |
| | state_event_indices = segment(np.array(state_event_indices).astype(np.int32), TARGET_FEATURE_LENGTH) |
| |
|
| | outputs = [] |
| | for start_indices, end_indices, event_indices in zip(event_start_indices, event_end_indices, state_event_indices): |
| | outputs.append( |
| | { |
| | "inputs": events, |
| | "event_start_indices": start_indices, |
| | "event_end_indices": end_indices, |
| | "state_events": state_events, |
| | "state_event_indices": event_indices, |
| | } |
| | ) |
| |
|
| | return outputs |
| |
|
| |
|
| | def extract_sequence_with_indices(features, state_events_end_token=None, feature_key="inputs"): |
| | """Extract target sequence corresponding to audio token segment.""" |
| | features = features.copy() |
| | start_idx = features["event_start_indices"][0] |
| | end_idx = features["event_end_indices"][-1] |
| |
|
| | features[feature_key] = features[feature_key][start_idx:end_idx] |
| |
|
| | if state_events_end_token is not None: |
| | |
| | |
| | state_event_start_idx = features["state_event_indices"][0] |
| | state_event_end_idx = state_event_start_idx + 1 |
| | while features["state_events"][state_event_end_idx - 1] != state_events_end_token: |
| | state_event_end_idx += 1 |
| | features[feature_key] = np.concatenate( |
| | [ |
| | features["state_events"][state_event_start_idx:state_event_end_idx], |
| | features[feature_key], |
| | ], |
| | axis=0, |
| | ) |
| |
|
| | return features |
| |
|
| |
|
| | def map_midi_programs( |
| | feature, codec: Codec, granularity_type: str = "full", feature_key: str = "inputs" |
| | ) -> Mapping[str, Any]: |
| | """Apply MIDI program map to token sequences.""" |
| | granularity = PROGRAM_GRANULARITIES[granularity_type] |
| |
|
| | feature[feature_key] = granularity.tokens_map_fn(feature[feature_key], codec) |
| | return feature |
| |
|
| |
|
| | def run_length_encode_shifts_fn( |
| | features, |
| | codec: Codec, |
| | feature_key: str = "inputs", |
| | state_change_event_types: Sequence[str] = (), |
| | ) -> Callable[[Mapping[str, Any]], Mapping[str, Any]]: |
| | """Return a function that run-length encodes shifts for a given codec. |
| | |
| | Args: |
| | codec: The Codec to use for shift events. |
| | feature_key: The feature key for which to run-length encode shifts. |
| | state_change_event_types: A list of event types that represent state |
| | changes; tokens corresponding to these event types will be interpreted as state changes and redundant ones |
| | will be removed. |
| | |
| | Returns: |
| | A preprocessing function that run-length encodes single-step shifts. |
| | """ |
| | state_change_event_ranges = [codec.event_type_range(event_type) for event_type in state_change_event_types] |
| |
|
| | def run_length_encode_shifts(features: MutableMapping[str, Any]) -> Mapping[str, Any]: |
| | """Combine leading/interior shifts, trim trailing shifts. |
| | |
| | Args: |
| | features: Dict of features to process. |
| | |
| | Returns: |
| | A dict of features. |
| | """ |
| | events = features[feature_key] |
| |
|
| | shift_steps = 0 |
| | total_shift_steps = 0 |
| | output = np.array([], dtype=np.int32) |
| |
|
| | current_state = np.zeros(len(state_change_event_ranges), dtype=np.int32) |
| |
|
| | for event in events: |
| | if codec.is_shift_event_index(event): |
| | shift_steps += 1 |
| | total_shift_steps += 1 |
| |
|
| | else: |
| | |
| | |
| | is_redundant = False |
| | for i, (min_index, max_index) in enumerate(state_change_event_ranges): |
| | if (min_index <= event) and (event <= max_index): |
| | if current_state[i] == event: |
| | is_redundant = True |
| | current_state[i] = event |
| | if is_redundant: |
| | continue |
| |
|
| | |
| | |
| | if shift_steps > 0: |
| | shift_steps = total_shift_steps |
| | while shift_steps > 0: |
| | output_steps = np.minimum(codec.max_shift_steps, shift_steps) |
| | output = np.concatenate([output, [output_steps]], axis=0) |
| | shift_steps -= output_steps |
| | output = np.concatenate([output, [event]], axis=0) |
| |
|
| | features[feature_key] = output |
| | return features |
| |
|
| | return run_length_encode_shifts(features) |
| |
|
| |
|
| | def note_representation_processor_chain(features, codec: Codec, note_representation_config: NoteRepresentationConfig): |
| | tie_token = codec.encode_event(Event("tie", 0)) |
| | state_events_end_token = tie_token if note_representation_config.include_ties else None |
| |
|
| | features = extract_sequence_with_indices( |
| | features, state_events_end_token=state_events_end_token, feature_key="inputs" |
| | ) |
| |
|
| | features = map_midi_programs(features, codec) |
| |
|
| | features = run_length_encode_shifts_fn(features, codec, state_change_event_types=["velocity", "program"]) |
| |
|
| | return features |
| |
|
| |
|
| | class MidiProcessor: |
| | def __init__(self): |
| | self.codec = Codec( |
| | max_shift_steps=DEFAULT_MAX_SHIFT_SECONDS * DEFAULT_STEPS_PER_SECOND, |
| | steps_per_second=DEFAULT_STEPS_PER_SECOND, |
| | event_ranges=[ |
| | EventRange("pitch", note_seq.MIN_MIDI_PITCH, note_seq.MAX_MIDI_PITCH), |
| | EventRange("velocity", 0, DEFAULT_NUM_VELOCITY_BINS), |
| | EventRange("tie", 0, 0), |
| | EventRange("program", note_seq.MIN_MIDI_PROGRAM, note_seq.MAX_MIDI_PROGRAM), |
| | EventRange("drum", note_seq.MIN_MIDI_PITCH, note_seq.MAX_MIDI_PITCH), |
| | ], |
| | ) |
| | self.tokenizer = Tokenizer(self.codec.num_classes) |
| | self.note_representation_config = NoteRepresentationConfig(onsets_only=False, include_ties=True) |
| |
|
| | def __call__(self, midi: Union[bytes, os.PathLike, str]): |
| | if not isinstance(midi, bytes): |
| | with open(midi, "rb") as f: |
| | midi = f.read() |
| |
|
| | ns = note_seq.midi_to_note_sequence(midi) |
| | ns_sus = note_seq.apply_sustain_control_changes(ns) |
| |
|
| | for note in ns_sus.notes: |
| | if not note.is_drum: |
| | note.program = program_to_slakh_program(note.program) |
| |
|
| | samples = np.zeros(int(ns_sus.total_time * SAMPLE_RATE)) |
| |
|
| | _, frame_times = audio_to_frames(samples, HOP_SIZE, FRAME_RATE) |
| | times, values = note_sequence_to_onsets_and_offsets_and_programs(ns_sus) |
| |
|
| | events = encode_and_index_events( |
| | state=NoteEncodingState(), |
| | event_times=times, |
| | event_values=values, |
| | frame_times=frame_times, |
| | codec=self.codec, |
| | encode_event_fn=note_event_data_to_events, |
| | encoding_state_to_events_fn=note_encoding_state_to_events, |
| | ) |
| |
|
| | events = [ |
| | note_representation_processor_chain(event, self.codec, self.note_representation_config) for event in events |
| | ] |
| | input_tokens = [self.tokenizer.encode(event["inputs"]) for event in events] |
| |
|
| | return input_tokens |
| |
|