Shixiang Wang

>上士闻道
勤而行之

用purrr实现迭代

王诗翔 · 2018-10-04

分类: r  
标签: r   purrr   loop  

函数有3个好处:

除了函数,减少重复代码的另一种工具是迭代,它的作用在于可以对多个输入执行同一种处理,比如对多个列或多个数据集进行同样的操作。

迭代方式主要有两种:

准备工作

purrr是tidyverse的核心r包之一,提供了一些更加强大的编程工具。

library(tidyverse)
#> -- Attaching packages ------------------------------------------------------------ tidyverse 1.3.0 --
#> v ggplot2 3.3.2     v purrr   0.3.4
#> v tibble  3.0.3     v dplyr   1.0.0
#> v tidyr   1.1.0     v stringr 1.4.0
#> v readr   1.3.1     v forcats 0.5.0
#> -- Conflicts --------------------------------------------------------------- tidyverse_conflicts() --
#> x dplyr::filter() masks stats::filter()
#> x dplyr::lag()    masks stats::lag()

for循环与函数式编程

因为R是一门函数式编程语言,我们可以先将for循环包装在函数中,然后再调用函数,而不是使用for循环,因此for循环在R中不像在其他编程语言中那么重要。

为了说明函数式编程,我们先利用下面简单的数据框进行一些思考:

df = tibble(
    a = rnorm(10),
    b = rnorm(10),
    c = rnorm(10),
    d = rnorm(10)
)

如果想要计算每列的均值,我们使用for循环完成任务:

output = vector("double", length(df))
for (i in seq_along(df)) {
    output[[i]] = mean(df[[i]])
}
output
#> [1]  0.103 -0.315  0.716 -0.402

然后我们可能意识到需要频繁地计算每列的均值,因此将代码提取出来,转换为一个函数:

col_mean = function(df) {
    output = vector("double", length(df))
    for ( i in seq_along(df)) {
        output[i] = mean(df[[i]])
    }
    
    output
}

然后我们觉得可能还需要这样计算每列的中位数和标准差,因此复制粘贴了col_mean(),并使用相应的median()sd()函数替换了mean()函数:

col_median = function(df) {
    output = vector("double", length(df))
    for ( i in seq_along(df)) {
        output[i] = median(df[[i]])
    }
    
    output
}
col_sd = function(df) {
    output = vector("double", length(df))
    for ( i in seq_along(df)) {
        output[i] = sd(df[[i]])
    }
    
    output
}

(有时候我还真这么干的。)

哎呀,我们又复制粘贴了2次代码,因此是不是该思考下如何扩展一个代码让它同时发挥几个函数的功能呢?这段代码的大部分是一个for循环,而且如果不仔细很难看出3个函数有什么差别。

通过添加支持函数到每列的参数,我们可以使用同一个函数解决3个问题:

col_summary = function(df, fun){
    out = vector("double", length(df))
    for (i in seq_along(df)) {
        out[i] = fun(df[[i]])
    }
    out
}
col_summary(df, median)
#> [1]  0.425 -0.471  0.670 -0.170
col_summary(df, mean)
#> [1]  0.103 -0.315  0.716 -0.402

将函数作为参数传入另一个函数的做法是一种非常强大的功能,我们需要花些时间理解这种思想,但绝对是值得的。接下来我们将学习和使用purrr包,它提供的函数可以替代很多常见的for循环应用。R基础包中的apply应用函数族也可以完成类似的任务,但purrr包的函数更一致,也更容易学习。

使用purrr函数替代for循环的目的是将常见的列表问题分解为独立的几部分

映射函数

先对向量进行循环,然后对其每一个元素进行一番处理,最后保存结果。这种模式太普遍了,因而purrr包提供了一个函数族替我们完成这种操作。每种类型的输出都有一个相应的函数:

每个函数都使用一个向量(注意列表可以作为递归向量看待)作为输入,并对向量的每个元素应用一个函数,然后返回和输入向量同样长度的一个新向量。向量的类型由映射函数的后缀决定。

