Skip to content

File instr_op_fused.h

File List > projects > simtix > src > simtix > sm > instr_op_fused.h

Go to the documentation of this file

#pragma once

#include <simtix/riscv.h>

#include <algorithm>
#include <vector>

#include "common/softfloat_util.h"
#include "sm/instr.h"
#include "sm/warp.h"

extern "C" {
#include "softfloat.h"
}

namespace simtix {
class InstrOpFused : public Instr {
 public:
  using Op = void (InstrOpFused::*)();

  void Decode() override;
  void Issue() override;
  void OperandCollect() override;
  void Execute() override;
  void Commit() override;
  void Reset() override;

  void Assign(const Instr *) override;
  bool CanIssue() const override;
  bool CanExecute() const override {
    return rs1_data_ready_ && rs2_data_ready_ && rs3_data_ready_;
  };
  bool CanCommit() const override { return executed_; }
  bool CanRetire() const override { return committed_; }

 protected:
  InstrOpFused(Warp *warp, uint32_t iword, uint64_t wpc);

  void Reinitialize(Warp *warp, uint32_t iword, uint64_t wpc) override;

  std::vector<uint64_t> rs1_data_;
  std::vector<uint64_t> rs2_data_;
  std::vector<uint64_t> rs3_data_;
  std::vector<uint64_t> rd_data_;
  std::vector<uint8_t> fp_exception_;
  std::vector<uint32_t> fp_fmt_;

  // Used to determine where the operands come from/store to.
  uint8_t rs3_;
  uint32_t rm_;
  uint32_t fmt_;

  Op op_;

  bool rs1_data_ready_ = false;
  bool rs2_data_ready_ = false;
  bool rs3_data_ready_ = false;
  bool executed_ = false;
  bool committed_ = false;

 private:
  inline bool has_fatal_exception() {
    return std::any_of(fp_exception_.begin(), fp_exception_.end(),
                       [](uint8_t flag) {
                         return flag & (riscv::DZ | riscv::NV);
                       });
  }
  void nop_() {}
  // FMADD
  void fmadd_s_() {
    for (Thread *t : active_threads_) {
      uint32_t tid = t->tid();
      softfloat_init(rm_);
      auto res =
          f32_mulAdd(to_float32(rs1_data_[tid]), to_float32(rs2_data_[tid]),
                     to_float32(rs3_data_[tid]));
      rd_data_[tid] = from_float32(res);
      fp_exception_[tid] = softfloat_exceptionFlags;
    }
  }
  void fmadd_d_() {
    for (Thread *t : active_threads_) {
      uint32_t tid = t->tid();
      softfloat_init(rm_);
      auto res =
          f64_mulAdd(to_float64(rs1_data_[tid]), to_float64(rs2_data_[tid]),
                     to_float64(rs3_data_[tid]));
      rd_data_[tid] = from_float64(res);
      fp_exception_[tid] = softfloat_exceptionFlags;
    }
  }
  // FMSUB
  void fmsub_s_() {
    for (Thread *t : active_threads_) {
      uint32_t tid = t->tid();
      softfloat_init(rm_);
      auto neg_rs3 = rs3_data_[tid] ^ F32_SIGN;
      auto res = f32_mulAdd(to_float32(rs1_data_[tid]),
                            to_float32(rs2_data_[tid]), to_float32(neg_rs3));
      rd_data_[tid] = from_float32(res);
      fp_exception_[tid] = softfloat_exceptionFlags;
    }
  }
  void fmsub_d_() {
    for (Thread *t : active_threads_) {
      uint32_t tid = t->tid();
      softfloat_init(rm_);
      auto neg_rs3 = rs3_data_[tid] ^ F64_SIGN;
      auto res = f64_mulAdd(to_float64(rs1_data_[tid]),
                            to_float64(rs2_data_[tid]), to_float64(neg_rs3));
      rd_data_[tid] = from_float64(res);
      fp_exception_[tid] = softfloat_exceptionFlags;
    }
  }
  // FNMADD
  void fnmadd_s_() {
    for (Thread *t : active_threads_) {
      uint32_t tid = t->tid();
      softfloat_init(rm_);
      auto neg_rs1 = rs1_data_[tid] ^ F32_SIGN;
      auto neg_rs3 = rs3_data_[tid] ^ F32_SIGN;
      auto res = f32_mulAdd(to_float32(neg_rs1), to_float32(rs2_data_[tid]),
                            to_float32(neg_rs3));
      rd_data_[tid] = from_float32(res);
      fp_exception_[tid] = softfloat_exceptionFlags;
    }
  }
  void fnmadd_d_() {
    for (Thread *t : active_threads_) {
      uint32_t tid = t->tid();
      softfloat_init(rm_);
      auto neg_rs1 = rs1_data_[tid] ^ F64_SIGN;
      auto neg_rs3 = rs3_data_[tid] ^ F64_SIGN;
      auto res = f64_mulAdd(to_float64(neg_rs1), to_float64(rs2_data_[tid]),
                            to_float64(neg_rs3));
      rd_data_[tid] = from_float64(res);
      fp_exception_[tid] = softfloat_exceptionFlags;
    }
  }
  // FNMSUB
  void fnmsub_s_() {
    for (Thread *t : active_threads_) {
      uint32_t tid = t->tid();
      softfloat_init(rm_);
      auto neg_rs1 = rs1_data_[tid] ^ F32_SIGN;
      auto res = f32_mulAdd(to_float32(neg_rs1), to_float32(rs2_data_[tid]),
                            to_float32(rs3_data_[tid]));
      rd_data_[tid] = from_float32(res);
      fp_exception_[tid] = softfloat_exceptionFlags;
    }
  }
  void fnmsub_d_() {
    for (Thread *t : active_threads_) {
      uint32_t tid = t->tid();
      softfloat_init(rm_);
      auto neg_rs1 = rs1_data_[tid] ^ F64_SIGN;
      auto res = f64_mulAdd(to_float64(neg_rs1), to_float64(rs2_data_[tid]),
                            to_float64(rs3_data_[tid]));
      rd_data_[tid] = from_float64(res);
      fp_exception_[tid] = softfloat_exceptionFlags;
    }
  }
  // util
  const Op sgnj_op[2][4] = {
      {&InstrOpFused::fmadd_s_, &InstrOpFused::fmsub_s_,
       &InstrOpFused::fnmadd_s_, &InstrOpFused::fnmsub_s_},
      {&InstrOpFused::fmadd_d_, &InstrOpFused::fmsub_d_,
       &InstrOpFused::fnmadd_d_, &InstrOpFused::fnmsub_d_}};
  const char *sgnj_name[2][4] = {
      {"fmadd.s", "fmsub.s", "fnmadd.s", "fnmsub.s"},
      {"fmadd.d", "fmsub.d", "fnmadd.d", "fnmsub.d"}};

  template <class InstrType>
  friend class InstrPool;
};
}  // namespace simtix