| import copy |
| from pathlib import Path |
|
|
| import gradio as gr |
| import torch |
| import yaml |
| from transformers import is_torch_xpu_available |
|
|
| import extensions |
| from modules import shared |
|
|
| with open(Path(__file__).resolve().parent / '../css/NotoSans/stylesheet.css', 'r') as f: |
| css = f.read() |
| with open(Path(__file__).resolve().parent / '../css/main.css', 'r') as f: |
| css += f.read() |
| with open(Path(__file__).resolve().parent / '../js/main.js', 'r') as f: |
| js = f.read() |
| with open(Path(__file__).resolve().parent / '../js/save_files.js', 'r') as f: |
| save_files_js = f.read() |
| with open(Path(__file__).resolve().parent / '../js/switch_tabs.js', 'r') as f: |
| switch_tabs_js = f.read() |
| with open(Path(__file__).resolve().parent / '../js/show_controls.js', 'r') as f: |
| show_controls_js = f.read() |
| with open(Path(__file__).resolve().parent / '../js/update_big_picture.js', 'r') as f: |
| update_big_picture_js = f.read() |
|
|
| refresh_symbol = '๐' |
| delete_symbol = '๐๏ธ' |
| save_symbol = '๐พ' |
|
|
| theme = gr.themes.Default( |
| font=['Noto Sans', 'Helvetica', 'ui-sans-serif', 'system-ui', 'sans-serif'], |
| font_mono=['IBM Plex Mono', 'ui-monospace', 'Consolas', 'monospace'], |
| ).set( |
| border_color_primary='#c5c5d2', |
| button_large_padding='6px 12px', |
| body_text_color_subdued='#484848', |
| background_fill_secondary='#eaeaea' |
| ) |
|
|
| if Path("notification.mp3").exists(): |
| audio_notification_js = "document.querySelector('#audio_notification audio')?.play();" |
| else: |
| audio_notification_js = "" |
|
|
|
|
| def list_model_elements(): |
| elements = [ |
| 'loader', |
| 'filter_by_loader', |
| 'cpu_memory', |
| 'auto_devices', |
| 'disk', |
| 'cpu', |
| 'bf16', |
| 'load_in_8bit', |
| 'trust_remote_code', |
| 'no_use_fast', |
| 'use_flash_attention_2', |
| 'load_in_4bit', |
| 'compute_dtype', |
| 'quant_type', |
| 'use_double_quant', |
| 'wbits', |
| 'groupsize', |
| 'model_type', |
| 'pre_layer', |
| 'triton', |
| 'desc_act', |
| 'no_inject_fused_attention', |
| 'no_inject_fused_mlp', |
| 'no_use_cuda_fp16', |
| 'disable_exllama', |
| 'disable_exllamav2', |
| 'cfg_cache', |
| 'no_flash_attn', |
| 'num_experts_per_token', |
| 'cache_8bit', |
| 'threads', |
| 'threads_batch', |
| 'n_batch', |
| 'no_mmap', |
| 'mlock', |
| 'no_mul_mat_q', |
| 'n_gpu_layers', |
| 'tensor_split', |
| 'n_ctx', |
| 'gpu_split', |
| 'max_seq_len', |
| 'compress_pos_emb', |
| 'alpha_value', |
| 'rope_freq_base', |
| 'numa', |
| 'logits_all', |
| 'no_offload_kqv', |
| 'tensorcores', |
| 'hqq_backend', |
| ] |
| if is_torch_xpu_available(): |
| for i in range(torch.xpu.device_count()): |
| elements.append(f'gpu_memory_{i}') |
| else: |
| for i in range(torch.cuda.device_count()): |
| elements.append(f'gpu_memory_{i}') |
|
|
| return elements |
|
|
|
|
| def list_interface_input_elements(): |
| elements = [ |
| 'max_new_tokens', |
| 'auto_max_new_tokens', |
| 'max_tokens_second', |
| 'max_updates_second', |
| 'seed', |
| 'temperature', |
| 'temperature_last', |
| 'dynamic_temperature', |
| 'dynatemp_low', |
| 'dynatemp_high', |
| 'dynatemp_exponent', |
| 'top_p', |
| 'min_p', |
| 'top_k', |
| 'typical_p', |
| 'epsilon_cutoff', |
| 'eta_cutoff', |
| 'repetition_penalty', |
| 'presence_penalty', |
| 'frequency_penalty', |
| 'repetition_penalty_range', |
| 'encoder_repetition_penalty', |
| 'no_repeat_ngram_size', |
| 'min_length', |
| 'do_sample', |
| 'penalty_alpha', |
| 'num_beams', |
| 'length_penalty', |
| 'early_stopping', |
| 'mirostat_mode', |
| 'mirostat_tau', |
| 'mirostat_eta', |
| 'grammar_string', |
| 'negative_prompt', |
| 'guidance_scale', |
| 'add_bos_token', |
| 'ban_eos_token', |
| 'custom_token_bans', |
| 'truncation_length', |
| 'custom_stopping_strings', |
| 'skip_special_tokens', |
| 'stream', |
| 'tfs', |
| 'top_a', |
| ] |
|
|
| |
| elements += [ |
| 'textbox', |
| 'start_with', |
| 'character_menu', |
| 'history', |
| 'name1', |
| 'name2', |
| 'greeting', |
| 'context', |
| 'mode', |
| 'custom_system_message', |
| 'instruction_template_str', |
| 'chat_template_str', |
| 'chat_style', |
| 'chat-instruct_command', |
| ] |
|
|
| |
| elements += [ |
| 'textbox-notebook', |
| 'textbox-default', |
| 'output_textbox', |
| 'prompt_menu-default', |
| 'prompt_menu-notebook', |
| ] |
|
|
| |
| elements += list_model_elements() |
|
|
| return elements |
|
|
|
|
| def gather_interface_values(*args): |
| output = {} |
| for i, element in enumerate(list_interface_input_elements()): |
| output[element] = args[i] |
|
|
| if not shared.args.multi_user: |
| shared.persistent_interface_state = output |
|
|
| return output |
|
|
|
|
| def apply_interface_values(state, use_persistent=False): |
| if use_persistent: |
| state = shared.persistent_interface_state |
|
|
| elements = list_interface_input_elements() |
| if len(state) == 0: |
| return [gr.update() for k in elements] |
| else: |
| return [state[k] if k in state else gr.update() for k in elements] |
|
|
|
|
| def save_settings(state, preset, extensions_list, show_controls, theme_state): |
| output = copy.deepcopy(shared.settings) |
| exclude = ['name2', 'greeting', 'context', 'turn_template'] |
| for k in state: |
| if k in shared.settings and k not in exclude: |
| output[k] = state[k] |
|
|
| output['preset'] = preset |
| output['prompt-default'] = state['prompt_menu-default'] |
| output['prompt-notebook'] = state['prompt_menu-notebook'] |
| output['character'] = state['character_menu'] |
| output['default_extensions'] = extensions_list |
| output['seed'] = int(output['seed']) |
| output['show_controls'] = show_controls |
| output['dark_theme'] = True if theme_state == 'dark' else False |
|
|
| |
| for extension_name in extensions_list: |
| extension = getattr(extensions, extension_name).script |
| if hasattr(extension, 'params'): |
| params = getattr(extension, 'params') |
| for param in params: |
| _id = f"{extension_name}-{param}" |
| |
| if param not in shared.default_settings or params[param] != shared.default_settings[param]: |
| output[_id] = params[param] |
|
|
| |
| for key in list(output.keys()): |
| if key in shared.default_settings and output[key] == shared.default_settings[key]: |
| output.pop(key) |
|
|
| return yaml.dump(output, sort_keys=False, width=float("inf")) |
|
|
|
|
| def create_refresh_button(refresh_component, refresh_method, refreshed_args, elem_class, interactive=True): |
| """ |
| Copied from https://github.com/AUTOMATIC1111/stable-diffusion-webui |
| """ |
| def refresh(): |
| refresh_method() |
| args = refreshed_args() if callable(refreshed_args) else refreshed_args |
|
|
| return gr.update(**(args or {})) |
|
|
| refresh_button = gr.Button(refresh_symbol, elem_classes=elem_class, interactive=interactive) |
| refresh_button.click( |
| fn=lambda: {k: tuple(v) if type(k) is list else v for k, v in refresh().items()}, |
| inputs=[], |
| outputs=[refresh_component] |
| ) |
|
|
| return refresh_button |
|
|