Skip to content

File statistics.h

File List > include > simtix > statistics.h

Go to the documentation of this file

#pragma once

#include <fmt/format.h>
#include <fmt/ranges.h>
#include <toml++/toml.h>

#include <cstdint>
#include <deque>
#include <memory>
#include <numeric>
#include <stack>
#include <string>
#include <tuple>
#include <utility>
#include <variant>
#include <vector>

#define STAT(name, desc, unit, ...) name(this, #name, desc, unit, ##__VA_ARGS__)

namespace simtix {
namespace stat {

class StatBase;

class Group {
 public:
  explicit Group(const std::string &name) : name_(name) {}

  bool AddStat(const StatBase *s);
  bool AddStatGroup(const std::shared_ptr<Group> &g);
  toml::table Tabularize() const;

  const std::string &name() const { return name_; }

 protected:
  std::string name_;
  std::vector<const StatBase *> child_stats_;
  std::vector<std::shared_ptr<Group>> child_groups_;
};

class StatBase {
 public:
  explicit StatBase(Group *parent, const std::string &name,
                    const std::string &desc, const std::string &unit);
  const std::string &name() const { return name_; }
  const std::string &desc() const { return desc_; }
  const std::string &unit() const { return unit_; }

  virtual toml::table Tabularize() const = 0;

 protected:
  Group *parent_;
  std::string name_;
  std::string desc_;
  std::string unit_;
};

enum class Operation {
  kAddition = 0,
  kSubtraction,
  kMultiplication,
  kDivision
};

class ScalarBase;

class PostfixExpr {
 public:
  PostfixExpr() = default;

  // Copy
  PostfixExpr(const PostfixExpr &other) : expr_(other.expr_) {}

  PostfixExpr &operator=(const PostfixExpr &other) {
    expr_ = other.expr_;  // Copy
    return *this;
  }

  // Move
  PostfixExpr(PostfixExpr &&other) {
    expr_.swap(other.expr_);
    other.expr_.clear();
  }

  PostfixExpr &operator=(PostfixExpr &&other) {
    if (this != &other) {
      expr_.swap(other.expr_);
      other.expr_.clear();
    }
    return *this;
  }

  PostfixExpr operator+(const ScalarBase &s) {
    PushScalar(s);
    expr_.push_back(Operation::kAddition);
    return *this;
  }
  PostfixExpr operator-(const ScalarBase &s) {
    PushScalar(s);
    expr_.push_back(Operation::kSubtraction);
    return *this;
  }
  PostfixExpr operator*(const ScalarBase &s) {
    PushScalar(s);
    expr_.push_back(Operation::kMultiplication);
    return *this;
  }
  PostfixExpr operator/(const ScalarBase &s) {
    PushScalar(s);
    expr_.push_back(Operation::kDivision);
    return *this;
  }

  PostfixExpr operator+(const PostfixExpr &f) {
    PushExpr(f);
    expr_.push_back(Operation::kAddition);
    return *this;
  }
  PostfixExpr operator-(const PostfixExpr &f) {
    PushExpr(f);
    expr_.push_back(Operation::kSubtraction);
    return *this;
  }
  PostfixExpr operator*(const PostfixExpr &f) {
    PushExpr(f);
    expr_.push_back(Operation::kMultiplication);
    return *this;
  }
  PostfixExpr operator/(const PostfixExpr &f) {
    PushExpr(f);
    expr_.push_back(Operation::kDivision);
    return *this;
  }

  template <class Storage>
  std::tuple<Storage, std::string> Evaluate() const;

 protected:
  explicit PostfixExpr(const ScalarBase *opd) { PushScalar(*opd); }

  PostfixExpr(const ScalarBase *opd0, const ScalarBase *opd1, Operation op) {
    PushScalar(*opd0);
    PushScalar(*opd1);
    expr_.push_back(op);
  }

  PostfixExpr(const ScalarBase *opd0, const PostfixExpr &opd1, Operation op) {
    PushScalar(*opd0);
    PushExpr(opd1);
    expr_.push_back(op);
  }

  std::deque<std::variant<Operation, const ScalarBase *, double>> expr_;

  void PushExpr(const PostfixExpr &e) {
    expr_.insert(expr_.end(), e.expr_.begin(), e.expr_.end());
  }

  void PushScalar(const ScalarBase &s);

