Skip to content

File warp.h

File List > projects > simtix > src > simtix > sm > warp.h

Go to the documentation of this file

#pragma once

#include <simtix/mem.h>
#include <simtix/sm.h>

#include <cassert>
#include <cstdint>
#include <memory>
#include <string>
#include <vector>

#include "sm/arbitrator/base.h"
#include "sm/thread.h"

using simtix::mem::MemoryInterface;

namespace simtix {

class BaseSMImpl;
class BaseLoadStoreUnit;
class Instr;

class Warp {
 public:
  // Warp Status
  enum class Status { kInvalid = 0, kReady, kRunning, kBarrier, kException };
  Warp(BaseSMImpl *sm, uint32_t wid, const ArchParam &p = kDefaultArchParam);
  // deconstructor
  ~Warp() = default;

  // wid getter
  uint32_t wid() const { return wid_; }
  // status getter
  Status status() const { return status_; }
  // Ready Thread Vector getter
  const std::vector<std::vector<std::vector<bool>>> &rtv() const {
    return rtv_;
  }
  // wpc getter
  uint64_t wpc() const { return wpc_; }
  uint64_t sswpc() const { return sswpc_; }
  // num_threads getter
  uint32_t num_threads() const { return num_threads_; }
  uint32_t num_valid_threads() const { return num_valid_threads_; }

  const std::vector<Thread *> &active_threads(uint32_t sswid) {
    return active_threads_[sswid];
  }
  uint8_t bid() const { return bid_; }
  uint8_t wc() const { return wc_; }

  const std::vector<bool> &active_thread_mask(uint32_t sswid) {
    return active_thread_mask_[sswid];
  }

  void swap_active_thread_mask() {
    active_thread_mask_[0].swap(active_thread_mask_[1]);
  }

  void set_active_thread_mask(uint32_t sswid, uint32_t tid, bool is_active) {
    active_thread_mask_[sswid][tid] = is_active;
  }

  bool ScoreboardClean() const;

  const uint8_t &issue_tswid(uint32_t sswid) { return issue_tswid_[sswid]; }

  bool is_last_tsw(uint32_t sswid) { return issue_tswid_[sswid] == 0; }

  uint32_t num_tsws(uint32_t sswid) { return active_tswid_[sswid]; }

  const std::vector<Thread *> &tsw_active_threads(uint32_t sswid,
                                                  uint8_t *tswid);

  // Spatial Sub-Warp
  bool is_diverged() const { return is_diverged_; }

  // thread getter
  Thread *thread(uint32_t tid) const { return threads_[tid].get(); }

  BaseLoadStoreUnit *lsu();
  BaseArbitrator *arbitrator();

  // Getter for minstret, needed by Instr
  uint64_t minstret() const { return minstret_; }

  // status setter
  void set_status(Status status) { status_ = status; }

  void set_num_active_tsws(uint32_t num_active_tsws) {
    num_active_tsws_ = num_active_tsws;
  }

  // Setter for exception related CSRs
  void set_mcause(uint64_t mcause) { mcause_ = mcause; }
  void set_mepc(uint64_t mepc) { mepc_ = mepc; }
  void set_mtval(uint64_t mtval) { mtval_ = mtval; }

  // Setter for barrier's attributes
  void set_bid(uint8_t bid) { bid_ = bid; }
  void set_wc(uint8_t wc) { wc_ = wc; }

  void ComputeRTV();

  void ArbitratePC();

  bool CompactActiveThreads(uint32_t sswid);

  void Initialize(uint32_t num_valid_threads, uint64_t pc,
                  const std::vector<std::vector<uint32_t>> &local_id);

  int32_t ReadCSR(uint16_t addr, uint64_t *data);

  void NotifyIssue(std::optional<uint32_t> ssw);

  void NotifyCommit(const Instr &instr);

  void Reset();

  void ResetScoreboard();

  void ReserveRegister(uint8_t index, uint8_t tswid, uint32_t sswid);

  void ReleaseRegister(uint8_t index, uint8_t tswid, uint32_t sswid);

  void ReservePC();

  void ReleasePC();

  bool RegisterAvailable(uint8_t index, uint32_t sswid);

  bool HasUnresolvedBranch();

  bool HasPendingCommit();

  bool HasUnissuedTSW(uint32_t sswid);

  std::string ActiveThreadMaskPattern(uint32_t sswid);

 protected:
  void ResetActiveThreadMask() {
    active_thread_mask_[0].assign(num_threads_, 0);
    active_thread_mask_[1].assign(num_threads_, 0);
  }
  void UpdateCommittedInstrStat(const Instr &instr);

  // Barrier
  uint8_t bid_ = 0;
  uint8_t wc_ = 0;

 private:
  BaseSMImpl *const sm_;

  uint32_t wid_;
  Status status_;
  const uint32_t num_threads_;
  const uint32_t num_warps_per_warpgroup_;
  const uint32_t max_priority_level_;
  const uint32_t max_priority_f_level_;

  std::vector<std::unique_ptr<Thread>> threads_;
  std::vector<std::vector<std::vector<bool>>> rtv_;
  uint64_t wpc_;

  uint32_t num_valid_threads_ = 0;

  std::array<std::vector<bool>, 2> active_thread_mask_;
  std::array<std::vector<Thread *>, 2> active_threads_;

  // Temporal Sub-Warp
  std::array<uint8_t, 2> active_tswid_ = {0, 0};  // the last active tswid
  std::array<uint8_t, 2> issue_tswid_ = {0, 0};

  std::array<std::vector<bool>, 2> tsw_commit_mask_;

  uint32_t num_active_tsws_ = 1;

  uint64_t num_pending_commits_ = 0;

  // Spatial Sub-Warp
  uint64_t sswpc_;

  bool is_diverged_ = false;

  // Per-warp CSRs
  uint64_t mepc_ = 0;
  uint64_t mcause_ = 0;
  uint64_t mtval_ = 0;
  uint64_t minstret_ = 0;

  // scoreboard
  std::array<std::array<bool, 32>, 2> reg_busy_;
  bool pc_busy_;
};

}  // namespace simtix