| | from fastapi import HTTPException |
| | from pydantic import BaseModel |
| |
|
| | from modules import refiner |
| | from modules.api import utils as api_utils |
| | from modules.api.Api import APIManager |
| | from modules.normalization import text_normalize |
| |
|
| |
|
| | class RefineTextRequest(BaseModel): |
| | text: str |
| | prompt: str = "[oral_2][laugh_0][break_6]" |
| | seed: int = -1 |
| | top_P: float = 0.7 |
| | top_K: int = 20 |
| | temperature: float = 0.7 |
| | repetition_penalty: float = 1.0 |
| | max_new_token: int = 384 |
| | normalize: bool = True |
| |
|
| |
|
| | async def refiner_prompt_post(request: RefineTextRequest): |
| | """ |
| | This endpoint receives a prompt and returns the refined result |
| | """ |
| |
|
| | try: |
| | text = request.text |
| | if request.normalize: |
| | text = text_normalize(request.text) |
| | |
| | refined_text = refiner.refine_text( |
| | text=text, |
| | prompt=request.prompt, |
| | seed=request.seed, |
| | top_P=request.top_P, |
| | top_K=request.top_K, |
| | temperature=request.temperature, |
| | repetition_penalty=request.repetition_penalty, |
| | max_new_token=request.max_new_token, |
| | ) |
| | return {"message": "ok", "data": refined_text} |
| |
|
| | except Exception as e: |
| | import logging |
| |
|
| | logging.exception(e) |
| |
|
| | if isinstance(e, HTTPException): |
| | raise e |
| | else: |
| | raise HTTPException(status_code=500, detail=str(e)) |
| |
|
| |
|
| | def setup(api_manager: APIManager): |
| | api_manager.post("/v1/prompt/refine", response_model=api_utils.BaseResponse)( |
| | refiner_prompt_post |
| | ) |
| |
|