-
Notifications
You must be signed in to change notification settings - Fork 0
Expand file tree
/
Copy pathbuild_unified_db_voice.py
More file actions
67 lines (52 loc) · 2.16 KB
/
build_unified_db_voice.py
File metadata and controls
67 lines (52 loc) · 2.16 KB
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
# build_unified_db_voice.py
import argparse
from pathlib import Path
from tqdm import tqdm
from mapping_reader import load_id_name_map
from voice_model import ECAPATDNNModel
from db_utils import load_db, save_db, index_teachers, ensure_teacher, finalize_db
AUD_EXTS = {".wav"}
def collect_audio(folder: Path):
if not folder.exists():
return []
files = []
for p in folder.rglob("*"):
if p.is_file() and p.suffix.lower() in AUD_EXTS:
files.append(p)
return sorted(files)
def main():
ap = argparse.ArgumentParser()
ap.add_argument("--dataset", default="dataset", help="dataset root")
ap.add_argument("--map", default="Teachers.csv", help="Teachers.csv (id,name)")
ap.add_argument("--out", default="db/teachers.json")
ap.add_argument("--mode", choices=["replace", "append"], default="replace")
ap.add_argument("--min_audios", type=int, default=1)
ap.add_argument("--ffmpeg", default="ffmpeg", help="ffmpeg path or 'ffmpeg'")
ap.add_argument("--device", default=None)
args = ap.parse_args()
ds = Path(args.dataset)
id_name = load_id_name_map(args.map)
db = load_db(args.out)
idx = index_teachers(db)
model = ECAPATDNNModel(device=args.device, ffmpeg_path=args.ffmpeg)
for tid, name in tqdm(id_name.items(), desc="Teachers"):
teacher = ensure_teacher(idx, tid, name)
aud_dir = ds / tid / "audio"
audios = collect_audio(aud_dir)
if len(audios) < args.min_audios:
continue
if args.mode == "replace":
teacher["voice_embeddings"] = []
for a in audios:
try:
emb = model.embed_file(str(a))
teacher["voice_embeddings"].append(emb.tolist())
except Exception:
pass
teacher["meta"]["num_audios_used"] = len(teacher["voice_embeddings"])
teacher["meta"]["num_images_used"] = len(teacher["face_embeddings"])
db = finalize_db(db, idx)
save_db(args.out, db)
print(f"Saved unified DB (voice): {args.out} | teachers={len(db['teachers'])}")
if __name__ == "__main__":
main()