shreyask's picture
Upload folder using huggingface_hub
afef44c verified
import { useState, useRef, useCallback, useEffect } from "react";
import {
AutoModelForCausalLM,
AutoTokenizer,
TextStreamer,
} from "@huggingface/transformers";
const MODEL_ID = "shreyask/Maincoder-1B-ONNX-web";
interface LLMState {
isLoading: boolean;
isReady: boolean;
error: string | null;
progress: number;
}
interface LLMInstance {
model: any;
tokenizer: any;
}
let cachedInstance: LLMInstance | null = null;
let loadingPromise: Promise<LLMInstance> | null = null;
export const useLLM = () => {
const [state, setState] = useState<LLMState>({
isLoading: false,
isReady: false,
error: null,
progress: 0,
});
const instanceRef = useRef<LLMInstance | null>(null);
const pastKeyValuesRef = useRef<any>(null);
const loadModel = useCallback(async () => {
if (instanceRef.current || cachedInstance) {
const instance = instanceRef.current || cachedInstance;
instanceRef.current = instance;
cachedInstance = instance;
setState((prev) => ({ ...prev, isReady: true, isLoading: false }));
return instance;
}
if (loadingPromise) {
const instance = await loadingPromise;
instanceRef.current = instance;
cachedInstance = instance;
setState((prev) => ({ ...prev, isReady: true, isLoading: false }));
return instance;
}
setState((prev) => ({
...prev,
isLoading: true,
error: null,
progress: 0,
}));
loadingPromise = (async () => {
try {
const progress_callback = (progress: any) => {
if (
progress.status === "progress" &&
(progress.file?.endsWith(".onnx") ||
progress.file?.endsWith(".onnx_data"))
) {
const percentage = Math.round(
(progress.loaded / progress.total) * 100,
);
setState((prev) => ({ ...prev, progress: percentage }));
}
};
const tokenizer = await AutoTokenizer.from_pretrained(MODEL_ID, {
progress_callback,
});
const model = await AutoModelForCausalLM.from_pretrained(MODEL_ID, {
dtype: "q4",
device: "webgpu",
progress_callback,
});
const instance = { model, tokenizer };
instanceRef.current = instance;
cachedInstance = instance;
loadingPromise = null;
setState({
isLoading: false,
isReady: true,
error: null,
progress: 100,
});
return instance;
} catch (error) {
loadingPromise = null;
const message =
error instanceof Error ? error.message : "Failed to load model";
setState((prev) => ({
...prev,
isLoading: false,
error: message,
}));
throw error;
}
})();
return loadingPromise;
}, []);
const generateResponse = useCallback(
async (
messages: Array<{ role: string; content: string }>,
onToken?: (token: string) => void,
): Promise<string> => {
const instance = instanceRef.current;
if (!instance) {
throw new Error("Model not loaded. Call loadModel() first.");
}
const { model, tokenizer } = instance;
const input = tokenizer.apply_chat_template(messages, {
add_generation_prompt: true,
return_dict: true,
});
const streamer = onToken
? new TextStreamer(tokenizer, {
skip_prompt: true,
skip_special_tokens: true,
callback_function: onToken,
})
: undefined;
const { sequences, past_key_values } = await model.generate({
...input,
past_key_values: pastKeyValuesRef.current,
max_new_tokens: 1024,
do_sample: false,
repetition_penalty: 1.2,
eos_token_id: [151643, 151645], // <|endoftext|> and <|im_end|>
streamer,
return_dict_in_generate: true,
});
pastKeyValuesRef.current = past_key_values;
const response = tokenizer
.batch_decode(sequences.slice(null, [input.input_ids.dims[1], null]), {
skip_special_tokens: true,
})[0];
return response;
},
[],
);
const clearHistory = useCallback(() => {
pastKeyValuesRef.current = null;
}, []);
useEffect(() => {
if (cachedInstance) {
instanceRef.current = cachedInstance;
setState((prev) => ({ ...prev, isReady: true }));
}
}, []);
return {
...state,
loadModel,
generateResponse,
clearHistory,
};
};