Skip to content

File instr_buffer.h

File List > pipelined > instr_buffer.h

Go to the documentation of this file

#pragma once

#include <simtix/param.h>

#include <algorithm>
#include <array>
#include <cstdint>
#include <memory>
#include <utility>
#include <vector>

#include "sim/queue.h"
#include "sm/instr_ptr.h"
#include "sm/pipelined/pipelined.h"

namespace simtix {

namespace pipelined {

class InstrBuffer {
 public:
  struct InstrStream {
    explicit InstrStream(std::size_t capacity)
        : is_diverged(false),
          fifos{sim::SizedQueue<InstrPtr>(capacity / 2),
                sim::SizedQueue<InstrPtr>(capacity / 2)} {}

    bool is_diverged = false;           // Indicates if the stream has diverged.
    std::array<bool, 2> valid{};        // Valid flags.
    std::array<uint64_t, 2> next_pc{};  // Next PC.
    std::array<uint64_t, 2> fetch_pc{};  // Fetch PC.

    std::array<sim::SizedQueue<InstrPtr>, 2> fifos;
  };

  explicit InstrBuffer(
      PipelinedSMImpl *sm, const ArchParam &p = kDefaultArchParam,
      const PipelinedSM::Param &pp = PipelinedSM::kDefaultParam);
  virtual ~InstrBuffer() = default;

  // When doing PipelinedSM::Process, use this method to initialize the
  // instruction buffer for each valid warp.

  void Enable(uint32_t wid, uint64_t next_pc);

  void Disable(uint32_t wid);

  void Reset();

  bool CanEnq(const InstrPtr &instr) const;

  void Enq(InstrPtr instr);

  InstrPtr Deq(uint32_t wid, std::optional<uint32_t> ssw);

  void NotifyCtrlFlowChange(uint32_t wid, uint64_t wpc,
                            std::optional<uint64_t> sswpc);

  void set_fetch_pc(uint32_t wid, std::optional<uint32_t> ssw,
                    uint64_t fetch_pc) const {
    uint32_t sswid = ssw.value_or(0);
    buf_[wid]->fetch_pc[sswid] = fetch_pc;
  }

  template <class TFunc>
  void ForEachValidInstrStream(TFunc cb) {
    for (Warp *warp : sm_->valid_leader_warps_) {
      uint32_t wid = warp->wid();
      uint32_t begin = static_cast<uint32_t>(is_diverged(wid));
      uint32_t end = static_cast<uint32_t>(is_diverged(wid)) * 2;
      for (uint32_t i = begin; i <= end; ++i) {
        auto ssw = kInstrStreamVariants[i];
        if (valid(wid, ssw)) cb(wid, ssw);
      }
    }
  }

  const InstrPtr &front(uint32_t wid, std::optional<uint32_t> ssw) const {
    assert((ssw && is_diverged(wid)) || (!ssw && !is_diverged(wid)));
    uint32_t sswid = ssw.value_or(0);
    return buf_[wid]->fifos[sswid].front();
  }

  bool is_diverged(uint32_t wid) const { return buf_[wid]->is_diverged; }

  std::size_t capacity(uint32_t wid) const {
    if (is_diverged(wid)) {
      return capacity_ / 2;
    }
    return capacity_;
  }

  bool valid(uint32_t wid, std::optional<uint32_t> ssw) const {
    if (ssw) {
      return is_diverged(wid) && buf_[wid]->valid[*ssw];
    } else {
      return !is_diverged(wid) && buf_[wid]->valid[0];
    }
  }

  bool full(uint32_t wid, std::optional<uint32_t> ssw) const {
    assert((ssw && is_diverged(wid)) || (!ssw && !is_diverged(wid)));
    if (ssw) {
      return buf_[wid]->fifos[*ssw].full();
    } else {
      return buf_[wid]->fifos[0].full() && buf_[wid]->fifos[1].full();
    }
  }

  bool empty(uint32_t wid, std::optional<uint32_t> ssw) const {
    assert((ssw && is_diverged(wid)) || (!ssw && !is_diverged(wid)));
    if (ssw) {
      return buf_[wid]->fifos[*ssw].empty();
    } else {
      return buf_[wid]->fifos[0].empty() && buf_[wid]->fifos[1].empty();
    }
  }

  virtual bool capacious(uint32_t wid, std::optional<uint32_t> ssw) const {
    uint32_t upcoming = (fetch_pc(wid, ssw) - next_pc(wid, ssw)) / 4;
    return upcoming + size(wid, ssw) + kFetchWidth <= capacity(wid);
  }

  std::size_t size(uint32_t wid, std::optional<uint32_t> ssw) const {
    assert((ssw && is_diverged(wid)) || (!ssw && !is_diverged(wid)));
    if (ssw) {
      return buf_[wid]->fifos[*ssw].size();
    } else {
      return buf_[wid]->fifos[0].size() + buf_[wid]->fifos[1].size();
    }
  }

  uint64_t next_pc(uint32_t wid, std::optional<uint32_t> ssw) const {
    assert((ssw && is_diverged(wid)) || (!ssw && !is_diverged(wid)));
    uint32_t sswid = ssw.value_or(0);
    return buf_[wid]->next_pc[sswid];
  }

  virtual uint64_t fetch_pc(uint32_t wid, std::optional<uint32_t> ssw) const {
    assert((ssw && is_diverged(wid)) || (!ssw && !is_diverged(wid)));
    uint32_t sswid = ssw.value_or(0);
    return buf_[wid]->fetch_pc[sswid];
  }

 protected:
  void RedirectInstrStream(uint32_t wid, uint64_t fetch_pc,
                           std::optional<uint32_t> ssw);

  void FlushFIFO(sim::SizedQueue<InstrPtr> *fifo) {
    for (size_t i = 0; i < fifo->size(); ++i) {
      InstrPtr instr = std::move((*fifo)[i]);
      instr.set_flushed();
    }
    fifo->clear();
  }

  void UpdateFetchPC(InstrStream *is, std::optional<uint32_t> ssw) {
    uint32_t sswid = ssw.value_or(0);
    is->fetch_pc[sswid] = std::max(is->next_pc[sswid], is->fetch_pc[sswid]);
  }

  bool CanEnqInstrToStream(uint32_t wid, std::optional<uint32_t> ssw,
                           uint64_t wpc) const {
    return valid(wid, ssw) && !full(wid, ssw) && next_pc(wid, ssw) == wpc;
  }

  void EnqInstrToStream(uint32_t wid, std::optional<uint32_t> ssw,
                        InstrPtr instr);

  inline static constexpr std::array<std::optional<uint32_t>, 3>
      kInstrStreamVariants{std::nullopt, 0, 1};  // no-diverge, ssw0, ssw1

  const uint32_t kFetchWidth;

  std::size_t capacity_;
  std::vector<std::unique_ptr<InstrStream>> buf_;
  PipelinedSMImpl *sm_;
};

}  // namespace pipelined

}  // namespace simtix