vstat
Loading...
Searching...
No Matches
univariate.hpp
1// SPDX-License-Identifier: MIT
2// SPDX-FileCopyrightText: Copyright 2020-2024 Heal Research
3
4#ifndef VSTAT_UNIVARIATE_HPP
5#define VSTAT_UNIVARIATE_HPP
6
7#include "combine.hpp"
8
9namespace VSTAT_NAMESPACE {
13template <typename T>
15 static auto load_state(T sw, T sx, T sxx) noexcept -> univariate_accumulator<T>
16 {
18 acc.sum_w = sw;
19 acc.sum_w_old = sw;
20 acc.sum_x = sx;
21 acc.sum_xx = sxx;
22 return acc;
23 }
24
25 static auto load_state(std::tuple<T, T, T> state) noexcept -> univariate_accumulator<T>
26 {
27 auto [sw, sx, sxx] = state;
28 return load_state(sw, sx, sxx);
29 }
30
31 inline void operator()(T x) noexcept
32 {
33 T dx = sum_w * x - sum_x;
34 sum_x += x;
35 sum_w += 1;
36 sum_xx += dx * dx / (sum_w * sum_w_old);
37 sum_w_old = sum_w;
38 }
39
40 inline void operator()(T x, T w) noexcept
41 {
42 x *= w;
43 T dx = sum_w * x - sum_x * w;
44 sum_x += x;
45 sum_w += w;
46 sum_xx += dx * dx / (w * sum_w * sum_w_old);
47 sum_w_old = sum_w;
48 }
49
50 template<typename U>
51 requires eve::simd_value<T> && eve::simd_compatible_ptr<U, T>
52 inline void operator()(U const* x) noexcept
53 {
54 (*this)(T{x});
55 }
56
57 template<typename U>
58 requires eve::simd_value<T> && eve::simd_compatible_ptr<U, T>
59 inline void operator()(U const* x, U const* w) noexcept
60 {
61 (*this)(T{x}, T{w});
62 }
63
64 // performs the reductions and returns { sum_w, sum_x, sum_xx }
65 [[nodiscard]] auto stats() const noexcept -> std::tuple<double, double, double>
66 {
67 if constexpr (std::is_floating_point_v<T>) {
68 return { sum_w, sum_x, sum_xx };
69 } else {
70 return { eve::reduce(sum_w), eve::reduce(sum_x), combine(sum_w, sum_x, sum_xx) };
71 }
72 }
73
74private:
75 T sum_w{0};
76 T sum_w_old{1};
77 T sum_x{0};
78 T sum_xx{0};
79};
80
85 double count;
86 double sum;
87 double ssr;
88 double mean;
89 double variance;
90 double sample_variance;
91
92 template <typename T>
93 explicit univariate_statistics(T const& accumulator)
94 {
95 auto [sw, sx, sxx] = accumulator.stats();
96 count = sw;
97 sum = sx;
98 ssr = sxx;
99 mean = sx / sw;
100 variance = sxx / sw;
101 sample_variance = sxx / (sw - 1);
102 }
103};
104
105inline auto operator<<(std::ostream& os, univariate_statistics const& stats) -> std::ostream&
106{
107 os << "count: \t" << stats.count
108 << "\nsum: \t" << stats.sum
109 << "\nssr: \t" << stats.ssr
110 << "\nmean: \t" << stats.mean
111 << "\nvariance: \t" << stats.variance
112 << "\nsample variance:\t" << stats.sample_variance
113 << "\n";
114 return os;
115}
116
117} // namespace VSTAT_NAMESPACE
118
119#endif
Univariate accumulator object.
Definition univariate.hpp:14
Univariate statistics.
Definition univariate.hpp:84