使用map()函数族的优势不是速度,而是简洁:它可以让我们的代码更易编写,也更易阅读。

下面是进行上一节一样的操作:

library(purrr)
map_dbl(df, mean)
#>      a      b      c      d 
#>  0.103 -0.315  0.716 -0.402
map_dbl(df, median)
#>      a      b      c      d 
#>  0.425 -0.471  0.670 -0.170
map_dbl(df, sd)
#>    a    b    c    d 
#> 1.21 1.11 1.01 1.21

**与for循环相比,映射函数的重点在于需要执行的操作(即mean()median()sd()),而不是在所有元素中循环所需的跟踪记录以及保存结果。使用管道时这一点尤为突出:

df %>% map_dbl(mean)
#>      a      b      c      d 
#>  0.103 -0.315  0.716 -0.402
df %>% map_dbl(median)
#>      a      b      c      d 
#>  0.425 -0.471  0.670 -0.170
df %>% map_dbl(sd)
#>    a    b    c    d 
#> 1.21 1.11 1.01 1.21

map_*()col_summary()具有以下几点区别: * 所有的purrr函数都是用C实现的,这让它们的速度非常快,但牺牲了一些可读性。 * 第二个参数可以是一个公式、一个字符向量或一个整型向量。 * map_*()使用....f传递一些附加参数,供每次调用时使用 * 映射函数还保留名称 ### 快捷方式 对于第二个参数.f,我们可以使用几种快捷方式来减少输入量。比如我们现在想对某个数据集中的每一个分组都拟合一个线性模型,下面示例将mtcars数据集拆分为3个部分(按照气缸值分类),并对每个部分拟合一个线性模型:

models = mtcars %>% 
    split(.$cyl) %>% 
    map(function(df) lm(mpg ~ wt, data = df))

因为在R中创建匿名函数的语法比较复杂,所以purrr提供了一种更方便的快捷方式——单侧公式:

models = mtcars %>% 
    split(.$cyl) %>% 
    map(~lm(mpg ~ wt, data = .))

上面.作为一个代词:它表示当前列表元素(与for循环中用i表示当前索引是一样的)。 当检查多个模型时,有时候我们需要提取像R方这样的摘要统计量,要想完成这个任务,我们需要先运行summary()函数,然后提取结果中的r.squared:

models %>% 
    map(summary) %>% 
    map_dbl(~.$r.squared)
#>     4     6     8 
#> 0.509 0.465 0.423

因为提取命名成分操作非常普遍,所以purrr提供了一种更简单的快捷方式:使用字符串。

models %>% 
    map(summary) %>% 
    map_dbl("r.squared")
#>     4     6     8 
#> 0.509 0.465 0.423

对操作失败的处理

当使用映射函数重复多次操作时,某次操作失败的概率大大增加。这个时候我们会收到一条错误信息,但得不到任何结果。这让人很恼火!我们怎么保证不会出现一条鱼腥了一锅汤? safely()是一个修饰函数(副词),它接收一个函数(动词),对其进行修改并返回修改后的函数。这样,修改后的函数就不会抛出错误,相反,它总是返回由下面两个元素组成的列表: * result - 原始结果。如果出现错误,那么它就是NULL * error - 错误对象。如果操作成功,那么它就是NULL 下面用log()函数进行说明:

safe_log = safely(log)
str(safe_log(10))
#> List of 2
#>  $ result: num 2.3
#>  $ error : NULL
str(safe_log("a"))
#> List of 2
#>  $ result: NULL
#>  $ error :List of 2
#>   ..$ message: chr "non-numeric argument to mathematical function"
#>   ..$ call   : language .Primitive("log")(x, base)
#>   ..- attr(*, "class")= chr [1:3] "simpleError" "error" "condition"

safely()函数也可以与map()共同使用:

