//===-- SPIRVAsmPrinter.cpp - SPIR-V LLVM assembly writer ------*- C++ -*--===//
//
// Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions.
// See https://llvm.org/LICENSE.txt for license information.
// SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
//
//===----------------------------------------------------------------------===//
//
// This file contains a printer that converts from our internal representation
// of machine-dependent LLVM code to the SPIR-V assembly language.
//
//===----------------------------------------------------------------------===//

#include "MCTargetDesc/SPIRVInstPrinter.h"
#include "SPIRV.h"
#include "SPIRVInstrInfo.h"
#include "SPIRVMCInstLower.h"
#include "SPIRVModuleAnalysis.h"
#include "SPIRVSubtarget.h"
#include "SPIRVTargetMachine.h"
#include "SPIRVUtils.h"
#include "TargetInfo/SPIRVTargetInfo.h"
#include "llvm/ADT/DenseMap.h"
#include "llvm/Analysis/ValueTracking.h"
#include "llvm/CodeGen/AsmPrinter.h"
#include "llvm/CodeGen/MachineConstantPool.h"
#include "llvm/CodeGen/MachineInstr.h"
#include "llvm/CodeGen/MachineModuleInfo.h"
#include "llvm/CodeGen/TargetLoweringObjectFileImpl.h"
#include "llvm/MC/MCAsmInfo.h"
#include "llvm/MC/MCAssembler.h"
#include "llvm/MC/MCInst.h"
#include "llvm/MC/MCObjectStreamer.h"
#include "llvm/MC/MCSPIRVObjectWriter.h"
#include "llvm/MC/MCStreamer.h"
#include "llvm/MC/MCSymbol.h"
#include "llvm/MC/TargetRegistry.h"
#include "llvm/Support/Compiler.h"
#include "llvm/Support/raw_ostream.h"

using namespace llvm;

#define DEBUG_TYPE "asm-printer"

namespace {
class SPIRVAsmPrinter : public AsmPrinter {
  unsigned NLabels = 0;
  SmallPtrSet<const MachineBasicBlock *, 8> LabeledMBB;

public:
  explicit SPIRVAsmPrinter(TargetMachine &TM,
                           std::unique_ptr<MCStreamer> Streamer)
      : AsmPrinter(TM, std::move(Streamer), ID), ModuleSectionsEmitted(false),
        ST(nullptr), TII(nullptr), MAI(nullptr) {}
  static char ID;
  bool ModuleSectionsEmitted;
  const SPIRVSubtarget *ST;
  const SPIRVInstrInfo *TII;

  StringRef getPassName() const override { return "SPIRV Assembly Printer"; }
  void printOperand(const MachineInstr *MI, int OpNum, raw_ostream &O);
  bool PrintAsmOperand(const MachineInstr *MI, unsigned OpNo,
                       const char *ExtraCode, raw_ostream &O) override;

  void outputMCInst(MCInst &Inst);
  void outputInstruction(const MachineInstr *MI);
  void outputModuleSection(SPIRV::ModuleSectionType MSType);
  void outputGlobalRequirements();
  void outputEntryPoints();
  void outputDebugSourceAndStrings(const Module &M);
  void outputOpExtInstImports(const Module &M);
  void outputOpMemoryModel();
  void outputOpFunctionEnd();
  void outputExtFuncDecls();
  void outputExecutionModeFromMDNode(MCRegister Reg, MDNode *Node,
                                     SPIRV::ExecutionMode::ExecutionMode EM,
                                     unsigned ExpectMDOps, int64_t DefVal);
  void outputExecutionModeFromNumthreadsAttribute(
      const MCRegister &Reg, const Attribute &Attr,
      SPIRV::ExecutionMode::ExecutionMode EM);
  void outputExecutionModeFromEnableMaximalReconvergenceAttr(
      const MCRegister &Reg, const SPIRVSubtarget &ST);
  void outputExecutionMode(const Module &M);
  void outputAnnotations(const Module &M);
  void outputModuleSections();
  void outputFPFastMathDefaultInfo();
  bool isHidden() {
    return MF->getFunction()
        .getFnAttribute(SPIRV_BACKEND_SERVICE_FUN_NAME)
        .isValid();
  }

  void emitInstruction(const MachineInstr *MI) override;
  void emitFunctionEntryLabel() override {}
  void emitFunctionHeader() override;
  void emitFunctionBodyStart() override {}
  void emitFunctionBodyEnd() override;
  void emitBasicBlockStart(const MachineBasicBlock &MBB) override;
  void emitBasicBlockEnd(const MachineBasicBlock &MBB) override {}
  void emitGlobalVariable(const GlobalVariable *GV) override {}
  void emitOpLabel(const MachineBasicBlock &MBB);
  void emitEndOfAsmFile(Module &M) override;
  bool doInitialization(Module &M) override;

  void getAnalysisUsage(AnalysisUsage &AU) const override;
  SPIRV::ModuleAnalysisInfo *MAI;

protected:
  void cleanUp(Module &M);
};
} // namespace

