| from functools import wraps |
| from flask import ( |
| Flask, |
| jsonify, |
| request, |
| Response, |
| render_template_string, |
| abort, |
| send_from_directory, |
| send_file, |
| ) |
| from flask_cors import CORS |
| from flask_compress import Compress |
| import markdown |
| import argparse |
| from transformers import AutoTokenizer, AutoProcessor, pipeline |
| from transformers import AutoModelForCausalLM, AutoModelForSeq2SeqLM |
| from transformers import BlipForConditionalGeneration |
| import unicodedata |
| import torch |
| import time |
| import os |
| import gc |
| import sys |
| import secrets |
| from PIL import Image |
| import base64 |
| from io import BytesIO |
| from random import randint |
| import webuiapi |
| import hashlib |
| from constants import * |
| from colorama import Fore, Style, init as colorama_init |
|
|
| colorama_init() |
|
|
| if sys.hexversion < 0x030b0000: |
| print(f"{Fore.BLUE}{Style.BRIGHT}Python 3.11 or newer is recommended to run this program.{Style.RESET_ALL}") |
| time.sleep(2) |
|
|
| class SplitArgs(argparse.Action): |
| def __call__(self, parser, namespace, values, option_string=None): |
| setattr( |
| namespace, self.dest, values.replace('"', "").replace("'", "").split(",") |
| ) |
|
|
| |
| parent_dir = os.path.dirname(os.path.abspath(__file__)) |
| SILERO_SAMPLES_PATH = os.path.join(parent_dir, "tts_samples") |
| SILERO_SAMPLE_TEXT = os.path.join(parent_dir) |
|
|
| |
| if not os.path.exists(SILERO_SAMPLES_PATH): |
| os.makedirs(SILERO_SAMPLES_PATH) |
| if not os.path.exists(SILERO_SAMPLE_TEXT): |
| os.makedirs(SILERO_SAMPLE_TEXT) |
|
|
| |
| parser = argparse.ArgumentParser( |
| prog="SillyTavern Extras", description="Web API for transformers models" |
| ) |
| parser.add_argument( |
| "--port", type=int, help="Specify the port on which the application is hosted" |
| ) |
| parser.add_argument( |
| "--listen", action="store_true", help="Host the app on the local network" |
| ) |
| parser.add_argument( |
| "--share", action="store_true", help="Share the app on CloudFlare tunnel" |
| ) |
| parser.add_argument("--cpu", action="store_true", help="Run the models on the CPU") |
| parser.add_argument("--cuda", action="store_false", dest="cpu", help="Run the models on the GPU") |
| parser.add_argument("--cuda-device", help="Specify the CUDA device to use") |
| parser.add_argument("--mps", "--apple", "--m1", "--m2", action="store_false", dest="cpu", help="Run the models on Apple Silicon") |
| parser.set_defaults(cpu=True) |
| parser.add_argument("--summarization-model", help="Load a custom summarization model") |
| parser.add_argument( |
| "--classification-model", help="Load a custom text classification model" |
| ) |
| parser.add_argument("--captioning-model", help="Load a custom captioning model") |
| parser.add_argument("--embedding-model", help="Load a custom text embedding model") |
| parser.add_argument("--chroma-host", help="Host IP for a remote ChromaDB instance") |
| parser.add_argument("--chroma-port", help="HTTP port for a remote ChromaDB instance (defaults to 8000)") |
| parser.add_argument("--chroma-folder", help="Path for chromadb persistence folder", default='.chroma_db') |
| parser.add_argument('--chroma-persist', help="ChromaDB persistence", default=True, action=argparse.BooleanOptionalAction) |
| parser.add_argument( |
| "--secure", action="store_true", help="Enforces the use of an API key" |
| ) |
| sd_group = parser.add_mutually_exclusive_group() |
|
|
| local_sd = sd_group.add_argument_group("sd-local") |
| local_sd.add_argument("--sd-model", help="Load a custom SD image generation model") |
| local_sd.add_argument("--sd-cpu", help="Force the SD pipeline to run on the CPU", action="store_true") |
|
|
| remote_sd = sd_group.add_argument_group("sd-remote") |
| remote_sd.add_argument( |
| "--sd-remote", action="store_true", help="Use a remote backend for SD" |
| ) |
| remote_sd.add_argument( |
| "--sd-remote-host", type=str, help="Specify the host of the remote SD backend" |
| ) |
| remote_sd.add_argument( |
| "--sd-remote-port", type=int, help="Specify the port of the remote SD backend" |
| ) |
| remote_sd.add_argument( |
| "--sd-remote-ssl", action="store_true", help="Use SSL for the remote SD backend" |
| ) |
| remote_sd.add_argument( |
| "--sd-remote-auth", |
| type=str, |
| help="Specify the username:password for the remote SD backend (if required)", |
| ) |
|
|
| parser.add_argument( |
| "--enable-modules", |
| action=SplitArgs, |
| default=[], |
| help="Override a list of enabled modules", |
| ) |
|
|
| args = parser.parse_args() |
| |
| port = 7860 |
| host = "0.0.0.0" |
| summarization_model = ( |
| args.summarization_model |
| if args.summarization_model |
| else DEFAULT_SUMMARIZATION_MODEL |
| ) |
| classification_model = ( |
| args.classification_model |
| if args.classification_model |
| else DEFAULT_CLASSIFICATION_MODEL |
| ) |
| captioning_model = ( |
| args.captioning_model if args.captioning_model else DEFAULT_CAPTIONING_MODEL |
| ) |
| embedding_model = ( |
| args.embedding_model if args.embedding_model else DEFAULT_EMBEDDING_MODEL |
| ) |
|
|
| sd_use_remote = False if args.sd_model else True |
| sd_model = args.sd_model if args.sd_model else DEFAULT_SD_MODEL |
| sd_remote_host = args.sd_remote_host if args.sd_remote_host else DEFAULT_REMOTE_SD_HOST |
| sd_remote_port = args.sd_remote_port if args.sd_remote_port else DEFAULT_REMOTE_SD_PORT |
| sd_remote_ssl = args.sd_remote_ssl |
| sd_remote_auth = args.sd_remote_auth |
|
|
| modules = ( |
| args.enable_modules if args.enable_modules and len(args.enable_modules) > 0 else [] |
| ) |
|
|
| if len(modules) == 0: |
| print( |
| f"{Fore.RED}{Style.BRIGHT}You did not select any modules to run! Choose them by adding an --enable-modules option" |
| ) |
| print(f"Example: --enable-modules=caption,summarize{Style.RESET_ALL}") |
|
|
| |
| cuda_device = DEFAULT_CUDA_DEVICE if not args.cuda_device else args.cuda_device |
| device_string = cuda_device if torch.cuda.is_available() and not args.cpu else 'mps' if torch.backends.mps.is_available() and not args.cpu else 'cpu' |
| device = torch.device(device_string) |
| torch_dtype = torch.float32 if device_string != cuda_device else torch.float16 |
|
|
| if not torch.cuda.is_available() and not args.cpu: |
| print(f"{Fore.YELLOW}{Style.BRIGHT}torch-cuda is not supported on this device.{Style.RESET_ALL}") |
| if not torch.backends.mps.is_available() and not args.cpu: |
| print(f"{Fore.YELLOW}{Style.BRIGHT}torch-mps is not supported on this device.{Style.RESET_ALL}") |
|
|
|
|
| print(f"{Fore.GREEN}{Style.BRIGHT}Using torch device: {device_string}{Style.RESET_ALL}") |
|
|
| if "caption" in modules: |
| print("Initializing an image captioning model...") |
| captioning_processor = AutoProcessor.from_pretrained(captioning_model) |
| if "blip" in captioning_model: |
| captioning_transformer = BlipForConditionalGeneration.from_pretrained( |
| captioning_model, torch_dtype=torch_dtype |
| ).to(device) |
| else: |
| captioning_transformer = AutoModelForCausalLM.from_pretrained( |
| captioning_model, torch_dtype=torch_dtype |
| ).to(device) |
|
|
| if "summarize" in modules: |
| print("Initializing a text summarization model...") |
| summarization_tokenizer = AutoTokenizer.from_pretrained(summarization_model) |
| summarization_transformer = AutoModelForSeq2SeqLM.from_pretrained( |
| summarization_model, torch_dtype=torch_dtype |
| ).to(device) |
|
|
| if "classify" in modules: |
| print("Initializing a sentiment classification pipeline...") |
| classification_pipe = pipeline( |
| "text-classification", |
| model=classification_model, |
| top_k=None, |
| device=device, |
| torch_dtype=torch_dtype, |
| ) |
|
|
| if "sd" in modules and not sd_use_remote: |
| from diffusers import StableDiffusionPipeline |
| from diffusers import EulerAncestralDiscreteScheduler |
|
|
| print("Initializing Stable Diffusion pipeline...") |
| sd_device_string = cuda_device if torch.cuda.is_available() else 'mps' if torch.backends.mps.is_available() else 'cpu' |
| sd_device = torch.device(sd_device_string) |
| sd_torch_dtype = torch.float32 if sd_device_string != cuda_device else torch.float16 |
| sd_pipe = StableDiffusionPipeline.from_pretrained( |
| sd_model, custom_pipeline="lpw_stable_diffusion", torch_dtype=sd_torch_dtype |
| ).to(sd_device) |
| sd_pipe.safety_checker = lambda images, clip_input: (images, False) |
| sd_pipe.enable_attention_slicing() |
| |
| sd_pipe.scheduler = EulerAncestralDiscreteScheduler.from_config( |
| sd_pipe.scheduler.config |
| ) |
| elif "sd" in modules and sd_use_remote: |
| print("Initializing Stable Diffusion connection") |
| try: |
| sd_remote = webuiapi.WebUIApi( |
| host=sd_remote_host, port=sd_remote_port, use_https=sd_remote_ssl |
| ) |
| if sd_remote_auth: |
| username, password = sd_remote_auth.split(":") |
| sd_remote.set_auth(username, password) |
| sd_remote.util_wait_for_ready() |
| except Exception as e: |
| |
| print( |
| f"{Fore.RED}{Style.BRIGHT}Could not connect to remote SD backend at http{'s' if sd_remote_ssl else ''}://{sd_remote_host}:{sd_remote_port}! Disabling SD module...{Style.RESET_ALL}" |
| ) |
| modules.remove("sd") |
|
|
| if "tts" in modules: |
| print("tts module is deprecated. Please use silero-tts instead.") |
| modules.remove("tts") |
| modules.append("silero-tts") |
|
|
|
|
| if "silero-tts" in modules: |
| if not os.path.exists(SILERO_SAMPLES_PATH): |
| os.makedirs(SILERO_SAMPLES_PATH) |
| print("Initializing Silero TTS server") |
| from silero_api_server import tts |
|
|
| tts_service = tts.SileroTtsService(SILERO_SAMPLES_PATH) |
| if len(os.listdir(SILERO_SAMPLES_PATH)) == 0: |
| print("Generating Silero TTS samples...") |
| tts_service.update_sample_text(SILERO_SAMPLE_TEXT) |
| tts_service.generate_samples() |
|
|
|
|
| if "edge-tts" in modules: |
| print("Initializing Edge TTS client") |
| import tts_edge as edge |
|
|
|
|
| if "chromadb" in modules: |
| print("Initializing ChromaDB") |
| import chromadb |
| import posthog |
| from chromadb.config import Settings |
| from sentence_transformers import SentenceTransformer |
|
|
| |
| |
| posthog.capture = lambda *args, **kwargs: None |
| if args.chroma_host is None: |
| if args.chroma_persist: |
| chromadb_client = chromadb.PersistentClient(path=args.chroma_folder, settings=Settings(anonymized_telemetry=False)) |
| print(f"ChromaDB is running in-memory with persistence. Persistence is stored in {args.chroma_folder}. Can be cleared by deleting the folder or purging db.") |
| else: |
| chromadb_client = chromadb.EphemeralClient(Settings(anonymized_telemetry=False)) |
| print(f"ChromaDB is running in-memory without persistence.") |
| else: |
| chroma_port=( |
| args.chroma_port if args.chroma_port else DEFAULT_CHROMA_PORT |
| ) |
| chromadb_client = chromadb.HttpClient(host=args.chroma_host, port=chroma_port, settings=Settings(anonymized_telemetry=False)) |
| print(f"ChromaDB is remotely configured at {args.chroma_host}:{chroma_port}") |
|
|
| chromadb_embedder = SentenceTransformer(embedding_model, device=device_string) |
| chromadb_embed_fn = lambda *args, **kwargs: chromadb_embedder.encode(*args, **kwargs).tolist() |
|
|
| |
| try: |
| chromadb_client.heartbeat() |
| print("Successfully pinged ChromaDB! Your client is successfully connected.") |
| except: |
| print("Could not ping ChromaDB! If you are running remotely, please check your host and port!") |
|
|
| |
| app = Flask(__name__) |
| CORS(app) |
| Compress(app) |
| app.config["MAX_CONTENT_LENGTH"] = 100 * 1024 * 1024 |
|
|
|
|
| def require_module(name): |
| def wrapper(fn): |
| @wraps(fn) |
| def decorated_view(*args, **kwargs): |
| if name not in modules: |
| abort(403, "Module is disabled by config") |
| return fn(*args, **kwargs) |
|
|
| return decorated_view |
|
|
| return wrapper |
|
|
|
|
| |
| def classify_text(text: str) -> list: |
| output = classification_pipe( |
| text, |
| truncation=True, |
| max_length=classification_pipe.model.config.max_position_embeddings, |
| )[0] |
| return sorted(output, key=lambda x: x["score"], reverse=True) |
|
|
|
|
| def caption_image(raw_image: Image, max_new_tokens: int = 20) -> str: |
| inputs = captioning_processor(raw_image.convert("RGB"), return_tensors="pt").to( |
| device, torch_dtype |
| ) |
| outputs = captioning_transformer.generate(**inputs, max_new_tokens=max_new_tokens) |
| caption = captioning_processor.decode(outputs[0], skip_special_tokens=True) |
| return caption |
|
|
|
|
| def summarize_chunks(text: str, params: dict) -> str: |
| try: |
| return summarize(text, params) |
| except IndexError: |
| print( |
| "Sequence length too large for model, cutting text in half and calling again" |
| ) |
| new_params = params.copy() |
| new_params["max_length"] = new_params["max_length"] // 2 |
| new_params["min_length"] = new_params["min_length"] // 2 |
| return summarize_chunks( |
| text[: (len(text) // 2)], new_params |
| ) + summarize_chunks(text[(len(text) // 2) :], new_params) |
|
|
|
|
| def summarize(text: str, params: dict) -> str: |
| |
| inputs = summarization_tokenizer(text, return_tensors="pt").to(device) |
| token_count = len(inputs[0]) |
|
|
| bad_words_ids = [ |
| summarization_tokenizer(bad_word, add_special_tokens=False).input_ids |
| for bad_word in params["bad_words"] |
| ] |
| summary_ids = summarization_transformer.generate( |
| inputs["input_ids"], |
| num_beams=2, |
| max_new_tokens=max(token_count, int(params["max_length"])), |
| min_new_tokens=min(token_count, int(params["min_length"])), |
| repetition_penalty=float(params["repetition_penalty"]), |
| temperature=float(params["temperature"]), |
| length_penalty=float(params["length_penalty"]), |
| bad_words_ids=bad_words_ids, |
| ) |
| summary = summarization_tokenizer.batch_decode( |
| summary_ids, skip_special_tokens=True, clean_up_tokenization_spaces=True |
| )[0] |
| summary = normalize_string(summary) |
| return summary |
|
|
|
|
| def normalize_string(input: str) -> str: |
| output = " ".join(unicodedata.normalize("NFKC", input).strip().split()) |
| return output |
|
|
|
|
| def generate_image(data: dict) -> Image: |
| prompt = normalize_string(f'{data["prompt_prefix"]} {data["prompt"]}') |
|
|
| if sd_use_remote: |
| image = sd_remote.txt2img( |
| prompt=prompt, |
| negative_prompt=data["negative_prompt"], |
| sampler_name=data["sampler"], |
| steps=data["steps"], |
| cfg_scale=data["scale"], |
| width=data["width"], |
| height=data["height"], |
| restore_faces=data["restore_faces"], |
| enable_hr=data["enable_hr"], |
| save_images=True, |
| send_images=True, |
| do_not_save_grid=False, |
| do_not_save_samples=False, |
| ).image |
| else: |
| image = sd_pipe( |
| prompt=prompt, |
| negative_prompt=data["negative_prompt"], |
| num_inference_steps=data["steps"], |
| guidance_scale=data["scale"], |
| width=data["width"], |
| height=data["height"], |
| ).images[0] |
|
|
| image.save("./debug.png") |
| return image |
|
|
|
|
| def image_to_base64(image: Image, quality: int = 75) -> str: |
| buffer = BytesIO() |
| image.convert("RGB") |
| image.save(buffer, format="JPEG", quality=quality) |
| img_str = base64.b64encode(buffer.getvalue()).decode("utf-8") |
| return img_str |
|
|
|
|
| ignore_auth = [] |
| |
| api_key = os.environ.get("password") |
|
|
| def is_authorize_ignored(request): |
| view_func = app.view_functions.get(request.endpoint) |
|
|
| if view_func is not None: |
| if view_func in ignore_auth: |
| return True |
| return False |
|
|
| @app.before_request |
| def before_request(): |
| |
| request.start_time = time.time() |
|
|
| |
| |
| try: |
| if request.method != 'OPTIONS' and is_authorize_ignored(request) == False and getattr(request.authorization, 'token', '') != api_key: |
| print(f"WARNING: Unauthorized API key access from {request.remote_addr}") |
| if request.method == 'POST': |
| print(f"Incoming POST request with {request.headers.get('Authorization')}") |
| response = jsonify({ 'error': '401: Invalid API key' }) |
| response.status_code = 401 |
| return "https://(hf_name)-(space_name).hf.space/" |
| except Exception as e: |
| print(f"API key check error: {e}") |
| return "https://(hf_name)-(space_name).hf.space/" |
|
|
|
|
| @app.after_request |
| def after_request(response): |
| duration = time.time() - request.start_time |
| response.headers["X-Request-Duration"] = str(duration) |
| return response |
|
|
|
|
| @app.route("/", methods=["GET"]) |
| def index(): |
| with open("./README.md", "r", encoding="utf8") as f: |
| content = f.read() |
| return render_template_string(markdown.markdown(content, extensions=["tables"])) |
|
|
|
|
| @app.route("/api/extensions", methods=["GET"]) |
| def get_extensions(): |
| extensions = dict( |
| { |
| "extensions": [ |
| { |
| "name": "not-supported", |
| "metadata": { |
| "display_name": """<span style="white-space:break-spaces;">Extensions serving using Extensions API is no longer supported. Please update the mod from: <a href="https://github.com/Cohee1207/SillyTavern">https://github.com/Cohee1207/SillyTavern</a></span>""", |
| "requires": [], |
| "assets": [], |
| }, |
| } |
| ] |
| } |
| ) |
| return jsonify(extensions) |
|
|
|
|
| @app.route("/api/caption", methods=["POST"]) |
| @require_module("caption") |
| def api_caption(): |
| data = request.get_json() |
|
|
| if "image" not in data or not isinstance(data["image"], str): |
| abort(400, '"image" is required') |
|
|
| image = Image.open(BytesIO(base64.b64decode(data["image"]))) |
| image = image.convert("RGB") |
| image.thumbnail((512, 512)) |
| caption = caption_image(image) |
| thumbnail = image_to_base64(image) |
| print("Caption:", caption, sep="\n") |
| gc.collect() |
| return jsonify({"caption": caption, "thumbnail": thumbnail}) |
|
|
|
|
| @app.route("/api/summarize", methods=["POST"]) |
| @require_module("summarize") |
| def api_summarize(): |
| data = request.get_json() |
|
|
| if "text" not in data or not isinstance(data["text"], str): |
| abort(400, '"text" is required') |
|
|
| params = DEFAULT_SUMMARIZE_PARAMS.copy() |
|
|
| if "params" in data and isinstance(data["params"], dict): |
| params.update(data["params"]) |
|
|
| print("Summary input:", data["text"], sep="\n") |
| summary = summarize_chunks(data["text"], params) |
| print("Summary output:", summary, sep="\n") |
| gc.collect() |
| return jsonify({"summary": summary}) |
|
|
|
|
| @app.route("/api/classify", methods=["POST"]) |
| @require_module("classify") |
| def api_classify(): |
| data = request.get_json() |
|
|
| if "text" not in data or not isinstance(data["text"], str): |
| abort(400, '"text" is required') |
|
|
| print("Classification input:", data["text"], sep="\n") |
| classification = classify_text(data["text"]) |
| print("Classification output:", classification, sep="\n") |
| gc.collect() |
| return jsonify({"classification": classification}) |
|
|
|
|
| @app.route("/api/classify/labels", methods=["GET"]) |
| @require_module("classify") |
| def api_classify_labels(): |
| classification = classify_text("") |
| labels = [x["label"] for x in classification] |
| return jsonify({"labels": labels}) |
|
|
|
|
| @app.route("/api/image", methods=["POST"]) |
| @require_module("sd") |
| def api_image(): |
| required_fields = { |
| "prompt": str, |
| } |
|
|
| optional_fields = { |
| "steps": 30, |
| "scale": 6, |
| "sampler": "DDIM", |
| "width": 512, |
| "height": 512, |
| "restore_faces": False, |
| "enable_hr": False, |
| "prompt_prefix": PROMPT_PREFIX, |
| "negative_prompt": NEGATIVE_PROMPT, |
| } |
|
|
| data = request.get_json() |
|
|
| |
| for field, field_type in required_fields.items(): |
| if field not in data or not isinstance(data[field], field_type): |
| abort(400, f'"{field}" is required') |
|
|
| |
| for field, default_value in optional_fields.items(): |
| type_match = ( |
| (int, float) |
| if isinstance(default_value, (int, float)) |
| else type(default_value) |
| ) |
| if field not in data or not isinstance(data[field], type_match): |
| data[field] = default_value |
|
|
| try: |
| print("SD inputs:", data, sep="\n") |
| image = generate_image(data) |
| base64image = image_to_base64(image, quality=90) |
| return jsonify({"image": base64image}) |
| except RuntimeError as e: |
| abort(400, str(e)) |
|
|
|
|
| @app.route("/api/image/model", methods=["POST"]) |
| @require_module("sd") |
| def api_image_model_set(): |
| data = request.get_json() |
|
|
| if not sd_use_remote: |
| abort(400, "Changing model for local sd is not supported.") |
| if "model" not in data or not isinstance(data["model"], str): |
| abort(400, '"model" is required') |
|
|
| old_model = sd_remote.util_get_current_model() |
| sd_remote.util_set_model(data["model"], find_closest=False) |
| |
| sd_remote.util_wait_for_ready() |
| new_model = sd_remote.util_get_current_model() |
|
|
| return jsonify({"previous_model": old_model, "current_model": new_model}) |
|
|
|
|
| @app.route("/api/image/model", methods=["GET"]) |
| @require_module("sd") |
| def api_image_model_get(): |
| model = sd_model |
|
|
| if sd_use_remote: |
| model = sd_remote.util_get_current_model() |
|
|
| return jsonify({"model": model}) |
|
|
|
|
| @app.route("/api/image/models", methods=["GET"]) |
| @require_module("sd") |
| def api_image_models(): |
| models = [sd_model] |
|
|
| if sd_use_remote: |
| models = sd_remote.util_get_model_names() |
|
|
| return jsonify({"models": models}) |
|
|
|
|
| @app.route("/api/image/samplers", methods=["GET"]) |
| @require_module("sd") |
| def api_image_samplers(): |
| samplers = ["Euler a"] |
|
|
| if sd_use_remote: |
| samplers = [sampler["name"] for sampler in sd_remote.get_samplers()] |
|
|
| return jsonify({"samplers": samplers}) |
|
|
|
|
| @app.route("/api/modules", methods=["GET"]) |
| def get_modules(): |
| return jsonify({"modules": modules}) |
|
|
|
|
| @app.route("/api/tts/speakers", methods=["GET"]) |
| @require_module("silero-tts") |
| def tts_speakers(): |
| voices = [ |
| { |
| "name": speaker, |
| "voice_id": speaker, |
| "preview_url": f"{str(request.url_root)}api/tts/sample/{speaker}", |
| } |
| for speaker in tts_service.get_speakers() |
| ] |
| return jsonify(voices) |
|
|
| |
| @app.route("/api/tts/generate", methods=["POST"]) |
| @require_module("silero-tts") |
| def tts_generate(): |
| voice = request.get_json() |
| if "text" not in voice or not isinstance(voice["text"], str): |
| abort(400, '"text" is required') |
| if "speaker" not in voice or not isinstance(voice["speaker"], str): |
| abort(400, '"speaker" is required') |
| |
| voice["text"] = voice["text"].replace("*", "") |
| try: |
| |
| if os.path.exists('test.wav'): |
| os.remove('test.wav') |
|
|
| audio = tts_service.generate(voice["speaker"], voice["text"]) |
| audio_file_path = os.path.join(os.path.dirname(os.path.abspath(__file__)), os.path.basename(audio)) |
|
|
| os.rename(audio, audio_file_path) |
| return send_file(audio_file_path, mimetype="audio/x-wav") |
| except Exception as e: |
| print(e) |
| abort(500, voice["speaker"]) |
|
|
|
|
| @app.route("/api/tts/sample/<speaker>", methods=["GET"]) |
| @require_module("silero-tts") |
| def tts_play_sample(speaker: str): |
| return send_from_directory(SILERO_SAMPLES_PATH, f"{speaker}.wav") |
|
|
|
|
| @app.route("/api/edge-tts/list", methods=["GET"]) |
| @require_module("edge-tts") |
| def edge_tts_list(): |
| voices = edge.get_voices() |
| return jsonify(voices) |
|
|
|
|
| @app.route("/api/edge-tts/generate", methods=["POST"]) |
| @require_module("edge-tts") |
| def edge_tts_generate(): |
| data = request.get_json() |
| if "text" not in data or not isinstance(data["text"], str): |
| abort(400, '"text" is required') |
| if "voice" not in data or not isinstance(data["voice"], str): |
| abort(400, '"voice" is required') |
| if "rate" in data and isinstance(data['rate'], int): |
| rate = data['rate'] |
| else: |
| rate = 0 |
| |
| data["text"] = data["text"].replace("*", "") |
| try: |
| audio = edge.generate_audio(text=data["text"], voice=data["voice"], rate=rate) |
| return Response(audio, mimetype="audio/mpeg") |
| except Exception as e: |
| print(e) |
| abort(500, data["voice"]) |
|
|
|
|
| @app.route("/api/chromadb", methods=["POST"]) |
| @require_module("chromadb") |
| def chromadb_add_messages(): |
| data = request.get_json() |
| if "chat_id" not in data or not isinstance(data["chat_id"], str): |
| abort(400, '"chat_id" is required') |
| if "messages" not in data or not isinstance(data["messages"], list): |
| abort(400, '"messages" is required') |
|
|
| chat_id_md5 = hashlib.md5(data["chat_id"].encode()).hexdigest() |
| collection = chromadb_client.get_or_create_collection( |
| name=f"chat-{chat_id_md5}", embedding_function=chromadb_embed_fn |
| ) |
|
|
| documents = [m["content"] for m in data["messages"]] |
| ids = [m["id"] for m in data["messages"]] |
| metadatas = [ |
| {"role": m["role"], "date": m["date"], "meta": m.get("meta", "")} |
| for m in data["messages"] |
| ] |
|
|
| collection.upsert( |
| ids=ids, |
| documents=documents, |
| metadatas=metadatas, |
| ) |
|
|
| return jsonify({"count": len(ids)}) |
|
|
|
|
| @app.route("/api/chromadb/purge", methods=["POST"]) |
| @require_module("chromadb") |
| def chromadb_purge(): |
| data = request.get_json() |
| if "chat_id" not in data or not isinstance(data["chat_id"], str): |
| abort(400, '"chat_id" is required') |
|
|
| chat_id_md5 = hashlib.md5(data["chat_id"].encode()).hexdigest() |
| collection = chromadb_client.get_or_create_collection( |
| name=f"chat-{chat_id_md5}", embedding_function=chromadb_embed_fn |
| ) |
|
|
| count = collection.count() |
| collection.delete() |
| print("ChromaDB embeddings deleted", count) |
| return 'Ok', 200 |
|
|
|
|
| @app.route("/api/chromadb/query", methods=["POST"]) |
| @require_module("chromadb") |
| def chromadb_query(): |
| data = request.get_json() |
| if "chat_id" not in data or not isinstance(data["chat_id"], str): |
| abort(400, '"chat_id" is required') |
| if "query" not in data or not isinstance(data["query"], str): |
| abort(400, '"query" is required') |
|
|
| if "n_results" not in data or not isinstance(data["n_results"], int): |
| n_results = 1 |
| else: |
| n_results = data["n_results"] |
|
|
| chat_id_md5 = hashlib.md5(data["chat_id"].encode()).hexdigest() |
| collection = chromadb_client.get_or_create_collection( |
| name=f"chat-{chat_id_md5}", embedding_function=chromadb_embed_fn |
| ) |
|
|
| if collection.count() == 0: |
| print(f"Queried empty/missing collection for {repr(data['chat_id'])}.") |
| return jsonify([]) |
|
|
|
|
| n_results = min(collection.count(), n_results) |
| query_result = collection.query( |
| query_texts=[data["query"]], |
| n_results=n_results, |
| ) |
|
|
| documents = query_result["documents"][0] |
| ids = query_result["ids"][0] |
| metadatas = query_result["metadatas"][0] |
| distances = query_result["distances"][0] |
|
|
| messages = [ |
| { |
| "id": ids[i], |
| "date": metadatas[i]["date"], |
| "role": metadatas[i]["role"], |
| "meta": metadatas[i]["meta"], |
| "content": documents[i], |
| "distance": distances[i], |
| } |
| for i in range(len(ids)) |
| ] |
|
|
| return jsonify(messages) |
|
|
| @app.route("/api/chromadb/multiquery", methods=["POST"]) |
| @require_module("chromadb") |
| def chromadb_multiquery(): |
| data = request.get_json() |
| if "chat_list" not in data or not isinstance(data["chat_list"], list): |
| abort(400, '"chat_list" is required and should be a list') |
| if "query" not in data or not isinstance(data["query"], str): |
| abort(400, '"query" is required') |
|
|
| if "n_results" not in data or not isinstance(data["n_results"], int): |
| n_results = 1 |
| else: |
| n_results = data["n_results"] |
|
|
| messages = [] |
|
|
| for chat_id in data["chat_list"]: |
| if not isinstance(chat_id, str): |
| continue |
|
|
| try: |
| chat_id_md5 = hashlib.md5(chat_id.encode()).hexdigest() |
| collection = chromadb_client.get_collection( |
| name=f"chat-{chat_id_md5}", embedding_function=chromadb_embed_fn |
| ) |
|
|
| |
| if collection.count() == 0: |
| continue |
|
|
| n_results_per_chat = min(collection.count(), n_results) |
| query_result = collection.query( |
| query_texts=[data["query"]], |
| n_results=n_results_per_chat, |
| ) |
| documents = query_result["documents"][0] |
| ids = query_result["ids"][0] |
| metadatas = query_result["metadatas"][0] |
| distances = query_result["distances"][0] |
|
|
| chat_messages = [ |
| { |
| "id": ids[i], |
| "date": metadatas[i]["date"], |
| "role": metadatas[i]["role"], |
| "meta": metadatas[i]["meta"], |
| "content": documents[i], |
| "distance": distances[i], |
| } |
| for i in range(len(ids)) |
| ] |
|
|
| messages.extend(chat_messages) |
| except Exception as e: |
| print(e) |
|
|
| |
| seen = set() |
| messages = [d for d in messages if not (d['content'] in seen or seen.add(d['content']))] |
| messages = sorted(messages, key=lambda x: x['distance'])[0:n_results] |
|
|
| return jsonify(messages) |
|
|
|
|
| @app.route("/api/chromadb/export", methods=["POST"]) |
| @require_module("chromadb") |
| def chromadb_export(): |
| data = request.get_json() |
| if "chat_id" not in data or not isinstance(data["chat_id"], str): |
| abort(400, '"chat_id" is required') |
|
|
| chat_id_md5 = hashlib.md5(data["chat_id"].encode()).hexdigest() |
| try: |
| collection = chromadb_client.get_collection( |
| name=f"chat-{chat_id_md5}", embedding_function=chromadb_embed_fn |
| ) |
| except Exception as e: |
| print(e) |
| abort(400, "Chat collection not found in chromadb") |
|
|
| collection_content = collection.get() |
| documents = collection_content.get('documents', []) |
| ids = collection_content.get('ids', []) |
| metadatas = collection_content.get('metadatas', []) |
|
|
| unsorted_content = [ |
| { |
| "id": ids[i], |
| "metadata": metadatas[i], |
| "document": documents[i], |
| } |
| for i in range(len(ids)) |
| ] |
|
|
| sorted_content = sorted(unsorted_content, key=lambda x: x['metadata']['date']) |
|
|
| export = { |
| "chat_id": data["chat_id"], |
| "content": sorted_content |
| } |
|
|
| return jsonify(export) |
|
|
| @app.route("/api/chromadb/import", methods=["POST"]) |
| @require_module("chromadb") |
| def chromadb_import(): |
| data = request.get_json() |
| content = data['content'] |
| if "chat_id" not in data or not isinstance(data["chat_id"], str): |
| abort(400, '"chat_id" is required') |
|
|
| chat_id_md5 = hashlib.md5(data["chat_id"].encode()).hexdigest() |
| collection = chromadb_client.get_or_create_collection( |
| name=f"chat-{chat_id_md5}", embedding_function=chromadb_embed_fn |
| ) |
|
|
| documents = [item['document'] for item in content] |
| metadatas = [item['metadata'] for item in content] |
| ids = [item['id'] for item in content] |
|
|
|
|
| collection.upsert(documents=documents, metadatas=metadatas, ids=ids) |
| print(f"Imported {len(ids)} (total {collection.count()}) content entries into {repr(data['chat_id'])}") |
|
|
| return jsonify({"count": len(ids)}) |
|
|
|
|
| if args.share: |
| from flask_cloudflared import _run_cloudflared |
| import inspect |
|
|
| sig = inspect.signature(_run_cloudflared) |
| sum = sum( |
| 1 |
| for param in sig.parameters.values() |
| if param.kind == param.POSITIONAL_OR_KEYWORD |
| ) |
| if sum > 1: |
| metrics_port = randint(8100, 9000) |
| cloudflare = _run_cloudflared(port, metrics_port) |
| else: |
| cloudflare = _run_cloudflared(port) |
| print("Running on", cloudflare) |
|
|
| ignore_auth.append(tts_play_sample) |
| app.run(host=host, port=port) |
|
|