Comments (2)
Hi thanks!
you can use passt_hear21 like this:
# Loading the weights
p ="output/esc50/_None/checkpoints/epoch=4-step=2669.ckpt"
ckpt = torch.load(p)
net_statedict = {k[4:]: v.cpu() for k, v in ckpt['state_dict'].items() if k.startswith("net.")} # main weights
net_swa = {k[len("net_swa."):]: v.cpu() for k, v in ckpt['state_dict'].items() if k.startswith("net_swa.")} # swa weights
# getting the model
from hear21passt.base import load_model, get_scene_embeddings, get_timestamp_embeddings
model = load_model(mode="logits").cuda()
model.net.load_state_dict(net_statedict) # loading the fine-tuned weights
# example
wave_example = torch.ones((3, 32000 * 5))*0.5
logits = model(wave_example)
from passt.
Thank you so much for the reply.
This is my first time of creating inference script in PyTorch. It was a great help to me.
I am gonna share my inference script here in case someone wants to use.
# References
# 1) https://github.com/YuanGongND/ast/blob/master/egs/audioset/inference.py
# 2) https://github.com/kkoutini/passt_hear21
import csv
import argparse
import numpy as np
import torch
import torchaudio
from pytorch_lightning import Trainer as plTrainer
from hear21passt.base import load_model
def load_label(label_csv):
with open(label_csv, 'r') as f:
reader = csv.reader(f, delimiter=',')
lines = list(reader)
labels = []
ids = [] # Each label has a unique id such as "/m/068hy"
for i1 in range(1, len(lines)):
id = lines[i1][1]
label = lines[i1][2]
ids.append(id)
labels.append(label)
return labels
if __name__ == '__main__':
parser = argparse.ArgumentParser(description='Example of parser:'
'python inference --audio_path ESC-50-master/audio_32k/1-5996-A-6.wav '
'--model_path checkpoints/epoch=2-step=4799.ckpt')
parser.add_argument("--model_path", required= True,type=str,
help="the trained model you want to test")
parser.add_argument("--audio_path", required= True,
help='the audio you want to predict, sample rate 32k.',
type=str)
args = parser.parse_args()
label_csv = './esc50/esc_class_labels_indices.csv' # label and indices for ESC-50 data
# 1. load audio file
audio_path = args.audio_path
waveform, _ = torchaudio.load(audio_path)
# 2. load checkpoint
checkpoint_path = args.model_path
ckpt = torch.load(checkpoint_path)
net_statedict = {k[4:]: v.cpu() for k, v in ckpt['state_dict'].items() if k.startswith("net.")} # main weights
net_swa = {k[len("net_swa."):]: v.cpu() for k, v in ckpt['state_dict'].items() if k.startswith("net_swa.")} # swa weights
# 3. loading the fine-tuned weights
passt_model = load_model(mode="logits").cuda()
passt_model.net.load_state_dict(net_statedict)
trainer = plTrainer(gpus=1)
print(f'[*INFO] load checkpoint: {checkpoint_path}')
passt_model = passt_model.to(torch.device("cuda:0"))
waveform = waveform.to(torch.device("cuda:0"))
with torch.no_grad():
output = passt_model(waveform)
output = torch.sigmoid(output)
result_output = output.data.cpu().numpy()[0]
# 4. map the post-prob to label
labels = load_label(label_csv)
sorted_indexes = np.argsort(result_output)[::-1]
# Print audio tagging top probabilities
print('[*INFO] predict results:')
for k in range(5):
print('{}: {:.4f}'.format(np.array(labels)[sorted_indexes[k]],
result_output[sorted_indexes[k]]))
from passt.
Related Issues (20)
- kaggle HOT 2
- Fine tuning on novel dataset HOT 4
- Is it possible to use this project directly for a code example for instrument recognition? HOT 4
- mismatch version of pytorch-lighting and sarced HOT 15
- Installation issues HOT 1
- The loop in the diagram HOT 1
- RuntimeError: The size of tensor a (2055) must match the size of tensor b (99) at non-singleton dimension 3 HOT 3
- is `config.dyn_norm` enabled? HOT 1
- Is it possible to install the passt with python=3.6? HOT 2
- ImportError: cannot import name 'F1' from 'torchmetrics' (/app/anaconda3/lib/python3.7/site-packages/torchmetrics/__init__.py) HOT 1
- FSD50K - validating on eval data HOT 5
- Pretrained models config HOT 3
- OpenMic fine-tuned model? HOT 2
- Could not solve for environment specs HOT 4
- setup.py
- I have a problem. why convert wav to mp3? HOT 3
- difference of fine-tuning the pretrained models HOT 2
- Inference Issue HOT 2
- Getting started with a custom dataset HOT 8
Recommend Projects
-
React
A declarative, efficient, and flexible JavaScript library for building user interfaces.
-
Vue.js
🖖 Vue.js is a progressive, incrementally-adoptable JavaScript framework for building UI on the web.
-
Typescript
TypeScript is a superset of JavaScript that compiles to clean JavaScript output.
-
TensorFlow
An Open Source Machine Learning Framework for Everyone
-
Django
The Web framework for perfectionists with deadlines.
-
Laravel
A PHP framework for web artisans
-
D3
Bring data to life with SVG, Canvas and HTML. 📊📈🎉
-
Recommend Topics
-
javascript
JavaScript (JS) is a lightweight interpreted programming language with first-class functions.
-
web
Some thing interesting about web. New door for the world.
-
server
A server is a program made to process requests and deliver data to clients.
-
Machine learning
Machine learning is a way of modeling and interpreting data that allows a piece of software to respond intelligently.
-
Visualization
Some thing interesting about visualization, use data art
-
Game
Some thing interesting about game, make everyone happy.
Recommend Org
-
Facebook
We are working to build community through open source technology. NB: members must have two-factor auth.
-
Microsoft
Open source projects and samples from Microsoft.
-
Google
Google ❤️ Open Source for everyone.
-
Alibaba
Alibaba Open Source for everyone
-
D3
Data-Driven Documents codes.
-
Tencent
China tencent open source team.
from passt.