void SPIRVAsmPrinter::getAnalysisUsage(AnalysisUsage &AU) const {
  AU.addRequired<SPIRVModuleAnalysis>();
  AU.addPreserved<SPIRVModuleAnalysis>();
  AsmPrinter::getAnalysisUsage(AU);
}

// If the module has no functions, we need output global info anyway.
void SPIRVAsmPrinter::emitEndOfAsmFile(Module &M) {
  if (ModuleSectionsEmitted == false) {
    outputModuleSections();
    ModuleSectionsEmitted = true;
  }

  ST = static_cast<const SPIRVTargetMachine &>(TM).getSubtargetImpl();
  VersionTuple SPIRVVersion = ST->getSPIRVVersion();
  uint32_t Major = SPIRVVersion.getMajor();
  uint32_t Minor = SPIRVVersion.getMinor().value_or(0);
  // Bound is an approximation that accounts for the maximum used register
  // number and number of generated OpLabels
  unsigned Bound = 2 * (ST->getBound() + 1) + NLabels;
  if (MCAssembler *Asm = OutStreamer->getAssemblerPtr())
    static_cast<SPIRVObjectWriter &>(Asm->getWriter())
        .setBuildVersion(Major, Minor, Bound);

  cleanUp(M);
}

// Any cleanup actions with the Module after we don't care about its content
// anymore.
void SPIRVAsmPrinter::cleanUp(Module &M) {
  // Verifier disallows uses of intrinsic global variables.
  for (StringRef GVName :
       {"llvm.global_ctors", "llvm.global_dtors", "llvm.used"}) {
    if (GlobalVariable *GV = M.getNamedGlobal(GVName))
      GV->setName("");
  }
}

void SPIRVAsmPrinter::emitFunctionHeader() {
  if (ModuleSectionsEmitted == false) {
    outputModuleSections();
    ModuleSectionsEmitted = true;
  }
  // Get the subtarget from the current MachineFunction.
  ST = &MF->getSubtarget<SPIRVSubtarget>();
  TII = ST->getInstrInfo();
  const Function &F = MF->getFunction();

  if (isVerbose() && !isHidden()) {
    OutStreamer->getCommentOS()
        << "-- Begin function "
        << GlobalValue::dropLLVMManglingEscape(F.getName()) << '\n';
  }

  auto Section = getObjFileLowering().SectionForGlobal(&F, TM);
  MF->setSection(Section);
}

void SPIRVAsmPrinter::outputOpFunctionEnd() {
  MCInst FunctionEndInst;
  FunctionEndInst.setOpcode(SPIRV::OpFunctionEnd);
  outputMCInst(FunctionEndInst);
}

void SPIRVAsmPrinter::emitFunctionBodyEnd() {
  if (!isHidden())
    outputOpFunctionEnd();
}

void SPIRVAsmPrinter::emitOpLabel(const MachineBasicBlock &MBB) {
  // Do not emit anything if it's an internal service function.
  if (isHidden())
    return;

  MCInst LabelInst;
  LabelInst.setOpcode(SPIRV::OpLabel);
  LabelInst.addOperand(MCOperand::createReg(MAI->getOrCreateMBBRegister(MBB)));
  outputMCInst(LabelInst);
  ++NLabels;
  LabeledMBB.insert(&MBB);
}

void SPIRVAsmPrinter::emitBasicBlockStart(const MachineBasicBlock &MBB) {
  // Do not emit anything if it's an internal service function.
  if (MBB.empty())
    return;

  // If it's the first MBB in MF, it has OpFunction and OpFunctionParameter, so
  // OpLabel should be output after them.
  if (MBB.getNumber() == MF->front().getNumber()) {
    for (const MachineInstr &MI : MBB)
      if (MI.getOpcode() == SPIRV::OpFunction)
        return;
    // TODO: this case should be checked by the verifier.
    report_fatal_error("OpFunction is expected in the front MBB of MF");
  }
  emitOpLabel(MBB);
}

void SPIRVAsmPrinter::printOperand(const MachineInstr *MI, int OpNum,
                                   raw_ostream &O) {
  const MachineOperand &MO = MI->getOperand(OpNum);

  switch (MO.getType()) {
  case MachineOperand::MO_Register:
    O << SPIRVInstPrinter::getRegisterName(MO.getReg());
    break;

  case MachineOperand::MO_Immediate:
    O << MO.getImm();
    break;

  case MachineOperand::MO_FPImmediate:
    O << MO.getFPImm();
    break;

  case MachineOperand::MO_MachineBasicBlock:
    O << *MO.getMBB()->getSymbol();
    break;

  case MachineOperand::MO_GlobalAddress:
    O << *getSymbol(MO.getGlobal());
    break;

  case MachineOperand::MO_BlockAddress: {
    MCSymbol *BA = GetBlockAddressSymbol(MO.getBlockAddress());
    O << BA->getName();
    break;
  }

  case MachineOperand::MO_ExternalSymbol:
    O << *GetExternalSymbolSymbol(MO.getSymbolName());
    break;

  case MachineOperand::MO_JumpTableIndex:
  case MachineOperand::MO_ConstantPoolIndex:
  default:
    llvm_unreachable("<unknown operand type>");
  }
}

