pocket-tts-web / inference-worker.js
KevinAHM's picture
Add: cache voice conditioning per voice switch
1dd87f5
// Pocket TTS ONNX Web Worker
console.log('Pocket TTS Worker Starting...');
self.postMessage({ type: 'status', status: 'Worker Thread Started', state: 'idle' });
// Load ONNX Runtime (will be loaded dynamically in loadModels for module worker)
let ort = null;
// Configuration
const MODELS = {
mimi_encoder: './onnx/mimi_encoder.onnx',
text_conditioner: './onnx/text_conditioner.onnx',
flow_lm_main: './onnx/flow_lm_main_int8.onnx',
flow_lm_flow: './onnx/flow_lm_flow_int8.onnx',
mimi_decoder: './onnx/mimi_decoder_int8.onnx',
tokenizer: './tokenizer.model',
voices: './voices.bin'
};
const SAMPLE_RATE = 24000;
const SAMPLES_PER_FRAME = 1920;
const MAX_FRAMES = 500;
const DEBUG_LOGS = false;
// Text chunking target; lower if long passages hit generation limits.
const CHUNK_TARGET_TOKENS = 50;
const CHUNK_GAP_SEC = 0.25;
// If true, re-run voice conditioning per chunk to avoid stale AR state.
const RESET_FLOW_STATE_EACH_CHUNK = true;
// If true, reset decoder state per chunk to avoid carry-over artifacts.
const RESET_MIMI_STATE_EACH_CHUNK = true;
// State
let mimiEncoderSession = null;
let textConditionerSession = null;
let flowLmMainSession = null;
let flowLmFlowSession = null;
let mimiDecoderSession = null;
let tokenizerProcessor = null;
let tokenizerModelB64 = null;
let predefinedVoices = {};
let stTensors = []; // Optimization: Pre-allocated s/t tensors for max LSD
let isGenerating = false;
let isReady = false;
// Dynamic LSD (Latent Solver/Diffusion steps)
const MAX_LSD = 10; // Default/max quality
let currentLSD = MAX_LSD;
// Current voice embedding (cached)
let currentVoiceEmbedding = null;
let currentVoiceName = null;
let voiceConditioningCache = new Map();
// Text preprocessing utilities
const ONES = ['', 'one', 'two', 'three', 'four', 'five', 'six', 'seven', 'eight', 'nine', 'ten', 'eleven', 'twelve', 'thirteen', 'fourteen', 'fifteen', 'sixteen', 'seventeen', 'eighteen', 'nineteen'];
const TENS = ['', '', 'twenty', 'thirty', 'forty', 'fifty', 'sixty', 'seventy', 'eighty', 'ninety'];
const ORDINAL_ONES = ['', 'first', 'second', 'third', 'fourth', 'fifth', 'sixth', 'seventh', 'eighth', 'ninth', 'tenth', 'eleventh', 'twelfth', 'thirteenth', 'fourteenth', 'fifteenth', 'sixteenth', 'seventeenth', 'eighteenth', 'nineteenth'];
const ORDINAL_TENS = ['', '', 'twentieth', 'thirtieth', 'fortieth', 'fiftieth', 'sixtieth', 'seventieth', 'eightieth', 'ninetieth'];
function numberToWords(num, options = {}) {
const { andword = '', zero = 'zero', group = 0 } = options;
if (num === 0) return zero;
const convert = (n) => {
if (n < 20) return ONES[n];
if (n < 100) return TENS[Math.floor(n / 10)] + (n % 10 ? ' ' + ONES[n % 10] : '');
if (n < 1000) {
const remainder = n % 100;
return ONES[Math.floor(n / 100)] + ' hundred' + (remainder ? (andword ? ' ' + andword + ' ' : ' ') + convert(remainder) : '');
}
if (n < 1000000) {
const thousands = Math.floor(n / 1000);
const remainder = n % 1000;
return convert(thousands) + ' thousand' + (remainder ? ' ' + convert(remainder) : '');
}
if (n < 1000000000) {
const millions = Math.floor(n / 1000000);
const remainder = n % 1000000;
return convert(millions) + ' million' + (remainder ? ' ' + convert(remainder) : '');
}
const billions = Math.floor(n / 1000000000);
const remainder = n % 1000000000;
return convert(billions) + ' billion' + (remainder ? ' ' + convert(remainder) : '');
};
if (group === 2 && num > 1000 && num < 10000) {
const high = Math.floor(num / 100);
const low = num % 100;
if (low === 0) return convert(high) + ' hundred';
else if (low < 10) return convert(high) + ' ' + (zero === 'oh' ? 'oh' : zero) + ' ' + ONES[low];
else return convert(high) + ' ' + convert(low);
}
return convert(num);
}
function ordinalToWords(num) {
if (num < 20) return ORDINAL_ONES[num] || numberToWords(num) + 'th';
if (num < 100) {
const tens = Math.floor(num / 10);
const ones = num % 10;
if (ones === 0) return ORDINAL_TENS[tens];
return TENS[tens] + ' ' + ORDINAL_ONES[ones];
}
const cardinal = numberToWords(num);
if (cardinal.endsWith('y')) return cardinal.slice(0, -1) + 'ieth';
if (cardinal.endsWith('one')) return cardinal.slice(0, -3) + 'first';
if (cardinal.endsWith('two')) return cardinal.slice(0, -3) + 'second';
if (cardinal.endsWith('three')) return cardinal.slice(0, -5) + 'third';
if (cardinal.endsWith('ve')) return cardinal.slice(0, -2) + 'fth';
if (cardinal.endsWith('e')) return cardinal.slice(0, -1) + 'th';
if (cardinal.endsWith('t')) return cardinal + 'h';
return cardinal + 'th';
}
const UNICODE_MAP = {
'à': 'a', 'á': 'a', 'â': 'a', 'ã': 'a', 'ä': 'a', 'å': 'a', 'æ': 'ae', 'ç': 'c', 'è': 'e', 'é': 'e', 'ê': 'e', 'ë': 'e', 'ì': 'i', 'í': 'i', 'î': 'i', 'ï': 'i', 'ñ': 'n', 'ò': 'o', 'ó': 'o', 'ô': 'o', 'õ': 'o', 'ö': 'o', 'ø': 'o', 'ù': 'u', 'ú': 'u', 'û': 'u', 'ü': 'u', 'ý': 'y', 'ÿ': 'y', 'ß': 'ss', 'œ': 'oe', 'ð': 'd', 'þ': 'th', 'À': 'A', 'Á': 'A', 'Â': 'A', 'Ã': 'A', 'Ä': 'A', 'Å': 'A', 'Æ': 'AE', 'Ç': 'C', 'È': 'E', 'É': 'E', 'Ê': 'E', 'Ë': 'E', 'Ì': 'I', 'Í': 'I', 'Î': 'I', 'Ï': 'I', 'Ñ': 'N', 'Ò': 'O', 'Ó': 'O', 'Ô': 'O', 'Õ': 'O', 'Ö': 'O', 'Ø': 'O', 'Ù': 'U', 'Ú': 'U', 'Û': 'U', 'Ü': 'U', 'Ý': 'Y', '\u201C': '"', '\u201D': '"', '\u2018': "'", '\u2019': "'", '\u2026': '...', '\u2013': '-', '\u2014': '-'
};
function convertToAscii(text) {
return text.split('').map(c => UNICODE_MAP[c] || c).join('').normalize('NFD').replace(/[\u0300-\u036f]/g, '');
}
const ABBREVIATIONS = [
[/\bmrs\./gi, 'misuss'], [/\bms\./gi, 'miss'], [/\bmr\./gi, 'mister'], [/\bdr\./gi, 'doctor'], [/\bst\./gi, 'saint'], [/\bco\./gi, 'company'], [/\bjr\./gi, 'junior'], [/\bmaj\./gi, 'major'], [/\bgen\./gi, 'general'], [/\bdrs\./gi, 'doctors'], [/\brev\./gi, 'reverend'], [/\blt\./gi, 'lieutenant'], [/\bhon\./gi, 'honorable'], [/\bsgt\./gi, 'sergeant'], [/\bcapt\./gi, 'captain'], [/\besq\./gi, 'esquire'], [/\bltd\./gi, 'limited'], [/\bcol\./gi, 'colonel'], [/\bft\./gi, 'fort']
];
const CASED_ABBREVIATIONS = [
[/\bTTS\b/g, 'text to speech'], [/\bHz\b/g, 'hertz'], [/\bkHz\b/g, 'kilohertz'], [/\bKBs\b/g, 'kilobytes'], [/\bKB\b/g, 'kilobyte'], [/\bMBs\b/g, 'megabytes'], [/\bMB\b/g, 'megabyte'], [/\bGBs\b/g, 'gigabytes'], [/\bGB\b/g, 'gigabyte'], [/\bTBs\b/g, 'terabytes'], [/\bTB\b/g, 'terabyte'], [/\bAPIs\b/g, "a p i's"], [/\bAPI\b/g, 'a p i'], [/\bCLIs\b/g, "c l i's"], [/\bCLI\b/g, 'c l i'], [/\bCPUs\b/g, "c p u's"], [/\bCPU\b/g, 'c p u'], [/\bGPUs\b/g, "g p u's"], [/\bGPU\b/g, 'g p u'], [/\bAve\b/g, 'avenue'], [/\betc\b/g, 'etcetera']
];
function expandAbbreviations(text) {
for (const [regex, replacement] of [...ABBREVIATIONS, ...CASED_ABBREVIATIONS]) text = text.replace(regex, replacement);
return text;
}
const NUM_PREFIX_RE = /#(\d)/g;
const NUM_SUFFIX_RE = /(\d)([KMBT])/gi;
const NUM_LETTER_SPLIT_RE = /(\d)([a-z])|([a-z])(\d)/gi;
const COMMA_NUMBER_RE = /(\d[\d,]+\d)/g;
const DATE_RE = /(^|[^/])(\d\d?[/-]\d\d?[/-]\d\d(?:\d\d)?)($|[^/])/g;
const PHONE_NUMBER_RE = /\(?\d{3}\)?[-.\s]\d{3}[-.\s]?\d{4}/g;
const TIME_RE = /(\d\d?):(\d\d)(?::(\d\d))?/g;
const POUNDS_RE = /£([\d,]*\d+)/g;
const DOLLARS_RE = /\$([\d.,]*\d+)/g;
const DECIMAL_NUMBER_RE = /(\d+(?:\.\d+)+)/g;
const MULTIPLY_RE = /(\d)\s?\*\s?(\d)/g;
const DIVIDE_RE = /(\d)\s?\/\s?(\d)/g;
const ADD_RE = /(\d)\s?\+\s?(\d)/g;
const SUBTRACT_RE = /(\d)?\s?-\s?(\d)/g;
const FRACTION_RE = /(\d+)\/(\d+)/g;
const ORDINAL_RE = /(\d+)(st|nd|rd|th)/gi;
const NUMBER_RE = /\d+/g;
function normalizeNumbers(text) {
text = text.replace(NUM_PREFIX_RE, (_, d) => `number ${d}`);
text = text.replace(NUM_SUFFIX_RE, (_, num, suffix) => {
const map = { k: 'thousand', m: 'million', b: 'billion', t: 'trillion' };
return `${num} ${map[suffix.toLowerCase()]}`;
});
for (let i = 0; i < 2; i++) {
text = text.replace(NUM_LETTER_SPLIT_RE, (m, d1, l1, l2, d2) => {
if (d1 && l1) return `${d1} ${l1}`;
if (l2 && d2) return `${l2} ${d2}`;
return m;
});
}
text = text.replace(COMMA_NUMBER_RE, m => m.replace(/,/g, ''));
text = text.replace(DATE_RE, (_, pre, date, post) => pre + date.split(/[./-]/).join(' dash ') + post);
text = text.replace(PHONE_NUMBER_RE, m => {
const digits = m.replace(/\D/g, '');
return digits.length === 10 ? `${digits.slice(0, 3).split('').join(' ')}, ${digits.slice(3, 6).split('').join(' ')}, ${digits.slice(6).split('').join(' ')}` : m;
});
text = text.replace(TIME_RE, (_, hours, minutes, seconds) => {
const h = parseInt(hours), m = parseInt(minutes), s = seconds ? parseInt(seconds) : 0;
if (!seconds) return m === 0 ? (h === 0 ? '0' : h > 12 ? `${hours} minutes` : `${hours} o'clock`) : minutes.startsWith('0') ? `${hours} oh ${minutes[1]}` : `${hours} ${minutes}`;
let res = '';
if (h !== 0) res = hours + ' ' + (m === 0 ? 'oh oh' : minutes.startsWith('0') ? `oh ${minutes[1]}` : minutes);
else if (m !== 0) res = minutes + ' ' + (s === 0 ? 'oh oh' : seconds.startsWith('0') ? `oh ${seconds[1]}` : seconds);
else res = seconds;
return res + ' ' + (s === 0 ? '' : seconds.startsWith('0') ? `oh ${seconds[1]}` : seconds);
});
text = text.replace(POUNDS_RE, (_, amount) => `${amount.replace(/,/g, '')} pounds`);
text = text.replace(DOLLARS_RE, (_, amount) => {
const parts = amount.replace(/,/g, '').split('.');
const dollars = parseInt(parts[0]) || 0;
const cents = parts[1] ? parseInt(parts[1]) : 0;
if (dollars && cents) return `${dollars} ${dollars === 1 ? 'dollar' : 'dollars'}, ${cents} ${cents === 1 ? 'cent' : 'cents'}`;
if (dollars) return `${dollars} ${dollars === 1 ? 'dollar' : 'dollars'}`;
if (cents) return `${cents} ${cents === 1 ? 'cent' : 'cents'}`;
return 'zero dollars';
});
text = text.replace(DECIMAL_NUMBER_RE, m => m.split('.').join(' point ').split('').join(' '));
text = text.replace(MULTIPLY_RE, '$1 times $2');
text = text.replace(DIVIDE_RE, '$1 over $2');
text = text.replace(ADD_RE, '$1 plus $2');
text = text.replace(SUBTRACT_RE, (_, a, b) => (a ? a : '') + ' minus ' + b);
text = text.replace(FRACTION_RE, '$1 over $2');
text = text.replace(ORDINAL_RE, (_, num) => ordinalToWords(parseInt(num)));
text = text.replace(NUMBER_RE, m => {
const num = parseInt(m);
if (num > 1000 && num < 3000) {
if (num === 2000) return 'two thousand';
if (num > 2000 && num < 2010) return 'two thousand ' + numberToWords(num % 100);
if (num % 100 === 0) return numberToWords(Math.floor(num / 100)) + ' hundred';
return numberToWords(num, { zero: 'oh', group: 2 });
}
return numberToWords(num);
});
return text;
}
const SPECIAL_CHARACTERS = [
[/@/g, ' at '], [/&/g, ' and '], [/%/g, ' percent '], [/:/g, '.'], [/;/g, ','], [/\+/g, ' plus '], [/\\/g, ' backslash '], [/~/g, ' about '], [/(^| )<3/g, ' heart '], [/<=/g, ' less than or equal to '], [/>=/g, ' greater than or equal to '], [/</g, ' less than '], [/>/g, ' greater than '], [/=/g, ' equals '], [/\//g, ' slash '], [/_/g, ' '],
];
const LINK_HEADER_RE = /https?:\/\//gi;
const DASH_RE = /(.) - (.)/g;
const DOT_RE = /([A-Z])\.([A-Z])/gi;
const PARENTHESES_RE = /[\(\[\{][^\)\]\}]*[\)\]\}](.)?/g;
function normalizeSpecial(text) {
text = text.replace(LINK_HEADER_RE, 'h t t p s colon slash slash ');
text = text.replace(DASH_RE, '$1, $2');
text = text.replace(DOT_RE, '$1 dot $2');
text = text.replace(PARENTHESES_RE, (m, after) => {
let result = m.replace(/[\(\[\{]/g, ', ').replace(/[\)\]\}]/g, ', ');
if (after && /[$.!?,]/.test(after)) result = result.slice(0, -2) + after;
return result;
});
return text;
}
function expandSpecialCharacters(text) {
for (const [regex, replacement] of SPECIAL_CHARACTERS) text = text.replace(regex, replacement);
return text;
}
function collapseWhitespace(text) {
return text.replace(/\s+/g, ' ').replace(/ ([.\?!,])/g, '$1');
}
function dedupPunctuation(text) {
return text.replace(/\.\.\.+/g, '[ELLIPSIS]').replace(/,+/g, ',').replace(/[.,]*\.[.,]*/g, '.').replace(/[.,!]*![.,!]*/g, '!').replace(/[.,!?]*\?[.,!?]*/g, '?').replace(/\[ELLIPSIS\]/g, '...');
}
const SENTENCE_SPLIT_RE = /[^.!?]+[.!?]+|[^.!?]+$/g;
function splitTextIntoSentences(text) {
const matches = text.match(SENTENCE_SPLIT_RE);
if (!matches) return [];
return matches.map(sentence => sentence.trim()).filter(Boolean);
}
function splitTokenIdsIntoChunks(tokenIds, maxTokens) {
const chunks = [];
for (let i = 0; i < tokenIds.length; i += maxTokens) {
const chunkText = tokenizerProcessor.decodeIds(tokenIds.slice(i, i + maxTokens)).trim();
if (chunkText) chunks.push(chunkText);
}
return chunks;
}
// Split text into sentence chunks (target <= CHUNK_TARGET_TOKENS tokens)
function splitIntoBestSentences(text) {
const preparedText = prepareText(text);
if (!preparedText) return [];
const sentences = splitTextIntoSentences(preparedText);
if (sentences.length === 0) return [];
// Merge sentences into chunks that stay within the token target
const chunks = [];
let currentChunk = '';
for (const sentenceText of sentences) {
const sentenceTokenIds = tokenizerProcessor.encodeIds(sentenceText);
const sentenceTokens = sentenceTokenIds.length;
if (sentenceTokens > CHUNK_TARGET_TOKENS) {
if (currentChunk !== '') {
chunks.push(currentChunk.trim());
currentChunk = '';
}
const splitChunks = splitTokenIdsIntoChunks(sentenceTokenIds, CHUNK_TARGET_TOKENS);
for (const splitChunk of splitChunks) {
if (splitChunk) chunks.push(splitChunk.trim());
}
continue;
}
if (currentChunk === '') {
currentChunk = sentenceText;
continue;
}
const combined = `${currentChunk} ${sentenceText}`;
const combinedTokens = tokenizerProcessor.encodeIds(combined).length;
if (combinedTokens > CHUNK_TARGET_TOKENS) {
chunks.push(currentChunk.trim());
currentChunk = sentenceText;
} else {
currentChunk = combined;
}
}
if (currentChunk !== '') {
chunks.push(currentChunk.trim());
}
return chunks;
}
// Pocket TTS specific text preprocessing
function prepareText(text) {
text = text.trim();
if (!text) return '';
// Convert to ASCII
text = convertToAscii(text);
// Normalize numbers first
text = normalizeNumbers(text);
// Normalize special characters
text = normalizeSpecial(text);
// Expand abbreviations
text = expandAbbreviations(text);
// Expand special characters
text = expandSpecialCharacters(text);
// Collapse whitespace
text = collapseWhitespace(text);
// Deduplicate punctuation
text = dedupPunctuation(text);
// Final cleanup
text = text.trim();
// Ensure proper punctuation at end
if (text && text[text.length - 1].match(/[a-zA-Z0-9]/)) {
text = text + '.';
}
// Capitalize first letter
if (text && !text[0].match(/[A-Z]/)) {
text = text[0].toUpperCase() + text.slice(1);
}
return text;
}
// ----------------------------------------------------------------------------
// Worker Logic
// ----------------------------------------------------------------------------
self.onmessage = async (e) => {
const { type, data } = e.data;
console.log('Worker received message:', type);
if (type === 'load') {
try {
await loadModels();
postMessage({ type: 'loaded' });
} catch (err) {
postMessage({ type: 'error', error: err.toString() });
}
} else if (type === 'generate') {
if (!isReady) {
postMessage({ type: 'error', error: 'Models are not loaded yet.' });
return;
}
if (isGenerating) return;
try {
await startGeneration(data.text, data.voice);
} catch (err) {
console.error('Generation Error:', err);
postMessage({ type: 'error', error: err.toString() });
}
} else if (type === 'encode_voice') {
if (!isReady) {
postMessage({ type: 'error', error: 'Models are not loaded yet.' });
return;
}
if (isGenerating) {
postMessage({ type: 'error', error: 'Cannot encode a voice while generation is running.' });
return;
}
try {
const embedding = await encodeVoiceAudio(data.audio);
currentVoiceEmbedding = embedding;
currentVoiceName = 'custom';
await ensureVoiceConditioningCached('custom', embedding, {
force: true,
statusText: 'Conditioning custom voice...'
});
postMessage({ type: 'voice_encoded', voiceName: 'custom' });
postMessage({ type: 'status', status: 'Ready', state: 'idle' });
} catch (err) {
console.error('Voice encoding error:', err);
postMessage({ type: 'error', error: 'Failed to encode voice: ' + err.toString() });
}
} else if (type === 'set_voice') {
if (!isReady) {
postMessage({ type: 'error', error: 'Models are not loaded yet.' });
return;
}
if (isGenerating) {
postMessage({ type: 'error', error: 'Cannot switch voice while generation is running.' });
return;
}
try {
if (data.voiceName === 'custom') {
if (!currentVoiceEmbedding || currentVoiceName !== 'custom') {
postMessage({ type: 'error', error: 'No custom voice loaded. Upload audio first.' });
return;
}
await ensureVoiceConditioningCached('custom', currentVoiceEmbedding, {
statusText: 'Conditioning custom voice...'
});
postMessage({ type: 'voice_set', voiceName: 'custom' });
} else if (predefinedVoices[data.voiceName]) {
currentVoiceEmbedding = predefinedVoices[data.voiceName];
currentVoiceName = data.voiceName;
await ensureVoiceConditioningCached(data.voiceName, currentVoiceEmbedding, {
statusText: `Conditioning voice (${data.voiceName})...`
});
postMessage({ type: 'voice_set', voiceName: data.voiceName });
} else {
postMessage({ type: 'error', error: `Unknown voice: ${data.voiceName}` });
return;
}
postMessage({ type: 'status', status: 'Ready', state: 'idle' });
} catch (err) {
console.error('Voice switch error:', err);
postMessage({ type: 'error', error: 'Failed to set voice: ' + err.toString() });
}
} else if (type === 'set_lsd') {
// Dynamic LSD adjustment for edge devices
const newLSD = Math.max(1, Math.min(MAX_LSD, data.lsd));
if (newLSD !== currentLSD) {
console.log(`LSD adjusted: ${currentLSD}${newLSD}`);
currentLSD = newLSD;
}
} else if (type === 'stop') {
isGenerating = false;
postMessage({ type: 'status', status: 'Stopped', state: 'idle' });
}
};
async function loadModels() {
if (mimiEncoderSession) return;
postMessage({ type: 'status', status: 'Loading ONNX Runtime...', state: 'loading' });
// Load ONNX Runtime dynamically
const version = '1.20.0';
const cdnBase = `https://cdn.jsdelivr.net/npm/onnxruntime-web@${version}/dist/`;
try {
const ortModule = await import(`https://cdn.jsdelivr.net/npm/onnxruntime-web@${version}/dist/ort.min.mjs`);
ort = ortModule.default || ortModule;
} catch (e) {
console.error('Failed to load ONNX Runtime:', e);
throw new Error('Failed to load ONNX Runtime: ' + e.message);
}
if (!ort) {
throw new Error('ONNX Runtime failed to load');
}
postMessage({ type: 'status', status: 'Loading models...', state: 'loading' });
// Configure WASM Paths
ort.env.wasm.wasmPaths = cdnBase;
// Enable SIMD for significant performance boost (2-4x faster)
ort.env.wasm.simd = true;
// Configure multi-threading
if (!self.crossOriginIsolated) {
console.warn('Environment is not cross-origin isolated. Disabling WASM multi-threading.');
console.warn('To enable multi-threading, serve with headers:');
console.warn(' Cross-Origin-Opener-Policy: same-origin');
console.warn(' Cross-Origin-Embedder-Policy: require-corp');
ort.env.wasm.numThreads = 1;
} else {
const threads = Math.min(navigator.hardwareConcurrency || 4, 8);
ort.env.wasm.numThreads = threads;
if (DEBUG_LOGS) {
console.log(`Multi-threading enabled with ${threads} threads`);
}
}
console.log(`ORT: crossOriginIsolated=${self.crossOriginIsolated}, simd=${ort.env.wasm.simd}, threads=${ort.env.wasm.numThreads}`);
try {
const sessionOptions = {
executionProviders: ['wasm'],
graphOptimizationLevel: 'all'
};
// Load all models in parallel
postMessage({ type: 'status', status: 'Loading MIMI encoder...', state: 'loading' });
if (DEBUG_LOGS) {
console.log('Loading MIMI encoder...');
}
const [encoderRes, textCondRes, flowMainRes, flowFlowRes, decoderRes] = await Promise.all([
ort.InferenceSession.create(MODELS.mimi_encoder, sessionOptions),
ort.InferenceSession.create(MODELS.text_conditioner, sessionOptions),
ort.InferenceSession.create(MODELS.flow_lm_main, sessionOptions),
ort.InferenceSession.create(MODELS.flow_lm_flow, sessionOptions),
ort.InferenceSession.create(MODELS.mimi_decoder, sessionOptions)
]);
mimiEncoderSession = encoderRes;
textConditionerSession = textCondRes;
flowLmMainSession = flowMainRes;
flowLmFlowSession = flowFlowRes;
mimiDecoderSession = decoderRes;
if (DEBUG_LOGS) {
console.log('All models loaded successfully');
console.log('Flow LM Main inputs:', flowLmMainSession.inputNames);
console.log('Flow LM Main outputs:', flowLmMainSession.outputNames);
console.log('MIMI decoder inputs:', mimiDecoderSession.inputNames);
console.log('MIMI decoder outputs:', mimiDecoderSession.outputNames);
}
// Load tokenizer
postMessage({ type: 'status', status: 'Loading tokenizer...', state: 'loading' });
if (DEBUG_LOGS) {
console.log('Loading tokenizer...');
}
const tokenizerResponse = await fetch(MODELS.tokenizer);
if (!tokenizerResponse.ok) {
throw new Error(`Failed to load tokenizer: ${tokenizerResponse.statusText}`);
}
const tokenizerBuffer = await tokenizerResponse.arrayBuffer();
tokenizerModelB64 = btoa(String.fromCharCode(...new Uint8Array(tokenizerBuffer)));
// Import and initialize sentencepiece processor
const spModule = await import('./sentencepiece.js?v=2');
const SentencePieceProcessor = spModule.SentencePieceProcessor;
if (!SentencePieceProcessor) {
throw new Error('SentencePieceProcessor not found in sentencepiece.js');
}
tokenizerProcessor = new SentencePieceProcessor();
await tokenizerProcessor.loadFromB64StringModel(tokenizerModelB64);
if (DEBUG_LOGS) {
console.log('Tokenizer loaded');
}
// Load predefined voices
postMessage({ type: 'status', status: 'Loading voices...', state: 'loading' });
if (DEBUG_LOGS) {
console.log('Loading predefined voices...');
}
try {
const voicesResponse = await fetch(MODELS.voices);
if (voicesResponse.ok) {
const voicesData = await voicesResponse.arrayBuffer();
predefinedVoices = parseVoicesBin(voicesData);
if (DEBUG_LOGS) {
console.log('Loaded voices:', Object.keys(predefinedVoices));
}
// Set default voice
if (predefinedVoices['cosette']) {
currentVoiceEmbedding = predefinedVoices['cosette'];
currentVoiceName = 'cosette';
} else {
// Use first available voice
const firstVoice = Object.keys(predefinedVoices)[0];
if (firstVoice) {
currentVoiceEmbedding = predefinedVoices[firstVoice];
currentVoiceName = firstVoice;
}
}
}
} catch (e) {
console.warn('Could not load predefined voices:', e);
}
if (currentVoiceEmbedding && currentVoiceName) {
await ensureVoiceConditioningCached(currentVoiceName, currentVoiceEmbedding, {
force: true,
statusText: `Loading voice conditioning (${currentVoiceName})...`
});
}
// Send list of available voices
postMessage({
type: 'voices_loaded',
voices: Object.keys(predefinedVoices),
defaultVoice: currentVoiceName
});
// Pre-allocate s/t tensors for Flow Matching Loop (Optimization)
// Pre-allocate for MAX_LSD to support dynamic switching
if (DEBUG_LOGS) {
console.log(`Pre-allocating Flow Matching tensors for LSD 1-${MAX_LSD}...`);
}
stTensors = {};
for (let lsd = 1; lsd <= MAX_LSD; lsd++) {
stTensors[lsd] = [];
const dt = 1.0 / lsd;
for (let j = 0; j < lsd; j++) {
const s = j / lsd;
const t = s + dt;
stTensors[lsd].push({
s: new ort.Tensor('float32', new Float32Array([s]), [1, 1]),
t: new ort.Tensor('float32', new Float32Array([t]), [1, 1])
});
}
}
isReady = true;
postMessage({ type: 'status', status: 'Ready', state: 'idle' });
postMessage({ type: 'model_status', status: 'ready', text: 'Ready' });
postMessage({ type: 'loaded' });
} catch (err) {
console.error('Model load failed:', err);
throw err;
}
}
function parseVoicesBin(buffer) {
// Simple binary format:
// Header: 4 bytes (uint32) = number of voices
// For each voice:
// - 32 bytes: voice name (null-terminated string)
// - 4 bytes (uint32): number of frames
// - 4 bytes (uint32): embedding dim (1024)
// - frames * dim * 4 bytes: float32 embeddings
const voices = {};
const view = new DataView(buffer);
let offset = 0;
const numVoices = view.getUint32(offset, true);
offset += 4;
for (let i = 0; i < numVoices; i++) {
// Read voice name
const nameBytes = new Uint8Array(buffer, offset, 32);
const nameEnd = nameBytes.indexOf(0);
const name = new TextDecoder().decode(nameBytes.subarray(0, nameEnd > 0 ? nameEnd : 32)).trim();
offset += 32;
// Read dimensions
const numFrames = view.getUint32(offset, true);
offset += 4;
const embDim = view.getUint32(offset, true);
offset += 4;
// Read embeddings
const embSize = numFrames * embDim;
const embeddings = new Float32Array(buffer, offset, embSize);
offset += embSize * 4;
// Store as [1, numFrames, embDim] shaped array info
voices[name] = {
data: new Float32Array(embeddings),
shape: [1, numFrames, embDim]
};
console.log(`Loaded voice '${name}': ${numFrames} frames, ${embDim} dim`);
}
return voices;
}
async function encodeVoiceAudio(audioData) {
// audioData should be Float32Array at 24kHz, mono
// Reshape to [1, 1, samples]
const input = new ort.Tensor('float32', audioData, [1, 1, audioData.length]);
const outputs = await mimiEncoderSession.run({ audio: input });
const embeddings = outputs[mimiEncoderSession.outputNames[0]];
return {
data: new Float32Array(embeddings.data),
shape: embeddings.dims
};
}
async function buildVoiceConditionedState(voiceEmb) {
const flowLmState = initState(flowLmMainSession, FLOW_LM_STATE_SHAPES);
const emptySeq = new ort.Tensor('float32', new Float32Array(0), [1, 0, 32]);
const voiceTensor = new ort.Tensor('float32', voiceEmb.data, voiceEmb.shape);
const voiceCondInputs = {
sequence: emptySeq,
text_embeddings: voiceTensor,
...flowLmState
};
const condResult = await flowLmMainSession.run(voiceCondInputs);
for (let i = 2; i < flowLmMainSession.outputNames.length; i++) {
const outputName = flowLmMainSession.outputNames[i];
if (outputName.startsWith('out_state_')) {
const stateIdx = parseInt(outputName.replace('out_state_', ''));
flowLmState[`state_${stateIdx}`] = condResult[outputName];
}
}
return flowLmState;
}
function cloneFlowState(baseState) {
// Shallow clone is enough: we only replace tensor refs in the local state map.
return { ...baseState };
}
async function ensureVoiceConditioningCached(voiceName, voiceEmb, options = {}) {
const { force = false, statusText = 'Conditioning voice...' } = options;
if (!voiceName) {
throw new Error('Cannot cache voice conditioning without a voice name.');
}
if (!voiceEmb) {
throw new Error(`Cannot cache voice conditioning for '${voiceName}' without embeddings.`);
}
if (!force && voiceConditioningCache.has(voiceName)) {
console.log(`[voice-conditioning] ready for '${voiceName}' (cache hit)`);
return voiceConditioningCache.get(voiceName);
}
postMessage({ type: 'status', status: statusText, state: 'loading' });
const startMs = performance.now();
const conditionedState = await buildVoiceConditionedState(voiceEmb);
voiceConditioningCache.set(voiceName, conditionedState);
const elapsedMs = performance.now() - startMs;
console.log(`[voice-conditioning] completed for '${voiceName}' in ${elapsedMs.toFixed(0)}ms`);
return conditionedState;
}
// Hardcoded state shapes extracted from ONNX model metadata
// These are the initial shapes - dynamic dimensions start at 0
const FLOW_LM_STATE_SHAPES = {
// KV cache layers: [kv=2, batch=1, max_seq=1000, heads=16, head_dim=64]
state_0: { shape: [2, 1, 1000, 16, 64], dtype: 'float32' },
state_1: { shape: [0], dtype: 'float32' }, // dynamic
state_2: { shape: [1], dtype: 'int64' }, // step counter
state_3: { shape: [2, 1, 1000, 16, 64], dtype: 'float32' },
state_4: { shape: [0], dtype: 'float32' },
state_5: { shape: [1], dtype: 'int64' },
state_6: { shape: [2, 1, 1000, 16, 64], dtype: 'float32' },
state_7: { shape: [0], dtype: 'float32' },
state_8: { shape: [1], dtype: 'int64' },
state_9: { shape: [2, 1, 1000, 16, 64], dtype: 'float32' },
state_10: { shape: [0], dtype: 'float32' },
state_11: { shape: [1], dtype: 'int64' },
state_12: { shape: [2, 1, 1000, 16, 64], dtype: 'float32' },
state_13: { shape: [0], dtype: 'float32' },
state_14: { shape: [1], dtype: 'int64' },
state_15: { shape: [2, 1, 1000, 16, 64], dtype: 'float32' },
state_16: { shape: [0], dtype: 'float32' },
state_17: { shape: [1], dtype: 'int64' },
};
const MIMI_DECODER_STATE_SHAPES = {
state_0: { shape: [1], dtype: 'bool' },
state_1: { shape: [1, 512, 6], dtype: 'float32' },
state_2: { shape: [1], dtype: 'bool' },
state_3: { shape: [1, 64, 2], dtype: 'float32' },
state_4: { shape: [1, 256, 6], dtype: 'float32' },
state_5: { shape: [1], dtype: 'bool' },
state_6: { shape: [1, 256, 2], dtype: 'float32' },
state_7: { shape: [1], dtype: 'bool' },
state_8: { shape: [1, 128, 0], dtype: 'float32' }, // dynamic
state_9: { shape: [1, 128, 5], dtype: 'float32' },
state_10: { shape: [1], dtype: 'bool' },
state_11: { shape: [1, 128, 2], dtype: 'float32' },
state_12: { shape: [1], dtype: 'bool' },
state_13: { shape: [1, 64, 0], dtype: 'float32' }, // dynamic
state_14: { shape: [1, 64, 4], dtype: 'float32' },
state_15: { shape: [1], dtype: 'bool' },
state_16: { shape: [1, 64, 2], dtype: 'float32' },
state_17: { shape: [1], dtype: 'bool' },
state_18: { shape: [1, 32, 0], dtype: 'float32' }, // dynamic
state_19: { shape: [2, 1, 8, 1000, 64], dtype: 'float32' },
state_20: { shape: [1], dtype: 'int64' },
state_21: { shape: [1], dtype: 'int64' },
state_22: { shape: [2, 1, 8, 1000, 64], dtype: 'float32' },
state_23: { shape: [1], dtype: 'int64' },
state_24: { shape: [1], dtype: 'int64' },
state_25: { shape: [1], dtype: 'bool' },
state_26: { shape: [1, 512, 16], dtype: 'float32' },
state_27: { shape: [1], dtype: 'bool' },
state_28: { shape: [1, 1, 6], dtype: 'float32' },
state_29: { shape: [1], dtype: 'bool' },
state_30: { shape: [1, 64, 2], dtype: 'float32' },
state_31: { shape: [1], dtype: 'bool' },
state_32: { shape: [1, 32, 0], dtype: 'float32' }, // dynamic
state_33: { shape: [1], dtype: 'bool' },
state_34: { shape: [1, 512, 2], dtype: 'float32' },
state_35: { shape: [1], dtype: 'bool' },
state_36: { shape: [1, 64, 4], dtype: 'float32' },
state_37: { shape: [1], dtype: 'bool' },
state_38: { shape: [1, 128, 2], dtype: 'float32' },
state_39: { shape: [1], dtype: 'bool' },
state_40: { shape: [1, 64, 0], dtype: 'float32' }, // dynamic
state_41: { shape: [1], dtype: 'bool' },
state_42: { shape: [1, 128, 5], dtype: 'float32' },
state_43: { shape: [1], dtype: 'bool' },
state_44: { shape: [1, 256, 2], dtype: 'float32' },
state_45: { shape: [1], dtype: 'bool' },
state_46: { shape: [1, 128, 0], dtype: 'float32' }, // dynamic
state_47: { shape: [1], dtype: 'bool' },
state_48: { shape: [1, 256, 6], dtype: 'float32' },
state_49: { shape: [2, 1, 8, 1000, 64], dtype: 'float32' },
state_50: { shape: [1], dtype: 'int64' },
state_51: { shape: [1], dtype: 'int64' },
state_52: { shape: [2, 1, 8, 1000, 64], dtype: 'float32' },
state_53: { shape: [1], dtype: 'int64' },
state_54: { shape: [1], dtype: 'int64' },
state_55: { shape: [1, 512, 16], dtype: 'float32' },
};
function initState(session, stateShapes) {
/**
* Initialize state tensors for a stateful ONNX model using hardcoded shapes.
*/
const state = {};
for (const inputName of session.inputNames) {
if (inputName.startsWith('state_')) {
const stateInfo = stateShapes[inputName];
if (!stateInfo) {
console.warn(`Unknown state input: ${inputName}, skipping`);
continue;
}
const { shape, dtype } = stateInfo;
const size = shape.reduce((a, b) => a * b, 1);
let data;
if (dtype === 'int64') {
data = new BigInt64Array(size);
} else if (dtype === 'bool') {
data = new Uint8Array(size);
} else {
data = new Float32Array(size);
}
state[inputName] = new ort.Tensor(dtype, data, shape);
if (DEBUG_LOGS) {
console.log(`Init state ${inputName}: shape=${JSON.stringify(shape)}, dtype=${dtype}`);
}
}
}
return state;
}
async function startGeneration(text, voiceName) {
isGenerating = true;
currentLSD = MAX_LSD; // Reset to max quality for each new generation
postMessage({ type: 'status', status: 'Generating...', state: 'running' });
postMessage({ type: 'generation_started', data: { time: performance.now() } });
try {
// Split text into sentence chunks (target <= CHUNK_TARGET_TOKENS tokens)
const chunks = splitIntoBestSentences(text);
console.log(`Split into ${chunks.length} chunks:`, chunks);
if (chunks.length === 0) {
throw new Error('No text to generate');
}
// Resolve voice
let resolvedVoiceName = currentVoiceName;
if (voiceName && voiceName !== currentVoiceName) {
if (predefinedVoices[voiceName]) {
currentVoiceEmbedding = predefinedVoices[voiceName];
currentVoiceName = voiceName;
resolvedVoiceName = voiceName;
await ensureVoiceConditioningCached(resolvedVoiceName, currentVoiceEmbedding, {
statusText: `Conditioning voice (${resolvedVoiceName})...`
});
} else if (voiceName === 'custom' && currentVoiceName === 'custom') {
resolvedVoiceName = 'custom';
}
}
if (!currentVoiceEmbedding || !resolvedVoiceName) {
throw new Error('No voice embedding available. Please select a voice or upload custom audio.');
}
if (!voiceConditioningCache.has(resolvedVoiceName)) {
throw new Error(`Voice conditioning cache missing for '${resolvedVoiceName}'. Switch voices to prepare cache.`);
}
// Run generation pipeline with chunks
await runGenerationPipeline(resolvedVoiceName, chunks);
} catch (err) {
console.error('Generation error:', err);
postMessage({ type: 'error', error: err.toString() });
} finally {
if (isGenerating) {
postMessage({ type: 'stream_ended' });
postMessage({ type: 'status', status: 'Finished', state: 'idle' });
}
isGenerating = false;
}
}
async function runGenerationPipeline(voiceName, chunks) {
// Initialize state - may be reset per chunk
let mimiState = initState(mimiDecoderSession, MIMI_DECODER_STATE_SHAPES);
const emptySeq = new ort.Tensor('float32', new Float32Array(0), [1, 0, 32]);
const emptyTextEmb = new ort.Tensor('float32', new Float32Array(0), [1, 0, 1024]);
const baseFlowState = voiceConditioningCache.get(voiceName);
if (!baseFlowState) {
throw new Error(`Voice conditioning cache missing for '${voiceName}'.`);
}
let flowLmState = cloneFlowState(baseFlowState);
// Streaming parameters
const FIRST_CHUNK_FRAMES = 3;
const NORMAL_CHUNK_FRAMES = 12;
// Tracking across all chunks
const allGeneratedLatents = [];
let isFirstAudioChunk = true;
let totalDecodedFrames = 0;
let totalFlowLmTime = 0;
let totalDecodeTime = 0;
const arStartTime = performance.now();
// Process each text chunk
for (let chunkIdx = 0; chunkIdx < chunks.length; chunkIdx++) {
if (!isGenerating) break;
if (RESET_FLOW_STATE_EACH_CHUNK && chunkIdx > 0) {
flowLmState = cloneFlowState(baseFlowState);
}
if (RESET_MIMI_STATE_EACH_CHUNK && chunkIdx > 0) {
mimiState = initState(mimiDecoderSession, MIMI_DECODER_STATE_SHAPES);
}
const chunkText = chunks[chunkIdx];
console.log(`Processing chunk ${chunkIdx + 1}/${chunks.length}: "${chunkText}"`);
let isFirstAudioChunkOfTextChunk = true;
// Tokenize this chunk
const tokenIds = tokenizerProcessor.encodeIds(chunkText);
console.log(`Chunk ${chunkIdx + 1} tokens:`, tokenIds.length);
// Text conditioning for this chunk
const textInput = new ort.Tensor('int64', BigInt64Array.from(tokenIds.map(x => BigInt(x))), [1, tokenIds.length]);
const textCondResult = await textConditionerSession.run({ token_ids: textInput });
let textEmb = textCondResult[textConditionerSession.outputNames[0]];
if (textEmb.dims.length === 2) {
textEmb = new ort.Tensor('float32', textEmb.data, [1, textEmb.dims[0], textEmb.dims[1]]);
}
const textCondInputs = {
sequence: emptySeq,
text_embeddings: textEmb,
...flowLmState
};
let condResult = await flowLmMainSession.run(textCondInputs);
// Update state from text conditioning
for (let i = 2; i < flowLmMainSession.outputNames.length; i++) {
const outputName = flowLmMainSession.outputNames[i];
if (outputName.startsWith('out_state_')) {
const stateIdx = parseInt(outputName.replace('out_state_', ''));
flowLmState[`state_${stateIdx}`] = condResult[outputName];
}
}
// AR generation for this chunk
const chunkLatents = [];
let currentLatent = new ort.Tensor('float32', new Float32Array(32).fill(NaN), [1, 1, 32]);
let chunkDecodedFrames = 0;
const FRAMES_AFTER_EOS = 3; // Match PyTorch behavior - generate extra frames after EOS
let eosStep = null;
let chunkEnded = false;
let chunkGenTimeMs = 0;
for (let step = 0; step < MAX_FRAMES; step++) {
if (!isGenerating) break;
// Yield every 4 steps to allow message processing (e.g., set_lsd)
if (step > 0 && step % 4 === 0) {
await new Promise(r => setTimeout(r, 0));
}
const arInputs = {
sequence: currentLatent,
text_embeddings: emptyTextEmb,
...flowLmState
};
const stepStart = performance.now();
const arResult = await flowLmMainSession.run(arInputs);
const stepElapsed = performance.now() - stepStart;
chunkGenTimeMs += stepElapsed;
const conditioning = arResult['conditioning'];
const eosLogit = arResult['eos_logit'].data[0];
const isEos = eosLogit > -4.0;
// Track when EOS is first detected
if (isEos && eosStep === null) {
eosStep = step;
}
// Only stop after FRAMES_AFTER_EOS additional frames
const shouldStop = eosStep !== null && step >= eosStep + FRAMES_AFTER_EOS;
// Flow matching (LSD loop) - uses currentLSD which can be adjusted dynamically
const TEMP = 0.7;
const STD = Math.sqrt(TEMP);
let xData = new Float32Array(32);
for (let i = 0; i < 32; i++) {
let u = 0, v = 0;
while (u === 0) u = Math.random();
while (v === 0) v = Math.random();
xData[i] = Math.sqrt(-2.0 * Math.log(u)) * Math.cos(2.0 * Math.PI * v) * STD;
}
const lsdSteps = currentLSD;
const dt = 1.0 / lsdSteps;
for (let j = 0; j < lsdSteps; j++) {
const flowInputs = {
c: conditioning,
s: stTensors[lsdSteps][j].s,
t: stTensors[lsdSteps][j].t,
x: new ort.Tensor('float32', xData, [1, 32])
};
const flowResult = await flowLmFlowSession.run(flowInputs);
const v = flowResult['flow_dir'].data;
for (let k = 0; k < 32; k++) {
xData[k] += v[k] * dt;
}
}
totalFlowLmTime += stepElapsed;
const latentData = xData;
chunkLatents.push(new Float32Array(latentData));
allGeneratedLatents.push(new Float32Array(latentData));
// Update state
currentLatent = new ort.Tensor('float32', latentData, [1, 1, 32]);
for (let i = 2; i < flowLmMainSession.outputNames.length; i++) {
const outputName = flowLmMainSession.outputNames[i];
if (outputName.startsWith('out_state_')) {
const stateIdx = parseInt(outputName.replace('out_state_', ''));
flowLmState[`state_${stateIdx}`] = arResult[outputName];
}
}
// Decode audio chunks
const pending = chunkLatents.length - chunkDecodedFrames;
let decodeSize = 0;
if (shouldStop) {
decodeSize = pending;
} else if (isFirstAudioChunk && pending >= FIRST_CHUNK_FRAMES) {
decodeSize = FIRST_CHUNK_FRAMES;
} else if (pending >= NORMAL_CHUNK_FRAMES) {
decodeSize = NORMAL_CHUNK_FRAMES;
}
if (decodeSize > 0) {
const decodeLatents = new Float32Array(decodeSize * 32);
for (let i = 0; i < decodeSize; i++) {
decodeLatents.set(chunkLatents[chunkDecodedFrames + i], i * 32);
}
const latentTensor = new ort.Tensor('float32', decodeLatents, [1, decodeSize, 32]);
const decodeInputs = { latent: latentTensor, ...mimiState };
const decStart = performance.now();
const decodeResult = await mimiDecoderSession.run(decodeInputs);
const decElapsed = performance.now() - decStart;
totalDecodeTime += decElapsed;
chunkGenTimeMs += decElapsed;
const audioChunk = decodeResult[mimiDecoderSession.outputNames[0]].data;
// Update MIMI state
for (let i = 1; i < mimiDecoderSession.outputNames.length; i++) {
const outputName = mimiDecoderSession.outputNames[i];
const stateIdx = i - 1;
mimiState[`state_${stateIdx}`] = decodeResult[outputName];
}
chunkDecodedFrames += decodeSize;
totalDecodedFrames += decodeSize;
const audioFloat32 = new Float32Array(audioChunk);
const isLastChunk = shouldStop && chunkIdx === chunks.length - 1;
postMessage({
type: 'audio_chunk',
data: audioFloat32,
metrics: {
bbTime: 0,
decTime: 0,
chunkDuration: audioFloat32.length / SAMPLE_RATE,
genTimeSec: chunkGenTimeMs / 1000,
isFirst: isFirstAudioChunk,
isLast: isLastChunk,
chunkStart: isFirstAudioChunkOfTextChunk
}
}, [audioFloat32.buffer]);
isFirstAudioChunk = false;
isFirstAudioChunkOfTextChunk = false;
chunkGenTimeMs = 0;
}
if (shouldStop) {
console.log(`Chunk ${chunkIdx + 1} EOS at step ${eosStep}, stopped at step ${step}, ${chunkLatents.length} frames`);
chunkEnded = true;
break;
}
}
if (chunkEnded && isGenerating && chunkIdx < chunks.length - 1) {
const gapSamples = Math.max(1, Math.floor(CHUNK_GAP_SEC * SAMPLE_RATE));
const silence = new Float32Array(gapSamples);
postMessage({
type: 'audio_chunk',
data: silence,
metrics: {
bbTime: 0,
decTime: 0,
chunkDuration: gapSamples / SAMPLE_RATE,
isFirst: false,
isLast: false,
isSilence: true
}
}, [silence.buffer]);
}
}
const totalTime = (performance.now() - arStartTime) / 1000;
const audioSeconds = allGeneratedLatents.length * SAMPLES_PER_FRAME / SAMPLE_RATE;
// RTFx based on actual generation time (flow LM + decoder), not including conditioning
const genTime = (totalFlowLmTime + totalDecodeTime) / 1000;
const rtfx = audioSeconds / genTime;
console.log(`Generation complete: ${allGeneratedLatents.length} frames (${audioSeconds.toFixed(2)}s audio)`);
console.log(` Total time: ${totalTime.toFixed(2)}s`);
console.log(` Gen time: ${genTime.toFixed(2)}s, RTFx: ${rtfx.toFixed(2)}x`);
console.log(` Flow LM: ${(totalFlowLmTime / 1000).toFixed(2)}s (${(totalFlowLmTime / allGeneratedLatents.length).toFixed(1)}ms/step)`);
console.log(` Decoder: ${(totalDecodeTime / 1000).toFixed(2)}s`);
postMessage({
type: 'status',
status: `Finished (RTFx: ${rtfx.toFixed(2)}x)`,
state: 'idle',
metrics: { rtfx, genTime, totalTime, audioDuration: audioSeconds }
});
}
// Pre-allocated buffers for step counter updates (avoid GC pressure in hot loop)
const stepBuffers = {};
function updateStateSteps(state, increment) {
// Update step counters in state dict - reuse buffers to avoid allocation
const incBigInt = BigInt(increment);
for (const key in state) {
if (key.includes('step') && state[key]) {
const tensor = state[key];
if (tensor.data instanceof BigInt64Array) {
// Reuse buffer if same size, otherwise create new one
if (!stepBuffers[key] || stepBuffers[key].length !== tensor.data.length) {
stepBuffers[key] = new BigInt64Array(tensor.data.length);
}
const buf = stepBuffers[key];
for (let i = 0; i < tensor.data.length; i++) {
buf[i] = tensor.data[i] + incBigInt;
}
state[key] = new ort.Tensor('int64', buf, tensor.dims);
}
}
}
}