vstat
Loading...
Searching...
No Matches
bivariate.hpp
1// SPDX-License-Identifier: MIT
2// SPDX-FileCopyrightText: Copyright 2020-2024 Heal Research
3
4#ifndef VSTAT_BIVARIATE_HPP
5#define VSTAT_BIVARIATE_HPP
6
7#include "combine.hpp"
8
9namespace VSTAT_NAMESPACE {
13template <typename T>
15 static auto load_state(T sx, T sy, T sw, T sxx, T syy, T sxy) noexcept -> bivariate_accumulator<T> // NOLINT
16 {
18 acc.sum_w = sw;
19 acc.sum_w_old = sw;
20 acc.sum_x = sx;
21 acc.sum_y = sy;
22 acc.sum_xx = sxx;
23 acc.sum_yy = syy;
24 acc.sum_xy = sxy;
25 return acc;
26 }
27
28 static auto load_state(std::tuple<T, T, T, T, T, T> state) noexcept -> bivariate_accumulator<T>
29 {
30 auto [sx, sy, sw, sxx, syy, sxy] = state;
31 return load_state(sx, sy, sw, sxx, syy, sxy);
32 }
33
34 inline void operator()(T x, T y) noexcept
35 {
36 T dx = x * sum_w - sum_x;
37 T dy = y * sum_w - sum_y;
38
39 sum_w += 1;
40
41 T f = 1. / (sum_w * sum_w_old);
42 sum_xx += f * dx * dx;
43 sum_yy += f * dy * dy;
44 sum_xy += f * dx * dy;
45
46 sum_x += x;
47 sum_y += y;
48
49 sum_w_old = sum_w;
50 }
51
52 inline void operator()(T x, T y, T w) noexcept // NOLINT
53 {
54 T dx = x * sum_w - sum_x;
55 T dy = y * sum_w - sum_y;
56
57 sum_x += x * w;
58 sum_y += y * w;
59 sum_w += w;
60
61 T f = w / (sum_w * sum_w_old);
62 sum_xx += f * dx * dx;
63 sum_yy += f * dy * dy;
64 sum_xy += f * dx * dy;
65
66 sum_w_old = sum_w;
67 }
68
69 template <typename U>
70 requires eve::simd_value<T> && eve::simd_compatible_ptr<U, T>
71 inline void operator()(U const* x, U const* y) noexcept
72 {
73 (*this)(T{x}, T{y});
74 }
75
76 template <typename U>
77 requires eve::simd_value<T> && eve::simd_compatible_ptr<U, T>
78 inline void operator()(U const* x, U const* y, U const* w) noexcept
79 {
80 (*this)(T{x}, T{y}, T{w});
81 }
82
83 // performs a reduction on the vector types and returns the sums and the squared residuals sums
84 auto stats() const noexcept -> std::tuple<double, double, double, double, double, double>
85 {
86 if constexpr (std::is_floating_point_v<T>) {
87 return { sum_w, sum_x, sum_y, sum_xx, sum_yy, sum_xy };
88 } else {
89 auto [sxx, syy, sxy] = combine(sum_w, sum_x, sum_y, sum_xx, sum_yy, sum_xy);
90 return { eve::reduce(sum_w), eve::reduce(sum_x), eve::reduce(sum_y), sxx, syy, sxy };
91 }
92 }
93
94private:
95 // sum of weights
96 T sum_w{0};
97 T sum_w_old{1};
98 // means
99 T sum_x{0};
100 T sum_y{0};
101 // squared residuals
102 T sum_xx{0};
103 T sum_yy{0};
104 T sum_xy{0};
105};
106
111 double count;
112 double sum_x;
113 double sum_y;
114 double ssr_x;
115 double ssr_y;
116 double sum_xy;
117 double mean_x;
118 double mean_y;
119 double variance_x;
120 double variance_y;
121 double sample_variance_x;
122 double sample_variance_y;
123 double correlation;
124 double covariance;
125 double sample_covariance;
126
127 template <typename T>
128 explicit bivariate_statistics(T accumulator)
129 {
130 auto [sw, sx, sy, sxx, syy, sxy] = accumulator.stats();
131 count = sw;
132 sum_x = sx;
133 sum_y = sy;
134 ssr_x = sxx;
135 ssr_y = syy;
136 sum_xy = sxy;
137 mean_x = sx / sw;
138 mean_y = sy / sw;
139 variance_x = sxx / sw;
140 variance_y = syy / sw;
141 sample_variance_x = sxx / (sw - 1);
142 sample_variance_y = syy / (sw - 1);
143
144 if (!(sxx > 0 && syy > 0)) {
145 correlation = static_cast<double>(sxx == syy);
146 } else {
147 correlation = sxy / std::sqrt(sxx * syy);
148 }
149
150 covariance = sxy / sw;
151 sample_covariance = sxy / (sw - 1);
152 }
153};
154
155inline auto operator<<(std::ostream& os, bivariate_statistics const& stats) -> std::ostream&
156{
157 os << "count: \t" << stats.count
158 << "\nsum_x: \t" << stats.sum_x
159 << "\nssr_x: \t" << stats.ssr_x
160 << "\nmean_x: \t" << stats.mean_x
161 << "\nvariance_x: \t" << stats.variance_x
162 << "\nsample variance_x:\t" << stats.sample_variance_x
163 << "\nsum_y: \t" << stats.sum_y
164 << "\nssr_y: \t" << stats.ssr_y
165 << "\nmean_y: \t" << stats.mean_y
166 << "\nvariance_y: \t" << stats.variance_y
167 << "\nsample variance_y:\t" << stats.sample_variance_y
168 << "\ncorrelation: \t" << stats.correlation
169 << "\ncovariance: \t" << stats.covariance
170 << "\nsample covariance:\t" << stats.sample_covariance
171 << "\n";
172 return os;
173}
174} // namespace VSTAT_NAMESPACE
175
176#endif
Bivariate accumulator object.
Definition bivariate.hpp:14
Bivariate statistics.
Definition bivariate.hpp:110