Comments (4)
@hdubey you can find the weights for RawNet2 here.
You can use this script for quick testing and getting the embeddings. Make sure that you have this model definition in the directory you are running this script.
from tqdm import tqdm
from collections import OrderedDict
import os
import argparse
import json
import numpy as np
import glob
import pickle
import torch
import torch.nn as nn
from torch.utils import data
from dataloader import *
from model_RawNet2 import RawNet2
from parser import get_args
from trainer import *
from utils import *
from model_RawNet2_original_code import *
from pydub import AudioSegment
load_model_dir = "Pre-trained_model/rawnet2_best_weights.pt"
test_wav_path1 = "/root/host/ml-speaker-verification/data/vox1/vox1_test_wav/id10270/5r0dWxy17C8/00001.wav"
test_wav_path2 = "/root/host/ml-speaker-verification/data/vox1/vox1_test_wav/id10278/d6WJf6TOoIQ/00001.wav"
test_wav_path3 = "/root/host/ml-speaker-verification/data/vox2/vox2_test_m4a/id00017/01dfn2spqyE/00001.m4a"
test_wav_path4 = "/root/host/ml-speaker-verification/data/vox2/vox2_test_m4a/id00017/8_a6O3vdlU0/00021.m4a"
def cos_sim(a,b) :
return np.dot(a,b) / (np.linalg.norm(a)*np.linalg.norm(b))
def read_wav_and_get_clip_tensor(test_wav_path, nb_samp, window_size, wav_file = True):
if not wav_file:
X = AudioSegment.from_file(test_wav_path)
X = X.get_array_of_samples()
X = np.array(X)
else:
X, _ = sf.read(test_wav_path)
X = X.astype(np.float64)
X = _normalize_scale(X).astype(np.float32)
X = X.reshape(1,-1)
nb_time = X.shape[1]
list_X = []
nb_time = X.shape[1]
if nb_time < nb_samp:
nb_dup = int(nb_samp / nb_time) + 1
list_X.append(np.tile(X, (1, nb_dup))[:, :nb_samp][0])
elif nb_time > nb_samp:
step = nb_samp - window_size
iteration = int( (nb_time - window_size) / step ) + 1
for i in range(iteration):
if i == 0:
list_X.append(X[:, :nb_samp][0])
elif i < iteration - 1:
list_X.append(X[:, i*step : i*step + nb_samp][0])
else:
list_X.append(X[:, -nb_samp:][0])
else :
list_X.append(X[0])
return torch.from_numpy(np.asarray(list_X))
def get_embedding_from_clip_tensor(clip_tensor, model, device):
model.eval()
with torch.set_grad_enabled(False):
#1st, extract speaker embeddings.
l_embeddings = []
l_code = []
mbatch = clip_tensor
mbatch = mbatch.unsqueeze(1)
# print("Batch size = {}".format(mbatch.size()))
for batch in mbatch:
batch = batch.to(device)
code = model(x = batch, is_test=True)
# print("Code size = {}".format(code.size()))
l_code.extend(code.cpu().numpy())
embedding = np.mean(l_code, axis=0)
# print("Embedding shape = {}".format(embedding.shape))
return embedding
def _normalize_scale(x):
'''
Normalize sample scale alike SincNet.
'''
return x/np.max(np.abs(x))
def main_test():
#parse arguments
args = get_args()
wav_path = args.wav_path
save_path = args.sav_path
direc_level = args.direc_level
wav_file = True if args.wav_file==1 else False
## Number of speakers in VoxCeleb2 dataset.
## Not used in computing embeddings but should still be there.
## Do not comment this.
args.model['nb_classes'] = 6112
torch.backends.cudnn.deterministic = True
torch.backends.cudnn.benchmark = False
#device setting
cuda = torch.cuda.is_available()
device = torch.device('cuda' if cuda else 'cpu')
print('Device: {}'.format(device))
model = RawNet(args.model, device).to(device)
model.load_state_dict(torch.load(load_model_dir))
nb_params = sum([param.view(-1).size()[0] for param in model.parameters()])
nb_samp = args.model["nb_samp"]
window_size = args.window_size
print('nb_params: {}'.format(nb_params))
X1 = read_wav_and_get_clip_tensor(test_wav_path3, nb_samp, window_size, wav_file)
emb_X1 = get_embedding_from_clip_tensor(X1, model, device)
X2 = read_wav_and_get_clip_tensor(test_wav_path4, nb_samp, window_size, wav_file)
emb_X2 = get_embedding_from_clip_tensor(X2, model, device)
sim_score = cos_sim(emb_X1, emb_X2)
print("Similarity = {}".format(sim_score))
if __name__ == '__main__':
main_test()
from rawnet.
@ac-alpha thanks for the reply :)
I'll close this
from rawnet.
@ac-alpha thanks. using above script and provided model leads to following errors. Is it RawNet or RawNet2 or Rawnet2_modified? RuntimeError: Error(s) in loading state_dict for RawNet:
Unexpected key(s) in state_dict: "block2.0.conv_downsample.weight", "block2.0.conv_downsample.bias".
size mismatch for block2.0.bn1.weight: copying a param with shape torch.Size([128]) from checkpoint, the shape in current model is torch.Size([256]).
size mismatch for block2.0.bn1.bias: copying a param with shape torch.Size([128]) from checkpoint, the shape in current model is torch.Size([256]).
size mismatch for block2.0.bn1.running_mean: copying a param with shape torch.Size([128]) from checkpoint, the shape in current model is torch.Size([256]).
size mismatch for block2.0.bn1.running_var: copying a param with shape torch.Size([128]) from checkpoint, the shape in current model is torch.Size([256]).
size mismatch for block2.0.conv1.weight: copying a param with shape torch.Size([256, 128, 3]) from checkpoint, the shape in current model is torch.Size([256, 256, 3]).
from rawnet.
Closing this now as I have uploaded RawNet3 and a script to extract speaker embedding from any 16k 16bit mono utterance
from rawnet.
Related Issues (20)
- About the pretrained model HOT 2
- script generate embedding HOT 1
- The generalization abality HOT 1
- How make embedding for single wave file? HOT 3
- The speaker embedding for the VoxCeleb1 was deleted HOT 1
- can you share code to load the pretrained model? HOT 1
- What is the requirement in terms of hardware? HOT 1
- Unable to train RawNet1 using Keras HOT 1
- Weights of RawNet2_modified trained on VoxCeleb2 HOT 3
- Misbehaving losses while training RawNet1 HOT 2
- Too long IO time HOT 2
- The link for Pre-trained weight parameters for Rawnet 3 is not available. HOT 2
- Overfitting on VoxCeleb HOT 2
- Error in PreEmphasis Class HOT 2
- how to evaluate your implementation with a different dataset HOT 1
- Can I feed the 22050 sr wav to the pre-trained rawnet3 model ? HOT 3
- how to create the test_list for a new test dataset HOT 3
- how to use it for speaker verification HOT 6
- centre loss HOT 1
Recommend Projects
-
React
A declarative, efficient, and flexible JavaScript library for building user interfaces.
-
Vue.js
🖖 Vue.js is a progressive, incrementally-adoptable JavaScript framework for building UI on the web.
-
Typescript
TypeScript is a superset of JavaScript that compiles to clean JavaScript output.
-
TensorFlow
An Open Source Machine Learning Framework for Everyone
-
Django
The Web framework for perfectionists with deadlines.
-
Laravel
A PHP framework for web artisans
-
D3
Bring data to life with SVG, Canvas and HTML. 📊📈🎉
-
Recommend Topics
-
javascript
JavaScript (JS) is a lightweight interpreted programming language with first-class functions.
-
web
Some thing interesting about web. New door for the world.
-
server
A server is a program made to process requests and deliver data to clients.
-
Machine learning
Machine learning is a way of modeling and interpreting data that allows a piece of software to respond intelligently.
-
Visualization
Some thing interesting about visualization, use data art
-
Game
Some thing interesting about game, make everyone happy.
Recommend Org
-
Facebook
We are working to build community through open source technology. NB: members must have two-factor auth.
-
Microsoft
Open source projects and samples from Microsoft.
-
Google
Google ❤️ Open Source for everyone.
-
Alibaba
Alibaba Open Source for everyone
-
D3
Data-Driven Documents codes.
-
Tencent
China tencent open source team.
from rawnet.