bool SPIRVAsmPrinter::PrintAsmOperand(const MachineInstr *MI, unsigned OpNo,
                                      const char *ExtraCode, raw_ostream &O) {
  if (ExtraCode && ExtraCode[0])
    return true; // Invalid instruction - SPIR-V does not have special modifiers

  printOperand(MI, OpNo, O);
  return false;
}

static bool isFuncOrHeaderInstr(const MachineInstr *MI,
                                const SPIRVInstrInfo *TII) {
  return TII->isHeaderInstr(*MI) || MI->getOpcode() == SPIRV::OpFunction ||
         MI->getOpcode() == SPIRV::OpFunctionParameter;
}

void SPIRVAsmPrinter::outputMCInst(MCInst &Inst) {
  OutStreamer->emitInstruction(Inst, *OutContext.getSubtargetInfo());
}

void SPIRVAsmPrinter::outputInstruction(const MachineInstr *MI) {
  SPIRVMCInstLower MCInstLowering;
  MCInst TmpInst;
  MCInstLowering.lower(MI, TmpInst, MAI);
  outputMCInst(TmpInst);
}

void SPIRVAsmPrinter::emitInstruction(const MachineInstr *MI) {
  SPIRV_MC::verifyInstructionPredicates(MI->getOpcode(),
                                        getSubtargetInfo().getFeatureBits());

  if (!MAI->getSkipEmission(MI))
    outputInstruction(MI);

  // Output OpLabel after OpFunction and OpFunctionParameter in the first MBB.
  const MachineInstr *NextMI = MI->getNextNode();
  if (!LabeledMBB.contains(MI->getParent()) && isFuncOrHeaderInstr(MI, TII) &&
      (!NextMI || !isFuncOrHeaderInstr(NextMI, TII))) {
    assert(MI->getParent()->getNumber() == MF->front().getNumber() &&
           "OpFunction is not in the front MBB of MF");
    emitOpLabel(*MI->getParent());
  }
}

void SPIRVAsmPrinter::outputModuleSection(SPIRV::ModuleSectionType MSType) {
  for (const MachineInstr *MI : MAI->getMSInstrs(MSType))
    outputInstruction(MI);
}

void SPIRVAsmPrinter::outputDebugSourceAndStrings(const Module &M) {
  // Output OpSourceExtensions.
  for (auto &Str : MAI->SrcExt) {
    MCInst Inst;
    Inst.setOpcode(SPIRV::OpSourceExtension);
    addStringImm(Str.first(), Inst);
    outputMCInst(Inst);
  }
  // Output OpString.
  outputModuleSection(SPIRV::MB_DebugStrings);
  // Output OpSource.
  MCInst Inst;
  Inst.setOpcode(SPIRV::OpSource);
  Inst.addOperand(MCOperand::createImm(static_cast<unsigned>(MAI->SrcLang)));
  Inst.addOperand(
      MCOperand::createImm(static_cast<unsigned>(MAI->SrcLangVersion)));
  outputMCInst(Inst);
}

void SPIRVAsmPrinter::outputOpExtInstImports(const Module &M) {
  for (auto &CU : MAI->ExtInstSetMap) {
    unsigned Set = CU.first;
    MCRegister Reg = CU.second;
    MCInst Inst;
    Inst.setOpcode(SPIRV::OpExtInstImport);
    Inst.addOperand(MCOperand::createReg(Reg));
    addStringImm(getExtInstSetName(
                     static_cast<SPIRV::InstructionSet::InstructionSet>(Set)),
                 Inst);
    outputMCInst(Inst);
  }
}

void SPIRVAsmPrinter::outputOpMemoryModel() {
  MCInst Inst;
  Inst.setOpcode(SPIRV::OpMemoryModel);
  Inst.addOperand(MCOperand::createImm(static_cast<unsigned>(MAI->Addr)));
  Inst.addOperand(MCOperand::createImm(static_cast<unsigned>(MAI->Mem)));
  outputMCInst(Inst);
}

// Before the OpEntryPoints' output, we need to add the entry point's
// interfaces. The interface is a list of IDs of global OpVariable instructions.
// These declare the set of global variables from a module that form
// the interface of this entry point.
void SPIRVAsmPrinter::outputEntryPoints() {
  // Find all OpVariable IDs with required StorageClass.
  DenseSet<MCRegister> InterfaceIDs;
  for (const MachineInstr *MI : MAI->GlobalVarList) {
    assert(MI->getOpcode() == SPIRV::OpVariable);
    auto SC = static_cast<SPIRV::StorageClass::StorageClass>(
        MI->getOperand(2).getImm());
    // Before version 1.4, the interface's storage classes are limited to
    // the Input and Output storage classes. Starting with version 1.4,
    // the interface's storage classes are all storage classes used in
    // declaring all global variables referenced by the entry point call tree.
    if (ST->isAtLeastSPIRVVer(VersionTuple(1, 4)) ||
        SC == SPIRV::StorageClass::Input || SC == SPIRV::StorageClass::Output) {
      const MachineFunction *MF = MI->getMF();
      MCRegister Reg = MAI->getRegisterAlias(MF, MI->getOperand(0).getReg());
      InterfaceIDs.insert(Reg);
    }
  }

  // Output OpEntryPoints adding interface args to all of them.
  for (const MachineInstr *MI : MAI->getMSInstrs(SPIRV::MB_EntryPoints)) {
    SPIRVMCInstLower MCInstLowering;
    MCInst TmpInst;
    MCInstLowering.lower(MI, TmpInst, MAI);
    for (MCRegister Reg : InterfaceIDs) {
      assert(Reg.isValid());
      TmpInst.addOperand(MCOperand::createReg(Reg));
    }
    outputMCInst(TmpInst);
  }
}