  friend class ScalarBase;
};

class ScalarBase : public StatBase {
 public:
  ScalarBase(double constant)  
      : StatBase(nullptr, fmt::to_string(constant), "constant", "N/A"),
        constant_(constant),
        is_constant_(true) {}
  explicit ScalarBase(Group *parent, const std::string &name,
                      const std::string &desc, const std::string &unit)
      : StatBase(parent, name, desc, unit), constant_(0), is_constant_(false) {}
  virtual explicit operator int64_t() const {
    return static_cast<int64_t>(constant_);
  }
  virtual explicit operator double() const {
    return static_cast<double>(constant_);
  }
  operator PostfixExpr() const { return PostfixExpr(this); }

  bool is_constant() const { return is_constant_; }

  PostfixExpr operator+(const ScalarBase &rhs) {
    return PostfixExpr(this, &rhs, Operation::kAddition);
  }
  PostfixExpr operator-(const ScalarBase &rhs) {
    return PostfixExpr(this, &rhs, Operation::kSubtraction);
  }
  PostfixExpr operator*(const ScalarBase &rhs) {
    return PostfixExpr(this, &rhs, Operation::kMultiplication);
  }
  PostfixExpr operator/(const ScalarBase &rhs) {
    return PostfixExpr(this, &rhs, Operation::kDivision);
  }

  PostfixExpr operator+(const PostfixExpr &rhs) {
    return PostfixExpr(this, rhs, Operation::kAddition);
  }
  PostfixExpr operator-(const PostfixExpr &rhs) {
    return PostfixExpr(this, rhs, Operation::kSubtraction);
  }
  PostfixExpr operator*(const PostfixExpr &rhs) {
    return PostfixExpr(this, rhs, Operation::kMultiplication);
  }
  PostfixExpr operator/(const PostfixExpr &rhs) {
    return PostfixExpr(this, rhs, Operation::kDivision);
  }

 private:
  toml::table Tabularize() const override {
    return toml::table{};
  }  

  double constant_;
  bool is_constant_;
};

template <typename T>
class Scalar : public ScalarBase {
 public:
  using Storage = T;
  explicit Scalar(Group *parent, const std::string &name,
                  const std::string &desc, const std::string &unit)
      : ScalarBase(parent, name, desc, unit) {
    storage_ = Storage();  // cppcheck-suppress useInitializationList
  }
  // operations for scalar stats
  Storage operator++() { return ++storage_; }
  Storage operator--() { return --storage_; }
  Storage operator++(int) { return storage_++; }
  Storage operator--(int) { return storage_--; }
  void operator=(const Storage &s) { storage_ = s; }
  void operator+=(const Storage &s) { storage_ += s; }
  void operator-=(const Storage &s) { storage_ -= s; }

  explicit operator int64_t() const override {
    return static_cast<int64_t>(storage_);
  }
  explicit operator double() const override {
    return static_cast<double>(storage_);
  }

  toml::table Tabularize() const override {
    return toml::table{
        {"val", storage_},
        {"unit", unit()},
        {"desc", desc()},
    };
  }

 protected:
  Storage storage_;
};

template <class T>
class Formula : public StatBase {
 public:
  using Storage = typename T::Storage;
  explicit Formula(Group *parent, const std::string &name,
                   const std::string &desc, const std::string &unit)
      : StatBase(parent, name, desc, unit) {}

  Formula &operator=(PostfixExpr &&other) {
    expr_ = std::move(other);
    return *this;
  }

  toml::table Tabularize() const override {
    auto [val, expr] = expr_.Evaluate<Storage>();
    return toml::table{
        {"val", val},
        {"unit", unit()},
        {"desc", fmt::format("{}; {}", desc(), expr)},
    };
  }

 protected:
  PostfixExpr expr_;
};

using Integer = Scalar<int64_t>;
using Real = Scalar<double>;

template <typename T>
class Vector : public StatBase {
 public:
  using Storage = typename T::Storage;
  explicit Vector(Group *parent, const std::string &name,
                  const std::string &desc, const std::string &unit,
                  uint64_t size)
      : StatBase(parent, name, desc, unit),
        sum_(*this),
        average_(*this),
        storage_(size, Storage()) {}
  Storage &operator[](uint64_t index) { return storage_[index]; }

  std::vector<Storage> &data() { return storage_; }
  const std::vector<Storage> &data() const { return storage_; }

  ScalarBase &sum() { return sum_; }
  ScalarBase &average() { return average_; }

