Giter VIP home page Giter VIP logo

Comments (7)

ablaom avatar ablaom commented on May 20, 2024 1

Registry update pending: JuliaRegistries/General#99656.

The warning is expected. According to the input_scitype, only tabular input is allowed:

input_scitype = MMI.Table(Union{MMI.Continuous,MMI.Missing})

Looks like a matrix is also acceptable. So maybe BetaML extends input_scitype declaration to suppress warnings in future, although this certainly isn't always done in MLJ interfaces. You can always wrap your matrix, Tables.table(matrix).

An extended declaration would look like this:

input_scitype = Union{
    MMI.Table(MMI.Table(Union{MMI.Continuous,MMI.Missing}),
    AbstractMatrix{<:Union{MMI.Continuous,MMI.Missing}},
}

from betaml.jl.

sylvaticus avatar sylvaticus commented on May 20, 2024

Thank you. I have recently refactored the model names.. I am going to look at it..

from betaml.jl.

sylvaticus avatar sylvaticus commented on May 20, 2024

..hmmm.. I can reproduce it. It seems to work, but it produces a warning, if the model is loaded directly from BetaML:

(temp) pkg> activate --temp

julia> using MLJ
 │ Package MLJ not found, but a package named MLJ is available from a registry. 
 │ Install package?
 │   (jl_TgwIzK) pkg> add MLJ 
 └ (y/n/o) [y]: y

[...]

julia> model2 = BetaML.Bmlj.GaussianMixtureClusterer()
GaussianMixtureClusterer(
  n_classes = 3, 
  initial_probmixtures = Float64[], 
  mixtures = BetaML.GMM.DiagonalGaussian{Float64}[BetaML.GMM.DiagonalGaussian{Float64}(nothing, nothing), BetaML.GMM.DiagonalGaussian{Float64}(nothing, nothing), BetaML.GMM.DiagonalGaussian{Float64}(nothing, nothing)], 
  tol = 1.0e-6, 
  minimum_variance = 0.05, 
  minimum_covariance = 0.0, 
  initialisation_strategy = "kmeans", 
  maximum_iterations = 9223372036854775807, 
  rng = Random._GLOBAL_RNG())

julia> modelMachine = machine(model2, rand(100, 10))
┌ Warning: The number and/or types of data arguments do not match what the specified model
│ supports. Suppress this type check by specifying `scitype_check_level=0`.
│ 
│ Run `@doc BetaML.GaussianMixtureClusterer` to learn more about your model's requirements.
│ 
│ Commonly, but non exclusively, supervised models are constructed using the syntax
│ `machine(model, X, y)` or `machine(model, X, y, w)` while most other models are
│ constructed with `machine(model, X)`.  Here `X` are features, `y` a target, and `w`
│ sample or class weights.
│ 
│ In general, data in `machine(model, data...)` is expected to satisfy
│ 
│     scitype(data) <: MLJ.fit_data_scitype(model)
│ 
│ In the present case:
│ 
│ scitype(data) = Tuple{AbstractMatrix{Continuous}}
│ 
│ fit_data_scitype(model) = Tuple{Table{<:AbstractVector{<:Union{Missing, Continuous}}}}
└ @ MLJBase ~/.julia/packages/MLJBase/mIaqI/src/machines.jl:231
untrained Machine; caches model-specific representations of data
  model: GaussianMixtureClusterer(n_classes = 3, )
  args: 
    1:	Source @867 ⏎ AbstractMatrix{Continuous}


julia> fit!(modelMachine)
[ Info: Training machine(GaussianMixtureClusterer(n_classes = 3, ), ).
Iter. 1:	Var. of the post  8.609516054119508 	  Log-likelihood -164.88840388654208
trained Machine; caches model-specific representations of data
  model: GaussianMixtureClusterer(n_classes = 3, )
  args: 
    1:	Source @867 ⏎ AbstractMatrix{Continuous}

julia> classes_est = predict(modelMachine,rand(10,10))
10-element CategoricalDistributions.UnivariateFiniteVector{Multiclass{3}, Int64, UInt32, Float64}:
 UnivariateFinite{Multiclass{3}}(1=>0.853, 2=>0.00419, 3=>0.142)
 UnivariateFinite{Multiclass{3}}(1=>0.951, 2=>0.0433, 3=>0.00529)
 UnivariateFinite{Multiclass{3}}(1=>0.101, 2=>0.287, 3=>0.613)
 UnivariateFinite{Multiclass{3}}(1=>0.224, 2=>0.69, 3=>0.0857)
 UnivariateFinite{Multiclass{3}}(1=>0.0488, 2=>0.0195, 3=>0.932)
 UnivariateFinite{Multiclass{3}}(1=>0.656, 2=>0.129, 3=>0.215)
 UnivariateFinite{Multiclass{3}}(1=>0.386, 2=>0.46, 3=>0.154)
 UnivariateFinite{Multiclass{3}}(1=>0.888, 2=>0.0113, 3=>0.101)
 UnivariateFinite{Multiclass{3}}(1=>0.617, 2=>0.238, 3=>0.145)
 UnivariateFinite{Multiclass{3}}(1=>0.621, 2=>0.0495, 3=>0.329)

This relates to a difference between the input scientific type of the data (rand(100,10)) and the one defined for the MLJ interface of the model (input_scitype = MMI.Table(Union{MMI.Continuous,MMI.Missing}), with MMI = MLJModelInterface). However I haven't changed anything there, perhaps MLJ has changed some definition of scientific types or it has become stricter ??

@ablaom what's your thought here ?

from betaml.jl.

sylvaticus avatar sylvaticus commented on May 20, 2024

As a second thought, I think the scientific type is another distinct issue... here the problem is that MLJ.@load is still loading the old model in the GMM submodule, and not the one that has moved in the Bmlj one in BetaML v0.11. I think it is just that @ablaom needs to update the MLJ registry...
In the meantime you can still use loading directly the model from BetaML as a workaround...

from betaml.jl.

sylvaticus avatar sylvaticus commented on May 20, 2024

@ablaom, why in the first entry of the Union that you propose there is a nested reference to MMI.Table (i.e. MMI.Table(MMi.Table(..)) ?

from betaml.jl.

ablaom avatar ablaom commented on May 20, 2024

Sorry, my mistake.

input_scitype = Union{
    MMI.Table(Union{MMI.Continuous,MMI.Missing}),
    AbstractMatrix{<:Union{MMI.Continuous,MMI.Missing}},
}

from betaml.jl.

sylvaticus avatar sylvaticus commented on May 20, 2024

Tested now, it should be fine (and I have added explicitly AbstractMatrix to the input_scitype)

from betaml.jl.

Related Issues (20)

Recommend Projects

  • React photo React

    A declarative, efficient, and flexible JavaScript library for building user interfaces.

  • Vue.js photo Vue.js

    🖖 Vue.js is a progressive, incrementally-adoptable JavaScript framework for building UI on the web.

  • Typescript photo Typescript

    TypeScript is a superset of JavaScript that compiles to clean JavaScript output.

  • TensorFlow photo TensorFlow

    An Open Source Machine Learning Framework for Everyone

  • Django photo Django

    The Web framework for perfectionists with deadlines.

  • D3 photo D3

    Bring data to life with SVG, Canvas and HTML. 📊📈🎉

Recommend Topics

  • javascript

    JavaScript (JS) is a lightweight interpreted programming language with first-class functions.

  • web

    Some thing interesting about web. New door for the world.

  • server

    A server is a program made to process requests and deliver data to clients.

  • Machine learning

    Machine learning is a way of modeling and interpreting data that allows a piece of software to respond intelligently.

  • Game

    Some thing interesting about game, make everyone happy.

Recommend Org

  • Facebook photo Facebook

    We are working to build community through open source technology. NB: members must have two-factor auth.

  • Microsoft photo Microsoft

    Open source projects and samples from Microsoft.

  • Google photo Google

    Google ❤️ Open Source for everyone.

  • D3 photo D3

    Data-Driven Documents codes.