import gradio import spaces import psaiops.common.model import psaiops.common.style import psaiops.common.tokenizer import psaiops.score.residual.app as app # META ######################################################################### app.MODEL = 'qwen/qwen3.5-9b' # additional args to use when loading the model _CONFIG = {} # frontload the model on the CPU to avoid downloading it from the GPU slot psaiops.common.model.get_model(name=app.MODEL, device='cpu', **_CONFIG) # but do not instantiate unless necessary _MODEL = None _TOKENIZER = psaiops.common.tokenizer.get_tokenizer(name=app.MODEL) # LAZY ######################################################################### def fetch_model() -> object: global _MODEL # control when the model is downloaded to avoid moving it to the CPU if _MODEL is None: _MODEL = psaiops.common.model.get_model(name=app.MODEL, device='cuda', **_CONFIG) # tuple of objects or (None, None) return _MODEL def fetch_tokenizer() -> object: global _TOKENIZER # not strictly necessary, but symmetry is everything if _TOKENIZER is None: _TOKENIZER = psaiops.common.tokenizer.get_tokenizer(name=app.MODEL) # tuple of objects or (None, None) return _TOKENIZER # EVENT HANDLERS ############################################################### def highlight_tokens( left_idx: float, right_idx: float, output_data: object, ) -> list: # do not download the model without the GPU wrapper __tokenizer = fetch_tokenizer() # fill all the arguments that cannot be pickled return app.update_token_focus( left_idx=left_idx, right_idx=right_idx, output_data=output_data, tokenizer_obj=__tokenizer) @spaces.GPU(duration=30) def compute_states( token_num: float, topk_num: float, topp_num: float, prompt_str: str, ) -> tuple: # load the model and tokenizer inside the GPU wrapper __model = fetch_model() __tokenizer = fetch_tokenizer() # fill all the arguments that cannot be pickled return app.update_computation_state( token_num=token_num, topk_num=topk_num, topp_num=topp_num, prompt_str=prompt_str, device_str='cuda', model_obj=__model, tokenizer_obj=__tokenizer) # MAIN ######################################################################### demo = app.create_app(highlight=highlight_tokens, compute=compute_states) demo.queue() demo.launch(theme=gradio.themes.Soft(), css=psaiops.common.style.ALL, debug=False)