import { useRef, useState, useCallback, type ReactNode } from "react"; import { pipeline, TextStreamer, type AutomaticSpeechRecognitionPipeline, type AutomaticSpeechRecognitionOutput, } from "@huggingface/transformers"; import { TranscriberContext, type TranscriberState, } from "./transcriberContext.ts"; const MODEL_ID = "onnx-community/cohere-transcribe-03-2026-ONNX"; export function TranscriberProvider({ children }: { children: ReactNode }) { const [status, setStatus] = useState("idle"); const [error, setError] = useState(null); const [progress, setProgress] = useState(0); const [statusText, setStatusText] = useState("Initializing..."); const pipelineRef = useRef(null); const loadingRef = useRef | null>(null); const load = useCallback(async () => { if (pipelineRef.current) return; if (loadingRef.current) return loadingRef.current; const loadPromise = (async () => { setStatus("loading"); setProgress(0); setStatusText("Downloading model..."); try { const transcriber = await pipeline( "automatic-speech-recognition", MODEL_ID, { dtype: "q4", device: "webgpu", progress_callback: (info: { status: string; progress?: number; }) => { if (info.status === "progress_total") { const pct = Math.round(info.progress ?? 0); setProgress(pct); setStatusText(`Loading model... ${pct}%`); } }, }, ); pipelineRef.current = transcriber; setProgress(100); setStatusText("Ready"); setStatus("ready"); } catch (err) { console.error("Failed to load transcription model:", err); const message = err instanceof Error ? err.message : "Failed to load model"; setStatus("error"); setError(message); setStatusText(message); } })(); loadingRef.current = loadPromise; return loadPromise; }, []); const transcribe = useCallback( async ( audio: Float32Array, language?: string, onToken?: (token: string) => void, ) => { if (!pipelineRef.current) { throw new Error("Model not loaded"); } const streamer = onToken ? new TextStreamer(pipelineRef.current.tokenizer, { skip_prompt: true, skip_special_tokens: true, callback_function: onToken, }) : undefined; const result = (await pipelineRef.current(audio, { max_new_tokens: 1024, language, streamer, })) as AutomaticSpeechRecognitionOutput; return result.text; }, [], ); return ( {children} ); }