first commit

This commit is contained in:
Jingfan Ke 2024-07-08 14:15:34 +08:00
commit 8bbb96fe44
8 changed files with 303 additions and 0 deletions

7
.gitignore vendored Normal file
View File

@ -0,0 +1,7 @@
*.wav
__pycache__
api/tts/chatgpt_api_config.py
dependencies/*

0
README.md Normal file
View File

147
api/tts/main.py Normal file
View File

@ -0,0 +1,147 @@
from flask import Flask, request, send_file, jsonify
import requests
import os
import uuid
from datetime import datetime
from pydub import AudioSegment
from concurrent.futures import ThreadPoolExecutor
from queue import Queue
import json
from io import BytesIO
from chatgpt_api_config import chatgpt_apis
app = Flask(__name__)
tts_servers = [
'http://127.0.0.1:9995/tts',
'http://127.0.0.1:9996/tts'
]
tts_server_index = 0
executor = ThreadPoolExecutor(max_workers=len(tts_servers))
zh_punc = {'', '', '', '\n'}
en_punc = {'.', '?', '!', '\n'}
def merge_audio_files(base_audio, increment):
"""将多段语音拼接"""
base_audio += increment
return base_audio
def call_tts_api(server_url, response_text, language, audio):
"""调用ChatTTS API回答转语音"""
response = requests.post(
server_url,
data={
"text": response_text,
'language': language
},
files={'audio': open(audio, 'rb')}
)
if response.status_code == 200:
audio_segment = AudioSegment.from_file(file=BytesIO(response.content), format='wav')
return audio_segment
else:
print(f"Error: {response.json()['error']}")
return None
def generate_response_stream(transcription):
"""调用ChatGPT API回答问题"""
for index, chatgpt_api in enumerate(chatgpt_apis):
url = chatgpt_api['url']
api_key = chatgpt_api['key']
headers = {
"Authorization": f"Bearer {api_key}",
"Content-Type": "application/json"
}
data = {
"model": "gpt-3.5-turbo",
"messages": [
{"role": "user", "content": transcription}
],
"temperature": 0.7,
"stream": True
}
response = requests.post(url, headers=headers, json=data, stream=True)
print(f"ChatGPT API {index} Response Status Code: {response.status_code}")
if response.status_code == 200:
return response
return None
@app.route('/tts', methods=['POST'])
def tts():
global tts_server_index
unique_id = str(uuid.uuid4())
timestamp = datetime.now().strftime('%Y%m%d%H%M%S')
os.makedirs('temp', exist_ok=True)
input_audio_filename = f"input_{timestamp}_{unique_id}.wav"
input_audio_path = os.path.join('temp', input_audio_filename)
output_audio_filename = f"output_{timestamp}_{unique_id}.wav"
output_audio_path = os.path.join('temp', output_audio_filename)
base_audio = AudioSegment.silent(duration=0) # 初始化一个空音频段
collected_chunks = []
collected_messages = ['']
futures = []
audio_queue = Queue()
language = request.form['language']
response_stream = generate_response_stream(request.form['text'])
if response_stream == None:
return jsonify({"error": "Something wrong with ChatGPT API."}), 502
speaker_file = request.files['audio']
speaker_file.save(input_audio_path)
try:
for chunk in response_stream.iter_lines():
if chunk:
decoded_line = chunk.decode('utf-8')
if decoded_line.startswith('data: '):
content = decoded_line[6:]
if content.strip() == '[DONE]':
break
response_json = json.loads(content)
collected_chunks.append(response_json)
chunk_message = response_json['choices'][0]['delta']
collected_messages[-1] += chunk_message.get('content', '')
if len(collected_messages[-1]) > 0 and collected_messages[-1][-1] in (zh_punc if language == 'chinese' else en_punc):
partial_text = collected_messages[-1]
if partial_text:
print(f"{partial_text}", end="")
server_url = tts_servers[tts_server_index % len(tts_servers)]
tts_server_index += 1
future = executor.submit(call_tts_api, server_url, partial_text, language, input_audio_path)
futures.append((partial_text, future))
collected_messages.append("")
# 处理所有 future 并按顺序添加到队列中
for partial_text, future in futures:
audio_data = future.result()
if audio_data:
audio_queue.put((partial_text, audio_data))
# 拼接音频文件
while not audio_queue.empty():
_, audio_segment = audio_queue.get()
base_audio = merge_audio_files(base_audio, audio_segment)
# 将最终的音频文件保存到硬盘
base_audio.export(output_audio_path, format='wav')
print("\n")
# 返回生成的回答音频
return send_file(output_audio_path, as_attachment=True, download_name='response.wav')
finally:
if os.path.exists(input_audio_path):
os.remove(input_audio_path)
if os.path.exists(output_audio_path):
os.remove(output_audio_path)
if __name__ == '__main__':
app.run()

4
api/tts/run_tts.sh Normal file
View File

@ -0,0 +1,4 @@
#!/bin/bash
FLASK_APP=main.py FLASK_ENV=development flask run \
-h 0.0.0.0 \
-p 9992

58
api/wenet/main.py Normal file
View File

@ -0,0 +1,58 @@
from flask import Flask, request, jsonify
import wenet
import os
import uuid
from datetime import datetime
app = Flask(__name__)
# 加载wenet模型
wenet_model_cn = wenet.load_model('chinese', device='cuda')
wenet_model_en = wenet.load_model('english', device='cuda')
def transcribe_audio(audio_path, language):
"""Transcribe audio file to text using wenet."""
if language == 'chinese':
result = wenet_model_cn.transcribe(audio_path)['text']
else:
result = wenet_model_en.transcribe(audio_path)['text']
result = result.replace("", " ")
print(result)
return result
@app.route('/transcribe', methods=['POST'])
def transcribe():
if 'audio' not in request.files or 'language' not in request.form:
return jsonify({"error": "Audio file and language must be provided"}), 400
audio_file = request.files['audio']
language = request.form['language']
if language not in ['chinese', 'english']:
return jsonify({"error": "Unsupported language"}), 400
# 设置缓存音频文件地址
unique_id = str(uuid.uuid4())
timestamp = datetime.now().strftime('%Y%m%d%H%M%S')
os.makedirs('temp', exist_ok=True)
input_audio_filename = f"input_{timestamp}_{unique_id}.wav"
input_audio_path = os.path.join('temp', input_audio_filename)
audio_file.save(input_audio_path)
try:
# 使用wenet音频转文本
response_text = transcribe_audio(input_audio_path, language)
if language == "chinese":
response_text = response_text.replace("", "")
response_text = response_text.replace("*", "")
else:
response_text = response_text.replace(":", ",")
response_text = response_text.replace("*", "")
return jsonify({"text": response_text})
finally:
# 清理缓存音频文件
if os.path.exists(input_audio_path):
os.remove(input_audio_path)
if __name__ == '__main__':
app.run()

5
api/wenet/run_wenet.sh Normal file
View File

@ -0,0 +1,5 @@
export FLASK_APP=main.py
export FLASK_ENV=development
flask run \
-h 0.0.0.0 \
-p 9991

66
api/xtts/main.py Normal file
View File

@ -0,0 +1,66 @@
from flask import Flask, request, jsonify, send_file
import uuid
from datetime import datetime
from TTS.api import TTS
import os
app = Flask(__name__)
device = os.getenv('APP_DEVICE', 'cpu') # 使用环境变量获取设备
lang2short = {'english': 'en', 'chinese': 'zh-cn'}
tts = TTS("tts_models/multilingual/multi-dataset/xtts_v2", gpu=('cuda' in device)).to(device)
def generate_wav(response_text, speaker_wav, language, output_file_path):
tts.tts_to_file(
text=response_text,
speaker_wav=speaker_wav,
language=lang2short[language],
file_path=output_file_path
)
@app.route('/tts', methods=['POST'])
def generate():
if 'audio' not in request.files or 'language' not in request.form or 'text' not in request.form:
return jsonify({"error": "Speaker audio file, text and language must be provided"}), 400
speaker_wav = request.files['audio']
language = request.form['language']
text = request.form['text']
if language not in ['chinese', 'english']:
return jsonify({"error": "Unsupported language"}), 400
# 设置缓存音频文件地址
unique_id = str(uuid.uuid4())
timestamp = datetime.now().strftime('%Y%m%d%H%M%S')
os.makedirs('temp', exist_ok=True)
input_audio_filename = f"speaker_{timestamp}_{unique_id}.wav"
input_audio_path = os.path.join('temp', input_audio_filename)
output_audio_filename = f"output_{timestamp}_{unique_id}.wav"
output_audio_path = os.path.join('temp', output_audio_filename)
speaker_wav.save(input_audio_path)
try:
# 生成音频数据
generate_wav(text, input_audio_path, language, output_audio_path)
return send_file(
output_audio_path,
mimetype='audio/wav',
as_attachment=True,
download_name='generated_audio.wav'
)
finally:
# 清理缓存音频文件
if os.path.exists(input_audio_path):
os.remove(input_audio_path)
if os.path.exists(output_audio_path):
os.remove(output_audio_path)
if __name__ == '__main__':
app.run()

16
api/xtts/run_xtts.sh Normal file
View File

@ -0,0 +1,16 @@
export FLASK_APP=main.py
export FLASK_ENV=development
# Define the ports to run the application on
ports=(9995 9996)
devices=('cuda' 'cuda')
# Loop through each port and start the application
for i in "${!ports[@]}"; do
port=${ports[$i]}
device=${devices[$i]}
echo "Starting server on port $port with device $device"
APP_DEVICE=$device FLASK_APP=main.py FLASK_ENV=development flask run --port $port --host '0.0.0.0' &
done
wait