Giter VIP home page Giter VIP logo

Comments (7)

JonghwanMun avatar JonghwanMun commented on June 18, 2024

Simply, apply softmax to generate probaility and then thresholding the value at the second dimension by 0.5 provides the binary prediction result.

For example,

logits = model(input)
probs = F.softmax(logits, dim=-1)
preds = probs[:, 1] > 0.5

from bassl.

BrunoSader avatar BrunoSader commented on June 18, 2024

Thank you.
I have some more questions if you don't mind.
The model I am using is the trained BaSSl 40epochs, here is how I load it.

    cfg = init_hydra_config(mode="extract_shot")
    apply_random_seed(cfg)
    cfg = load_pretrained_config(cfg)

    # init model
    cfg, model = init_model(cfg)

    # init trainer
    cfg, trainer = init_trainer(cfg)

Is this right?
And I don't understand what I am supposed to give it as an input.
Do I just create a dataloader of tensors for each image in my movie?
Thank you very much for your help 😄

from bassl.

JonghwanMun avatar JonghwanMun commented on June 18, 2024

For loading a BaSSL 40 epochs scene segmentation model in inference, you need to convert load_pretrained_config to load_finetuned_config function, for example,

def load_finetuned_config(cfg):
     ckpt_root = cfg.CKPT_PATH
     load_from = cfg.LOAD_FROM

     with open(os.path.join(ckpt_root, load_from, "config.json"), "r") as fopen:
         finetuned_cfg = json.load(fopen)
         finetuned_cfg = easydict.EasyDict(finetuned_cfg)

     # override configuration of pre-trained model
     cfg.MODEL = finetuned_cfg.MODEL
     cfg.PRETRAINED_LOAD_FROM = finetuned_cfg.PRETRAINED_LOAD_FROM

     cfg.TRAIN.USE_SINGLE_KEYFRAME = False
     cfg.MODEL.contextual_relation_network.params.trn.pooling_method = "center"

     # override neighbor size of an input sequence of shots
     sampling = finetuned_cfg.LOSS.sampling_method.name
     nsize = finetuned_cfg.LOSS.sampling_method.params[sampling]["neighbor_size"]
     cfg.LOSS.sampling_method.params["sbd"]["neighbor_size"] = nsize

     return cfg

Then, you also need to specify LOAD_FROM option to tell the path of a finetuned model.
It may be same with EXPR_NAME used during finetuning stage.

For an input, our algorithm works on top of shot.
you first need to divide a movie into a series of shots and extract three key-frames for each shot (refer http://docs.movienet.site/movie-toolbox/tools/shot_detector).
Then, you need to feed three key-frames for each shot as input of the network.

from bassl.

LFavano avatar LFavano commented on June 18, 2024

Hello, I would also be interested in knowing more details on how to run the code for inference starting from a fine-tuned model, I tried using @JonghwanMun but couldn't come up with working code.

Is it correct to init the cfg this way, and would "finetune" be the correct mode here?

cfg = init_hydra_config(mode="finetune")
apply_random_seed(cfg)
cfg = load_finetuned_config(cfg)

About the data, I have two questions:

  • Can the init_data_loader util function be used for the input that trainer.predict() expects to receive?
  • If using the data loader is not possible, can I ask what code would generate the right input for the network? I have extracted the key frames for each shot, but from running model(data) it seems that the expected shape is [64,3,7,7], is this the right behavior?

Thank you

from bassl.

barry2025 avatar barry2025 commented on June 18, 2024

Hello, I see FinetuningWrapper.load_from_checkpoint in main_utils.py, but i cannot find the implementation of load_from_checkpoint in finetune_wrapper.py, I wonder how it works, thanks

from bassl.

JonghwanMun avatar JonghwanMun commented on June 18, 2024

@barry2025
load_from_checkpoint() is a function inherited from LightningModule of pytorch lightning;
It initializes the parameters from the checkpoint given by checkpoint_path when constructing FinetuningWrapper instance.
Please refer to pytorch lightning document for more details.

from bassl.

barry2025 avatar barry2025 commented on June 18, 2024

Thanks! I never used pytorch lightning before, I'll try.

from bassl.

Related Issues (11)

Recommend Projects

  • React photo React

    A declarative, efficient, and flexible JavaScript library for building user interfaces.

  • Vue.js photo Vue.js

    🖖 Vue.js is a progressive, incrementally-adoptable JavaScript framework for building UI on the web.

  • Typescript photo Typescript

    TypeScript is a superset of JavaScript that compiles to clean JavaScript output.

  • TensorFlow photo TensorFlow

    An Open Source Machine Learning Framework for Everyone

  • Django photo Django

    The Web framework for perfectionists with deadlines.

  • D3 photo D3

    Bring data to life with SVG, Canvas and HTML. 📊📈🎉

Recommend Topics

  • javascript

    JavaScript (JS) is a lightweight interpreted programming language with first-class functions.

  • web

    Some thing interesting about web. New door for the world.

  • server

    A server is a program made to process requests and deliver data to clients.

  • Machine learning

    Machine learning is a way of modeling and interpreting data that allows a piece of software to respond intelligently.

  • Game

    Some thing interesting about game, make everyone happy.

Recommend Org

  • Facebook photo Facebook

    We are working to build community through open source technology. NB: members must have two-factor auth.

  • Microsoft photo Microsoft

    Open source projects and samples from Microsoft.

  • Google photo Google

    Google ❤️ Open Source for everyone.

  • D3 photo D3

    Data-Driven Documents codes.