vstat
Loading...
Searching...
No Matches
combine.hpp
1// SPDX-License-Identifier: MIT
2// SPDX-FileCopyrightText: Copyright 2020-2024 Heal Research
3
4#ifndef VSTAT_COMBINE_HPP
5#define VSTAT_COMBINE_HPP
6
7#include <array>
8#include <cstddef>
9#include <cmath>
10#include <tuple>
11#include <eve/wide.hpp>
12#include <eve/module/core.hpp>
13
14#include "util.hpp"
15
16namespace VSTAT_NAMESPACE {
17
18namespace detail {
19 template<typename T>
20 requires eve::simd_value<T>
21 inline auto unpack(T v) -> auto
22 {
23 return [&]<std::size_t ...I>(std::index_sequence<I...>){
24 return std::array{ v.get(I) ... };
25 }(std::make_index_sequence<T::size()>{});
26 }
27} // namespace detail
28
29// The code below is based on:
30// Schubert et al. - Numerically Stable Parallel Computation of (Co-)Variance, p. 4, eq. 22-26
31// https://dbs.ifi.uni-heidelberg.de/files/Team/eschubert/publications/SSDBM18-covariance-authorcopy.pdf
32// merge covariance from individual data partitions A,B
33template<typename T>
34requires eve::simd_value<T> && (T::size() >= 2)
35inline auto combine(T sum_w, T sum_x, T sum_xx) -> double
36{
37 if constexpr (T::size() == 2) {
38 auto [n0, n1] = detail::unpack(sum_w);
39 auto [s0, s1] = detail::unpack(sum_x);
40
41 // this happens when we call stats on an accumulator that
42 // hasn't yet processed any values and the weight is zero
43 double f = 1. / (n0 * n1 * (n0 + n1));
44 if (!std::isfinite(f)) { f = 0; }
45 return eve::reduce(sum_xx) + f * eve::sqr(n1 * s0 - n0 * s1); // eq. 22
46 } else {
47 auto [sum_w0, sum_w1] = sum_w.slice();
48 auto [sum_x0, sum_x1] = sum_x.slice();
49 auto [sum_xx0, sum_xx1] = sum_xx.slice();
50
51 double n0 = eve::reduce(sum_w0);
52 double n1 = eve::reduce(sum_w1);
53
54 double s0 = eve::reduce(sum_x0);
55 double s1 = eve::reduce(sum_x1);
56
57 double q0 = combine(sum_w0, sum_x0, sum_xx0);
58 double q1 = combine(sum_w1, sum_x1, sum_xx1);
59
60 double f = 1. / (n0 * n1 * (n0 + n1));
61 if (!std::isfinite(f)) { f = 0; }
62 return q0 + q1 + f * eve::sqr(n1 * s0 - n0 * s1); // eq. 22
63 }
64}
65
66template<typename T>
67requires eve::simd_value<T> && (T::size() >= 2)
68inline auto combine(T sum_w, T sum_x, T sum_y, T sum_xx, T sum_yy, T sum_xy) -> std::tuple<double, double, double> // NOLINT
69{
70 if constexpr (T::size() == 2) {
71 auto [n0, n1] = detail::unpack(sum_w);
72 auto [sx0, sx1] = detail::unpack(sum_x);
73 auto [sy0, sy1] = detail::unpack(sum_y);
74 auto [sxx0, sxx1] = detail::unpack(sum_xx);
75 auto [syy0, syy1] = detail::unpack(sum_yy);
76 auto [sxy0, sxy1] = detail::unpack(sum_xy);
77
78 double f = 1. / (n0 * n1 * (n0 + n1));
79 if (!std::isfinite(f)) { f = 0; }
80
81 double sx = n1 * sx0 - n0 * sx1;
82 double sy = n1 * sy0 - n0 * sy1;
83 double sxx = sxx0 + sxx1 + f * sx * sx;
84 double syy = syy0 + syy1 + f * sy * sy;
85 double sxy = sxy0 + sxy1 + f * sx * sy;
86
87 return { sxx, syy, sxy };
88 } else {
89 auto [sum_w0, sum_w1] = sum_w.slice();
90 auto [sum_x0, sum_x1] = sum_x.slice();
91 auto [sum_y0, sum_y1] = sum_y.slice();
92 auto [sum_xx0, sum_xx1] = sum_xx.slice();
93 auto [sum_yy0, sum_yy1] = sum_yy.slice();
94 auto [sum_xy0, sum_xy1] = sum_xy.slice();
95
96 auto [sxx0, syy0, sxy0] = combine(sum_w0, sum_x0, sum_y0, sum_xx0, sum_yy0, sum_xy0);
97 auto [sxx1, syy1, sxy1] = combine(sum_w1, sum_x1, sum_y1, sum_xx1, sum_yy1, sum_xy1);
98
99 double n0 = eve::reduce(sum_w0);
100 double n1 = eve::reduce(sum_w1);
101 double sx0 = eve::reduce(sum_x0);
102 double sx1 = eve::reduce(sum_x1);
103 double sy0 = eve::reduce(sum_y0);
104 double sy1 = eve::reduce(sum_y1);
105
106 double f = 1. / (n0 * n1 * (n0 + n1));
107 if (!std::isfinite(f)) { f = 0; }
108
109 double sx = n1 * sx0 - n0 * sx1;
110 double sy = n1 * sy0 - n0 * sy1;
111 double sxx = sxx0 + sxx1 + f * sx * sx;
112 double syy = syy0 + syy1 + f * sy * sy;
113 double sxy = sxy0 + sxy1 + f * sx * sy;
114
115 return { sxx, syy, sxy };
116 }
117}
118} // namespace VSTAT_NAMESPACE
119
120#endif