| | |
| | |
| | |
| | |
| | |
| | |
| | |
| | |
| | |
| | |
| | |
| | |
| | |
| | import os |
| | import sys |
| | ROOT_DIR = os.path.dirname(os.path.abspath(__file__)) |
| | sys.path.append('{}/../..'.format(ROOT_DIR)) |
| | sys.path.append('{}/../../third_party/AcademiCodec'.format(ROOT_DIR)) |
| | sys.path.append('{}/../../third_party/Matcha-TTS'.format(ROOT_DIR)) |
| | from concurrent import futures |
| | import argparse |
| | import cosyvoice_pb2 |
| | import cosyvoice_pb2_grpc |
| | import logging |
| | logging.getLogger('matplotlib').setLevel(logging.WARNING) |
| | import grpc |
| | import torch |
| | import numpy as np |
| | from cosyvoice.cli.cosyvoice import CosyVoice |
| |
|
| | logging.basicConfig(level=logging.DEBUG, |
| | format='%(asctime)s %(levelname)s %(message)s') |
| |
|
| | class CosyVoiceServiceImpl(cosyvoice_pb2_grpc.CosyVoiceServicer): |
| | def __init__(self, args): |
| | self.cosyvoice = CosyVoice(args.model_dir) |
| | logging.info('grpc service initialized') |
| |
|
| | def Inference(self, request, context): |
| | if request.HasField('sft_request'): |
| | logging.info('get sft inference request') |
| | model_output = self.cosyvoice.inference_sft(request.sft_request.tts_text, request.sft_request.spk_id) |
| | elif request.HasField('zero_shot_request'): |
| | logging.info('get zero_shot inference request') |
| | prompt_speech_16k = torch.from_numpy(np.array(np.frombuffer(request.zero_shot_request.prompt_audio, dtype=np.int16))).unsqueeze(dim=0) |
| | prompt_speech_16k = prompt_speech_16k.float() / (2**15) |
| | model_output = self.cosyvoice.inference_zero_shot(request.zero_shot_request.tts_text, request.zero_shot_request.prompt_text, prompt_speech_16k) |
| | elif request.HasField('cross_lingual_request'): |
| | logging.info('get cross_lingual inference request') |
| | prompt_speech_16k = torch.from_numpy(np.array(np.frombuffer(request.cross_lingual_request.prompt_audio, dtype=np.int16))).unsqueeze(dim=0) |
| | prompt_speech_16k = prompt_speech_16k.float() / (2**15) |
| | model_output = self.cosyvoice.inference_cross_lingual(request.cross_lingual_request.tts_text, prompt_speech_16k) |
| | else: |
| | logging.info('get instruct inference request') |
| | model_output = self.cosyvoice.inference_instruct(request.instruct_request.tts_text, request.instruct_request.spk_id, request.instruct_request.instruct_text) |
| |
|
| | logging.info('send inference response') |
| | response = cosyvoice_pb2.Response() |
| | response.tts_audio = (model_output['tts_speech'].numpy() * (2 ** 15)).astype(np.int16).tobytes() |
| | return response |
| |
|
| | def main(): |
| | grpcServer = grpc.server(futures.ThreadPoolExecutor(max_workers=args.max_conc), maximum_concurrent_rpcs=args.max_conc) |
| | cosyvoice_pb2_grpc.add_CosyVoiceServicer_to_server(CosyVoiceServiceImpl(args), grpcServer) |
| | grpcServer.add_insecure_port('0.0.0.0:{}'.format(args.port)) |
| | grpcServer.start() |
| | logging.info("server listening on 0.0.0.0:{}".format(args.port)) |
| | grpcServer.wait_for_termination() |
| |
|
| |
|
| | if __name__ == '__main__': |
| | parser = argparse.ArgumentParser() |
| | parser.add_argument('--port', |
| | type=int, |
| | default=50000) |
| | parser.add_argument('--max_conc', |
| | type=int, |
| | default=4) |
| | parser.add_argument('--model_dir', |
| | type=str, |
| | required=True, |
| | default='speech_tts/CosyVoice-300M', |
| | help='local path or modelscope repo id') |
| | args = parser.parse_args() |
| | main() |
| |
|