first commit
This commit is contained in:
commit
8bbb96fe44
7
.gitignore
vendored
Normal file
7
.gitignore
vendored
Normal file
@ -0,0 +1,7 @@
|
|||||||
|
*.wav
|
||||||
|
|
||||||
|
__pycache__
|
||||||
|
|
||||||
|
api/tts/chatgpt_api_config.py
|
||||||
|
|
||||||
|
dependencies/*
|
147
api/tts/main.py
Normal file
147
api/tts/main.py
Normal file
@ -0,0 +1,147 @@
|
|||||||
|
from flask import Flask, request, send_file, jsonify
|
||||||
|
import requests
|
||||||
|
import os
|
||||||
|
import uuid
|
||||||
|
from datetime import datetime
|
||||||
|
from pydub import AudioSegment
|
||||||
|
from concurrent.futures import ThreadPoolExecutor
|
||||||
|
from queue import Queue
|
||||||
|
import json
|
||||||
|
from io import BytesIO
|
||||||
|
from chatgpt_api_config import chatgpt_apis
|
||||||
|
|
||||||
|
app = Flask(__name__)
|
||||||
|
|
||||||
|
tts_servers = [
|
||||||
|
'http://127.0.0.1:9995/tts',
|
||||||
|
'http://127.0.0.1:9996/tts'
|
||||||
|
]
|
||||||
|
tts_server_index = 0
|
||||||
|
executor = ThreadPoolExecutor(max_workers=len(tts_servers))
|
||||||
|
|
||||||
|
zh_punc = {'。', '?', '!', '\n'}
|
||||||
|
en_punc = {'.', '?', '!', '\n'}
|
||||||
|
|
||||||
|
def merge_audio_files(base_audio, increment):
|
||||||
|
"""将多段语音拼接"""
|
||||||
|
base_audio += increment
|
||||||
|
return base_audio
|
||||||
|
|
||||||
|
def call_tts_api(server_url, response_text, language, audio):
|
||||||
|
"""调用ChatTTS API,回答转语音"""
|
||||||
|
response = requests.post(
|
||||||
|
server_url,
|
||||||
|
data={
|
||||||
|
"text": response_text,
|
||||||
|
'language': language
|
||||||
|
},
|
||||||
|
files={'audio': open(audio, 'rb')}
|
||||||
|
)
|
||||||
|
if response.status_code == 200:
|
||||||
|
audio_segment = AudioSegment.from_file(file=BytesIO(response.content), format='wav')
|
||||||
|
return audio_segment
|
||||||
|
else:
|
||||||
|
print(f"Error: {response.json()['error']}")
|
||||||
|
return None
|
||||||
|
|
||||||
|
|
||||||
|
def generate_response_stream(transcription):
|
||||||
|
"""调用ChatGPT API,回答问题"""
|
||||||
|
for index, chatgpt_api in enumerate(chatgpt_apis):
|
||||||
|
url = chatgpt_api['url']
|
||||||
|
api_key = chatgpt_api['key']
|
||||||
|
|
||||||
|
headers = {
|
||||||
|
"Authorization": f"Bearer {api_key}",
|
||||||
|
"Content-Type": "application/json"
|
||||||
|
}
|
||||||
|
data = {
|
||||||
|
"model": "gpt-3.5-turbo",
|
||||||
|
"messages": [
|
||||||
|
{"role": "user", "content": transcription}
|
||||||
|
],
|
||||||
|
"temperature": 0.7,
|
||||||
|
"stream": True
|
||||||
|
}
|
||||||
|
response = requests.post(url, headers=headers, json=data, stream=True)
|
||||||
|
print(f"ChatGPT API {index} Response Status Code: {response.status_code}")
|
||||||
|
if response.status_code == 200:
|
||||||
|
return response
|
||||||
|
return None
|
||||||
|
|
||||||
|
|
||||||
|
@app.route('/tts', methods=['POST'])
|
||||||
|
def tts():
|
||||||
|
global tts_server_index
|
||||||
|
|
||||||
|
unique_id = str(uuid.uuid4())
|
||||||
|
timestamp = datetime.now().strftime('%Y%m%d%H%M%S')
|
||||||
|
os.makedirs('temp', exist_ok=True)
|
||||||
|
input_audio_filename = f"input_{timestamp}_{unique_id}.wav"
|
||||||
|
input_audio_path = os.path.join('temp', input_audio_filename)
|
||||||
|
output_audio_filename = f"output_{timestamp}_{unique_id}.wav"
|
||||||
|
output_audio_path = os.path.join('temp', output_audio_filename)
|
||||||
|
|
||||||
|
base_audio = AudioSegment.silent(duration=0) # 初始化一个空音频段
|
||||||
|
|
||||||
|
collected_chunks = []
|
||||||
|
collected_messages = ['']
|
||||||
|
futures = []
|
||||||
|
audio_queue = Queue()
|
||||||
|
|
||||||
|
language = request.form['language']
|
||||||
|
response_stream = generate_response_stream(request.form['text'])
|
||||||
|
if response_stream == None:
|
||||||
|
return jsonify({"error": "Something wrong with ChatGPT API."}), 502
|
||||||
|
speaker_file = request.files['audio']
|
||||||
|
speaker_file.save(input_audio_path)
|
||||||
|
|
||||||
|
try:
|
||||||
|
for chunk in response_stream.iter_lines():
|
||||||
|
if chunk:
|
||||||
|
decoded_line = chunk.decode('utf-8')
|
||||||
|
if decoded_line.startswith('data: '):
|
||||||
|
content = decoded_line[6:]
|
||||||
|
if content.strip() == '[DONE]':
|
||||||
|
break
|
||||||
|
response_json = json.loads(content)
|
||||||
|
collected_chunks.append(response_json)
|
||||||
|
chunk_message = response_json['choices'][0]['delta']
|
||||||
|
collected_messages[-1] += chunk_message.get('content', '')
|
||||||
|
|
||||||
|
if len(collected_messages[-1]) > 0 and collected_messages[-1][-1] in (zh_punc if language == 'chinese' else en_punc):
|
||||||
|
partial_text = collected_messages[-1]
|
||||||
|
if partial_text:
|
||||||
|
print(f"{partial_text}", end="")
|
||||||
|
server_url = tts_servers[tts_server_index % len(tts_servers)]
|
||||||
|
tts_server_index += 1
|
||||||
|
future = executor.submit(call_tts_api, server_url, partial_text, language, input_audio_path)
|
||||||
|
futures.append((partial_text, future))
|
||||||
|
collected_messages.append("")
|
||||||
|
|
||||||
|
# 处理所有 future 并按顺序添加到队列中
|
||||||
|
for partial_text, future in futures:
|
||||||
|
audio_data = future.result()
|
||||||
|
if audio_data:
|
||||||
|
audio_queue.put((partial_text, audio_data))
|
||||||
|
|
||||||
|
# 拼接音频文件
|
||||||
|
while not audio_queue.empty():
|
||||||
|
_, audio_segment = audio_queue.get()
|
||||||
|
base_audio = merge_audio_files(base_audio, audio_segment)
|
||||||
|
|
||||||
|
# 将最终的音频文件保存到硬盘
|
||||||
|
base_audio.export(output_audio_path, format='wav')
|
||||||
|
print("\n")
|
||||||
|
|
||||||
|
# 返回生成的回答音频
|
||||||
|
return send_file(output_audio_path, as_attachment=True, download_name='response.wav')
|
||||||
|
finally:
|
||||||
|
if os.path.exists(input_audio_path):
|
||||||
|
os.remove(input_audio_path)
|
||||||
|
if os.path.exists(output_audio_path):
|
||||||
|
os.remove(output_audio_path)
|
||||||
|
|
||||||
|
|
||||||
|
if __name__ == '__main__':
|
||||||
|
app.run()
|
4
api/tts/run_tts.sh
Normal file
4
api/tts/run_tts.sh
Normal file
@ -0,0 +1,4 @@
|
|||||||
|
#!/bin/bash
|
||||||
|
FLASK_APP=main.py FLASK_ENV=development flask run \
|
||||||
|
-h 0.0.0.0 \
|
||||||
|
-p 9992
|
58
api/wenet/main.py
Normal file
58
api/wenet/main.py
Normal file
@ -0,0 +1,58 @@
|
|||||||
|
from flask import Flask, request, jsonify
|
||||||
|
import wenet
|
||||||
|
import os
|
||||||
|
import uuid
|
||||||
|
from datetime import datetime
|
||||||
|
|
||||||
|
app = Flask(__name__)
|
||||||
|
|
||||||
|
# 加载wenet模型
|
||||||
|
wenet_model_cn = wenet.load_model('chinese', device='cuda')
|
||||||
|
wenet_model_en = wenet.load_model('english', device='cuda')
|
||||||
|
|
||||||
|
def transcribe_audio(audio_path, language):
|
||||||
|
"""Transcribe audio file to text using wenet."""
|
||||||
|
if language == 'chinese':
|
||||||
|
result = wenet_model_cn.transcribe(audio_path)['text']
|
||||||
|
else:
|
||||||
|
result = wenet_model_en.transcribe(audio_path)['text']
|
||||||
|
result = result.replace("▁", " ")
|
||||||
|
print(result)
|
||||||
|
return result
|
||||||
|
|
||||||
|
@app.route('/transcribe', methods=['POST'])
|
||||||
|
def transcribe():
|
||||||
|
if 'audio' not in request.files or 'language' not in request.form:
|
||||||
|
return jsonify({"error": "Audio file and language must be provided"}), 400
|
||||||
|
|
||||||
|
audio_file = request.files['audio']
|
||||||
|
language = request.form['language']
|
||||||
|
|
||||||
|
if language not in ['chinese', 'english']:
|
||||||
|
return jsonify({"error": "Unsupported language"}), 400
|
||||||
|
|
||||||
|
# 设置缓存音频文件地址
|
||||||
|
unique_id = str(uuid.uuid4())
|
||||||
|
timestamp = datetime.now().strftime('%Y%m%d%H%M%S')
|
||||||
|
os.makedirs('temp', exist_ok=True)
|
||||||
|
input_audio_filename = f"input_{timestamp}_{unique_id}.wav"
|
||||||
|
input_audio_path = os.path.join('temp', input_audio_filename)
|
||||||
|
audio_file.save(input_audio_path)
|
||||||
|
|
||||||
|
try:
|
||||||
|
# 使用wenet,音频转文本
|
||||||
|
response_text = transcribe_audio(input_audio_path, language)
|
||||||
|
if language == "chinese":
|
||||||
|
response_text = response_text.replace(":", ",")
|
||||||
|
response_text = response_text.replace("*", "")
|
||||||
|
else:
|
||||||
|
response_text = response_text.replace(":", ",")
|
||||||
|
response_text = response_text.replace("*", "")
|
||||||
|
return jsonify({"text": response_text})
|
||||||
|
finally:
|
||||||
|
# 清理缓存音频文件
|
||||||
|
if os.path.exists(input_audio_path):
|
||||||
|
os.remove(input_audio_path)
|
||||||
|
|
||||||
|
if __name__ == '__main__':
|
||||||
|
app.run()
|
5
api/wenet/run_wenet.sh
Normal file
5
api/wenet/run_wenet.sh
Normal file
@ -0,0 +1,5 @@
|
|||||||
|
export FLASK_APP=main.py
|
||||||
|
export FLASK_ENV=development
|
||||||
|
flask run \
|
||||||
|
-h 0.0.0.0 \
|
||||||
|
-p 9991
|
66
api/xtts/main.py
Normal file
66
api/xtts/main.py
Normal file
@ -0,0 +1,66 @@
|
|||||||
|
from flask import Flask, request, jsonify, send_file
|
||||||
|
import uuid
|
||||||
|
from datetime import datetime
|
||||||
|
from TTS.api import TTS
|
||||||
|
import os
|
||||||
|
|
||||||
|
|
||||||
|
app = Flask(__name__)
|
||||||
|
device = os.getenv('APP_DEVICE', 'cpu') # 使用环境变量获取设备
|
||||||
|
|
||||||
|
lang2short = {'english': 'en', 'chinese': 'zh-cn'}
|
||||||
|
|
||||||
|
tts = TTS("tts_models/multilingual/multi-dataset/xtts_v2", gpu=('cuda' in device)).to(device)
|
||||||
|
|
||||||
|
|
||||||
|
def generate_wav(response_text, speaker_wav, language, output_file_path):
|
||||||
|
tts.tts_to_file(
|
||||||
|
text=response_text,
|
||||||
|
speaker_wav=speaker_wav,
|
||||||
|
language=lang2short[language],
|
||||||
|
file_path=output_file_path
|
||||||
|
)
|
||||||
|
|
||||||
|
|
||||||
|
@app.route('/tts', methods=['POST'])
|
||||||
|
def generate():
|
||||||
|
if 'audio' not in request.files or 'language' not in request.form or 'text' not in request.form:
|
||||||
|
return jsonify({"error": "Speaker audio file, text and language must be provided"}), 400
|
||||||
|
|
||||||
|
speaker_wav = request.files['audio']
|
||||||
|
language = request.form['language']
|
||||||
|
text = request.form['text']
|
||||||
|
|
||||||
|
if language not in ['chinese', 'english']:
|
||||||
|
return jsonify({"error": "Unsupported language"}), 400
|
||||||
|
|
||||||
|
# 设置缓存音频文件地址
|
||||||
|
unique_id = str(uuid.uuid4())
|
||||||
|
timestamp = datetime.now().strftime('%Y%m%d%H%M%S')
|
||||||
|
os.makedirs('temp', exist_ok=True)
|
||||||
|
input_audio_filename = f"speaker_{timestamp}_{unique_id}.wav"
|
||||||
|
input_audio_path = os.path.join('temp', input_audio_filename)
|
||||||
|
output_audio_filename = f"output_{timestamp}_{unique_id}.wav"
|
||||||
|
output_audio_path = os.path.join('temp', output_audio_filename)
|
||||||
|
speaker_wav.save(input_audio_path)
|
||||||
|
|
||||||
|
try:
|
||||||
|
# 生成音频数据
|
||||||
|
generate_wav(text, input_audio_path, language, output_audio_path)
|
||||||
|
|
||||||
|
return send_file(
|
||||||
|
output_audio_path,
|
||||||
|
mimetype='audio/wav',
|
||||||
|
as_attachment=True,
|
||||||
|
download_name='generated_audio.wav'
|
||||||
|
)
|
||||||
|
finally:
|
||||||
|
# 清理缓存音频文件
|
||||||
|
if os.path.exists(input_audio_path):
|
||||||
|
os.remove(input_audio_path)
|
||||||
|
if os.path.exists(output_audio_path):
|
||||||
|
os.remove(output_audio_path)
|
||||||
|
|
||||||
|
|
||||||
|
if __name__ == '__main__':
|
||||||
|
app.run()
|
16
api/xtts/run_xtts.sh
Normal file
16
api/xtts/run_xtts.sh
Normal file
@ -0,0 +1,16 @@
|
|||||||
|
export FLASK_APP=main.py
|
||||||
|
export FLASK_ENV=development
|
||||||
|
|
||||||
|
# Define the ports to run the application on
|
||||||
|
ports=(9995 9996)
|
||||||
|
devices=('cuda' 'cuda')
|
||||||
|
|
||||||
|
# Loop through each port and start the application
|
||||||
|
for i in "${!ports[@]}"; do
|
||||||
|
port=${ports[$i]}
|
||||||
|
device=${devices[$i]}
|
||||||
|
echo "Starting server on port $port with device $device"
|
||||||
|
APP_DEVICE=$device FLASK_APP=main.py FLASK_ENV=development flask run --port $port --host '0.0.0.0' &
|
||||||
|
done
|
||||||
|
|
||||||
|
wait
|
Loading…
x
Reference in New Issue
Block a user