Shixiang Wang

>上士闻道
勤而行之

无影腿快不过组合拳?均值计算哪家强

王诗翔 · 2020-11-27

分类: r  
标签: r   计算效率  

昨天我在公众号推文中提了一个非常有意思的问题:mean()sum() / length() 哪一个更快?

我在知识星球看到有朋友已经测试过了,发现后者更快,为什么呢?

性能测试

我们先通过基准测试来比对两种方法的计算效率:

# 生成一组随机数
x <- rnorm(1e6)
# 测试
microbenchmark::microbenchmark(
  mean = mean(x),
  comb = sum(x) / length(x)
)
#> Unit: microseconds
#>  expr  min   lq mean median   uq  max neval cld
#>  mean 1787 1888 1976   1958 2036 2488   100   b
#>  comb  919  973 1062   1014 1057 2752   100  a

从 100 万数据的均值计算来看,组合拳差不多比 mean() 快 1倍。再问一句,为什么呢?

性能探究

想要深入理解它们的性能差异,我们有必要了解 3 个函数的结构:

# 均值
mean
#> function (x, ...) 
#> UseMethod("mean")
#> <bytecode: 0x7f9db263c438>
#> <environment: namespace:base>
# 求和
sum
#> function (..., na.rm = FALSE)  .Primitive("sum")
# 取长度
length
#> function (x)  .Primitive("length")

不难发现 mean() 是一个泛型函数,而后两者都是一类 .Primitive 的元素,我们来了解一下它是什么。

R语言中有些函数是通过接口 .Primitive() 直接调用的 C 语言代码,而不是用 R 语言代码编写的。这些函数被称元函数(Primitive functions)。元函数仅在R基础包base中出现。因为元函数用底层语言写成,所以他们通常计算效率更高。但是也因为他们用C语言而不是用R语言写成。他们的行为方式也可能与 R 语言的其他函数不一样。

引自 R 语言中的函数

这就正常了,C 语言毕竟是性能之王。

microbenchmark::microbenchmark(
  mean = .Internal(mean(x)),
  comb = sum(x) / length(x)
)
#> Unit: microseconds
#>  expr  min   lq mean median   uq  max neval cld
#>  mean 1796 1899 2079   1954 2087 3497   100   b
#>  comb  927  992 1096   1045 1125 1932   100  a

我们最后再看看 R 分派用来计算数值均值的函数:

mean.default
#> function (x, trim = 0, na.rm = FALSE, ...) 
#> {
#>     if (!is.numeric(x) && !is.complex(x) && !is.logical(x)) {
#>         warning("argument is not numeric or logical: returning NA")
#>         return(NA_real_)
#>     }
#>     if (na.rm) 
#>         x <- x[!is.na(x)]
#>     if (!is.numeric(trim) || length(trim) != 1L) 
#>         stop("'trim' must be numeric of length one")
#>     n <- length(x)
#>     if (trim > 0 && n) {
#>         if (is.complex(x)) 
#>             stop("trimmed means are not defined for complex data")
#>         if (anyNA(x)) 
#>             return(NA_real_)
#>         if (trim >= 0.5) 
#>             return(stats::median(x, na.rm = FALSE))
#>         lo <- floor(n * trim) + 1
#>         hi <- n + 1 - lo
#>         x <- sort.int(x, partial = unique(c(lo, hi)))[lo:hi]
#>     }
#>     .Internal(mean(x))
#> }
#> <bytecode: 0x7f9dbb0fe550>
#> <environment: namespace:base>

我们可以看到有很多的条件判断,而最后一句是计算的核心代码,只保留它能加速吗?

microbenchmark::microbenchmark(
  mean = mean(x),
  mean_internal = .Internal(mean(x)),
  comb = sum(x) / length(x)
)
#> Unit: microseconds
#>           expr  min   lq mean median   uq  max neval cld
#>           mean 2109 3051 3115   3128 3267 3502   100   b
#>  mean_internal 1947 3072 3120   3164 3282 3524   100   b
#>           comb  993 1565 1580   1629 1699 1764   100  a

从结果来看,只执行最后一句并没有性能提升,反而速度有所下降。从代码中查看可以看出在调用最后一句计算代码之前进行过排序操作,显然这些 R 代码是有比较大的意义的。

话说这里的 .Internal 又是什么?查文档。

.Internal performs a call to an internal code which is built in to the R interpreter.

Only true R wizards should even consider using this function, and only R developers can add to the list of internal functions.

再看代码:

.Internal
#> function (call)  .Primitive(".Internal")

原来是个 C 写的函数。

小结

综上,组合拳求数值的均值要快一倍。

细心的读者可能会问为啥 R 不默认用组合拳求均值?

前面已经提到 mean() 是泛型函数,它支持多个数据类型的操作,使用更加广泛:

.S3methods("mean")
#>  [1] mean.Date        mean.default     mean.difftime    mean.POSIXct    
#>  [5] mean.POSIXlt     mean.quosure*    mean.vctrs_vctr* mean.yearmon*   
#>  [9] mean.yearqtr*    mean.zoo*       
#> see '?methods' for accessing help and source code