  toml::table Tabularize() const override {
    toml::array arr;
    for (auto e : storage_) {
      arr.push_back(e);  // cppcheck-suppress useStlAlgorithm
    }
    return toml::table{
        {"val", arr},
        {"unit", unit()},
        {"desc", desc()},
    };
  }

 protected:
  class Reduction : public ScalarBase {
   public:
    Reduction(const Vector &vector, std::string_view type)
        : ScalarBase(nullptr, fmt::format("{}.{}", vector.name(), type), "",
                     ""),
          vector_(vector) {}

   protected:
    const Vector &vector_;
  };

  class Sum : public Reduction {
   public:
    explicit Sum(const Vector &parent) : Reduction(parent, "sum") {}

    explicit operator int64_t() const override {
      const auto &vec_data = Reduction::vector_.data();
      return static_cast<int64_t>(
          std::reduce(vec_data.begin(), vec_data.end()));
    }
    explicit operator double() const override {
      const auto &vec_data = Reduction::vector_.data();
      return static_cast<double>(std::reduce(vec_data.begin(), vec_data.end()));
    }
  } sum_;

  class Average : public Reduction {
   public:
    explicit Average(const Vector &parent) : Reduction(parent, "average") {}

    explicit operator int64_t() const override {
      const auto &vec_data = Reduction::vector_.data();
      return static_cast<int64_t>(
                 std::reduce(vec_data.begin(), vec_data.end())) /
             vec_data.size();
    }
    explicit operator double() const override {
      const auto &vec_data = Reduction::vector_.data();
      return static_cast<double>(
                 std::reduce(vec_data.begin(), vec_data.end())) /
             vec_data.size();
    }
  } average_;

  std::vector<Storage> storage_;
};

inline void PostfixExpr::PushScalar(const ScalarBase &s) {
  if (s.is_constant()) {
    expr_.push_back(static_cast<double>(s));
  } else {
    expr_.push_back(&s);
  }
}

template <class Storage>
inline std::tuple<Storage, std::string> PostfixExpr::Evaluate() const {
  auto expr = expr_;  // Copy
  std::stack<Storage> data_stack({Storage()});
  std::stack<std::string> expr_stack({""});
  while (!expr.empty()) {
    if (expr.front().index() == 0) {  // is operator
      Storage opd2 = data_stack.top();
      data_stack.pop();
      Storage opd1 = data_stack.top();
      data_stack.pop();

      std::string expr2 = std::move(expr_stack.top());
      expr_stack.pop();
      std::string expr1 = std::move(expr_stack.top());
      expr_stack.pop();

      Operation operation = std::get<Operation>(expr.front());
      switch (operation) {
        default:
        case Operation::kAddition:
          data_stack.push(opd1 + opd2);
          break;
        case Operation::kSubtraction:
          data_stack.push(opd1 - opd2);
          break;
        case Operation::kMultiplication:
          data_stack.push(opd1 * opd2);
          break;
        case Operation::kDivision:
          data_stack.push(opd1 / opd2);
          break;
      }
      expr_stack.push(fmt::format("({} {} {})", expr1, operation, expr2));
    } else if (expr.front().index() == 1) {
      const ScalarBase *scalar = std::get<const ScalarBase *>(expr.front());
      data_stack.push(static_cast<Storage>(*scalar));
      expr_stack.push(scalar->name());
    } else {
      double data = std::get<double>(expr.front());
      data_stack.push(static_cast<Storage>(data));
      expr_stack.push(fmt::to_string(data));
    }
    expr.pop_front();
  }
  return std::make_tuple(data_stack.top(), expr_stack.top());
}

}  // namespace stat

}  // namespace simtix

namespace fmt {

// fmt support for TaskDispatcher::Status
template <>
struct formatter<simtix::stat::Operation> : formatter<string_view> {
  template <typename FormatContext>
  auto format(simtix::stat::Operation op,
              FormatContext &ctx) const {  // NOLINT
    string_view name = "Unknown";
    switch (op) {
      case simtix::stat::Operation::kAddition:
        name = "+";
        break;
      case simtix::stat::Operation::kSubtraction:
        name = "-";
        break;
      case simtix::stat::Operation::kMultiplication:
        name = "*";
        break;
      case simtix::stat::Operation::kDivision:
        name = "/";
        break;
    }
    return formatter<string_view>::format(name, ctx);
  }
};

}  // namespace fmt