x = list(1, 10, "a")
y = x %>% map(safely(log))
str(y)
#> List of 3
#>  $ :List of 2
#>   ..$ result: num 0
#>   ..$ error : NULL
#>  $ :List of 2
#>   ..$ result: num 2.3
#>   ..$ error : NULL
#>  $ :List of 2
#>   ..$ result: NULL
#>   ..$ error :List of 2
#>   .. ..$ message: chr "non-numeric argument to mathematical function"
#>   .. ..$ call   : language .Primitive("log")(x, base)
#>   .. ..- attr(*, "class")= chr [1:3] "simpleError" "error" "condition"

如果将以上结果转换为2个列表,一个列表包含所有错误对象,另一个列表包含所有原始结果,那么处理起来就会更容易。我们可以使用purrr::transpose()函数轻松完成该任务

y = y %>% transpose()
str(y)
#> List of 2
#>  $ result:List of 3
#>   ..$ : num 0
#>   ..$ : num 2.3
#>   ..$ : NULL
#>  $ error :List of 3
#>   ..$ : NULL
#>   ..$ : NULL
#>   ..$ :List of 2
#>   .. ..$ message: chr "non-numeric argument to mathematical function"
#>   .. ..$ call   : language .Primitive("log")(x, base)
#>   .. ..- attr(*, "class")= chr [1:3] "simpleError" "error" "condition"

我们可以自行决定如何处理错误对象,一般来说,我们应该检查一下y中错误对象所对应的x值,或者使用y中的正常结果进行一些处理:

is_ok = y$error %>% map_lgl(is_null)
x[!is_ok]
#> [[1]]
#> [1] "a"
y$result[is_ok] %>% flatten_dbl()
#> [1] 0.0 2.3

purrr还提供了两个有用的修饰函数: * 与safely()类似,possibly()函数总是会成功返回。它比safely()还要简单一些,因为可以设定出现错误时返回一个默认值:

x = list(1, 10, "a")
x %>% map_dbl(possibly(log, NA_real_))
#> [1] 0.0 2.3  NA
x = list(1, -1)
x %>% map(quietly(log)) %>% str()
#> List of 2
#>  $ :List of 4
#>   ..$ result  : num 0
#>   ..$ output  : chr ""
#>   ..$ warnings: chr(0) 
#>   ..$ messages: chr(0) 
#>  $ :List of 4
#>   ..$ result  : num NaN
#>   ..$ output  : chr ""
#>   ..$ warnings: chr "NaNs produced"
#>   ..$ messages: chr(0)
x %>% map(safely(log)) %>% str()
#> Warning in .Primitive("log")(x, base): NaNs produced
#> List of 2
#>  $ :List of 2
#>   ..$ result: num 0
#>   ..$ error : NULL
#>  $ :List of 2
#>   ..$ result: num NaN
#>   ..$ error : NULL

多参数映射

前面我们提到的映射函数都是对单个输入进行映射,但有时候我们需要多个相关输入同步迭代,这就是map2()pmap()函数的用武之地。 例如我们想模拟几个均值不同的随机正态分布,我们可以使用map完成这个任务:

mu = list(5, 10, -3)
mu %>% 
    map(rnorm, n = 5) %>% 
    str()
#> List of 3
#>  $ : num [1:5] 5.39 5.84 5.18 5.75 4.48
#>  $ : num [1:5] 10.03 9.89 8.87 11.75 11.21
#>  $ : num [1:5] -5.28 -4.2 -2.43 -1.16 -2.97

如果我们想让标准差也不同,一种方法是使用均值向量和标准差向量的索引进行迭代:

sigma = list(1, 5, 10)
seq_along(mu) %>% 
    map(~rnorm(5, mu[[.]], sigma[[.]])) %>% 
    str()
#> List of 3
#>  $ : num [1:5] 5.43 3.79 5.62 5.76 7.18
#>  $ : num [1:5] 9.9 12.68 7.09 12.07 9.07
#>  $ : num [1:5] -5.02 -2.57 -1.72 14.5 -11.31

但这种方式比较难理解,我们使用map2()进行同步迭代:

map2(mu, sigma, rnorm, n = 5) %>% str()
#> List of 3
#>  $ : num [1:5] 4.99 3.69 5.29 4.79 3.78
#>  $ : num [1:5] 13.3 9.53 13.04 3.88 6.19
#>  $ : num [1:5] -17.15 -3.18 -3.7 -6.63 -5.78

注意这里每次调用时值发生变换的参数要放在映射函数前面,值不变的参数要放在映射函数后面。 和map()函数一样,map2()函数也是对for循环的包装:

map2 = function(x, y, f, ...){
    out = vector("list", length(x))
    for (i in seq_along(x)) {
        out[[i]] = f(x[[i]], y[[i]], ...)
    }
    out
}

(实际的map2()并不是这样的,此处是给出R实现的一种思想) 根据这个函数,我们可以涉及map3()map4()等等,但这样实在无聊。purrr提供了pmap()函数,它可以将列表作为参数。如果我们想要生成均值、标准差和样本数都不同的正态分布,可以使用:

n = list(1, 3, 5)
args1 = list(n, mu, sigma)
args1 %>% 
    pmap(rnorm) %>% 
    str()
#> List of 3
#>  $ : num 4.39
#>  $ : num [1:3] 1.05 22.01 10.26
#>  $ : num [1:5] -9.26 -7.14 4.41 3.01 7.69

如果没有为列表元素命名,那么pmap()在调用函数时会按照位置匹配。这样做容易出错而且可读性差,因此最后使用命名参数:

args2 = list(mean = mu, sd = sigma, n = n)
args2 %>% 
    pmap(rnorm) %>% 
    str()
#> List of 3
#>  $ : num 5.6
#>  $ : num [1:3] 12.61 4.43 11.71
#>  $ : num [1:5] -4.73 -1.41 -6.44 1.37 -15.55

这样更加安全。

因为长度都相同,所以将各个参数保存在一个数据框中:

params = tibble::tribble(
    ~mean, ~sd, ~n,
    5, 1, 1,
    10, 5, 3,
    -3, 10, 5
)
params %>% 
    pmap(rnorm)
#> [[1]]
#> [1] 3.44
#> 
#> [[2]]
#> [1] 13.23  7.33 17.82
#> 
#> [[3]]
#> [1]  -8.034 -11.547 -12.785   6.243  -0.937

调用不同的函数

还有一种更复杂的情况:不但传给函数的参数不同,甚至函数本身也是不同的。

f = c("runif", "rnorm", "rpois")
param = list(
    list(min = -1, max = 1),
    list(sd = 5),
    list(lambda = 10)
)

为了处理这种情况,我们使用invoke_map()函数:

invoke_map(f, param, n = 5) %>% str()
#> List of 3
#>  $ : num [1:5] -0.73 -0.352 0.377 -0.442 0.294
#>  $ : num [1:5] -2.22 -2.87 3.42 5.6 8.97
#>  $ : int [1:5] 13 5 12 8 7

第1个参数是一个函数列表或包含函数名称的字符串向量。第2个参数是列表的一个列表,给出了要传给各个函数的不同参数。随后的参数要传给每个函数

我们使用tribble()让参数配对更容易:

sim = tibble::tribble(
    ~f, ~params,
    "runif", list(min = -1, max = 1),
    "rnorm", list(sd = 5),
    "rpois", list(lambda = 10)
)
sim %>% 
    dplyr::mutate(sim = invoke_map(f, params, n = 10))
#> # A tibble: 3 x 3
#>   f     params           sim       
#>   <chr> <list>           <list>    
#> 1 runif <named list [2]> <dbl [10]>
#> 2 rnorm <named list [1]> <dbl [10]>
#> 3 rpois <named list [1]> <int [10]>

游走函数

当使用函数的目的是向屏幕提供输出或将文件保存到磁盘——重要的是操作过程而不是返回值,我们应该使用游走函数,而不是映射函数。

下面是一个示例:

x = list(1, "a", 3)
x %>% 
    walk(print)
#> [1] 1
#> [1] "a"
#> [1] 3

