Comments (10)
I agree with all of this š
It will very likely live in rsample
. The timeline depends on how soon that grouped_df
check in dplyr
get's removed / altered so that this works out of the box.
from rsample.
@clauswilke, I've spent a bit of time thinking about this, and think that I have a nice solution (that we can't actually use yet unless dplyr
changes). It is an implementation of that "virtual groups" idea and modifies an attribute on the data frame that controls the grouping.
It requires a modification to dplyr cpp, but it's not a big one. If you want to play with it you can install my branch.
https://github.com/DavisVaughan/dplyr/tree/feature/virtual-bootstrap-groups
It's also insanely fast. bench::mark()
shows some impressive results compared to your solution (no offense). And more memory efficient!
library(dplyr)
library(broom)
bootstrapify <- function(.data, .n) {
group_df <- attr(.data, "groups")
group_rows <- group_df[[".rows"]]
new_row_indices <- purrr::map(group_rows, ~{
tibble(
.virtual_straps = paste0("id_", seq_len(.n)),
.rows = replicate(.n, sample(.x, length(.x), replace = TRUE), simplify = FALSE)
)
})
group_df[[".rows"]] <- new_row_indices
group_df <- tidyr::unnest(group_df, .rows)
attr(.data, "groups") <- group_df
.data
}
iris_g <- iris %>%
as_tibble() %>%
group_by(Species)
iris_g %>%
bootstrapify(5)
#> # A tibble: 150 x 5
#> # Groups: Species, .virtual_straps [15]
#> Sepal.Length Sepal.Width Petal.Length Petal.Width Species
#> <dbl> <dbl> <dbl> <dbl> <fct>
#> 1 5.1 3.5 1.4 0.2 setosa
#> 2 4.9 3 1.4 0.2 setosa
#> 3 4.7 3.2 1.3 0.2 setosa
#> 4 4.6 3.1 1.5 0.2 setosa
#> 5 5 3.6 1.4 0.2 setosa
#> 6 5.4 3.9 1.7 0.4 setosa
#> 7 4.6 3.4 1.4 0.3 setosa
#> 8 5 3.4 1.5 0.2 setosa
#> 9 4.4 2.9 1.4 0.2 setosa
#> 10 4.9 3.1 1.5 0.1 setosa
#> # ... with 140 more rows
iris_g %>%
bootstrapify(5) %>%
summarise(x = mean(Petal.Length))
#> # A tibble: 15 x 3
#> # Groups: Species [3]
#> Species .virtual_straps x
#> <fct> <chr> <dbl>
#> 1 setosa id_1 1.44
#> 2 setosa id_2 1.45
#> 3 setosa id_3 1.46
#> 4 setosa id_4 1.48
#> 5 setosa id_5 1.47
#> 6 versicolor id_1 4.31
#> 7 versicolor id_2 4.31
#> 8 versicolor id_3 4.28
#> 9 versicolor id_4 4.38
#> 10 versicolor id_5 4.20
#> 11 virginica id_1 5.56
#> 12 virginica id_2 5.52
#> 13 virginica id_3 5.60
#> 14 virginica id_4 5.55
#> 15 virginica id_5 5.47
iris_g %>%
bootstrapify(5) %>%
do(tidy(lm(Sepal.Length ~ Petal.Length, data = .)))
#> # A tibble: 30 x 7
#> # Groups: Species, .virtual_straps [15]
#> Species .virtual_straps term estimate std.error statistic p.value
#> <fct> <chr> <chr> <dbl> <dbl> <dbl> <dbl>
#> 1 setosa id_1 (Interceā¦ 4.28 0.391 10.9 1.20e-14
#> 2 setosa id_1 Petal.Leā¦ 0.432 0.273 1.58 1.20e- 1
#> 3 setosa id_2 (Interceā¦ 4.08 0.507 8.05 1.83e-10
#> 4 setosa id_2 Petal.Leā¦ 0.624 0.340 1.84 7.27e- 2
#> 5 setosa id_3 (Interceā¦ 3.36 0.368 9.14 4.47e-12
#> 6 setosa id_3 Petal.Leā¦ 1.14 0.246 4.65 2.62e- 5
#> 7 setosa id_4 (Interceā¦ 3.16 0.373 8.47 4.33e-11
#> 8 setosa id_4 Petal.Leā¦ 1.21 0.262 4.62 2.94e- 5
#> 9 setosa id_5 (Interceā¦ 4.47 0.433 10.3 9.14e-14
#> 10 setosa id_5 Petal.Leā¦ 0.393 0.292 1.35 1.85e- 1
#> # ... with 20 more rows
# no mutating allowed please
iris_g %>%
bootstrapify(5) %>%
mutate(x = 4)
#> Error: Column `.virtual_straps` is unknown
Created on 2018-09-18 by the reprex
package (v0.2.0).
from rsample.
~10x faster with 50 bootstraps here. ~1/3 memory
bench::mark(
bootstrap_summarize(
iris_g,
ndraws = 50,
mean_sepal_length = mean(Sepal.Length),
mean_petal_length = mean(Sepal.Length)
),
iterations = 100
)
#> Warning: Some expressions had a GC in every iteration; so filtering is
#> disabled.
#> # A tibble: 1 x 10
#> expression min mean median max `itr/sec` mem_alloc n_gc n_itr
#> <chr> <bch:> <bch:> <bch:> <bch:> <dbl> <bch:byt> <dbl> <int>
#> 1 bootstrapā¦ 35.7ms 42.5ms 41ms 94.2ms 23.6 1.54MB 121 100
#> # ... with 1 more variable: total_time <bch:tm>
bench::mark(
bootstrapify(iris_g, 50) %>%
summarise(
mean_sepal_length = mean(Sepal.Length),
mean_petal_length = mean(Sepal.Length)
),
iterations = 100
)
#> # A tibble: 1 x 10
#> expression min mean median max `itr/sec` mem_alloc n_gc n_itr
#> <chr> <bch:> <bch:> <bch:> <bch:> <dbl> <bch:byt> <dbl> <int>
#> 1 bootstrapā¦ 3.58ms 4.05ms 4.01ms 4.62ms 247. 685KB 10 90
#> # ... with 1 more variable: total_time <bch:tm>
from rsample.
This looks great! And no offense taken, I spent maybe 20 min. on my version.
I made a few modifications (see below). I think that, most importantly, the column that indicates the different virtual bootstraps has to have a good name (and maybe needs to be configurable), since it'll be user facing. I also wrote a function that can convert the virtual bootstraps into actual data. This is useful for gganimate.
Which package could this go into? rsample?
library(tidyverse)
library(broom)
library(rlang)
#>
#> Attaching package: 'rlang'
#> The following objects are masked from 'package:purrr':
#>
#> %@%, %||%, as_function, flatten, flatten_chr, flatten_dbl,
#> flatten_int, flatten_lgl, invoke, list_along, modify, prepend,
#> rep_along, splice
bootstrapify <- function(.data, .n, .key = ".draw") {
group_df <- attr(.data, "groups")
group_rows <- group_df[[".rows"]]
new_row_indices <- purrr::map(group_rows, ~{
tibble(
!!.key := factor(seq_len(.n)),
.rows = replicate(.n, sample(.x, length(.x), replace = TRUE), simplify = FALSE)
)
})
group_df[[".rows"]] <- new_row_indices
group_df <- tidyr::unnest(group_df, .rows)
attr(.data, "groups") <- group_df
.data
}
realize <- function(.data, .key = ".draw", .row_key = NULL) {
# verify that table is grouped by key; if not, return
groups <- attributes(.data)$groups
if (is.null(groups) || is.null(groups[[.key]])) {
return(.data)
}
.data <- do(.data, data = .) %>% select(!!.key, data) %>% unnest()
if (!is.null(.row_key)) {
.data <- mutate(.data, !!.row_key := 1)
.data[.row_key] = 1:nrow(.data)
}
.data
}
iris_g <- iris %>%
as_tibble() %>%
group_by(Species)
iris_small <- iris_g %>%
do(.[1:3,] )
iris_small %>%
bootstrapify(5)
#> # A tibble: 9 x 5
#> # Groups: Species, .draw [15]
#> Sepal.Length Sepal.Width Petal.Length Petal.Width Species
#> <dbl> <dbl> <dbl> <dbl> <fct>
#> 1 5.10 3.50 1.40 0.200 setosa
#> 2 4.90 3.00 1.40 0.200 setosa
#> 3 4.70 3.20 1.30 0.200 setosa
#> 4 7.00 3.20 4.70 1.40 versicolor
#> 5 6.40 3.20 4.50 1.50 versicolor
#> 6 6.90 3.10 4.90 1.50 versicolor
#> 7 6.30 3.30 6.00 2.50 virginica
#> 8 5.80 2.70 5.10 1.90 virginica
#> 9 7.10 3.00 5.90 2.10 virginica
iris_small %>% bootstrapify(5) %>% realize()
#> # A tibble: 45 x 6
#> .draw Sepal.Length Sepal.Width Petal.Length Petal.Width Species
#> <fct> <dbl> <dbl> <dbl> <dbl> <fct>
#> 1 1 4.90 3.00 1.40 0.200 setosa
#> 2 1 4.90 3.00 1.40 0.200 setosa
#> 3 1 5.10 3.50 1.40 0.200 setosa
#> 4 2 4.90 3.00 1.40 0.200 setosa
#> 5 2 4.90 3.00 1.40 0.200 setosa
#> 6 2 4.70 3.20 1.30 0.200 setosa
#> 7 3 5.10 3.50 1.40 0.200 setosa
#> 8 3 4.70 3.20 1.30 0.200 setosa
#> 9 3 4.90 3.00 1.40 0.200 setosa
#> 10 4 5.10 3.50 1.40 0.200 setosa
#> # ... with 35 more rows
iris_small %>% bootstrapify(5) %>% realize(.row_key = ".row") # useful for gganimate
#> # A tibble: 45 x 7
#> .draw Sepal.Length Sepal.Width Petal.Length Petal.Width Species .row
#> <fct> <dbl> <dbl> <dbl> <dbl> <fct> <int>
#> 1 1 5.10 3.50 1.40 0.200 setosa 1
#> 2 1 4.90 3.00 1.40 0.200 setosa 2
#> 3 1 4.90 3.00 1.40 0.200 setosa 3
#> 4 2 4.90 3.00 1.40 0.200 setosa 4
#> 5 2 5.10 3.50 1.40 0.200 setosa 5
#> 6 2 4.70 3.20 1.30 0.200 setosa 6
#> 7 3 5.10 3.50 1.40 0.200 setosa 7
#> 8 3 5.10 3.50 1.40 0.200 setosa 8
#> 9 3 4.90 3.00 1.40 0.200 setosa 9
#> 10 4 5.10 3.50 1.40 0.200 setosa 10
#> # ... with 35 more rows
iris_g %>%
bootstrapify(5) %>%
summarise(x = mean(Petal.Length))
#> # A tibble: 15 x 3
#> # Groups: Species [3]
#> Species .draw x
#> <fct> <fct> <dbl>
#> 1 setosa 1 1.50
#> 2 setosa 2 1.46
#> 3 setosa 3 1.49
#> 4 setosa 4 1.44
#> 5 setosa 5 1.46
#> 6 versicolor 1 4.36
#> 7 versicolor 2 4.18
#> 8 versicolor 3 4.28
#> 9 versicolor 4 4.31
#> 10 versicolor 5 4.33
#> 11 virginica 1 5.63
#> 12 virginica 2 5.59
#> 13 virginica 3 5.62
#> 14 virginica 4 5.52
#> 15 virginica 5 5.54
iris_g %>%
bootstrapify(5) %>%
do(tidy(lm(Sepal.Length ~ Petal.Length, data = .)))
#> # A tibble: 30 x 7
#> # Groups: Species, .draw [15]
#> Species .draw term estimate std.error statistic p.value
#> <fct> <fct> <chr> <dbl> <dbl> <dbl> <dbl>
#> 1 setosa 1 (Intercept) 4.39 0.425 10.3 8.26e-14
#> 2 setosa 1 Petal.Length 0.388 0.294 1.32 1.94e- 1
#> 3 setosa 2 (Intercept) 3.40 0.408 8.35 6.60e-11
#> 4 setosa 2 Petal.Length 1.07 0.271 3.96 2.49e- 4
#> 5 setosa 3 (Intercept) 4.12 0.404 10.2 1.37e-13
#> 6 setosa 3 Petal.Length 0.585 0.283 2.07 4.43e- 2
#> 7 setosa 4 (Intercept) 5.19 0.448 11.6 1.64e-15
#> 8 setosa 4 Petal.Length -0.140 0.310 -0.451 6.54e- 1
#> 9 setosa 5 (Intercept) 4.19 0.420 9.98 2.71e-13
#> 10 setosa 5 Petal.Length 0.541 0.290 1.87 6.77e- 2
#> # ... with 20 more rows
iris_g %>%
bootstrapify(5) %>%
mutate(x = 4)
#> Error: Column `.draw` is unknown
Created on 2018-09-18 by the reprex package (v0.2.0).
from rsample.
I think that, most importantly, the column that indicates the different virtual bootstraps has to have a good name (and maybe needs to be configurable)
I agree! And I like this solution and the.key
arg.
I also wrote a function that can convert the virtual bootstraps into actual data
I actually had this as well. Hadley suggested just using collect()
for this behavior.
I've made an S3 class on top of a grouped data frame called bootstrapped_df
so that we can dispatch correctly.
Would it be useful to return the actual row id when you call collect()
? I see you also do .row_key
that does 1:n
but this would be a bit different so the user can track down where a bootstrapped row came from.
I would imagine this living in rsample, but not yet sure how it would integrate with everything else.
from rsample.
Here is a temp repo I've started for this work.
https://github.com/DavisVaughan/strapgod
I believe I have set up the Remotes
field in the description file correctly so an install_github()
should just work
from rsample.
Would it be useful to return the actual row id when you call collect()? I see you also do .row_key that does 1:n but this would be a bit different so the user can track down where a bootstrapped row came from.
I think there are three indices that are potentially useful in different scenarios:
- The index of the bootstrap
- The index of the row id in the original data frame
- An index counting rows from 1 to nrow() in the final data frame. This is primarily useful for gganimate, and should definitely not be added by default
These could be turned on or off by setting the respective key variables to either NULL
or a value (as I did in my example), or alternatively via separate bool parameters. The problem with setting the key variables to NULL
is that if the default is NULL
then there won't be a standardized naming scheme, everybody will use their own names for the index columns. Not sure if this is something to worry about or not.
from rsample.
One more point: I think the index of the bootstrap should be returned by default. The other two can be off by default.
from rsample.
With strapgod all set now, we can close this issue. š
Thanks for your discussion!
from rsample.
This issue has been automatically locked. If you believe you have found a related problem, please file a new issue (with a reprex https://reprex.tidyverse.org) and link to this issue.
from rsample.
Related Issues (20)
- Sorting of strata in training data from initial_split
- Grouped resampling breaks with non-missing `strata = NULL` HOT 1
- `inner_split()`: keep everything inside of `split_args` or not? HOT 1
- inner_split(): better labels
- inner_split(): S3 method to retrive splitting arguments HOT 3
- inner_split(): no initial_split() arguments HOT 2
- inner_split(): global prop argument
- Update naming for elements relating to potato set
- audit for backticked package names
- Use cli errors for `R/bootci.R` HOT 1
- Use cli errors for `R/caret.R`
- Use cli errors in `R/initial_validation_split.R`
- Use cli errors in `R/labels.R`
- Use cli errors in `R/make_groups.R`, `R/mc.R`, `R/nest.R` HOT 1
- Use cli errors in `R/misc.R`
- Use cli errors in `R/permutations.R`, `R/reg_intervals.R`
- Use cli errors in `R/rset.R`
- Use cli errors in `R/rsplit.R` HOT 1
- Use cli errors in `R/slide.R`
- Use cli errors in `R/tidy.R`, `R/validation_set.R`, `R/vfold.R`
Recommend Projects
-
React
A declarative, efficient, and flexible JavaScript library for building user interfaces.
-
Vue.js
š Vue.js is a progressive, incrementally-adoptable JavaScript framework for building UI on the web.
-
Typescript
TypeScript is a superset of JavaScript that compiles to clean JavaScript output.
-
TensorFlow
An Open Source Machine Learning Framework for Everyone
-
Django
The Web framework for perfectionists with deadlines.
-
Laravel
A PHP framework for web artisans
-
D3
Bring data to life with SVG, Canvas and HTML. ššš
-
Recommend Topics
-
javascript
JavaScript (JS) is a lightweight interpreted programming language with first-class functions.
-
web
Some thing interesting about web. New door for the world.
-
server
A server is a program made to process requests and deliver data to clients.
-
Machine learning
Machine learning is a way of modeling and interpreting data that allows a piece of software to respond intelligently.
-
Visualization
Some thing interesting about visualization, use data art
-
Game
Some thing interesting about game, make everyone happy.
Recommend Org
-
Facebook
We are working to build community through open source technology. NB: members must have two-factor auth.
-
Microsoft
Open source projects and samples from Microsoft.
-
Google
Google ā¤ļø Open Source for everyone.
-
Alibaba
Alibaba Open Source for everyone
-
D3
Data-Driven Documents codes.
-
Tencent
China tencent open source team.
from rsample.