commit 8bbb96fe44dc1329179a5de76a86851f1d9cd1b7 Author: kejingfan Date: Mon Jul 8 14:15:34 2024 +0800 first commit diff --git a/.gitignore b/.gitignore new file mode 100644 index 0000000..e985309 --- /dev/null +++ b/.gitignore @@ -0,0 +1,7 @@ +*.wav + +__pycache__ + +api/tts/chatgpt_api_config.py + +dependencies/* \ No newline at end of file diff --git a/README.md b/README.md new file mode 100644 index 0000000..e69de29 diff --git a/api/tts/main.py b/api/tts/main.py new file mode 100644 index 0000000..8fa45ee --- /dev/null +++ b/api/tts/main.py @@ -0,0 +1,147 @@ +from flask import Flask, request, send_file, jsonify +import requests +import os +import uuid +from datetime import datetime +from pydub import AudioSegment +from concurrent.futures import ThreadPoolExecutor +from queue import Queue +import json +from io import BytesIO +from chatgpt_api_config import chatgpt_apis + +app = Flask(__name__) + +tts_servers = [ + 'http://127.0.0.1:9995/tts', + 'http://127.0.0.1:9996/tts' +] +tts_server_index = 0 +executor = ThreadPoolExecutor(max_workers=len(tts_servers)) + +zh_punc = {'。', '?', '!', '\n'} +en_punc = {'.', '?', '!', '\n'} + +def merge_audio_files(base_audio, increment): + """将多段语音拼接""" + base_audio += increment + return base_audio + +def call_tts_api(server_url, response_text, language, audio): + """调用ChatTTS API,回答转语音""" + response = requests.post( + server_url, + data={ + "text": response_text, + 'language': language + }, + files={'audio': open(audio, 'rb')} + ) + if response.status_code == 200: + audio_segment = AudioSegment.from_file(file=BytesIO(response.content), format='wav') + return audio_segment + else: + print(f"Error: {response.json()['error']}") + return None + + +def generate_response_stream(transcription): + """调用ChatGPT API,回答问题""" + for index, chatgpt_api in enumerate(chatgpt_apis): + url = chatgpt_api['url'] + api_key = chatgpt_api['key'] + + headers = { + "Authorization": f"Bearer {api_key}", + "Content-Type": "application/json" + } + data = { + "model": "gpt-3.5-turbo", + "messages": [ + {"role": "user", "content": transcription} + ], + "temperature": 0.7, + "stream": True + } + response = requests.post(url, headers=headers, json=data, stream=True) + print(f"ChatGPT API {index} Response Status Code: {response.status_code}") + if response.status_code == 200: + return response + return None + + +@app.route('/tts', methods=['POST']) +def tts(): + global tts_server_index + + unique_id = str(uuid.uuid4()) + timestamp = datetime.now().strftime('%Y%m%d%H%M%S') + os.makedirs('temp', exist_ok=True) + input_audio_filename = f"input_{timestamp}_{unique_id}.wav" + input_audio_path = os.path.join('temp', input_audio_filename) + output_audio_filename = f"output_{timestamp}_{unique_id}.wav" + output_audio_path = os.path.join('temp', output_audio_filename) + + base_audio = AudioSegment.silent(duration=0) # 初始化一个空音频段 + + collected_chunks = [] + collected_messages = [''] + futures = [] + audio_queue = Queue() + + language = request.form['language'] + response_stream = generate_response_stream(request.form['text']) + if response_stream == None: + return jsonify({"error": "Something wrong with ChatGPT API."}), 502 + speaker_file = request.files['audio'] + speaker_file.save(input_audio_path) + + try: + for chunk in response_stream.iter_lines(): + if chunk: + decoded_line = chunk.decode('utf-8') + if decoded_line.startswith('data: '): + content = decoded_line[6:] + if content.strip() == '[DONE]': + break + response_json = json.loads(content) + collected_chunks.append(response_json) + chunk_message = response_json['choices'][0]['delta'] + collected_messages[-1] += chunk_message.get('content', '') + + if len(collected_messages[-1]) > 0 and collected_messages[-1][-1] in (zh_punc if language == 'chinese' else en_punc): + partial_text = collected_messages[-1] + if partial_text: + print(f"{partial_text}", end="") + server_url = tts_servers[tts_server_index % len(tts_servers)] + tts_server_index += 1 + future = executor.submit(call_tts_api, server_url, partial_text, language, input_audio_path) + futures.append((partial_text, future)) + collected_messages.append("") + + # 处理所有 future 并按顺序添加到队列中 + for partial_text, future in futures: + audio_data = future.result() + if audio_data: + audio_queue.put((partial_text, audio_data)) + + # 拼接音频文件 + while not audio_queue.empty(): + _, audio_segment = audio_queue.get() + base_audio = merge_audio_files(base_audio, audio_segment) + + # 将最终的音频文件保存到硬盘 + base_audio.export(output_audio_path, format='wav') + print("\n") + + # 返回生成的回答音频 + return send_file(output_audio_path, as_attachment=True, download_name='response.wav') + finally: + if os.path.exists(input_audio_path): + os.remove(input_audio_path) + if os.path.exists(output_audio_path): + os.remove(output_audio_path) + + +if __name__ == '__main__': + app.run() diff --git a/api/tts/run_tts.sh b/api/tts/run_tts.sh new file mode 100644 index 0000000..5883f41 --- /dev/null +++ b/api/tts/run_tts.sh @@ -0,0 +1,4 @@ +#!/bin/bash +FLASK_APP=main.py FLASK_ENV=development flask run \ + -h 0.0.0.0 \ + -p 9992 \ No newline at end of file diff --git a/api/wenet/main.py b/api/wenet/main.py new file mode 100644 index 0000000..a648c1d --- /dev/null +++ b/api/wenet/main.py @@ -0,0 +1,58 @@ +from flask import Flask, request, jsonify +import wenet +import os +import uuid +from datetime import datetime + +app = Flask(__name__) + +# 加载wenet模型 +wenet_model_cn = wenet.load_model('chinese', device='cuda') +wenet_model_en = wenet.load_model('english', device='cuda') + +def transcribe_audio(audio_path, language): + """Transcribe audio file to text using wenet.""" + if language == 'chinese': + result = wenet_model_cn.transcribe(audio_path)['text'] + else: + result = wenet_model_en.transcribe(audio_path)['text'] + result = result.replace("▁", " ") + print(result) + return result + +@app.route('/transcribe', methods=['POST']) +def transcribe(): + if 'audio' not in request.files or 'language' not in request.form: + return jsonify({"error": "Audio file and language must be provided"}), 400 + + audio_file = request.files['audio'] + language = request.form['language'] + + if language not in ['chinese', 'english']: + return jsonify({"error": "Unsupported language"}), 400 + + # 设置缓存音频文件地址 + unique_id = str(uuid.uuid4()) + timestamp = datetime.now().strftime('%Y%m%d%H%M%S') + os.makedirs('temp', exist_ok=True) + input_audio_filename = f"input_{timestamp}_{unique_id}.wav" + input_audio_path = os.path.join('temp', input_audio_filename) + audio_file.save(input_audio_path) + + try: + # 使用wenet,音频转文本 + response_text = transcribe_audio(input_audio_path, language) + if language == "chinese": + response_text = response_text.replace(":", ",") + response_text = response_text.replace("*", "") + else: + response_text = response_text.replace(":", ",") + response_text = response_text.replace("*", "") + return jsonify({"text": response_text}) + finally: + # 清理缓存音频文件 + if os.path.exists(input_audio_path): + os.remove(input_audio_path) + +if __name__ == '__main__': + app.run() diff --git a/api/wenet/run_wenet.sh b/api/wenet/run_wenet.sh new file mode 100644 index 0000000..941943f --- /dev/null +++ b/api/wenet/run_wenet.sh @@ -0,0 +1,5 @@ +export FLASK_APP=main.py +export FLASK_ENV=development +flask run \ + -h 0.0.0.0 \ + -p 9991 \ No newline at end of file diff --git a/api/xtts/main.py b/api/xtts/main.py new file mode 100644 index 0000000..c4b5e6a --- /dev/null +++ b/api/xtts/main.py @@ -0,0 +1,66 @@ +from flask import Flask, request, jsonify, send_file +import uuid +from datetime import datetime +from TTS.api import TTS +import os + + +app = Flask(__name__) +device = os.getenv('APP_DEVICE', 'cpu') # 使用环境变量获取设备 + +lang2short = {'english': 'en', 'chinese': 'zh-cn'} + +tts = TTS("tts_models/multilingual/multi-dataset/xtts_v2", gpu=('cuda' in device)).to(device) + + +def generate_wav(response_text, speaker_wav, language, output_file_path): + tts.tts_to_file( + text=response_text, + speaker_wav=speaker_wav, + language=lang2short[language], + file_path=output_file_path + ) + + +@app.route('/tts', methods=['POST']) +def generate(): + if 'audio' not in request.files or 'language' not in request.form or 'text' not in request.form: + return jsonify({"error": "Speaker audio file, text and language must be provided"}), 400 + + speaker_wav = request.files['audio'] + language = request.form['language'] + text = request.form['text'] + + if language not in ['chinese', 'english']: + return jsonify({"error": "Unsupported language"}), 400 + + # 设置缓存音频文件地址 + unique_id = str(uuid.uuid4()) + timestamp = datetime.now().strftime('%Y%m%d%H%M%S') + os.makedirs('temp', exist_ok=True) + input_audio_filename = f"speaker_{timestamp}_{unique_id}.wav" + input_audio_path = os.path.join('temp', input_audio_filename) + output_audio_filename = f"output_{timestamp}_{unique_id}.wav" + output_audio_path = os.path.join('temp', output_audio_filename) + speaker_wav.save(input_audio_path) + + try: + # 生成音频数据 + generate_wav(text, input_audio_path, language, output_audio_path) + + return send_file( + output_audio_path, + mimetype='audio/wav', + as_attachment=True, + download_name='generated_audio.wav' + ) + finally: + # 清理缓存音频文件 + if os.path.exists(input_audio_path): + os.remove(input_audio_path) + if os.path.exists(output_audio_path): + os.remove(output_audio_path) + + +if __name__ == '__main__': + app.run() diff --git a/api/xtts/run_xtts.sh b/api/xtts/run_xtts.sh new file mode 100644 index 0000000..5159262 --- /dev/null +++ b/api/xtts/run_xtts.sh @@ -0,0 +1,16 @@ +export FLASK_APP=main.py +export FLASK_ENV=development + +# Define the ports to run the application on +ports=(9995 9996) +devices=('cuda' 'cuda') + +# Loop through each port and start the application +for i in "${!ports[@]}"; do + port=${ports[$i]} + device=${devices[$i]} + echo "Starting server on port $port with device $device" + APP_DEVICE=$device FLASK_APP=main.py FLASK_ENV=development flask run --port $port --host '0.0.0.0' & +done + +wait