// Create global OpCapability instructions for the required capabilities.
void SPIRVAsmPrinter::outputGlobalRequirements() {
  // Abort here if not all requirements can be satisfied.
  MAI->Reqs.checkSatisfiable(*ST);

  for (const auto &Cap : MAI->Reqs.getMinimalCapabilities()) {
    MCInst Inst;
    Inst.setOpcode(SPIRV::OpCapability);
    Inst.addOperand(MCOperand::createImm(Cap));
    outputMCInst(Inst);
  }

  // Generate the final OpExtensions with strings instead of enums.
  for (const auto &Ext : MAI->Reqs.getExtensions()) {
    MCInst Inst;
    Inst.setOpcode(SPIRV::OpExtension);
    addStringImm(getSymbolicOperandMnemonic(
                     SPIRV::OperandCategory::ExtensionOperand, Ext),
                 Inst);
    outputMCInst(Inst);
  }
  // TODO add a pseudo instr for version number.
}

void SPIRVAsmPrinter::outputExtFuncDecls() {
  // Insert OpFunctionEnd after each declaration.
  auto I = MAI->getMSInstrs(SPIRV::MB_ExtFuncDecls).begin(),
       E = MAI->getMSInstrs(SPIRV::MB_ExtFuncDecls).end();
  for (; I != E; ++I) {
    outputInstruction(*I);
    if ((I + 1) == E || (*(I + 1))->getOpcode() == SPIRV::OpFunction)
      outputOpFunctionEnd();
  }
}

// Encode LLVM type by SPIR-V execution mode VecTypeHint.
static unsigned encodeVecTypeHint(Type *Ty) {
  if (Ty->isHalfTy())
    return 4;
  if (Ty->isFloatTy())
    return 5;
  if (Ty->isDoubleTy())
    return 6;
  if (IntegerType *IntTy = dyn_cast<IntegerType>(Ty)) {
    switch (IntTy->getIntegerBitWidth()) {
    case 8:
      return 0;
    case 16:
      return 1;
    case 32:
      return 2;
    case 64:
      return 3;
    default:
      llvm_unreachable("invalid integer type");
    }
  }
  if (FixedVectorType *VecTy = dyn_cast<FixedVectorType>(Ty)) {
    Type *EleTy = VecTy->getElementType();
    unsigned Size = VecTy->getNumElements();
    return Size << 16 | encodeVecTypeHint(EleTy);
  }
  llvm_unreachable("invalid type");
}

static void addOpsFromMDNode(MDNode *MDN, MCInst &Inst,
                             SPIRV::ModuleAnalysisInfo *MAI) {
  for (const MDOperand &MDOp : MDN->operands()) {
    if (auto *CMeta = dyn_cast<ConstantAsMetadata>(MDOp)) {
      Constant *C = CMeta->getValue();
      if (ConstantInt *Const = dyn_cast<ConstantInt>(C)) {
        Inst.addOperand(MCOperand::createImm(Const->getZExtValue()));
      } else if (auto *CE = dyn_cast<Function>(C)) {
        MCRegister FuncReg = MAI->getFuncReg(CE);
        assert(FuncReg.isValid());
        Inst.addOperand(MCOperand::createReg(FuncReg));
      }
    }
  }
}

void SPIRVAsmPrinter::outputExecutionModeFromMDNode(
    MCRegister Reg, MDNode *Node, SPIRV::ExecutionMode::ExecutionMode EM,
    unsigned ExpectMDOps, int64_t DefVal) {
  MCInst Inst;
  Inst.setOpcode(SPIRV::OpExecutionMode);
  Inst.addOperand(MCOperand::createReg(Reg));
  Inst.addOperand(MCOperand::createImm(static_cast<unsigned>(EM)));
  addOpsFromMDNode(Node, Inst, MAI);
  // reqd_work_group_size and work_group_size_hint require 3 operands,
  // if metadata contains less operands, just add a default value
  unsigned NodeSz = Node->getNumOperands();
  if (ExpectMDOps > 0 && NodeSz < ExpectMDOps)
    for (unsigned i = NodeSz; i < ExpectMDOps; ++i)
      Inst.addOperand(MCOperand::createImm(DefVal));
  outputMCInst(Inst);
}

