Giter VIP home page Giter VIP logo

Comments (4)

ac-alpha avatar ac-alpha commented on May 29, 2024 1

@hdubey you can find the weights for RawNet2 here.

You can use this script for quick testing and getting the embeddings. Make sure that you have this model definition in the directory you are running this script.

from tqdm import tqdm
from collections import OrderedDict

import os
import argparse
import json
import numpy as np
import glob
import pickle

import torch
import torch.nn as nn
from torch.utils import data

from dataloader import *
from model_RawNet2 import RawNet2
from parser import get_args
from trainer import *
from utils import *
from model_RawNet2_original_code import *
from pydub import AudioSegment

load_model_dir = "Pre-trained_model/rawnet2_best_weights.pt"
test_wav_path1 = "/root/host/ml-speaker-verification/data/vox1/vox1_test_wav/id10270/5r0dWxy17C8/00001.wav"
test_wav_path2 = "/root/host/ml-speaker-verification/data/vox1/vox1_test_wav/id10278/d6WJf6TOoIQ/00001.wav"

test_wav_path3 = "/root/host/ml-speaker-verification/data/vox2/vox2_test_m4a/id00017/01dfn2spqyE/00001.m4a"
test_wav_path4 = "/root/host/ml-speaker-verification/data/vox2/vox2_test_m4a/id00017/8_a6O3vdlU0/00021.m4a"

def cos_sim(a,b) :
    return np.dot(a,b) / (np.linalg.norm(a)*np.linalg.norm(b))

def read_wav_and_get_clip_tensor(test_wav_path, nb_samp, window_size, wav_file = True):
    
    if not wav_file:
        X = AudioSegment.from_file(test_wav_path)
        X = X.get_array_of_samples()
        X = np.array(X)
    else:
        X, _ = sf.read(test_wav_path)
    X = X.astype(np.float64)
    X = _normalize_scale(X).astype(np.float32)
    X = X.reshape(1,-1)
    
    nb_time = X.shape[1]
    list_X = []
    nb_time = X.shape[1]
    if nb_time < nb_samp:
        nb_dup = int(nb_samp / nb_time) + 1
        list_X.append(np.tile(X, (1, nb_dup))[:, :nb_samp][0])
    elif nb_time > nb_samp:
        step = nb_samp - window_size
        iteration = int( (nb_time - window_size) / step ) + 1
        for i in range(iteration):
            if i == 0:
                list_X.append(X[:, :nb_samp][0])
            elif i < iteration - 1:
                list_X.append(X[:, i*step : i*step + nb_samp][0])
            else:
                list_X.append(X[:, -nb_samp:][0])
    else :
        list_X.append(X[0])
    return torch.from_numpy(np.asarray(list_X))

def get_embedding_from_clip_tensor(clip_tensor, model, device):
    model.eval()
    
    with torch.set_grad_enabled(False):
        #1st, extract speaker embeddings.
        l_embeddings = []
        l_code = []
        mbatch = clip_tensor
        mbatch = mbatch.unsqueeze(1)
#         print("Batch size = {}".format(mbatch.size()))
        for batch in mbatch:
            batch = batch.to(device)
            code = model(x = batch, is_test=True)
#             print("Code size = {}".format(code.size()))
            l_code.extend(code.cpu().numpy())
        embedding = np.mean(l_code, axis=0)
#         print("Embedding shape = {}".format(embedding.shape))
        return embedding

def _normalize_scale(x):
    '''
    Normalize sample scale alike SincNet.
    '''
    return x/np.max(np.abs(x))

