Spaces:
Build error
Build error
chenjianfei
commited on
Commit
·
1f32878
1
Parent(s):
c0a45d8
init
Browse files- .gitattributes +1 -0
- .gitignore +37 -0
- Dockerfile +30 -0
- OpenVoice/.gitignore +13 -0
- OpenVoice/LICENSE +7 -0
- OpenVoice/README.md +70 -0
- OpenVoice/docs/QA.md +39 -0
- OpenVoice/docs/USAGE.md +83 -0
- OpenVoice/openvoice/__init__.py +0 -0
- OpenVoice/openvoice/api.py +202 -0
- OpenVoice/openvoice/attentions.py +465 -0
- OpenVoice/openvoice/commons.py +160 -0
- OpenVoice/openvoice/mel_processing.py +183 -0
- OpenVoice/openvoice/models.py +499 -0
- OpenVoice/openvoice/modules.py +598 -0
- OpenVoice/openvoice/openvoice_app.py +275 -0
- OpenVoice/openvoice/se_extractor.py +153 -0
- OpenVoice/openvoice/text/__init__.py +79 -0
- OpenVoice/openvoice/text/cleaners.py +16 -0
- OpenVoice/openvoice/text/english.py +188 -0
- OpenVoice/openvoice/text/mandarin.py +326 -0
- OpenVoice/openvoice/text/symbols.py +88 -0
- OpenVoice/openvoice/transforms.py +209 -0
- OpenVoice/openvoice/utils.py +194 -0
- OpenVoice/resources/framework-ipa.png +3 -0
- OpenVoice/resources/huggingface.png +3 -0
- OpenVoice/resources/lepton-hd.png +3 -0
- OpenVoice/resources/myshell-hd.png +3 -0
- OpenVoice/resources/tts-guide.png +3 -0
- OpenVoice/resources/voice-clone-guide.png +3 -0
- OpenVoice/setup.py +45 -0
- app.py +548 -0
- config.py +114 -0
- knowledge_base.py +245 -0
- rag/kb/周杰伦/周杰伦全部歌曲.md +341 -0
- rag/kb/周杰伦/周杰伦全部歌词.md +0 -0
- rag/kb/周杰伦/周杰伦基本资料.md +153 -0
- tts_api.py +274 -0
- utils.py +197 -0
.gitattributes
CHANGED
@@ -33,3 +33,4 @@ saved_model/**/* filter=lfs diff=lfs merge=lfs -text
|
|
33 |
*.zip filter=lfs diff=lfs merge=lfs -text
|
34 |
*.zst filter=lfs diff=lfs merge=lfs -text
|
35 |
*tfevents* filter=lfs diff=lfs merge=lfs -text
|
|
|
|
33 |
*.zip filter=lfs diff=lfs merge=lfs -text
|
34 |
*.zst filter=lfs diff=lfs merge=lfs -text
|
35 |
*tfevents* filter=lfs diff=lfs merge=lfs -text
|
36 |
+
*.png filter=lfs diff=lfs merge=lfs -text
|
.gitignore
ADDED
@@ -0,0 +1,37 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
# general things to ignore
|
2 |
+
.DS_Store
|
3 |
+
build/
|
4 |
+
build_contrib/
|
5 |
+
dist/
|
6 |
+
.cache/
|
7 |
+
*.egg-info/
|
8 |
+
*.egg
|
9 |
+
*.py[cod]
|
10 |
+
__pycache__/
|
11 |
+
*.so
|
12 |
+
*~
|
13 |
+
|
14 |
+
# IDE
|
15 |
+
.vscode/
|
16 |
+
|
17 |
+
# misc
|
18 |
+
checkpoints/
|
19 |
+
test_waves/
|
20 |
+
reconstructed/
|
21 |
+
.python-version
|
22 |
+
ruff.log
|
23 |
+
/configs/inuse/
|
24 |
+
*.wav
|
25 |
+
*.ogg
|
26 |
+
*.mp3
|
27 |
+
demo_dir/*
|
28 |
+
*.pt
|
29 |
+
*.json
|
30 |
+
*.txt
|
31 |
+
*.ipynb
|
32 |
+
asset/
|
33 |
+
*.csv
|
34 |
+
*.xlsx
|
35 |
+
*.jpg
|
36 |
+
*.log
|
37 |
+
*.pth
|
Dockerfile
ADDED
@@ -0,0 +1,30 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
# 第一阶段:构建 OpenVoice 环境(锁定 numpy==1.22.0)
|
2 |
+
FROM python:3.9-slim AS openvoice-builder
|
3 |
+
|
4 |
+
# 安装系统依赖
|
5 |
+
RUN apt-get update && apt-get install -y \
|
6 |
+
ffmpeg libsndfile1 git \
|
7 |
+
&& rm -rf /var/lib/apt/lists/*
|
8 |
+
|
9 |
+
# 强制安装 numpy==1.22.0 并构建 OpenVoice
|
10 |
+
WORKDIR /app
|
11 |
+
COPY OpenVoice/ ./OpenVoice
|
12 |
+
RUN pip install numpy==1.22.0 && \
|
13 |
+
pip install -e ./OpenVoice
|
14 |
+
|
15 |
+
|
16 |
+
# 第二阶段:构建主程序环境(允许 numpy>=2.0.0)
|
17 |
+
FROM python:3.9-slim AS main-app
|
18 |
+
|
19 |
+
# 安装主程序依赖(隔离于 OpenVoice 的环境)
|
20 |
+
WORKDIR /app
|
21 |
+
COPY requirements.txt .
|
22 |
+
RUN pip install --no-cache-dir -r requirements.txt
|
23 |
+
|
24 |
+
# 从第一阶段复制已编译的 OpenVoice 和兼容库
|
25 |
+
COPY --from=openvoice-builder /usr/local/lib/python3.9/site-packages /usr/local/lib/python3.9/site-packages
|
26 |
+
COPY --from=openvoice-builder /app/OpenVoice /app/OpenVoice
|
27 |
+
|
28 |
+
# 暴露端口并启动
|
29 |
+
EXPOSE 7860
|
30 |
+
CMD ["python", "app.py"]
|
OpenVoice/.gitignore
ADDED
@@ -0,0 +1,13 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
__pycache__/
|
2 |
+
.ipynb_checkpoints/
|
3 |
+
processed
|
4 |
+
outputs
|
5 |
+
outputs_v2
|
6 |
+
checkpoints
|
7 |
+
checkpoints_v2
|
8 |
+
trash
|
9 |
+
examples*
|
10 |
+
.env
|
11 |
+
build
|
12 |
+
*.egg-info/
|
13 |
+
*.zip
|
OpenVoice/LICENSE
ADDED
@@ -0,0 +1,7 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
Copyright 2024 MyShell.ai
|
2 |
+
|
3 |
+
Permission is hereby granted, free of charge, to any person obtaining a copy of this software and associated documentation files (the “Software”), to deal in the Software without restriction, including without limitation the rights to use, copy, modify, merge, publish, distribute, sublicense, and/or sell copies of the Software, and to permit persons to whom the Software is furnished to do so, subject to the following conditions:
|
4 |
+
|
5 |
+
The above copyright notice and this permission notice shall be included in all copies or substantial portions of the Software.
|
6 |
+
|
7 |
+
THE SOFTWARE IS PROVIDED “AS IS”, WITHOUT WARRANTY OF ANY KIND, EXPRESS OR IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY, FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM, OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN THE SOFTWARE.
|
OpenVoice/README.md
ADDED
@@ -0,0 +1,70 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
<div align="center">
|
2 |
+
<div> </div>
|
3 |
+
<img src="resources/openvoicelogo.jpg" width="400"/>
|
4 |
+
|
5 |
+
[Paper](https://arxiv.org/abs/2312.01479) |
|
6 |
+
[Website](https://research.myshell.ai/open-voice) <br> <br>
|
7 |
+
<a href="https://trendshift.io/repositories/6161" target="_blank"><img src="https://trendshift.io/api/badge/repositories/6161" alt="myshell-ai%2FOpenVoice | Trendshift" style="width: 250px; height: 55px;" width="250" height="55"/></a>
|
8 |
+
</div>
|
9 |
+
|
10 |
+
## Introduction
|
11 |
+
|
12 |
+
### OpenVoice V1
|
13 |
+
|
14 |
+
As we detailed in our [paper](https://arxiv.org/abs/2312.01479) and [website](https://research.myshell.ai/open-voice), the advantages of OpenVoice are three-fold:
|
15 |
+
|
16 |
+
**1. Accurate Tone Color Cloning.**
|
17 |
+
OpenVoice can accurately clone the reference tone color and generate speech in multiple languages and accents.
|
18 |
+
|
19 |
+
**2. Flexible Voice Style Control.**
|
20 |
+
OpenVoice enables granular control over voice styles, such as emotion and accent, as well as other style parameters including rhythm, pauses, and intonation.
|
21 |
+
|
22 |
+
**3. Zero-shot Cross-lingual Voice Cloning.**
|
23 |
+
Neither of the language of the generated speech nor the language of the reference speech needs to be presented in the massive-speaker multi-lingual training dataset.
|
24 |
+
|
25 |
+
### OpenVoice V2
|
26 |
+
|
27 |
+
In April 2024, we released OpenVoice V2, which includes all features in V1 and has:
|
28 |
+
|
29 |
+
**1. Better Audio Quality.**
|
30 |
+
OpenVoice V2 adopts a different training strategy that delivers better audio quality.
|
31 |
+
|
32 |
+
**2. Native Multi-lingual Support.**
|
33 |
+
English, Spanish, French, Chinese, Japanese and Korean are natively supported in OpenVoice V2.
|
34 |
+
|
35 |
+
**3. Free Commercial Use.**
|
36 |
+
Starting from April 2024, both V2 and V1 are released under MIT License. Free for commercial use.
|
37 |
+
|
38 |
+
[Video](https://github.com/myshell-ai/OpenVoice/assets/40556743/3cba936f-82bf-476c-9e52-09f0f417bb2f)
|
39 |
+
|
40 |
+
OpenVoice has been powering the instant voice cloning capability of [myshell.ai](https://app.myshell.ai/explore) since May 2023. Until Nov 2023, the voice cloning model has been used tens of millions of times by users worldwide, and witnessed the explosive user growth on the platform.
|
41 |
+
|
42 |
+
## Main Contributors
|
43 |
+
|
44 |
+
- [Zengyi Qin](https://www.qinzy.tech) at MIT
|
45 |
+
- [Wenliang Zhao](https://wl-zhao.github.io) at Tsinghua University
|
46 |
+
- [Xumin Yu](https://yuxumin.github.io) at Tsinghua University
|
47 |
+
- [Ethan Sun](https://twitter.com/ethan_myshell) at MyShell
|
48 |
+
|
49 |
+
## How to Use
|
50 |
+
Please see [usage](docs/USAGE.md) for detailed instructions.
|
51 |
+
|
52 |
+
## Common Issues
|
53 |
+
|
54 |
+
Please see [QA](docs/QA.md) for common questions and answers. We will regularly update the question and answer list.
|
55 |
+
|
56 |
+
## Citation
|
57 |
+
```
|
58 |
+
@article{qin2023openvoice,
|
59 |
+
title={OpenVoice: Versatile Instant Voice Cloning},
|
60 |
+
author={Qin, Zengyi and Zhao, Wenliang and Yu, Xumin and Sun, Xin},
|
61 |
+
journal={arXiv preprint arXiv:2312.01479},
|
62 |
+
year={2023}
|
63 |
+
}
|
64 |
+
```
|
65 |
+
|
66 |
+
## License
|
67 |
+
OpenVoice V1 and V2 are MIT Licensed. Free for both commercial and research use.
|
68 |
+
|
69 |
+
## Acknowledgements
|
70 |
+
This implementation is based on several excellent projects, [TTS](https://github.com/coqui-ai/TTS), [VITS](https://github.com/jaywalnut310/vits), and [VITS2](https://github.com/daniilrobnikov/vits2). Thanks for their awesome work!
|
OpenVoice/docs/QA.md
ADDED
@@ -0,0 +1,39 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
# Common Questions and Answers
|
2 |
+
|
3 |
+
## General Comments
|
4 |
+
|
5 |
+
**OpenVoice is a Technology, not a Product**
|
6 |
+
|
7 |
+
Although it works on a majority of voices if used correctly, please do not expect it to work perfectly on every case, as it takes a lot of engineering effort to translate a technology to a stable product. The targeted users of this technology are developers and researchers, not end users. End users expects a perfect product. However, we are confident to say that OpenVoice is the state-of-the-art among the source-available voice cloning technologies.
|
8 |
+
|
9 |
+
The contribution of OpenVoice is a versatile instant voice cloning technical approach, not a ready-to-use perfect voice cloning product. However, we firmly believe that by releasing OpenVoice, we can accelerate the open research community's progress on instant voice cloning, and someday in the future the free voice cloning methods will be as good as commercial ones.
|
10 |
+
|
11 |
+
## Issues with Voice Quality
|
12 |
+
|
13 |
+
**Accent and Emotion of the Generated Voice is not Similar to the Reference Voice**
|
14 |
+
|
15 |
+
First of all, OpenVoice only clones the tone color of the reference speaker. It does NOT clone the accent or emotion. The accent and emotion is controlled by the base speaker TTS model, not cloned by the tone color converter (please refer to our [paper](https://arxiv.org/pdf/2312.01479.pdf) for technical details). If the user wants to change the accent or emotion of the output, they need to have a base speaker model with that accent. OpenVoice provides sufficient flexibility for users to integrate their own base speaker model into the framework by simply replacing the current base speaker we provided.
|
16 |
+
|
17 |
+
**Bad Audio Quality of the Generated Speech**
|
18 |
+
|
19 |
+
Please check the followings:
|
20 |
+
- Is your reference audio is clean enough without any background noise? You can find some high-quality reference speech [here](https://aiartes.com/voiceai)
|
21 |
+
- Is your audio too short?
|
22 |
+
- Does your audio contain speech from more than one person?
|
23 |
+
- Does the reference audio contain long blank sections?
|
24 |
+
- Did you name the reference audio the same name you used before but forgot to delete the `processed` folder?
|
25 |
+
|
26 |
+
## Issues with Languages
|
27 |
+
|
28 |
+
**Support of Other Languages**
|
29 |
+
|
30 |
+
For multi-lingual and cross-lingual usage, please refer to [`demo_part2.ipynb`](https://github.com/myshell-ai/OpenVoice/blob/main/demo_part2.ipynb). OpenVoice supports any language as long as you have a base speaker in that language. The OpenVoice team already did the most difficult part (tone color converter training) for you. Base speaker TTS model is relatively easy to train, and multiple existing open-source repositories support it. If you don't want to train by yourself, simply use the OpenAI TTS model as the base speaker.
|
31 |
+
|
32 |
+
## Issues with Installation
|
33 |
+
**Error Related to Silero**
|
34 |
+
|
35 |
+
When calling `get_vad_segments` from `se_extractor.py`, there should be a message like this:
|
36 |
+
```
|
37 |
+
Downloading: "https://github.com/snakers4/silero-vad/zipball/master" to /home/user/.cache/torch/hub/master.zip
|
38 |
+
```
|
39 |
+
The download would fail if your machine can not access github. Please download the zip from "https://github.com/snakers4/silero-vad/zipball/master" manually and unzip it to `/home/user/.cache/torch/hub/snakers4_silero-vad_master`. You can also see [this issue](https://github.com/myshell-ai/OpenVoice/issues/57) for solutions for other versions of silero.
|
OpenVoice/docs/USAGE.md
ADDED
@@ -0,0 +1,83 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
# Usage
|
2 |
+
|
3 |
+
## Table of Content
|
4 |
+
|
5 |
+
- [Quick Use](#quick-use): directly use OpenVoice without installation.
|
6 |
+
- [Linux Install](#linux-install): for researchers and developers only.
|
7 |
+
- [V1](#openvoice-v1)
|
8 |
+
- [V2](#openvoice-v2)
|
9 |
+
- [Install on Other Platforms](#install-on-other-platforms): unofficial installation guide contributed by the community
|
10 |
+
|
11 |
+
## Quick Use
|
12 |
+
|
13 |
+
The input speech audio of OpenVoice can be in **Any Language**. OpenVoice can clone the voice in that speech audio, and use the voice to speak in multiple languages. For quick use, we recommend you to try the already deployed services:
|
14 |
+
|
15 |
+
- [British English](https://app.myshell.ai/widget/vYjqae)
|
16 |
+
- [American English](https://app.myshell.ai/widget/nEFFJf)
|
17 |
+
- [Indian English](https://app.myshell.ai/widget/V3iYze)
|
18 |
+
- [Australian English](https://app.myshell.ai/widget/fM7JVf)
|
19 |
+
- [Spanish](https://app.myshell.ai/widget/NNFFVz)
|
20 |
+
- [French](https://app.myshell.ai/widget/z2uyUz)
|
21 |
+
- [Chinese](https://app.myshell.ai/widget/fU7nUz)
|
22 |
+
- [Japanese](https://app.myshell.ai/widget/IfIB3u)
|
23 |
+
- [Korean](https://app.myshell.ai/widget/q6ZjIn)
|
24 |
+
|
25 |
+
## Minimal Demo
|
26 |
+
|
27 |
+
For users who want to quickly try OpenVoice and do not require high quality or stability, click any of the following links:
|
28 |
+
|
29 |
+
<div align="center">
|
30 |
+
<a href="https://app.myshell.ai/bot/z6Bvua/1702636181"><img src="../resources/myshell-hd.png" height="28"></a>
|
31 |
+
|
32 |
+
<a href="https://huggingface.co/spaces/myshell-ai/OpenVoice"><img src="../resources/huggingface.png" height="32"></a>
|
33 |
+
</div>
|
34 |
+
|
35 |
+
## Linux Install
|
36 |
+
|
37 |
+
This section is only for developers and researchers who are familiar with Linux, Python and PyTorch. Clone this repo, and run
|
38 |
+
|
39 |
+
```
|
40 |
+
conda create -n openvoice python=3.9
|
41 |
+
conda activate openvoice
|
42 |
+
git clone git@github.com:myshell-ai/OpenVoice.git
|
43 |
+
cd OpenVoice
|
44 |
+
pip install -e .
|
45 |
+
```
|
46 |
+
|
47 |
+
No matter if you are using V1 or V2, the above installation is the same.
|
48 |
+
|
49 |
+
### OpenVoice V1
|
50 |
+
|
51 |
+
Download the checkpoint from [here](https://myshell-public-repo-host.s3.amazonaws.com/openvoice/checkpoints_1226.zip) and extract it to the `checkpoints` folder.
|
52 |
+
|
53 |
+
**1. Flexible Voice Style Control.**
|
54 |
+
Please see [`demo_part1.ipynb`](../demo_part1.ipynb) for an example usage of how OpenVoice enables flexible style control over the cloned voice.
|
55 |
+
|
56 |
+
**2. Cross-Lingual Voice Cloning.**
|
57 |
+
Please see [`demo_part2.ipynb`](../demo_part2.ipynb) for an example for languages seen or unseen in the MSML training set.
|
58 |
+
|
59 |
+
**3. Gradio Demo.**. We provide a minimalist local gradio demo here. We strongly suggest the users to look into `demo_part1.ipynb`, `demo_part2.ipynb` and the [QnA](QA.md) if they run into issues with the gradio demo. Launch a local gradio demo with `python -m openvoice_app --share`.
|
60 |
+
|
61 |
+
### OpenVoice V2
|
62 |
+
|
63 |
+
Download the checkpoint from [here](https://myshell-public-repo-host.s3.amazonaws.com/openvoice/checkpoints_v2_0417.zip) and extract it to the `checkpoints_v2` folder.
|
64 |
+
|
65 |
+
Install [MeloTTS](https://github.com/myshell-ai/MeloTTS):
|
66 |
+
```
|
67 |
+
pip install git+https://github.com/myshell-ai/MeloTTS.git
|
68 |
+
python -m unidic download
|
69 |
+
```
|
70 |
+
|
71 |
+
**Demo Usage.** Please see [`demo_part3.ipynb`](../demo_part3.ipynb) for example usage of OpenVoice V2. Now it natively supports English, Spanish, French, Chinese, Japanese and Korean.
|
72 |
+
|
73 |
+
|
74 |
+
## Install on Other Platforms
|
75 |
+
|
76 |
+
This section provides the unofficial installation guides by open-source contributors in the community:
|
77 |
+
|
78 |
+
- Windows
|
79 |
+
- [Guide](https://github.com/Alienpups/OpenVoice/blob/main/docs/USAGE_WINDOWS.md) by [@Alienpups](https://github.com/Alienpups)
|
80 |
+
- You are welcome to contribute if you have a better installation guide. We will list you here.
|
81 |
+
- Docker
|
82 |
+
- [Guide](https://github.com/StevenJSCF/OpenVoice/blob/update-docs/docs/DF_USAGE.md) by [@StevenJSCF](https://github.com/StevenJSCF)
|
83 |
+
- You are welcome to contribute if you have a better installation guide. We will list you here.
|
OpenVoice/openvoice/__init__.py
ADDED
File without changes
|
OpenVoice/openvoice/api.py
ADDED
@@ -0,0 +1,202 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
import torch
|
2 |
+
import numpy as np
|
3 |
+
import re
|
4 |
+
import soundfile
|
5 |
+
from openvoice import utils
|
6 |
+
from openvoice import commons
|
7 |
+
import os
|
8 |
+
import librosa
|
9 |
+
from openvoice.text import text_to_sequence
|
10 |
+
from openvoice.mel_processing import spectrogram_torch
|
11 |
+
from openvoice.models import SynthesizerTrn
|
12 |
+
|
13 |
+
|
14 |
+
class OpenVoiceBaseClass(object):
|
15 |
+
def __init__(self,
|
16 |
+
config_path,
|
17 |
+
device='cuda:0'):
|
18 |
+
if 'cuda' in device:
|
19 |
+
assert torch.cuda.is_available()
|
20 |
+
|
21 |
+
hps = utils.get_hparams_from_file(config_path)
|
22 |
+
|
23 |
+
model = SynthesizerTrn(
|
24 |
+
len(getattr(hps, 'symbols', [])),
|
25 |
+
hps.data.filter_length // 2 + 1,
|
26 |
+
n_speakers=hps.data.n_speakers,
|
27 |
+
**hps.model,
|
28 |
+
).to(device)
|
29 |
+
|
30 |
+
model.eval()
|
31 |
+
self.model = model
|
32 |
+
self.hps = hps
|
33 |
+
self.device = device
|
34 |
+
|
35 |
+
def load_ckpt(self, ckpt_path):
|
36 |
+
checkpoint_dict = torch.load(ckpt_path, map_location=torch.device(self.device))
|
37 |
+
a, b = self.model.load_state_dict(checkpoint_dict['model'], strict=False)
|
38 |
+
print("Loaded checkpoint '{}'".format(ckpt_path))
|
39 |
+
print('missing/unexpected keys:', a, b)
|
40 |
+
|
41 |
+
|
42 |
+
class BaseSpeakerTTS(OpenVoiceBaseClass):
|
43 |
+
language_marks = {
|
44 |
+
"english": "EN",
|
45 |
+
"chinese": "ZH",
|
46 |
+
}
|
47 |
+
|
48 |
+
@staticmethod
|
49 |
+
def get_text(text, hps, is_symbol):
|
50 |
+
text_norm = text_to_sequence(text, hps.symbols, [] if is_symbol else hps.data.text_cleaners)
|
51 |
+
if hps.data.add_blank:
|
52 |
+
text_norm = commons.intersperse(text_norm, 0)
|
53 |
+
text_norm = torch.LongTensor(text_norm)
|
54 |
+
return text_norm
|
55 |
+
|
56 |
+
@staticmethod
|
57 |
+
def audio_numpy_concat(segment_data_list, sr, speed=1.):
|
58 |
+
audio_segments = []
|
59 |
+
for segment_data in segment_data_list:
|
60 |
+
audio_segments += segment_data.reshape(-1).tolist()
|
61 |
+
audio_segments += [0] * int((sr * 0.05)/speed)
|
62 |
+
audio_segments = np.array(audio_segments).astype(np.float32)
|
63 |
+
return audio_segments
|
64 |
+
|
65 |
+
@staticmethod
|
66 |
+
def split_sentences_into_pieces(text, language_str):
|
67 |
+
texts = utils.split_sentence(text, language_str=language_str)
|
68 |
+
print(" > Text splitted to sentences.")
|
69 |
+
print('\n'.join(texts))
|
70 |
+
print(" > ===========================")
|
71 |
+
return texts
|
72 |
+
|
73 |
+
def tts(self, text, output_path, speaker, language='English', speed=1.0):
|
74 |
+
mark = self.language_marks.get(language.lower(), None)
|
75 |
+
assert mark is not None, f"language {language} is not supported"
|
76 |
+
|
77 |
+
texts = self.split_sentences_into_pieces(text, mark)
|
78 |
+
|
79 |
+
audio_list = []
|
80 |
+
for t in texts:
|
81 |
+
t = re.sub(r'([a-z])([A-Z])', r'\1 \2', t)
|
82 |
+
t = f'[{mark}]{t}[{mark}]'
|
83 |
+
stn_tst = self.get_text(t, self.hps, False)
|
84 |
+
device = self.device
|
85 |
+
speaker_id = self.hps.speakers[speaker]
|
86 |
+
with torch.no_grad():
|
87 |
+
x_tst = stn_tst.unsqueeze(0).to(device)
|
88 |
+
x_tst_lengths = torch.LongTensor([stn_tst.size(0)]).to(device)
|
89 |
+
sid = torch.LongTensor([speaker_id]).to(device)
|
90 |
+
audio = self.model.infer(x_tst, x_tst_lengths, sid=sid, noise_scale=0.667, noise_scale_w=0.6,
|
91 |
+
length_scale=1.0 / speed)[0][0, 0].data.cpu().float().numpy()
|
92 |
+
audio_list.append(audio)
|
93 |
+
audio = self.audio_numpy_concat(audio_list, sr=self.hps.data.sampling_rate, speed=speed)
|
94 |
+
|
95 |
+
if output_path is None:
|
96 |
+
return audio
|
97 |
+
else:
|
98 |
+
soundfile.write(output_path, audio, self.hps.data.sampling_rate)
|
99 |
+
|
100 |
+
|
101 |
+
class ToneColorConverter(OpenVoiceBaseClass):
|
102 |
+
def __init__(self, *args, **kwargs):
|
103 |
+
super().__init__(*args, **kwargs)
|
104 |
+
|
105 |
+
if kwargs.get('enable_watermark', True):
|
106 |
+
import wavmark
|
107 |
+
self.watermark_model = wavmark.load_model().to(self.device)
|
108 |
+
else:
|
109 |
+
self.watermark_model = None
|
110 |
+
self.version = getattr(self.hps, '_version_', "v1")
|
111 |
+
|
112 |
+
|
113 |
+
|
114 |
+
def extract_se(self, ref_wav_list, se_save_path=None):
|
115 |
+
if isinstance(ref_wav_list, str):
|
116 |
+
ref_wav_list = [ref_wav_list]
|
117 |
+
|
118 |
+
device = self.device
|
119 |
+
hps = self.hps
|
120 |
+
gs = []
|
121 |
+
|
122 |
+
for fname in ref_wav_list:
|
123 |
+
audio_ref, sr = librosa.load(fname, sr=hps.data.sampling_rate)
|
124 |
+
y = torch.FloatTensor(audio_ref)
|
125 |
+
y = y.to(device)
|
126 |
+
y = y.unsqueeze(0)
|
127 |
+
y = spectrogram_torch(y, hps.data.filter_length,
|
128 |
+
hps.data.sampling_rate, hps.data.hop_length, hps.data.win_length,
|
129 |
+
center=False).to(device)
|
130 |
+
with torch.no_grad():
|
131 |
+
g = self.model.ref_enc(y.transpose(1, 2)).unsqueeze(-1)
|
132 |
+
gs.append(g.detach())
|
133 |
+
gs = torch.stack(gs).mean(0)
|
134 |
+
|
135 |
+
if se_save_path is not None:
|
136 |
+
os.makedirs(os.path.dirname(se_save_path), exist_ok=True)
|
137 |
+
torch.save(gs.cpu(), se_save_path)
|
138 |
+
|
139 |
+
return gs
|
140 |
+
|
141 |
+
def convert(self, audio_src_path, src_se, tgt_se, output_path=None, tau=0.3, message="default"):
|
142 |
+
hps = self.hps
|
143 |
+
# load audio
|
144 |
+
audio, sample_rate = librosa.load(audio_src_path, sr=hps.data.sampling_rate)
|
145 |
+
audio = torch.tensor(audio).float()
|
146 |
+
|
147 |
+
with torch.no_grad():
|
148 |
+
y = torch.FloatTensor(audio).to(self.device)
|
149 |
+
y = y.unsqueeze(0)
|
150 |
+
spec = spectrogram_torch(y, hps.data.filter_length,
|
151 |
+
hps.data.sampling_rate, hps.data.hop_length, hps.data.win_length,
|
152 |
+
center=False).to(self.device)
|
153 |
+
spec_lengths = torch.LongTensor([spec.size(-1)]).to(self.device)
|
154 |
+
audio = self.model.voice_conversion(spec, spec_lengths, sid_src=src_se, sid_tgt=tgt_se, tau=tau)[0][
|
155 |
+
0, 0].data.cpu().float().numpy()
|
156 |
+
audio = self.add_watermark(audio, message)
|
157 |
+
if output_path is None:
|
158 |
+
return audio
|
159 |
+
else:
|
160 |
+
soundfile.write(output_path, audio, hps.data.sampling_rate)
|
161 |
+
|
162 |
+
def add_watermark(self, audio, message):
|
163 |
+
if self.watermark_model is None:
|
164 |
+
return audio
|
165 |
+
device = self.device
|
166 |
+
bits = utils.string_to_bits(message).reshape(-1)
|
167 |
+
n_repeat = len(bits) // 32
|
168 |
+
|
169 |
+
K = 16000
|
170 |
+
coeff = 2
|
171 |
+
for n in range(n_repeat):
|
172 |
+
trunck = audio[(coeff * n) * K: (coeff * n + 1) * K]
|
173 |
+
if len(trunck) != K:
|
174 |
+
print('Audio too short, fail to add watermark')
|
175 |
+
break
|
176 |
+
message_npy = bits[n * 32: (n + 1) * 32]
|
177 |
+
|
178 |
+
with torch.no_grad():
|
179 |
+
signal = torch.FloatTensor(trunck).to(device)[None]
|
180 |
+
message_tensor = torch.FloatTensor(message_npy).to(device)[None]
|
181 |
+
signal_wmd_tensor = self.watermark_model.encode(signal, message_tensor)
|
182 |
+
signal_wmd_npy = signal_wmd_tensor.detach().cpu().squeeze()
|
183 |
+
audio[(coeff * n) * K: (coeff * n + 1) * K] = signal_wmd_npy
|
184 |
+
return audio
|
185 |
+
|
186 |
+
def detect_watermark(self, audio, n_repeat):
|
187 |
+
bits = []
|
188 |
+
K = 16000
|
189 |
+
coeff = 2
|
190 |
+
for n in range(n_repeat):
|
191 |
+
trunck = audio[(coeff * n) * K: (coeff * n + 1) * K]
|
192 |
+
if len(trunck) != K:
|
193 |
+
print('Audio too short, fail to detect watermark')
|
194 |
+
return 'Fail'
|
195 |
+
with torch.no_grad():
|
196 |
+
signal = torch.FloatTensor(trunck).to(self.device).unsqueeze(0)
|
197 |
+
message_decoded_npy = (self.watermark_model.decode(signal) >= 0.5).int().detach().cpu().numpy().squeeze()
|
198 |
+
bits.append(message_decoded_npy)
|
199 |
+
bits = np.stack(bits).reshape(-1, 8)
|
200 |
+
message = utils.bits_to_string(bits)
|
201 |
+
return message
|
202 |
+
|
OpenVoice/openvoice/attentions.py
ADDED
@@ -0,0 +1,465 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
import math
|
2 |
+
import torch
|
3 |
+
from torch import nn
|
4 |
+
from torch.nn import functional as F
|
5 |
+
|
6 |
+
from openvoice import commons
|
7 |
+
import logging
|
8 |
+
|
9 |
+
logger = logging.getLogger(__name__)
|
10 |
+
|
11 |
+
|
12 |
+
class LayerNorm(nn.Module):
|
13 |
+
def __init__(self, channels, eps=1e-5):
|
14 |
+
super().__init__()
|
15 |
+
self.channels = channels
|
16 |
+
self.eps = eps
|
17 |
+
|
18 |
+
self.gamma = nn.Parameter(torch.ones(channels))
|
19 |
+
self.beta = nn.Parameter(torch.zeros(channels))
|
20 |
+
|
21 |
+
def forward(self, x):
|
22 |
+
x = x.transpose(1, -1)
|
23 |
+
x = F.layer_norm(x, (self.channels,), self.gamma, self.beta, self.eps)
|
24 |
+
return x.transpose(1, -1)
|
25 |
+
|
26 |
+
|
27 |
+
@torch.jit.script
|
28 |
+
def fused_add_tanh_sigmoid_multiply(input_a, input_b, n_channels):
|
29 |
+
n_channels_int = n_channels[0]
|
30 |
+
in_act = input_a + input_b
|
31 |
+
t_act = torch.tanh(in_act[:, :n_channels_int, :])
|
32 |
+
s_act = torch.sigmoid(in_act[:, n_channels_int:, :])
|
33 |
+
acts = t_act * s_act
|
34 |
+
return acts
|
35 |
+
|
36 |
+
|
37 |
+
class Encoder(nn.Module):
|
38 |
+
def __init__(
|
39 |
+
self,
|
40 |
+
hidden_channels,
|
41 |
+
filter_channels,
|
42 |
+
n_heads,
|
43 |
+
n_layers,
|
44 |
+
kernel_size=1,
|
45 |
+
p_dropout=0.0,
|
46 |
+
window_size=4,
|
47 |
+
isflow=True,
|
48 |
+
**kwargs
|
49 |
+
):
|
50 |
+
super().__init__()
|
51 |
+
self.hidden_channels = hidden_channels
|
52 |
+
self.filter_channels = filter_channels
|
53 |
+
self.n_heads = n_heads
|
54 |
+
self.n_layers = n_layers
|
55 |
+
self.kernel_size = kernel_size
|
56 |
+
self.p_dropout = p_dropout
|
57 |
+
self.window_size = window_size
|
58 |
+
# if isflow:
|
59 |
+
# cond_layer = torch.nn.Conv1d(256, 2*hidden_channels*n_layers, 1)
|
60 |
+
# self.cond_pre = torch.nn.Conv1d(hidden_channels, 2*hidden_channels, 1)
|
61 |
+
# self.cond_layer = weight_norm(cond_layer, name='weight')
|
62 |
+
# self.gin_channels = 256
|
63 |
+
self.cond_layer_idx = self.n_layers
|
64 |
+
if "gin_channels" in kwargs:
|
65 |
+
self.gin_channels = kwargs["gin_channels"]
|
66 |
+
if self.gin_channels != 0:
|
67 |
+
self.spk_emb_linear = nn.Linear(self.gin_channels, self.hidden_channels)
|
68 |
+
# vits2 says 3rd block, so idx is 2 by default
|
69 |
+
self.cond_layer_idx = (
|
70 |
+
kwargs["cond_layer_idx"] if "cond_layer_idx" in kwargs else 2
|
71 |
+
)
|
72 |
+
# logging.debug(self.gin_channels, self.cond_layer_idx)
|
73 |
+
assert (
|
74 |
+
self.cond_layer_idx < self.n_layers
|
75 |
+
), "cond_layer_idx should be less than n_layers"
|
76 |
+
self.drop = nn.Dropout(p_dropout)
|
77 |
+
self.attn_layers = nn.ModuleList()
|
78 |
+
self.norm_layers_1 = nn.ModuleList()
|
79 |
+
self.ffn_layers = nn.ModuleList()
|
80 |
+
self.norm_layers_2 = nn.ModuleList()
|
81 |
+
|
82 |
+
for i in range(self.n_layers):
|
83 |
+
self.attn_layers.append(
|
84 |
+
MultiHeadAttention(
|
85 |
+
hidden_channels,
|
86 |
+
hidden_channels,
|
87 |
+
n_heads,
|
88 |
+
p_dropout=p_dropout,
|
89 |
+
window_size=window_size,
|
90 |
+
)
|
91 |
+
)
|
92 |
+
self.norm_layers_1.append(LayerNorm(hidden_channels))
|
93 |
+
self.ffn_layers.append(
|
94 |
+
FFN(
|
95 |
+
hidden_channels,
|
96 |
+
hidden_channels,
|
97 |
+
filter_channels,
|
98 |
+
kernel_size,
|
99 |
+
p_dropout=p_dropout,
|
100 |
+
)
|
101 |
+
)
|
102 |
+
self.norm_layers_2.append(LayerNorm(hidden_channels))
|
103 |
+
|
104 |
+
def forward(self, x, x_mask, g=None):
|
105 |
+
attn_mask = x_mask.unsqueeze(2) * x_mask.unsqueeze(-1)
|
106 |
+
x = x * x_mask
|
107 |
+
for i in range(self.n_layers):
|
108 |
+
if i == self.cond_layer_idx and g is not None:
|
109 |
+
g = self.spk_emb_linear(g.transpose(1, 2))
|
110 |
+
g = g.transpose(1, 2)
|
111 |
+
x = x + g
|
112 |
+
x = x * x_mask
|
113 |
+
y = self.attn_layers[i](x, x, attn_mask)
|
114 |
+
y = self.drop(y)
|
115 |
+
x = self.norm_layers_1[i](x + y)
|
116 |
+
|
117 |
+
y = self.ffn_layers[i](x, x_mask)
|
118 |
+
y = self.drop(y)
|
119 |
+
x = self.norm_layers_2[i](x + y)
|
120 |
+
x = x * x_mask
|
121 |
+
return x
|
122 |
+
|
123 |
+
|
124 |
+
class Decoder(nn.Module):
|
125 |
+
def __init__(
|
126 |
+
self,
|
127 |
+
hidden_channels,
|
128 |
+
filter_channels,
|
129 |
+
n_heads,
|
130 |
+
n_layers,
|
131 |
+
kernel_size=1,
|
132 |
+
p_dropout=0.0,
|
133 |
+
proximal_bias=False,
|
134 |
+
proximal_init=True,
|
135 |
+
**kwargs
|
136 |
+
):
|
137 |
+
super().__init__()
|
138 |
+
self.hidden_channels = hidden_channels
|
139 |
+
self.filter_channels = filter_channels
|
140 |
+
self.n_heads = n_heads
|
141 |
+
self.n_layers = n_layers
|
142 |
+
self.kernel_size = kernel_size
|
143 |
+
self.p_dropout = p_dropout
|
144 |
+
self.proximal_bias = proximal_bias
|
145 |
+
self.proximal_init = proximal_init
|
146 |
+
|
147 |
+
self.drop = nn.Dropout(p_dropout)
|
148 |
+
self.self_attn_layers = nn.ModuleList()
|
149 |
+
self.norm_layers_0 = nn.ModuleList()
|
150 |
+
self.encdec_attn_layers = nn.ModuleList()
|
151 |
+
self.norm_layers_1 = nn.ModuleList()
|
152 |
+
self.ffn_layers = nn.ModuleList()
|
153 |
+
self.norm_layers_2 = nn.ModuleList()
|
154 |
+
for i in range(self.n_layers):
|
155 |
+
self.self_attn_layers.append(
|
156 |
+
MultiHeadAttention(
|
157 |
+
hidden_channels,
|
158 |
+
hidden_channels,
|
159 |
+
n_heads,
|
160 |
+
p_dropout=p_dropout,
|
161 |
+
proximal_bias=proximal_bias,
|
162 |
+
proximal_init=proximal_init,
|
163 |
+
)
|
164 |
+
)
|
165 |
+
self.norm_layers_0.append(LayerNorm(hidden_channels))
|
166 |
+
self.encdec_attn_layers.append(
|
167 |
+
MultiHeadAttention(
|
168 |
+
hidden_channels, hidden_channels, n_heads, p_dropout=p_dropout
|
169 |
+
)
|
170 |
+
)
|
171 |
+
self.norm_layers_1.append(LayerNorm(hidden_channels))
|
172 |
+
self.ffn_layers.append(
|
173 |
+
FFN(
|
174 |
+
hidden_channels,
|
175 |
+
hidden_channels,
|
176 |
+
filter_channels,
|
177 |
+
kernel_size,
|
178 |
+
p_dropout=p_dropout,
|
179 |
+
causal=True,
|
180 |
+
)
|
181 |
+
)
|
182 |
+
self.norm_layers_2.append(LayerNorm(hidden_channels))
|
183 |
+
|
184 |
+
def forward(self, x, x_mask, h, h_mask):
|
185 |
+
"""
|
186 |
+
x: decoder input
|
187 |
+
h: encoder output
|
188 |
+
"""
|
189 |
+
self_attn_mask = commons.subsequent_mask(x_mask.size(2)).to(
|
190 |
+
device=x.device, dtype=x.dtype
|
191 |
+
)
|
192 |
+
encdec_attn_mask = h_mask.unsqueeze(2) * x_mask.unsqueeze(-1)
|
193 |
+
x = x * x_mask
|
194 |
+
for i in range(self.n_layers):
|
195 |
+
y = self.self_attn_layers[i](x, x, self_attn_mask)
|
196 |
+
y = self.drop(y)
|
197 |
+
x = self.norm_layers_0[i](x + y)
|
198 |
+
|
199 |
+
y = self.encdec_attn_layers[i](x, h, encdec_attn_mask)
|
200 |
+
y = self.drop(y)
|
201 |
+
x = self.norm_layers_1[i](x + y)
|
202 |
+
|
203 |
+
y = self.ffn_layers[i](x, x_mask)
|
204 |
+
y = self.drop(y)
|
205 |
+
x = self.norm_layers_2[i](x + y)
|
206 |
+
x = x * x_mask
|
207 |
+
return x
|
208 |
+
|
209 |
+
|
210 |
+
class MultiHeadAttention(nn.Module):
|
211 |
+
def __init__(
|
212 |
+
self,
|
213 |
+
channels,
|
214 |
+
out_channels,
|
215 |
+
n_heads,
|
216 |
+
p_dropout=0.0,
|
217 |
+
window_size=None,
|
218 |
+
heads_share=True,
|
219 |
+
block_length=None,
|
220 |
+
proximal_bias=False,
|
221 |
+
proximal_init=False,
|
222 |
+
):
|
223 |
+
super().__init__()
|
224 |
+
assert channels % n_heads == 0
|
225 |
+
|
226 |
+
self.channels = channels
|
227 |
+
self.out_channels = out_channels
|
228 |
+
self.n_heads = n_heads
|
229 |
+
self.p_dropout = p_dropout
|
230 |
+
self.window_size = window_size
|
231 |
+
self.heads_share = heads_share
|
232 |
+
self.block_length = block_length
|
233 |
+
self.proximal_bias = proximal_bias
|
234 |
+
self.proximal_init = proximal_init
|
235 |
+
self.attn = None
|
236 |
+
|
237 |
+
self.k_channels = channels // n_heads
|
238 |
+
self.conv_q = nn.Conv1d(channels, channels, 1)
|
239 |
+
self.conv_k = nn.Conv1d(channels, channels, 1)
|
240 |
+
self.conv_v = nn.Conv1d(channels, channels, 1)
|
241 |
+
self.conv_o = nn.Conv1d(channels, out_channels, 1)
|
242 |
+
self.drop = nn.Dropout(p_dropout)
|
243 |
+
|
244 |
+
if window_size is not None:
|
245 |
+
n_heads_rel = 1 if heads_share else n_heads
|
246 |
+
rel_stddev = self.k_channels**-0.5
|
247 |
+
self.emb_rel_k = nn.Parameter(
|
248 |
+
torch.randn(n_heads_rel, window_size * 2 + 1, self.k_channels)
|
249 |
+
* rel_stddev
|
250 |
+
)
|
251 |
+
self.emb_rel_v = nn.Parameter(
|
252 |
+
torch.randn(n_heads_rel, window_size * 2 + 1, self.k_channels)
|
253 |
+
* rel_stddev
|
254 |
+
)
|
255 |
+
|
256 |
+
nn.init.xavier_uniform_(self.conv_q.weight)
|
257 |
+
nn.init.xavier_uniform_(self.conv_k.weight)
|
258 |
+
nn.init.xavier_uniform_(self.conv_v.weight)
|
259 |
+
if proximal_init:
|
260 |
+
with torch.no_grad():
|
261 |
+
self.conv_k.weight.copy_(self.conv_q.weight)
|
262 |
+
self.conv_k.bias.copy_(self.conv_q.bias)
|
263 |
+
|
264 |
+
def forward(self, x, c, attn_mask=None):
|
265 |
+
q = self.conv_q(x)
|
266 |
+
k = self.conv_k(c)
|
267 |
+
v = self.conv_v(c)
|
268 |
+
|
269 |
+
x, self.attn = self.attention(q, k, v, mask=attn_mask)
|
270 |
+
|
271 |
+
x = self.conv_o(x)
|
272 |
+
return x
|
273 |
+
|
274 |
+
def attention(self, query, key, value, mask=None):
|
275 |
+
# reshape [b, d, t] -> [b, n_h, t, d_k]
|
276 |
+
b, d, t_s, t_t = (*key.size(), query.size(2))
|
277 |
+
query = query.view(b, self.n_heads, self.k_channels, t_t).transpose(2, 3)
|
278 |
+
key = key.view(b, self.n_heads, self.k_channels, t_s).transpose(2, 3)
|
279 |
+
value = value.view(b, self.n_heads, self.k_channels, t_s).transpose(2, 3)
|
280 |
+
|
281 |
+
scores = torch.matmul(query / math.sqrt(self.k_channels), key.transpose(-2, -1))
|
282 |
+
if self.window_size is not None:
|
283 |
+
assert (
|
284 |
+
t_s == t_t
|
285 |
+
), "Relative attention is only available for self-attention."
|
286 |
+
key_relative_embeddings = self._get_relative_embeddings(self.emb_rel_k, t_s)
|
287 |
+
rel_logits = self._matmul_with_relative_keys(
|
288 |
+
query / math.sqrt(self.k_channels), key_relative_embeddings
|
289 |
+
)
|
290 |
+
scores_local = self._relative_position_to_absolute_position(rel_logits)
|
291 |
+
scores = scores + scores_local
|
292 |
+
if self.proximal_bias:
|
293 |
+
assert t_s == t_t, "Proximal bias is only available for self-attention."
|
294 |
+
scores = scores + self._attention_bias_proximal(t_s).to(
|
295 |
+
device=scores.device, dtype=scores.dtype
|
296 |
+
)
|
297 |
+
if mask is not None:
|
298 |
+
scores = scores.masked_fill(mask == 0, -1e4)
|
299 |
+
if self.block_length is not None:
|
300 |
+
assert (
|
301 |
+
t_s == t_t
|
302 |
+
), "Local attention is only available for self-attention."
|
303 |
+
block_mask = (
|
304 |
+
torch.ones_like(scores)
|
305 |
+
.triu(-self.block_length)
|
306 |
+
.tril(self.block_length)
|
307 |
+
)
|
308 |
+
scores = scores.masked_fill(block_mask == 0, -1e4)
|
309 |
+
p_attn = F.softmax(scores, dim=-1) # [b, n_h, t_t, t_s]
|
310 |
+
p_attn = self.drop(p_attn)
|
311 |
+
output = torch.matmul(p_attn, value)
|
312 |
+
if self.window_size is not None:
|
313 |
+
relative_weights = self._absolute_position_to_relative_position(p_attn)
|
314 |
+
value_relative_embeddings = self._get_relative_embeddings(
|
315 |
+
self.emb_rel_v, t_s
|
316 |
+
)
|
317 |
+
output = output + self._matmul_with_relative_values(
|
318 |
+
relative_weights, value_relative_embeddings
|
319 |
+
)
|
320 |
+
output = (
|
321 |
+
output.transpose(2, 3).contiguous().view(b, d, t_t)
|
322 |
+
) # [b, n_h, t_t, d_k] -> [b, d, t_t]
|
323 |
+
return output, p_attn
|
324 |
+
|
325 |
+
def _matmul_with_relative_values(self, x, y):
|
326 |
+
"""
|
327 |
+
x: [b, h, l, m]
|
328 |
+
y: [h or 1, m, d]
|
329 |
+
ret: [b, h, l, d]
|
330 |
+
"""
|
331 |
+
ret = torch.matmul(x, y.unsqueeze(0))
|
332 |
+
return ret
|
333 |
+
|
334 |
+
def _matmul_with_relative_keys(self, x, y):
|
335 |
+
"""
|
336 |
+
x: [b, h, l, d]
|
337 |
+
y: [h or 1, m, d]
|
338 |
+
ret: [b, h, l, m]
|
339 |
+
"""
|
340 |
+
ret = torch.matmul(x, y.unsqueeze(0).transpose(-2, -1))
|
341 |
+
return ret
|
342 |
+
|
343 |
+
def _get_relative_embeddings(self, relative_embeddings, length):
|
344 |
+
2 * self.window_size + 1
|
345 |
+
# Pad first before slice to avoid using cond ops.
|
346 |
+
pad_length = max(length - (self.window_size + 1), 0)
|
347 |
+
slice_start_position = max((self.window_size + 1) - length, 0)
|
348 |
+
slice_end_position = slice_start_position + 2 * length - 1
|
349 |
+
if pad_length > 0:
|
350 |
+
padded_relative_embeddings = F.pad(
|
351 |
+
relative_embeddings,
|
352 |
+
commons.convert_pad_shape([[0, 0], [pad_length, pad_length], [0, 0]]),
|
353 |
+
)
|
354 |
+
else:
|
355 |
+
padded_relative_embeddings = relative_embeddings
|
356 |
+
used_relative_embeddings = padded_relative_embeddings[
|
357 |
+
:, slice_start_position:slice_end_position
|
358 |
+
]
|
359 |
+
return used_relative_embeddings
|
360 |
+
|
361 |
+
def _relative_position_to_absolute_position(self, x):
|
362 |
+
"""
|
363 |
+
x: [b, h, l, 2*l-1]
|
364 |
+
ret: [b, h, l, l]
|
365 |
+
"""
|
366 |
+
batch, heads, length, _ = x.size()
|
367 |
+
# Concat columns of pad to shift from relative to absolute indexing.
|
368 |
+
x = F.pad(x, commons.convert_pad_shape([[0, 0], [0, 0], [0, 0], [0, 1]]))
|
369 |
+
|
370 |
+
# Concat extra elements so to add up to shape (len+1, 2*len-1).
|
371 |
+
x_flat = x.view([batch, heads, length * 2 * length])
|
372 |
+
x_flat = F.pad(
|
373 |
+
x_flat, commons.convert_pad_shape([[0, 0], [0, 0], [0, length - 1]])
|
374 |
+
)
|
375 |
+
|
376 |
+
# Reshape and slice out the padded elements.
|
377 |
+
x_final = x_flat.view([batch, heads, length + 1, 2 * length - 1])[
|
378 |
+
:, :, :length, length - 1 :
|
379 |
+
]
|
380 |
+
return x_final
|
381 |
+
|
382 |
+
def _absolute_position_to_relative_position(self, x):
|
383 |
+
"""
|
384 |
+
x: [b, h, l, l]
|
385 |
+
ret: [b, h, l, 2*l-1]
|
386 |
+
"""
|
387 |
+
batch, heads, length, _ = x.size()
|
388 |
+
# pad along column
|
389 |
+
x = F.pad(
|
390 |
+
x, commons.convert_pad_shape([[0, 0], [0, 0], [0, 0], [0, length - 1]])
|
391 |
+
)
|
392 |
+
x_flat = x.view([batch, heads, length**2 + length * (length - 1)])
|
393 |
+
# add 0's in the beginning that will skew the elements after reshape
|
394 |
+
x_flat = F.pad(x_flat, commons.convert_pad_shape([[0, 0], [0, 0], [length, 0]]))
|
395 |
+
x_final = x_flat.view([batch, heads, length, 2 * length])[:, :, :, 1:]
|
396 |
+
return x_final
|
397 |
+
|
398 |
+
def _attention_bias_proximal(self, length):
|
399 |
+
"""Bias for self-attention to encourage attention to close positions.
|
400 |
+
Args:
|
401 |
+
length: an integer scalar.
|
402 |
+
Returns:
|
403 |
+
a Tensor with shape [1, 1, length, length]
|
404 |
+
"""
|
405 |
+
r = torch.arange(length, dtype=torch.float32)
|
406 |
+
diff = torch.unsqueeze(r, 0) - torch.unsqueeze(r, 1)
|
407 |
+
return torch.unsqueeze(torch.unsqueeze(-torch.log1p(torch.abs(diff)), 0), 0)
|
408 |
+
|
409 |
+
|
410 |
+
class FFN(nn.Module):
|
411 |
+
def __init__(
|
412 |
+
self,
|
413 |
+
in_channels,
|
414 |
+
out_channels,
|
415 |
+
filter_channels,
|
416 |
+
kernel_size,
|
417 |
+
p_dropout=0.0,
|
418 |
+
activation=None,
|
419 |
+
causal=False,
|
420 |
+
):
|
421 |
+
super().__init__()
|
422 |
+
self.in_channels = in_channels
|
423 |
+
self.out_channels = out_channels
|
424 |
+
self.filter_channels = filter_channels
|
425 |
+
self.kernel_size = kernel_size
|
426 |
+
self.p_dropout = p_dropout
|
427 |
+
self.activation = activation
|
428 |
+
self.causal = causal
|
429 |
+
|
430 |
+
if causal:
|
431 |
+
self.padding = self._causal_padding
|
432 |
+
else:
|
433 |
+
self.padding = self._same_padding
|
434 |
+
|
435 |
+
self.conv_1 = nn.Conv1d(in_channels, filter_channels, kernel_size)
|
436 |
+
self.conv_2 = nn.Conv1d(filter_channels, out_channels, kernel_size)
|
437 |
+
self.drop = nn.Dropout(p_dropout)
|
438 |
+
|
439 |
+
def forward(self, x, x_mask):
|
440 |
+
x = self.conv_1(self.padding(x * x_mask))
|
441 |
+
if self.activation == "gelu":
|
442 |
+
x = x * torch.sigmoid(1.702 * x)
|
443 |
+
else:
|
444 |
+
x = torch.relu(x)
|
445 |
+
x = self.drop(x)
|
446 |
+
x = self.conv_2(self.padding(x * x_mask))
|
447 |
+
return x * x_mask
|
448 |
+
|
449 |
+
def _causal_padding(self, x):
|
450 |
+
if self.kernel_size == 1:
|
451 |
+
return x
|
452 |
+
pad_l = self.kernel_size - 1
|
453 |
+
pad_r = 0
|
454 |
+
padding = [[0, 0], [0, 0], [pad_l, pad_r]]
|
455 |
+
x = F.pad(x, commons.convert_pad_shape(padding))
|
456 |
+
return x
|
457 |
+
|
458 |
+
def _same_padding(self, x):
|
459 |
+
if self.kernel_size == 1:
|
460 |
+
return x
|
461 |
+
pad_l = (self.kernel_size - 1) // 2
|
462 |
+
pad_r = self.kernel_size // 2
|
463 |
+
padding = [[0, 0], [0, 0], [pad_l, pad_r]]
|
464 |
+
x = F.pad(x, commons.convert_pad_shape(padding))
|
465 |
+
return x
|
OpenVoice/openvoice/commons.py
ADDED
@@ -0,0 +1,160 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
import math
|
2 |
+
import torch
|
3 |
+
from torch.nn import functional as F
|
4 |
+
|
5 |
+
|
6 |
+
def init_weights(m, mean=0.0, std=0.01):
|
7 |
+
classname = m.__class__.__name__
|
8 |
+
if classname.find("Conv") != -1:
|
9 |
+
m.weight.data.normal_(mean, std)
|
10 |
+
|
11 |
+
|
12 |
+
def get_padding(kernel_size, dilation=1):
|
13 |
+
return int((kernel_size * dilation - dilation) / 2)
|
14 |
+
|
15 |
+
|
16 |
+
def convert_pad_shape(pad_shape):
|
17 |
+
layer = pad_shape[::-1]
|
18 |
+
pad_shape = [item for sublist in layer for item in sublist]
|
19 |
+
return pad_shape
|
20 |
+
|
21 |
+
|
22 |
+
def intersperse(lst, item):
|
23 |
+
result = [item] * (len(lst) * 2 + 1)
|
24 |
+
result[1::2] = lst
|
25 |
+
return result
|
26 |
+
|
27 |
+
|
28 |
+
def kl_divergence(m_p, logs_p, m_q, logs_q):
|
29 |
+
"""KL(P||Q)"""
|
30 |
+
kl = (logs_q - logs_p) - 0.5
|
31 |
+
kl += (
|
32 |
+
0.5 * (torch.exp(2.0 * logs_p) + ((m_p - m_q) ** 2)) * torch.exp(-2.0 * logs_q)
|
33 |
+
)
|
34 |
+
return kl
|
35 |
+
|
36 |
+
|
37 |
+
def rand_gumbel(shape):
|
38 |
+
"""Sample from the Gumbel distribution, protect from overflows."""
|
39 |
+
uniform_samples = torch.rand(shape) * 0.99998 + 0.00001
|
40 |
+
return -torch.log(-torch.log(uniform_samples))
|
41 |
+
|
42 |
+
|
43 |
+
def rand_gumbel_like(x):
|
44 |
+
g = rand_gumbel(x.size()).to(dtype=x.dtype, device=x.device)
|
45 |
+
return g
|
46 |
+
|
47 |
+
|
48 |
+
def slice_segments(x, ids_str, segment_size=4):
|
49 |
+
ret = torch.zeros_like(x[:, :, :segment_size])
|
50 |
+
for i in range(x.size(0)):
|
51 |
+
idx_str = ids_str[i]
|
52 |
+
idx_end = idx_str + segment_size
|
53 |
+
ret[i] = x[i, :, idx_str:idx_end]
|
54 |
+
return ret
|
55 |
+
|
56 |
+
|
57 |
+
def rand_slice_segments(x, x_lengths=None, segment_size=4):
|
58 |
+
b, d, t = x.size()
|
59 |
+
if x_lengths is None:
|
60 |
+
x_lengths = t
|
61 |
+
ids_str_max = x_lengths - segment_size + 1
|
62 |
+
ids_str = (torch.rand([b]).to(device=x.device) * ids_str_max).to(dtype=torch.long)
|
63 |
+
ret = slice_segments(x, ids_str, segment_size)
|
64 |
+
return ret, ids_str
|
65 |
+
|
66 |
+
|
67 |
+
def get_timing_signal_1d(length, channels, min_timescale=1.0, max_timescale=1.0e4):
|
68 |
+
position = torch.arange(length, dtype=torch.float)
|
69 |
+
num_timescales = channels // 2
|
70 |
+
log_timescale_increment = math.log(float(max_timescale) / float(min_timescale)) / (
|
71 |
+
num_timescales - 1
|
72 |
+
)
|
73 |
+
inv_timescales = min_timescale * torch.exp(
|
74 |
+
torch.arange(num_timescales, dtype=torch.float) * -log_timescale_increment
|
75 |
+
)
|
76 |
+
scaled_time = position.unsqueeze(0) * inv_timescales.unsqueeze(1)
|
77 |
+
signal = torch.cat([torch.sin(scaled_time), torch.cos(scaled_time)], 0)
|
78 |
+
signal = F.pad(signal, [0, 0, 0, channels % 2])
|
79 |
+
signal = signal.view(1, channels, length)
|
80 |
+
return signal
|
81 |
+
|
82 |
+
|
83 |
+
def add_timing_signal_1d(x, min_timescale=1.0, max_timescale=1.0e4):
|
84 |
+
b, channels, length = x.size()
|
85 |
+
signal = get_timing_signal_1d(length, channels, min_timescale, max_timescale)
|
86 |
+
return x + signal.to(dtype=x.dtype, device=x.device)
|
87 |
+
|
88 |
+
|
89 |
+
def cat_timing_signal_1d(x, min_timescale=1.0, max_timescale=1.0e4, axis=1):
|
90 |
+
b, channels, length = x.size()
|
91 |
+
signal = get_timing_signal_1d(length, channels, min_timescale, max_timescale)
|
92 |
+
return torch.cat([x, signal.to(dtype=x.dtype, device=x.device)], axis)
|
93 |
+
|
94 |
+
|
95 |
+
def subsequent_mask(length):
|
96 |
+
mask = torch.tril(torch.ones(length, length)).unsqueeze(0).unsqueeze(0)
|
97 |
+
return mask
|
98 |
+
|
99 |
+
|
100 |
+
@torch.jit.script
|
101 |
+
def fused_add_tanh_sigmoid_multiply(input_a, input_b, n_channels):
|
102 |
+
n_channels_int = n_channels[0]
|
103 |
+
in_act = input_a + input_b
|
104 |
+
t_act = torch.tanh(in_act[:, :n_channels_int, :])
|
105 |
+
s_act = torch.sigmoid(in_act[:, n_channels_int:, :])
|
106 |
+
acts = t_act * s_act
|
107 |
+
return acts
|
108 |
+
|
109 |
+
|
110 |
+
def convert_pad_shape(pad_shape):
|
111 |
+
layer = pad_shape[::-1]
|
112 |
+
pad_shape = [item for sublist in layer for item in sublist]
|
113 |
+
return pad_shape
|
114 |
+
|
115 |
+
|
116 |
+
def shift_1d(x):
|
117 |
+
x = F.pad(x, convert_pad_shape([[0, 0], [0, 0], [1, 0]]))[:, :, :-1]
|
118 |
+
return x
|
119 |
+
|
120 |
+
|
121 |
+
def sequence_mask(length, max_length=None):
|
122 |
+
if max_length is None:
|
123 |
+
max_length = length.max()
|
124 |
+
x = torch.arange(max_length, dtype=length.dtype, device=length.device)
|
125 |
+
return x.unsqueeze(0) < length.unsqueeze(1)
|
126 |
+
|
127 |
+
|
128 |
+
def generate_path(duration, mask):
|
129 |
+
"""
|
130 |
+
duration: [b, 1, t_x]
|
131 |
+
mask: [b, 1, t_y, t_x]
|
132 |
+
"""
|
133 |
+
|
134 |
+
b, _, t_y, t_x = mask.shape
|
135 |
+
cum_duration = torch.cumsum(duration, -1)
|
136 |
+
|
137 |
+
cum_duration_flat = cum_duration.view(b * t_x)
|
138 |
+
path = sequence_mask(cum_duration_flat, t_y).to(mask.dtype)
|
139 |
+
path = path.view(b, t_x, t_y)
|
140 |
+
path = path - F.pad(path, convert_pad_shape([[0, 0], [1, 0], [0, 0]]))[:, :-1]
|
141 |
+
path = path.unsqueeze(1).transpose(2, 3) * mask
|
142 |
+
return path
|
143 |
+
|
144 |
+
|
145 |
+
def clip_grad_value_(parameters, clip_value, norm_type=2):
|
146 |
+
if isinstance(parameters, torch.Tensor):
|
147 |
+
parameters = [parameters]
|
148 |
+
parameters = list(filter(lambda p: p.grad is not None, parameters))
|
149 |
+
norm_type = float(norm_type)
|
150 |
+
if clip_value is not None:
|
151 |
+
clip_value = float(clip_value)
|
152 |
+
|
153 |
+
total_norm = 0
|
154 |
+
for p in parameters:
|
155 |
+
param_norm = p.grad.data.norm(norm_type)
|
156 |
+
total_norm += param_norm.item() ** norm_type
|
157 |
+
if clip_value is not None:
|
158 |
+
p.grad.data.clamp_(min=-clip_value, max=clip_value)
|
159 |
+
total_norm = total_norm ** (1.0 / norm_type)
|
160 |
+
return total_norm
|
OpenVoice/openvoice/mel_processing.py
ADDED
@@ -0,0 +1,183 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
import torch
|
2 |
+
import torch.utils.data
|
3 |
+
from librosa.filters import mel as librosa_mel_fn
|
4 |
+
|
5 |
+
MAX_WAV_VALUE = 32768.0
|
6 |
+
|
7 |
+
|
8 |
+
def dynamic_range_compression_torch(x, C=1, clip_val=1e-5):
|
9 |
+
"""
|
10 |
+
PARAMS
|
11 |
+
------
|
12 |
+
C: compression factor
|
13 |
+
"""
|
14 |
+
return torch.log(torch.clamp(x, min=clip_val) * C)
|
15 |
+
|
16 |
+
|
17 |
+
def dynamic_range_decompression_torch(x, C=1):
|
18 |
+
"""
|
19 |
+
PARAMS
|
20 |
+
------
|
21 |
+
C: compression factor used to compress
|
22 |
+
"""
|
23 |
+
return torch.exp(x) / C
|
24 |
+
|
25 |
+
|
26 |
+
def spectral_normalize_torch(magnitudes):
|
27 |
+
output = dynamic_range_compression_torch(magnitudes)
|
28 |
+
return output
|
29 |
+
|
30 |
+
|
31 |
+
def spectral_de_normalize_torch(magnitudes):
|
32 |
+
output = dynamic_range_decompression_torch(magnitudes)
|
33 |
+
return output
|
34 |
+
|
35 |
+
|
36 |
+
mel_basis = {}
|
37 |
+
hann_window = {}
|
38 |
+
|
39 |
+
|
40 |
+
def spectrogram_torch(y, n_fft, sampling_rate, hop_size, win_size, center=False):
|
41 |
+
if torch.min(y) < -1.1:
|
42 |
+
print("min value is ", torch.min(y))
|
43 |
+
if torch.max(y) > 1.1:
|
44 |
+
print("max value is ", torch.max(y))
|
45 |
+
|
46 |
+
global hann_window
|
47 |
+
dtype_device = str(y.dtype) + "_" + str(y.device)
|
48 |
+
wnsize_dtype_device = str(win_size) + "_" + dtype_device
|
49 |
+
if wnsize_dtype_device not in hann_window:
|
50 |
+
hann_window[wnsize_dtype_device] = torch.hann_window(win_size).to(
|
51 |
+
dtype=y.dtype, device=y.device
|
52 |
+
)
|
53 |
+
|
54 |
+
y = torch.nn.functional.pad(
|
55 |
+
y.unsqueeze(1),
|
56 |
+
(int((n_fft - hop_size) / 2), int((n_fft - hop_size) / 2)),
|
57 |
+
mode="reflect",
|
58 |
+
)
|
59 |
+
y = y.squeeze(1)
|
60 |
+
|
61 |
+
spec = torch.stft(
|
62 |
+
y,
|
63 |
+
n_fft,
|
64 |
+
hop_length=hop_size,
|
65 |
+
win_length=win_size,
|
66 |
+
window=hann_window[wnsize_dtype_device],
|
67 |
+
center=center,
|
68 |
+
pad_mode="reflect",
|
69 |
+
normalized=False,
|
70 |
+
onesided=True,
|
71 |
+
return_complex=False,
|
72 |
+
)
|
73 |
+
|
74 |
+
spec = torch.sqrt(spec.pow(2).sum(-1) + 1e-6)
|
75 |
+
return spec
|
76 |
+
|
77 |
+
|
78 |
+
def spectrogram_torch_conv(y, n_fft, sampling_rate, hop_size, win_size, center=False):
|
79 |
+
# if torch.min(y) < -1.:
|
80 |
+
# print('min value is ', torch.min(y))
|
81 |
+
# if torch.max(y) > 1.:
|
82 |
+
# print('max value is ', torch.max(y))
|
83 |
+
|
84 |
+
global hann_window
|
85 |
+
dtype_device = str(y.dtype) + '_' + str(y.device)
|
86 |
+
wnsize_dtype_device = str(win_size) + '_' + dtype_device
|
87 |
+
if wnsize_dtype_device not in hann_window:
|
88 |
+
hann_window[wnsize_dtype_device] = torch.hann_window(win_size).to(dtype=y.dtype, device=y.device)
|
89 |
+
|
90 |
+
y = torch.nn.functional.pad(y.unsqueeze(1), (int((n_fft-hop_size)/2), int((n_fft-hop_size)/2)), mode='reflect')
|
91 |
+
|
92 |
+
# ******************** original ************************#
|
93 |
+
# y = y.squeeze(1)
|
94 |
+
# spec1 = torch.stft(y, n_fft, hop_length=hop_size, win_length=win_size, window=hann_window[wnsize_dtype_device],
|
95 |
+
# center=center, pad_mode='reflect', normalized=False, onesided=True, return_complex=False)
|
96 |
+
|
97 |
+
# ******************** ConvSTFT ************************#
|
98 |
+
freq_cutoff = n_fft // 2 + 1
|
99 |
+
fourier_basis = torch.view_as_real(torch.fft.fft(torch.eye(n_fft)))
|
100 |
+
forward_basis = fourier_basis[:freq_cutoff].permute(2, 0, 1).reshape(-1, 1, fourier_basis.shape[1])
|
101 |
+
forward_basis = forward_basis * torch.as_tensor(librosa.util.pad_center(torch.hann_window(win_size), size=n_fft)).float()
|
102 |
+
|
103 |
+
import torch.nn.functional as F
|
104 |
+
|
105 |
+
# if center:
|
106 |
+
# signal = F.pad(y[:, None, None, :], (n_fft // 2, n_fft // 2, 0, 0), mode = 'reflect').squeeze(1)
|
107 |
+
assert center is False
|
108 |
+
|
109 |
+
forward_transform_squared = F.conv1d(y, forward_basis.to(y.device), stride = hop_size)
|
110 |
+
spec2 = torch.stack([forward_transform_squared[:, :freq_cutoff, :], forward_transform_squared[:, freq_cutoff:, :]], dim = -1)
|
111 |
+
|
112 |
+
|
113 |
+
# ******************** Verification ************************#
|
114 |
+
spec1 = torch.stft(y.squeeze(1), n_fft, hop_length=hop_size, win_length=win_size, window=hann_window[wnsize_dtype_device],
|
115 |
+
center=center, pad_mode='reflect', normalized=False, onesided=True, return_complex=False)
|
116 |
+
assert torch.allclose(spec1, spec2, atol=1e-4)
|
117 |
+
|
118 |
+
spec = torch.sqrt(spec2.pow(2).sum(-1) + 1e-6)
|
119 |
+
return spec
|
120 |
+
|
121 |
+
|
122 |
+
def spec_to_mel_torch(spec, n_fft, num_mels, sampling_rate, fmin, fmax):
|
123 |
+
global mel_basis
|
124 |
+
dtype_device = str(spec.dtype) + "_" + str(spec.device)
|
125 |
+
fmax_dtype_device = str(fmax) + "_" + dtype_device
|
126 |
+
if fmax_dtype_device not in mel_basis:
|
127 |
+
mel = librosa_mel_fn(sampling_rate, n_fft, num_mels, fmin, fmax)
|
128 |
+
mel_basis[fmax_dtype_device] = torch.from_numpy(mel).to(
|
129 |
+
dtype=spec.dtype, device=spec.device
|
130 |
+
)
|
131 |
+
spec = torch.matmul(mel_basis[fmax_dtype_device], spec)
|
132 |
+
spec = spectral_normalize_torch(spec)
|
133 |
+
return spec
|
134 |
+
|
135 |
+
|
136 |
+
def mel_spectrogram_torch(
|
137 |
+
y, n_fft, num_mels, sampling_rate, hop_size, win_size, fmin, fmax, center=False
|
138 |
+
):
|
139 |
+
if torch.min(y) < -1.0:
|
140 |
+
print("min value is ", torch.min(y))
|
141 |
+
if torch.max(y) > 1.0:
|
142 |
+
print("max value is ", torch.max(y))
|
143 |
+
|
144 |
+
global mel_basis, hann_window
|
145 |
+
dtype_device = str(y.dtype) + "_" + str(y.device)
|
146 |
+
fmax_dtype_device = str(fmax) + "_" + dtype_device
|
147 |
+
wnsize_dtype_device = str(win_size) + "_" + dtype_device
|
148 |
+
if fmax_dtype_device not in mel_basis:
|
149 |
+
mel = librosa_mel_fn(sampling_rate, n_fft, num_mels, fmin, fmax)
|
150 |
+
mel_basis[fmax_dtype_device] = torch.from_numpy(mel).to(
|
151 |
+
dtype=y.dtype, device=y.device
|
152 |
+
)
|
153 |
+
if wnsize_dtype_device not in hann_window:
|
154 |
+
hann_window[wnsize_dtype_device] = torch.hann_window(win_size).to(
|
155 |
+
dtype=y.dtype, device=y.device
|
156 |
+
)
|
157 |
+
|
158 |
+
y = torch.nn.functional.pad(
|
159 |
+
y.unsqueeze(1),
|
160 |
+
(int((n_fft - hop_size) / 2), int((n_fft - hop_size) / 2)),
|
161 |
+
mode="reflect",
|
162 |
+
)
|
163 |
+
y = y.squeeze(1)
|
164 |
+
|
165 |
+
spec = torch.stft(
|
166 |
+
y,
|
167 |
+
n_fft,
|
168 |
+
hop_length=hop_size,
|
169 |
+
win_length=win_size,
|
170 |
+
window=hann_window[wnsize_dtype_device],
|
171 |
+
center=center,
|
172 |
+
pad_mode="reflect",
|
173 |
+
normalized=False,
|
174 |
+
onesided=True,
|
175 |
+
return_complex=False,
|
176 |
+
)
|
177 |
+
|
178 |
+
spec = torch.sqrt(spec.pow(2).sum(-1) + 1e-6)
|
179 |
+
|
180 |
+
spec = torch.matmul(mel_basis[fmax_dtype_device], spec)
|
181 |
+
spec = spectral_normalize_torch(spec)
|
182 |
+
|
183 |
+
return spec
|
OpenVoice/openvoice/models.py
ADDED
@@ -0,0 +1,499 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
import math
|
2 |
+
import torch
|
3 |
+
from torch import nn
|
4 |
+
from torch.nn import functional as F
|
5 |
+
|
6 |
+
from openvoice import commons
|
7 |
+
from openvoice import modules
|
8 |
+
from openvoice import attentions
|
9 |
+
|
10 |
+
from torch.nn import Conv1d, ConvTranspose1d, Conv2d
|
11 |
+
from torch.nn.utils import weight_norm, remove_weight_norm, spectral_norm
|
12 |
+
|
13 |
+
from openvoice.commons import init_weights, get_padding
|
14 |
+
|
15 |
+
|
16 |
+
class TextEncoder(nn.Module):
|
17 |
+
def __init__(self,
|
18 |
+
n_vocab,
|
19 |
+
out_channels,
|
20 |
+
hidden_channels,
|
21 |
+
filter_channels,
|
22 |
+
n_heads,
|
23 |
+
n_layers,
|
24 |
+
kernel_size,
|
25 |
+
p_dropout):
|
26 |
+
super().__init__()
|
27 |
+
self.n_vocab = n_vocab
|
28 |
+
self.out_channels = out_channels
|
29 |
+
self.hidden_channels = hidden_channels
|
30 |
+
self.filter_channels = filter_channels
|
31 |
+
self.n_heads = n_heads
|
32 |
+
self.n_layers = n_layers
|
33 |
+
self.kernel_size = kernel_size
|
34 |
+
self.p_dropout = p_dropout
|
35 |
+
|
36 |
+
self.emb = nn.Embedding(n_vocab, hidden_channels)
|
37 |
+
nn.init.normal_(self.emb.weight, 0.0, hidden_channels**-0.5)
|
38 |
+
|
39 |
+
self.encoder = attentions.Encoder(
|
40 |
+
hidden_channels,
|
41 |
+
filter_channels,
|
42 |
+
n_heads,
|
43 |
+
n_layers,
|
44 |
+
kernel_size,
|
45 |
+
p_dropout)
|
46 |
+
self.proj= nn.Conv1d(hidden_channels, out_channels * 2, 1)
|
47 |
+
|
48 |
+
def forward(self, x, x_lengths):
|
49 |
+
x = self.emb(x) * math.sqrt(self.hidden_channels) # [b, t, h]
|
50 |
+
x = torch.transpose(x, 1, -1) # [b, h, t]
|
51 |
+
x_mask = torch.unsqueeze(commons.sequence_mask(x_lengths, x.size(2)), 1).to(x.dtype)
|
52 |
+
|
53 |
+
x = self.encoder(x * x_mask, x_mask)
|
54 |
+
stats = self.proj(x) * x_mask
|
55 |
+
|
56 |
+
m, logs = torch.split(stats, self.out_channels, dim=1)
|
57 |
+
return x, m, logs, x_mask
|
58 |
+
|
59 |
+
|
60 |
+
class DurationPredictor(nn.Module):
|
61 |
+
def __init__(
|
62 |
+
self, in_channels, filter_channels, kernel_size, p_dropout, gin_channels=0
|
63 |
+
):
|
64 |
+
super().__init__()
|
65 |
+
|
66 |
+
self.in_channels = in_channels
|
67 |
+
self.filter_channels = filter_channels
|
68 |
+
self.kernel_size = kernel_size
|
69 |
+
self.p_dropout = p_dropout
|
70 |
+
self.gin_channels = gin_channels
|
71 |
+
|
72 |
+
self.drop = nn.Dropout(p_dropout)
|
73 |
+
self.conv_1 = nn.Conv1d(
|
74 |
+
in_channels, filter_channels, kernel_size, padding=kernel_size // 2
|
75 |
+
)
|
76 |
+
self.norm_1 = modules.LayerNorm(filter_channels)
|
77 |
+
self.conv_2 = nn.Conv1d(
|
78 |
+
filter_channels, filter_channels, kernel_size, padding=kernel_size // 2
|
79 |
+
)
|
80 |
+
self.norm_2 = modules.LayerNorm(filter_channels)
|
81 |
+
self.proj = nn.Conv1d(filter_channels, 1, 1)
|
82 |
+
|
83 |
+
if gin_channels != 0:
|
84 |
+
self.cond = nn.Conv1d(gin_channels, in_channels, 1)
|
85 |
+
|
86 |
+
def forward(self, x, x_mask, g=None):
|
87 |
+
x = torch.detach(x)
|
88 |
+
if g is not None:
|
89 |
+
g = torch.detach(g)
|
90 |
+
x = x + self.cond(g)
|
91 |
+
x = self.conv_1(x * x_mask)
|
92 |
+
x = torch.relu(x)
|
93 |
+
x = self.norm_1(x)
|
94 |
+
x = self.drop(x)
|
95 |
+
x = self.conv_2(x * x_mask)
|
96 |
+
x = torch.relu(x)
|
97 |
+
x = self.norm_2(x)
|
98 |
+
x = self.drop(x)
|
99 |
+
x = self.proj(x * x_mask)
|
100 |
+
return x * x_mask
|
101 |
+
|
102 |
+
class StochasticDurationPredictor(nn.Module):
|
103 |
+
def __init__(self, in_channels, filter_channels, kernel_size, p_dropout, n_flows=4, gin_channels=0):
|
104 |
+
super().__init__()
|
105 |
+
filter_channels = in_channels # it needs to be removed from future version.
|
106 |
+
self.in_channels = in_channels
|
107 |
+
self.filter_channels = filter_channels
|
108 |
+
self.kernel_size = kernel_size
|
109 |
+
self.p_dropout = p_dropout
|
110 |
+
self.n_flows = n_flows
|
111 |
+
self.gin_channels = gin_channels
|
112 |
+
|
113 |
+
self.log_flow = modules.Log()
|
114 |
+
self.flows = nn.ModuleList()
|
115 |
+
self.flows.append(modules.ElementwiseAffine(2))
|
116 |
+
for i in range(n_flows):
|
117 |
+
self.flows.append(modules.ConvFlow(2, filter_channels, kernel_size, n_layers=3))
|
118 |
+
self.flows.append(modules.Flip())
|
119 |
+
|
120 |
+
self.post_pre = nn.Conv1d(1, filter_channels, 1)
|
121 |
+
self.post_proj = nn.Conv1d(filter_channels, filter_channels, 1)
|
122 |
+
self.post_convs = modules.DDSConv(filter_channels, kernel_size, n_layers=3, p_dropout=p_dropout)
|
123 |
+
self.post_flows = nn.ModuleList()
|
124 |
+
self.post_flows.append(modules.ElementwiseAffine(2))
|
125 |
+
for i in range(4):
|
126 |
+
self.post_flows.append(modules.ConvFlow(2, filter_channels, kernel_size, n_layers=3))
|
127 |
+
self.post_flows.append(modules.Flip())
|
128 |
+
|
129 |
+
self.pre = nn.Conv1d(in_channels, filter_channels, 1)
|
130 |
+
self.proj = nn.Conv1d(filter_channels, filter_channels, 1)
|
131 |
+
self.convs = modules.DDSConv(filter_channels, kernel_size, n_layers=3, p_dropout=p_dropout)
|
132 |
+
if gin_channels != 0:
|
133 |
+
self.cond = nn.Conv1d(gin_channels, filter_channels, 1)
|
134 |
+
|
135 |
+
def forward(self, x, x_mask, w=None, g=None, reverse=False, noise_scale=1.0):
|
136 |
+
x = torch.detach(x)
|
137 |
+
x = self.pre(x)
|
138 |
+
if g is not None:
|
139 |
+
g = torch.detach(g)
|
140 |
+
x = x + self.cond(g)
|
141 |
+
x = self.convs(x, x_mask)
|
142 |
+
x = self.proj(x) * x_mask
|
143 |
+
|
144 |
+
if not reverse:
|
145 |
+
flows = self.flows
|
146 |
+
assert w is not None
|
147 |
+
|
148 |
+
logdet_tot_q = 0
|
149 |
+
h_w = self.post_pre(w)
|
150 |
+
h_w = self.post_convs(h_w, x_mask)
|
151 |
+
h_w = self.post_proj(h_w) * x_mask
|
152 |
+
e_q = torch.randn(w.size(0), 2, w.size(2)).to(device=x.device, dtype=x.dtype) * x_mask
|
153 |
+
z_q = e_q
|
154 |
+
for flow in self.post_flows:
|
155 |
+
z_q, logdet_q = flow(z_q, x_mask, g=(x + h_w))
|
156 |
+
logdet_tot_q += logdet_q
|
157 |
+
z_u, z1 = torch.split(z_q, [1, 1], 1)
|
158 |
+
u = torch.sigmoid(z_u) * x_mask
|
159 |
+
z0 = (w - u) * x_mask
|
160 |
+
logdet_tot_q += torch.sum((F.logsigmoid(z_u) + F.logsigmoid(-z_u)) * x_mask, [1,2])
|
161 |
+
logq = torch.sum(-0.5 * (math.log(2*math.pi) + (e_q**2)) * x_mask, [1,2]) - logdet_tot_q
|
162 |
+
|
163 |
+
logdet_tot = 0
|
164 |
+
z0, logdet = self.log_flow(z0, x_mask)
|
165 |
+
logdet_tot += logdet
|
166 |
+
z = torch.cat([z0, z1], 1)
|
167 |
+
for flow in flows:
|
168 |
+
z, logdet = flow(z, x_mask, g=x, reverse=reverse)
|
169 |
+
logdet_tot = logdet_tot + logdet
|
170 |
+
nll = torch.sum(0.5 * (math.log(2*math.pi) + (z**2)) * x_mask, [1,2]) - logdet_tot
|
171 |
+
return nll + logq # [b]
|
172 |
+
else:
|
173 |
+
flows = list(reversed(self.flows))
|
174 |
+
flows = flows[:-2] + [flows[-1]] # remove a useless vflow
|
175 |
+
z = torch.randn(x.size(0), 2, x.size(2)).to(device=x.device, dtype=x.dtype) * noise_scale
|
176 |
+
for flow in flows:
|
177 |
+
z = flow(z, x_mask, g=x, reverse=reverse)
|
178 |
+
z0, z1 = torch.split(z, [1, 1], 1)
|
179 |
+
logw = z0
|
180 |
+
return logw
|
181 |
+
|
182 |
+
class PosteriorEncoder(nn.Module):
|
183 |
+
def __init__(
|
184 |
+
self,
|
185 |
+
in_channels,
|
186 |
+
out_channels,
|
187 |
+
hidden_channels,
|
188 |
+
kernel_size,
|
189 |
+
dilation_rate,
|
190 |
+
n_layers,
|
191 |
+
gin_channels=0,
|
192 |
+
):
|
193 |
+
super().__init__()
|
194 |
+
self.in_channels = in_channels
|
195 |
+
self.out_channels = out_channels
|
196 |
+
self.hidden_channels = hidden_channels
|
197 |
+
self.kernel_size = kernel_size
|
198 |
+
self.dilation_rate = dilation_rate
|
199 |
+
self.n_layers = n_layers
|
200 |
+
self.gin_channels = gin_channels
|
201 |
+
|
202 |
+
self.pre = nn.Conv1d(in_channels, hidden_channels, 1)
|
203 |
+
self.enc = modules.WN(
|
204 |
+
hidden_channels,
|
205 |
+
kernel_size,
|
206 |
+
dilation_rate,
|
207 |
+
n_layers,
|
208 |
+
gin_channels=gin_channels,
|
209 |
+
)
|
210 |
+
self.proj = nn.Conv1d(hidden_channels, out_channels * 2, 1)
|
211 |
+
|
212 |
+
def forward(self, x, x_lengths, g=None, tau=1.0):
|
213 |
+
x_mask = torch.unsqueeze(commons.sequence_mask(x_lengths, x.size(2)), 1).to(
|
214 |
+
x.dtype
|
215 |
+
)
|
216 |
+
x = self.pre(x) * x_mask
|
217 |
+
x = self.enc(x, x_mask, g=g)
|
218 |
+
stats = self.proj(x) * x_mask
|
219 |
+
m, logs = torch.split(stats, self.out_channels, dim=1)
|
220 |
+
z = (m + torch.randn_like(m) * tau * torch.exp(logs)) * x_mask
|
221 |
+
return z, m, logs, x_mask
|
222 |
+
|
223 |
+
|
224 |
+
class Generator(torch.nn.Module):
|
225 |
+
def __init__(
|
226 |
+
self,
|
227 |
+
initial_channel,
|
228 |
+
resblock,
|
229 |
+
resblock_kernel_sizes,
|
230 |
+
resblock_dilation_sizes,
|
231 |
+
upsample_rates,
|
232 |
+
upsample_initial_channel,
|
233 |
+
upsample_kernel_sizes,
|
234 |
+
gin_channels=0,
|
235 |
+
):
|
236 |
+
super(Generator, self).__init__()
|
237 |
+
self.num_kernels = len(resblock_kernel_sizes)
|
238 |
+
self.num_upsamples = len(upsample_rates)
|
239 |
+
self.conv_pre = Conv1d(
|
240 |
+
initial_channel, upsample_initial_channel, 7, 1, padding=3
|
241 |
+
)
|
242 |
+
resblock = modules.ResBlock1 if resblock == "1" else modules.ResBlock2
|
243 |
+
|
244 |
+
self.ups = nn.ModuleList()
|
245 |
+
for i, (u, k) in enumerate(zip(upsample_rates, upsample_kernel_sizes)):
|
246 |
+
self.ups.append(
|
247 |
+
weight_norm(
|
248 |
+
ConvTranspose1d(
|
249 |
+
upsample_initial_channel // (2**i),
|
250 |
+
upsample_initial_channel // (2 ** (i + 1)),
|
251 |
+
k,
|
252 |
+
u,
|
253 |
+
padding=(k - u) // 2,
|
254 |
+
)
|
255 |
+
)
|
256 |
+
)
|
257 |
+
|
258 |
+
self.resblocks = nn.ModuleList()
|
259 |
+
for i in range(len(self.ups)):
|
260 |
+
ch = upsample_initial_channel // (2 ** (i + 1))
|
261 |
+
for j, (k, d) in enumerate(
|
262 |
+
zip(resblock_kernel_sizes, resblock_dilation_sizes)
|
263 |
+
):
|
264 |
+
self.resblocks.append(resblock(ch, k, d))
|
265 |
+
|
266 |
+
self.conv_post = Conv1d(ch, 1, 7, 1, padding=3, bias=False)
|
267 |
+
self.ups.apply(init_weights)
|
268 |
+
|
269 |
+
if gin_channels != 0:
|
270 |
+
self.cond = nn.Conv1d(gin_channels, upsample_initial_channel, 1)
|
271 |
+
|
272 |
+
def forward(self, x, g=None):
|
273 |
+
x = self.conv_pre(x)
|
274 |
+
if g is not None:
|
275 |
+
x = x + self.cond(g)
|
276 |
+
|
277 |
+
for i in range(self.num_upsamples):
|
278 |
+
x = F.leaky_relu(x, modules.LRELU_SLOPE)
|
279 |
+
x = self.ups[i](x)
|
280 |
+
xs = None
|
281 |
+
for j in range(self.num_kernels):
|
282 |
+
if xs is None:
|
283 |
+
xs = self.resblocks[i * self.num_kernels + j](x)
|
284 |
+
else:
|
285 |
+
xs += self.resblocks[i * self.num_kernels + j](x)
|
286 |
+
x = xs / self.num_kernels
|
287 |
+
x = F.leaky_relu(x)
|
288 |
+
x = self.conv_post(x)
|
289 |
+
x = torch.tanh(x)
|
290 |
+
|
291 |
+
return x
|
292 |
+
|
293 |
+
def remove_weight_norm(self):
|
294 |
+
print("Removing weight norm...")
|
295 |
+
for layer in self.ups:
|
296 |
+
remove_weight_norm(layer)
|
297 |
+
for layer in self.resblocks:
|
298 |
+
layer.remove_weight_norm()
|
299 |
+
|
300 |
+
|
301 |
+
class ReferenceEncoder(nn.Module):
|
302 |
+
"""
|
303 |
+
inputs --- [N, Ty/r, n_mels*r] mels
|
304 |
+
outputs --- [N, ref_enc_gru_size]
|
305 |
+
"""
|
306 |
+
|
307 |
+
def __init__(self, spec_channels, gin_channels=0, layernorm=True):
|
308 |
+
super().__init__()
|
309 |
+
self.spec_channels = spec_channels
|
310 |
+
ref_enc_filters = [32, 32, 64, 64, 128, 128]
|
311 |
+
K = len(ref_enc_filters)
|
312 |
+
filters = [1] + ref_enc_filters
|
313 |
+
convs = [
|
314 |
+
weight_norm(
|
315 |
+
nn.Conv2d(
|
316 |
+
in_channels=filters[i],
|
317 |
+
out_channels=filters[i + 1],
|
318 |
+
kernel_size=(3, 3),
|
319 |
+
stride=(2, 2),
|
320 |
+
padding=(1, 1),
|
321 |
+
)
|
322 |
+
)
|
323 |
+
for i in range(K)
|
324 |
+
]
|
325 |
+
self.convs = nn.ModuleList(convs)
|
326 |
+
|
327 |
+
out_channels = self.calculate_channels(spec_channels, 3, 2, 1, K)
|
328 |
+
self.gru = nn.GRU(
|
329 |
+
input_size=ref_enc_filters[-1] * out_channels,
|
330 |
+
hidden_size=256 // 2,
|
331 |
+
batch_first=True,
|
332 |
+
)
|
333 |
+
self.proj = nn.Linear(128, gin_channels)
|
334 |
+
if layernorm:
|
335 |
+
self.layernorm = nn.LayerNorm(self.spec_channels)
|
336 |
+
else:
|
337 |
+
self.layernorm = None
|
338 |
+
|
339 |
+
def forward(self, inputs, mask=None):
|
340 |
+
N = inputs.size(0)
|
341 |
+
|
342 |
+
out = inputs.view(N, 1, -1, self.spec_channels) # [N, 1, Ty, n_freqs]
|
343 |
+
if self.layernorm is not None:
|
344 |
+
out = self.layernorm(out)
|
345 |
+
|
346 |
+
for conv in self.convs:
|
347 |
+
out = conv(out)
|
348 |
+
# out = wn(out)
|
349 |
+
out = F.relu(out) # [N, 128, Ty//2^K, n_mels//2^K]
|
350 |
+
|
351 |
+
out = out.transpose(1, 2) # [N, Ty//2^K, 128, n_mels//2^K]
|
352 |
+
T = out.size(1)
|
353 |
+
N = out.size(0)
|
354 |
+
out = out.contiguous().view(N, T, -1) # [N, Ty//2^K, 128*n_mels//2^K]
|
355 |
+
|
356 |
+
self.gru.flatten_parameters()
|
357 |
+
memory, out = self.gru(out) # out --- [1, N, 128]
|
358 |
+
|
359 |
+
return self.proj(out.squeeze(0))
|
360 |
+
|
361 |
+
def calculate_channels(self, L, kernel_size, stride, pad, n_convs):
|
362 |
+
for i in range(n_convs):
|
363 |
+
L = (L - kernel_size + 2 * pad) // stride + 1
|
364 |
+
return L
|
365 |
+
|
366 |
+
|
367 |
+
class ResidualCouplingBlock(nn.Module):
|
368 |
+
def __init__(self,
|
369 |
+
channels,
|
370 |
+
hidden_channels,
|
371 |
+
kernel_size,
|
372 |
+
dilation_rate,
|
373 |
+
n_layers,
|
374 |
+
n_flows=4,
|
375 |
+
gin_channels=0):
|
376 |
+
super().__init__()
|
377 |
+
self.channels = channels
|
378 |
+
self.hidden_channels = hidden_channels
|
379 |
+
self.kernel_size = kernel_size
|
380 |
+
self.dilation_rate = dilation_rate
|
381 |
+
self.n_layers = n_layers
|
382 |
+
self.n_flows = n_flows
|
383 |
+
self.gin_channels = gin_channels
|
384 |
+
|
385 |
+
self.flows = nn.ModuleList()
|
386 |
+
for i in range(n_flows):
|
387 |
+
self.flows.append(modules.ResidualCouplingLayer(channels, hidden_channels, kernel_size, dilation_rate, n_layers, gin_channels=gin_channels, mean_only=True))
|
388 |
+
self.flows.append(modules.Flip())
|
389 |
+
|
390 |
+
def forward(self, x, x_mask, g=None, reverse=False):
|
391 |
+
if not reverse:
|
392 |
+
for flow in self.flows:
|
393 |
+
x, _ = flow(x, x_mask, g=g, reverse=reverse)
|
394 |
+
else:
|
395 |
+
for flow in reversed(self.flows):
|
396 |
+
x = flow(x, x_mask, g=g, reverse=reverse)
|
397 |
+
return x
|
398 |
+
|
399 |
+
class SynthesizerTrn(nn.Module):
|
400 |
+
"""
|
401 |
+
Synthesizer for Training
|
402 |
+
"""
|
403 |
+
|
404 |
+
def __init__(
|
405 |
+
self,
|
406 |
+
n_vocab,
|
407 |
+
spec_channels,
|
408 |
+
inter_channels,
|
409 |
+
hidden_channels,
|
410 |
+
filter_channels,
|
411 |
+
n_heads,
|
412 |
+
n_layers,
|
413 |
+
kernel_size,
|
414 |
+
p_dropout,
|
415 |
+
resblock,
|
416 |
+
resblock_kernel_sizes,
|
417 |
+
resblock_dilation_sizes,
|
418 |
+
upsample_rates,
|
419 |
+
upsample_initial_channel,
|
420 |
+
upsample_kernel_sizes,
|
421 |
+
n_speakers=256,
|
422 |
+
gin_channels=256,
|
423 |
+
zero_g=False,
|
424 |
+
**kwargs
|
425 |
+
):
|
426 |
+
super().__init__()
|
427 |
+
|
428 |
+
self.dec = Generator(
|
429 |
+
inter_channels,
|
430 |
+
resblock,
|
431 |
+
resblock_kernel_sizes,
|
432 |
+
resblock_dilation_sizes,
|
433 |
+
upsample_rates,
|
434 |
+
upsample_initial_channel,
|
435 |
+
upsample_kernel_sizes,
|
436 |
+
gin_channels=gin_channels,
|
437 |
+
)
|
438 |
+
self.enc_q = PosteriorEncoder(
|
439 |
+
spec_channels,
|
440 |
+
inter_channels,
|
441 |
+
hidden_channels,
|
442 |
+
5,
|
443 |
+
1,
|
444 |
+
16,
|
445 |
+
gin_channels=gin_channels,
|
446 |
+
)
|
447 |
+
|
448 |
+
self.flow = ResidualCouplingBlock(inter_channels, hidden_channels, 5, 1, 4, gin_channels=gin_channels)
|
449 |
+
|
450 |
+
self.n_speakers = n_speakers
|
451 |
+
if n_speakers == 0:
|
452 |
+
self.ref_enc = ReferenceEncoder(spec_channels, gin_channels)
|
453 |
+
else:
|
454 |
+
self.enc_p = TextEncoder(n_vocab,
|
455 |
+
inter_channels,
|
456 |
+
hidden_channels,
|
457 |
+
filter_channels,
|
458 |
+
n_heads,
|
459 |
+
n_layers,
|
460 |
+
kernel_size,
|
461 |
+
p_dropout)
|
462 |
+
self.sdp = StochasticDurationPredictor(hidden_channels, 192, 3, 0.5, 4, gin_channels=gin_channels)
|
463 |
+
self.dp = DurationPredictor(hidden_channels, 256, 3, 0.5, gin_channels=gin_channels)
|
464 |
+
self.emb_g = nn.Embedding(n_speakers, gin_channels)
|
465 |
+
self.zero_g = zero_g
|
466 |
+
|
467 |
+
def infer(self, x, x_lengths, sid=None, noise_scale=1, length_scale=1, noise_scale_w=1., sdp_ratio=0.2, max_len=None):
|
468 |
+
x, m_p, logs_p, x_mask = self.enc_p(x, x_lengths)
|
469 |
+
if self.n_speakers > 0:
|
470 |
+
g = self.emb_g(sid).unsqueeze(-1) # [b, h, 1]
|
471 |
+
else:
|
472 |
+
g = None
|
473 |
+
|
474 |
+
logw = self.sdp(x, x_mask, g=g, reverse=True, noise_scale=noise_scale_w) * sdp_ratio \
|
475 |
+
+ self.dp(x, x_mask, g=g) * (1 - sdp_ratio)
|
476 |
+
|
477 |
+
w = torch.exp(logw) * x_mask * length_scale
|
478 |
+
w_ceil = torch.ceil(w)
|
479 |
+
y_lengths = torch.clamp_min(torch.sum(w_ceil, [1, 2]), 1).long()
|
480 |
+
y_mask = torch.unsqueeze(commons.sequence_mask(y_lengths, None), 1).to(x_mask.dtype)
|
481 |
+
attn_mask = torch.unsqueeze(x_mask, 2) * torch.unsqueeze(y_mask, -1)
|
482 |
+
attn = commons.generate_path(w_ceil, attn_mask)
|
483 |
+
|
484 |
+
m_p = torch.matmul(attn.squeeze(1), m_p.transpose(1, 2)).transpose(1, 2) # [b, t', t], [b, t, d] -> [b, d, t']
|
485 |
+
logs_p = torch.matmul(attn.squeeze(1), logs_p.transpose(1, 2)).transpose(1, 2) # [b, t', t], [b, t, d] -> [b, d, t']
|
486 |
+
|
487 |
+
z_p = m_p + torch.randn_like(m_p) * torch.exp(logs_p) * noise_scale
|
488 |
+
z = self.flow(z_p, y_mask, g=g, reverse=True)
|
489 |
+
o = self.dec((z * y_mask)[:,:,:max_len], g=g)
|
490 |
+
return o, attn, y_mask, (z, z_p, m_p, logs_p)
|
491 |
+
|
492 |
+
def voice_conversion(self, y, y_lengths, sid_src, sid_tgt, tau=1.0):
|
493 |
+
g_src = sid_src
|
494 |
+
g_tgt = sid_tgt
|
495 |
+
z, m_q, logs_q, y_mask = self.enc_q(y, y_lengths, g=g_src if not self.zero_g else torch.zeros_like(g_src), tau=tau)
|
496 |
+
z_p = self.flow(z, y_mask, g=g_src)
|
497 |
+
z_hat = self.flow(z_p, y_mask, g=g_tgt, reverse=True)
|
498 |
+
o_hat = self.dec(z_hat * y_mask, g=g_tgt if not self.zero_g else torch.zeros_like(g_tgt))
|
499 |
+
return o_hat, y_mask, (z, z_p, z_hat)
|
OpenVoice/openvoice/modules.py
ADDED
@@ -0,0 +1,598 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
import math
|
2 |
+
import torch
|
3 |
+
from torch import nn
|
4 |
+
from torch.nn import functional as F
|
5 |
+
|
6 |
+
from torch.nn import Conv1d
|
7 |
+
from torch.nn.utils import weight_norm, remove_weight_norm
|
8 |
+
|
9 |
+
from openvoice import commons
|
10 |
+
from openvoice.commons import init_weights, get_padding
|
11 |
+
from openvoice.transforms import piecewise_rational_quadratic_transform
|
12 |
+
from openvoice.attentions import Encoder
|
13 |
+
|
14 |
+
LRELU_SLOPE = 0.1
|
15 |
+
|
16 |
+
|
17 |
+
class LayerNorm(nn.Module):
|
18 |
+
def __init__(self, channels, eps=1e-5):
|
19 |
+
super().__init__()
|
20 |
+
self.channels = channels
|
21 |
+
self.eps = eps
|
22 |
+
|
23 |
+
self.gamma = nn.Parameter(torch.ones(channels))
|
24 |
+
self.beta = nn.Parameter(torch.zeros(channels))
|
25 |
+
|
26 |
+
def forward(self, x):
|
27 |
+
x = x.transpose(1, -1)
|
28 |
+
x = F.layer_norm(x, (self.channels,), self.gamma, self.beta, self.eps)
|
29 |
+
return x.transpose(1, -1)
|
30 |
+
|
31 |
+
|
32 |
+
class ConvReluNorm(nn.Module):
|
33 |
+
def __init__(
|
34 |
+
self,
|
35 |
+
in_channels,
|
36 |
+
hidden_channels,
|
37 |
+
out_channels,
|
38 |
+
kernel_size,
|
39 |
+
n_layers,
|
40 |
+
p_dropout,
|
41 |
+
):
|
42 |
+
super().__init__()
|
43 |
+
self.in_channels = in_channels
|
44 |
+
self.hidden_channels = hidden_channels
|
45 |
+
self.out_channels = out_channels
|
46 |
+
self.kernel_size = kernel_size
|
47 |
+
self.n_layers = n_layers
|
48 |
+
self.p_dropout = p_dropout
|
49 |
+
assert n_layers > 1, "Number of layers should be larger than 0."
|
50 |
+
|
51 |
+
self.conv_layers = nn.ModuleList()
|
52 |
+
self.norm_layers = nn.ModuleList()
|
53 |
+
self.conv_layers.append(
|
54 |
+
nn.Conv1d(
|
55 |
+
in_channels, hidden_channels, kernel_size, padding=kernel_size // 2
|
56 |
+
)
|
57 |
+
)
|
58 |
+
self.norm_layers.append(LayerNorm(hidden_channels))
|
59 |
+
self.relu_drop = nn.Sequential(nn.ReLU(), nn.Dropout(p_dropout))
|
60 |
+
for _ in range(n_layers - 1):
|
61 |
+
self.conv_layers.append(
|
62 |
+
nn.Conv1d(
|
63 |
+
hidden_channels,
|
64 |
+
hidden_channels,
|
65 |
+
kernel_size,
|
66 |
+
padding=kernel_size // 2,
|
67 |
+
)
|
68 |
+
)
|
69 |
+
self.norm_layers.append(LayerNorm(hidden_channels))
|
70 |
+
self.proj = nn.Conv1d(hidden_channels, out_channels, 1)
|
71 |
+
self.proj.weight.data.zero_()
|
72 |
+
self.proj.bias.data.zero_()
|
73 |
+
|
74 |
+
def forward(self, x, x_mask):
|
75 |
+
x_org = x
|
76 |
+
for i in range(self.n_layers):
|
77 |
+
x = self.conv_layers[i](x * x_mask)
|
78 |
+
x = self.norm_layers[i](x)
|
79 |
+
x = self.relu_drop(x)
|
80 |
+
x = x_org + self.proj(x)
|
81 |
+
return x * x_mask
|
82 |
+
|
83 |
+
|
84 |
+
class DDSConv(nn.Module):
|
85 |
+
"""
|
86 |
+
Dilated and Depth-Separable Convolution
|
87 |
+
"""
|
88 |
+
|
89 |
+
def __init__(self, channels, kernel_size, n_layers, p_dropout=0.0):
|
90 |
+
super().__init__()
|
91 |
+
self.channels = channels
|
92 |
+
self.kernel_size = kernel_size
|
93 |
+
self.n_layers = n_layers
|
94 |
+
self.p_dropout = p_dropout
|
95 |
+
|
96 |
+
self.drop = nn.Dropout(p_dropout)
|
97 |
+
self.convs_sep = nn.ModuleList()
|
98 |
+
self.convs_1x1 = nn.ModuleList()
|
99 |
+
self.norms_1 = nn.ModuleList()
|
100 |
+
self.norms_2 = nn.ModuleList()
|
101 |
+
for i in range(n_layers):
|
102 |
+
dilation = kernel_size**i
|
103 |
+
padding = (kernel_size * dilation - dilation) // 2
|
104 |
+
self.convs_sep.append(
|
105 |
+
nn.Conv1d(
|
106 |
+
channels,
|
107 |
+
channels,
|
108 |
+
kernel_size,
|
109 |
+
groups=channels,
|
110 |
+
dilation=dilation,
|
111 |
+
padding=padding,
|
112 |
+
)
|
113 |
+
)
|
114 |
+
self.convs_1x1.append(nn.Conv1d(channels, channels, 1))
|
115 |
+
self.norms_1.append(LayerNorm(channels))
|
116 |
+
self.norms_2.append(LayerNorm(channels))
|
117 |
+
|
118 |
+
def forward(self, x, x_mask, g=None):
|
119 |
+
if g is not None:
|
120 |
+
x = x + g
|
121 |
+
for i in range(self.n_layers):
|
122 |
+
y = self.convs_sep[i](x * x_mask)
|
123 |
+
y = self.norms_1[i](y)
|
124 |
+
y = F.gelu(y)
|
125 |
+
y = self.convs_1x1[i](y)
|
126 |
+
y = self.norms_2[i](y)
|
127 |
+
y = F.gelu(y)
|
128 |
+
y = self.drop(y)
|
129 |
+
x = x + y
|
130 |
+
return x * x_mask
|
131 |
+
|
132 |
+
|
133 |
+
class WN(torch.nn.Module):
|
134 |
+
def __init__(
|
135 |
+
self,
|
136 |
+
hidden_channels,
|
137 |
+
kernel_size,
|
138 |
+
dilation_rate,
|
139 |
+
n_layers,
|
140 |
+
gin_channels=0,
|
141 |
+
p_dropout=0,
|
142 |
+
):
|
143 |
+
super(WN, self).__init__()
|
144 |
+
assert kernel_size % 2 == 1
|
145 |
+
self.hidden_channels = hidden_channels
|
146 |
+
self.kernel_size = (kernel_size,)
|
147 |
+
self.dilation_rate = dilation_rate
|
148 |
+
self.n_layers = n_layers
|
149 |
+
self.gin_channels = gin_channels
|
150 |
+
self.p_dropout = p_dropout
|
151 |
+
|
152 |
+
self.in_layers = torch.nn.ModuleList()
|
153 |
+
self.res_skip_layers = torch.nn.ModuleList()
|
154 |
+
self.drop = nn.Dropout(p_dropout)
|
155 |
+
|
156 |
+
if gin_channels != 0:
|
157 |
+
cond_layer = torch.nn.Conv1d(
|
158 |
+
gin_channels, 2 * hidden_channels * n_layers, 1
|
159 |
+
)
|
160 |
+
self.cond_layer = torch.nn.utils.weight_norm(cond_layer, name="weight")
|
161 |
+
|
162 |
+
for i in range(n_layers):
|
163 |
+
dilation = dilation_rate**i
|
164 |
+
padding = int((kernel_size * dilation - dilation) / 2)
|
165 |
+
in_layer = torch.nn.Conv1d(
|
166 |
+
hidden_channels,
|
167 |
+
2 * hidden_channels,
|
168 |
+
kernel_size,
|
169 |
+
dilation=dilation,
|
170 |
+
padding=padding,
|
171 |
+
)
|
172 |
+
in_layer = torch.nn.utils.weight_norm(in_layer, name="weight")
|
173 |
+
self.in_layers.append(in_layer)
|
174 |
+
|
175 |
+
# last one is not necessary
|
176 |
+
if i < n_layers - 1:
|
177 |
+
res_skip_channels = 2 * hidden_channels
|
178 |
+
else:
|
179 |
+
res_skip_channels = hidden_channels
|
180 |
+
|
181 |
+
res_skip_layer = torch.nn.Conv1d(hidden_channels, res_skip_channels, 1)
|
182 |
+
res_skip_layer = torch.nn.utils.weight_norm(res_skip_layer, name="weight")
|
183 |
+
self.res_skip_layers.append(res_skip_layer)
|
184 |
+
|
185 |
+
def forward(self, x, x_mask, g=None, **kwargs):
|
186 |
+
output = torch.zeros_like(x)
|
187 |
+
n_channels_tensor = torch.IntTensor([self.hidden_channels])
|
188 |
+
|
189 |
+
if g is not None:
|
190 |
+
g = self.cond_layer(g)
|
191 |
+
|
192 |
+
for i in range(self.n_layers):
|
193 |
+
x_in = self.in_layers[i](x)
|
194 |
+
if g is not None:
|
195 |
+
cond_offset = i * 2 * self.hidden_channels
|
196 |
+
g_l = g[:, cond_offset : cond_offset + 2 * self.hidden_channels, :]
|
197 |
+
else:
|
198 |
+
g_l = torch.zeros_like(x_in)
|
199 |
+
|
200 |
+
acts = commons.fused_add_tanh_sigmoid_multiply(x_in, g_l, n_channels_tensor)
|
201 |
+
acts = self.drop(acts)
|
202 |
+
|
203 |
+
res_skip_acts = self.res_skip_layers[i](acts)
|
204 |
+
if i < self.n_layers - 1:
|
205 |
+
res_acts = res_skip_acts[:, : self.hidden_channels, :]
|
206 |
+
x = (x + res_acts) * x_mask
|
207 |
+
output = output + res_skip_acts[:, self.hidden_channels :, :]
|
208 |
+
else:
|
209 |
+
output = output + res_skip_acts
|
210 |
+
return output * x_mask
|
211 |
+
|
212 |
+
def remove_weight_norm(self):
|
213 |
+
if self.gin_channels != 0:
|
214 |
+
torch.nn.utils.remove_weight_norm(self.cond_layer)
|
215 |
+
for l in self.in_layers:
|
216 |
+
torch.nn.utils.remove_weight_norm(l)
|
217 |
+
for l in self.res_skip_layers:
|
218 |
+
torch.nn.utils.remove_weight_norm(l)
|
219 |
+
|
220 |
+
|
221 |
+
class ResBlock1(torch.nn.Module):
|
222 |
+
def __init__(self, channels, kernel_size=3, dilation=(1, 3, 5)):
|
223 |
+
super(ResBlock1, self).__init__()
|
224 |
+
self.convs1 = nn.ModuleList(
|
225 |
+
[
|
226 |
+
weight_norm(
|
227 |
+
Conv1d(
|
228 |
+
channels,
|
229 |
+
channels,
|
230 |
+
kernel_size,
|
231 |
+
1,
|
232 |
+
dilation=dilation[0],
|
233 |
+
padding=get_padding(kernel_size, dilation[0]),
|
234 |
+
)
|
235 |
+
),
|
236 |
+
weight_norm(
|
237 |
+
Conv1d(
|
238 |
+
channels,
|
239 |
+
channels,
|
240 |
+
kernel_size,
|
241 |
+
1,
|
242 |
+
dilation=dilation[1],
|
243 |
+
padding=get_padding(kernel_size, dilation[1]),
|
244 |
+
)
|
245 |
+
),
|
246 |
+
weight_norm(
|
247 |
+
Conv1d(
|
248 |
+
channels,
|
249 |
+
channels,
|
250 |
+
kernel_size,
|
251 |
+
1,
|
252 |
+
dilation=dilation[2],
|
253 |
+
padding=get_padding(kernel_size, dilation[2]),
|
254 |
+
)
|
255 |
+
),
|
256 |
+
]
|
257 |
+
)
|
258 |
+
self.convs1.apply(init_weights)
|
259 |
+
|
260 |
+
self.convs2 = nn.ModuleList(
|
261 |
+
[
|
262 |
+
weight_norm(
|
263 |
+
Conv1d(
|
264 |
+
channels,
|
265 |
+
channels,
|
266 |
+
kernel_size,
|
267 |
+
1,
|
268 |
+
dilation=1,
|
269 |
+
padding=get_padding(kernel_size, 1),
|
270 |
+
)
|
271 |
+
),
|
272 |
+
weight_norm(
|
273 |
+
Conv1d(
|
274 |
+
channels,
|
275 |
+
channels,
|
276 |
+
kernel_size,
|
277 |
+
1,
|
278 |
+
dilation=1,
|
279 |
+
padding=get_padding(kernel_size, 1),
|
280 |
+
)
|
281 |
+
),
|
282 |
+
weight_norm(
|
283 |
+
Conv1d(
|
284 |
+
channels,
|
285 |
+
channels,
|
286 |
+
kernel_size,
|
287 |
+
1,
|
288 |
+
dilation=1,
|
289 |
+
padding=get_padding(kernel_size, 1),
|
290 |
+
)
|
291 |
+
),
|
292 |
+
]
|
293 |
+
)
|
294 |
+
self.convs2.apply(init_weights)
|
295 |
+
|
296 |
+
def forward(self, x, x_mask=None):
|
297 |
+
for c1, c2 in zip(self.convs1, self.convs2):
|
298 |
+
xt = F.leaky_relu(x, LRELU_SLOPE)
|
299 |
+
if x_mask is not None:
|
300 |
+
xt = xt * x_mask
|
301 |
+
xt = c1(xt)
|
302 |
+
xt = F.leaky_relu(xt, LRELU_SLOPE)
|
303 |
+
if x_mask is not None:
|
304 |
+
xt = xt * x_mask
|
305 |
+
xt = c2(xt)
|
306 |
+
x = xt + x
|
307 |
+
if x_mask is not None:
|
308 |
+
x = x * x_mask
|
309 |
+
return x
|
310 |
+
|
311 |
+
def remove_weight_norm(self):
|
312 |
+
for l in self.convs1:
|
313 |
+
remove_weight_norm(l)
|
314 |
+
for l in self.convs2:
|
315 |
+
remove_weight_norm(l)
|
316 |
+
|
317 |
+
|
318 |
+
class ResBlock2(torch.nn.Module):
|
319 |
+
def __init__(self, channels, kernel_size=3, dilation=(1, 3)):
|
320 |
+
super(ResBlock2, self).__init__()
|
321 |
+
self.convs = nn.ModuleList(
|
322 |
+
[
|
323 |
+
weight_norm(
|
324 |
+
Conv1d(
|
325 |
+
channels,
|
326 |
+
channels,
|
327 |
+
kernel_size,
|
328 |
+
1,
|
329 |
+
dilation=dilation[0],
|
330 |
+
padding=get_padding(kernel_size, dilation[0]),
|
331 |
+
)
|
332 |
+
),
|
333 |
+
weight_norm(
|
334 |
+
Conv1d(
|
335 |
+
channels,
|
336 |
+
channels,
|
337 |
+
kernel_size,
|
338 |
+
1,
|
339 |
+
dilation=dilation[1],
|
340 |
+
padding=get_padding(kernel_size, dilation[1]),
|
341 |
+
)
|
342 |
+
),
|
343 |
+
]
|
344 |
+
)
|
345 |
+
self.convs.apply(init_weights)
|
346 |
+
|
347 |
+
def forward(self, x, x_mask=None):
|
348 |
+
for c in self.convs:
|
349 |
+
xt = F.leaky_relu(x, LRELU_SLOPE)
|
350 |
+
if x_mask is not None:
|
351 |
+
xt = xt * x_mask
|
352 |
+
xt = c(xt)
|
353 |
+
x = xt + x
|
354 |
+
if x_mask is not None:
|
355 |
+
x = x * x_mask
|
356 |
+
return x
|
357 |
+
|
358 |
+
def remove_weight_norm(self):
|
359 |
+
for l in self.convs:
|
360 |
+
remove_weight_norm(l)
|
361 |
+
|
362 |
+
|
363 |
+
class Log(nn.Module):
|
364 |
+
def forward(self, x, x_mask, reverse=False, **kwargs):
|
365 |
+
if not reverse:
|
366 |
+
y = torch.log(torch.clamp_min(x, 1e-5)) * x_mask
|
367 |
+
logdet = torch.sum(-y, [1, 2])
|
368 |
+
return y, logdet
|
369 |
+
else:
|
370 |
+
x = torch.exp(x) * x_mask
|
371 |
+
return x
|
372 |
+
|
373 |
+
|
374 |
+
class Flip(nn.Module):
|
375 |
+
def forward(self, x, *args, reverse=False, **kwargs):
|
376 |
+
x = torch.flip(x, [1])
|
377 |
+
if not reverse:
|
378 |
+
logdet = torch.zeros(x.size(0)).to(dtype=x.dtype, device=x.device)
|
379 |
+
return x, logdet
|
380 |
+
else:
|
381 |
+
return x
|
382 |
+
|
383 |
+
|
384 |
+
class ElementwiseAffine(nn.Module):
|
385 |
+
def __init__(self, channels):
|
386 |
+
super().__init__()
|
387 |
+
self.channels = channels
|
388 |
+
self.m = nn.Parameter(torch.zeros(channels, 1))
|
389 |
+
self.logs = nn.Parameter(torch.zeros(channels, 1))
|
390 |
+
|
391 |
+
def forward(self, x, x_mask, reverse=False, **kwargs):
|
392 |
+
if not reverse:
|
393 |
+
y = self.m + torch.exp(self.logs) * x
|
394 |
+
y = y * x_mask
|
395 |
+
logdet = torch.sum(self.logs * x_mask, [1, 2])
|
396 |
+
return y, logdet
|
397 |
+
else:
|
398 |
+
x = (x - self.m) * torch.exp(-self.logs) * x_mask
|
399 |
+
return x
|
400 |
+
|
401 |
+
|
402 |
+
class ResidualCouplingLayer(nn.Module):
|
403 |
+
def __init__(
|
404 |
+
self,
|
405 |
+
channels,
|
406 |
+
hidden_channels,
|
407 |
+
kernel_size,
|
408 |
+
dilation_rate,
|
409 |
+
n_layers,
|
410 |
+
p_dropout=0,
|
411 |
+
gin_channels=0,
|
412 |
+
mean_only=False,
|
413 |
+
):
|
414 |
+
assert channels % 2 == 0, "channels should be divisible by 2"
|
415 |
+
super().__init__()
|
416 |
+
self.channels = channels
|
417 |
+
self.hidden_channels = hidden_channels
|
418 |
+
self.kernel_size = kernel_size
|
419 |
+
self.dilation_rate = dilation_rate
|
420 |
+
self.n_layers = n_layers
|
421 |
+
self.half_channels = channels // 2
|
422 |
+
self.mean_only = mean_only
|
423 |
+
|
424 |
+
self.pre = nn.Conv1d(self.half_channels, hidden_channels, 1)
|
425 |
+
self.enc = WN(
|
426 |
+
hidden_channels,
|
427 |
+
kernel_size,
|
428 |
+
dilation_rate,
|
429 |
+
n_layers,
|
430 |
+
p_dropout=p_dropout,
|
431 |
+
gin_channels=gin_channels,
|
432 |
+
)
|
433 |
+
self.post = nn.Conv1d(hidden_channels, self.half_channels * (2 - mean_only), 1)
|
434 |
+
self.post.weight.data.zero_()
|
435 |
+
self.post.bias.data.zero_()
|
436 |
+
|
437 |
+
def forward(self, x, x_mask, g=None, reverse=False):
|
438 |
+
x0, x1 = torch.split(x, [self.half_channels] * 2, 1)
|
439 |
+
h = self.pre(x0) * x_mask
|
440 |
+
h = self.enc(h, x_mask, g=g)
|
441 |
+
stats = self.post(h) * x_mask
|
442 |
+
if not self.mean_only:
|
443 |
+
m, logs = torch.split(stats, [self.half_channels] * 2, 1)
|
444 |
+
else:
|
445 |
+
m = stats
|
446 |
+
logs = torch.zeros_like(m)
|
447 |
+
|
448 |
+
if not reverse:
|
449 |
+
x1 = m + x1 * torch.exp(logs) * x_mask
|
450 |
+
x = torch.cat([x0, x1], 1)
|
451 |
+
logdet = torch.sum(logs, [1, 2])
|
452 |
+
return x, logdet
|
453 |
+
else:
|
454 |
+
x1 = (x1 - m) * torch.exp(-logs) * x_mask
|
455 |
+
x = torch.cat([x0, x1], 1)
|
456 |
+
return x
|
457 |
+
|
458 |
+
|
459 |
+
class ConvFlow(nn.Module):
|
460 |
+
def __init__(
|
461 |
+
self,
|
462 |
+
in_channels,
|
463 |
+
filter_channels,
|
464 |
+
kernel_size,
|
465 |
+
n_layers,
|
466 |
+
num_bins=10,
|
467 |
+
tail_bound=5.0,
|
468 |
+
):
|
469 |
+
super().__init__()
|
470 |
+
self.in_channels = in_channels
|
471 |
+
self.filter_channels = filter_channels
|
472 |
+
self.kernel_size = kernel_size
|
473 |
+
self.n_layers = n_layers
|
474 |
+
self.num_bins = num_bins
|
475 |
+
self.tail_bound = tail_bound
|
476 |
+
self.half_channels = in_channels // 2
|
477 |
+
|
478 |
+
self.pre = nn.Conv1d(self.half_channels, filter_channels, 1)
|
479 |
+
self.convs = DDSConv(filter_channels, kernel_size, n_layers, p_dropout=0.0)
|
480 |
+
self.proj = nn.Conv1d(
|
481 |
+
filter_channels, self.half_channels * (num_bins * 3 - 1), 1
|
482 |
+
)
|
483 |
+
self.proj.weight.data.zero_()
|
484 |
+
self.proj.bias.data.zero_()
|
485 |
+
|
486 |
+
def forward(self, x, x_mask, g=None, reverse=False):
|
487 |
+
x0, x1 = torch.split(x, [self.half_channels] * 2, 1)
|
488 |
+
h = self.pre(x0)
|
489 |
+
h = self.convs(h, x_mask, g=g)
|
490 |
+
h = self.proj(h) * x_mask
|
491 |
+
|
492 |
+
b, c, t = x0.shape
|
493 |
+
h = h.reshape(b, c, -1, t).permute(0, 1, 3, 2) # [b, cx?, t] -> [b, c, t, ?]
|
494 |
+
|
495 |
+
unnormalized_widths = h[..., : self.num_bins] / math.sqrt(self.filter_channels)
|
496 |
+
unnormalized_heights = h[..., self.num_bins : 2 * self.num_bins] / math.sqrt(
|
497 |
+
self.filter_channels
|
498 |
+
)
|
499 |
+
unnormalized_derivatives = h[..., 2 * self.num_bins :]
|
500 |
+
|
501 |
+
x1, logabsdet = piecewise_rational_quadratic_transform(
|
502 |
+
x1,
|
503 |
+
unnormalized_widths,
|
504 |
+
unnormalized_heights,
|
505 |
+
unnormalized_derivatives,
|
506 |
+
inverse=reverse,
|
507 |
+
tails="linear",
|
508 |
+
tail_bound=self.tail_bound,
|
509 |
+
)
|
510 |
+
|
511 |
+
x = torch.cat([x0, x1], 1) * x_mask
|
512 |
+
logdet = torch.sum(logabsdet * x_mask, [1, 2])
|
513 |
+
if not reverse:
|
514 |
+
return x, logdet
|
515 |
+
else:
|
516 |
+
return x
|
517 |
+
|
518 |
+
|
519 |
+
class TransformerCouplingLayer(nn.Module):
|
520 |
+
def __init__(
|
521 |
+
self,
|
522 |
+
channels,
|
523 |
+
hidden_channels,
|
524 |
+
kernel_size,
|
525 |
+
n_layers,
|
526 |
+
n_heads,
|
527 |
+
p_dropout=0,
|
528 |
+
filter_channels=0,
|
529 |
+
mean_only=False,
|
530 |
+
wn_sharing_parameter=None,
|
531 |
+
gin_channels=0,
|
532 |
+
):
|
533 |
+
assert n_layers == 3, n_layers
|
534 |
+
assert channels % 2 == 0, "channels should be divisible by 2"
|
535 |
+
super().__init__()
|
536 |
+
self.channels = channels
|
537 |
+
self.hidden_channels = hidden_channels
|
538 |
+
self.kernel_size = kernel_size
|
539 |
+
self.n_layers = n_layers
|
540 |
+
self.half_channels = channels // 2
|
541 |
+
self.mean_only = mean_only
|
542 |
+
|
543 |
+
self.pre = nn.Conv1d(self.half_channels, hidden_channels, 1)
|
544 |
+
self.enc = (
|
545 |
+
Encoder(
|
546 |
+
hidden_channels,
|
547 |
+
filter_channels,
|
548 |
+
n_heads,
|
549 |
+
n_layers,
|
550 |
+
kernel_size,
|
551 |
+
p_dropout,
|
552 |
+
isflow=True,
|
553 |
+
gin_channels=gin_channels,
|
554 |
+
)
|
555 |
+
if wn_sharing_parameter is None
|
556 |
+
else wn_sharing_parameter
|
557 |
+
)
|
558 |
+
self.post = nn.Conv1d(hidden_channels, self.half_channels * (2 - mean_only), 1)
|
559 |
+
self.post.weight.data.zero_()
|
560 |
+
self.post.bias.data.zero_()
|
561 |
+
|
562 |
+
def forward(self, x, x_mask, g=None, reverse=False):
|
563 |
+
x0, x1 = torch.split(x, [self.half_channels] * 2, 1)
|
564 |
+
h = self.pre(x0) * x_mask
|
565 |
+
h = self.enc(h, x_mask, g=g)
|
566 |
+
stats = self.post(h) * x_mask
|
567 |
+
if not self.mean_only:
|
568 |
+
m, logs = torch.split(stats, [self.half_channels] * 2, 1)
|
569 |
+
else:
|
570 |
+
m = stats
|
571 |
+
logs = torch.zeros_like(m)
|
572 |
+
|
573 |
+
if not reverse:
|
574 |
+
x1 = m + x1 * torch.exp(logs) * x_mask
|
575 |
+
x = torch.cat([x0, x1], 1)
|
576 |
+
logdet = torch.sum(logs, [1, 2])
|
577 |
+
return x, logdet
|
578 |
+
else:
|
579 |
+
x1 = (x1 - m) * torch.exp(-logs) * x_mask
|
580 |
+
x = torch.cat([x0, x1], 1)
|
581 |
+
return x
|
582 |
+
|
583 |
+
x1, logabsdet = piecewise_rational_quadratic_transform(
|
584 |
+
x1,
|
585 |
+
unnormalized_widths,
|
586 |
+
unnormalized_heights,
|
587 |
+
unnormalized_derivatives,
|
588 |
+
inverse=reverse,
|
589 |
+
tails="linear",
|
590 |
+
tail_bound=self.tail_bound,
|
591 |
+
)
|
592 |
+
|
593 |
+
x = torch.cat([x0, x1], 1) * x_mask
|
594 |
+
logdet = torch.sum(logabsdet * x_mask, [1, 2])
|
595 |
+
if not reverse:
|
596 |
+
return x, logdet
|
597 |
+
else:
|
598 |
+
return x
|
OpenVoice/openvoice/openvoice_app.py
ADDED
@@ -0,0 +1,275 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
import os
|
2 |
+
import torch
|
3 |
+
import argparse
|
4 |
+
import gradio as gr
|
5 |
+
from zipfile import ZipFile
|
6 |
+
import langid
|
7 |
+
from openvoice import se_extractor
|
8 |
+
from openvoice.api import BaseSpeakerTTS, ToneColorConverter
|
9 |
+
|
10 |
+
parser = argparse.ArgumentParser()
|
11 |
+
parser.add_argument("--share", action='store_true', default=False, help="make link public")
|
12 |
+
args = parser.parse_args()
|
13 |
+
|
14 |
+
en_ckpt_base = 'checkpoints/base_speakers/EN'
|
15 |
+
zh_ckpt_base = 'checkpoints/base_speakers/ZH'
|
16 |
+
ckpt_converter = 'checkpoints/converter'
|
17 |
+
device = 'cuda' if torch.cuda.is_available() else 'cpu'
|
18 |
+
output_dir = 'outputs'
|
19 |
+
os.makedirs(output_dir, exist_ok=True)
|
20 |
+
|
21 |
+
# load models
|
22 |
+
en_base_speaker_tts = BaseSpeakerTTS(f'{en_ckpt_base}/config.json', device=device)
|
23 |
+
en_base_speaker_tts.load_ckpt(f'{en_ckpt_base}/checkpoint.pth')
|
24 |
+
zh_base_speaker_tts = BaseSpeakerTTS(f'{zh_ckpt_base}/config.json', device=device)
|
25 |
+
zh_base_speaker_tts.load_ckpt(f'{zh_ckpt_base}/checkpoint.pth')
|
26 |
+
tone_color_converter = ToneColorConverter(f'{ckpt_converter}/config.json', device=device)
|
27 |
+
tone_color_converter.load_ckpt(f'{ckpt_converter}/checkpoint.pth')
|
28 |
+
|
29 |
+
# load speaker embeddings
|
30 |
+
en_source_default_se = torch.load(f'{en_ckpt_base}/en_default_se.pth').to(device)
|
31 |
+
en_source_style_se = torch.load(f'{en_ckpt_base}/en_style_se.pth').to(device)
|
32 |
+
zh_source_se = torch.load(f'{zh_ckpt_base}/zh_default_se.pth').to(device)
|
33 |
+
|
34 |
+
# This online demo mainly supports English and Chinese
|
35 |
+
supported_languages = ['zh', 'en']
|
36 |
+
|
37 |
+
def predict(prompt, style, audio_file_pth, agree):
|
38 |
+
# initialize a empty info
|
39 |
+
text_hint = ''
|
40 |
+
# agree with the terms
|
41 |
+
if agree == False:
|
42 |
+
text_hint += '[ERROR] Please accept the Terms & Condition!\n'
|
43 |
+
gr.Warning("Please accept the Terms & Condition!")
|
44 |
+
return (
|
45 |
+
text_hint,
|
46 |
+
None,
|
47 |
+
None,
|
48 |
+
)
|
49 |
+
|
50 |
+
# first detect the input language
|
51 |
+
language_predicted = langid.classify(prompt)[0].strip()
|
52 |
+
print(f"Detected language:{language_predicted}")
|
53 |
+
|
54 |
+
if language_predicted not in supported_languages:
|
55 |
+
text_hint += f"[ERROR] The detected language {language_predicted} for your input text is not in our Supported Languages: {supported_languages}\n"
|
56 |
+
gr.Warning(
|
57 |
+
f"The detected language {language_predicted} for your input text is not in our Supported Languages: {supported_languages}"
|
58 |
+
)
|
59 |
+
|
60 |
+
return (
|
61 |
+
text_hint,
|
62 |
+
None,
|
63 |
+
None,
|
64 |
+
)
|
65 |
+
|
66 |
+
if language_predicted == "zh":
|
67 |
+
tts_model = zh_base_speaker_tts
|
68 |
+
source_se = zh_source_se
|
69 |
+
language = 'Chinese'
|
70 |
+
if style not in ['default']:
|
71 |
+
text_hint += f"[ERROR] The style {style} is not supported for Chinese, which should be in ['default']\n"
|
72 |
+
gr.Warning(f"The style {style} is not supported for Chinese, which should be in ['default']")
|
73 |
+
return (
|
74 |
+
text_hint,
|
75 |
+
None,
|
76 |
+
None,
|
77 |
+
)
|
78 |
+
|
79 |
+
else:
|
80 |
+
tts_model = en_base_speaker_tts
|
81 |
+
if style == 'default':
|
82 |
+
source_se = en_source_default_se
|
83 |
+
else:
|
84 |
+
source_se = en_source_style_se
|
85 |
+
language = 'English'
|
86 |
+
if style not in ['default', 'whispering', 'shouting', 'excited', 'cheerful', 'terrified', 'angry', 'sad', 'friendly']:
|
87 |
+
text_hint += f"[ERROR] The style {style} is not supported for English, which should be in ['default', 'whispering', 'shouting', 'excited', 'cheerful', 'terrified', 'angry', 'sad', 'friendly']\n"
|
88 |
+
gr.Warning(f"The style {style} is not supported for English, which should be in ['default', 'whispering', 'shouting', 'excited', 'cheerful', 'terrified', 'angry', 'sad', 'friendly']")
|
89 |
+
return (
|
90 |
+
text_hint,
|
91 |
+
None,
|
92 |
+
None,
|
93 |
+
)
|
94 |
+
|
95 |
+
speaker_wav = audio_file_pth
|
96 |
+
|
97 |
+
if len(prompt) < 2:
|
98 |
+
text_hint += f"[ERROR] Please give a longer prompt text \n"
|
99 |
+
gr.Warning("Please give a longer prompt text")
|
100 |
+
return (
|
101 |
+
text_hint,
|
102 |
+
None,
|
103 |
+
None,
|
104 |
+
)
|
105 |
+
if len(prompt) > 200:
|
106 |
+
text_hint += f"[ERROR] Text length limited to 200 characters for this demo, please try shorter text. You can clone our open-source repo and try for your usage \n"
|
107 |
+
gr.Warning(
|
108 |
+
"Text length limited to 200 characters for this demo, please try shorter text. You can clone our open-source repo for your usage"
|
109 |
+
)
|
110 |
+
return (
|
111 |
+
text_hint,
|
112 |
+
None,
|
113 |
+
None,
|
114 |
+
)
|
115 |
+
|
116 |
+
# note diffusion_conditioning not used on hifigan (default mode), it will be empty but need to pass it to model.inference
|
117 |
+
try:
|
118 |
+
target_se, audio_name = se_extractor.get_se(speaker_wav, tone_color_converter, target_dir='processed', vad=True)
|
119 |
+
except Exception as e:
|
120 |
+
text_hint += f"[ERROR] Get target tone color error {str(e)} \n"
|
121 |
+
gr.Warning(
|
122 |
+
"[ERROR] Get target tone color error {str(e)} \n"
|
123 |
+
)
|
124 |
+
return (
|
125 |
+
text_hint,
|
126 |
+
None,
|
127 |
+
None,
|
128 |
+
)
|
129 |
+
|
130 |
+
src_path = f'{output_dir}/tmp.wav'
|
131 |
+
tts_model.tts(prompt, src_path, speaker=style, language=language)
|
132 |
+
|
133 |
+
save_path = f'{output_dir}/output.wav'
|
134 |
+
# Run the tone color converter
|
135 |
+
encode_message = "@MyShell"
|
136 |
+
tone_color_converter.convert(
|
137 |
+
audio_src_path=src_path,
|
138 |
+
src_se=source_se,
|
139 |
+
tgt_se=target_se,
|
140 |
+
output_path=save_path,
|
141 |
+
message=encode_message)
|
142 |
+
|
143 |
+
text_hint += f'''Get response successfully \n'''
|
144 |
+
|
145 |
+
return (
|
146 |
+
text_hint,
|
147 |
+
save_path,
|
148 |
+
speaker_wav,
|
149 |
+
)
|
150 |
+
|
151 |
+
|
152 |
+
|
153 |
+
title = "MyShell OpenVoice"
|
154 |
+
|
155 |
+
description = """
|
156 |
+
We introduce OpenVoice, a versatile instant voice cloning approach that requires only a short audio clip from the reference speaker to replicate their voice and generate speech in multiple languages. OpenVoice enables granular control over voice styles, including emotion, accent, rhythm, pauses, and intonation, in addition to replicating the tone color of the reference speaker. OpenVoice also achieves zero-shot cross-lingual voice cloning for languages not included in the massive-speaker training set.
|
157 |
+
"""
|
158 |
+
|
159 |
+
markdown_table = """
|
160 |
+
<div align="center" style="margin-bottom: 10px;">
|
161 |
+
|
162 |
+
| | | |
|
163 |
+
| :-----------: | :-----------: | :-----------: |
|
164 |
+
| **OpenSource Repo** | **Project Page** | **Join the Community** |
|
165 |
+
| <div style='text-align: center;'><a style="display:inline-block,align:center" href='https://github.com/myshell-ai/OpenVoice'><img src='https://img.shields.io/github/stars/myshell-ai/OpenVoice?style=social' /></a></div> | [OpenVoice](https://research.myshell.ai/open-voice) | [](https://discord.gg/myshell) |
|
166 |
+
|
167 |
+
</div>
|
168 |
+
"""
|
169 |
+
|
170 |
+
markdown_table_v2 = """
|
171 |
+
<div align="center" style="margin-bottom: 2px;">
|
172 |
+
|
173 |
+
| | | | |
|
174 |
+
| :-----------: | :-----------: | :-----------: | :-----------: |
|
175 |
+
| **OpenSource Repo** | <div style='text-align: center;'><a style="display:inline-block,align:center" href='https://github.com/myshell-ai/OpenVoice'><img src='https://img.shields.io/github/stars/myshell-ai/OpenVoice?style=social' /></a></div> | **Project Page** | [OpenVoice](https://research.myshell.ai/open-voice) |
|
176 |
+
|
177 |
+
| | |
|
178 |
+
| :-----------: | :-----------: |
|
179 |
+
**Join the Community** | [](https://discord.gg/myshell) |
|
180 |
+
|
181 |
+
</div>
|
182 |
+
"""
|
183 |
+
content = """
|
184 |
+
<div>
|
185 |
+
<strong>If the generated voice does not sound like the reference voice, please refer to <a href='https://github.com/myshell-ai/OpenVoice/blob/main/docs/QA.md'>this QnA</a>.</strong> <strong>For multi-lingual & cross-lingual examples, please refer to <a href='https://github.com/myshell-ai/OpenVoice/blob/main/demo_part2.ipynb'>this jupyter notebook</a>.</strong>
|
186 |
+
This online demo mainly supports <strong>English</strong>. The <em>default</em> style also supports <strong>Chinese</strong>. But OpenVoice can adapt to any other language as long as a base speaker is provided.
|
187 |
+
</div>
|
188 |
+
"""
|
189 |
+
wrapped_markdown_content = f"<div style='border: 1px solid #000; padding: 10px;'>{content}</div>"
|
190 |
+
|
191 |
+
|
192 |
+
examples = [
|
193 |
+
[
|
194 |
+
"今天天气真好,我们一起出去吃饭吧。",
|
195 |
+
'default',
|
196 |
+
"resources/demo_speaker1.mp3",
|
197 |
+
True,
|
198 |
+
],[
|
199 |
+
"This audio is generated by open voice with a half-performance model.",
|
200 |
+
'whispering',
|
201 |
+
"resources/demo_speaker2.mp3",
|
202 |
+
True,
|
203 |
+
],
|
204 |
+
[
|
205 |
+
"He hoped there would be stew for dinner, turnips and carrots and bruised potatoes and fat mutton pieces to be ladled out in thick, peppered, flour-fattened sauce.",
|
206 |
+
'sad',
|
207 |
+
"resources/demo_speaker0.mp3",
|
208 |
+
True,
|
209 |
+
],
|
210 |
+
]
|
211 |
+
|
212 |
+
with gr.Blocks(analytics_enabled=False) as demo:
|
213 |
+
|
214 |
+
with gr.Row():
|
215 |
+
with gr.Column():
|
216 |
+
with gr.Row():
|
217 |
+
gr.Markdown(
|
218 |
+
"""
|
219 |
+
## <img src="https://huggingface.co/spaces/myshell-ai/OpenVoice/raw/main/logo.jpg" height="40"/>
|
220 |
+
"""
|
221 |
+
)
|
222 |
+
with gr.Row():
|
223 |
+
gr.Markdown(markdown_table_v2)
|
224 |
+
with gr.Row():
|
225 |
+
gr.Markdown(description)
|
226 |
+
with gr.Column():
|
227 |
+
gr.Video('https://github.com/myshell-ai/OpenVoice/assets/40556743/3cba936f-82bf-476c-9e52-09f0f417bb2f', autoplay=True)
|
228 |
+
|
229 |
+
with gr.Row():
|
230 |
+
gr.HTML(wrapped_markdown_content)
|
231 |
+
|
232 |
+
with gr.Row():
|
233 |
+
with gr.Column():
|
234 |
+
input_text_gr = gr.Textbox(
|
235 |
+
label="Text Prompt",
|
236 |
+
info="One or two sentences at a time is better. Up to 200 text characters.",
|
237 |
+
value="He hoped there would be stew for dinner, turnips and carrots and bruised potatoes and fat mutton pieces to be ladled out in thick, peppered, flour-fattened sauce.",
|
238 |
+
)
|
239 |
+
style_gr = gr.Dropdown(
|
240 |
+
label="Style",
|
241 |
+
info="Select a style of output audio for the synthesised speech. (Chinese only support 'default' now)",
|
242 |
+
choices=['default', 'whispering', 'cheerful', 'terrified', 'angry', 'sad', 'friendly'],
|
243 |
+
max_choices=1,
|
244 |
+
value="default",
|
245 |
+
)
|
246 |
+
ref_gr = gr.Audio(
|
247 |
+
label="Reference Audio",
|
248 |
+
info="Click on the ✎ button to upload your own target speaker audio",
|
249 |
+
type="filepath",
|
250 |
+
value="resources/demo_speaker2.mp3",
|
251 |
+
)
|
252 |
+
tos_gr = gr.Checkbox(
|
253 |
+
label="Agree",
|
254 |
+
value=False,
|
255 |
+
info="I agree to the terms of the cc-by-nc-4.0 license-: https://github.com/myshell-ai/OpenVoice/blob/main/LICENSE",
|
256 |
+
)
|
257 |
+
|
258 |
+
tts_button = gr.Button("Send", elem_id="send-btn", visible=True)
|
259 |
+
|
260 |
+
|
261 |
+
with gr.Column():
|
262 |
+
out_text_gr = gr.Text(label="Info")
|
263 |
+
audio_gr = gr.Audio(label="Synthesised Audio", autoplay=True)
|
264 |
+
ref_audio_gr = gr.Audio(label="Reference Audio Used")
|
265 |
+
|
266 |
+
gr.Examples(examples,
|
267 |
+
label="Examples",
|
268 |
+
inputs=[input_text_gr, style_gr, ref_gr, tos_gr],
|
269 |
+
outputs=[out_text_gr, audio_gr, ref_audio_gr],
|
270 |
+
fn=predict,
|
271 |
+
cache_examples=False,)
|
272 |
+
tts_button.click(predict, [input_text_gr, style_gr, ref_gr, tos_gr], outputs=[out_text_gr, audio_gr, ref_audio_gr])
|
273 |
+
|
274 |
+
demo.queue()
|
275 |
+
demo.launch(debug=True, show_api=True, share=args.share)
|
OpenVoice/openvoice/se_extractor.py
ADDED
@@ -0,0 +1,153 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
import os
|
2 |
+
import glob
|
3 |
+
import torch
|
4 |
+
import hashlib
|
5 |
+
import librosa
|
6 |
+
import base64
|
7 |
+
from glob import glob
|
8 |
+
import numpy as np
|
9 |
+
from pydub import AudioSegment
|
10 |
+
from faster_whisper import WhisperModel
|
11 |
+
import hashlib
|
12 |
+
import base64
|
13 |
+
import librosa
|
14 |
+
from whisper_timestamped.transcribe import get_audio_tensor, get_vad_segments
|
15 |
+
|
16 |
+
model_size = "medium"
|
17 |
+
# Run on GPU with FP16
|
18 |
+
model = None
|
19 |
+
def split_audio_whisper(audio_path, audio_name, target_dir='processed'):
|
20 |
+
global model
|
21 |
+
if model is None:
|
22 |
+
model = WhisperModel(model_size, device="cuda", compute_type="float16")
|
23 |
+
audio = AudioSegment.from_file(audio_path)
|
24 |
+
max_len = len(audio)
|
25 |
+
|
26 |
+
target_folder = os.path.join(target_dir, audio_name)
|
27 |
+
|
28 |
+
segments, info = model.transcribe(audio_path, beam_size=5, word_timestamps=True)
|
29 |
+
segments = list(segments)
|
30 |
+
|
31 |
+
# create directory
|
32 |
+
os.makedirs(target_folder, exist_ok=True)
|
33 |
+
wavs_folder = os.path.join(target_folder, 'wavs')
|
34 |
+
os.makedirs(wavs_folder, exist_ok=True)
|
35 |
+
|
36 |
+
# segments
|
37 |
+
s_ind = 0
|
38 |
+
start_time = None
|
39 |
+
|
40 |
+
for k, w in enumerate(segments):
|
41 |
+
# process with the time
|
42 |
+
if k == 0:
|
43 |
+
start_time = max(0, w.start)
|
44 |
+
|
45 |
+
end_time = w.end
|
46 |
+
|
47 |
+
# calculate confidence
|
48 |
+
if len(w.words) > 0:
|
49 |
+
confidence = sum([s.probability for s in w.words]) / len(w.words)
|
50 |
+
else:
|
51 |
+
confidence = 0.
|
52 |
+
# clean text
|
53 |
+
text = w.text.replace('...', '')
|
54 |
+
|
55 |
+
# left 0.08s for each audios
|
56 |
+
audio_seg = audio[int( start_time * 1000) : min(max_len, int(end_time * 1000) + 80)]
|
57 |
+
|
58 |
+
# segment file name
|
59 |
+
fname = f"{audio_name}_seg{s_ind}.wav"
|
60 |
+
|
61 |
+
# filter out the segment shorter than 1.5s and longer than 20s
|
62 |
+
save = audio_seg.duration_seconds > 1.5 and \
|
63 |
+
audio_seg.duration_seconds < 20. and \
|
64 |
+
len(text) >= 2 and len(text) < 200
|
65 |
+
|
66 |
+
if save:
|
67 |
+
output_file = os.path.join(wavs_folder, fname)
|
68 |
+
audio_seg.export(output_file, format='wav')
|
69 |
+
|
70 |
+
if k < len(segments) - 1:
|
71 |
+
start_time = max(0, segments[k+1].start - 0.08)
|
72 |
+
|
73 |
+
s_ind = s_ind + 1
|
74 |
+
return wavs_folder
|
75 |
+
|
76 |
+
|
77 |
+
def split_audio_vad(audio_path, audio_name, target_dir, split_seconds=10.0):
|
78 |
+
SAMPLE_RATE = 16000
|
79 |
+
audio_vad = get_audio_tensor(audio_path)
|
80 |
+
segments = get_vad_segments(
|
81 |
+
audio_vad,
|
82 |
+
output_sample=True,
|
83 |
+
min_speech_duration=0.1,
|
84 |
+
min_silence_duration=1,
|
85 |
+
method="silero",
|
86 |
+
)
|
87 |
+
segments = [(seg["start"], seg["end"]) for seg in segments]
|
88 |
+
segments = [(float(s) / SAMPLE_RATE, float(e) / SAMPLE_RATE) for s,e in segments]
|
89 |
+
print(segments)
|
90 |
+
audio_active = AudioSegment.silent(duration=0)
|
91 |
+
audio = AudioSegment.from_file(audio_path)
|
92 |
+
|
93 |
+
for start_time, end_time in segments:
|
94 |
+
audio_active += audio[int( start_time * 1000) : int(end_time * 1000)]
|
95 |
+
|
96 |
+
audio_dur = audio_active.duration_seconds
|
97 |
+
print(f'after vad: dur = {audio_dur}')
|
98 |
+
target_folder = os.path.join(target_dir, audio_name)
|
99 |
+
wavs_folder = os.path.join(target_folder, 'wavs')
|
100 |
+
os.makedirs(wavs_folder, exist_ok=True)
|
101 |
+
start_time = 0.
|
102 |
+
count = 0
|
103 |
+
num_splits = int(np.round(audio_dur / split_seconds))
|
104 |
+
assert num_splits > 0, 'input audio is too short'
|
105 |
+
interval = audio_dur / num_splits
|
106 |
+
|
107 |
+
for i in range(num_splits):
|
108 |
+
end_time = min(start_time + interval, audio_dur)
|
109 |
+
if i == num_splits - 1:
|
110 |
+
end_time = audio_dur
|
111 |
+
output_file = f"{wavs_folder}/{audio_name}_seg{count}.wav"
|
112 |
+
audio_seg = audio_active[int(start_time * 1000): int(end_time * 1000)]
|
113 |
+
audio_seg.export(output_file, format='wav')
|
114 |
+
start_time = end_time
|
115 |
+
count += 1
|
116 |
+
return wavs_folder
|
117 |
+
|
118 |
+
def hash_numpy_array(audio_path):
|
119 |
+
array, _ = librosa.load(audio_path, sr=None, mono=True)
|
120 |
+
# Convert the array to bytes
|
121 |
+
array_bytes = array.tobytes()
|
122 |
+
# Calculate the hash of the array bytes
|
123 |
+
hash_object = hashlib.sha256(array_bytes)
|
124 |
+
hash_value = hash_object.digest()
|
125 |
+
# Convert the hash value to base64
|
126 |
+
base64_value = base64.b64encode(hash_value)
|
127 |
+
return base64_value.decode('utf-8')[:16].replace('/', '_^')
|
128 |
+
|
129 |
+
def get_se(audio_path, vc_model, target_dir='processed', vad=True):
|
130 |
+
device = vc_model.device
|
131 |
+
version = vc_model.version
|
132 |
+
print("OpenVoice version:", version)
|
133 |
+
|
134 |
+
audio_name = f"{os.path.basename(audio_path).rsplit('.', 1)[0]}_{version}_{hash_numpy_array(audio_path)}"
|
135 |
+
se_path = os.path.join(target_dir, audio_name, 'se.pth')
|
136 |
+
|
137 |
+
# if os.path.isfile(se_path):
|
138 |
+
# se = torch.load(se_path).to(device)
|
139 |
+
# return se, audio_name
|
140 |
+
# if os.path.isdir(audio_path):
|
141 |
+
# wavs_folder = audio_path
|
142 |
+
|
143 |
+
if vad:
|
144 |
+
wavs_folder = split_audio_vad(audio_path, target_dir=target_dir, audio_name=audio_name)
|
145 |
+
else:
|
146 |
+
wavs_folder = split_audio_whisper(audio_path, target_dir=target_dir, audio_name=audio_name)
|
147 |
+
|
148 |
+
audio_segs = glob(f'{wavs_folder}/*.wav')
|
149 |
+
if len(audio_segs) == 0:
|
150 |
+
raise NotImplementedError('No audio segments found!')
|
151 |
+
|
152 |
+
return vc_model.extract_se(audio_segs, se_save_path=se_path), audio_name
|
153 |
+
|
OpenVoice/openvoice/text/__init__.py
ADDED
@@ -0,0 +1,79 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
""" from https://github.com/keithito/tacotron """
|
2 |
+
from openvoice.text import cleaners
|
3 |
+
from openvoice.text.symbols import symbols
|
4 |
+
|
5 |
+
|
6 |
+
# Mappings from symbol to numeric ID and vice versa:
|
7 |
+
_symbol_to_id = {s: i for i, s in enumerate(symbols)}
|
8 |
+
_id_to_symbol = {i: s for i, s in enumerate(symbols)}
|
9 |
+
|
10 |
+
|
11 |
+
def text_to_sequence(text, symbols, cleaner_names):
|
12 |
+
'''Converts a string of text to a sequence of IDs corresponding to the symbols in the text.
|
13 |
+
Args:
|
14 |
+
text: string to convert to a sequence
|
15 |
+
cleaner_names: names of the cleaner functions to run the text through
|
16 |
+
Returns:
|
17 |
+
List of integers corresponding to the symbols in the text
|
18 |
+
'''
|
19 |
+
sequence = []
|
20 |
+
symbol_to_id = {s: i for i, s in enumerate(symbols)}
|
21 |
+
clean_text = _clean_text(text, cleaner_names)
|
22 |
+
print(clean_text)
|
23 |
+
print(f" length:{len(clean_text)}")
|
24 |
+
for symbol in clean_text:
|
25 |
+
if symbol not in symbol_to_id.keys():
|
26 |
+
continue
|
27 |
+
symbol_id = symbol_to_id[symbol]
|
28 |
+
sequence += [symbol_id]
|
29 |
+
print(f" length:{len(sequence)}")
|
30 |
+
return sequence
|
31 |
+
|
32 |
+
|
33 |
+
def cleaned_text_to_sequence(cleaned_text, symbols):
|
34 |
+
'''Converts a string of text to a sequence of IDs corresponding to the symbols in the text.
|
35 |
+
Args:
|
36 |
+
text: string to convert to a sequence
|
37 |
+
Returns:
|
38 |
+
List of integers corresponding to the symbols in the text
|
39 |
+
'''
|
40 |
+
symbol_to_id = {s: i for i, s in enumerate(symbols)}
|
41 |
+
sequence = [symbol_to_id[symbol] for symbol in cleaned_text if symbol in symbol_to_id.keys()]
|
42 |
+
return sequence
|
43 |
+
|
44 |
+
|
45 |
+
|
46 |
+
from openvoice.text.symbols import language_tone_start_map
|
47 |
+
def cleaned_text_to_sequence_vits2(cleaned_text, tones, language, symbols, languages):
|
48 |
+
"""Converts a string of text to a sequence of IDs corresponding to the symbols in the text.
|
49 |
+
Args:
|
50 |
+
text: string to convert to a sequence
|
51 |
+
Returns:
|
52 |
+
List of integers corresponding to the symbols in the text
|
53 |
+
"""
|
54 |
+
symbol_to_id = {s: i for i, s in enumerate(symbols)}
|
55 |
+
language_id_map = {s: i for i, s in enumerate(languages)}
|
56 |
+
phones = [symbol_to_id[symbol] for symbol in cleaned_text]
|
57 |
+
tone_start = language_tone_start_map[language]
|
58 |
+
tones = [i + tone_start for i in tones]
|
59 |
+
lang_id = language_id_map[language]
|
60 |
+
lang_ids = [lang_id for i in phones]
|
61 |
+
return phones, tones, lang_ids
|
62 |
+
|
63 |
+
|
64 |
+
def sequence_to_text(sequence):
|
65 |
+
'''Converts a sequence of IDs back to a string'''
|
66 |
+
result = ''
|
67 |
+
for symbol_id in sequence:
|
68 |
+
s = _id_to_symbol[symbol_id]
|
69 |
+
result += s
|
70 |
+
return result
|
71 |
+
|
72 |
+
|
73 |
+
def _clean_text(text, cleaner_names):
|
74 |
+
for name in cleaner_names:
|
75 |
+
cleaner = getattr(cleaners, name)
|
76 |
+
if not cleaner:
|
77 |
+
raise Exception('Unknown cleaner: %s' % name)
|
78 |
+
text = cleaner(text)
|
79 |
+
return text
|
OpenVoice/openvoice/text/cleaners.py
ADDED
@@ -0,0 +1,16 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
import re
|
2 |
+
from openvoice.text.english import english_to_lazy_ipa, english_to_ipa2, english_to_lazy_ipa2
|
3 |
+
from openvoice.text.mandarin import number_to_chinese, chinese_to_bopomofo, latin_to_bopomofo, chinese_to_romaji, chinese_to_lazy_ipa, chinese_to_ipa, chinese_to_ipa2
|
4 |
+
|
5 |
+
def cjke_cleaners2(text):
|
6 |
+
text = re.sub(r'\[ZH\](.*?)\[ZH\]',
|
7 |
+
lambda x: chinese_to_ipa(x.group(1))+' ', text)
|
8 |
+
text = re.sub(r'\[JA\](.*?)\[JA\]',
|
9 |
+
lambda x: japanese_to_ipa2(x.group(1))+' ', text)
|
10 |
+
text = re.sub(r'\[KO\](.*?)\[KO\]',
|
11 |
+
lambda x: korean_to_ipa(x.group(1))+' ', text)
|
12 |
+
text = re.sub(r'\[EN\](.*?)\[EN\]',
|
13 |
+
lambda x: english_to_ipa2(x.group(1))+' ', text)
|
14 |
+
text = re.sub(r'\s+$', '', text)
|
15 |
+
text = re.sub(r'([^\.,!\?\-…~])$', r'\1.', text)
|
16 |
+
return text
|
OpenVoice/openvoice/text/english.py
ADDED
@@ -0,0 +1,188 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
""" from https://github.com/keithito/tacotron """
|
2 |
+
|
3 |
+
'''
|
4 |
+
Cleaners are transformations that run over the input text at both training and eval time.
|
5 |
+
|
6 |
+
Cleaners can be selected by passing a comma-delimited list of cleaner names as the "cleaners"
|
7 |
+
hyperparameter. Some cleaners are English-specific. You'll typically want to use:
|
8 |
+
1. "english_cleaners" for English text
|
9 |
+
2. "transliteration_cleaners" for non-English text that can be transliterated to ASCII using
|
10 |
+
the Unidecode library (https://pypi.python.org/pypi/Unidecode)
|
11 |
+
3. "basic_cleaners" if you do not want to transliterate (in this case, you should also update
|
12 |
+
the symbols in symbols.py to match your data).
|
13 |
+
'''
|
14 |
+
|
15 |
+
|
16 |
+
# Regular expression matching whitespace:
|
17 |
+
|
18 |
+
|
19 |
+
import re
|
20 |
+
import inflect
|
21 |
+
from unidecode import unidecode
|
22 |
+
import eng_to_ipa as ipa
|
23 |
+
_inflect = inflect.engine()
|
24 |
+
_comma_number_re = re.compile(r'([0-9][0-9\,]+[0-9])')
|
25 |
+
_decimal_number_re = re.compile(r'([0-9]+\.[0-9]+)')
|
26 |
+
_pounds_re = re.compile(r'£([0-9\,]*[0-9]+)')
|
27 |
+
_dollars_re = re.compile(r'\$([0-9\.\,]*[0-9]+)')
|
28 |
+
_ordinal_re = re.compile(r'[0-9]+(st|nd|rd|th)')
|
29 |
+
_number_re = re.compile(r'[0-9]+')
|
30 |
+
|
31 |
+
# List of (regular expression, replacement) pairs for abbreviations:
|
32 |
+
_abbreviations = [(re.compile('\\b%s\\.' % x[0], re.IGNORECASE), x[1]) for x in [
|
33 |
+
('mrs', 'misess'),
|
34 |
+
('mr', 'mister'),
|
35 |
+
('dr', 'doctor'),
|
36 |
+
('st', 'saint'),
|
37 |
+
('co', 'company'),
|
38 |
+
('jr', 'junior'),
|
39 |
+
('maj', 'major'),
|
40 |
+
('gen', 'general'),
|
41 |
+
('drs', 'doctors'),
|
42 |
+
('rev', 'reverend'),
|
43 |
+
('lt', 'lieutenant'),
|
44 |
+
('hon', 'honorable'),
|
45 |
+
('sgt', 'sergeant'),
|
46 |
+
('capt', 'captain'),
|
47 |
+
('esq', 'esquire'),
|
48 |
+
('ltd', 'limited'),
|
49 |
+
('col', 'colonel'),
|
50 |
+
('ft', 'fort'),
|
51 |
+
]]
|
52 |
+
|
53 |
+
|
54 |
+
# List of (ipa, lazy ipa) pairs:
|
55 |
+
_lazy_ipa = [(re.compile('%s' % x[0]), x[1]) for x in [
|
56 |
+
('r', 'ɹ'),
|
57 |
+
('æ', 'e'),
|
58 |
+
('ɑ', 'a'),
|
59 |
+
('ɔ', 'o'),
|
60 |
+
('ð', 'z'),
|
61 |
+
('θ', 's'),
|
62 |
+
('ɛ', 'e'),
|
63 |
+
('ɪ', 'i'),
|
64 |
+
('ʊ', 'u'),
|
65 |
+
('ʒ', 'ʥ'),
|
66 |
+
('ʤ', 'ʥ'),
|
67 |
+
('ˈ', '↓'),
|
68 |
+
]]
|
69 |
+
|
70 |
+
# List of (ipa, lazy ipa2) pairs:
|
71 |
+
_lazy_ipa2 = [(re.compile('%s' % x[0]), x[1]) for x in [
|
72 |
+
('r', 'ɹ'),
|
73 |
+
('ð', 'z'),
|
74 |
+
('θ', 's'),
|
75 |
+
('ʒ', 'ʑ'),
|
76 |
+
('ʤ', 'dʑ'),
|
77 |
+
('ˈ', '↓'),
|
78 |
+
]]
|
79 |
+
|
80 |
+
# List of (ipa, ipa2) pairs
|
81 |
+
_ipa_to_ipa2 = [(re.compile('%s' % x[0]), x[1]) for x in [
|
82 |
+
('r', 'ɹ'),
|
83 |
+
('ʤ', 'dʒ'),
|
84 |
+
('ʧ', 'tʃ')
|
85 |
+
]]
|
86 |
+
|
87 |
+
|
88 |
+
def expand_abbreviations(text):
|
89 |
+
for regex, replacement in _abbreviations:
|
90 |
+
text = re.sub(regex, replacement, text)
|
91 |
+
return text
|
92 |
+
|
93 |
+
|
94 |
+
def collapse_whitespace(text):
|
95 |
+
return re.sub(r'\s+', ' ', text)
|
96 |
+
|
97 |
+
|
98 |
+
def _remove_commas(m):
|
99 |
+
return m.group(1).replace(',', '')
|
100 |
+
|
101 |
+
|
102 |
+
def _expand_decimal_point(m):
|
103 |
+
return m.group(1).replace('.', ' point ')
|
104 |
+
|
105 |
+
|
106 |
+
def _expand_dollars(m):
|
107 |
+
match = m.group(1)
|
108 |
+
parts = match.split('.')
|
109 |
+
if len(parts) > 2:
|
110 |
+
return match + ' dollars' # Unexpected format
|
111 |
+
dollars = int(parts[0]) if parts[0] else 0
|
112 |
+
cents = int(parts[1]) if len(parts) > 1 and parts[1] else 0
|
113 |
+
if dollars and cents:
|
114 |
+
dollar_unit = 'dollar' if dollars == 1 else 'dollars'
|
115 |
+
cent_unit = 'cent' if cents == 1 else 'cents'
|
116 |
+
return '%s %s, %s %s' % (dollars, dollar_unit, cents, cent_unit)
|
117 |
+
elif dollars:
|
118 |
+
dollar_unit = 'dollar' if dollars == 1 else 'dollars'
|
119 |
+
return '%s %s' % (dollars, dollar_unit)
|
120 |
+
elif cents:
|
121 |
+
cent_unit = 'cent' if cents == 1 else 'cents'
|
122 |
+
return '%s %s' % (cents, cent_unit)
|
123 |
+
else:
|
124 |
+
return 'zero dollars'
|
125 |
+
|
126 |
+
|
127 |
+
def _expand_ordinal(m):
|
128 |
+
return _inflect.number_to_words(m.group(0))
|
129 |
+
|
130 |
+
|
131 |
+
def _expand_number(m):
|
132 |
+
num = int(m.group(0))
|
133 |
+
if num > 1000 and num < 3000:
|
134 |
+
if num == 2000:
|
135 |
+
return 'two thousand'
|
136 |
+
elif num > 2000 and num < 2010:
|
137 |
+
return 'two thousand ' + _inflect.number_to_words(num % 100)
|
138 |
+
elif num % 100 == 0:
|
139 |
+
return _inflect.number_to_words(num // 100) + ' hundred'
|
140 |
+
else:
|
141 |
+
return _inflect.number_to_words(num, andword='', zero='oh', group=2).replace(', ', ' ')
|
142 |
+
else:
|
143 |
+
return _inflect.number_to_words(num, andword='')
|
144 |
+
|
145 |
+
|
146 |
+
def normalize_numbers(text):
|
147 |
+
text = re.sub(_comma_number_re, _remove_commas, text)
|
148 |
+
text = re.sub(_pounds_re, r'\1 pounds', text)
|
149 |
+
text = re.sub(_dollars_re, _expand_dollars, text)
|
150 |
+
text = re.sub(_decimal_number_re, _expand_decimal_point, text)
|
151 |
+
text = re.sub(_ordinal_re, _expand_ordinal, text)
|
152 |
+
text = re.sub(_number_re, _expand_number, text)
|
153 |
+
return text
|
154 |
+
|
155 |
+
|
156 |
+
def mark_dark_l(text):
|
157 |
+
return re.sub(r'l([^aeiouæɑɔəɛɪʊ ]*(?: |$))', lambda x: 'ɫ'+x.group(1), text)
|
158 |
+
|
159 |
+
|
160 |
+
def english_to_ipa(text):
|
161 |
+
text = unidecode(text).lower()
|
162 |
+
text = expand_abbreviations(text)
|
163 |
+
text = normalize_numbers(text)
|
164 |
+
phonemes = ipa.convert(text)
|
165 |
+
phonemes = collapse_whitespace(phonemes)
|
166 |
+
return phonemes
|
167 |
+
|
168 |
+
|
169 |
+
def english_to_lazy_ipa(text):
|
170 |
+
text = english_to_ipa(text)
|
171 |
+
for regex, replacement in _lazy_ipa:
|
172 |
+
text = re.sub(regex, replacement, text)
|
173 |
+
return text
|
174 |
+
|
175 |
+
|
176 |
+
def english_to_ipa2(text):
|
177 |
+
text = english_to_ipa(text)
|
178 |
+
text = mark_dark_l(text)
|
179 |
+
for regex, replacement in _ipa_to_ipa2:
|
180 |
+
text = re.sub(regex, replacement, text)
|
181 |
+
return text.replace('...', '…')
|
182 |
+
|
183 |
+
|
184 |
+
def english_to_lazy_ipa2(text):
|
185 |
+
text = english_to_ipa(text)
|
186 |
+
for regex, replacement in _lazy_ipa2:
|
187 |
+
text = re.sub(regex, replacement, text)
|
188 |
+
return text
|
OpenVoice/openvoice/text/mandarin.py
ADDED
@@ -0,0 +1,326 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
import os
|
2 |
+
import sys
|
3 |
+
import re
|
4 |
+
from pypinyin import lazy_pinyin, BOPOMOFO
|
5 |
+
import jieba
|
6 |
+
import cn2an
|
7 |
+
import logging
|
8 |
+
|
9 |
+
|
10 |
+
# List of (Latin alphabet, bopomofo) pairs:
|
11 |
+
_latin_to_bopomofo = [(re.compile('%s' % x[0], re.IGNORECASE), x[1]) for x in [
|
12 |
+
('a', 'ㄟˉ'),
|
13 |
+
('b', 'ㄅㄧˋ'),
|
14 |
+
('c', 'ㄙㄧˉ'),
|
15 |
+
('d', 'ㄉㄧˋ'),
|
16 |
+
('e', 'ㄧˋ'),
|
17 |
+
('f', 'ㄝˊㄈㄨˋ'),
|
18 |
+
('g', 'ㄐㄧˋ'),
|
19 |
+
('h', 'ㄝˇㄑㄩˋ'),
|
20 |
+
('i', 'ㄞˋ'),
|
21 |
+
('j', 'ㄐㄟˋ'),
|
22 |
+
('k', 'ㄎㄟˋ'),
|
23 |
+
('l', 'ㄝˊㄛˋ'),
|
24 |
+
('m', 'ㄝˊㄇㄨˋ'),
|
25 |
+
('n', 'ㄣˉ'),
|
26 |
+
('o', 'ㄡˉ'),
|
27 |
+
('p', 'ㄆㄧˉ'),
|
28 |
+
('q', 'ㄎㄧㄡˉ'),
|
29 |
+
('r', 'ㄚˋ'),
|
30 |
+
('s', 'ㄝˊㄙˋ'),
|
31 |
+
('t', 'ㄊㄧˋ'),
|
32 |
+
('u', 'ㄧㄡˉ'),
|
33 |
+
('v', 'ㄨㄧˉ'),
|
34 |
+
('w', 'ㄉㄚˋㄅㄨˋㄌㄧㄡˋ'),
|
35 |
+
('x', 'ㄝˉㄎㄨˋㄙˋ'),
|
36 |
+
('y', 'ㄨㄞˋ'),
|
37 |
+
('z', 'ㄗㄟˋ')
|
38 |
+
]]
|
39 |
+
|
40 |
+
# List of (bopomofo, romaji) pairs:
|
41 |
+
_bopomofo_to_romaji = [(re.compile('%s' % x[0]), x[1]) for x in [
|
42 |
+
('ㄅㄛ', 'p⁼wo'),
|
43 |
+
('ㄆㄛ', 'pʰwo'),
|
44 |
+
('ㄇㄛ', 'mwo'),
|
45 |
+
('ㄈㄛ', 'fwo'),
|
46 |
+
('ㄅ', 'p⁼'),
|
47 |
+
('ㄆ', 'pʰ'),
|
48 |
+
('ㄇ', 'm'),
|
49 |
+
('ㄈ', 'f'),
|
50 |
+
('ㄉ', 't⁼'),
|
51 |
+
('ㄊ', 'tʰ'),
|
52 |
+
('ㄋ', 'n'),
|
53 |
+
('ㄌ', 'l'),
|
54 |
+
('ㄍ', 'k⁼'),
|
55 |
+
('ㄎ', 'kʰ'),
|
56 |
+
('ㄏ', 'h'),
|
57 |
+
('ㄐ', 'ʧ⁼'),
|
58 |
+
('ㄑ', 'ʧʰ'),
|
59 |
+
('ㄒ', 'ʃ'),
|
60 |
+
('ㄓ', 'ʦ`⁼'),
|
61 |
+
('ㄔ', 'ʦ`ʰ'),
|
62 |
+
('ㄕ', 's`'),
|
63 |
+
('ㄖ', 'ɹ`'),
|
64 |
+
('ㄗ', 'ʦ⁼'),
|
65 |
+
('ㄘ', 'ʦʰ'),
|
66 |
+
('ㄙ', 's'),
|
67 |
+
('ㄚ', 'a'),
|
68 |
+
('ㄛ', 'o'),
|
69 |
+
('ㄜ', 'ə'),
|
70 |
+
('ㄝ', 'e'),
|
71 |
+
('ㄞ', 'ai'),
|
72 |
+
('ㄟ', 'ei'),
|
73 |
+
('ㄠ', 'au'),
|
74 |
+
('ㄡ', 'ou'),
|
75 |
+
('ㄧㄢ', 'yeNN'),
|
76 |
+
('ㄢ', 'aNN'),
|
77 |
+
('ㄧㄣ', 'iNN'),
|
78 |
+
('ㄣ', 'əNN'),
|
79 |
+
('ㄤ', 'aNg'),
|
80 |
+
('ㄧㄥ', 'iNg'),
|
81 |
+
('ㄨㄥ', 'uNg'),
|
82 |
+
('ㄩㄥ', 'yuNg'),
|
83 |
+
('ㄥ', 'əNg'),
|
84 |
+
('ㄦ', 'əɻ'),
|
85 |
+
('ㄧ', 'i'),
|
86 |
+
('ㄨ', 'u'),
|
87 |
+
('ㄩ', 'ɥ'),
|
88 |
+
('ˉ', '→'),
|
89 |
+
('ˊ', '↑'),
|
90 |
+
('ˇ', '↓↑'),
|
91 |
+
('ˋ', '↓'),
|
92 |
+
('˙', ''),
|
93 |
+
(',', ','),
|
94 |
+
('。', '.'),
|
95 |
+
('!', '!'),
|
96 |
+
('?', '?'),
|
97 |
+
('—', '-')
|
98 |
+
]]
|
99 |
+
|
100 |
+
# List of (romaji, ipa) pairs:
|
101 |
+
_romaji_to_ipa = [(re.compile('%s' % x[0], re.IGNORECASE), x[1]) for x in [
|
102 |
+
('ʃy', 'ʃ'),
|
103 |
+
('ʧʰy', 'ʧʰ'),
|
104 |
+
('ʧ⁼y', 'ʧ⁼'),
|
105 |
+
('NN', 'n'),
|
106 |
+
('Ng', 'ŋ'),
|
107 |
+
('y', 'j'),
|
108 |
+
('h', 'x')
|
109 |
+
]]
|
110 |
+
|
111 |
+
# List of (bopomofo, ipa) pairs:
|
112 |
+
_bopomofo_to_ipa = [(re.compile('%s' % x[0]), x[1]) for x in [
|
113 |
+
('ㄅㄛ', 'p⁼wo'),
|
114 |
+
('ㄆㄛ', 'pʰwo'),
|
115 |
+
('ㄇㄛ', 'mwo'),
|
116 |
+
('ㄈㄛ', 'fwo'),
|
117 |
+
('ㄅ', 'p⁼'),
|
118 |
+
('ㄆ', 'pʰ'),
|
119 |
+
('ㄇ', 'm'),
|
120 |
+
('ㄈ', 'f'),
|
121 |
+
('ㄉ', 't⁼'),
|
122 |
+
('ㄊ', 'tʰ'),
|
123 |
+
('ㄋ', 'n'),
|
124 |
+
('ㄌ', 'l'),
|
125 |
+
('ㄍ', 'k⁼'),
|
126 |
+
('ㄎ', 'kʰ'),
|
127 |
+
('ㄏ', 'x'),
|
128 |
+
('ㄐ', 'tʃ⁼'),
|
129 |
+
('ㄑ', 'tʃʰ'),
|
130 |
+
('ㄒ', 'ʃ'),
|
131 |
+
('ㄓ', 'ts`⁼'),
|
132 |
+
('ㄔ', 'ts`ʰ'),
|
133 |
+
('ㄕ', 's`'),
|
134 |
+
('ㄖ', 'ɹ`'),
|
135 |
+
('ㄗ', 'ts⁼'),
|
136 |
+
('ㄘ', 'tsʰ'),
|
137 |
+
('ㄙ', 's'),
|
138 |
+
('ㄚ', 'a'),
|
139 |
+
('ㄛ', 'o'),
|
140 |
+
('ㄜ', 'ə'),
|
141 |
+
('ㄝ', 'ɛ'),
|
142 |
+
('ㄞ', 'aɪ'),
|
143 |
+
('ㄟ', 'eɪ'),
|
144 |
+
('ㄠ', 'ɑʊ'),
|
145 |
+
('ㄡ', 'oʊ'),
|
146 |
+
('ㄧㄢ', 'jɛn'),
|
147 |
+
('ㄩㄢ', 'ɥæn'),
|
148 |
+
('ㄢ', 'an'),
|
149 |
+
('ㄧㄣ', 'in'),
|
150 |
+
('ㄩㄣ', 'ɥn'),
|
151 |
+
('ㄣ', 'ən'),
|
152 |
+
('ㄤ', 'ɑŋ'),
|
153 |
+
('ㄧㄥ', 'iŋ'),
|
154 |
+
('ㄨㄥ', 'ʊŋ'),
|
155 |
+
('ㄩㄥ', 'jʊŋ'),
|
156 |
+
('ㄥ', 'əŋ'),
|
157 |
+
('ㄦ', 'əɻ'),
|
158 |
+
('ㄧ', 'i'),
|
159 |
+
('ㄨ', 'u'),
|
160 |
+
('ㄩ', 'ɥ'),
|
161 |
+
('ˉ', '→'),
|
162 |
+
('ˊ', '↑'),
|
163 |
+
('ˇ', '↓↑'),
|
164 |
+
('ˋ', '↓'),
|
165 |
+
('˙', ''),
|
166 |
+
(',', ','),
|
167 |
+
('。', '.'),
|
168 |
+
('!', '!'),
|
169 |
+
('?', '?'),
|
170 |
+
('—', '-')
|
171 |
+
]]
|
172 |
+
|
173 |
+
# List of (bopomofo, ipa2) pairs:
|
174 |
+
_bopomofo_to_ipa2 = [(re.compile('%s' % x[0]), x[1]) for x in [
|
175 |
+
('ㄅㄛ', 'pwo'),
|
176 |
+
('ㄆㄛ', 'pʰwo'),
|
177 |
+
('ㄇㄛ', 'mwo'),
|
178 |
+
('ㄈㄛ', 'fwo'),
|
179 |
+
('ㄅ', 'p'),
|
180 |
+
('ㄆ', 'pʰ'),
|
181 |
+
('ㄇ', 'm'),
|
182 |
+
('ㄈ', 'f'),
|
183 |
+
('ㄉ', 't'),
|
184 |
+
('ㄊ', 'tʰ'),
|
185 |
+
('ㄋ', 'n'),
|
186 |
+
('ㄌ', 'l'),
|
187 |
+
('ㄍ', 'k'),
|
188 |
+
('ㄎ', 'kʰ'),
|
189 |
+
('ㄏ', 'h'),
|
190 |
+
('ㄐ', 'tɕ'),
|
191 |
+
('ㄑ', 'tɕʰ'),
|
192 |
+
('ㄒ', 'ɕ'),
|
193 |
+
('ㄓ', 'tʂ'),
|
194 |
+
('ㄔ', 'tʂʰ'),
|
195 |
+
('ㄕ', 'ʂ'),
|
196 |
+
('ㄖ', 'ɻ'),
|
197 |
+
('ㄗ', 'ts'),
|
198 |
+
('ㄘ', 'tsʰ'),
|
199 |
+
('ㄙ', 's'),
|
200 |
+
('ㄚ', 'a'),
|
201 |
+
('ㄛ', 'o'),
|
202 |
+
('ㄜ', 'ɤ'),
|
203 |
+
('ㄝ', 'ɛ'),
|
204 |
+
('ㄞ', 'aɪ'),
|
205 |
+
('ㄟ', 'eɪ'),
|
206 |
+
('ㄠ', 'ɑʊ'),
|
207 |
+
('ㄡ', 'oʊ'),
|
208 |
+
('ㄧㄢ', 'jɛn'),
|
209 |
+
('ㄩㄢ', 'yæn'),
|
210 |
+
('ㄢ', 'an'),
|
211 |
+
('ㄧㄣ', 'in'),
|
212 |
+
('ㄩㄣ', 'yn'),
|
213 |
+
('ㄣ', 'ən'),
|
214 |
+
('ㄤ', 'ɑŋ'),
|
215 |
+
('ㄧㄥ', 'iŋ'),
|
216 |
+
('ㄨㄥ', 'ʊŋ'),
|
217 |
+
('ㄩㄥ', 'jʊŋ'),
|
218 |
+
('ㄥ', 'ɤŋ'),
|
219 |
+
('ㄦ', 'əɻ'),
|
220 |
+
('ㄧ', 'i'),
|
221 |
+
('ㄨ', 'u'),
|
222 |
+
('ㄩ', 'y'),
|
223 |
+
('ˉ', '˥'),
|
224 |
+
('ˊ', '˧˥'),
|
225 |
+
('ˇ', '˨˩˦'),
|
226 |
+
('ˋ', '˥˩'),
|
227 |
+
('˙', ''),
|
228 |
+
(',', ','),
|
229 |
+
('。', '.'),
|
230 |
+
('!', '!'),
|
231 |
+
('?', '?'),
|
232 |
+
('—', '-')
|
233 |
+
]]
|
234 |
+
|
235 |
+
|
236 |
+
def number_to_chinese(text):
|
237 |
+
numbers = re.findall(r'\d+(?:\.?\d+)?', text)
|
238 |
+
for number in numbers:
|
239 |
+
text = text.replace(number, cn2an.an2cn(number), 1)
|
240 |
+
return text
|
241 |
+
|
242 |
+
|
243 |
+
def chinese_to_bopomofo(text):
|
244 |
+
text = text.replace('、', ',').replace(';', ',').replace(':', ',')
|
245 |
+
words = jieba.lcut(text, cut_all=False)
|
246 |
+
text = ''
|
247 |
+
for word in words:
|
248 |
+
bopomofos = lazy_pinyin(word, BOPOMOFO)
|
249 |
+
if not re.search('[\u4e00-\u9fff]', word):
|
250 |
+
text += word
|
251 |
+
continue
|
252 |
+
for i in range(len(bopomofos)):
|
253 |
+
bopomofos[i] = re.sub(r'([\u3105-\u3129])$', r'\1ˉ', bopomofos[i])
|
254 |
+
if text != '':
|
255 |
+
text += ' '
|
256 |
+
text += ''.join(bopomofos)
|
257 |
+
return text
|
258 |
+
|
259 |
+
|
260 |
+
def latin_to_bopomofo(text):
|
261 |
+
for regex, replacement in _latin_to_bopomofo:
|
262 |
+
text = re.sub(regex, replacement, text)
|
263 |
+
return text
|
264 |
+
|
265 |
+
|
266 |
+
def bopomofo_to_romaji(text):
|
267 |
+
for regex, replacement in _bopomofo_to_romaji:
|
268 |
+
text = re.sub(regex, replacement, text)
|
269 |
+
return text
|
270 |
+
|
271 |
+
|
272 |
+
def bopomofo_to_ipa(text):
|
273 |
+
for regex, replacement in _bopomofo_to_ipa:
|
274 |
+
text = re.sub(regex, replacement, text)
|
275 |
+
return text
|
276 |
+
|
277 |
+
|
278 |
+
def bopomofo_to_ipa2(text):
|
279 |
+
for regex, replacement in _bopomofo_to_ipa2:
|
280 |
+
text = re.sub(regex, replacement, text)
|
281 |
+
return text
|
282 |
+
|
283 |
+
|
284 |
+
def chinese_to_romaji(text):
|
285 |
+
text = number_to_chinese(text)
|
286 |
+
text = chinese_to_bopomofo(text)
|
287 |
+
text = latin_to_bopomofo(text)
|
288 |
+
text = bopomofo_to_romaji(text)
|
289 |
+
text = re.sub('i([aoe])', r'y\1', text)
|
290 |
+
text = re.sub('u([aoəe])', r'w\1', text)
|
291 |
+
text = re.sub('([ʦsɹ]`[⁼ʰ]?)([→↓↑ ]+|$)',
|
292 |
+
r'\1ɹ`\2', text).replace('ɻ', 'ɹ`')
|
293 |
+
text = re.sub('([ʦs][⁼ʰ]?)([→↓↑ ]+|$)', r'\1ɹ\2', text)
|
294 |
+
return text
|
295 |
+
|
296 |
+
|
297 |
+
def chinese_to_lazy_ipa(text):
|
298 |
+
text = chinese_to_romaji(text)
|
299 |
+
for regex, replacement in _romaji_to_ipa:
|
300 |
+
text = re.sub(regex, replacement, text)
|
301 |
+
return text
|
302 |
+
|
303 |
+
|
304 |
+
def chinese_to_ipa(text):
|
305 |
+
text = number_to_chinese(text)
|
306 |
+
text = chinese_to_bopomofo(text)
|
307 |
+
text = latin_to_bopomofo(text)
|
308 |
+
text = bopomofo_to_ipa(text)
|
309 |
+
text = re.sub('i([aoe])', r'j\1', text)
|
310 |
+
text = re.sub('u([aoəe])', r'w\1', text)
|
311 |
+
text = re.sub('([sɹ]`[⁼ʰ]?)([→↓↑ ]+|$)',
|
312 |
+
r'\1ɹ`\2', text).replace('ɻ', 'ɹ`')
|
313 |
+
text = re.sub('([s][⁼ʰ]?)([→↓↑ ]+|$)', r'\1ɹ\2', text)
|
314 |
+
return text
|
315 |
+
|
316 |
+
|
317 |
+
def chinese_to_ipa2(text):
|
318 |
+
text = number_to_chinese(text)
|
319 |
+
text = chinese_to_bopomofo(text)
|
320 |
+
text = latin_to_bopomofo(text)
|
321 |
+
text = bopomofo_to_ipa2(text)
|
322 |
+
text = re.sub(r'i([aoe])', r'j\1', text)
|
323 |
+
text = re.sub(r'u([aoəe])', r'w\1', text)
|
324 |
+
text = re.sub(r'([ʂɹ]ʰ?)([˩˨˧˦˥ ]+|$)', r'\1ʅ\2', text)
|
325 |
+
text = re.sub(r'(sʰ?)([˩˨˧˦˥ ]+|$)', r'\1ɿ\2', text)
|
326 |
+
return text
|
OpenVoice/openvoice/text/symbols.py
ADDED
@@ -0,0 +1,88 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
'''
|
2 |
+
Defines the set of symbols used in text input to the model.
|
3 |
+
'''
|
4 |
+
|
5 |
+
# japanese_cleaners
|
6 |
+
# _pad = '_'
|
7 |
+
# _punctuation = ',.!?-'
|
8 |
+
# _letters = 'AEINOQUabdefghijkmnoprstuvwyzʃʧ↓↑ '
|
9 |
+
|
10 |
+
|
11 |
+
'''# japanese_cleaners2
|
12 |
+
_pad = '_'
|
13 |
+
_punctuation = ',.!?-~…'
|
14 |
+
_letters = 'AEINOQUabdefghijkmnoprstuvwyzʃʧʦ↓↑ '
|
15 |
+
'''
|
16 |
+
|
17 |
+
|
18 |
+
'''# korean_cleaners
|
19 |
+
_pad = '_'
|
20 |
+
_punctuation = ',.!?…~'
|
21 |
+
_letters = 'ㄱㄴㄷㄹㅁㅂㅅㅇㅈㅊㅋㅌㅍㅎㄲㄸㅃㅆㅉㅏㅓㅗㅜㅡㅣㅐㅔ '
|
22 |
+
'''
|
23 |
+
|
24 |
+
'''# chinese_cleaners
|
25 |
+
_pad = '_'
|
26 |
+
_punctuation = ',。!?—…'
|
27 |
+
_letters = 'ㄅㄆㄇㄈㄉㄊㄋㄌㄍㄎㄏㄐㄑㄒㄓㄔㄕㄖㄗㄘㄙㄚㄛㄜㄝㄞㄟㄠㄡㄢㄣㄤㄥㄦㄧㄨㄩˉˊˇˋ˙ '
|
28 |
+
'''
|
29 |
+
|
30 |
+
# # zh_ja_mixture_cleaners
|
31 |
+
# _pad = '_'
|
32 |
+
# _punctuation = ',.!?-~…'
|
33 |
+
# _letters = 'AEINOQUabdefghijklmnoprstuvwyzʃʧʦɯɹəɥ⁼ʰ`→↓↑ '
|
34 |
+
|
35 |
+
|
36 |
+
'''# sanskrit_cleaners
|
37 |
+
_pad = '_'
|
38 |
+
_punctuation = '।'
|
39 |
+
_letters = 'ँंःअआइईउऊऋएऐओऔकखगघङचछजझञटठडढणतथदधनपफबभमयरलळवशषसहऽािीुूृॄेैोौ्ॠॢ '
|
40 |
+
'''
|
41 |
+
|
42 |
+
'''# cjks_cleaners
|
43 |
+
_pad = '_'
|
44 |
+
_punctuation = ',.!?-~…'
|
45 |
+
_letters = 'NQabdefghijklmnopstuvwxyzʃʧʥʦɯɹəɥçɸɾβŋɦː⁼ʰ`^#*=→↓↑ '
|
46 |
+
'''
|
47 |
+
|
48 |
+
'''# thai_cleaners
|
49 |
+
_pad = '_'
|
50 |
+
_punctuation = '.!? '
|
51 |
+
_letters = 'กขฃคฆงจฉชซฌญฎฏฐฑฒณดตถทธนบปผฝพฟภมยรฤลวศษสหฬอฮฯะัาำิีึืุูเแโใไๅๆ็่้๊๋์'
|
52 |
+
'''
|
53 |
+
|
54 |
+
# # cjke_cleaners2
|
55 |
+
_pad = '_'
|
56 |
+
_punctuation = ',.!?-~…'
|
57 |
+
_letters = 'NQabdefghijklmnopstuvwxyzɑæʃʑçɯɪɔɛɹðəɫɥɸʊɾʒθβŋɦ⁼ʰ`^#*=ˈˌ→↓↑ '
|
58 |
+
|
59 |
+
|
60 |
+
'''# shanghainese_cleaners
|
61 |
+
_pad = '_'
|
62 |
+
_punctuation = ',.!?…'
|
63 |
+
_letters = 'abdfghiklmnopstuvyzøŋȵɑɔɕəɤɦɪɿʑʔʰ̩̃ᴀᴇ15678 '
|
64 |
+
'''
|
65 |
+
|
66 |
+
'''# chinese_dialect_cleaners
|
67 |
+
_pad = '_'
|
68 |
+
_punctuation = ',.!?~…─'
|
69 |
+
_letters = '#Nabdefghijklmnoprstuvwxyzæçøŋœȵɐɑɒɓɔɕɗɘəɚɛɜɣɤɦɪɭɯɵɷɸɻɾɿʂʅʊʋʌʏʑʔʦʮʰʷˀː˥˦˧˨˩̥̩̃̚ᴀᴇ↑↓∅ⱼ '
|
70 |
+
'''
|
71 |
+
|
72 |
+
# Export all symbols:
|
73 |
+
symbols = [_pad] + list(_punctuation) + list(_letters)
|
74 |
+
|
75 |
+
# Special symbol ids
|
76 |
+
SPACE_ID = symbols.index(" ")
|
77 |
+
|
78 |
+
num_ja_tones = 1
|
79 |
+
num_kr_tones = 1
|
80 |
+
num_zh_tones = 6
|
81 |
+
num_en_tones = 4
|
82 |
+
|
83 |
+
language_tone_start_map = {
|
84 |
+
"ZH": 0,
|
85 |
+
"JP": num_zh_tones,
|
86 |
+
"EN": num_zh_tones + num_ja_tones,
|
87 |
+
'KR': num_zh_tones + num_ja_tones + num_en_tones,
|
88 |
+
}
|
OpenVoice/openvoice/transforms.py
ADDED
@@ -0,0 +1,209 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
import torch
|
2 |
+
from torch.nn import functional as F
|
3 |
+
|
4 |
+
import numpy as np
|
5 |
+
|
6 |
+
|
7 |
+
DEFAULT_MIN_BIN_WIDTH = 1e-3
|
8 |
+
DEFAULT_MIN_BIN_HEIGHT = 1e-3
|
9 |
+
DEFAULT_MIN_DERIVATIVE = 1e-3
|
10 |
+
|
11 |
+
|
12 |
+
def piecewise_rational_quadratic_transform(
|
13 |
+
inputs,
|
14 |
+
unnormalized_widths,
|
15 |
+
unnormalized_heights,
|
16 |
+
unnormalized_derivatives,
|
17 |
+
inverse=False,
|
18 |
+
tails=None,
|
19 |
+
tail_bound=1.0,
|
20 |
+
min_bin_width=DEFAULT_MIN_BIN_WIDTH,
|
21 |
+
min_bin_height=DEFAULT_MIN_BIN_HEIGHT,
|
22 |
+
min_derivative=DEFAULT_MIN_DERIVATIVE,
|
23 |
+
):
|
24 |
+
if tails is None:
|
25 |
+
spline_fn = rational_quadratic_spline
|
26 |
+
spline_kwargs = {}
|
27 |
+
else:
|
28 |
+
spline_fn = unconstrained_rational_quadratic_spline
|
29 |
+
spline_kwargs = {"tails": tails, "tail_bound": tail_bound}
|
30 |
+
|
31 |
+
outputs, logabsdet = spline_fn(
|
32 |
+
inputs=inputs,
|
33 |
+
unnormalized_widths=unnormalized_widths,
|
34 |
+
unnormalized_heights=unnormalized_heights,
|
35 |
+
unnormalized_derivatives=unnormalized_derivatives,
|
36 |
+
inverse=inverse,
|
37 |
+
min_bin_width=min_bin_width,
|
38 |
+
min_bin_height=min_bin_height,
|
39 |
+
min_derivative=min_derivative,
|
40 |
+
**spline_kwargs
|
41 |
+
)
|
42 |
+
return outputs, logabsdet
|
43 |
+
|
44 |
+
|
45 |
+
def searchsorted(bin_locations, inputs, eps=1e-6):
|
46 |
+
bin_locations[..., -1] += eps
|
47 |
+
return torch.sum(inputs[..., None] >= bin_locations, dim=-1) - 1
|
48 |
+
|
49 |
+
|
50 |
+
def unconstrained_rational_quadratic_spline(
|
51 |
+
inputs,
|
52 |
+
unnormalized_widths,
|
53 |
+
unnormalized_heights,
|
54 |
+
unnormalized_derivatives,
|
55 |
+
inverse=False,
|
56 |
+
tails="linear",
|
57 |
+
tail_bound=1.0,
|
58 |
+
min_bin_width=DEFAULT_MIN_BIN_WIDTH,
|
59 |
+
min_bin_height=DEFAULT_MIN_BIN_HEIGHT,
|
60 |
+
min_derivative=DEFAULT_MIN_DERIVATIVE,
|
61 |
+
):
|
62 |
+
inside_interval_mask = (inputs >= -tail_bound) & (inputs <= tail_bound)
|
63 |
+
outside_interval_mask = ~inside_interval_mask
|
64 |
+
|
65 |
+
outputs = torch.zeros_like(inputs)
|
66 |
+
logabsdet = torch.zeros_like(inputs)
|
67 |
+
|
68 |
+
if tails == "linear":
|
69 |
+
unnormalized_derivatives = F.pad(unnormalized_derivatives, pad=(1, 1))
|
70 |
+
constant = np.log(np.exp(1 - min_derivative) - 1)
|
71 |
+
unnormalized_derivatives[..., 0] = constant
|
72 |
+
unnormalized_derivatives[..., -1] = constant
|
73 |
+
|
74 |
+
outputs[outside_interval_mask] = inputs[outside_interval_mask]
|
75 |
+
logabsdet[outside_interval_mask] = 0
|
76 |
+
else:
|
77 |
+
raise RuntimeError("{} tails are not implemented.".format(tails))
|
78 |
+
|
79 |
+
(
|
80 |
+
outputs[inside_interval_mask],
|
81 |
+
logabsdet[inside_interval_mask],
|
82 |
+
) = rational_quadratic_spline(
|
83 |
+
inputs=inputs[inside_interval_mask],
|
84 |
+
unnormalized_widths=unnormalized_widths[inside_interval_mask, :],
|
85 |
+
unnormalized_heights=unnormalized_heights[inside_interval_mask, :],
|
86 |
+
unnormalized_derivatives=unnormalized_derivatives[inside_interval_mask, :],
|
87 |
+
inverse=inverse,
|
88 |
+
left=-tail_bound,
|
89 |
+
right=tail_bound,
|
90 |
+
bottom=-tail_bound,
|
91 |
+
top=tail_bound,
|
92 |
+
min_bin_width=min_bin_width,
|
93 |
+
min_bin_height=min_bin_height,
|
94 |
+
min_derivative=min_derivative,
|
95 |
+
)
|
96 |
+
|
97 |
+
return outputs, logabsdet
|
98 |
+
|
99 |
+
|
100 |
+
def rational_quadratic_spline(
|
101 |
+
inputs,
|
102 |
+
unnormalized_widths,
|
103 |
+
unnormalized_heights,
|
104 |
+
unnormalized_derivatives,
|
105 |
+
inverse=False,
|
106 |
+
left=0.0,
|
107 |
+
right=1.0,
|
108 |
+
bottom=0.0,
|
109 |
+
top=1.0,
|
110 |
+
min_bin_width=DEFAULT_MIN_BIN_WIDTH,
|
111 |
+
min_bin_height=DEFAULT_MIN_BIN_HEIGHT,
|
112 |
+
min_derivative=DEFAULT_MIN_DERIVATIVE,
|
113 |
+
):
|
114 |
+
if torch.min(inputs) < left or torch.max(inputs) > right:
|
115 |
+
raise ValueError("Input to a transform is not within its domain")
|
116 |
+
|
117 |
+
num_bins = unnormalized_widths.shape[-1]
|
118 |
+
|
119 |
+
if min_bin_width * num_bins > 1.0:
|
120 |
+
raise ValueError("Minimal bin width too large for the number of bins")
|
121 |
+
if min_bin_height * num_bins > 1.0:
|
122 |
+
raise ValueError("Minimal bin height too large for the number of bins")
|
123 |
+
|
124 |
+
widths = F.softmax(unnormalized_widths, dim=-1)
|
125 |
+
widths = min_bin_width + (1 - min_bin_width * num_bins) * widths
|
126 |
+
cumwidths = torch.cumsum(widths, dim=-1)
|
127 |
+
cumwidths = F.pad(cumwidths, pad=(1, 0), mode="constant", value=0.0)
|
128 |
+
cumwidths = (right - left) * cumwidths + left
|
129 |
+
cumwidths[..., 0] = left
|
130 |
+
cumwidths[..., -1] = right
|
131 |
+
widths = cumwidths[..., 1:] - cumwidths[..., :-1]
|
132 |
+
|
133 |
+
derivatives = min_derivative + F.softplus(unnormalized_derivatives)
|
134 |
+
|
135 |
+
heights = F.softmax(unnormalized_heights, dim=-1)
|
136 |
+
heights = min_bin_height + (1 - min_bin_height * num_bins) * heights
|
137 |
+
cumheights = torch.cumsum(heights, dim=-1)
|
138 |
+
cumheights = F.pad(cumheights, pad=(1, 0), mode="constant", value=0.0)
|
139 |
+
cumheights = (top - bottom) * cumheights + bottom
|
140 |
+
cumheights[..., 0] = bottom
|
141 |
+
cumheights[..., -1] = top
|
142 |
+
heights = cumheights[..., 1:] - cumheights[..., :-1]
|
143 |
+
|
144 |
+
if inverse:
|
145 |
+
bin_idx = searchsorted(cumheights, inputs)[..., None]
|
146 |
+
else:
|
147 |
+
bin_idx = searchsorted(cumwidths, inputs)[..., None]
|
148 |
+
|
149 |
+
input_cumwidths = cumwidths.gather(-1, bin_idx)[..., 0]
|
150 |
+
input_bin_widths = widths.gather(-1, bin_idx)[..., 0]
|
151 |
+
|
152 |
+
input_cumheights = cumheights.gather(-1, bin_idx)[..., 0]
|
153 |
+
delta = heights / widths
|
154 |
+
input_delta = delta.gather(-1, bin_idx)[..., 0]
|
155 |
+
|
156 |
+
input_derivatives = derivatives.gather(-1, bin_idx)[..., 0]
|
157 |
+
input_derivatives_plus_one = derivatives[..., 1:].gather(-1, bin_idx)[..., 0]
|
158 |
+
|
159 |
+
input_heights = heights.gather(-1, bin_idx)[..., 0]
|
160 |
+
|
161 |
+
if inverse:
|
162 |
+
a = (inputs - input_cumheights) * (
|
163 |
+
input_derivatives + input_derivatives_plus_one - 2 * input_delta
|
164 |
+
) + input_heights * (input_delta - input_derivatives)
|
165 |
+
b = input_heights * input_derivatives - (inputs - input_cumheights) * (
|
166 |
+
input_derivatives + input_derivatives_plus_one - 2 * input_delta
|
167 |
+
)
|
168 |
+
c = -input_delta * (inputs - input_cumheights)
|
169 |
+
|
170 |
+
discriminant = b.pow(2) - 4 * a * c
|
171 |
+
assert (discriminant >= 0).all()
|
172 |
+
|
173 |
+
root = (2 * c) / (-b - torch.sqrt(discriminant))
|
174 |
+
outputs = root * input_bin_widths + input_cumwidths
|
175 |
+
|
176 |
+
theta_one_minus_theta = root * (1 - root)
|
177 |
+
denominator = input_delta + (
|
178 |
+
(input_derivatives + input_derivatives_plus_one - 2 * input_delta)
|
179 |
+
* theta_one_minus_theta
|
180 |
+
)
|
181 |
+
derivative_numerator = input_delta.pow(2) * (
|
182 |
+
input_derivatives_plus_one * root.pow(2)
|
183 |
+
+ 2 * input_delta * theta_one_minus_theta
|
184 |
+
+ input_derivatives * (1 - root).pow(2)
|
185 |
+
)
|
186 |
+
logabsdet = torch.log(derivative_numerator) - 2 * torch.log(denominator)
|
187 |
+
|
188 |
+
return outputs, -logabsdet
|
189 |
+
else:
|
190 |
+
theta = (inputs - input_cumwidths) / input_bin_widths
|
191 |
+
theta_one_minus_theta = theta * (1 - theta)
|
192 |
+
|
193 |
+
numerator = input_heights * (
|
194 |
+
input_delta * theta.pow(2) + input_derivatives * theta_one_minus_theta
|
195 |
+
)
|
196 |
+
denominator = input_delta + (
|
197 |
+
(input_derivatives + input_derivatives_plus_one - 2 * input_delta)
|
198 |
+
* theta_one_minus_theta
|
199 |
+
)
|
200 |
+
outputs = input_cumheights + numerator / denominator
|
201 |
+
|
202 |
+
derivative_numerator = input_delta.pow(2) * (
|
203 |
+
input_derivatives_plus_one * theta.pow(2)
|
204 |
+
+ 2 * input_delta * theta_one_minus_theta
|
205 |
+
+ input_derivatives * (1 - theta).pow(2)
|
206 |
+
)
|
207 |
+
logabsdet = torch.log(derivative_numerator) - 2 * torch.log(denominator)
|
208 |
+
|
209 |
+
return outputs, logabsdet
|
OpenVoice/openvoice/utils.py
ADDED
@@ -0,0 +1,194 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
import re
|
2 |
+
import json
|
3 |
+
import numpy as np
|
4 |
+
|
5 |
+
|
6 |
+
def get_hparams_from_file(config_path):
|
7 |
+
with open(config_path, "r", encoding="utf-8") as f:
|
8 |
+
data = f.read()
|
9 |
+
config = json.loads(data)
|
10 |
+
|
11 |
+
hparams = HParams(**config)
|
12 |
+
return hparams
|
13 |
+
|
14 |
+
class HParams:
|
15 |
+
def __init__(self, **kwargs):
|
16 |
+
for k, v in kwargs.items():
|
17 |
+
if type(v) == dict:
|
18 |
+
v = HParams(**v)
|
19 |
+
self[k] = v
|
20 |
+
|
21 |
+
def keys(self):
|
22 |
+
return self.__dict__.keys()
|
23 |
+
|
24 |
+
def items(self):
|
25 |
+
return self.__dict__.items()
|
26 |
+
|
27 |
+
def values(self):
|
28 |
+
return self.__dict__.values()
|
29 |
+
|
30 |
+
def __len__(self):
|
31 |
+
return len(self.__dict__)
|
32 |
+
|
33 |
+
def __getitem__(self, key):
|
34 |
+
return getattr(self, key)
|
35 |
+
|
36 |
+
def __setitem__(self, key, value):
|
37 |
+
return setattr(self, key, value)
|
38 |
+
|
39 |
+
def __contains__(self, key):
|
40 |
+
return key in self.__dict__
|
41 |
+
|
42 |
+
def __repr__(self):
|
43 |
+
return self.__dict__.__repr__()
|
44 |
+
|
45 |
+
|
46 |
+
def string_to_bits(string, pad_len=8):
|
47 |
+
# Convert each character to its ASCII value
|
48 |
+
ascii_values = [ord(char) for char in string]
|
49 |
+
|
50 |
+
# Convert ASCII values to binary representation
|
51 |
+
binary_values = [bin(value)[2:].zfill(8) for value in ascii_values]
|
52 |
+
|
53 |
+
# Convert binary strings to integer arrays
|
54 |
+
bit_arrays = [[int(bit) for bit in binary] for binary in binary_values]
|
55 |
+
|
56 |
+
# Convert list of arrays to NumPy array
|
57 |
+
numpy_array = np.array(bit_arrays)
|
58 |
+
numpy_array_full = np.zeros((pad_len, 8), dtype=numpy_array.dtype)
|
59 |
+
numpy_array_full[:, 2] = 1
|
60 |
+
max_len = min(pad_len, len(numpy_array))
|
61 |
+
numpy_array_full[:max_len] = numpy_array[:max_len]
|
62 |
+
return numpy_array_full
|
63 |
+
|
64 |
+
|
65 |
+
def bits_to_string(bits_array):
|
66 |
+
# Convert each row of the array to a binary string
|
67 |
+
binary_values = [''.join(str(bit) for bit in row) for row in bits_array]
|
68 |
+
|
69 |
+
# Convert binary strings to ASCII values
|
70 |
+
ascii_values = [int(binary, 2) for binary in binary_values]
|
71 |
+
|
72 |
+
# Convert ASCII values to characters
|
73 |
+
output_string = ''.join(chr(value) for value in ascii_values)
|
74 |
+
|
75 |
+
return output_string
|
76 |
+
|
77 |
+
|
78 |
+
def split_sentence(text, min_len=10, language_str='[EN]'):
|
79 |
+
if language_str in ['EN']:
|
80 |
+
sentences = split_sentences_latin(text, min_len=min_len)
|
81 |
+
else:
|
82 |
+
sentences = split_sentences_zh(text, min_len=min_len)
|
83 |
+
return sentences
|
84 |
+
|
85 |
+
def split_sentences_latin(text, min_len=10):
|
86 |
+
"""Split Long sentences into list of short ones
|
87 |
+
|
88 |
+
Args:
|
89 |
+
str: Input sentences.
|
90 |
+
|
91 |
+
Returns:
|
92 |
+
List[str]: list of output sentences.
|
93 |
+
"""
|
94 |
+
# deal with dirty sentences
|
95 |
+
text = re.sub('[。!?;]', '.', text)
|
96 |
+
text = re.sub('[,]', ',', text)
|
97 |
+
text = re.sub('[“”]', '"', text)
|
98 |
+
text = re.sub('[‘’]', "'", text)
|
99 |
+
text = re.sub(r"[\<\>\(\)\[\]\"\«\»]+", "", text)
|
100 |
+
text = re.sub('[\n\t ]+', ' ', text)
|
101 |
+
text = re.sub('([,.!?;])', r'\1 $#!', text)
|
102 |
+
# split
|
103 |
+
sentences = [s.strip() for s in text.split('$#!')]
|
104 |
+
if len(sentences[-1]) == 0: del sentences[-1]
|
105 |
+
|
106 |
+
new_sentences = []
|
107 |
+
new_sent = []
|
108 |
+
count_len = 0
|
109 |
+
for ind, sent in enumerate(sentences):
|
110 |
+
# print(sent)
|
111 |
+
new_sent.append(sent)
|
112 |
+
count_len += len(sent.split(" "))
|
113 |
+
if count_len > min_len or ind == len(sentences) - 1:
|
114 |
+
count_len = 0
|
115 |
+
new_sentences.append(' '.join(new_sent))
|
116 |
+
new_sent = []
|
117 |
+
return merge_short_sentences_latin(new_sentences)
|
118 |
+
|
119 |
+
|
120 |
+
def merge_short_sentences_latin(sens):
|
121 |
+
"""Avoid short sentences by merging them with the following sentence.
|
122 |
+
|
123 |
+
Args:
|
124 |
+
List[str]: list of input sentences.
|
125 |
+
|
126 |
+
Returns:
|
127 |
+
List[str]: list of output sentences.
|
128 |
+
"""
|
129 |
+
sens_out = []
|
130 |
+
for s in sens:
|
131 |
+
# If the previous sentence is too short, merge them with
|
132 |
+
# the current sentence.
|
133 |
+
if len(sens_out) > 0 and len(sens_out[-1].split(" ")) <= 2:
|
134 |
+
sens_out[-1] = sens_out[-1] + " " + s
|
135 |
+
else:
|
136 |
+
sens_out.append(s)
|
137 |
+
try:
|
138 |
+
if len(sens_out[-1].split(" ")) <= 2:
|
139 |
+
sens_out[-2] = sens_out[-2] + " " + sens_out[-1]
|
140 |
+
sens_out.pop(-1)
|
141 |
+
except:
|
142 |
+
pass
|
143 |
+
return sens_out
|
144 |
+
|
145 |
+
def split_sentences_zh(text, min_len=10):
|
146 |
+
text = re.sub('[。!?;]', '.', text)
|
147 |
+
text = re.sub('[,]', ',', text)
|
148 |
+
# 将文本中的换行符、空格和制表符替换为空格
|
149 |
+
text = re.sub('[\n\t ]+', ' ', text)
|
150 |
+
# 在标点符号后添加一个空格
|
151 |
+
text = re.sub('([,.!?;])', r'\1 $#!', text)
|
152 |
+
# 分隔句子并去除前后空格
|
153 |
+
# sentences = [s.strip() for s in re.split('(。|!|?|;)', text)]
|
154 |
+
sentences = [s.strip() for s in text.split('$#!')]
|
155 |
+
if len(sentences[-1]) == 0: del sentences[-1]
|
156 |
+
|
157 |
+
new_sentences = []
|
158 |
+
new_sent = []
|
159 |
+
count_len = 0
|
160 |
+
for ind, sent in enumerate(sentences):
|
161 |
+
new_sent.append(sent)
|
162 |
+
count_len += len(sent)
|
163 |
+
if count_len > min_len or ind == len(sentences) - 1:
|
164 |
+
count_len = 0
|
165 |
+
new_sentences.append(' '.join(new_sent))
|
166 |
+
new_sent = []
|
167 |
+
return merge_short_sentences_zh(new_sentences)
|
168 |
+
|
169 |
+
|
170 |
+
def merge_short_sentences_zh(sens):
|
171 |
+
# return sens
|
172 |
+
"""Avoid short sentences by merging them with the following sentence.
|
173 |
+
|
174 |
+
Args:
|
175 |
+
List[str]: list of input sentences.
|
176 |
+
|
177 |
+
Returns:
|
178 |
+
List[str]: list of output sentences.
|
179 |
+
"""
|
180 |
+
sens_out = []
|
181 |
+
for s in sens:
|
182 |
+
# If the previous sentense is too short, merge them with
|
183 |
+
# the current sentence.
|
184 |
+
if len(sens_out) > 0 and len(sens_out[-1]) <= 2:
|
185 |
+
sens_out[-1] = sens_out[-1] + " " + s
|
186 |
+
else:
|
187 |
+
sens_out.append(s)
|
188 |
+
try:
|
189 |
+
if len(sens_out[-1]) <= 2:
|
190 |
+
sens_out[-2] = sens_out[-2] + " " + sens_out[-1]
|
191 |
+
sens_out.pop(-1)
|
192 |
+
except:
|
193 |
+
pass
|
194 |
+
return sens_out
|
OpenVoice/resources/framework-ipa.png
ADDED
![]() |
Git LFS Details
|
OpenVoice/resources/huggingface.png
ADDED
![]() |
Git LFS Details
|
OpenVoice/resources/lepton-hd.png
ADDED
![]() |
Git LFS Details
|
OpenVoice/resources/myshell-hd.png
ADDED
![]() |
Git LFS Details
|
OpenVoice/resources/tts-guide.png
ADDED
![]() |
Git LFS Details
|
OpenVoice/resources/voice-clone-guide.png
ADDED
![]() |
Git LFS Details
|
OpenVoice/setup.py
ADDED
@@ -0,0 +1,45 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
from setuptools import setup, find_packages
|
2 |
+
|
3 |
+
|
4 |
+
setup(name='MyShell-OpenVoice',
|
5 |
+
version='0.0.0',
|
6 |
+
description='Instant voice cloning by MyShell.',
|
7 |
+
long_description=open('README.md').read().strip(),
|
8 |
+
long_description_content_type='text/markdown',
|
9 |
+
keywords=[
|
10 |
+
'text-to-speech',
|
11 |
+
'tts',
|
12 |
+
'voice-clone',
|
13 |
+
'zero-shot-tts'
|
14 |
+
],
|
15 |
+
url='https://github.com/myshell-ai/OpenVoice',
|
16 |
+
project_urls={
|
17 |
+
'Documentation': 'https://github.com/myshell-ai/OpenVoice/blob/main/docs/USAGE.md',
|
18 |
+
'Changes': 'https://github.com/myshell-ai/OpenVoice/releases',
|
19 |
+
'Code': 'https://github.com/myshell-ai/OpenVoice',
|
20 |
+
'Issue tracker': 'https://github.com/myshell-ai/OpenVoice/issues',
|
21 |
+
},
|
22 |
+
author='MyShell',
|
23 |
+
author_email='ethan@myshell.ai',
|
24 |
+
license='MIT License',
|
25 |
+
packages=find_packages(),
|
26 |
+
|
27 |
+
python_requires='>=3.9',
|
28 |
+
install_requires=[
|
29 |
+
'librosa==0.9.1',
|
30 |
+
'faster-whisper==0.9.0',
|
31 |
+
'pydub==0.25.1',
|
32 |
+
'wavmark==0.0.3',
|
33 |
+
'numpy==1.22.0',
|
34 |
+
'eng_to_ipa==0.0.2',
|
35 |
+
'inflect==7.0.0',
|
36 |
+
'unidecode==1.3.7',
|
37 |
+
'whisper-timestamped==1.14.2',
|
38 |
+
'pypinyin==0.50.0',
|
39 |
+
'cn2an==0.5.22',
|
40 |
+
'jieba==0.42.1',
|
41 |
+
'gradio==3.48.0',
|
42 |
+
'langid==1.1.6'
|
43 |
+
],
|
44 |
+
zip_safe=False
|
45 |
+
)
|
app.py
ADDED
@@ -0,0 +1,548 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
import gradio as gr
|
2 |
+
import numpy as np
|
3 |
+
import json
|
4 |
+
from tts_api import TTSapi, DEFAULT_TTS_MODEL_NAME
|
5 |
+
from config import *
|
6 |
+
from utils import *
|
7 |
+
from knowledge_base import LocalRAG, CosPlayer
|
8 |
+
from transformers import AutoModelForCausalLM, AutoTokenizer
|
9 |
+
|
10 |
+
def handle_retry(history, thinking_history, config, section_state, retry_data: gr.RetryData):
|
11 |
+
# 获取用户之前的消息
|
12 |
+
previous_message = history[retry_data.index]['content']
|
13 |
+
# 清除后续的回复和思考过程
|
14 |
+
new_history = history[:retry_data.index]
|
15 |
+
section_state['chat_history'] = section_state['chat_history'][:retry_data.index + 1]
|
16 |
+
|
17 |
+
try:
|
18 |
+
items = thinking_history.split('\n==================\n')
|
19 |
+
if len(items) > 2:
|
20 |
+
new_thinking_history = '\n==================\n'.join(items[:-2])
|
21 |
+
else:
|
22 |
+
new_thinking_history = ''
|
23 |
+
|
24 |
+
items = section_state['thinking_history'].split('\n==================\n')
|
25 |
+
if len(items) > 2:
|
26 |
+
section_state['thinking_history'] = '\n==================\n'.join(items[:-2])
|
27 |
+
else:
|
28 |
+
section_state['thinking_history'] = ''
|
29 |
+
except Exception as e:
|
30 |
+
print('-----------------------------------')
|
31 |
+
print(e)
|
32 |
+
print('-----------------------------------')
|
33 |
+
print('思考过程发生异常,重置为空')
|
34 |
+
section_state['thinking_history'] = ''
|
35 |
+
new_thinking_history = ''
|
36 |
+
# 重新生成回复
|
37 |
+
return predict(previous_message, new_history, new_thinking_history, config, section_state)
|
38 |
+
|
39 |
+
|
40 |
+
def predict(message, chat_history, thinking_history, config, section_state):
|
41 |
+
global local_rag, TTS_LOADED, LLM_LOADED, synthesiser, core_llm, core_tokenizer
|
42 |
+
print(f"当前模式:{config['mode_selected']}")
|
43 |
+
print(f'角色扮演描述:{config["character_description"]}')
|
44 |
+
print(f"写入角色设定方式:{config['character_setting_mode']}")
|
45 |
+
print(f"选中LLM:{config['llm_model']}")
|
46 |
+
print(f"是否使用RAG本地知识库:{config['kb_on']}")
|
47 |
+
print(f"选中知识库:{config['current_knowledge_base']}")
|
48 |
+
print(f"是否联网搜索:{config['net_on']}")
|
49 |
+
print(f"选中TTS模型:{config['tts_model']}")
|
50 |
+
print(f"是否合成语音:{config['tts_on']}")
|
51 |
+
print(f"参考音频路径:{config['ref_audio']}")
|
52 |
+
print(f"参考音频文本:{config['ref_audio_transcribe']}")
|
53 |
+
|
54 |
+
context = ''
|
55 |
+
net_search_res = []
|
56 |
+
docs = []
|
57 |
+
if config['kb_on'] and len(config['current_knowledge_base']) > 0:
|
58 |
+
# 检索相似文档
|
59 |
+
doc_and_scores = local_rag.vector_db.similarity_search(message, k=local_rag.rag_top_k)
|
60 |
+
# doc_and_scores = list(filter(lambda x: x[1] <= 0.4, doc_and_scores))
|
61 |
+
if len(doc_and_scores) > 0:
|
62 |
+
docs, scores = list(zip(*doc_and_scores))
|
63 |
+
docs, scores = list(docs), list(scores)
|
64 |
+
context_local = "【本地知识库】" + "\n".join([concate_metadata(d.metadata) + d.page_content for d in docs])
|
65 |
+
context = context + context_local
|
66 |
+
if config['net_on']:
|
67 |
+
# 检索相似文档
|
68 |
+
ret = web_search(message, max_results=MAX_RESULTS)
|
69 |
+
net_search_res = parse_net_search(ret)
|
70 |
+
context_net = "\n【网络搜索结果】" + ''.join(net_search_res)
|
71 |
+
context = context + context_net
|
72 |
+
|
73 |
+
if config['character_description']:
|
74 |
+
if config['character_setting_mode'] == 'by system':
|
75 |
+
if len(section_state['chat_history']) == 0 or section_state['chat_history'][0]['role'] != 'system':
|
76 |
+
section_state['chat_history'].insert(0, {"role": "system", "content": config["character_description"]})
|
77 |
+
elif config['character_setting_mode'] == 'by prompt':
|
78 |
+
if len(section_state['chat_history']) > 0 and section_state['chat_history'][0]['role'] == 'system':
|
79 |
+
section_state['chat_history'].pop(0)
|
80 |
+
context = f'【系统核心设定】:{config["character_description"]}\n' if config["character_description"] else '' + context
|
81 |
+
else:
|
82 |
+
raise ValueError(f"未知的角色设定模式:{config['character_setting_mode']}")
|
83 |
+
|
84 |
+
if len(context) > 0:
|
85 |
+
prompt = f"""请充分理解以下上下文信息,并结合当前及历史对话产生回复':\n
|
86 |
+
上下文:{context}
|
87 |
+
用户当前输入:{message}
|
88 |
+
回复:
|
89 |
+
"""
|
90 |
+
input_message = section_state["chat_history"] + [{"role": "user", "content": prompt}]
|
91 |
+
else:
|
92 |
+
input_message = section_state["chat_history"] + [{"role": "user", "content": message}]
|
93 |
+
|
94 |
+
# 关闭Qwen3系列默认的思考模式
|
95 |
+
if config['llm_model'].startswith('Qwen3'):
|
96 |
+
input_message[-1]['content'] += '/no_think'
|
97 |
+
# input_message[-1]['content'] += '/no_think'
|
98 |
+
|
99 |
+
# 添加用户消息到历史
|
100 |
+
section_state["chat_history"].append({"role": "user", "content": message})
|
101 |
+
|
102 |
+
|
103 |
+
|
104 |
+
try:
|
105 |
+
# 调用模型
|
106 |
+
if not LLM_LOADED:
|
107 |
+
core_llm = AutoModelForCausalLM.from_pretrained(
|
108 |
+
config['llm_model'],
|
109 |
+
torch_dtype="auto",
|
110 |
+
device_map="auto"
|
111 |
+
)
|
112 |
+
core_tokenizer = AutoTokenizer.from_pretrained(config['llm_model'])
|
113 |
+
LLM_LOADED = True
|
114 |
+
token_cnt = count_tokens_local(input_message, core_tokenizer)
|
115 |
+
if token_cnt >= MAX_MODEL_CTX:
|
116 |
+
gr.Warning("当前对话已经超出模型上下文长度,请开启新会话...")
|
117 |
+
text = core_tokenizer.apply_chat_template(
|
118 |
+
input_message,
|
119 |
+
tokenize=False,
|
120 |
+
add_generation_prompt=True,
|
121 |
+
enable_thinking=False
|
122 |
+
)
|
123 |
+
model_inputs = core_tokenizer([text], return_tensors="pt").to(core_llm.device)
|
124 |
+
# conduct text completion
|
125 |
+
generated_ids = core_llm.generate(
|
126 |
+
**model_inputs,
|
127 |
+
max_new_tokens=32768
|
128 |
+
)
|
129 |
+
output_ids = generated_ids[0][len(model_inputs.input_ids[0]):].tolist()
|
130 |
+
|
131 |
+
# parsing thinking content
|
132 |
+
# try:
|
133 |
+
# # rindex finding 151668 (</think>)
|
134 |
+
# index = len(output_ids) - output_ids[::-1].index(151668)
|
135 |
+
# except ValueError:
|
136 |
+
# index = 0
|
137 |
+
index = 0
|
138 |
+
# thinking_content = tokenizer.decode(output_ids[:index], skip_special_tokens=True).strip("\n")
|
139 |
+
thinking = None
|
140 |
+
response_content = core_tokenizer.decode(output_ids[index:], skip_special_tokens=True).strip("\n")
|
141 |
+
print('回复:', response_content)
|
142 |
+
# 更新对话历史
|
143 |
+
chat_history.append({'role': 'user', 'content': message})
|
144 |
+
if len(context) > 0:
|
145 |
+
# 构建带折叠结构的消息
|
146 |
+
formatted_response = f"""
|
147 |
+
<details class="rag-details">
|
148 |
+
<summary style='cursor: pointer; color: #666;'>
|
149 |
+
🔍 检索完成✅(共{len(docs)+len(net_search_res)}条)
|
150 |
+
</summary>
|
151 |
+
<div style='margin:10px 0;padding:10px;background:#f5f5f5;border-radius:8px;'>
|
152 |
+
{
|
153 |
+
"<br>".join(
|
154 |
+
["<br>".join(wash_up_content(content if isinstance(content, str) else (content.page_content, scores[idx])))
|
155 |
+
for idx, content in enumerate(docs + net_search_res)]
|
156 |
+
)
|
157 |
+
}
|
158 |
+
</div>
|
159 |
+
</details>
|
160 |
+
<div style="margin-top: 10px;">{response_content}</div> <!-- 增加顶部间距容器 -->
|
161 |
+
"""
|
162 |
+
chat_history.append({'role': 'assistant', 'content': formatted_response})
|
163 |
+
else:
|
164 |
+
chat_history.append({'role': 'assistant', 'content': response_content})
|
165 |
+
|
166 |
+
thinking_history += f"User: {message}\nThinking: {thinking}" + '\n==================\n'
|
167 |
+
# 添加助手响应到历史
|
168 |
+
section_state["chat_history"].append({"role": "assistant", "content": response_content})
|
169 |
+
section_state["thinking_history"] += f"User: {message}\nThinking: {thinking}" + '\n==================\n'
|
170 |
+
|
171 |
+
if (not config['tts_on']) or len(response_content) == 0:
|
172 |
+
audio_output = np.array([0], dtype=np.int16)
|
173 |
+
if len(response_content) == 0:
|
174 |
+
print("LLM 回复为空,无法合成语音")
|
175 |
+
else:
|
176 |
+
if not TTS_LOADED:
|
177 |
+
print('TTS模型首次加载...')
|
178 |
+
gr.Info("初次加载TTS模型,请稍候..", duration=63)
|
179 |
+
synthesiser = TTSapi(model_name=config['tts_model'])
|
180 |
+
TTS_LOADED = True
|
181 |
+
print('加载完毕...')
|
182 |
+
# 检查当前模型是否是所选
|
183 |
+
if config['tts_model'] != synthesiser.model_name:
|
184 |
+
print(f'当前TTS模型{synthesiser.model_name}非所选,重新加载')
|
185 |
+
synthesiser.reload(model_name=config['tts_model'])
|
186 |
+
|
187 |
+
# 如果提供了参考音频,则需把参考音频的文本加在response_content前面作为前缀
|
188 |
+
if config['ref_audio']:
|
189 |
+
prompt_text = config['ref_audio_transcribe']
|
190 |
+
if prompt_text is None:
|
191 |
+
# prompt_text = ...
|
192 |
+
raise NotImplementedError('暂时必须提供文本') # TODO:考虑后续加入ASR模型
|
193 |
+
response_content = prompt_text + response_content
|
194 |
+
|
195 |
+
audio_output = synthesiser.forward(response_content, speech_prompt=config['ref_audio'])
|
196 |
+
|
197 |
+
except Exception as e:
|
198 |
+
print('!!!!!!!!')
|
199 |
+
print(e)
|
200 |
+
print('!!!!!!!!')
|
201 |
+
error_msg = f"Error: {str(e)}"
|
202 |
+
chat_history.append((message, error_msg))
|
203 |
+
thinking_history += f"Error occurred: {str(e)}" + '\n'
|
204 |
+
|
205 |
+
return "", chat_history, thinking_history, (synthesiser.sr if synthesiser else 16000, audio_output)
|
206 |
+
|
207 |
+
|
208 |
+
def init_model(init_llm=True, init_rag=False, init_tts=False):
|
209 |
+
if init_llm:
|
210 |
+
print(f'正在加载LLM:{DEFAULT_MODEL_NAME}...')
|
211 |
+
core_llm = AutoModelForCausalLM.from_pretrained(
|
212 |
+
DEFAULT_MODEL_NAME,
|
213 |
+
torch_dtype="auto",
|
214 |
+
device_map="auto"
|
215 |
+
)
|
216 |
+
print('device:', core_llm.device)
|
217 |
+
core_tokenizer = AutoTokenizer.from_pretrained(DEFAULT_MODEL_NAME)
|
218 |
+
LLM_LOADED = True
|
219 |
+
else:
|
220 |
+
core_llm, core_tokenizer = None, None
|
221 |
+
LLM_LOADED = False
|
222 |
+
|
223 |
+
if init_rag:
|
224 |
+
gr.Info("正在加载知识库,请稍候...")
|
225 |
+
local_rag = LocalRAG(rag_top_k=RAG_TOP_K)
|
226 |
+
else:
|
227 |
+
local_rag =None
|
228 |
+
|
229 |
+
if init_tts:
|
230 |
+
print(f'正在加载TTS模型:{DEFAULT_TTS_MODEL_NAME}...')
|
231 |
+
synthesiser = TTSapi()
|
232 |
+
TTS_LOADED = True
|
233 |
+
else:
|
234 |
+
synthesiser = None
|
235 |
+
TTS_LOADED = False
|
236 |
+
return local_rag, synthesiser, core_llm, core_tokenizer, TTS_LOADED, LLM_LOADED
|
237 |
+
|
238 |
+
|
239 |
+
if __name__ == "__main__":
|
240 |
+
import time
|
241 |
+
st = time.time()
|
242 |
+
print('********************模型加载中************************')
|
243 |
+
local_rag, synthesiser, core_llm, core_tokenizer, TTS_LOADED, LLM_LOADED = init_model()
|
244 |
+
print('********************模型加载完成************************')
|
245 |
+
print('耗时:',time.time() - st)
|
246 |
+
|
247 |
+
state = {}
|
248 |
+
resp, state = log_in(0, state)
|
249 |
+
cosplayer = CosPlayer(description_file=DEFAULT_COSPLAY_SETTING)
|
250 |
+
print("===== 初始化开始 =====")
|
251 |
+
with gr.Blocks(css=CSS, title="LLM Chat Demo", theme=gr.themes.Soft(font=["sans-serif", "Arial"])) as demo:
|
252 |
+
gr.Markdown("""
|
253 |
+
# LLM Chat Demo
|
254 |
+
## 用法介绍
|
255 |
+
### 用户登录
|
256 |
+
* 输入用户名,点击Log In按钮。首次登录会自动创建用户目录,聊天记录会保存在下面,如不登录,默认为公共目录'0'
|
257 |
+
### 模型选择
|
258 |
+
目前支持Qwen、Deepseek-R1蒸馏系列等部分模型,可下拉菜单选择
|
259 |
+
### 高级设置
|
260 |
+
* 模式选择:可以选择角色扮演模式/普通模式
|
261 |
+
* 角色设定选择:支持加载不同角色设定文件
|
262 |
+
* 角色配置方式:
|
263 |
+
* by system: 角色设定将作为system prompt存在于输入首部
|
264 |
+
* by prompt: 角色设定每次被添加到当前上下文中
|
265 |
+
* 知识库配置: 支持自由选择、组合知识库
|
266 |
+
""")
|
267 |
+
section_state = gr.State(value=state) # 创建会话状态对象
|
268 |
+
with gr.Row():
|
269 |
+
uid_input = gr.Textbox(label="Type Your UID:")
|
270 |
+
response = gr.Textbox(label='', value=resp)
|
271 |
+
login_button = gr.Button("Log In")
|
272 |
+
llm_select = gr.Dropdown(label= "模型选择", choices=AVALIABLE_MODELS, value=DEFAULT_MODEL_NAME, visible=True)
|
273 |
+
|
274 |
+
gr.Markdown("## 高级设置")
|
275 |
+
with gr.Accordion("点击展开折叠", open=False, visible=True):
|
276 |
+
mode_select = gr.Radio(label='模式选择', choices=SUPPORT_MODES, value=DEFAULT_MODE)
|
277 |
+
coser_select = gr.Dropdown(label= "角色设定选择", choices=cosplayer.get_all_characters(), value=DEFAULT_COSPLAY_SETTING, visible=True)
|
278 |
+
coser_setting = gr.Radio(label='角色配置方式', choices=CHARACTER_SETTING_MODES, value=DEFAULT_C_SETTING_MODE, visible=True)
|
279 |
+
kb_select = gr.Dropdown(label= "知识库配置", choices=AVALIABLE_KNOWLEDGE_BASE, value=None, visible=True, multiselect=True)
|
280 |
+
|
281 |
+
with gr.Row():
|
282 |
+
# 页面左侧
|
283 |
+
with gr.Column(scale=3):
|
284 |
+
chatbot = gr.Chatbot(label="对话记录", height=500, show_copy_button=True, type='messages')
|
285 |
+
with gr.Row():
|
286 |
+
msg = gr.Textbox(label="输入消息", placeholder="请输入您的问题...", scale=7)
|
287 |
+
with gr.Column(scale=1, min_width=15):
|
288 |
+
with gr.Row():
|
289 |
+
rag_switch = gr.Checkbox(label="本地RAG", value=False, info="")
|
290 |
+
net_switch = gr.Checkbox(label="联网搜索", value=False, info="")
|
291 |
+
|
292 |
+
submit_btn = gr.Button("发送", variant="primary", min_width=15)#, , elem_classes=['custom-btn']
|
293 |
+
with gr.Row():
|
294 |
+
gr.Examples(
|
295 |
+
examples=[[example] for example in EXAMPLES],
|
296 |
+
inputs=msg,
|
297 |
+
outputs=chatbot,
|
298 |
+
fn=predict,
|
299 |
+
visible=True,
|
300 |
+
cache_examples=False
|
301 |
+
)
|
302 |
+
with gr.Row():
|
303 |
+
save_btn = gr.Button("保存对话")
|
304 |
+
clear_btn = gr.Button("清空对话")
|
305 |
+
chat_history_select = gr.Dropdown(label='加载历史对话', choices=state['available_history'], visible=True, interactive=True)
|
306 |
+
|
307 |
+
# 页面右侧
|
308 |
+
with gr.Column(scale=2):
|
309 |
+
thinking_display = gr.TextArea(label="思考过程",interactive=False,
|
310 |
+
placeholder="模型思考过程将在此显���..."
|
311 |
+
)
|
312 |
+
tts_switch = gr.Checkbox(label="TTS开关", value=False, info="Check me to hear voice")
|
313 |
+
with gr.Tabs() as audio_tabs:
|
314 |
+
# 选项卡1:音频播放
|
315 |
+
with gr.Tab("音频输出", id="audio_output"):
|
316 |
+
audio_player = gr.Audio(
|
317 |
+
label="听听我声音~",
|
318 |
+
type="numpy",
|
319 |
+
interactive=False
|
320 |
+
)
|
321 |
+
|
322 |
+
# 选项卡2:TTS配置
|
323 |
+
with gr.Tab("TTS配置", id="tts_config"):
|
324 |
+
# TTS模型选择
|
325 |
+
tts_model = gr.Dropdown(
|
326 |
+
label="选择TTS模型",
|
327 |
+
choices=["ChatTTS", "HKUSTAudio/LLasa-1B", "HKUSTAudio/LLasa-3B", "HKUSTAudio/LLasa-8B"],
|
328 |
+
value=DEFAULT_TTS_MODEL_NAME,
|
329 |
+
interactive=True
|
330 |
+
)
|
331 |
+
|
332 |
+
# 参考音频上传
|
333 |
+
ref_audio = gr.Audio(
|
334 |
+
label="上传参考音频",
|
335 |
+
type="filepath",
|
336 |
+
interactive=True
|
337 |
+
)
|
338 |
+
ref_audio_transcribe = gr.Textbox(label="参考音频对应文本", visible=True)
|
339 |
+
|
340 |
+
|
341 |
+
# ================= 状态管理 =================
|
342 |
+
current_config = gr.State({
|
343 |
+
"llm_model": DEFAULT_MODEL_NAME,
|
344 |
+
"tts_model": DEFAULT_TTS_MODEL_NAME,
|
345 |
+
"tts_on": False,
|
346 |
+
"kb_on": False,
|
347 |
+
"net_on": False,
|
348 |
+
"ref_audio": None,
|
349 |
+
"ref_audio_transcribe": None,
|
350 |
+
"mode_selected": DEFAULT_MODE,
|
351 |
+
"character_description": cosplayer.get_core_setting(),
|
352 |
+
"character_setting_mode": DEFAULT_C_SETTING_MODE,
|
353 |
+
"current_knowledge_base": AVALIABLE_KNOWLEDGE_BASE[0]
|
354 |
+
})
|
355 |
+
|
356 |
+
# 事件处理
|
357 |
+
login_button.click(log_in, inputs=[uid_input, section_state], outputs=[response, section_state])
|
358 |
+
|
359 |
+
gr.on(triggers=[llm_select.change, tts_model.change, ref_audio.change,
|
360 |
+
ref_audio_transcribe.change, tts_switch.select, rag_switch.select, net_switch.select,
|
361 |
+
mode_select.change],
|
362 |
+
fn=lambda model1, model2, audio, text, tts_on, kb_on, net_on, mode, character_setting, kb_select: {"llm_model": model1, "tts_model": model2, "ref_audio": audio,
|
363 |
+
"ref_audio_transcribe": text, "tts_on": tts_on, "kb_on": kb_on, 'net_on': net_on,
|
364 |
+
"mode_selected": mode, "character_description": None if mode == '普通模式' else cosplayer.get_core_setting(),
|
365 |
+
"character_setting_mode": character_setting, "current_knowledge_base": kb_select},
|
366 |
+
inputs=[llm_select, tts_model, ref_audio, ref_audio_transcribe, tts_switch, rag_switch, net_switch, mode_select, coser_setting, kb_select],
|
367 |
+
outputs=current_config
|
368 |
+
)
|
369 |
+
msg.submit(
|
370 |
+
predict,
|
371 |
+
[msg, chatbot, thinking_display, current_config, section_state],
|
372 |
+
[msg, chatbot, thinking_display, audio_player],
|
373 |
+
queue=False
|
374 |
+
)
|
375 |
+
chatbot.retry(fn=handle_retry,
|
376 |
+
inputs=[chatbot, thinking_display, current_config, section_state],
|
377 |
+
outputs=[msg, chatbot, thinking_display, audio_player])
|
378 |
+
|
379 |
+
submit_btn.click(
|
380 |
+
predict,
|
381 |
+
[msg, chatbot, thinking_display, current_config, section_state],
|
382 |
+
[msg, chatbot, thinking_display, audio_player],
|
383 |
+
queue=False
|
384 |
+
)
|
385 |
+
|
386 |
+
def save_chat(state):
|
387 |
+
from datetime import datetime
|
388 |
+
now = datetime.now().strftime('%Y%m%d_%H%M%S')
|
389 |
+
with open(state['user_dir'] / f'chat_history_{now}.json', 'w', encoding='utf-8') as file:
|
390 |
+
json.dump(state["chat_history"], file, ensure_ascii=False, indent=4)
|
391 |
+
with open(state['user_dir'] / f'thinking_history_{now}.txt', 'w') as file:
|
392 |
+
if isinstance(state["thinking_history"], list):
|
393 |
+
for item in state["thinking_history"]:
|
394 |
+
file.write(item + '\n')
|
395 |
+
else:
|
396 |
+
file.write(state["thinking_history"])
|
397 |
+
|
398 |
+
gr.Info("聊天记录已保存!")
|
399 |
+
state['available_history'].append(f'chat_history_{now}')
|
400 |
+
return state
|
401 |
+
|
402 |
+
def clear_chat(state):
|
403 |
+
state["chat_history"] = []
|
404 |
+
state["thinking_history"] = []
|
405 |
+
prologue = cosplayer.get_prologue()
|
406 |
+
if prologue:
|
407 |
+
state['chat_history'].append({'role': 'assistant', 'content': prologue})
|
408 |
+
chatbot = [{'role': 'assistant', 'content': prologue}]
|
409 |
+
else:
|
410 |
+
chatbot = []
|
411 |
+
return chatbot, [], state
|
412 |
+
|
413 |
+
def load_chat(state, chat_file):
|
414 |
+
# NOTE: 加载历史聊天记录。一般在对话开始之前加载,如果本次对话已经开始,本操作会覆盖当前会话内容
|
415 |
+
if chat_file:
|
416 |
+
think_file = chat_file.replace("chat_", "thinking_")
|
417 |
+
chat_file_path = state['user_dir'] / (chat_file + '.json')
|
418 |
+
think_file_path = state['user_dir'] / (think_file + '.txt')
|
419 |
+
|
420 |
+
if not chat_file_path.exists():
|
421 |
+
gr.Warning(f'聊天记录文件:{chat_file}.json不存在, 加载失败')
|
422 |
+
return [], '', state
|
423 |
+
|
424 |
+
with open(chat_file_path, 'r', encoding='utf-8') as f:
|
425 |
+
content = json.load(f)
|
426 |
+
state['chat_history'] = content
|
427 |
+
|
428 |
+
think = ''
|
429 |
+
if think_file_path.exists():
|
430 |
+
with open(think_file_path, 'r') as f:
|
431 |
+
think = f.read()
|
432 |
+
|
433 |
+
state['thinking_history'] = think
|
434 |
+
|
435 |
+
# 转换成chatbot可以识别的格式
|
436 |
+
# bot_content = parse_chat_history(content)
|
437 |
+
# 指定chatbot类型为message后,无需解析
|
438 |
+
bot_content = content
|
439 |
+
return bot_content, think, state
|
440 |
+
|
441 |
+
return [], '', state
|
442 |
+
|
443 |
+
|
444 |
+
def update_history(state):
|
445 |
+
return gr.update(choices=state['available_history'])
|
446 |
+
|
447 |
+
def update_visible(mode):
|
448 |
+
if mode != '普通模式':
|
449 |
+
gr.Warning("当前为角色扮演模式,请确认已配置好该角色的知识库...")
|
450 |
+
|
451 |
+
return gr.update(visible=True), gr.update(visible=True)
|
452 |
+
return gr.update(visible=False), gr.update(visible=False)
|
453 |
+
|
454 |
+
def update_cosplay(cos_select, config, chatbot, think_display, state):
|
455 |
+
cosplayer.update(cos_select)
|
456 |
+
config['character_description'] = cosplayer.get_core_setting()
|
457 |
+
# 角色设定发生改变后,自动保存当前聊天记录,之后清空历史记录
|
458 |
+
if len(state['chat_history']) > 1:
|
459 |
+
state = save_chat(state)
|
460 |
+
gr.Warning("我的角色已更换,对话已重置。请检查知识库是否需要更新...")
|
461 |
+
chatbot, think_display, state = clear_chat(state)
|
462 |
+
return gr.update(value=cos_select), config, chatbot, think_display, state
|
463 |
+
|
464 |
+
def update_character_setting_mode(coser_setting, config):
|
465 |
+
config['character_setting_mode'] = coser_setting
|
466 |
+
return gr.update(value=coser_setting), config
|
467 |
+
|
468 |
+
def update_knowledge_base(knowledge_base, config):
|
469 |
+
global local_rag
|
470 |
+
config['current_knowledge_base'] = knowledge_base
|
471 |
+
if len(knowledge_base) == 0:
|
472 |
+
gr.Warning("当前未选中任何知识库,本地RAG将失效。请确认...")
|
473 |
+
else:
|
474 |
+
if local_rag is None:
|
475 |
+
gr.Info("初次加载知识库,请稍候...")
|
476 |
+
local_rag = LocalRAG(rag_top_k=RAG_TOP_K, doc_dir=knowledge_base)
|
477 |
+
gr.Info("知识库加载完成!")
|
478 |
+
else:
|
479 |
+
gr.Info("重新加载知识库,请稍候...")
|
480 |
+
local_rag.reload_knowledge_base(knowledge_base)
|
481 |
+
gr.Info("知识库加载完成!")
|
482 |
+
return gr.update(value=knowledge_base), config
|
483 |
+
|
484 |
+
def init_kb(rag_on, kb_select, config):
|
485 |
+
global local_rag
|
486 |
+
if rag_on:
|
487 |
+
# 初始化本地知识库
|
488 |
+
if config['mode_selected'] == "角色扮演":
|
489 |
+
gr.Warning("当前为角色扮演模式,请确认已配置好该角色的知识库...")
|
490 |
+
|
491 |
+
if local_rag is None:
|
492 |
+
gr.Info("初次加载知识库,请稍候...")
|
493 |
+
local_rag = LocalRAG(rag_top_k=RAG_TOP_K, doc_dir=kb_select)
|
494 |
+
gr.Info("知识库加载完成!")
|
495 |
+
return gr.update(value=rag_on)
|
496 |
+
|
497 |
+
# 选择非普通模式时(角色扮演),会展示可控选择的角色设定列表
|
498 |
+
mode_select.change(update_visible,
|
499 |
+
inputs=mode_select,
|
500 |
+
outputs=[coser_select, coser_setting])
|
501 |
+
|
502 |
+
coser_select.change(update_cosplay,
|
503 |
+
inputs=[coser_select, current_config, chatbot, thinking_display, section_state],
|
504 |
+
outputs=[coser_select, current_config, chatbot, thinking_display, section_state])
|
505 |
+
# TODO: 根据角色变化动态展示示例
|
506 |
+
# coser_select.change(update_examples,
|
507 |
+
# inputs=[coser_select],
|
508 |
+
# outputs=[examples_show])
|
509 |
+
|
510 |
+
coser_setting.change(update_character_setting_mode,
|
511 |
+
inputs=[coser_setting, current_config],
|
512 |
+
outputs=[coser_setting, current_config])
|
513 |
+
|
514 |
+
kb_select.change(update_knowledge_base,
|
515 |
+
inputs=[kb_select, current_config],
|
516 |
+
outputs=[kb_select, current_config])
|
517 |
+
|
518 |
+
# 勾选本地知识库时,若为角色扮演模式,提醒用户设置知识库目录
|
519 |
+
rag_switch.select(init_kb, inputs=[rag_switch, kb_select, current_config], outputs=rag_switch)
|
520 |
+
|
521 |
+
clear_btn.click(
|
522 |
+
clear_chat,
|
523 |
+
inputs=section_state,
|
524 |
+
outputs=[chatbot, thinking_display, section_state],
|
525 |
+
queue=False
|
526 |
+
)
|
527 |
+
|
528 |
+
save_btn.click(
|
529 |
+
save_chat,
|
530 |
+
inputs=section_state,
|
531 |
+
outputs=section_state,
|
532 |
+
queue=False
|
533 |
+
).then(
|
534 |
+
fn=update_history,
|
535 |
+
inputs=section_state,
|
536 |
+
outputs=chat_history_select
|
537 |
+
)
|
538 |
+
|
539 |
+
chat_history_select.change(load_chat,
|
540 |
+
inputs=[section_state, chat_history_select],
|
541 |
+
outputs=[chatbot, thinking_display, section_state])
|
542 |
+
|
543 |
+
section_state.change(update_history,
|
544 |
+
inputs=section_state,
|
545 |
+
outputs=chat_history_select)
|
546 |
+
print("===== 初始化完成 =====")
|
547 |
+
demo.launch(share=False)
|
548 |
+
|
config.py
ADDED
@@ -0,0 +1,114 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
from pathlib import Path
|
2 |
+
import os
|
3 |
+
|
4 |
+
DEFAULT_MODEL_NAME = "Qwen/Qwen3-0.6B"
|
5 |
+
DEFAULT_MODE = "角色扮演"
|
6 |
+
DEFAULT_C_SETTING_MODE = "by system"
|
7 |
+
DEFAULT_COSPLAY_SETTING = 'rag/characters/周杰伦.txt'
|
8 |
+
AVALIABLE_MODELS = [
|
9 |
+
"Qwen/Qwen3-8B",
|
10 |
+
"Qwen/Qwen3-4B",
|
11 |
+
"Qwen/Qwen3-1.7B",
|
12 |
+
"Qwen/Qwen3-0.6B"
|
13 |
+
]
|
14 |
+
BASE_MODEL_TABLE = {"qwen7B_jaychou_f16": "qwen2.5:7b-instruct", "qwen0.5B_jaychou13": "qwen2.5:0.5b-instruct",
|
15 |
+
"qwen14B_jaychou_q8_newdata_add_template": "qwen2.5:14b-instruct",
|
16 |
+
"qwen2.5_32B_jaychou": "qwen2.5:32b-instruct",
|
17 |
+
"qwen2.5_0.5B_jaychou_lora": "qwen2.5:0.5b-instruct",
|
18 |
+
# "qwen2.5_32B_jaychou_tq1": "qwen2.5:32b-instruct"
|
19 |
+
}
|
20 |
+
|
21 |
+
AVALIABLE_KNOWLEDGE_BASE = [
|
22 |
+
"rag/kb/BIGOLIVE及公司介绍",
|
23 |
+
"rag/kb/主播A的直播间对话数据",
|
24 |
+
"rag/kb/周杰伦",
|
25 |
+
"rag/kb/狼人杀"
|
26 |
+
]
|
27 |
+
SUPPORT_MODES = [
|
28 |
+
"角色扮演",
|
29 |
+
"普通模式",
|
30 |
+
]
|
31 |
+
CHARACTER_SETTING_MODES = [
|
32 |
+
"by system",
|
33 |
+
"by prompt"
|
34 |
+
]
|
35 |
+
|
36 |
+
EXAMPLES_changkong = [
|
37 |
+
"""
|
38 |
+
[0:00:00]:[0:00:04] 對呀 我肯定沒有回
|
39 |
+
[0:00:04]:[0:00:08] 真的 今天我兒時頭 我感覺頭髮摘
|
40 |
+
[0:00:08]:[0:00:10] 而且我的頭髮越來越爽了
|
41 |
+
[0:00:10]:[0:00:13] 我頭髮越來越爽了
|
42 |
+
[0:00:13]:[0:00:15] 真的越來越少
|
43 |
+
[0:00:15]:[0:00:17] 好煩呀 我經常脫頭髮
|
44 |
+
""",
|
45 |
+
"""
|
46 |
+
[0:01:44]:[0:01:46] 我咋覺得這個茶克斯那麼熟呢
|
47 |
+
[0:01:47]:[0:01:49] 哦 我想起來了
|
48 |
+
【观众536644926】
|
49 |
+
好像看过
|
50 |
+
|
51 |
+
[0:01:50]:[0:01:51] 上次跟我一起打BK那個
|
52 |
+
[0:01:52]:[0:01:53] 對面的那個
|
53 |
+
【观众1887407561】
|
54 |
+
嗨😂
|
55 |
+
|
56 |
+
[0:01:54]:[0:01:55] 是不是你
|
57 |
+
[0:01:55]:[0:01:56] 肯定是
|
58 |
+
[0:01:57]:[0:01:58] 我有點想起來了
|
59 |
+
[0:01:59]:[0:02:00] 因為他們老是在叫妳的名字
|
60 |
+
[0:02:01]:[0:02:02] 好像看過
|
61 |
+
[0:02:03]:[0:02:04] 哪一把BK來的
|
62 |
+
[0:02:08]:[0:02:10] 嗨 我記得你了
|
63 |
+
[0:02:11]:[0:02:12] 那個 讓我
|
64 |
+
[0:02:13]:[0:02:14] 讓我 讓我 讓我
|
65 |
+
[0:02:14]:[0:02:15] 做那個懲罰
|
66 |
+
[0:02:15]:[0:02:17] 你知道嗎 我差點進去醫院了
|
67 |
+
""",
|
68 |
+
|
69 |
+
]
|
70 |
+
EXAMPLES_zhubo_clone = ["""最近好烦啊"""]
|
71 |
+
EXAMPLES_langren = ["""玩家角色分配为一个预言家,一个女巫,一个猎人,三个平民,三个狼人。游戏开始,请开始主持。"""]
|
72 |
+
EXAMPLES_jaychou_clone = [
|
73 |
+
"""你是谁?""",
|
74 |
+
"""不忙的时候你会做些什么?""",
|
75 |
+
"""你创作《素颜》这首歌的时候,背后有什么故事吗;这首歌里面有哪句歌词是你特别喜欢的""",
|
76 |
+
"""你的香味一直徘徊,比我知道,秘密躺在我怀抱,还有没有人知道。 这是你的哪首歌?""",
|
77 |
+
]
|
78 |
+
EXAMPLES = EXAMPLES_jaychou_clone
|
79 |
+
# 本地RAG
|
80 |
+
RAG_TOP_K = 5
|
81 |
+
|
82 |
+
# 联网
|
83 |
+
MAX_RESULTS= 3
|
84 |
+
|
85 |
+
# 目录信息
|
86 |
+
BASE_DIR = Path("demo_dir")
|
87 |
+
TEMP_DIR = BASE_DIR / "tmp"
|
88 |
+
USER_DIR = BASE_DIR / "user"
|
89 |
+
TEMP_DIR.mkdir(parents=True, exist_ok=True)
|
90 |
+
USER_DIR.mkdir(parents=True, exist_ok=True)
|
91 |
+
os.environ['GRADIO_TEMP_DIR'] = str(TEMP_DIR)
|
92 |
+
|
93 |
+
CSS = """
|
94 |
+
.rag-details:not([open]) > div {
|
95 |
+
display: none !important; /* 强制折叠状态 */
|
96 |
+
white-space: pre-wrap; /* 保留换行符 */
|
97 |
+
}
|
98 |
+
.rag-details[open] summary::after {
|
99 |
+
content: "▼";
|
100 |
+
float: right;
|
101 |
+
}
|
102 |
+
"""
|
103 |
+
|
104 |
+
# LLM最大上下文长度
|
105 |
+
MAX_MODEL_CTX = 32768
|
106 |
+
|
107 |
+
# 知识库embedding模型
|
108 |
+
AVALIABLE_EMBEDDING_MODELS = [
|
109 |
+
"BAAI/bge-large-zh-v1.5",
|
110 |
+
"sentence-transformers/all-MiniLM-L12-v2",
|
111 |
+
"sentence-transformers/all-MiniLM-L6-v2",
|
112 |
+
"jinaai/jina-embeddings-v2-base-zh",
|
113 |
+
]
|
114 |
+
DEFAULT_EMBEDDING_MODEL = "BAAI/bge-large-zh-v1.5"
|
knowledge_base.py
ADDED
@@ -0,0 +1,245 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
from langchain_community.document_loaders import DirectoryLoader, JSONLoader, UnstructuredMarkdownLoader
|
2 |
+
from langchain.text_splitter import RecursiveCharacterTextSplitter, MarkdownTextSplitter, MarkdownHeaderTextSplitter
|
3 |
+
from langchain_community.embeddings import HuggingFaceEmbeddings
|
4 |
+
from langchain_community.vectorstores import FAISS
|
5 |
+
from pathlib import Path
|
6 |
+
from transformers import AutoModel, AutoTokenizer
|
7 |
+
import torch
|
8 |
+
import torch.nn.functional as F
|
9 |
+
import numpy as np
|
10 |
+
import config as cfg
|
11 |
+
|
12 |
+
|
13 |
+
|
14 |
+
class LocalRAG:
|
15 |
+
def __init__(self,
|
16 |
+
rag_top_k=3,
|
17 |
+
doc_dir="rag/kb/BIGOLIVE及公司介绍/", # 默认加载这个,若选择角色扮演模式,可根据角色选择
|
18 |
+
vector_db_path="rag/vector_db/",
|
19 |
+
embed_model=cfg.DEFAULT_EMBEDDING_MODEL
|
20 |
+
):
|
21 |
+
self.rag_top_k = rag_top_k
|
22 |
+
self.doc_dir = doc_dir # 本地知识库的文档目录
|
23 |
+
self.vector_db_path = vector_db_path # 向量数据库存储路径
|
24 |
+
self.embed_model = embed_model
|
25 |
+
self.build_vector_db()
|
26 |
+
|
27 |
+
def build_vector_db(self):
|
28 |
+
# 加载文档(支持PDF、TXT、DOCX)
|
29 |
+
if isinstance(self.doc_dir, list):
|
30 |
+
general_docs = []
|
31 |
+
json_docs = []
|
32 |
+
md_docs = []
|
33 |
+
for doc_dir in self.doc_dir:
|
34 |
+
# 处理一般文件,txt等
|
35 |
+
loader = DirectoryLoader(doc_dir, glob="**/*.[!json!md]*") # "**/[!.]*"
|
36 |
+
tmp_docs = loader.load()
|
37 |
+
general_docs.extend(tmp_docs)
|
38 |
+
# 额外处理json文件
|
39 |
+
for json_file in Path(doc_dir).rglob("*.json"):
|
40 |
+
loader = JSONLoader(
|
41 |
+
file_path=str(json_file),
|
42 |
+
jq_schema=".[] | {spk: .spk, text: .text}",
|
43 |
+
text_content=False)
|
44 |
+
|
45 |
+
data = loader.load()
|
46 |
+
for iidx in range(len(data)):
|
47 |
+
data[iidx].page_content = bytes(data[iidx].page_content, "utf-8").decode("unicode_escape")
|
48 |
+
json_docs.extend(data)
|
49 |
+
|
50 |
+
# 额外处理md文件
|
51 |
+
headers_to_split_on = [
|
52 |
+
("#", "Header 1"),
|
53 |
+
("##", "Header 2"),
|
54 |
+
("###", "Header 3"),
|
55 |
+
]
|
56 |
+
for md_file in Path(doc_dir).rglob("*.md"):
|
57 |
+
with open(md_file, 'r') as f:
|
58 |
+
content = f.read()
|
59 |
+
|
60 |
+
|
61 |
+
# 定义拆分器,拆分markdown内容
|
62 |
+
markdown_splitter = MarkdownHeaderTextSplitter(headers_to_split_on=headers_to_split_on)
|
63 |
+
md_header_splits = markdown_splitter.split_text(content)
|
64 |
+
md_docs.extend(md_header_splits)
|
65 |
+
|
66 |
+
# loader = UnstructuredMarkdownLoader(md_file, mode="elements")
|
67 |
+
# data = loader.load()
|
68 |
+
# docs.extend(data)
|
69 |
+
else:
|
70 |
+
loader = DirectoryLoader(self.doc_dir, glob="**/*.*")
|
71 |
+
docs = loader.load()
|
72 |
+
|
73 |
+
# 文本分块
|
74 |
+
if len(general_docs) > 0:
|
75 |
+
text_splitter = RecursiveCharacterTextSplitter(
|
76 |
+
chunk_size=500,
|
77 |
+
chunk_overlap=50
|
78 |
+
)
|
79 |
+
chunks = text_splitter.split_documents(docs)
|
80 |
+
else:
|
81 |
+
chunks = json_docs + md_docs
|
82 |
+
|
83 |
+
# 生成向量并构建FAISS数据库
|
84 |
+
embeddings = HuggingFaceEmbeddings(model_name=self.embed_model)
|
85 |
+
self.vector_db = FAISS.from_documents(chunks, embeddings)
|
86 |
+
self.vector_db.save_local(self.vector_db_path)
|
87 |
+
|
88 |
+
def reload_knowledge_base(self, target_doc_dir):
|
89 |
+
self.doc_dir = target_doc_dir
|
90 |
+
self.build_vector_db()
|
91 |
+
|
92 |
+
# def reset(self):
|
93 |
+
# self.vector_db = None
|
94 |
+
|
95 |
+
|
96 |
+
class LocalRAG_new:
|
97 |
+
|
98 |
+
def __init__(self,
|
99 |
+
rag_top_k=3,
|
100 |
+
doc_dir="rag/kb/BIGOLIVE及公司介绍/", # 默认加载这个,若选择角色扮演模式,可根据角色选择
|
101 |
+
vector_db_path="rag/vector_db/",
|
102 |
+
embed_model_path="princeton-nlp/sup-simcse-bert-large-uncased",
|
103 |
+
device=torch.device('cuda:2')):
|
104 |
+
self.rag_top_k = rag_top_k
|
105 |
+
self.doc_dir = doc_dir # 本地知识库的文档目录
|
106 |
+
self.kb_name = '_'.join([Path(doc_dir[i]).name for i in range(len(doc_dir))])
|
107 |
+
self.embed_model_name = Path(embed_model_path).name
|
108 |
+
self.vector_db_path = vector_db_path # 向量数据库存储路径
|
109 |
+
self.embed_model = embed_model_path
|
110 |
+
|
111 |
+
self.device = device
|
112 |
+
# 加载分词器和模型
|
113 |
+
self.tokenizer = AutoTokenizer.from_pretrained(self.embed_model)
|
114 |
+
self.embed_model = AutoModel.from_pretrained(self.embed_model).to(device)
|
115 |
+
self.vector_db = None
|
116 |
+
self._vector_db = None
|
117 |
+
self.build_vector_db()
|
118 |
+
|
119 |
+
class VectorDB:
|
120 |
+
def __init__(self, rag):
|
121 |
+
self._data = rag._vector_db
|
122 |
+
self.rag = rag
|
123 |
+
|
124 |
+
def similarity_search(self, query, k):
|
125 |
+
# 可能的输入预处理,暂无
|
126 |
+
# query = input_optimize(query)
|
127 |
+
|
128 |
+
# 计算query的embedding并与库中比较
|
129 |
+
with torch.inference_mode():
|
130 |
+
query_token = self.rag.tokenizer(query, padding=True, truncation=False, return_tensors="pt").to(self.rag.device)
|
131 |
+
query_embed = self.rag.embed_model(**query_token)['last_hidden_state'].mean(dim=1)
|
132 |
+
sim_query = F.cosine_similarity(query_embed.repeat(len(self._data['embeds']), 1), self._data['embeds'], dim=1, eps=1e-8)
|
133 |
+
max_ids_query = torch.argsort(sim_query, descending=True)[:self.rag.rag_top_k].cpu().detach().numpy()
|
134 |
+
return list(zip(np.array(self._data['chunks'])[max_ids_query], sim_query[max_ids_query]))
|
135 |
+
|
136 |
+
def build_vector_db(self):
|
137 |
+
# 加载文档(支持PDF、TXT、DOCX)
|
138 |
+
if isinstance(self.doc_dir, list):
|
139 |
+
docs = []
|
140 |
+
for doc_dir in self.doc_dir:
|
141 |
+
loader = DirectoryLoader(doc_dir, glob="**/*.[!json!md]*") # "**/[!.]*"
|
142 |
+
tmp_docs = loader.load()
|
143 |
+
docs.extend(tmp_docs)
|
144 |
+
# # 额外处理json文件
|
145 |
+
# for json_file in Path(doc_dir).rglob("*.json"):
|
146 |
+
# loader = JSONLoader(
|
147 |
+
# file_path=str(json_file),
|
148 |
+
# jq_schema='.messages[].content',
|
149 |
+
# text_content=False)
|
150 |
+
|
151 |
+
# data = loader.load()
|
152 |
+
# 额外处理md文件
|
153 |
+
headers_to_split_on = [
|
154 |
+
("#", "Header 1"),
|
155 |
+
("##", "Header 2"),
|
156 |
+
("###", "Header 3"),
|
157 |
+
]
|
158 |
+
for md_file in Path(doc_dir).rglob("*.md"):
|
159 |
+
with open(md_file, 'r') as f:
|
160 |
+
content = f.read()
|
161 |
+
|
162 |
+
|
163 |
+
# 定义拆分器,拆分markdown内容
|
164 |
+
markdown_splitter = MarkdownHeaderTextSplitter(headers_to_split_on=headers_to_split_on)
|
165 |
+
md_header_splits = markdown_splitter.split_text(content)
|
166 |
+
docs.extend(md_header_splits)
|
167 |
+
|
168 |
+
# loader = UnstructuredMarkdownLoader(md_file, mode="elements")
|
169 |
+
# data = loader.load()
|
170 |
+
# docs.extend(data)
|
171 |
+
else:
|
172 |
+
loader = DirectoryLoader(self.doc_dir, glob="**/*.*")
|
173 |
+
docs = loader.load()
|
174 |
+
|
175 |
+
# 文本分块
|
176 |
+
text_splitter = RecursiveCharacterTextSplitter(
|
177 |
+
chunk_size=500,
|
178 |
+
chunk_overlap=50
|
179 |
+
)
|
180 |
+
chunks = text_splitter.split_documents(docs)
|
181 |
+
with torch.inference_mode():
|
182 |
+
chunk_and_embed = []
|
183 |
+
for chunk in chunks:
|
184 |
+
chunk_token = self.tokenizer(chunk.page_content, padding=True, truncation=False, return_tensors="pt").to(self.device)
|
185 |
+
chunk_embed = self.embed_model(**chunk_token)['last_hidden_state'].mean(dim=1)
|
186 |
+
chunk_and_embed.append((chunk, chunk_embed))
|
187 |
+
all_chunks, all_embeds = list(zip(*chunk_and_embed))
|
188 |
+
all_chunks, all_embeds = list(all_chunks), list(all_embeds)
|
189 |
+
all_embeds = torch.cat(all_embeds, dim=0)
|
190 |
+
self._vector_db = {'chunks': all_chunks, 'embeds': all_embeds}
|
191 |
+
self.vector_db = self.VectorDB(self)
|
192 |
+
|
193 |
+
torch.save(self.vector_db, str(Path(self.vector_db_path) / f'{self.kb_name}_{self.embed_model_name}.pt'))
|
194 |
+
|
195 |
+
def reload_knowledge_base(self, target_doc_dir):
|
196 |
+
self.doc_dir = target_doc_dir
|
197 |
+
self.build_vector_db()
|
198 |
+
|
199 |
+
# def reset(self):
|
200 |
+
# self.vector_db = None
|
201 |
+
|
202 |
+
|
203 |
+
class CosPlayer:
|
204 |
+
def __init__(self, description_file):
|
205 |
+
self.update(description_file)
|
206 |
+
|
207 |
+
def update(self, description_file):
|
208 |
+
self.description_file = description_file
|
209 |
+
with open(description_file, 'r') as f:
|
210 |
+
all_lines = f.readlines()
|
211 |
+
self.core_setting = ''.join(all_lines)
|
212 |
+
self.characters_dir = Path(description_file).parent
|
213 |
+
self.prologue_file = self.description_file.replace('/characters/', '/prologues/')
|
214 |
+
if not Path(self.prologue_file).exists():
|
215 |
+
self.prologue_file = None
|
216 |
+
|
217 |
+
def get_all_characters(self):
|
218 |
+
return [str(i) for i in list(self.characters_dir.rglob('*.txt'))]
|
219 |
+
|
220 |
+
def get_core_setting(self):
|
221 |
+
return self.core_setting
|
222 |
+
|
223 |
+
def get_prologue(self):
|
224 |
+
if self.prologue_file:
|
225 |
+
with open(self.prologue_file, 'r') as f:
|
226 |
+
all_lines = f.readlines()
|
227 |
+
return ''.join(all_lines)
|
228 |
+
else:
|
229 |
+
return None
|
230 |
+
|
231 |
+
|
232 |
+
if __name__ == "__main__":
|
233 |
+
rag = LocalRAG()
|
234 |
+
# # rag.build_vector_db()
|
235 |
+
# doc_dir = "rag/debug"
|
236 |
+
# loader = DirectoryLoader(doc_dir, glob="**/*.*")
|
237 |
+
# docs = loader.load()
|
238 |
+
|
239 |
+
# # 文本分块
|
240 |
+
# text_splitter = RecursiveCharacterTextSplitter(
|
241 |
+
# chunk_size=500,
|
242 |
+
# chunk_overlap=50
|
243 |
+
# )
|
244 |
+
# chunks = text_splitter.split_documents(docs)
|
245 |
+
# pass
|
rag/kb/周杰伦/周杰伦全部歌曲.md
ADDED
@@ -0,0 +1,341 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
# 周杰伦全部歌曲
|
2 |
+
## 十一月的萧邦
|
3 |
+
### 浪漫手机
|
4 |
+
发行时间:2005.11
|
5 |
+
### 漂移
|
6 |
+
发行时间:2005.11
|
7 |
+
### 一路向北
|
8 |
+
发行时间:2005.11
|
9 |
+
### 枫
|
10 |
+
发行时间:2005.11
|
11 |
+
### 黑色毛衣
|
12 |
+
发行时间:2005.11
|
13 |
+
### 麦芽糖
|
14 |
+
发行时间:2005.11
|
15 |
+
### 夜曲
|
16 |
+
发行时间:2005.11
|
17 |
+
### 发如雪
|
18 |
+
发行时间:2005.11
|
19 |
+
### 蓝色风暴
|
20 |
+
发行时间:2005.11
|
21 |
+
### 珊瑚海
|
22 |
+
发行时间:2005.11
|
23 |
+
### 四面楚歌
|
24 |
+
发行时间:2005.11
|
25 |
+
### 逆鳞
|
26 |
+
发行时间:2005.11
|
27 |
+
## 叶惠美
|
28 |
+
### 东风破
|
29 |
+
发行时间:2003.7
|
30 |
+
### 三年二班
|
31 |
+
发行时间:2003.7
|
32 |
+
### 晴天
|
33 |
+
发行时间:2003.7
|
34 |
+
### 你听得到
|
35 |
+
发行时间:2003.7
|
36 |
+
### 同一种调调
|
37 |
+
发行时间:2003.7
|
38 |
+
### 她的睫毛
|
39 |
+
发行时间:2003.7
|
40 |
+
### 以父之名
|
41 |
+
发行时间:2003.7
|
42 |
+
### 爱情悬崖
|
43 |
+
发行时间:2003.7
|
44 |
+
### 懦夫
|
45 |
+
发行时间:2003.7
|
46 |
+
### 梯田
|
47 |
+
发行时间:2003.7
|
48 |
+
### 双刀
|
49 |
+
发行时间:2003.7
|
50 |
+
## 魔杰座
|
51 |
+
### 给我一首歌的时间
|
52 |
+
发行时间:2008.10
|
53 |
+
### 乔克叔叔
|
54 |
+
发行时间:2008.10
|
55 |
+
### 时光机
|
56 |
+
发行时间:2008.10
|
57 |
+
### 说好的幸福呢
|
58 |
+
发行时间:2008.10
|
59 |
+
### 稻香
|
60 |
+
发行时间:2008.10
|
61 |
+
### 龙战骑士
|
62 |
+
发行时间:2008.10
|
63 |
+
### 花海
|
64 |
+
发行时间:2008.10
|
65 |
+
### 蛇舞
|
66 |
+
发行时间:2008.10
|
67 |
+
### 兰亭序
|
68 |
+
发行时间:2008.10
|
69 |
+
### 流浪诗人
|
70 |
+
发行时间:2008.10
|
71 |
+
### 魔术先生
|
72 |
+
发行时间:2008.10
|
73 |
+
## 八度空间
|
74 |
+
### 火车叨位去
|
75 |
+
发行时间:2002.7
|
76 |
+
### 回到过去
|
77 |
+
发行时间:2002.7
|
78 |
+
### 爷爷泡的茶
|
79 |
+
发行时间:2002.7
|
80 |
+
### 半兽人
|
81 |
+
发行时间:2002.7
|
82 |
+
### 暗号
|
83 |
+
发行时间:2002.7
|
84 |
+
### 最后的战役
|
85 |
+
发行时间:2002.7
|
86 |
+
### 龙拳
|
87 |
+
发行时间:2002.7
|
88 |
+
### 米兰的小铁匠
|
89 |
+
发行时间:2002.7
|
90 |
+
### 分裂
|
91 |
+
发行时间:2002.7
|
92 |
+
### 半岛铁盒
|
93 |
+
发行时间:2002.7
|
94 |
+
## 十二新作
|
95 |
+
### 哪里都是你
|
96 |
+
发行时间:2012.12
|
97 |
+
### 公公偏头痛
|
98 |
+
发行时间:2012.12
|
99 |
+
### 手语
|
100 |
+
发行时间:2012.12
|
101 |
+
### 乌克丽丽
|
102 |
+
发行时间:2012.12
|
103 |
+
### 红尘客栈
|
104 |
+
发行时间:2012.12
|
105 |
+
### 梦想启动
|
106 |
+
发行时间:2012.12
|
107 |
+
### 四季列车
|
108 |
+
发行时间:2012.12
|
109 |
+
### 大笨钟
|
110 |
+
发行时间:2012.12
|
111 |
+
### 爱你没差
|
112 |
+
发行时间:2012.12
|
113 |
+
### 傻笑
|
114 |
+
发行时间:2012.12
|
115 |
+
### 明明就
|
116 |
+
发行时间:2012.12
|
117 |
+
### 比较大的大提琴
|
118 |
+
发行时间:2012.12
|
119 |
+
## 七里香
|
120 |
+
### 七里香
|
121 |
+
发行时间:2004.8
|
122 |
+
### 外婆
|
123 |
+
发行时间:2004.8
|
124 |
+
### 困兽之斗
|
125 |
+
发行时间:2004.8
|
126 |
+
### 我的地盘
|
127 |
+
发行时间:2004.8
|
128 |
+
### 借口
|
129 |
+
发行时间:2004.8
|
130 |
+
### 园游会
|
131 |
+
发行时间:2004.8
|
132 |
+
### 止战之殇
|
133 |
+
发行时间:2004.8
|
134 |
+
### 乱舞春秋
|
135 |
+
发行时间:2004.8
|
136 |
+
### 将军
|
137 |
+
发行时间:2004.8
|
138 |
+
### 搁浅
|
139 |
+
发行时间:2004.8
|
140 |
+
## 惊叹号
|
141 |
+
### 惊叹号
|
142 |
+
发行时间:2011.11
|
143 |
+
### 公主病
|
144 |
+
发行时间:2011.11
|
145 |
+
### 琴伤
|
146 |
+
发行时间:2011.11
|
147 |
+
### 周杰伦
|
148 |
+
发行时间:2011.11
|
149 |
+
### 皮影戏
|
150 |
+
发行时间:2011.11
|
151 |
+
### 超跑女神
|
152 |
+
发行时间:2011.11
|
153 |
+
### 水手怕水
|
154 |
+
发行时间:2011.11
|
155 |
+
### 世界未末日
|
156 |
+
发行时间:2011.11
|
157 |
+
### 迷魂曲
|
158 |
+
发行时间:2011.11
|
159 |
+
### 你好吗
|
160 |
+
发行时间:2011.11
|
161 |
+
### 疗伤烧肉粽
|
162 |
+
发行时间:2011.11
|
163 |
+
## 跨时代
|
164 |
+
### 免费教学录影带
|
165 |
+
发行时间:2010.5
|
166 |
+
### 爱的飞行日记
|
167 |
+
发行时间:2010.5
|
168 |
+
### 嘻哈空姐
|
169 |
+
发行时间:2010.5
|
170 |
+
### 超人不会飞
|
171 |
+
发行时间:2010.5
|
172 |
+
### 自导自演
|
173 |
+
发行时间:2010.5
|
174 |
+
### 跨时代
|
175 |
+
发行时间:2010.5
|
176 |
+
### 说了再见
|
177 |
+
发行时间:2010.5
|
178 |
+
### 烟花易冷
|
179 |
+
发行时间:2010.5
|
180 |
+
### 我落泪
|
181 |
+
发行时间:2010.5
|
182 |
+
### 雨下一整晚
|
183 |
+
发行时间:2010.5
|
184 |
+
### 好久不见
|
185 |
+
发行时间:2010.5
|
186 |
+
## 我很忙
|
187 |
+
### 最长的电影
|
188 |
+
发行时间:2007.11
|
189 |
+
### 甜甜的
|
190 |
+
发行时间:2007.11
|
191 |
+
### 青花瓷
|
192 |
+
发行时间:2007.11
|
193 |
+
### 我不配
|
194 |
+
发行时间:2007.11
|
195 |
+
### 牛仔很忙
|
196 |
+
发行时间:2007.11
|
197 |
+
### 无双
|
198 |
+
发行时间:2007.11
|
199 |
+
### 彩虹
|
200 |
+
发行时间:2007.11
|
201 |
+
### 阳光宅男
|
202 |
+
发行时间:2007.11
|
203 |
+
### 蒲公英的约定
|
204 |
+
发行时间:2007.11
|
205 |
+
### 扯
|
206 |
+
发行时间:2007.11
|
207 |
+
## 依然范特西
|
208 |
+
### 听妈妈的话
|
209 |
+
发行时间:2006.9
|
210 |
+
### 菊花台
|
211 |
+
发行时间:2006.9
|
212 |
+
### 退后
|
213 |
+
发行时间:2006.9
|
214 |
+
### 本草纲目
|
215 |
+
发行时间:2006.9
|
216 |
+
### 夜的第七章
|
217 |
+
发行时间:2006.9
|
218 |
+
### 迷迭香
|
219 |
+
发行时间:2006.9
|
220 |
+
### 千里之外
|
221 |
+
发行时间:2006.9
|
222 |
+
### 心雨
|
223 |
+
发行时间:2006.9
|
224 |
+
### 红模仿
|
225 |
+
发行时间:2006.9
|
226 |
+
### 白色风车
|
227 |
+
发行时间:2006.9
|
228 |
+
## Jay
|
229 |
+
### 反方向的钟
|
230 |
+
发行时间:2000.11
|
231 |
+
### 印地安老斑鸠
|
232 |
+
发行时间:2000.11
|
233 |
+
### 完美主义
|
234 |
+
发行时间:2000.11
|
235 |
+
### 黑色幽默
|
236 |
+
发行时间:2000.11
|
237 |
+
### 可爱女人
|
238 |
+
发行时间:2000.11
|
239 |
+
### 伊斯坦堡
|
240 |
+
发行时间:2000.11
|
241 |
+
### 斗牛
|
242 |
+
发行时间:2000.11
|
243 |
+
### 娘子
|
244 |
+
发行时间:2000.11
|
245 |
+
### 龙卷风
|
246 |
+
发行时间:2000.11
|
247 |
+
### 星晴
|
248 |
+
发行时间:2000.11
|
249 |
+
## 非专辑
|
250 |
+
## 范特西
|
251 |
+
### 简单爱
|
252 |
+
发行时间:2001.9
|
253 |
+
### 对不起
|
254 |
+
发行时间:2001.9
|
255 |
+
### 忍者
|
256 |
+
发行时间:2001.9
|
257 |
+
### 上海一九四三
|
258 |
+
发行时间:2001.9
|
259 |
+
### 开不了口
|
260 |
+
发行时间:2001.9
|
261 |
+
### 周杰伦
|
262 |
+
发行时间:2001.9
|
263 |
+
### 威廉古堡
|
264 |
+
发行时间:2001.9
|
265 |
+
### 爱在西元前
|
266 |
+
发行时间:2001.9
|
267 |
+
### 双截棍
|
268 |
+
发行时间:2001.9
|
269 |
+
### 安静
|
270 |
+
发行时间:2001.9
|
271 |
+
## 哎呦,不錯哦
|
272 |
+
### 阳明山
|
273 |
+
发行时间:2014.12
|
274 |
+
### 窃爱
|
275 |
+
发行时间:2014.12
|
276 |
+
### 算什么男人
|
277 |
+
发行时间:2014.12
|
278 |
+
### 天涯过客
|
279 |
+
发行时间:2014.12
|
280 |
+
### 怎么了
|
281 |
+
发行时间:2014.12
|
282 |
+
### 一口气全念对
|
283 |
+
发行时间:2014.12
|
284 |
+
### 我要夏天
|
285 |
+
发行时间:2014.12
|
286 |
+
### 手写的从前
|
287 |
+
发行时间:2014.12
|
288 |
+
### 鞋子特大号
|
289 |
+
发行时间:2014.12
|
290 |
+
### 听爸爸的话
|
291 |
+
发行时间:2014.12
|
292 |
+
### 美人鱼
|
293 |
+
发行时间:2014.12
|
294 |
+
### 听见下雨的声音
|
295 |
+
发行时间:2014.12
|
296 |
+
## 周杰伦的床边故事
|
297 |
+
### 床边故事
|
298 |
+
发行时间:2016.6
|
299 |
+
### 说走就走
|
300 |
+
发行时间:2016.6
|
301 |
+
### 一点点
|
302 |
+
发行时间:2016.6
|
303 |
+
### 前世情人
|
304 |
+
发行时间:2016.6
|
305 |
+
### 英雄
|
306 |
+
发行时间:2016.6
|
307 |
+
### 不该
|
308 |
+
发行时间:2016.6
|
309 |
+
### 土耳其冰淇淋
|
310 |
+
发行时间:2016.6
|
311 |
+
### 告白气球
|
312 |
+
发行时间:2016.6
|
313 |
+
### Now You See Me
|
314 |
+
发行时间:2016.6
|
315 |
+
### 爱情废柴
|
316 |
+
发行时间:2016.6
|
317 |
+
## 最伟大的作品
|
318 |
+
### Intro
|
319 |
+
发行时间:2022.7
|
320 |
+
### 最伟大的作品
|
321 |
+
发行时间:2022.7
|
322 |
+
### 还在流浪
|
323 |
+
发行时间:2022.7
|
324 |
+
### 说好不哭
|
325 |
+
发行时间:2022.7
|
326 |
+
### 红颜如霜
|
327 |
+
发行时间:2022.7
|
328 |
+
### 不爱我就拉倒
|
329 |
+
发行时间:2022.7
|
330 |
+
### Mojito
|
331 |
+
发行时间:2022.7
|
332 |
+
### 错过的烟火
|
333 |
+
发行时间:2022.7
|
334 |
+
### 等你下课
|
335 |
+
发行时间:2022.7
|
336 |
+
### 粉色海洋
|
337 |
+
发行时间:2022.7
|
338 |
+
### 倒影
|
339 |
+
发行时间:2022.7
|
340 |
+
### 我是如此相信
|
341 |
+
发行时间:2022.7
|
rag/kb/周杰伦/周杰伦全部歌词.md
ADDED
The diff for this file is too large to render.
See raw diff
|
|
rag/kb/周杰伦/周杰伦基本资料.md
ADDED
@@ -0,0 +1,153 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
# 周杰伦
|
2 |
+
周杰伦(1979年1月18日—),台湾创作男歌手、演员、词曲作家及制作人,被誉为“中文流行天王”,对华语乐坛及全球华人社群具有深远影响。
|
3 |
+
|
4 |
+
## 基本信息
|
5 |
+
| 项目 | 内容 |
|
6 |
+
|--------------|----------------------------------------------------------------------|
|
7 |
+
| **英文名** | Jay Chou |
|
8 |
+
| **昵称** | 周董、小公举[[1]] |
|
9 |
+
| **出生** | 1979年1月18日(46岁),台湾台北县林口乡(现新北市林口区)[[2]] |
|
10 |
+
| **居住地** | 台湾台北市大安区 |
|
11 |
+
| **职业** | 歌手、词曲作家、音乐制作人、演员、导演、企业家 |
|
12 |
+
| **语言** | 国语、台语、英语 |
|
13 |
+
| **母校** | 淡江中学 |
|
14 |
+
| **宗教信仰** | 基督新教[[3]] |
|
15 |
+
| **配偶** | 昆凌(2015年结婚)[[4]] |
|
16 |
+
| **儿女** | 2女1子 |
|
17 |
+
| **出道日期** | 2000年11月6日,出道作品《杰伦》 |
|
18 |
+
|
19 |
+
## 早年经历
|
20 |
+
1. **家庭背景**
|
21 |
+
- 父亲周耀中为生物教师,母亲叶惠美为美术教师,14岁父母离异后随母亲生活[[14][15][16][17]]。
|
22 |
+
- 澄清《爸,我回来了》并非指涉父母家暴,而是社会现象的感慨[[18]]。
|
23 |
+
|
24 |
+
2. **教育与成长**
|
25 |
+
- 3岁学钢琴,国中热爱篮球,高中就读淡江中学音乐科,奠定音乐基础[[19][20][21]]。
|
26 |
+
- 大学联考落榜,因僵直性脊椎炎免服兵役[[22][23]]。
|
27 |
+
|
28 |
+
3. **音乐启蒙**
|
29 |
+
- 受张学友、肖邦、李恕权、史帝夫·汪达影响,创作风格融合古典与流行[[24][25]]。
|
30 |
+
|
31 |
+
## 音乐事业
|
32 |
+
### 重要阶段
|
33 |
+
#### 1997年—2001年:起步与突破
|
34 |
+
- 1997年通过《超级新人王》被吴宗宪发掘,开启作曲生涯。
|
35 |
+
- 1998年发表首支创作歌曲《三暝三日》(吴宗宪演唱)。
|
36 |
+
- 1999年与方文山合作《落雨声》(江蕙演唱),奠定合作基础[[21]]。
|
37 |
+
|
38 |
+
#### 2000年—2003年:奠定地位
|
39 |
+
- 2000年首张专辑《Jay》融合R&B、Hip-Hop与中国风,开创“周式曲风”[[21]]。
|
40 |
+
- 2001年《范特西》获金曲奖5项大奖,成为首位获“最佳专辑制作人”的新人[[26][27]]。
|
41 |
+
- 2003年《叶惠美》主打歌《以父之名》引发“周杰伦日”,全球8亿人同步收听[[28]]。
|
42 |
+
|
43 |
+
#### 2004年—2007年:国际影响力
|
44 |
+
- 2004年《七里香》获世界音乐大奖“大中华区最畅销艺人”[[29]]。
|
45 |
+
- 2005年《十一月的肖邦》尝试MV执导,2007年成立杰威尔音乐[[30]]。
|
46 |
+
|
47 |
+
#### 2008年—2012年:跨界与巡演
|
48 |
+
- 2008年《魔杰座》在日本武道馆开唱,成为第二位华人歌手[[31]]。
|
49 |
+
- 2011年与科比合作雪碧广告曲《天地一斗》,进军国际[[33]]。
|
50 |
+
|
51 |
+
#### 2014年—2019年:持续创新
|
52 |
+
- 2016年《周杰伦的床边故事》主打歌《告白气球》创公告牌纪录[[34]]。
|
53 |
+
- 2019年《说好不哭》与五月天阿信合唱,YouTube播放量破千万[[35]]。
|
54 |
+
|
55 |
+
#### 2020年至今:新专辑与合作
|
56 |
+
- 2022年《最伟大的作品》获IFPI全球最畅销专辑冠军[[36][37]]。
|
57 |
+
- 2023年《圣诞星》MV 13小时破140万次播放[[41]]。
|
58 |
+
|
59 |
+
## 电影事业
|
60 |
+
### 代表作品
|
61 |
+
1. **《头文字D》(2005年)**
|
62 |
+
- 首登大银幕,获香港金像奖、金马奖双料最佳新演员[[48][49]]。
|
63 |
+
|
64 |
+
2. **《满城尽带黄金甲》(2006年)**
|
65 |
+
- 饰演二王子元杰,提名香港金像奖最佳男配角[[56][63]]。
|
66 |
+
|
67 |
+
3. **《不能说的·秘密》(2007年)**
|
68 |
+
- 自编自导自演,融合音乐创作,获韩国首映好评[[68]]。
|
69 |
+
|
70 |
+
4. **《青蜂侠》(2011年)**
|
71 |
+
- 好莱坞处女作,演唱片尾曲《双截棍》[[65][67]]。
|
72 |
+
|
73 |
+
5. **《天台爱情》(2013年)**
|
74 |
+
- 第二部自导自演电影,获张艺谋赞赏[[68]]。
|
75 |
+
|
76 |
+
## 副业
|
77 |
+
### 潮流事业
|
78 |
+
- 2006年创立潮流品牌**PHANTACi**,2016年推出电商平台**J Concept星品库**[[70][72]]。
|
79 |
+
|
80 |
+
### 电子竞技
|
81 |
+
- 2016年收购台北暗杀星战队,更名**J Team**,并开设“魔杰电竞馆”[[73]]。
|
82 |
+
|
83 |
+
### 餐饮与投资
|
84 |
+
- 开设“Mr.J藤原豆腐店”“J POT HOTPOT火锅料理”等餐饮品牌[[74]]。
|
85 |
+
- 入股中国数码文化集团,参与耳机品牌**1MORE万魔**[[75]]。
|
86 |
+
|
87 |
+
## 个人生活
|
88 |
+
### 婚姻与家庭
|
89 |
+
- 2015年与昆凌结婚,育有两女一子,三次婚礼分别在英国、台湾、澳大利亚举办[[83][84][85]]。
|
90 |
+
- 夫妻均为基督徒,婚��于英国塞尔比修道院举行[[89][92]]。
|
91 |
+
|
92 |
+
### 绯闻与争议
|
93 |
+
- 早年与蔡依林、侯佩岑等传绯闻,均否认[[75][76][77]]。
|
94 |
+
- 与昆凌恋情于2014年公开,称其“善良、贴心”[[78][80]]。
|
95 |
+
|
96 |
+
## 荣誉与奖项
|
97 |
+
### 重要荣誉
|
98 |
+
- 2003年《时代》杂志亚洲版封面人物[[21]]。
|
99 |
+
- 2009年CNN“亚洲最具影响力25人”[[12]]。
|
100 |
+
- 2010年小行星**257248**命名为“周杰伦星”[[38]]。
|
101 |
+
- 2023年《最伟大的作品》全球专辑销量冠军[[108]]。
|
102 |
+
|
103 |
+
### 主要奖项
|
104 |
+
- **金曲奖**:15座(含最佳专辑、最佳作曲人等)[[10][111][112]]。
|
105 |
+
- **世界音乐大奖**:4次“大中华区最畅销艺人”[[29][32]]。
|
106 |
+
- **香港金像奖**:最佳新演员(《头文字D》)[[48]]。
|
107 |
+
|
108 |
+
## 作品列表
|
109 |
+
### 音乐专辑
|
110 |
+
《Jay》(2000年)
|
111 |
+
《范特西》(2001年)
|
112 |
+
《八度空间》(2002年)
|
113 |
+
《叶惠美》(2003年)
|
114 |
+
《七里香》(2004年)
|
115 |
+
《十一月的肖邦》(2005年)
|
116 |
+
《依然范特西》(2006年)
|
117 |
+
《我很忙》(2007年)
|
118 |
+
《魔杰座》(2008年)
|
119 |
+
《跨时代》(2010年)
|
120 |
+
《惊叹号》(2011年)
|
121 |
+
《十二新作》(2012年)
|
122 |
+
《哎呦,不错哦》(2014年)
|
123 |
+
《周杰伦的床边故事》(2016年)
|
124 |
+
《最伟大的作品》(2022年)
|
125 |
+
|
126 |
+
### 书籍作品
|
127 |
+
| 出版日期 | 书名 | 作者 | 出版社 |
|
128 |
+
|------------|---------------------------------------|------------|----------------------|
|
129 |
+
| 2002年 | 《半岛铁盒》 | 方文山 | 华人版图出版社 |
|
130 |
+
| 2004年 | 《D调的华丽》 | 周杰伦 | 华人版图出版社 |
|
131 |
+
| 2007年 | 《不能说的秘密电影创作琴谱》 | 周杰伦 | 华人版图出版社 |
|
132 |
+
|
133 |
+
### 演出作品(电影)
|
134 |
+
| 年份 | 片名 | 角色 | 备注 |
|
135 |
+
|------|----------------------|------------|--------------------------------|
|
136 |
+
| 2005 | 《头文字D》 | 藤原拓海 | 首部主演电影,获双料新人奖 |
|
137 |
+
| 2006 | 《满城尽带黄金甲》 | 元杰 | 国际发行,提名香港金像奖 |
|
138 |
+
| 2007 | 《不能说的·秘密》 | 叶湘伦 | 自编自导自演 |
|
139 |
+
| 2011 | 《青蜂侠》 | 加藤(Kato)| 好莱坞处女作,演唱片尾曲 |
|
140 |
+
|
141 |
+
## 注释与参考资料
|
142 |
+
1. 注1:现今的新北市林口区。
|
143 |
+
2. 注2:淡江中学音乐科为新设科系,因错过华冈艺校报名而就读[[21]]。
|
144 |
+
|
145 |
+
[1] 周杰倫綽號大揭密 被叫「小公舉」是因.... 自由时报. 2016-01-20.
|
146 |
+
[2] 藝人專區 周杰倫 JAY CHOU. 杰威尔音乐有限公司. 2024-12-14.
|
147 |
+
[3] 周杰倫昆凌喜迎小王子 牧師祝福:「多子多孫,生養眾多!」. 基督日报. 2017-06-23.
|
148 |
+
|
149 |
+
|
150 |
+
## 外部链接
|
151 |
+
- [杰威尔音乐官方网站](https://www.jvrmusic.com)
|
152 |
+
- [周杰伦YouTube频道](https://www.youtube.com/user/JayChtV)
|
153 |
+
- [周杰伦Instagram](https://www.instagram.com/jaychou)
|
tts_api.py
ADDED
@@ -0,0 +1,274 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
from transformers import AutoTokenizer, AutoModelForCausalLM
|
2 |
+
import torch
|
3 |
+
import soundfile as sf
|
4 |
+
from xcodec2.modeling_xcodec2 import XCodec2Model
|
5 |
+
import numpy as np
|
6 |
+
import ChatTTS
|
7 |
+
import re
|
8 |
+
DEFAULT_TTS_MODEL_NAME = "HKUSTAudio/LLasa-1B"
|
9 |
+
DEMO_EXAMPLES = [
|
10 |
+
["太乙真人.wav", "对,这就是我万人敬仰的太乙真人,虽然有点婴儿肥,但也掩不住我逼人的帅气。"],
|
11 |
+
["邓紫棋.wav", "特别大的不同,因为以前在香港是过年的时候,我们可能见到的亲戚都是爸爸那边的亲戚"],
|
12 |
+
["雷军.wav", "这是个好问题,我把来龙去脉给你简单讲,就是这个社会对小米有很多的误解,有很多的误解,呃,也能理解啊,就是小米这个模式呢"],
|
13 |
+
["周杰伦.wav", "但如果你这兴趣可以得到很大的回响,那会更开心"],
|
14 |
+
["Taylor Swift.wav", "It's actually uh, it's a concept record, but it's my first directly autobiographical album in a while because the last album that I put out was, uh, a rework."]
|
15 |
+
]
|
16 |
+
class TTSapi:
|
17 |
+
def __init__(self,
|
18 |
+
model_name=DEFAULT_TTS_MODEL_NAME,
|
19 |
+
codec_model_name="HKUST-Audio/xcodec2",
|
20 |
+
device=torch.device("cuda:0")):
|
21 |
+
|
22 |
+
self.reload(model_name, codec_model_name, device)
|
23 |
+
|
24 |
+
def reload(self,
|
25 |
+
model_name=DEFAULT_TTS_MODEL_NAME,
|
26 |
+
codec_model_name="HKUST-Audio/xcodec2",
|
27 |
+
device=torch.device("cuda:0")):
|
28 |
+
if 'llasa' in model_name.lower():
|
29 |
+
self.tokenizer = AutoTokenizer.from_pretrained(model_name)
|
30 |
+
self.model = AutoModelForCausalLM.from_pretrained(model_name)
|
31 |
+
self.model.eval().to(device)
|
32 |
+
|
33 |
+
self.codec_model = XCodec2Model.from_pretrained(codec_model_name)
|
34 |
+
self.codec_model.eval().to(device)
|
35 |
+
self.device = device
|
36 |
+
self.codec_model_name = codec_model_name
|
37 |
+
self.sr = 16000
|
38 |
+
elif 'chattts' in model_name.lower():
|
39 |
+
self.model = ChatTTS.Chat()
|
40 |
+
self.model.load(compile=False) # Set to True for better performance but would l significantly reduce the loading speed
|
41 |
+
self.sr = 24000
|
42 |
+
self.punctuation = r'[,,.。??!!~~;;]'
|
43 |
+
else:
|
44 |
+
raise ValueError(f'不支持的TTS模型:{model_name}')
|
45 |
+
|
46 |
+
self.model_name = model_name
|
47 |
+
|
48 |
+
def ids_to_speech_tokens(self, speech_ids):
|
49 |
+
speech_tokens_str = []
|
50 |
+
for speech_id in speech_ids:
|
51 |
+
speech_tokens_str.append(f"<|s_{speech_id}|>")
|
52 |
+
return speech_tokens_str
|
53 |
+
|
54 |
+
def extract_speech_ids(self, speech_tokens_str):
|
55 |
+
speech_ids = []
|
56 |
+
for token_str in speech_tokens_str:
|
57 |
+
if token_str.startswith('<|s_') and token_str.endswith('|>'):
|
58 |
+
num_str = token_str[4:-2]
|
59 |
+
|
60 |
+
num = int(num_str)
|
61 |
+
speech_ids.append(num)
|
62 |
+
else:
|
63 |
+
print(f"Unexpected token: {token_str}")
|
64 |
+
return speech_ids
|
65 |
+
|
66 |
+
|
67 |
+
def forward(self, input_text, speech_prompt=None, save_path='wavs/generated/gen.wav'):
|
68 |
+
#TTS start!
|
69 |
+
with torch.no_grad():
|
70 |
+
if 'chattts' in self.model_name.lower():
|
71 |
+
# rand_spk = chat.sample_random_speaker()
|
72 |
+
# print(rand_spk) # save it for later timbre recovery
|
73 |
+
|
74 |
+
# params_infer_code = ChatTTS.Chat.InferCodeParams(
|
75 |
+
# spk_emb = rand_spk, # add sampled speaker
|
76 |
+
# temperature = .3, # using custom temperature
|
77 |
+
# top_P = 0.7, # top P decode
|
78 |
+
# top_K = 20, # top K decode
|
79 |
+
# )
|
80 |
+
break_num = max(min(len(re.split(self.punctuation, input_text)), 7), 2)
|
81 |
+
params_refine_text = ChatTTS.Chat.RefineTextParams(
|
82 |
+
prompt=f'[oral_2][laugh_0][break_{break_num}]',
|
83 |
+
)
|
84 |
+
wavs = self.model.infer([input_text],
|
85 |
+
params_refine_text=params_refine_text,
|
86 |
+
)
|
87 |
+
gen_wav_save = wavs[0]
|
88 |
+
sf.write(save_path, gen_wav_save, 24000)
|
89 |
+
|
90 |
+
else:
|
91 |
+
if speech_prompt:
|
92 |
+
# only 16khz speech support!
|
93 |
+
prompt_wav, sr = sf.read(speech_prompt) # you can find wav in Files
|
94 |
+
prompt_wav = torch.from_numpy(prompt_wav).float().unsqueeze(0)
|
95 |
+
|
96 |
+
# Encode the prompt wav
|
97 |
+
vq_code_prompt = self.codec_model.encode_code(input_waveform=prompt_wav)
|
98 |
+
print("Prompt Vq Code Shape:", vq_code_prompt.shape )
|
99 |
+
|
100 |
+
vq_code_prompt = vq_code_prompt[0,0,:]
|
101 |
+
# Convert int 12345 to token <|s_12345|>
|
102 |
+
speech_ids_prefix = self.ids_to_speech_tokens(vq_code_prompt)
|
103 |
+
else:
|
104 |
+
speech_ids_prefix = ''
|
105 |
+
formatted_text = f"<|TEXT_UNDERSTANDING_START|>{input_text}<|TEXT_UNDERSTANDING_END|>"
|
106 |
+
|
107 |
+
# Tokenize the text ( and the speech prefix)
|
108 |
+
chat = [
|
109 |
+
{"role": "user", "content": "Convert the text to speech:" + formatted_text},
|
110 |
+
{"role": "assistant", "content": "<|SPEECH_GENERATION_START|>" + ''.join(speech_ids_prefix)}
|
111 |
+
]
|
112 |
+
|
113 |
+
input_ids = self.tokenizer.apply_chat_template(
|
114 |
+
chat,
|
115 |
+
tokenize=True,
|
116 |
+
return_tensors='pt',
|
117 |
+
continue_final_message=True
|
118 |
+
)
|
119 |
+
input_ids = input_ids.to(self.device)
|
120 |
+
speech_end_id = self.tokenizer.convert_tokens_to_ids('<|SPEECH_GENERATION_END|>')
|
121 |
+
|
122 |
+
# Generate the speech autoregressively
|
123 |
+
outputs = self.model.generate(
|
124 |
+
input_ids,
|
125 |
+
max_length=2048, # We trained our model with a max length of 2048
|
126 |
+
eos_token_id= speech_end_id ,
|
127 |
+
do_sample=True,
|
128 |
+
top_p=1, # Adjusts the diversity of generated content
|
129 |
+
temperature=1, # Controls randomness in output
|
130 |
+
)
|
131 |
+
# Extract the speech tokens
|
132 |
+
generated_ids = outputs[0][input_ids.shape[1] - len(speech_ids_prefix):-1]
|
133 |
+
|
134 |
+
speech_tokens = self.tokenizer.batch_decode(generated_ids, skip_special_tokens=True)
|
135 |
+
|
136 |
+
# Convert token <|s_23456|> to int 23456
|
137 |
+
speech_tokens = self.extract_speech_ids(speech_tokens)
|
138 |
+
|
139 |
+
speech_tokens = torch.tensor(speech_tokens).to(self.device).unsqueeze(0).unsqueeze(0)
|
140 |
+
|
141 |
+
# Decode the speech tokens to speech waveform
|
142 |
+
gen_wav = self.codec_model.decode_code(speech_tokens)
|
143 |
+
|
144 |
+
# if only need the generated part
|
145 |
+
if speech_prompt:
|
146 |
+
gen_wav = gen_wav[:,:,prompt_wav.shape[1]:]
|
147 |
+
|
148 |
+
gen_wav_save = gen_wav[0, 0, :].cpu().numpy()
|
149 |
+
sf.write(save_path, gen_wav_save, 16000)
|
150 |
+
# gen_wav_save = np.clip(gen_wav_save, -1, 1)
|
151 |
+
# gen_wav_save = (gen_wav_save * 32767).astype(np.int16)
|
152 |
+
return gen_wav_save
|
153 |
+
|
154 |
+
|
155 |
+
if __name__ == '__main__':
|
156 |
+
# Llasa-8B shows better text understanding ability.
|
157 |
+
|
158 |
+
# input_text = " He shouted, 'Everyone, please gather 'round! Here's the plan: 1) Set-up at 9:15 a.m.; 2) Lunch at 12:00 p.m. (please RSVP!); 3) Playing — e.g., games, music, etc. — from 1:15 to 4:45; and 4) Clean-up at 5 p.m.'"
|
159 |
+
# prompt_text ="对,这就是我万人敬仰的太乙真人,虽然有点婴儿肥,但也掩不住我逼人的帅气。"
|
160 |
+
# input_text = prompt_text + '嘻嘻,臭宝儿你真可爱,我好喜欢你呀。'
|
161 |
+
# save_root = 'wavs/generated/'
|
162 |
+
# save_path = save_root + 'test.wav'
|
163 |
+
# speech_ref = 'wavs/ref/太乙真人.wav'
|
164 |
+
# # speech_ref = None
|
165 |
+
# # 帘外雨潺潺,春意阑珊。罗衾不耐五更寒。梦里不知身是客,一晌贪欢。独自莫凭栏,无限江山。别时容易见时难。流水落花春去也,天上人间。
|
166 |
+
# llasa_tts = TTSapi()
|
167 |
+
# gen = llasa_tts.forward(input_text, speech_prompt=speech_ref, save_path=save_path)
|
168 |
+
# print(gen.shape)
|
169 |
+
import gradio as gr
|
170 |
+
synthesiser = TTSapi()
|
171 |
+
TTS_LOADED = True
|
172 |
+
def predict(config):
|
173 |
+
global TTS_LOADED, synthesiser
|
174 |
+
print(f"待合成文本:{config['msg']}")
|
175 |
+
print(f"选中TTS模型:{config['tts_model']}")
|
176 |
+
print(f"参考音频路径:{config['ref_audio']}")
|
177 |
+
print(f"参考音频文本:{config['ref_audio_transcribe']}")
|
178 |
+
text = config['msg']
|
179 |
+
try:
|
180 |
+
if len(text) == 0:
|
181 |
+
audio_output = np.array([0], dtype=np.int16)
|
182 |
+
print("输入为空,无法合成语音")
|
183 |
+
else:
|
184 |
+
if not TTS_LOADED:
|
185 |
+
print('TTS模型首次加载...')
|
186 |
+
gr.Info("初次加载TTS模型,请稍候..", duration=63)
|
187 |
+
synthesiser = TTSapi(model_name=config['tts_model'])#, device="cuda:2")
|
188 |
+
TTS_LOADED = True
|
189 |
+
print('加载完毕...')
|
190 |
+
# 检查当前模型是否是所选
|
191 |
+
if config['tts_model'] != synthesiser.model_name:
|
192 |
+
print(f'当前TTS模型{synthesiser.model_name}非所选,重新加载')
|
193 |
+
synthesiser.reload(model_name=config['tts_model'])
|
194 |
+
|
195 |
+
# 如果提供了参考音频,则需把参考音频的文本加在response_content前面作为前缀
|
196 |
+
if config['ref_audio']:
|
197 |
+
prompt_text = config['ref_audio_transcribe']
|
198 |
+
if prompt_text is None:
|
199 |
+
# prompt_text = ...
|
200 |
+
raise NotImplementedError('暂时必须提供文本') # TODO:考虑后续加入ASR模型
|
201 |
+
text = prompt_text + text
|
202 |
+
|
203 |
+
audio_output = synthesiser.forward(text, speech_prompt=config['ref_audio'])
|
204 |
+
|
205 |
+
except Exception as e:
|
206 |
+
print('!!!!!!!!')
|
207 |
+
print(e)
|
208 |
+
print('!!!!!!!!')
|
209 |
+
|
210 |
+
return (synthesiser.sr if synthesiser else 16000, audio_output)
|
211 |
+
|
212 |
+
with gr.Blocks(title="TTS Demo", theme=gr.themes.Soft(font=["sans-serif", "Arial"])) as demo:
|
213 |
+
gr.Markdown("""
|
214 |
+
# Personalized TTS Demo
|
215 |
+
## 使用步骤
|
216 |
+
* 上传你想要合成的目标说话人的语音,10s左右即可,并在下面填入对应的文本。或直接点击下方示例
|
217 |
+
* 输入你想要合成的文字,点击合成语音按钮,稍等片刻即可
|
218 |
+
|
219 |
+
""")
|
220 |
+
with gr.Row():
|
221 |
+
with gr.Column():
|
222 |
+
# TTS模型选择
|
223 |
+
tts_model = gr.Dropdown(
|
224 |
+
label="选择TTS模型",
|
225 |
+
choices=["ChatTTS", "HKUSTAudio/LLasa-1B", "HKUSTAudio/LLasa-3B", "HKUSTAudio/LLasa-8B"],
|
226 |
+
value=DEFAULT_TTS_MODEL_NAME,
|
227 |
+
interactive=True,
|
228 |
+
visible=False # 给产品演示,暂时不展示模型选择
|
229 |
+
)
|
230 |
+
|
231 |
+
# 参考音频上传
|
232 |
+
ref_audio = gr.Audio(
|
233 |
+
label="上传参考音频",
|
234 |
+
type="filepath",
|
235 |
+
interactive=True
|
236 |
+
)
|
237 |
+
ref_audio_transcribe = gr.Textbox(label="参考音频对应文本", visible=True)
|
238 |
+
# 创建示例数据
|
239 |
+
examples = gr.Examples(
|
240 |
+
examples=DEMO_EXAMPLES,
|
241 |
+
inputs=[ref_audio, ref_audio_transcribe],
|
242 |
+
fn=predict
|
243 |
+
)
|
244 |
+
|
245 |
+
with gr.Column():
|
246 |
+
audio_player = gr.Audio(
|
247 |
+
label="听听我声音~",
|
248 |
+
type="numpy",
|
249 |
+
interactive=False
|
250 |
+
)
|
251 |
+
msg = gr.Textbox(label="输入文本", placeholder="请输入想要合成的文本")
|
252 |
+
submit_btn = gr.Button("合成语音", variant="primary")
|
253 |
+
|
254 |
+
current_config = gr.State({
|
255 |
+
"msg": None,
|
256 |
+
"tts_model": DEFAULT_TTS_MODEL_NAME,
|
257 |
+
"ref_audio": None,
|
258 |
+
"ref_audio_transcribe": None
|
259 |
+
})
|
260 |
+
gr.on(triggers=[msg.change, tts_model.change, ref_audio.change,
|
261 |
+
ref_audio_transcribe.change],
|
262 |
+
fn=lambda text, model, audio, ref_text: {"msg": text, "tts_model": model, "ref_audio": audio,
|
263 |
+
"ref_audio_transcribe": ref_text},
|
264 |
+
inputs=[msg, tts_model, ref_audio, ref_audio_transcribe],
|
265 |
+
outputs=current_config
|
266 |
+
)
|
267 |
+
submit_btn.click(
|
268 |
+
predict,
|
269 |
+
[current_config],
|
270 |
+
[audio_player],
|
271 |
+
queue=False
|
272 |
+
)
|
273 |
+
demo.launch(share=False, server_name='0.0.0.0', server_port=7863, inbrowser=True)
|
274 |
+
|
utils.py
ADDED
@@ -0,0 +1,197 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
import re
|
2 |
+
import requests
|
3 |
+
import chardet
|
4 |
+
import config as cfg
|
5 |
+
from bs4 import BeautifulSoup
|
6 |
+
from pathlib import Path
|
7 |
+
from transformers import AutoTokenizer
|
8 |
+
from duckduckgo_search import DDGS
|
9 |
+
|
10 |
+
|
11 |
+
def log_in(uid, state):
|
12 |
+
state['chat_history'] = []
|
13 |
+
state['thinking_history'] = ''
|
14 |
+
state['uid'] = uid
|
15 |
+
if uid!=0:
|
16 |
+
response = f"Your Log In UID: {uid}"
|
17 |
+
else:
|
18 |
+
response = f"You Are Not Logged In Yet, Use Public Directory"
|
19 |
+
user_dir = Path(cfg.USER_DIR) / str(uid)
|
20 |
+
user_dir.mkdir(parents=True, exist_ok=True)
|
21 |
+
state['user_dir'] = user_dir
|
22 |
+
|
23 |
+
# 加载历史会话
|
24 |
+
state['available_history'] = []
|
25 |
+
for json_file in user_dir.rglob("*.json"):
|
26 |
+
state['available_history'].append(json_file.stem)
|
27 |
+
|
28 |
+
return response, state
|
29 |
+
|
30 |
+
def clean_response(response_content):
|
31 |
+
response_content = re.sub(r'\*\*|__', '', response_content)
|
32 |
+
response_content = re.sub(r'\\\(|\\\)|\\\[|\\\]', '', response_content)
|
33 |
+
response_content = re.sub(r'\\boxed\{([^}]*)\}', r'\1', response_content)
|
34 |
+
response_content = re.sub(r'\\\\', '', response_content)
|
35 |
+
response_content = re.sub(r'\n\s*\n', '\n\n', response_content)
|
36 |
+
response_content = re.sub(r'\s+', ' ', response_content)
|
37 |
+
return response_content.strip()
|
38 |
+
|
39 |
+
def parse_output(response_content):
|
40 |
+
cleaned_content = clean_response(response_content)
|
41 |
+
if "<think>" in cleaned_content and "</think>" in cleaned_content:
|
42 |
+
split_pattern = r'<think>|</think>'
|
43 |
+
parts = re.split(split_pattern, cleaned_content)
|
44 |
+
return parts[1], parts[2]
|
45 |
+
return None, cleaned_content
|
46 |
+
|
47 |
+
|
48 |
+
def parse_chat_history(chat_history):
|
49 |
+
"""从保存的历史会话中解析出chatbot可以识别的格式
|
50 |
+
chat_history示例:
|
51 |
+
[
|
52 |
+
{
|
53 |
+
"role": "user",
|
54 |
+
"content": "hello"
|
55 |
+
},
|
56 |
+
{
|
57 |
+
"role": "assistant",
|
58 |
+
"content": " Hello! How can I assist you today? 😊"
|
59 |
+
}
|
60 |
+
]
|
61 |
+
Args:
|
62 |
+
chat_history (list): _description_
|
63 |
+
"""
|
64 |
+
from gradio import Warning
|
65 |
+
try:
|
66 |
+
assert len(chat_history) % 2 == 0
|
67 |
+
except AssertionError:
|
68 |
+
Warning('历史会话可能有遗失,用户和AI的消息数不匹配,截断最后一条消息...')
|
69 |
+
chat_history = chat_history[:-1]
|
70 |
+
|
71 |
+
if len(chat_history) == 0:
|
72 |
+
Warning('历史会话为空或无法匹配,加载历史会话失败...')
|
73 |
+
return []
|
74 |
+
|
75 |
+
messages = []
|
76 |
+
responses = []
|
77 |
+
|
78 |
+
for conversation in chat_history:
|
79 |
+
if conversation['role'] == 'user':
|
80 |
+
messages.append(conversation['content'])
|
81 |
+
elif conversation['role'] == 'assistant':
|
82 |
+
responses.append(conversation['content'])
|
83 |
+
|
84 |
+
if len(messages) != len(responses):
|
85 |
+
Warning('用户和AI的消息无法匹配,加载历史会话失败...')
|
86 |
+
return []
|
87 |
+
|
88 |
+
return list(zip(messages, responses))
|
89 |
+
|
90 |
+
|
91 |
+
def web_search(query: str, max_results: int = 3):
|
92 |
+
"""获取网络搜索结果并提取关键内容"""
|
93 |
+
try:
|
94 |
+
# 获取搜索结果链接
|
95 |
+
with DDGS() as ddgs:
|
96 |
+
results = [r for r in ddgs.text(query, max_results=max_results)]
|
97 |
+
|
98 |
+
# 提取网页正文
|
99 |
+
web_contents = []
|
100 |
+
for result in results:
|
101 |
+
try:
|
102 |
+
response = requests.get(result['href'], timeout=5)
|
103 |
+
encoding = chardet.detect(response.content)['encoding']
|
104 |
+
if response.encoding != encoding:
|
105 |
+
response.encoding = encoding
|
106 |
+
soup = BeautifulSoup(response.text, 'html.parser')
|
107 |
+
# 提取主要文本内容(可根据网站结构调整)
|
108 |
+
main_content = soup.find('main') or soup.find('article') or soup.body
|
109 |
+
web_contents.append({
|
110 |
+
'title': result['title'],
|
111 |
+
'content': main_content.get_text(separator=' ', strip=True)[:1000] # 限制长度
|
112 |
+
})
|
113 |
+
except Exception as e:
|
114 |
+
print('该结果搜索异常...', e)
|
115 |
+
continue
|
116 |
+
return web_contents
|
117 |
+
except Exception as e:
|
118 |
+
print('网络搜索异常,返回空...', e)
|
119 |
+
return []
|
120 |
+
|
121 |
+
|
122 |
+
def parse_net_search(search_res):
|
123 |
+
res = []
|
124 |
+
for item in search_res:
|
125 |
+
if len(item['content']) > 0:
|
126 |
+
res.append(f"标题:\n{item['title']}\n内容:\n{item['content']}\n")
|
127 |
+
return res
|
128 |
+
|
129 |
+
|
130 |
+
def wash_up_content(doc_score):
|
131 |
+
"""_summary_
|
132 |
+
|
133 |
+
Args:
|
134 |
+
doc (string): 整个文档的内容,可能有多行,用换行符分割
|
135 |
+
"""
|
136 |
+
if isinstance(doc_score, tuple):
|
137 |
+
doc, score = doc_score
|
138 |
+
else:
|
139 |
+
doc = doc_score
|
140 |
+
score = None
|
141 |
+
res = list(filter(lambda x: len(x) > 0, doc.split('\n')))
|
142 |
+
prefix = '✅<br>"' if score is None else '✅Recall with score:{:.3f}<br>'.format(score)
|
143 |
+
res[0] = prefix + res[0]
|
144 |
+
return res
|
145 |
+
|
146 |
+
|
147 |
+
MODEL_HF_MAPPING = {
|
148 |
+
"qwen2.5:14b-instruct": "Qwen/Qwen2.5-14B-Instruct",
|
149 |
+
"qwen2.5:32b-instruct": "Qwen/Qwen2.5-32B-Instruct",
|
150 |
+
"qwen2.5:7b-instruct": "Qwen/Qwen2.5-7B-Instruct",
|
151 |
+
"qwen2.5:3b-instruct": "Qwen/Qwen2.5-3B-Instruct",
|
152 |
+
"qwen2.5:0.5b-instruct": "Qwen/Qwen2.5-0.5B-Instruct",
|
153 |
+
"qwen2.5:0.5b": "Qwen/Qwen2.5-0.5B",
|
154 |
+
"qwen2.5:32b": "Qwen/Qwen2.5-32B",
|
155 |
+
"qwen3:32b": "Qwen/Qwen3-32B",
|
156 |
+
"qwen3:14b": "Qwen/Qwen3-14B",
|
157 |
+
"qwen3:4b": "Qwen/Qwen3-4B",
|
158 |
+
"qwen3:8b": "Qwen/Qwen3-8B",
|
159 |
+
"qwen3:30b-a3b": "Qwen/Qwen3-30B-A3B",
|
160 |
+
"qwq": "Qwen/QwQ-32B",
|
161 |
+
"deepseek-r1:14b":"deepseek-ai/DeepSeek-R1-Distill-Qwen-14B",
|
162 |
+
"deepseek-r1:7b": "deepseek-ai/DeepSeek-R1-Distill-Qwen-7B",
|
163 |
+
"deepseek-r1:32b": "deepseek-ai/DeepSeek-R1-Distill-Qwen-32B",
|
164 |
+
|
165 |
+
}
|
166 |
+
|
167 |
+
def load_tokenizer(model_name):
|
168 |
+
hf_model_name = MODEL_HF_MAPPING.get(model_name, model_name)
|
169 |
+
return AutoTokenizer.from_pretrained(hf_model_name, use_fast=True)
|
170 |
+
|
171 |
+
def messages_to_prompt(messages):
|
172 |
+
# 按模型要求的格式拼接(此处为示例,需根据实际模型调整)
|
173 |
+
prompt = ""
|
174 |
+
for msg in messages:
|
175 |
+
prompt += f"{msg['role']}: {msg['content']}\n"
|
176 |
+
return prompt.strip()
|
177 |
+
|
178 |
+
def count_tokens_local(messages, tokenizer):
|
179 |
+
prompt = messages_to_prompt(messages)
|
180 |
+
return len(tokenizer(prompt, return_tensors=None, truncation=False)["input_ids"])
|
181 |
+
|
182 |
+
|
183 |
+
def concate_metadata(metadata):
|
184 |
+
"""把Document对象的metadata的各个键值拼接起来
|
185 |
+
|
186 |
+
Args:
|
187 |
+
metadata (dict): _description_
|
188 |
+
"""
|
189 |
+
return '\n'.join([f"{k}: {v}" for k, v in metadata.items()])
|
190 |
+
|
191 |
+
if __name__ == "__main__":
|
192 |
+
query = "今天几号"
|
193 |
+
ret = web_search(query)
|
194 |
+
for item in ret:
|
195 |
+
print(item)
|
196 |
+
ret = parse_net_search(ret)
|
197 |
+
print(ret)
|