void SPIRVAsmPrinter::outputExecutionModeFromNumthreadsAttribute(
    const MCRegister &Reg, const Attribute &Attr,
    SPIRV::ExecutionMode::ExecutionMode EM) {
  assert(Attr.isValid() && "Function called with an invalid attribute.");

  MCInst Inst;
  Inst.setOpcode(SPIRV::OpExecutionMode);
  Inst.addOperand(MCOperand::createReg(Reg));
  Inst.addOperand(MCOperand::createImm(static_cast<unsigned>(EM)));

  SmallVector<StringRef> NumThreads;
  Attr.getValueAsString().split(NumThreads, ',');
  assert(NumThreads.size() == 3 && "invalid numthreads");
  for (uint32_t i = 0; i < 3; ++i) {
    uint32_t V;
    [[maybe_unused]] bool Result = NumThreads[i].getAsInteger(10, V);
    assert(!Result && "Failed to parse numthreads");
    Inst.addOperand(MCOperand::createImm(V));
  }

  outputMCInst(Inst);
}

void SPIRVAsmPrinter::outputExecutionModeFromEnableMaximalReconvergenceAttr(
    const MCRegister &Reg, const SPIRVSubtarget &ST) {
  assert(ST.canUseExtension(SPIRV::Extension::SPV_KHR_maximal_reconvergence) &&
         "Function called when SPV_KHR_maximal_reconvergence is not enabled.");

  MCInst Inst;
  Inst.setOpcode(SPIRV::OpExecutionMode);
  Inst.addOperand(MCOperand::createReg(Reg));
  unsigned EM =
      static_cast<unsigned>(SPIRV::ExecutionMode::MaximallyReconvergesKHR);
  Inst.addOperand(MCOperand::createImm(EM));
  outputMCInst(Inst);
}