def main_test():
    #parse arguments
    args = get_args()
    
    wav_path = args.wav_path
    save_path = args.sav_path
    direc_level = args.direc_level
    wav_file = True if args.wav_file==1 else False
    
    ## Number of speakers in VoxCeleb2 dataset. 
    ## Not used in computing embeddings but should still be there. 
    ## Do not comment this.
    args.model['nb_classes'] = 6112 

    torch.backends.cudnn.deterministic = True
    torch.backends.cudnn.benchmark = False

    #device setting
    cuda = torch.cuda.is_available()
    device = torch.device('cuda' if cuda else 'cpu')
    print('Device: {}'.format(device))
    
    model = RawNet(args.model, device).to(device)
    model.load_state_dict(torch.load(load_model_dir))
    nb_params = sum([param.view(-1).size()[0] for param in model.parameters()])
    nb_samp = args.model["nb_samp"]
    window_size = args.window_size
    print('nb_params: {}'.format(nb_params))
    
    X1 = read_wav_and_get_clip_tensor(test_wav_path3, nb_samp, window_size, wav_file)
    emb_X1 = get_embedding_from_clip_tensor(X1, model, device)
    
    X2 = read_wav_and_get_clip_tensor(test_wav_path4, nb_samp, window_size, wav_file)
    emb_X2 = get_embedding_from_clip_tensor(X2, model, device)
    
    sim_score = cos_sim(emb_X1, emb_X2)
    print("Similarity = {}".format(sim_score))

if __name__ == '__main__':
    main_test()

from rawnet.

Jungjee avatar Jungjee commented on May 29, 2024 1

@ac-alpha thanks for the reply :)
I'll close this

from rawnet.

hdubey avatar hdubey commented on May 29, 2024

@ac-alpha thanks. using above script and provided model leads to following errors. Is it RawNet or RawNet2 or Rawnet2_modified? RuntimeError: Error(s) in loading state_dict for RawNet:
Unexpected key(s) in state_dict: "block2.0.conv_downsample.weight", "block2.0.conv_downsample.bias".
size mismatch for block2.0.bn1.weight: copying a param with shape torch.Size([128]) from checkpoint, the shape in current model is torch.Size([256]).
size mismatch for block2.0.bn1.bias: copying a param with shape torch.Size([128]) from checkpoint, the shape in current model is torch.Size([256]).
size mismatch for block2.0.bn1.running_mean: copying a param with shape torch.Size([128]) from checkpoint, the shape in current model is torch.Size([256]).
size mismatch for block2.0.bn1.running_var: copying a param with shape torch.Size([128]) from checkpoint, the shape in current model is torch.Size([256]).
size mismatch for block2.0.conv1.weight: copying a param with shape torch.Size([256, 128, 3]) from checkpoint, the shape in current model is torch.Size([256, 256, 3]).

from rawnet.

Jungjee avatar Jungjee commented on May 29, 2024

Closing this now as I have uploaded RawNet3 and a script to extract speaker embedding from any 16k 16bit mono utterance

from rawnet.

Related Issues (20)

Recommend Projects

  • React photo React

    A declarative, efficient, and flexible JavaScript library for building user interfaces.

  • Vue.js photo Vue.js

    🖖 Vue.js is a progressive, incrementally-adoptable JavaScript framework for building UI on the web.

  • Typescript photo Typescript

    TypeScript is a superset of JavaScript that compiles to clean JavaScript output.

  • TensorFlow photo TensorFlow

    An Open Source Machine Learning Framework for Everyone

  • Django photo Django

    The Web framework for perfectionists with deadlines.

  • D3 photo D3

    Bring data to life with SVG, Canvas and HTML. 📊📈🎉

Recommend Topics

  • javascript

    JavaScript (JS) is a lightweight interpreted programming language with first-class functions.

  • web

    Some thing interesting about web. New door for the world.

  • server

    A server is a program made to process requests and deliver data to clients.

  • Machine learning

    Machine learning is a way of modeling and interpreting data that allows a piece of software to respond intelligently.

  • Game

    Some thing interesting about game, make everyone happy.

Recommend Org

  • Facebook photo Facebook

    We are working to build community through open source technology. NB: members must have two-factor auth.

  • Microsoft photo Microsoft

    Open source projects and samples from Microsoft.

  • Google photo Google

    Google ❤️ Open Source for everyone.

  • D3 photo D3

    Data-Driven Documents codes.