Giter VIP home page Giter VIP logo

Comments (2)

kkoutini avatar kkoutini commented on May 24, 2024 1

Hi thanks!
you can use passt_hear21 like this:

# Loading the weights
p ="output/esc50/_None/checkpoints/epoch=4-step=2669.ckpt"
ckpt = torch.load(p)
net_statedict = {k[4:]: v.cpu() for k, v in ckpt['state_dict'].items() if k.startswith("net.")} # main weights
net_swa  = {k[len("net_swa."):]: v.cpu() for k, v in ckpt['state_dict'].items() if k.startswith("net_swa.")} # swa weights

# getting the model
from hear21passt.base import load_model, get_scene_embeddings, get_timestamp_embeddings

model = load_model(mode="logits").cuda()
model.net.load_state_dict(net_statedict) # loading the fine-tuned weights

# example
wave_example = torch.ones((3, 32000 * 5))*0.5 
logits = model(wave_example)

from passt.

myatmyintzuthin avatar myatmyintzuthin commented on May 24, 2024

Thank you so much for the reply.
This is my first time of creating inference script in PyTorch. It was a great help to me.
I am gonna share my inference script here in case someone wants to use.

# References 
# 1) https://github.com/YuanGongND/ast/blob/master/egs/audioset/inference.py
# 2) https://github.com/kkoutini/passt_hear21

import csv
import argparse
import numpy as np
import torch
import torchaudio
from pytorch_lightning import Trainer as plTrainer
from hear21passt.base import load_model

def load_label(label_csv):
    with open(label_csv, 'r') as f:
        reader = csv.reader(f, delimiter=',')
        lines = list(reader)
    labels = []
    ids = []  # Each label has a unique id such as "/m/068hy"
    for i1 in range(1, len(lines)):
        id = lines[i1][1]
        label = lines[i1][2]
        ids.append(id)
        labels.append(label)
    return labels


if __name__ == '__main__':

    parser = argparse.ArgumentParser(description='Example of parser:'
                                                 'python inference --audio_path ESC-50-master/audio_32k/1-5996-A-6.wav '
                                                 '--model_path checkpoints/epoch=2-step=4799.ckpt')

    parser.add_argument("--model_path", required= True,type=str,
                        help="the trained model you want to test")
    parser.add_argument("--audio_path", required= True,
                        help='the audio you want to predict, sample rate 32k.',
                        type=str)

    args = parser.parse_args()

    label_csv = './esc50/esc_class_labels_indices.csv'       # label and indices for ESC-50 data

    # 1. load audio file
    audio_path = args.audio_path
    waveform, _ = torchaudio.load(audio_path)

    # 2. load checkpoint
    checkpoint_path = args.model_path
   
    ckpt = torch.load(checkpoint_path)
    net_statedict = {k[4:]: v.cpu() for k, v in ckpt['state_dict'].items() if k.startswith("net.")} # main weights
    net_swa  = {k[len("net_swa."):]: v.cpu() for k, v in ckpt['state_dict'].items() if k.startswith("net_swa.")} # swa weights

    # 3. loading the fine-tuned weights
    passt_model = load_model(mode="logits").cuda()
    passt_model.net.load_state_dict(net_statedict)
    
    trainer = plTrainer(gpus=1)
    print(f'[*INFO] load checkpoint: {checkpoint_path}')
    
    passt_model = passt_model.to(torch.device("cuda:0"))
    waveform = waveform.to(torch.device("cuda:0"))
    
    with torch.no_grad():
        output = passt_model(waveform)
        output = torch.sigmoid(output)
    result_output = output.data.cpu().numpy()[0]

    # 4. map the post-prob to label
    labels = load_label(label_csv)

    sorted_indexes = np.argsort(result_output)[::-1]

    # Print audio tagging top probabilities
    print('[*INFO] predict results:')
    for k in range(5):
        print('{}: {:.4f}'.format(np.array(labels)[sorted_indexes[k]],
                                  result_output[sorted_indexes[k]]))

from passt.

Related Issues (20)

Recommend Projects

  • React photo React

    A declarative, efficient, and flexible JavaScript library for building user interfaces.

  • Vue.js photo Vue.js

    🖖 Vue.js is a progressive, incrementally-adoptable JavaScript framework for building UI on the web.

  • Typescript photo Typescript

    TypeScript is a superset of JavaScript that compiles to clean JavaScript output.

  • TensorFlow photo TensorFlow

    An Open Source Machine Learning Framework for Everyone

  • Django photo Django

    The Web framework for perfectionists with deadlines.

  • D3 photo D3

    Bring data to life with SVG, Canvas and HTML. 📊📈🎉

Recommend Topics

  • javascript

    JavaScript (JS) is a lightweight interpreted programming language with first-class functions.

  • web

    Some thing interesting about web. New door for the world.

  • server

    A server is a program made to process requests and deliver data to clients.

  • Machine learning

    Machine learning is a way of modeling and interpreting data that allows a piece of software to respond intelligently.

  • Game

    Some thing interesting about game, make everyone happy.

Recommend Org

  • Facebook photo Facebook

    We are working to build community through open source technology. NB: members must have two-factor auth.

  • Microsoft photo Microsoft

    Open source projects and samples from Microsoft.

  • Google photo Google

    Google ❤️ Open Source for everyone.

  • D3 photo D3

    Data-Driven Documents codes.