一般来说,walk()函数不如walk2()pwalk()实用。例如有一个图形列表和一个文件名向量,那么我们就可以使用pwalk()将每个文件保存到相应的磁盘位置:

library(ggplot2)
plots = mtcars %>% 
    split(.$cyl) %>% 
    map(~ggplot(., aes(mpg, wt)) + geom_point())
paths = stringr::str_c(names(plots), ".pdf")
pwalk(list(paths, plots), ggsave, path = tempdir())
#> Saving 7 x 5 in image
#> Saving 7 x 5 in image
#> Saving 7 x 5 in image

我们来查看一下是不是建立好了:

dir(tempdir())
#> [1] "4.pdf" "6.pdf" "8.pdf"

for循环的其他模式

purrr还提供了其他一些函数,虽然这些函数的使用率低,但了解还是有必要的。本节就是对它们进行简单介绍

预测函数

一些函数可以与返回TRUEFALSE的预测函数一同使用。

keep()discard()函数可以分别保留输入中预测值为TRUEFALSE的元素(在数据框中就是指列):

iris %>% 
    keep(is.factor) %>% 
    str()
#> 'data.frame':    150 obs. of  1 variable:
#>  $ Species: Factor w/ 3 levels "setosa","versicolor",..: 1 1 1 1 1 1 1 1 1 1 ...
iris %>% 
    discard(is.factor) %>% 
    str()
#> 'data.frame':    150 obs. of  4 variables:
#>  $ Sepal.Length: num  5.1 4.9 4.7 4.6 5 5.4 4.6 5 4.4 4.9 ...
#>  $ Sepal.Width : num  3.5 3 3.2 3.1 3.6 3.9 3.4 3.4 2.9 3.1 ...
#>  $ Petal.Length: num  1.4 1.4 1.3 1.5 1.4 1.7 1.4 1.5 1.4 1.5 ...
#>  $ Petal.Width : num  0.2 0.2 0.2 0.2 0.2 0.4 0.3 0.2 0.2 0.1 ...

some()every()函数分别用来确定预测值是否对某个元素为真以及是否对所有元素为真:

x = list(1:5, letters, list(10))
x %>% 
    some(is_character)
#> [1] TRUE
x %>% 
    every(is_vector)
#> [1] TRUE

detect()可以找出预测值为真的第一个元素,detect_index()可以返回该元素的索引。

x = sample(10)
x
#>  [1]  9  2 10  3  5  4  1  7  8  6
x %>% 
    detect(~ . >5)
#> [1] 9
x %>% 
    detect_index(~ . >5)
#> [1] 1

head_while()tail_while()分别从向量的开头和结尾找出预测值为真的元素:

x %>% 
    head_while(~ . > 5)
#> [1] 9
x %>% 
    tail_while(~ . > 5)
#> [1] 7 8 6

归约和累计

操作一个复杂的列表,有时候我们想要不断合并两个预算两个元素(基础函数Reduce干的事情)。

dfs = list(
    age = tibble(name = "John", age = 30),
    sex = tibble(name = c("John", "Mary"), sex = c("M", "F")),
    trt = tibble(name = "Mary", treatment = "A")
)
dfs %>% reduce(full_join)
#> Joining, by = "name"
#> Joining, by = "name"
#> # A tibble: 2 x 4
#>   name    age sex   treatment
#>   <chr> <dbl> <chr> <chr>    
#> 1 John     30 M     <NA>     
#> 2 Mary     NA F     A

这里我们使用reduce结合dplyr中的full_join()将它们轻松合并为一个数据框。

reduce()函数使用一个“二元函数”(即两个基本输入),将其不断应用于一个列表,直到最后只剩下一个元素。

累计函数与归约函数类似,但会保留中间结果,比如下面求取累计和:

x = sample(10)
x
#>  [1]  3  9 10  7  4  2  5  8  6  1
x %>% accumulate(`+`)
#>  [1]  3 12 22 29 33 35 40 48 54 55