first commit
This commit is contained in:
commit
8bbb96fe44
7
.gitignore
vendored
Normal file
7
.gitignore
vendored
Normal file
@ -0,0 +1,7 @@
|
||||
*.wav
|
||||
|
||||
__pycache__
|
||||
|
||||
api/tts/chatgpt_api_config.py
|
||||
|
||||
dependencies/*
|
147
api/tts/main.py
Normal file
147
api/tts/main.py
Normal file
@ -0,0 +1,147 @@
|
||||
from flask import Flask, request, send_file, jsonify
|
||||
import requests
|
||||
import os
|
||||
import uuid
|
||||
from datetime import datetime
|
||||
from pydub import AudioSegment
|
||||
from concurrent.futures import ThreadPoolExecutor
|
||||
from queue import Queue
|
||||
import json
|
||||
from io import BytesIO
|
||||
from chatgpt_api_config import chatgpt_apis
|
||||
|
||||
app = Flask(__name__)
|
||||
|
||||
tts_servers = [
|
||||
'http://127.0.0.1:9995/tts',
|
||||
'http://127.0.0.1:9996/tts'
|
||||
]
|
||||
tts_server_index = 0
|
||||
executor = ThreadPoolExecutor(max_workers=len(tts_servers))
|
||||
|
||||
zh_punc = {'。', '?', '!', '\n'}
|
||||
en_punc = {'.', '?', '!', '\n'}
|
||||
|
||||
def merge_audio_files(base_audio, increment):
|
||||
"""将多段语音拼接"""
|
||||
base_audio += increment
|
||||
return base_audio
|
||||
|
||||
def call_tts_api(server_url, response_text, language, audio):
|
||||
"""调用ChatTTS API,回答转语音"""
|
||||
response = requests.post(
|
||||
server_url,
|
||||
data={
|
||||
"text": response_text,
|
||||
'language': language
|
||||
},
|
||||
files={'audio': open(audio, 'rb')}
|
||||
)
|
||||
if response.status_code == 200:
|
||||
audio_segment = AudioSegment.from_file(file=BytesIO(response.content), format='wav')
|
||||
return audio_segment
|
||||
else:
|
||||
print(f"Error: {response.json()['error']}")
|
||||
return None
|
||||
|
||||
|
||||
def generate_response_stream(transcription):
|
||||
"""调用ChatGPT API,回答问题"""
|
||||
for index, chatgpt_api in enumerate(chatgpt_apis):
|
||||
url = chatgpt_api['url']
|
||||
api_key = chatgpt_api['key']
|
||||
|
||||
headers = {
|
||||
"Authorization": f"Bearer {api_key}",
|
||||
"Content-Type": "application/json"
|
||||
}
|
||||
data = {
|
||||
"model": "gpt-3.5-turbo",
|
||||
"messages": [
|
||||
{"role": "user", "content": transcription}
|
||||
],
|
||||
"temperature": 0.7,
|
||||
"stream": True
|
||||
}
|
||||
response = requests.post(url, headers=headers, json=data, stream=True)
|
||||
print(f"ChatGPT API {index} Response Status Code: {response.status_code}")
|
||||
if response.status_code == 200:
|
||||
return response
|
||||
return None
|
||||
|
||||
|
||||
@app.route('/tts', methods=['POST'])
|
||||
def tts():
|
||||
global tts_server_index
|
||||
|
||||
unique_id = str(uuid.uuid4())
|
||||
timestamp = datetime.now().strftime('%Y%m%d%H%M%S')
|
||||
os.makedirs('temp', exist_ok=True)
|
||||
input_audio_filename = f"input_{timestamp}_{unique_id}.wav"
|
||||
input_audio_path = os.path.join('temp', input_audio_filename)
|
||||
output_audio_filename = f"output_{timestamp}_{unique_id}.wav"
|
||||
output_audio_path = os.path.join('temp', output_audio_filename)
|
||||
|
||||
base_audio = AudioSegment.silent(duration=0) # 初始化一个空音频段
|
||||
|
||||
collected_chunks = []
|
||||
collected_messages = ['']
|
||||
futures = []
|
||||
audio_queue = Queue()
|
||||
|
||||
language = request.form['language']
|
||||
response_stream = generate_response_stream(request.form['text'])
|
||||
if response_stream == None:
|
||||
return jsonify({"error": "Something wrong with ChatGPT API."}), 502
|
||||
speaker_file = request.files['audio']
|
||||
speaker_file.save(input_audio_path)
|
||||
|
||||
try:
|
||||
for chunk in response_stream.iter_lines():
|
||||
if chunk:
|
||||
decoded_line = chunk.decode('utf-8')
|
||||
if decoded_line.startswith('data: '):
|
||||
content = decoded_line[6:]
|
||||
if content.strip() == '[DONE]':
|
||||
break
|
||||
response_json = json.loads(content)
|
||||
collected_chunks.append(response_json)
|
||||
chunk_message = response_json['choices'][0]['delta']
|
||||
collected_messages[-1] += chunk_message.get('content', '')
|
||||
|
||||
if len(collected_messages[-1]) > 0 and collected_messages[-1][-1] in (zh_punc if language == 'chinese' else en_punc):
|
||||
partial_text = collected_messages[-1]
|
||||
if partial_text:
|
||||
print(f"{partial_text}", end="")
|
||||
server_url = tts_servers[tts_server_index % len(tts_servers)]
|
||||
tts_server_index += 1
|
||||
future = executor.submit(call_tts_api, server_url, partial_text, language, input_audio_path)
|
||||
futures.append((partial_text, future))
|
||||
collected_messages.append("")
|
||||
|
||||
# 处理所有 future 并按顺序添加到队列中
|
||||
for partial_text, future in futures:
|
||||
audio_data = future.result()
|
||||
if audio_data:
|
||||
audio_queue.put((partial_text, audio_data))
|
||||
|
||||
# 拼接音频文件
|
||||
while not audio_queue.empty():
|
||||
_, audio_segment = audio_queue.get()
|
||||
base_audio = merge_audio_files(base_audio, audio_segment)
|
||||
|
||||
# 将最终的音频文件保存到硬盘
|
||||
base_audio.export(output_audio_path, format='wav')
|
||||
print("\n")
|
||||
|
||||
# 返回生成的回答音频
|
||||
return send_file(output_audio_path, as_attachment=True, download_name='response.wav')
|
||||
finally:
|
||||
if os.path.exists(input_audio_path):
|
||||
os.remove(input_audio_path)
|
||||
if os.path.exists(output_audio_path):
|
||||
os.remove(output_audio_path)
|
||||
|
||||
|
||||
if __name__ == '__main__':
|
||||
app.run()
|
4
api/tts/run_tts.sh
Normal file
4
api/tts/run_tts.sh
Normal file
@ -0,0 +1,4 @@
|
||||
#!/bin/bash
|
||||
FLASK_APP=main.py FLASK_ENV=development flask run \
|
||||
-h 0.0.0.0 \
|
||||
-p 9992
|
58
api/wenet/main.py
Normal file
58
api/wenet/main.py
Normal file
@ -0,0 +1,58 @@
|
||||
from flask import Flask, request, jsonify
|
||||
import wenet
|
||||
import os
|
||||
import uuid
|
||||
from datetime import datetime
|
||||
|
||||
app = Flask(__name__)
|
||||
|
||||
# 加载wenet模型
|
||||
wenet_model_cn = wenet.load_model('chinese', device='cuda')
|
||||
wenet_model_en = wenet.load_model('english', device='cuda')
|
||||
|
||||
def transcribe_audio(audio_path, language):
|
||||
"""Transcribe audio file to text using wenet."""
|
||||
if language == 'chinese':
|
||||
result = wenet_model_cn.transcribe(audio_path)['text']
|
||||
else:
|
||||
result = wenet_model_en.transcribe(audio_path)['text']
|
||||
result = result.replace("▁", " ")
|
||||
print(result)
|
||||
return result
|
||||
|
||||
@app.route('/transcribe', methods=['POST'])
|
||||
def transcribe():
|
||||
if 'audio' not in request.files or 'language' not in request.form:
|
||||
return jsonify({"error": "Audio file and language must be provided"}), 400
|
||||
|
||||
audio_file = request.files['audio']
|
||||
language = request.form['language']
|
||||
|
||||
if language not in ['chinese', 'english']:
|
||||
return jsonify({"error": "Unsupported language"}), 400
|
||||
|
||||
# 设置缓存音频文件地址
|
||||
unique_id = str(uuid.uuid4())
|
||||
timestamp = datetime.now().strftime('%Y%m%d%H%M%S')
|
||||
os.makedirs('temp', exist_ok=True)
|
||||
input_audio_filename = f"input_{timestamp}_{unique_id}.wav"
|
||||
input_audio_path = os.path.join('temp', input_audio_filename)
|
||||
audio_file.save(input_audio_path)
|
||||
|
||||
try:
|
||||
# 使用wenet,音频转文本
|
||||
response_text = transcribe_audio(input_audio_path, language)
|
||||
if language == "chinese":
|
||||
response_text = response_text.replace(":", ",")
|
||||
response_text = response_text.replace("*", "")
|
||||
else:
|
||||
response_text = response_text.replace(":", ",")
|
||||
response_text = response_text.replace("*", "")
|
||||
return jsonify({"text": response_text})
|
||||
finally:
|
||||
# 清理缓存音频文件
|
||||
if os.path.exists(input_audio_path):
|
||||
os.remove(input_audio_path)
|
||||
|
||||
if __name__ == '__main__':
|
||||
app.run()
|
5
api/wenet/run_wenet.sh
Normal file
5
api/wenet/run_wenet.sh
Normal file
@ -0,0 +1,5 @@
|
||||
export FLASK_APP=main.py
|
||||
export FLASK_ENV=development
|
||||
flask run \
|
||||
-h 0.0.0.0 \
|
||||
-p 9991
|
66
api/xtts/main.py
Normal file
66
api/xtts/main.py
Normal file
@ -0,0 +1,66 @@
|
||||
from flask import Flask, request, jsonify, send_file
|
||||
import uuid
|
||||
from datetime import datetime
|
||||
from TTS.api import TTS
|
||||
import os
|
||||
|
||||
|
||||
app = Flask(__name__)
|
||||
device = os.getenv('APP_DEVICE', 'cpu') # 使用环境变量获取设备
|
||||
|
||||
lang2short = {'english': 'en', 'chinese': 'zh-cn'}
|
||||
|
||||
tts = TTS("tts_models/multilingual/multi-dataset/xtts_v2", gpu=('cuda' in device)).to(device)
|
||||
|
||||
|
||||
def generate_wav(response_text, speaker_wav, language, output_file_path):
|
||||
tts.tts_to_file(
|
||||
text=response_text,
|
||||
speaker_wav=speaker_wav,
|
||||
language=lang2short[language],
|
||||
file_path=output_file_path
|
||||
)
|
||||
|
||||
|
||||
@app.route('/tts', methods=['POST'])
|
||||
def generate():
|
||||
if 'audio' not in request.files or 'language' not in request.form or 'text' not in request.form:
|
||||
return jsonify({"error": "Speaker audio file, text and language must be provided"}), 400
|
||||
|
||||
speaker_wav = request.files['audio']
|
||||
language = request.form['language']
|
||||
text = request.form['text']
|
||||
|
||||
if language not in ['chinese', 'english']:
|
||||
return jsonify({"error": "Unsupported language"}), 400
|
||||
|
||||
# 设置缓存音频文件地址
|
||||
unique_id = str(uuid.uuid4())
|
||||
timestamp = datetime.now().strftime('%Y%m%d%H%M%S')
|
||||
os.makedirs('temp', exist_ok=True)
|
||||
input_audio_filename = f"speaker_{timestamp}_{unique_id}.wav"
|
||||
input_audio_path = os.path.join('temp', input_audio_filename)
|
||||
output_audio_filename = f"output_{timestamp}_{unique_id}.wav"
|
||||
output_audio_path = os.path.join('temp', output_audio_filename)
|
||||
speaker_wav.save(input_audio_path)
|
||||
|
||||
try:
|
||||
# 生成音频数据
|
||||
generate_wav(text, input_audio_path, language, output_audio_path)
|
||||
|
||||
return send_file(
|
||||
output_audio_path,
|
||||
mimetype='audio/wav',
|
||||
as_attachment=True,
|
||||
download_name='generated_audio.wav'
|
||||
)
|
||||
finally:
|
||||
# 清理缓存音频文件
|
||||
if os.path.exists(input_audio_path):
|
||||
os.remove(input_audio_path)
|
||||
if os.path.exists(output_audio_path):
|
||||
os.remove(output_audio_path)
|
||||
|
||||
|
||||
if __name__ == '__main__':
|
||||
app.run()
|
16
api/xtts/run_xtts.sh
Normal file
16
api/xtts/run_xtts.sh
Normal file
@ -0,0 +1,16 @@
|
||||
export FLASK_APP=main.py
|
||||
export FLASK_ENV=development
|
||||
|
||||
# Define the ports to run the application on
|
||||
ports=(9995 9996)
|
||||
devices=('cuda' 'cuda')
|
||||
|
||||
# Loop through each port and start the application
|
||||
for i in "${!ports[@]}"; do
|
||||
port=${ports[$i]}
|
||||
device=${devices[$i]}
|
||||
echo "Starting server on port $port with device $device"
|
||||
APP_DEVICE=$device FLASK_APP=main.py FLASK_ENV=development flask run --port $port --host '0.0.0.0' &
|
||||
done
|
||||
|
||||
wait
|
Loading…
x
Reference in New Issue
Block a user