void SPIRVAsmPrinter::outputExecutionMode(const Module &M) {
  NamedMDNode *Node = M.getNamedMetadata("spirv.ExecutionMode");
  if (Node) {
    for (unsigned i = 0; i < Node->getNumOperands(); i++) {
      // If SPV_KHR_float_controls2 is enabled and we find any of
      // FPFastMathDefault, ContractionOff or SignedZeroInfNanPreserve execution
      // modes, skip it, it'll be done somewhere else.
      if (ST->canUseExtension(SPIRV::Extension::SPV_KHR_float_controls2)) {
        const auto EM =
            cast<ConstantInt>(
                cast<ConstantAsMetadata>((Node->getOperand(i))->getOperand(1))
                    ->getValue())
                ->getZExtValue();
        if (EM == SPIRV::ExecutionMode::FPFastMathDefault ||
            EM == SPIRV::ExecutionMode::ContractionOff ||
            EM == SPIRV::ExecutionMode::SignedZeroInfNanPreserve)
          continue;
      }

      MCInst Inst;
      Inst.setOpcode(SPIRV::OpExecutionMode);
      addOpsFromMDNode(cast<MDNode>(Node->getOperand(i)), Inst, MAI);
      outputMCInst(Inst);
    }
    outputFPFastMathDefaultInfo();
  }
  for (auto FI = M.begin(), E = M.end(); FI != E; ++FI) {
    const Function &F = *FI;
    // Only operands of OpEntryPoint instructions are allowed to be
    // <Entry Point> operands of OpExecutionMode
    if (F.isDeclaration() || !isEntryPoint(F))
      continue;
    MCRegister FReg = MAI->getFuncReg(&F);
    assert(FReg.isValid());

    if (Attribute Attr = F.getFnAttribute("hlsl.shader"); Attr.isValid()) {
      // SPIR-V common validation: Fragment requires OriginUpperLeft or
      // OriginLowerLeft.
      // VUID-StandaloneSpirv-OriginLowerLeft-04653: Fragment must declare
      // OriginUpperLeft.
      if (Attr.getValueAsString() == "pixel") {
        MCInst Inst;
        Inst.setOpcode(SPIRV::OpExecutionMode);
        Inst.addOperand(MCOperand::createReg(FReg));
        unsigned EM =
            static_cast<unsigned>(SPIRV::ExecutionMode::OriginUpperLeft);
        Inst.addOperand(MCOperand::createImm(EM));
        outputMCInst(Inst);
      }
    }
    if (MDNode *Node = F.getMetadata("reqd_work_group_size"))
      outputExecutionModeFromMDNode(FReg, Node, SPIRV::ExecutionMode::LocalSize,
                                    3, 1);
    if (Attribute Attr = F.getFnAttribute("hlsl.numthreads"); Attr.isValid())
      outputExecutionModeFromNumthreadsAttribute(
          FReg, Attr, SPIRV::ExecutionMode::LocalSize);
    if (Attribute Attr = F.getFnAttribute("enable-maximal-reconvergence");
        Attr.getValueAsBool()) {
      outputExecutionModeFromEnableMaximalReconvergenceAttr(FReg, *ST);
    }
    if (MDNode *Node = F.getMetadata("work_group_size_hint"))
      outputExecutionModeFromMDNode(FReg, Node,
                                    SPIRV::ExecutionMode::LocalSizeHint, 3, 1);
    if (MDNode *Node = F.getMetadata("intel_reqd_sub_group_size"))
      outputExecutionModeFromMDNode(FReg, Node,
                                    SPIRV::ExecutionMode::SubgroupSize, 0, 0);
    if (MDNode *Node = F.getMetadata("max_work_group_size")) {
      if (ST->canUseExtension(SPIRV::Extension::SPV_INTEL_kernel_attributes))
        outputExecutionModeFromMDNode(
            FReg, Node, SPIRV::ExecutionMode::MaxWorkgroupSizeINTEL, 3, 1);
    }
    if (MDNode *Node = F.getMetadata("vec_type_hint")) {
      MCInst Inst;
      Inst.setOpcode(SPIRV::OpExecutionMode);
      Inst.addOperand(MCOperand::createReg(FReg));
      unsigned EM = static_cast<unsigned>(SPIRV::ExecutionMode::VecTypeHint);
      Inst.addOperand(MCOperand::createImm(EM));
      unsigned TypeCode = encodeVecTypeHint(getMDOperandAsType(Node, 0));
      Inst.addOperand(MCOperand::createImm(TypeCode));
      outputMCInst(Inst);
    }
    if (ST->isKernel() && !M.getNamedMetadata("spirv.ExecutionMode") &&
        !M.getNamedMetadata("opencl.enable.FP_CONTRACT")) {
      if (ST->canUseExtension(SPIRV::Extension::SPV_KHR_float_controls2)) {
        // When SPV_KHR_float_controls2 is enabled, ContractionOff is
        // deprecated. We need to use FPFastMathDefault with the appropriate
        // flags instead. Since FPFastMathDefault takes a target type, we need
        // to emit it for each floating-point type that exists in the module
        // to match the effect of ContractionOff. As of now, there are 3 FP
        // types: fp16, fp32 and fp64.

        // We only end up here because there is no "spirv.ExecutionMode"
        // metadata, so that means no FPFastMathDefault. Therefore, we only
        // need to make sure AllowContract is set to 0, as the rest of flags.
        // We still need to emit the OpExecutionMode instruction, otherwise
        // it's up to the client API to define the flags. Therefore, we need
        // to find the constant with 0 value.

        // Collect the SPIRVTypes for fp16, fp32, and fp64 and the constant of
        // type int32 with 0 value to represent the FP Fast Math Mode.
        std::vector<const MachineInstr *> SPIRVFloatTypes;
        const MachineInstr *ConstZeroInt32 = nullptr;
        for (const MachineInstr *MI :
             MAI->getMSInstrs(SPIRV::MB_TypeConstVars)) {
          unsigned OpCode = MI->getOpcode();

          // Collect the SPIRV type if it's a float.
          if (OpCode == SPIRV::OpTypeFloat) {
            // Skip if the target type is not fp16, fp32, fp64.
            const unsigned OpTypeFloatSize = MI->getOperand(1).getImm();
            if (OpTypeFloatSize != 16 && OpTypeFloatSize != 32 &&
                OpTypeFloatSize != 64) {
              continue;
            }
            SPIRVFloatTypes.push_back(MI);
            continue;
          }

          if (OpCode == SPIRV::OpConstantNull) {
            // Check if the constant is int32, if not skip it.
            const MachineRegisterInfo &MRI = MI->getMF()->getRegInfo();
            MachineInstr *TypeMI = MRI.getVRegDef(MI->getOperand(1).getReg());
            bool IsInt32Ty = TypeMI &&
                             TypeMI->getOpcode() == SPIRV::OpTypeInt &&
                             TypeMI->getOperand(1).getImm() == 32;
            if (IsInt32Ty)
              ConstZeroInt32 = MI;
          }
        }

        // When SPV_KHR_float_controls2 is enabled, ContractionOff is
        // deprecated. We need to use FPFastMathDefault with the appropriate
        // flags instead. Since FPFastMathDefault takes a target type, we need
        // to emit it for each floating-point type that exists in the module
        // to match the effect of ContractionOff. As of now, there are 3 FP
        // types: fp16, fp32 and fp64.
        for (const MachineInstr *MI : SPIRVFloatTypes) {
          MCInst Inst;
          Inst.setOpcode(SPIRV::OpExecutionModeId);
          Inst.addOperand(MCOperand::createReg(FReg));
          unsigned EM =
              static_cast<unsigned>(SPIRV::ExecutionMode::FPFastMathDefault);
          Inst.addOperand(MCOperand::createImm(EM));
          const MachineFunction *MF = MI->getMF();
          MCRegister TypeReg =
              MAI->getRegisterAlias(MF, MI->getOperand(0).getReg());
          Inst.addOperand(MCOperand::createReg(TypeReg));
          assert(ConstZeroInt32 && "There should be a constant zero.");
          MCRegister ConstReg = MAI->getRegisterAlias(
              ConstZeroInt32->getMF(), ConstZeroInt32->getOperand(0).getReg());
          Inst.addOperand(MCOperand::createReg(ConstReg));
          outputMCInst(Inst);
        }
      } else {
        MCInst Inst;
        Inst.setOpcode(SPIRV::OpExecutionMode);
        Inst.addOperand(MCOperand::createReg(FReg));
        unsigned EM =
            static_cast<unsigned>(SPIRV::ExecutionMode::ContractionOff);
        Inst.addOperand(MCOperand::createImm(EM));
        outputMCInst(Inst);
      }
    }
  }
}

