Spaces:
Running
Running
| import { useState, useRef, useCallback, useEffect } from "react"; | |
| import { | |
| AutoModelForCausalLM, | |
| AutoTokenizer, | |
| TextStreamer, | |
| } from "@huggingface/transformers"; | |
| const MODEL_ID = "shreyask/Maincoder-1B-ONNX-web"; | |
| interface LLMState { | |
| isLoading: boolean; | |
| isReady: boolean; | |
| error: string | null; | |
| progress: number; | |
| } | |
| interface LLMInstance { | |
| model: any; | |
| tokenizer: any; | |
| } | |
| let cachedInstance: LLMInstance | null = null; | |
| let loadingPromise: Promise<LLMInstance> | null = null; | |
| export const useLLM = () => { | |
| const [state, setState] = useState<LLMState>({ | |
| isLoading: false, | |
| isReady: false, | |
| error: null, | |
| progress: 0, | |
| }); | |
| const instanceRef = useRef<LLMInstance | null>(null); | |
| const pastKeyValuesRef = useRef<any>(null); | |
| const loadModel = useCallback(async () => { | |
| if (instanceRef.current || cachedInstance) { | |
| const instance = instanceRef.current || cachedInstance; | |
| instanceRef.current = instance; | |
| cachedInstance = instance; | |
| setState((prev) => ({ ...prev, isReady: true, isLoading: false })); | |
| return instance; | |
| } | |
| if (loadingPromise) { | |
| const instance = await loadingPromise; | |
| instanceRef.current = instance; | |
| cachedInstance = instance; | |
| setState((prev) => ({ ...prev, isReady: true, isLoading: false })); | |
| return instance; | |
| } | |
| setState((prev) => ({ | |
| ...prev, | |
| isLoading: true, | |
| error: null, | |
| progress: 0, | |
| })); | |
| loadingPromise = (async () => { | |
| try { | |
| const progress_callback = (progress: any) => { | |
| if ( | |
| progress.status === "progress" && | |
| (progress.file?.endsWith(".onnx") || | |
| progress.file?.endsWith(".onnx_data")) | |
| ) { | |
| const percentage = Math.round( | |
| (progress.loaded / progress.total) * 100, | |
| ); | |
| setState((prev) => ({ ...prev, progress: percentage })); | |
| } | |
| }; | |
| const tokenizer = await AutoTokenizer.from_pretrained(MODEL_ID, { | |
| progress_callback, | |
| }); | |
| const model = await AutoModelForCausalLM.from_pretrained(MODEL_ID, { | |
| dtype: "q4", | |
| device: "webgpu", | |
| progress_callback, | |
| }); | |
| const instance = { model, tokenizer }; | |
| instanceRef.current = instance; | |
| cachedInstance = instance; | |
| loadingPromise = null; | |
| setState({ | |
| isLoading: false, | |
| isReady: true, | |
| error: null, | |
| progress: 100, | |
| }); | |
| return instance; | |
| } catch (error) { | |
| loadingPromise = null; | |
| const message = | |
| error instanceof Error ? error.message : "Failed to load model"; | |
| setState((prev) => ({ | |
| ...prev, | |
| isLoading: false, | |
| error: message, | |
| })); | |
| throw error; | |
| } | |
| })(); | |
| return loadingPromise; | |
| }, []); | |
| const generateResponse = useCallback( | |
| async ( | |
| messages: Array<{ role: string; content: string }>, | |
| onToken?: (token: string) => void, | |
| ): Promise<string> => { | |
| const instance = instanceRef.current; | |
| if (!instance) { | |
| throw new Error("Model not loaded. Call loadModel() first."); | |
| } | |
| const { model, tokenizer } = instance; | |
| const input = tokenizer.apply_chat_template(messages, { | |
| add_generation_prompt: true, | |
| return_dict: true, | |
| }); | |
| const streamer = onToken | |
| ? new TextStreamer(tokenizer, { | |
| skip_prompt: true, | |
| skip_special_tokens: true, | |
| callback_function: onToken, | |
| }) | |
| : undefined; | |
| const { sequences, past_key_values } = await model.generate({ | |
| ...input, | |
| past_key_values: pastKeyValuesRef.current, | |
| max_new_tokens: 1024, | |
| do_sample: false, | |
| repetition_penalty: 1.2, | |
| eos_token_id: [151643, 151645], // <|endoftext|> and <|im_end|> | |
| streamer, | |
| return_dict_in_generate: true, | |
| }); | |
| pastKeyValuesRef.current = past_key_values; | |
| const response = tokenizer | |
| .batch_decode(sequences.slice(null, [input.input_ids.dims[1], null]), { | |
| skip_special_tokens: true, | |
| })[0]; | |
| return response; | |
| }, | |
| [], | |
| ); | |
| const clearHistory = useCallback(() => { | |
| pastKeyValuesRef.current = null; | |
| }, []); | |
| useEffect(() => { | |
| if (cachedInstance) { | |
| instanceRef.current = cachedInstance; | |
| setState((prev) => ({ ...prev, isReady: true })); | |
| } | |
| }, []); | |
| return { | |
| ...state, | |
| loadModel, | |
| generateResponse, | |
| clearHistory, | |
| }; | |
| }; | |