void SPIRVAsmPrinter::outputAnnotations(const Module &M) {
  outputModuleSection(SPIRV::MB_Annotations);
  // Process llvm.global.annotations special global variable.
  for (auto F = M.global_begin(), E = M.global_end(); F != E; ++F) {
    if ((*F).getName() != "llvm.global.annotations")
      continue;
    const GlobalVariable *V = &(*F);
    const ConstantArray *CA = cast<ConstantArray>(V->getOperand(0));
    for (Value *Op : CA->operands()) {
      ConstantStruct *CS = cast<ConstantStruct>(Op);
      // The first field of the struct contains a pointer to
      // the annotated variable.
      Value *AnnotatedVar = CS->getOperand(0)->stripPointerCasts();
      if (!isa<Function>(AnnotatedVar))
        report_fatal_error("Unsupported value in llvm.global.annotations");
      Function *Func = cast<Function>(AnnotatedVar);
      MCRegister Reg = MAI->getFuncReg(Func);
      if (!Reg.isValid()) {
        std::string DiagMsg;
        raw_string_ostream OS(DiagMsg);
        AnnotatedVar->print(OS);
        DiagMsg = "Unknown function in llvm.global.annotations: " + DiagMsg;
        report_fatal_error(DiagMsg.c_str());
      }

      // The second field contains a pointer to a global annotation string.
      GlobalVariable *GV =
          cast<GlobalVariable>(CS->getOperand(1)->stripPointerCasts());

      StringRef AnnotationString;
      [[maybe_unused]] bool Success =
          getConstantStringInfo(GV, AnnotationString);
      assert(Success && "Failed to get annotation string");
      MCInst Inst;
      Inst.setOpcode(SPIRV::OpDecorate);
      Inst.addOperand(MCOperand::createReg(Reg));
      unsigned Dec = static_cast<unsigned>(SPIRV::Decoration::UserSemantic);
      Inst.addOperand(MCOperand::createImm(Dec));
      addStringImm(AnnotationString, Inst);
      outputMCInst(Inst);
    }
  }
}

void SPIRVAsmPrinter::outputFPFastMathDefaultInfo() {
  // Collect the SPIRVTypes that are OpTypeFloat and the constants of type
  // int32, that might be used as FP Fast Math Mode.
  std::vector<const MachineInstr *> SPIRVFloatTypes;
  // Hashtable to associate immediate values with the constant holding them.
  std::unordered_map<int, const MachineInstr *> ConstMap;
  for (const MachineInstr *MI : MAI->getMSInstrs(SPIRV::MB_TypeConstVars)) {
    // Skip if the instruction is not OpTypeFloat or OpConstant.
    unsigned OpCode = MI->getOpcode();
    if (OpCode != SPIRV::OpTypeFloat && OpCode != SPIRV::OpConstantI &&
        OpCode != SPIRV::OpConstantNull)
      continue;

    // Collect the SPIRV type if it's a float.
    if (OpCode == SPIRV::OpTypeFloat) {
      SPIRVFloatTypes.push_back(MI);
    } else {
      // Check if the constant is int32, if not skip it.
      const MachineRegisterInfo &MRI = MI->getMF()->getRegInfo();
      MachineInstr *TypeMI = MRI.getVRegDef(MI->getOperand(1).getReg());
      if (!TypeMI || TypeMI->getOpcode() != SPIRV::OpTypeInt ||
          TypeMI->getOperand(1).getImm() != 32)
        continue;

      if (OpCode == SPIRV::OpConstantI)
        ConstMap[MI->getOperand(2).getImm()] = MI;
      else
        ConstMap[0] = MI;
    }
  }

  for (const auto &[Func, FPFastMathDefaultInfoVec] :
       MAI->FPFastMathDefaultInfoMap) {
    if (FPFastMathDefaultInfoVec.empty())
      continue;

    for (const MachineInstr *MI : SPIRVFloatTypes) {
      unsigned OpTypeFloatSize = MI->getOperand(1).getImm();
      unsigned Index = SPIRV::FPFastMathDefaultInfoVector::
          computeFPFastMathDefaultInfoVecIndex(OpTypeFloatSize);
      assert(Index < FPFastMathDefaultInfoVec.size() &&
             "Index out of bounds for FPFastMathDefaultInfoVec");
      const auto &FPFastMathDefaultInfo = FPFastMathDefaultInfoVec[Index];
      assert(FPFastMathDefaultInfo.Ty &&
             "Expected target type for FPFastMathDefaultInfo");
      assert(FPFastMathDefaultInfo.Ty->getScalarSizeInBits() ==
                 OpTypeFloatSize &&
             "Mismatched float type size");
      MCInst Inst;
      Inst.setOpcode(SPIRV::OpExecutionModeId);
      MCRegister FuncReg = MAI->getFuncReg(Func);
      assert(FuncReg.isValid());
      Inst.addOperand(MCOperand::createReg(FuncReg));
      Inst.addOperand(
          MCOperand::createImm(SPIRV::ExecutionMode::FPFastMathDefault));
      MCRegister TypeReg =
          MAI->getRegisterAlias(MI->getMF(), MI->getOperand(0).getReg());
      Inst.addOperand(MCOperand::createReg(TypeReg));
      unsigned Flags = FPFastMathDefaultInfo.FastMathFlags;
      if (FPFastMathDefaultInfo.ContractionOff &&
          (Flags & SPIRV::FPFastMathMode::AllowContract))
        report_fatal_error(
            "Conflicting FPFastMathFlags: ContractionOff and AllowContract");

      if (FPFastMathDefaultInfo.SignedZeroInfNanPreserve &&
          !(Flags &
            (SPIRV::FPFastMathMode::NotNaN | SPIRV::FPFastMathMode::NotInf |
             SPIRV::FPFastMathMode::NSZ))) {
        if (FPFastMathDefaultInfo.FPFastMathDefault)
          report_fatal_error("Conflicting FPFastMathFlags: "
                             "SignedZeroInfNanPreserve but at least one of "
                             "NotNaN/NotInf/NSZ is enabled.");
      }

      // Don't emit if none of the execution modes was used.
      if (Flags == SPIRV::FPFastMathMode::None &&
          !FPFastMathDefaultInfo.ContractionOff &&
          !FPFastMathDefaultInfo.SignedZeroInfNanPreserve &&
          !FPFastMathDefaultInfo.FPFastMathDefault)
        continue;

      // Retrieve the constant instruction for the immediate value.
      auto It = ConstMap.find(Flags);
      if (It == ConstMap.end())
        report_fatal_error("Expected constant instruction for FP Fast Math "
                           "Mode operand of FPFastMathDefault execution mode.");
      const MachineInstr *ConstMI = It->second;
      MCRegister ConstReg = MAI->getRegisterAlias(
          ConstMI->getMF(), ConstMI->getOperand(0).getReg());
      Inst.addOperand(MCOperand::createReg(ConstReg));
      outputMCInst(Inst);
    }
  }
}

void SPIRVAsmPrinter::outputModuleSections() {
  const Module *M = MMI->getModule();
  // Get the global subtarget to output module-level info.
  ST = static_cast<const SPIRVTargetMachine &>(TM).getSubtargetImpl();
  TII = ST->getInstrInfo();
  MAI = &SPIRVModuleAnalysis::MAI;
  assert(ST && TII && MAI && M && "Module analysis is required");
  // Output instructions according to the Logical Layout of a Module:
  // 1,2. All OpCapability instructions, then optional OpExtension
  // instructions.
  outputGlobalRequirements();
  // 3. Optional OpExtInstImport instructions.
  outputOpExtInstImports(*M);
  // 4. The single required OpMemoryModel instruction.
  outputOpMemoryModel();
  // 5. All entry point declarations, using OpEntryPoint.
  outputEntryPoints();
  // 6. Execution-mode declarations, using OpExecutionMode or
  // OpExecutionModeId.
  outputExecutionMode(*M);
  // 7a. Debug: all OpString, OpSourceExtension, OpSource, and
  // OpSourceContinued, without forward references.
  outputDebugSourceAndStrings(*M);
  // 7b. Debug: all OpName and all OpMemberName.
  outputModuleSection(SPIRV::MB_DebugNames);
  // 7c. Debug: all OpModuleProcessed instructions.
  outputModuleSection(SPIRV::MB_DebugModuleProcessed);
  // xxx. SPV_INTEL_memory_access_aliasing instructions go before 8.
  // "All annotation instructions"
  outputModuleSection(SPIRV::MB_AliasingInsts);
  // 8. All annotation instructions (all decorations).
  outputAnnotations(*M);
  // 9. All type declarations (OpTypeXXX instructions), all constant
  // instructions, and all global variable declarations. This section is
  // the first section to allow use of: OpLine and OpNoLine debug information;
  // non-semantic instructions with OpExtInst.
  outputModuleSection(SPIRV::MB_TypeConstVars);
  // 10. All global NonSemantic.Shader.DebugInfo.100 instructions.
  outputModuleSection(SPIRV::MB_NonSemanticGlobalDI);
  // 11. All function declarations (functions without a body).
  outputExtFuncDecls();
  // 12. All function definitions (functions with a body).
  // This is done in regular function output.
}

bool SPIRVAsmPrinter::doInitialization(Module &M) {
  ModuleSectionsEmitted = false;
  // We need to call the parent's one explicitly.
  return AsmPrinter::doInitialization(M);
}

char SPIRVAsmPrinter::ID = 0;

INITIALIZE_PASS(SPIRVAsmPrinter, "spirv-asm-printer", "SPIRV Assembly Printer",
                false, false)

// Force static initialization.
extern "C" LLVM_ABI LLVM_EXTERNAL_VISIBILITY void
LLVMInitializeSPIRVAsmPrinter() {
  RegisterAsmPrinter<SPIRVAsmPrinter> X(getTheSPIRV32Target());
  RegisterAsmPrinter<SPIRVAsmPrinter> Y(getTheSPIRV64Target());
  RegisterAsmPrinter<SPIRVAsmPrinter> Z(getTheSPIRVLogicalTarget());
}
