diff --git a/.github/actions/godot-build/action.yml b/.github/actions/godot-build/action.yml index bf29b7e430a..61613d26289 100644 --- a/.github/actions/godot-build/action.yml +++ b/.github/actions/godot-build/action.yml @@ -7,11 +7,14 @@ inputs: tests: description: Unit tests. default: false + required: false platform: description: Target platform. required: false sconsflags: + description: Additional SCons flags. default: "" + required: false scons-cache: description: The SCons cache path. default: "${{ github.workspace }}/.scons-cache/" diff --git a/COPYRIGHT.txt b/COPYRIGHT.txt index 1e32c87c016..a9d6cd7d320 100644 --- a/COPYRIGHT.txt +++ b/COPYRIGHT.txt @@ -475,6 +475,11 @@ Comment: RVO2 Copyright: 2016, University of North Carolina at Chapel Hill License: Apache-2.0 +Files: ./thirdparty/spirv-cross/ +Comment: SPIRV-Cross +Copyright: 2015-2021, Arm Limited +License: Apache-2.0 or Expat + Files: ./thirdparty/spirv-reflect/ Comment: SPIRV-Reflect Copyright: 2017-2022, Google Inc. diff --git a/SConstruct b/SConstruct index 599db4b8deb..ca160c1762b 100644 --- a/SConstruct +++ b/SConstruct @@ -222,6 +222,7 @@ opts.Add(BoolVariable("xaudio2", "Enable the XAudio2 audio driver", False)) opts.Add(BoolVariable("vulkan", "Enable the vulkan rendering driver", True)) opts.Add(BoolVariable("opengl3", "Enable the OpenGL/GLES3 rendering driver", True)) opts.Add(BoolVariable("d3d12", "Enable the Direct3D 12 rendering driver", False)) +opts.Add(BoolVariable("metal", "Enable the Metal rendering driver (Apple arm64 only)", False)) opts.Add(BoolVariable("openxr", "Enable the OpenXR driver", True)) opts.Add(BoolVariable("use_volk", "Use the volk library to load the Vulkan loader dynamically", True)) opts.Add(BoolVariable("disable_exceptions", "Force disabling exception handling code", True)) diff --git a/core/core_bind.cpp b/core/core_bind.cpp index a15085bcde5..36f662b92b9 100644 --- a/core/core_bind.cpp +++ b/core/core_bind.cpp @@ -692,6 +692,7 @@ void OS::_bind_methods() { BIND_ENUM_CONSTANT(RENDERING_DRIVER_VULKAN); BIND_ENUM_CONSTANT(RENDERING_DRIVER_OPENGL3); BIND_ENUM_CONSTANT(RENDERING_DRIVER_D3D12); + BIND_ENUM_CONSTANT(RENDERING_DRIVER_METAL); BIND_ENUM_CONSTANT(SYSTEM_DIR_DESKTOP); BIND_ENUM_CONSTANT(SYSTEM_DIR_DCIM); diff --git a/core/core_bind.h b/core/core_bind.h index 69a359e4ed5..d744da25511 100644 --- a/core/core_bind.h +++ b/core/core_bind.h @@ -132,6 +132,7 @@ public: RENDERING_DRIVER_VULKAN, RENDERING_DRIVER_OPENGL3, RENDERING_DRIVER_D3D12, + RENDERING_DRIVER_METAL, }; PackedByteArray get_entropy(int p_bytes); diff --git a/doc/classes/OS.xml b/doc/classes/OS.xml index bc20ff4e454..9675f5af507 100644 --- a/doc/classes/OS.xml +++ b/doc/classes/OS.xml @@ -802,6 +802,9 @@ The Direct3D 12 rendering driver. + + The Metal rendering driver. + Refers to the Desktop directory path. diff --git a/drivers/SCsub b/drivers/SCsub index e77b96cc87d..521d607740a 100644 --- a/drivers/SCsub +++ b/drivers/SCsub @@ -33,6 +33,8 @@ if env["opengl3"]: SConscript("gl_context/SCsub") SConscript("gles3/SCsub") SConscript("egl/SCsub") +if env["metal"]: + SConscript("metal/SCsub") # Core dependencies SConscript("png/SCsub") diff --git a/drivers/metal/README.md b/drivers/metal/README.md new file mode 100644 index 00000000000..30cfa523609 --- /dev/null +++ b/drivers/metal/README.md @@ -0,0 +1,39 @@ +# Metal Rendering Device + +This document aims to describe the Metal rendering device implementation in Godot. + +## Future work / ideas + +* Use placement heaps +* Explicit hazard tracking +* [MetalFX] upscaling support? + +## Acknowledgments + +The Metal rendering owes a lot to the work of the [MoltenVK] project, which is a Vulkan implementation on top of Metal. +In accordance with the Apache 2.0 license, the following copyright notices have been included where applicable: + +``` +/**************************************************************************/ +/* */ +/* Portions of this code were derived from MoltenVK. */ +/* */ +/* Copyright (c) 2015-2023 The Brenwill Workshop Ltd. */ +/* (http://www.brenwill.com) */ +/* */ +/* Licensed under the Apache License, Version 2.0 (the "License"); */ +/* you may not use this file except in compliance with the License. */ +/* You may obtain a copy of the License at */ +/* */ +/* http://www.apache.org/licenses/LICENSE-2.0 */ +/* */ +/* Unless required by applicable law or agreed to in writing, software */ +/* distributed under the License is distributed on an "AS IS" BASIS, */ +/* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or */ +/* implied. See the License for the specific language governing */ +/* permissions and limitations under the License. */ +/**************************************************************************/ +``` + +[MoltenVK]: https://github.com/KhronosGroup/MoltenVK +[MetalFX]: https://developer.apple.com/documentation/metalfx?language=objc diff --git a/drivers/metal/SCsub b/drivers/metal/SCsub new file mode 100644 index 00000000000..30129b78061 --- /dev/null +++ b/drivers/metal/SCsub @@ -0,0 +1,49 @@ +#!/usr/bin/env python + +Import("env") + +env_metal = env.Clone() + +# Thirdparty source files + +thirdparty_obj = [] + +thirdparty_dir = "#thirdparty/spirv-cross/" +thirdparty_sources = [ + "spirv_cfg.cpp", + "spirv_cross_util.cpp", + "spirv_cross.cpp", + "spirv_parser.cpp", + "spirv_msl.cpp", + "spirv_reflect.cpp", + "spirv_glsl.cpp", + "spirv_cross_parsed_ir.cpp", +] +thirdparty_sources = [thirdparty_dir + file for file in thirdparty_sources] + +env_metal.Prepend(CPPPATH=[thirdparty_dir, thirdparty_dir + "/include"]) + +# Must enable exceptions for SPIRV-Cross; otherwise, it will abort the process on errors. +if "-fno-exceptions" in env_metal["CXXFLAGS"]: + env_metal["CXXFLAGS"].remove("-fno-exceptions") +env_metal.Append(CXXFLAGS=["-fexceptions"]) + +env_thirdparty = env_metal.Clone() +env_thirdparty.disable_warnings() +env_thirdparty.add_source_files(thirdparty_obj, thirdparty_sources) +env_metal.drivers_sources += thirdparty_obj + +# Enable C++20 for the Objective-C++ Metal code, which uses C++20 concepts. +if "-std=gnu++17" in env_metal["CXXFLAGS"]: + env_metal["CXXFLAGS"].remove("-std=gnu++17") +env_metal.Append(CXXFLAGS=["-std=c++20"]) + +# Driver source files + +driver_obj = [] + +env_metal.add_source_files(driver_obj, "*.mm") +env.drivers_sources += driver_obj + +# Needed to force rebuilding the driver files when the thirdparty library is updated. +env.Depends(driver_obj, thirdparty_obj) diff --git a/drivers/metal/metal_device_properties.h b/drivers/metal/metal_device_properties.h new file mode 100644 index 00000000000..7467e8ceb49 --- /dev/null +++ b/drivers/metal/metal_device_properties.h @@ -0,0 +1,141 @@ +/**************************************************************************/ +/* metal_device_properties.h */ +/**************************************************************************/ +/* This file is part of: */ +/* GODOT ENGINE */ +/* https://godotengine.org */ +/**************************************************************************/ +/* Copyright (c) 2014-present Godot Engine contributors (see AUTHORS.md). */ +/* Copyright (c) 2007-2014 Juan Linietsky, Ariel Manzur. */ +/* */ +/* Permission is hereby granted, free of charge, to any person obtaining */ +/* a copy of this software and associated documentation files (the */ +/* "Software"), to deal in the Software without restriction, including */ +/* without limitation the rights to use, copy, modify, merge, publish, */ +/* distribute, sublicense, and/or sell copies of the Software, and to */ +/* permit persons to whom the Software is furnished to do so, subject to */ +/* the following conditions: */ +/* */ +/* The above copyright notice and this permission notice shall be */ +/* included in all copies or substantial portions of the Software. */ +/* */ +/* THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, */ +/* EXPRESS OR IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF */ +/* MERCHANTABILITY, FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. */ +/* IN NO EVENT SHALL THE AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY */ +/* CLAIM, DAMAGES OR OTHER LIABILITY, WHETHER IN AN ACTION OF CONTRACT, */ +/* TORT OR OTHERWISE, ARISING FROM, OUT OF OR IN CONNECTION WITH THE */ +/* SOFTWARE OR THE USE OR OTHER DEALINGS IN THE SOFTWARE. */ +/**************************************************************************/ + +/**************************************************************************/ +/* */ +/* Portions of this code were derived from MoltenVK. */ +/* */ +/* Copyright (c) 2015-2023 The Brenwill Workshop Ltd. */ +/* (http://www.brenwill.com) */ +/* */ +/* Licensed under the Apache License, Version 2.0 (the "License"); */ +/* you may not use this file except in compliance with the License. */ +/* You may obtain a copy of the License at */ +/* */ +/* http://www.apache.org/licenses/LICENSE-2.0 */ +/* */ +/* Unless required by applicable law or agreed to in writing, software */ +/* distributed under the License is distributed on an "AS IS" BASIS, */ +/* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or */ +/* implied. See the License for the specific language governing */ +/* permissions and limitations under the License. */ +/**************************************************************************/ + +#ifndef METAL_DEVICE_PROPERTIES_H +#define METAL_DEVICE_PROPERTIES_H + +#import "servers/rendering/rendering_device.h" + +#import +#import + +/** The buffer index to use for vertex content. */ +const static uint32_t VERT_CONTENT_BUFFER_INDEX = 0; +const static uint32_t MAX_COLOR_ATTACHMENT_COUNT = 8; + +typedef NS_OPTIONS(NSUInteger, SampleCount) { + SampleCount1 = (1UL << 0), + SampleCount2 = (1UL << 1), + SampleCount4 = (1UL << 2), + SampleCount8 = (1UL << 3), + SampleCount16 = (1UL << 4), + SampleCount32 = (1UL << 5), + SampleCount64 = (1UL << 6), +}; + +struct API_AVAILABLE(macos(11.0), ios(14.0)) MetalFeatures { + uint32_t mslVersion; + MTLGPUFamily highestFamily; + MTLLanguageVersion mslVersionEnum; + SampleCount supportedSampleCounts; + long hostMemoryPageSize; + bool layeredRendering; + bool multisampleLayeredRendering; + bool quadPermute; /**< If true, quadgroup permutation functions (vote, ballot, shuffle) are supported in shaders. */ + bool simdPermute; /**< If true, SIMD-group permutation functions (vote, ballot, shuffle) are supported in shaders. */ + bool simdReduction; /**< If true, SIMD-group reduction functions (arithmetic) are supported in shaders. */ + bool tessellationShader; /**< If true, tessellation shaders are supported. */ + bool imageCubeArray; /**< If true, image cube arrays are supported. */ +}; + +struct MetalLimits { + uint64_t maxImageArrayLayers; + uint64_t maxFramebufferHeight; + uint64_t maxFramebufferWidth; + uint64_t maxImageDimension1D; + uint64_t maxImageDimension2D; + uint64_t maxImageDimension3D; + uint64_t maxImageDimensionCube; + uint64_t maxViewportDimensionX; + uint64_t maxViewportDimensionY; + MTLSize maxThreadsPerThreadGroup; + MTLSize maxComputeWorkGroupCount; + uint64_t maxBoundDescriptorSets; + uint64_t maxColorAttachments; + uint64_t maxTexturesPerArgumentBuffer; + uint64_t maxSamplersPerArgumentBuffer; + uint64_t maxBuffersPerArgumentBuffer; + uint64_t maxBufferLength; + uint64_t minUniformBufferOffsetAlignment; + uint64_t maxVertexDescriptorLayoutStride; + uint16_t maxViewports; + uint32_t maxPerStageBufferCount; /**< The total number of per-stage Metal buffers available for shader uniform content and attributes. */ + uint32_t maxPerStageTextureCount; /**< The total number of per-stage Metal textures available for shader uniform content. */ + uint32_t maxPerStageSamplerCount; /**< The total number of per-stage Metal samplers available for shader uniform content. */ + uint32_t maxVertexInputAttributes; + uint32_t maxVertexInputBindings; + uint32_t maxVertexInputBindingStride; + uint32_t maxDrawIndexedIndexValue; + + uint32_t minSubgroupSize; /**< The minimum number of threads in a SIMD-group. */ + uint32_t maxSubgroupSize; /**< The maximum number of threads in a SIMD-group. */ + BitField subgroupSupportedShaderStages; + BitField subgroupSupportedOperations; /**< The subgroup operations supported by the device. */ +}; + +class API_AVAILABLE(macos(11.0), ios(14.0)) MetalDeviceProperties { +private: + void init_features(id p_device); + void init_limits(id p_device); + +public: + MetalFeatures features; + MetalLimits limits; + + SampleCount find_nearest_supported_sample_count(RenderingDevice::TextureSamples p_samples) const; + + MetalDeviceProperties(id p_device); + ~MetalDeviceProperties(); + +private: + static const SampleCount sample_count[RenderingDevice::TextureSamples::TEXTURE_SAMPLES_MAX]; +}; + +#endif // METAL_DEVICE_PROPERTIES_H diff --git a/drivers/metal/metal_device_properties.mm b/drivers/metal/metal_device_properties.mm new file mode 100644 index 00000000000..857fa8c66ee --- /dev/null +++ b/drivers/metal/metal_device_properties.mm @@ -0,0 +1,327 @@ +/**************************************************************************/ +/* metal_device_properties.mm */ +/**************************************************************************/ +/* This file is part of: */ +/* GODOT ENGINE */ +/* https://godotengine.org */ +/**************************************************************************/ +/* Copyright (c) 2014-present Godot Engine contributors (see AUTHORS.md). */ +/* Copyright (c) 2007-2014 Juan Linietsky, Ariel Manzur. */ +/* */ +/* Permission is hereby granted, free of charge, to any person obtaining */ +/* a copy of this software and associated documentation files (the */ +/* "Software"), to deal in the Software without restriction, including */ +/* without limitation the rights to use, copy, modify, merge, publish, */ +/* distribute, sublicense, and/or sell copies of the Software, and to */ +/* permit persons to whom the Software is furnished to do so, subject to */ +/* the following conditions: */ +/* */ +/* The above copyright notice and this permission notice shall be */ +/* included in all copies or substantial portions of the Software. */ +/* */ +/* THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, */ +/* EXPRESS OR IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF */ +/* MERCHANTABILITY, FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. */ +/* IN NO EVENT SHALL THE AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY */ +/* CLAIM, DAMAGES OR OTHER LIABILITY, WHETHER IN AN ACTION OF CONTRACT, */ +/* TORT OR OTHERWISE, ARISING FROM, OUT OF OR IN CONNECTION WITH THE */ +/* SOFTWARE OR THE USE OR OTHER DEALINGS IN THE SOFTWARE. */ +/**************************************************************************/ + +/**************************************************************************/ +/* */ +/* Portions of this code were derived from MoltenVK. */ +/* */ +/* Copyright (c) 2015-2023 The Brenwill Workshop Ltd. */ +/* (http://www.brenwill.com) */ +/* */ +/* Licensed under the Apache License, Version 2.0 (the "License"); */ +/* you may not use this file except in compliance with the License. */ +/* You may obtain a copy of the License at */ +/* */ +/* http://www.apache.org/licenses/LICENSE-2.0 */ +/* */ +/* Unless required by applicable law or agreed to in writing, software */ +/* distributed under the License is distributed on an "AS IS" BASIS, */ +/* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or */ +/* implied. See the License for the specific language governing */ +/* permissions and limitations under the License. */ +/**************************************************************************/ + +#import "metal_device_properties.h" + +#import +#import +#import + +// Common scaling multipliers. +#define KIBI (1024) +#define MEBI (KIBI * KIBI) + +#if (TARGET_OS_OSX && __MAC_OS_X_VERSION_MAX_ALLOWED < 140000) || (TARGET_OS_IOS && __IPHONE_OS_VERSION_MAX_ALLOWED < 170000) +#define MTLGPUFamilyApple9 (MTLGPUFamily)1009 +#endif + +API_AVAILABLE(macos(11.0), ios(14.0)) +MTLGPUFamily &operator--(MTLGPUFamily &p_family) { + p_family = static_cast(static_cast(p_family) - 1); + if (p_family < MTLGPUFamilyApple1) { + p_family = MTLGPUFamilyApple9; + } + + return p_family; +} + +void MetalDeviceProperties::init_features(id p_device) { + features = {}; + + features.highestFamily = MTLGPUFamilyApple1; + for (MTLGPUFamily family = MTLGPUFamilyApple9; family >= MTLGPUFamilyApple1; --family) { + if ([p_device supportsFamily:family]) { + features.highestFamily = family; + break; + } + } + + features.hostMemoryPageSize = sysconf(_SC_PAGESIZE); + + for (SampleCount sc = SampleCount1; sc <= SampleCount64; sc <<= 1) { + if ([p_device supportsTextureSampleCount:sc]) { + features.supportedSampleCounts |= sc; + } + } + + features.layeredRendering = [p_device supportsFamily:MTLGPUFamilyApple5]; + features.multisampleLayeredRendering = [p_device supportsFamily:MTLGPUFamilyApple7]; + features.tessellationShader = [p_device supportsFamily:MTLGPUFamilyApple3]; + features.imageCubeArray = [p_device supportsFamily:MTLGPUFamilyApple3]; + features.quadPermute = [p_device supportsFamily:MTLGPUFamilyApple4]; + features.simdPermute = [p_device supportsFamily:MTLGPUFamilyApple6]; + features.simdReduction = [p_device supportsFamily:MTLGPUFamilyApple7]; + + MTLCompileOptions *opts = [MTLCompileOptions new]; + features.mslVersionEnum = opts.languageVersion; // By default, Metal uses the most recent language version. + +#define setMSLVersion(m_maj, m_min) \ + features.mslVersion = SPIRV_CROSS_NAMESPACE::CompilerMSL::Options::make_msl_version(m_maj, m_min) + + switch (features.mslVersionEnum) { +#if __MAC_OS_X_VERSION_MAX_ALLOWED >= 150000 || __IPHONE_OS_VERSION_MAX_ALLOWED >= 180000 + case MTLLanguageVersion3_2: + setMSLVersion(3, 2); + break; +#endif +#if __MAC_OS_X_VERSION_MAX_ALLOWED >= 140000 || __IPHONE_OS_VERSION_MAX_ALLOWED >= 170000 + case MTLLanguageVersion3_1: + setMSLVersion(3, 1); + break; +#endif + case MTLLanguageVersion3_0: + setMSLVersion(3, 0); + break; + case MTLLanguageVersion2_4: + setMSLVersion(2, 4); + break; + case MTLLanguageVersion2_3: + setMSLVersion(2, 3); + break; + case MTLLanguageVersion2_2: + setMSLVersion(2, 2); + break; + case MTLLanguageVersion2_1: + setMSLVersion(2, 1); + break; + case MTLLanguageVersion2_0: + setMSLVersion(2, 0); + break; + case MTLLanguageVersion1_2: + setMSLVersion(1, 2); + break; + case MTLLanguageVersion1_1: + setMSLVersion(1, 1); + break; +#if TARGET_OS_IPHONE && !TARGET_OS_MACCATALYST + case MTLLanguageVersion1_0: + setMSLVersion(1, 0); + break; +#endif + } +} + +void MetalDeviceProperties::init_limits(id p_device) { + using std::max; + using std::min; + + // FST: https://developer.apple.com/metal/Metal-Feature-Set-Tables.pdf + + // FST: Maximum number of layers per 1D texture array, 2D texture array, or 3D texture. + limits.maxImageArrayLayers = 2048; + if ([p_device supportsFamily:MTLGPUFamilyApple3]) { + // FST: Maximum 2D texture width and height. + limits.maxFramebufferWidth = 16384; + limits.maxFramebufferHeight = 16384; + limits.maxViewportDimensionX = 16384; + limits.maxViewportDimensionY = 16384; + // FST: Maximum 1D texture width. + limits.maxImageDimension1D = 16384; + // FST: Maximum 2D texture width and height. + limits.maxImageDimension2D = 16384; + // FST: Maximum cube map texture width and height. + limits.maxImageDimensionCube = 16384; + } else { + // FST: Maximum 2D texture width and height. + limits.maxFramebufferWidth = 8192; + limits.maxFramebufferHeight = 8192; + limits.maxViewportDimensionX = 8192; + limits.maxViewportDimensionY = 8192; + // FST: Maximum 1D texture width. + limits.maxImageDimension1D = 8192; + // FST: Maximum 2D texture width and height. + limits.maxImageDimension2D = 8192; + // FST: Maximum cube map texture width and height. + limits.maxImageDimensionCube = 8192; + } + // FST: Maximum 3D texture width, height, and depth. + limits.maxImageDimension3D = 2048; + + limits.maxThreadsPerThreadGroup = p_device.maxThreadsPerThreadgroup; + // No effective limits. + limits.maxComputeWorkGroupCount = { std::numeric_limits::max(), std::numeric_limits::max(), std::numeric_limits::max() }; + // https://github.com/KhronosGroup/MoltenVK/blob/568cc3acc0e2299931fdaecaaa1fc3ec5b4af281/MoltenVK/MoltenVK/GPUObjects/MVKDevice.h#L85 + limits.maxBoundDescriptorSets = SPIRV_CROSS_NAMESPACE::kMaxArgumentBuffers; + // FST: Maximum number of color render targets per render pass descriptor. + limits.maxColorAttachments = 8; + + // Maximum number of textures the device can access, per stage, from an argument buffer. + if ([p_device supportsFamily:MTLGPUFamilyApple6]) { + limits.maxTexturesPerArgumentBuffer = 1'000'000; + } else if ([p_device supportsFamily:MTLGPUFamilyApple4]) { + limits.maxTexturesPerArgumentBuffer = 96; + } else { + limits.maxTexturesPerArgumentBuffer = 31; + } + + // Maximum number of samplers the device can access, per stage, from an argument buffer. + if ([p_device supportsFamily:MTLGPUFamilyApple6]) { + limits.maxSamplersPerArgumentBuffer = 1024; + } else { + limits.maxSamplersPerArgumentBuffer = 16; + } + + // Maximum number of buffers the device can access, per stage, from an argument buffer. + if ([p_device supportsFamily:MTLGPUFamilyApple6]) { + limits.maxBuffersPerArgumentBuffer = std::numeric_limits::max(); + } else if ([p_device supportsFamily:MTLGPUFamilyApple4]) { + limits.maxBuffersPerArgumentBuffer = 96; + } else { + limits.maxBuffersPerArgumentBuffer = 31; + } + + limits.minSubgroupSize = limits.maxSubgroupSize = 1; + // These values were taken from MoltenVK. + if (features.simdPermute) { + limits.minSubgroupSize = 4; + limits.maxSubgroupSize = 32; + } else if (features.quadPermute) { + limits.minSubgroupSize = limits.maxSubgroupSize = 4; + } + + limits.subgroupSupportedShaderStages.set_flag(RDD::ShaderStage::SHADER_STAGE_COMPUTE_BIT); + if (features.tessellationShader) { + limits.subgroupSupportedShaderStages.set_flag(RDD::ShaderStage::SHADER_STAGE_TESSELATION_CONTROL_BIT); + } + limits.subgroupSupportedShaderStages.set_flag(RDD::ShaderStage::SHADER_STAGE_FRAGMENT_BIT); + + limits.subgroupSupportedOperations.set_flag(RD::SubgroupOperations::SUBGROUP_BASIC_BIT); + if (features.simdPermute || features.quadPermute) { + limits.subgroupSupportedOperations.set_flag(RD::SubgroupOperations::SUBGROUP_VOTE_BIT); + limits.subgroupSupportedOperations.set_flag(RD::SubgroupOperations::SUBGROUP_BALLOT_BIT); + limits.subgroupSupportedOperations.set_flag(RD::SubgroupOperations::SUBGROUP_SHUFFLE_BIT); + limits.subgroupSupportedOperations.set_flag(RD::SubgroupOperations::SUBGROUP_SHUFFLE_RELATIVE_BIT); + } + + if (features.simdReduction) { + limits.subgroupSupportedOperations.set_flag(RD::SubgroupOperations::SUBGROUP_ARITHMETIC_BIT); + } + + if (features.quadPermute) { + limits.subgroupSupportedOperations.set_flag(RD::SubgroupOperations::SUBGROUP_QUAD_BIT); + } + + limits.maxBufferLength = p_device.maxBufferLength; + + // FST: Maximum size of vertex descriptor layout stride. + limits.maxVertexDescriptorLayoutStride = std::numeric_limits::max(); + + // Maximum number of viewports. + if ([p_device supportsFamily:MTLGPUFamilyApple5]) { + limits.maxViewports = 16; + } else { + limits.maxViewports = 1; + } + + limits.maxPerStageBufferCount = 31; + limits.maxPerStageSamplerCount = 16; + if ([p_device supportsFamily:MTLGPUFamilyApple6]) { + limits.maxPerStageTextureCount = 128; + } else if ([p_device supportsFamily:MTLGPUFamilyApple4]) { + limits.maxPerStageTextureCount = 96; + } else { + limits.maxPerStageTextureCount = 31; + } + + limits.maxVertexInputAttributes = 31; + limits.maxVertexInputBindings = 31; + limits.maxVertexInputBindingStride = (2 * KIBI); + +#if TARGET_OS_IOS && !TARGET_OS_MACCATALYST + limits.minUniformBufferOffsetAlignment = 64; +#endif + +#if TARGET_OS_OSX + // This is Apple Silicon specific. + limits.minUniformBufferOffsetAlignment = 16; +#endif + + limits.maxDrawIndexedIndexValue = std::numeric_limits::max() - 1; +} + +MetalDeviceProperties::MetalDeviceProperties(id p_device) { + init_features(p_device); + init_limits(p_device); +} + +MetalDeviceProperties::~MetalDeviceProperties() { +} + +SampleCount MetalDeviceProperties::find_nearest_supported_sample_count(RenderingDevice::TextureSamples p_samples) const { + SampleCount supported = features.supportedSampleCounts; + if (supported & sample_count[p_samples]) { + return sample_count[p_samples]; + } + + SampleCount requested_sample_count = sample_count[p_samples]; + // Find the nearest supported sample count. + while (requested_sample_count > SampleCount1) { + if (supported & requested_sample_count) { + return requested_sample_count; + } + requested_sample_count = (SampleCount)(requested_sample_count >> 1); + } + + return SampleCount1; +} + +// region static members + +const SampleCount MetalDeviceProperties::sample_count[RenderingDevice::TextureSamples::TEXTURE_SAMPLES_MAX] = { + SampleCount1, + SampleCount2, + SampleCount4, + SampleCount8, + SampleCount16, + SampleCount32, + SampleCount64, +}; + +// endregion diff --git a/drivers/metal/metal_objects.h b/drivers/metal/metal_objects.h new file mode 100644 index 00000000000..70f86f2facd --- /dev/null +++ b/drivers/metal/metal_objects.h @@ -0,0 +1,838 @@ +/**************************************************************************/ +/* metal_objects.h */ +/**************************************************************************/ +/* This file is part of: */ +/* GODOT ENGINE */ +/* https://godotengine.org */ +/**************************************************************************/ +/* Copyright (c) 2014-present Godot Engine contributors (see AUTHORS.md). */ +/* Copyright (c) 2007-2014 Juan Linietsky, Ariel Manzur. */ +/* */ +/* Permission is hereby granted, free of charge, to any person obtaining */ +/* a copy of this software and associated documentation files (the */ +/* "Software"), to deal in the Software without restriction, including */ +/* without limitation the rights to use, copy, modify, merge, publish, */ +/* distribute, sublicense, and/or sell copies of the Software, and to */ +/* permit persons to whom the Software is furnished to do so, subject to */ +/* the following conditions: */ +/* */ +/* The above copyright notice and this permission notice shall be */ +/* included in all copies or substantial portions of the Software. */ +/* */ +/* THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, */ +/* EXPRESS OR IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF */ +/* MERCHANTABILITY, FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. */ +/* IN NO EVENT SHALL THE AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY */ +/* CLAIM, DAMAGES OR OTHER LIABILITY, WHETHER IN AN ACTION OF CONTRACT, */ +/* TORT OR OTHERWISE, ARISING FROM, OUT OF OR IN CONNECTION WITH THE */ +/* SOFTWARE OR THE USE OR OTHER DEALINGS IN THE SOFTWARE. */ +/**************************************************************************/ + +/**************************************************************************/ +/* */ +/* Portions of this code were derived from MoltenVK. */ +/* */ +/* Copyright (c) 2015-2023 The Brenwill Workshop Ltd. */ +/* (http://www.brenwill.com) */ +/* */ +/* Licensed under the Apache License, Version 2.0 (the "License"); */ +/* you may not use this file except in compliance with the License. */ +/* You may obtain a copy of the License at */ +/* */ +/* http://www.apache.org/licenses/LICENSE-2.0 */ +/* */ +/* Unless required by applicable law or agreed to in writing, software */ +/* distributed under the License is distributed on an "AS IS" BASIS, */ +/* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or */ +/* implied. See the License for the specific language governing */ +/* permissions and limitations under the License. */ +/**************************************************************************/ + +#ifndef METAL_OBJECTS_H +#define METAL_OBJECTS_H + +#import "metal_device_properties.h" +#import "metal_utils.h" +#import "pixel_formats.h" + +#import "servers/rendering/rendering_device_driver.h" + +#import +#import +#import +#import +#import +#import +#import + +// These types can be used in Vector and other containers that use +// pointer operations not supported by ARC. +namespace MTL { +#define MTL_CLASS(name) \ + class name { \ + public: \ + name(id obj = nil) : m_obj(obj) {} \ + operator id() const { return m_obj; } \ + id m_obj; \ + }; + +MTL_CLASS(Texture) + +} //namespace MTL + +enum ShaderStageUsage : uint32_t { + None = 0, + Vertex = RDD::SHADER_STAGE_VERTEX_BIT, + Fragment = RDD::SHADER_STAGE_FRAGMENT_BIT, + TesselationControl = RDD::SHADER_STAGE_TESSELATION_CONTROL_BIT, + TesselationEvaluation = RDD::SHADER_STAGE_TESSELATION_EVALUATION_BIT, + Compute = RDD::SHADER_STAGE_COMPUTE_BIT, +}; + +_FORCE_INLINE_ ShaderStageUsage &operator|=(ShaderStageUsage &p_a, int p_b) { + p_a = ShaderStageUsage(uint32_t(p_a) | uint32_t(p_b)); + return p_a; +} + +enum class MDCommandBufferStateType { + None, + Render, + Compute, + Blit, +}; + +enum class MDPipelineType { + None, + Render, + Compute, +}; + +class MDRenderPass; +class MDPipeline; +class MDRenderPipeline; +class MDComputePipeline; +class MDFrameBuffer; +class RenderingDeviceDriverMetal; +class MDUniformSet; +class MDShader; + +#pragma mark - Resource Factory + +struct ClearAttKey { + const static uint32_t COLOR_COUNT = MAX_COLOR_ATTACHMENT_COUNT; + const static uint32_t DEPTH_INDEX = COLOR_COUNT; + const static uint32_t STENCIL_INDEX = DEPTH_INDEX + 1; + const static uint32_t ATTACHMENT_COUNT = STENCIL_INDEX + 1; + + uint16_t sample_count = 0; + uint16_t pixel_formats[ATTACHMENT_COUNT] = { 0 }; + + _FORCE_INLINE_ void set_color_format(uint32_t p_idx, MTLPixelFormat p_fmt) { pixel_formats[p_idx] = p_fmt; } + _FORCE_INLINE_ void set_depth_format(MTLPixelFormat p_fmt) { pixel_formats[DEPTH_INDEX] = p_fmt; } + _FORCE_INLINE_ void set_stencil_format(MTLPixelFormat p_fmt) { pixel_formats[STENCIL_INDEX] = p_fmt; } + _FORCE_INLINE_ MTLPixelFormat depth_format() const { return (MTLPixelFormat)pixel_formats[DEPTH_INDEX]; } + _FORCE_INLINE_ MTLPixelFormat stencil_format() const { return (MTLPixelFormat)pixel_formats[STENCIL_INDEX]; } + + _FORCE_INLINE_ bool is_enabled(uint32_t p_idx) const { return pixel_formats[p_idx] != 0; } + _FORCE_INLINE_ bool is_depth_enabled() const { return pixel_formats[DEPTH_INDEX] != 0; } + _FORCE_INLINE_ bool is_stencil_enabled() const { return pixel_formats[STENCIL_INDEX] != 0; } + + _FORCE_INLINE_ bool operator==(const ClearAttKey &p_rhs) const { + return memcmp(this, &p_rhs, sizeof(ClearAttKey)) == 0; + } + + uint32_t hash() const { + uint32_t h = hash_murmur3_one_32(sample_count); + h = hash_murmur3_buffer(pixel_formats, ATTACHMENT_COUNT * sizeof(pixel_formats[0]), h); + return h; + } +}; + +class API_AVAILABLE(macos(11.0), ios(14.0)) MDResourceFactory { +private: + RenderingDeviceDriverMetal *device_driver; + + id new_func(NSString *p_source, NSString *p_name, NSError **p_error); + id new_clear_vert_func(ClearAttKey &p_key); + id new_clear_frag_func(ClearAttKey &p_key); + NSString *get_format_type_string(MTLPixelFormat p_fmt); + +public: + id new_clear_pipeline_state(ClearAttKey &p_key, NSError **p_error); + id new_depth_stencil_state(bool p_use_depth, bool p_use_stencil); + + MDResourceFactory(RenderingDeviceDriverMetal *p_device_driver) : + device_driver(p_device_driver) {} + ~MDResourceFactory() = default; +}; + +class API_AVAILABLE(macos(11.0), ios(14.0)) MDResourceCache { +private: + typedef HashMap, HashableHasher> HashMap; + std::unique_ptr resource_factory; + HashMap clear_states; + + struct { + id all; + id depth_only; + id stencil_only; + id none; + } clear_depth_stencil_state; + +public: + id get_clear_render_pipeline_state(ClearAttKey &p_key, NSError **p_error); + id get_depth_stencil_state(bool p_use_depth, bool p_use_stencil); + + explicit MDResourceCache(RenderingDeviceDriverMetal *p_device_driver) : + resource_factory(new MDResourceFactory(p_device_driver)) {} + ~MDResourceCache() = default; +}; + +class API_AVAILABLE(macos(11.0), ios(14.0)) MDCommandBuffer { +private: + RenderingDeviceDriverMetal *device_driver = nullptr; + id queue = nil; + id commandBuffer = nil; + + void _end_compute_dispatch(); + void _end_blit(); + +#pragma mark - Render + + void _render_set_dirty_state(); + void _render_bind_uniform_sets(); + + static void _populate_vertices(simd::float4 *p_vertices, Size2i p_fb_size, VectorView p_rects); + static uint32_t _populate_vertices(simd::float4 *p_vertices, uint32_t p_index, Rect2i const &p_rect, Size2i p_fb_size); + void _end_render_pass(); + void _render_clear_render_area(); + +public: + MDCommandBufferStateType type = MDCommandBufferStateType::None; + + struct RenderState { + MDRenderPass *pass = nullptr; + MDFrameBuffer *frameBuffer = nullptr; + MDRenderPipeline *pipeline = nullptr; + LocalVector clear_values; + LocalVector viewports; + LocalVector scissors; + std::optional blend_constants; + uint32_t current_subpass = UINT32_MAX; + Rect2i render_area = {}; + bool is_rendering_entire_area = false; + MTLRenderPassDescriptor *desc = nil; + id encoder = nil; + id __unsafe_unretained index_buffer = nil; // Buffer is owned by RDD. + MTLIndexType index_type = MTLIndexTypeUInt16; + LocalVector __unsafe_unretained> vertex_buffers; + LocalVector vertex_offsets; + // clang-format off + enum DirtyFlag: uint8_t { + DIRTY_NONE = 0b0000'0000, + DIRTY_PIPELINE = 0b0000'0001, //! pipeline state + DIRTY_UNIFORMS = 0b0000'0010, //! uniform sets + DIRTY_DEPTH = 0b0000'0100, //! depth / stenci state + DIRTY_VERTEX = 0b0000'1000, //! vertex buffers + DIRTY_VIEWPORT = 0b0001'0000, //! viewport rectangles + DIRTY_SCISSOR = 0b0010'0000, //! scissor rectangles + DIRTY_BLEND = 0b0100'0000, //! blend state + DIRTY_RASTER = 0b1000'0000, //! encoder state like cull mode + + DIRTY_ALL = 0xff, + }; + // clang-format on + BitField dirty = DIRTY_NONE; + + LocalVector uniform_sets; + // Bit mask of the uniform sets that are dirty, to prevent redundant binding. + uint64_t uniform_set_mask = 0; + + _FORCE_INLINE_ void reset() { + pass = nil; + frameBuffer = nil; + pipeline = nil; + current_subpass = UINT32_MAX; + render_area = {}; + is_rendering_entire_area = false; + desc = nil; + encoder = nil; + index_buffer = nil; + index_type = MTLIndexTypeUInt16; + dirty = DIRTY_NONE; + uniform_sets.clear(); + uniform_set_mask = 0; + clear_values.clear(); + viewports.clear(); + scissors.clear(); + blend_constants.reset(); + vertex_buffers.clear(); + vertex_offsets.clear(); + } + + _FORCE_INLINE_ void mark_viewport_dirty() { + if (viewports.is_empty()) { + return; + } + dirty.set_flag(DirtyFlag::DIRTY_VIEWPORT); + } + + _FORCE_INLINE_ void mark_scissors_dirty() { + if (scissors.is_empty()) { + return; + } + dirty.set_flag(DirtyFlag::DIRTY_SCISSOR); + } + + _FORCE_INLINE_ void mark_vertex_dirty() { + if (vertex_buffers.is_empty()) { + return; + } + dirty.set_flag(DirtyFlag::DIRTY_VERTEX); + } + + _FORCE_INLINE_ void mark_uniforms_dirty(std::initializer_list l) { + if (uniform_sets.is_empty()) { + return; + } + for (uint32_t i : l) { + if (i < uniform_sets.size() && uniform_sets[i] != nullptr) { + uniform_set_mask |= 1 << i; + } + } + dirty.set_flag(DirtyFlag::DIRTY_UNIFORMS); + } + + _FORCE_INLINE_ void mark_uniforms_dirty(void) { + if (uniform_sets.is_empty()) { + return; + } + for (uint32_t i = 0; i < uniform_sets.size(); i++) { + if (uniform_sets[i] != nullptr) { + uniform_set_mask |= 1 << i; + } + } + dirty.set_flag(DirtyFlag::DIRTY_UNIFORMS); + } + + MTLScissorRect clip_to_render_area(MTLScissorRect p_rect) const { + uint32_t raLeft = render_area.position.x; + uint32_t raRight = raLeft + render_area.size.width; + uint32_t raBottom = render_area.position.y; + uint32_t raTop = raBottom + render_area.size.height; + + p_rect.x = CLAMP(p_rect.x, raLeft, MAX(raRight - 1, raLeft)); + p_rect.y = CLAMP(p_rect.y, raBottom, MAX(raTop - 1, raBottom)); + p_rect.width = MIN(p_rect.width, raRight - p_rect.x); + p_rect.height = MIN(p_rect.height, raTop - p_rect.y); + + return p_rect; + } + + Rect2i clip_to_render_area(Rect2i p_rect) const { + int32_t raLeft = render_area.position.x; + int32_t raRight = raLeft + render_area.size.width; + int32_t raBottom = render_area.position.y; + int32_t raTop = raBottom + render_area.size.height; + + p_rect.position.x = CLAMP(p_rect.position.x, raLeft, MAX(raRight - 1, raLeft)); + p_rect.position.y = CLAMP(p_rect.position.y, raBottom, MAX(raTop - 1, raBottom)); + p_rect.size.width = MIN(p_rect.size.width, raRight - p_rect.position.x); + p_rect.size.height = MIN(p_rect.size.height, raTop - p_rect.position.y); + + return p_rect; + } + + } render; + + // State specific for a compute pass. + struct { + MDComputePipeline *pipeline = nullptr; + id encoder = nil; + _FORCE_INLINE_ void reset() { + pipeline = nil; + encoder = nil; + } + } compute; + + // State specific to a blit pass. + struct { + id encoder = nil; + _FORCE_INLINE_ void reset() { + encoder = nil; + } + } blit; + + _FORCE_INLINE_ id get_command_buffer() const { + return commandBuffer; + } + + void begin(); + void commit(); + void end(); + + id blit_command_encoder(); + void encodeRenderCommandEncoderWithDescriptor(MTLRenderPassDescriptor *p_desc, NSString *p_label); + + void bind_pipeline(RDD::PipelineID p_pipeline); + +#pragma mark - Render Commands + + void render_bind_uniform_set(RDD::UniformSetID p_uniform_set, RDD::ShaderID p_shader, uint32_t p_set_index); + void render_clear_attachments(VectorView p_attachment_clears, VectorView p_rects); + void render_set_viewport(VectorView p_viewports); + void render_set_scissor(VectorView p_scissors); + void render_set_blend_constants(const Color &p_constants); + void render_begin_pass(RDD::RenderPassID p_render_pass, + RDD::FramebufferID p_frameBuffer, + RDD::CommandBufferType p_cmd_buffer_type, + const Rect2i &p_rect, + VectorView p_clear_values); + void render_next_subpass(); + void render_draw(uint32_t p_vertex_count, + uint32_t p_instance_count, + uint32_t p_base_vertex, + uint32_t p_first_instance); + void render_bind_vertex_buffers(uint32_t p_binding_count, const RDD::BufferID *p_buffers, const uint64_t *p_offsets); + void render_bind_index_buffer(RDD::BufferID p_buffer, RDD::IndexBufferFormat p_format, uint64_t p_offset); + + void render_draw_indexed(uint32_t p_index_count, + uint32_t p_instance_count, + uint32_t p_first_index, + int32_t p_vertex_offset, + uint32_t p_first_instance); + + void render_draw_indexed_indirect(RDD::BufferID p_indirect_buffer, uint64_t p_offset, uint32_t p_draw_count, uint32_t p_stride); + void render_draw_indexed_indirect_count(RDD::BufferID p_indirect_buffer, uint64_t p_offset, RDD::BufferID p_count_buffer, uint64_t p_count_buffer_offset, uint32_t p_max_draw_count, uint32_t p_stride); + void render_draw_indirect(RDD::BufferID p_indirect_buffer, uint64_t p_offset, uint32_t p_draw_count, uint32_t p_stride); + void render_draw_indirect_count(RDD::BufferID p_indirect_buffer, uint64_t p_offset, RDD::BufferID p_count_buffer, uint64_t p_count_buffer_offset, uint32_t p_max_draw_count, uint32_t p_stride); + + void render_end_pass(); + +#pragma mark - Compute Commands + + void compute_bind_uniform_set(RDD::UniformSetID p_uniform_set, RDD::ShaderID p_shader, uint32_t p_set_index); + void compute_dispatch(uint32_t p_x_groups, uint32_t p_y_groups, uint32_t p_z_groups); + void compute_dispatch_indirect(RDD::BufferID p_indirect_buffer, uint64_t p_offset); + + MDCommandBuffer(id p_queue, RenderingDeviceDriverMetal *p_device_driver) : + device_driver(p_device_driver), queue(p_queue) { + type = MDCommandBufferStateType::None; + } + + MDCommandBuffer() = default; +}; + +#if (TARGET_OS_OSX && __MAC_OS_X_VERSION_MAX_ALLOWED < 140000) || (TARGET_OS_IOS && __IPHONE_OS_VERSION_MAX_ALLOWED < 170000) +#define MTLBindingAccess MTLArgumentAccess +#define MTLBindingAccessReadOnly MTLArgumentAccessReadOnly +#define MTLBindingAccessReadWrite MTLArgumentAccessReadWrite +#define MTLBindingAccessWriteOnly MTLArgumentAccessWriteOnly +#endif + +struct API_AVAILABLE(macos(11.0), ios(14.0)) BindingInfo { + MTLDataType dataType = MTLDataTypeNone; + uint32_t index = 0; + MTLBindingAccess access = MTLBindingAccessReadOnly; + MTLResourceUsage usage = 0; + MTLTextureType textureType = MTLTextureType2D; + spv::ImageFormat imageFormat = spv::ImageFormatUnknown; + uint32_t arrayLength = 0; + bool isMultisampled = false; + + inline MTLArgumentDescriptor *new_argument_descriptor() const { + MTLArgumentDescriptor *desc = MTLArgumentDescriptor.argumentDescriptor; + desc.dataType = dataType; + desc.index = index; + desc.access = access; + desc.textureType = textureType; + desc.arrayLength = arrayLength; + return desc; + } + + size_t serialize_size() const { + return sizeof(uint32_t) * 8 /* 8 uint32_t fields */; + } + + template + void serialize(W &p_writer) const { + p_writer.write((uint32_t)dataType); + p_writer.write(index); + p_writer.write((uint32_t)access); + p_writer.write((uint32_t)usage); + p_writer.write((uint32_t)textureType); + p_writer.write(imageFormat); + p_writer.write(arrayLength); + p_writer.write(isMultisampled); + } + + template + void deserialize(R &p_reader) { + p_reader.read((uint32_t &)dataType); + p_reader.read(index); + p_reader.read((uint32_t &)access); + p_reader.read((uint32_t &)usage); + p_reader.read((uint32_t &)textureType); + p_reader.read((uint32_t &)imageFormat); + p_reader.read(arrayLength); + p_reader.read(isMultisampled); + } +}; + +using RDC = RenderingDeviceCommons; + +typedef API_AVAILABLE(macos(11.0), ios(14.0)) HashMap BindingInfoMap; + +struct API_AVAILABLE(macos(11.0), ios(14.0)) UniformInfo { + uint32_t binding; + ShaderStageUsage active_stages = None; + BindingInfoMap bindings; + BindingInfoMap bindings_secondary; +}; + +struct API_AVAILABLE(macos(11.0), ios(14.0)) UniformSet { + LocalVector uniforms; + uint32_t buffer_size = 0; + HashMap offsets; + HashMap> encoders; +}; + +class API_AVAILABLE(macos(11.0), ios(14.0)) MDShader { +public: + CharString name; + Vector sets; + + virtual void encode_push_constant_data(VectorView p_data, MDCommandBuffer *p_cb) = 0; + + MDShader(CharString p_name, Vector p_sets) : + name(p_name), sets(p_sets) {} + virtual ~MDShader() = default; +}; + +class API_AVAILABLE(macos(11.0), ios(14.0)) MDComputeShader final : public MDShader { +public: + struct { + uint32_t binding = -1; + uint32_t size = 0; + } push_constants; + MTLSize local = {}; + + id kernel; +#if DEV_ENABLED + CharString kernel_source; +#endif + + void encode_push_constant_data(VectorView p_data, MDCommandBuffer *p_cb) final; + + MDComputeShader(CharString p_name, Vector p_sets, id p_kernel); + ~MDComputeShader() override = default; +}; + +class API_AVAILABLE(macos(11.0), ios(14.0)) MDRenderShader final : public MDShader { +public: + struct { + struct { + int32_t binding = -1; + uint32_t size = 0; + } vert; + struct { + int32_t binding = -1; + uint32_t size = 0; + } frag; + } push_constants; + + id vert; + id frag; +#if DEV_ENABLED + CharString vert_source; + CharString frag_source; +#endif + + void encode_push_constant_data(VectorView p_data, MDCommandBuffer *p_cb) final; + + MDRenderShader(CharString p_name, Vector p_sets, id p_vert, id p_frag); + ~MDRenderShader() override = default; +}; + +enum StageResourceUsage : uint32_t { + VertexRead = (MTLResourceUsageRead << RDD::SHADER_STAGE_VERTEX * 2), + VertexWrite = (MTLResourceUsageWrite << RDD::SHADER_STAGE_VERTEX * 2), + FragmentRead = (MTLResourceUsageRead << RDD::SHADER_STAGE_FRAGMENT * 2), + FragmentWrite = (MTLResourceUsageWrite << RDD::SHADER_STAGE_FRAGMENT * 2), + TesselationControlRead = (MTLResourceUsageRead << RDD::SHADER_STAGE_TESSELATION_CONTROL * 2), + TesselationControlWrite = (MTLResourceUsageWrite << RDD::SHADER_STAGE_TESSELATION_CONTROL * 2), + TesselationEvaluationRead = (MTLResourceUsageRead << RDD::SHADER_STAGE_TESSELATION_EVALUATION * 2), + TesselationEvaluationWrite = (MTLResourceUsageWrite << RDD::SHADER_STAGE_TESSELATION_EVALUATION * 2), + ComputeRead = (MTLResourceUsageRead << RDD::SHADER_STAGE_COMPUTE * 2), + ComputeWrite = (MTLResourceUsageWrite << RDD::SHADER_STAGE_COMPUTE * 2), +}; + +_FORCE_INLINE_ StageResourceUsage &operator|=(StageResourceUsage &p_a, uint32_t p_b) { + p_a = StageResourceUsage(uint32_t(p_a) | p_b); + return p_a; +} + +_FORCE_INLINE_ StageResourceUsage stage_resource_usage(RDC::ShaderStage p_stage, MTLResourceUsage p_usage) { + return StageResourceUsage(p_usage << (p_stage * 2)); +} + +_FORCE_INLINE_ MTLResourceUsage resource_usage_for_stage(StageResourceUsage p_usage, RDC::ShaderStage p_stage) { + return MTLResourceUsage((p_usage >> (p_stage * 2)) & 0b11); +} + +template <> +struct HashMapComparatorDefault { + static bool compare(const RDD::ShaderID &p_lhs, const RDD::ShaderID &p_rhs) { + return p_lhs.id == p_rhs.id; + } +}; + +struct BoundUniformSet { + id buffer; + HashMap, StageResourceUsage> bound_resources; +}; + +class API_AVAILABLE(macos(11.0), ios(14.0)) MDUniformSet { +public: + uint32_t index; + LocalVector uniforms; + HashMap bound_uniforms; + + BoundUniformSet &boundUniformSetForShader(MDShader *p_shader, id p_device); +}; + +enum class MDAttachmentType : uint8_t { + None = 0, + Color = 1 << 0, + Depth = 1 << 1, + Stencil = 1 << 2, +}; + +_FORCE_INLINE_ MDAttachmentType &operator|=(MDAttachmentType &p_a, MDAttachmentType p_b) { + flags::set(p_a, p_b); + return p_a; +} + +_FORCE_INLINE_ bool operator&(MDAttachmentType p_a, MDAttachmentType p_b) { + return uint8_t(p_a) & uint8_t(p_b); +} + +struct MDSubpass { + uint32_t subpass_index = 0; + LocalVector input_references; + LocalVector color_references; + RDD::AttachmentReference depth_stencil_reference; + LocalVector resolve_references; + + MTLFmtCaps getRequiredFmtCapsForAttachmentAt(uint32_t p_index) const; +}; + +struct API_AVAILABLE(macos(11.0), ios(14.0)) MDAttachment { +private: + uint32_t index = 0; + uint32_t firstUseSubpassIndex = 0; + uint32_t lastUseSubpassIndex = 0; + +public: + MTLPixelFormat format = MTLPixelFormatInvalid; + MDAttachmentType type = MDAttachmentType::None; + MTLLoadAction loadAction = MTLLoadActionDontCare; + MTLStoreAction storeAction = MTLStoreActionDontCare; + MTLLoadAction stencilLoadAction = MTLLoadActionDontCare; + MTLStoreAction stencilStoreAction = MTLStoreActionDontCare; + uint32_t samples = 1; + + /*! + * @brief Returns true if this attachment is first used in the given subpass. + * @param p_subpass + * @return + */ + _FORCE_INLINE_ bool isFirstUseOf(MDSubpass const &p_subpass) const { + return p_subpass.subpass_index == firstUseSubpassIndex; + } + + /*! + * @brief Returns true if this attachment is last used in the given subpass. + * @param p_subpass + * @return + */ + _FORCE_INLINE_ bool isLastUseOf(MDSubpass const &p_subpass) const { + return p_subpass.subpass_index == lastUseSubpassIndex; + } + + void linkToSubpass(MDRenderPass const &p_pass); + + MTLStoreAction getMTLStoreAction(MDSubpass const &p_subpass, + bool p_is_rendering_entire_area, + bool p_has_resolve, + bool p_can_resolve, + bool p_is_stencil) const; + bool configureDescriptor(MTLRenderPassAttachmentDescriptor *p_desc, + PixelFormats &p_pf, + MDSubpass const &p_subpass, + id p_attachment, + bool p_is_rendering_entire_area, + bool p_has_resolve, + bool p_can_resolve, + bool p_is_stencil) const; + /** Returns whether this attachment should be cleared in the subpass. */ + bool shouldClear(MDSubpass const &p_subpass, bool p_is_stencil) const; +}; + +class API_AVAILABLE(macos(11.0), ios(14.0)) MDRenderPass { +public: + Vector attachments; + Vector subpasses; + + uint32_t get_sample_count() const { + return attachments.is_empty() ? 1 : attachments[0].samples; + } + + MDRenderPass(Vector &p_attachments, Vector &p_subpasses); +}; + +class API_AVAILABLE(macos(11.0), ios(14.0)) MDPipeline { +public: + MDPipelineType type; + + explicit MDPipeline(MDPipelineType p_type) : + type(p_type) {} + virtual ~MDPipeline() = default; +}; + +class API_AVAILABLE(macos(11.0), ios(14.0)) MDRenderPipeline final : public MDPipeline { +public: + id state = nil; + id depth_stencil = nil; + uint32_t push_constant_size = 0; + uint32_t push_constant_stages_mask = 0; + SampleCount sample_count = SampleCount1; + + struct { + MTLCullMode cull_mode = MTLCullModeNone; + MTLTriangleFillMode fill_mode = MTLTriangleFillModeFill; + MTLDepthClipMode clip_mode = MTLDepthClipModeClip; + MTLWinding winding = MTLWindingClockwise; + MTLPrimitiveType render_primitive = MTLPrimitiveTypePoint; + + struct { + bool enabled = false; + } depth_test; + + struct { + bool enabled = false; + float depth_bias = 0.0; + float slope_scale = 0.0; + float clamp = 0.0; + _FORCE_INLINE_ void apply(id __unsafe_unretained p_enc) const { + if (!enabled) { + return; + } + [p_enc setDepthBias:depth_bias slopeScale:slope_scale clamp:clamp]; + } + } depth_bias; + + struct { + bool enabled = false; + uint32_t front_reference = 0; + uint32_t back_reference = 0; + _FORCE_INLINE_ void apply(id __unsafe_unretained p_enc) const { + if (!enabled) + return; + [p_enc setStencilFrontReferenceValue:front_reference backReferenceValue:back_reference]; + }; + } stencil; + + struct { + bool enabled = false; + float r = 0.0; + float g = 0.0; + float b = 0.0; + float a = 0.0; + + _FORCE_INLINE_ void apply(id __unsafe_unretained p_enc) const { + //if (!enabled) + // return; + [p_enc setBlendColorRed:r green:g blue:b alpha:a]; + }; + } blend; + + _FORCE_INLINE_ void apply(id __unsafe_unretained p_enc) const { + [p_enc setCullMode:cull_mode]; + [p_enc setTriangleFillMode:fill_mode]; + [p_enc setDepthClipMode:clip_mode]; + [p_enc setFrontFacingWinding:winding]; + depth_bias.apply(p_enc); + stencil.apply(p_enc); + blend.apply(p_enc); + } + + } raster_state; + + MDRenderShader *shader = nil; + + MDRenderPipeline() : + MDPipeline(MDPipelineType::Render) {} + ~MDRenderPipeline() final = default; +}; + +class API_AVAILABLE(macos(11.0), ios(14.0)) MDComputePipeline final : public MDPipeline { +public: + id state = nil; + struct { + MTLSize local = {}; + } compute_state; + + MDComputeShader *shader = nil; + + explicit MDComputePipeline(id p_state) : + MDPipeline(MDPipelineType::Compute), state(p_state) {} + ~MDComputePipeline() final = default; +}; + +class API_AVAILABLE(macos(11.0), ios(14.0)) MDFrameBuffer { +public: + Vector textures; + Size2i size; + MDFrameBuffer(Vector p_textures, Size2i p_size) : + textures(p_textures), size(p_size) {} + MDFrameBuffer() {} + + virtual ~MDFrameBuffer() = default; +}; + +// These functions are used to convert between Objective-C objects and +// the RIDs used by Godot, respecting automatic reference counting. +namespace rid { + +// Converts an Objective-C object to a pointer, and incrementing the +// reference count. +_FORCE_INLINE_ +void *owned(id p_id) { + return (__bridge_retained void *)p_id; +} + +#define MAKE_ID(FROM, TO) \ + _FORCE_INLINE_ TO make(FROM p_obj) { return TO(owned(p_obj)); } + +MAKE_ID(id, RDD::TextureID) +MAKE_ID(id, RDD::BufferID) +MAKE_ID(id, RDD::SamplerID) +MAKE_ID(MTLVertexDescriptor *, RDD::VertexFormatID) +MAKE_ID(id, RDD::CommandPoolID) + +// Converts a pointer to an Objective-C object without changing the reference count. +_FORCE_INLINE_ +auto get(RDD::ID p_id) { + return (p_id.id) ? (__bridge ::id)(void *)p_id.id : nil; +} + +// Converts a pointer to an Objective-C object, and decrements the reference count. +_FORCE_INLINE_ +auto release(RDD::ID p_id) { + return (__bridge_transfer ::id)(void *)p_id.id; +} + +} // namespace rid + +#endif // METAL_OBJECTS_H diff --git a/drivers/metal/metal_objects.mm b/drivers/metal/metal_objects.mm new file mode 100644 index 00000000000..3ce00f74a30 --- /dev/null +++ b/drivers/metal/metal_objects.mm @@ -0,0 +1,1380 @@ +/**************************************************************************/ +/* metal_objects.mm */ +/**************************************************************************/ +/* This file is part of: */ +/* GODOT ENGINE */ +/* https://godotengine.org */ +/**************************************************************************/ +/* Copyright (c) 2014-present Godot Engine contributors (see AUTHORS.md). */ +/* Copyright (c) 2007-2014 Juan Linietsky, Ariel Manzur. */ +/* */ +/* Permission is hereby granted, free of charge, to any person obtaining */ +/* a copy of this software and associated documentation files (the */ +/* "Software"), to deal in the Software without restriction, including */ +/* without limitation the rights to use, copy, modify, merge, publish, */ +/* distribute, sublicense, and/or sell copies of the Software, and to */ +/* permit persons to whom the Software is furnished to do so, subject to */ +/* the following conditions: */ +/* */ +/* The above copyright notice and this permission notice shall be */ +/* included in all copies or substantial portions of the Software. */ +/* */ +/* THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, */ +/* EXPRESS OR IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF */ +/* MERCHANTABILITY, FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. */ +/* IN NO EVENT SHALL THE AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY */ +/* CLAIM, DAMAGES OR OTHER LIABILITY, WHETHER IN AN ACTION OF CONTRACT, */ +/* TORT OR OTHERWISE, ARISING FROM, OUT OF OR IN CONNECTION WITH THE */ +/* SOFTWARE OR THE USE OR OTHER DEALINGS IN THE SOFTWARE. */ +/**************************************************************************/ + +/**************************************************************************/ +/* */ +/* Portions of this code were derived from MoltenVK. */ +/* */ +/* Copyright (c) 2015-2023 The Brenwill Workshop Ltd. */ +/* (http://www.brenwill.com) */ +/* */ +/* Licensed under the Apache License, Version 2.0 (the "License"); */ +/* you may not use this file except in compliance with the License. */ +/* You may obtain a copy of the License at */ +/* */ +/* http://www.apache.org/licenses/LICENSE-2.0 */ +/* */ +/* Unless required by applicable law or agreed to in writing, software */ +/* distributed under the License is distributed on an "AS IS" BASIS, */ +/* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or */ +/* implied. See the License for the specific language governing */ +/* permissions and limitations under the License. */ +/**************************************************************************/ + +#import "metal_objects.h" + +#import "pixel_formats.h" +#import "rendering_device_driver_metal.h" + +void MDCommandBuffer::begin() { + DEV_ASSERT(commandBuffer == nil); + commandBuffer = queue.commandBuffer; +} + +void MDCommandBuffer::end() { + switch (type) { + case MDCommandBufferStateType::None: + return; + case MDCommandBufferStateType::Render: + return render_end_pass(); + case MDCommandBufferStateType::Compute: + return _end_compute_dispatch(); + case MDCommandBufferStateType::Blit: + return _end_blit(); + } +} + +void MDCommandBuffer::commit() { + end(); + [commandBuffer commit]; + commandBuffer = nil; +} + +void MDCommandBuffer::bind_pipeline(RDD::PipelineID p_pipeline) { + MDPipeline *p = (MDPipeline *)(p_pipeline.id); + + // End current encoder if it is a compute encoder or blit encoder, + // as they do not have a defined end boundary in the RDD like render. + if (type == MDCommandBufferStateType::Compute) { + _end_compute_dispatch(); + } else if (type == MDCommandBufferStateType::Blit) { + _end_blit(); + } + + if (p->type == MDPipelineType::Render) { + DEV_ASSERT(type == MDCommandBufferStateType::Render); + MDRenderPipeline *rp = (MDRenderPipeline *)p; + + if (render.encoder == nil) { + // This condition occurs when there are no attachments when calling render_next_subpass() + // and is due to the SUPPORTS_FRAGMENT_SHADER_WITH_ONLY_SIDE_EFFECTS flag. + render.desc.defaultRasterSampleCount = static_cast(rp->sample_count); + +// NOTE(sgc): This is to test rdar://FB13605547 and will be deleted once fix is confirmed. +#if 0 + if (render.pipeline->sample_count == 4) { + static id tex = nil; + static id res_tex = nil; + static dispatch_once_t onceToken; + dispatch_once(&onceToken, ^{ + Size2i sz = render.frameBuffer->size; + MTLTextureDescriptor *td = [MTLTextureDescriptor texture2DDescriptorWithPixelFormat:MTLPixelFormatRGBA8Unorm width:sz.width height:sz.height mipmapped:NO]; + td.textureType = MTLTextureType2DMultisample; + td.storageMode = MTLStorageModeMemoryless; + td.usage = MTLTextureUsageRenderTarget; + td.sampleCount = render.pipeline->sample_count; + tex = [device_driver->get_device() newTextureWithDescriptor:td]; + + td.textureType = MTLTextureType2D; + td.storageMode = MTLStorageModePrivate; + td.usage = MTLTextureUsageShaderWrite; + td.sampleCount = 1; + res_tex = [device_driver->get_device() newTextureWithDescriptor:td]; + }); + render.desc.colorAttachments[0].texture = tex; + render.desc.colorAttachments[0].loadAction = MTLLoadActionClear; + render.desc.colorAttachments[0].storeAction = MTLStoreActionMultisampleResolve; + + render.desc.colorAttachments[0].resolveTexture = res_tex; + } +#endif + render.encoder = [commandBuffer renderCommandEncoderWithDescriptor:render.desc]; + } + + if (render.pipeline != rp) { + render.dirty.set_flag((RenderState::DirtyFlag)(RenderState::DIRTY_PIPELINE | RenderState::DIRTY_RASTER)); + // Mark all uniforms as dirty, as variants of a shader pipeline may have a different entry point ABI, + // due to setting force_active_argument_buffer_resources = true for spirv_cross::CompilerMSL::Options. + // As a result, uniform sets with the same layout will generate redundant binding warnings when + // capturing a Metal frame in Xcode. + // + // If we don't mark as dirty, then some bindings will generate a validation error. + render.mark_uniforms_dirty(); + if (render.pipeline != nullptr && render.pipeline->depth_stencil != rp->depth_stencil) { + render.dirty.set_flag(RenderState::DIRTY_DEPTH); + } + render.pipeline = rp; + } + } else if (p->type == MDPipelineType::Compute) { + DEV_ASSERT(type == MDCommandBufferStateType::None); + type = MDCommandBufferStateType::Compute; + + compute.pipeline = (MDComputePipeline *)p; + compute.encoder = commandBuffer.computeCommandEncoder; + [compute.encoder setComputePipelineState:compute.pipeline->state]; + } +} + +id MDCommandBuffer::blit_command_encoder() { + switch (type) { + case MDCommandBufferStateType::None: + break; + case MDCommandBufferStateType::Render: + render_end_pass(); + break; + case MDCommandBufferStateType::Compute: + _end_compute_dispatch(); + break; + case MDCommandBufferStateType::Blit: + return blit.encoder; + } + + type = MDCommandBufferStateType::Blit; + blit.encoder = commandBuffer.blitCommandEncoder; + return blit.encoder; +} + +void MDCommandBuffer::encodeRenderCommandEncoderWithDescriptor(MTLRenderPassDescriptor *p_desc, NSString *p_label) { + switch (type) { + case MDCommandBufferStateType::None: + break; + case MDCommandBufferStateType::Render: + render_end_pass(); + break; + case MDCommandBufferStateType::Compute: + _end_compute_dispatch(); + break; + case MDCommandBufferStateType::Blit: + _end_blit(); + break; + } + + id enc = [commandBuffer renderCommandEncoderWithDescriptor:p_desc]; + if (p_label != nil) { + [enc pushDebugGroup:p_label]; + [enc popDebugGroup]; + } + [enc endEncoding]; +} + +#pragma mark - Render Commands + +void MDCommandBuffer::render_bind_uniform_set(RDD::UniformSetID p_uniform_set, RDD::ShaderID p_shader, uint32_t p_set_index) { + DEV_ASSERT(type == MDCommandBufferStateType::Render); + + MDUniformSet *set = (MDUniformSet *)(p_uniform_set.id); + if (render.uniform_sets.size() <= set->index) { + uint32_t s = render.uniform_sets.size(); + render.uniform_sets.resize(set->index + 1); + // Set intermediate values to null. + std::fill(&render.uniform_sets[s], &render.uniform_sets[set->index] + 1, nullptr); + } + + if (render.uniform_sets[set->index] != set) { + render.dirty.set_flag(RenderState::DIRTY_UNIFORMS); + render.uniform_set_mask |= 1ULL << set->index; + render.uniform_sets[set->index] = set; + } +} + +void MDCommandBuffer::render_clear_attachments(VectorView p_attachment_clears, VectorView p_rects) { + DEV_ASSERT(type == MDCommandBufferStateType::Render); + + uint32_t vertex_count = p_rects.size() * 6; + + simd::float4 vertices[vertex_count]; + simd::float4 clear_colors[ClearAttKey::ATTACHMENT_COUNT]; + + Size2i size = render.frameBuffer->size; + Rect2i render_area = render.clip_to_render_area({ { 0, 0 }, size }); + size = Size2i(render_area.position.x + render_area.size.width, render_area.position.y + render_area.size.height); + _populate_vertices(vertices, size, p_rects); + + ClearAttKey key; + key.sample_count = render.pass->get_sample_count(); + + float depth_value = 0; + uint32_t stencil_value = 0; + + for (uint32_t i = 0; i < p_attachment_clears.size(); i++) { + RDD::AttachmentClear const &attClear = p_attachment_clears[i]; + uint32_t attachment_index; + if (attClear.aspect.has_flag(RDD::TEXTURE_ASPECT_COLOR_BIT)) { + attachment_index = attClear.color_attachment; + } else { + attachment_index = render.pass->subpasses[render.current_subpass].depth_stencil_reference.attachment; + } + + MDAttachment const &mda = render.pass->attachments[attachment_index]; + if (attClear.aspect.has_flag(RDD::TEXTURE_ASPECT_COLOR_BIT)) { + key.set_color_format(attachment_index, mda.format); + clear_colors[attachment_index] = { + attClear.value.color.r, + attClear.value.color.g, + attClear.value.color.b, + attClear.value.color.a + }; + } + + if (attClear.aspect.has_flag(RDD::TEXTURE_ASPECT_DEPTH_BIT)) { + key.set_depth_format(mda.format); + depth_value = attClear.value.depth; + } + + if (attClear.aspect.has_flag(RDD::TEXTURE_ASPECT_STENCIL_BIT)) { + key.set_stencil_format(mda.format); + stencil_value = attClear.value.stencil; + } + } + clear_colors[ClearAttKey::DEPTH_INDEX] = { + depth_value, + depth_value, + depth_value, + depth_value + }; + + id enc = render.encoder; + + MDResourceCache &cache = device_driver->get_resource_cache(); + + [enc pushDebugGroup:@"ClearAttachments"]; + [enc setRenderPipelineState:cache.get_clear_render_pipeline_state(key, nil)]; + [enc setDepthStencilState:cache.get_depth_stencil_state( + key.is_depth_enabled(), + key.is_stencil_enabled())]; + [enc setStencilReferenceValue:stencil_value]; + [enc setCullMode:MTLCullModeNone]; + [enc setTriangleFillMode:MTLTriangleFillModeFill]; + [enc setDepthBias:0 slopeScale:0 clamp:0]; + [enc setViewport:{ 0, 0, (double)size.width, (double)size.height, 0.0, 1.0 }]; + [enc setScissorRect:{ 0, 0, (NSUInteger)size.width, (NSUInteger)size.height }]; + + [enc setVertexBytes:clear_colors length:sizeof(clear_colors) atIndex:0]; + [enc setFragmentBytes:clear_colors length:sizeof(clear_colors) atIndex:0]; + [enc setVertexBytes:vertices length:vertex_count * sizeof(vertices[0]) atIndex:device_driver->get_metal_buffer_index_for_vertex_attribute_binding(VERT_CONTENT_BUFFER_INDEX)]; + + [enc drawPrimitives:MTLPrimitiveTypeTriangle vertexStart:0 vertexCount:vertex_count]; + [enc popDebugGroup]; + + render.dirty.set_flag((RenderState::DirtyFlag)(RenderState::DIRTY_PIPELINE | RenderState::DIRTY_DEPTH | RenderState::DIRTY_RASTER)); + render.mark_uniforms_dirty({ 0 }); // Mark index 0 dirty, if there is already a binding for index 0. + render.mark_viewport_dirty(); + render.mark_scissors_dirty(); + render.mark_vertex_dirty(); +} + +void MDCommandBuffer::_render_set_dirty_state() { + _render_bind_uniform_sets(); + + if (render.dirty.has_flag(RenderState::DIRTY_PIPELINE)) { + [render.encoder setRenderPipelineState:render.pipeline->state]; + } + + if (render.dirty.has_flag(RenderState::DIRTY_VIEWPORT)) { + [render.encoder setViewports:render.viewports.ptr() count:render.viewports.size()]; + } + + if (render.dirty.has_flag(RenderState::DIRTY_DEPTH)) { + [render.encoder setDepthStencilState:render.pipeline->depth_stencil]; + } + + if (render.dirty.has_flag(RenderState::DIRTY_RASTER)) { + render.pipeline->raster_state.apply(render.encoder); + } + + if (render.dirty.has_flag(RenderState::DIRTY_SCISSOR) && !render.scissors.is_empty()) { + size_t len = render.scissors.size(); + MTLScissorRect rects[len]; + for (size_t i = 0; i < len; i++) { + rects[i] = render.clip_to_render_area(render.scissors[i]); + } + [render.encoder setScissorRects:rects count:len]; + } + + if (render.dirty.has_flag(RenderState::DIRTY_BLEND) && render.blend_constants.has_value()) { + [render.encoder setBlendColorRed:render.blend_constants->r green:render.blend_constants->g blue:render.blend_constants->b alpha:render.blend_constants->a]; + } + + if (render.dirty.has_flag(RenderState::DIRTY_VERTEX)) { + uint32_t p_binding_count = render.vertex_buffers.size(); + uint32_t first = device_driver->get_metal_buffer_index_for_vertex_attribute_binding(p_binding_count - 1); + [render.encoder setVertexBuffers:render.vertex_buffers.ptr() + offsets:render.vertex_offsets.ptr() + withRange:NSMakeRange(first, p_binding_count)]; + } + + render.dirty.clear(); +} + +void MDCommandBuffer::render_set_viewport(VectorView p_viewports) { + render.viewports.resize(p_viewports.size()); + for (uint32_t i = 0; i < p_viewports.size(); i += 1) { + Rect2i const &vp = p_viewports[i]; + render.viewports[i] = { + .originX = static_cast(vp.position.x), + .originY = static_cast(vp.position.y), + .width = static_cast(vp.size.width), + .height = static_cast(vp.size.height), + .znear = 0.0, + .zfar = 1.0, + }; + } + + render.dirty.set_flag(RenderState::DIRTY_VIEWPORT); +} + +void MDCommandBuffer::render_set_scissor(VectorView p_scissors) { + render.scissors.resize(p_scissors.size()); + for (uint32_t i = 0; i < p_scissors.size(); i += 1) { + Rect2i const &vp = p_scissors[i]; + render.scissors[i] = { + .x = static_cast(vp.position.x), + .y = static_cast(vp.position.y), + .width = static_cast(vp.size.width), + .height = static_cast(vp.size.height), + }; + } + + render.dirty.set_flag(RenderState::DIRTY_SCISSOR); +} + +void MDCommandBuffer::render_set_blend_constants(const Color &p_constants) { + DEV_ASSERT(type == MDCommandBufferStateType::Render); + if (render.blend_constants != p_constants) { + render.blend_constants = p_constants; + render.dirty.set_flag(RenderState::DIRTY_BLEND); + } +} + +void MDCommandBuffer::_render_bind_uniform_sets() { + DEV_ASSERT(type == MDCommandBufferStateType::Render); + if (!render.dirty.has_flag(RenderState::DIRTY_UNIFORMS)) { + return; + } + + render.dirty.clear_flag(RenderState::DIRTY_UNIFORMS); + uint64_t set_uniforms = render.uniform_set_mask; + render.uniform_set_mask = 0; + + id enc = render.encoder; + MDRenderShader *shader = render.pipeline->shader; + id device = enc.device; + + while (set_uniforms != 0) { + // Find the index of the next set bit. + int index = __builtin_ctzll(set_uniforms); + // Clear the set bit. + set_uniforms &= ~(1ULL << index); + MDUniformSet *set = render.uniform_sets[index]; + if (set == nullptr || set->index >= (uint32_t)shader->sets.size()) { + continue; + } + UniformSet const &set_info = shader->sets[set->index]; + + BoundUniformSet &bus = set->boundUniformSetForShader(shader, device); + + for (KeyValue, StageResourceUsage> const &keyval : bus.bound_resources) { + MTLResourceUsage usage = resource_usage_for_stage(keyval.value, RDD::ShaderStage::SHADER_STAGE_VERTEX); + if (usage != 0) { + [enc useResource:keyval.key usage:usage stages:MTLRenderStageVertex]; + } + usage = resource_usage_for_stage(keyval.value, RDD::ShaderStage::SHADER_STAGE_FRAGMENT); + if (usage != 0) { + [enc useResource:keyval.key usage:usage stages:MTLRenderStageFragment]; + } + } + + // Set the buffer for the vertex stage. + { + uint32_t const *offset = set_info.offsets.getptr(RDD::SHADER_STAGE_VERTEX); + if (offset) { + [enc setVertexBuffer:bus.buffer offset:*offset atIndex:set->index]; + } + } + // Set the buffer for the fragment stage. + { + uint32_t const *offset = set_info.offsets.getptr(RDD::SHADER_STAGE_FRAGMENT); + if (offset) { + [enc setFragmentBuffer:bus.buffer offset:*offset atIndex:set->index]; + } + } + } +} + +void MDCommandBuffer::_populate_vertices(simd::float4 *p_vertices, Size2i p_fb_size, VectorView p_rects) { + uint32_t idx = 0; + for (uint32_t i = 0; i < p_rects.size(); i++) { + Rect2i const &rect = p_rects[i]; + idx = _populate_vertices(p_vertices, idx, rect, p_fb_size); + } +} + +uint32_t MDCommandBuffer::_populate_vertices(simd::float4 *p_vertices, uint32_t p_index, Rect2i const &p_rect, Size2i p_fb_size) { + // Determine the positions of the four edges of the + // clear rectangle as a fraction of the attachment size. + float leftPos = (float)(p_rect.position.x) / (float)p_fb_size.width; + float rightPos = (float)(p_rect.size.width) / (float)p_fb_size.width + leftPos; + float bottomPos = (float)(p_rect.position.y) / (float)p_fb_size.height; + float topPos = (float)(p_rect.size.height) / (float)p_fb_size.height + bottomPos; + + // Transform to clip-space coordinates, which are bounded by (-1.0 < p < 1.0) in clip-space. + leftPos = (leftPos * 2.0f) - 1.0f; + rightPos = (rightPos * 2.0f) - 1.0f; + bottomPos = (bottomPos * 2.0f) - 1.0f; + topPos = (topPos * 2.0f) - 1.0f; + + simd::float4 vtx; + + uint32_t idx = p_index; + vtx.z = 0.0; + vtx.w = (float)1; + + // Top left vertex - First triangle. + vtx.y = topPos; + vtx.x = leftPos; + p_vertices[idx++] = vtx; + + // Bottom left vertex. + vtx.y = bottomPos; + vtx.x = leftPos; + p_vertices[idx++] = vtx; + + // Bottom right vertex. + vtx.y = bottomPos; + vtx.x = rightPos; + p_vertices[idx++] = vtx; + + // Bottom right vertex - Second triangle. + p_vertices[idx++] = vtx; + + // Top right vertex. + vtx.y = topPos; + vtx.x = rightPos; + p_vertices[idx++] = vtx; + + // Top left vertex. + vtx.y = topPos; + vtx.x = leftPos; + p_vertices[idx++] = vtx; + + return idx; +} + +void MDCommandBuffer::render_begin_pass(RDD::RenderPassID p_render_pass, RDD::FramebufferID p_frameBuffer, RDD::CommandBufferType p_cmd_buffer_type, const Rect2i &p_rect, VectorView p_clear_values) { + DEV_ASSERT(commandBuffer != nil); + end(); + + MDRenderPass *pass = (MDRenderPass *)(p_render_pass.id); + MDFrameBuffer *fb = (MDFrameBuffer *)(p_frameBuffer.id); + + type = MDCommandBufferStateType::Render; + render.pass = pass; + render.current_subpass = UINT32_MAX; + render.render_area = p_rect; + render.clear_values.resize(p_clear_values.size()); + for (uint32_t i = 0; i < p_clear_values.size(); i++) { + render.clear_values[i] = p_clear_values[i]; + } + render.is_rendering_entire_area = (p_rect.position == Point2i(0, 0)) && p_rect.size == fb->size; + render.frameBuffer = fb; + render_next_subpass(); +} + +void MDCommandBuffer::_end_render_pass() { + MDFrameBuffer const &fb_info = *render.frameBuffer; + MDRenderPass const &pass_info = *render.pass; + MDSubpass const &subpass = pass_info.subpasses[render.current_subpass]; + + PixelFormats &pf = device_driver->get_pixel_formats(); + + for (uint32_t i = 0; i < subpass.resolve_references.size(); i++) { + uint32_t color_index = subpass.color_references[i].attachment; + uint32_t resolve_index = subpass.resolve_references[i].attachment; + DEV_ASSERT((color_index == RDD::AttachmentReference::UNUSED) == (resolve_index == RDD::AttachmentReference::UNUSED)); + if (color_index == RDD::AttachmentReference::UNUSED || !fb_info.textures[color_index]) { + continue; + } + + id resolve_tex = fb_info.textures[resolve_index]; + + CRASH_COND_MSG(!flags::all(pf.getCapabilities(resolve_tex.pixelFormat), kMTLFmtCapsResolve), "not implemented: unresolvable texture types"); + // see: https://github.com/KhronosGroup/MoltenVK/blob/d20d13fe2735adb845636a81522df1b9d89c0fba/MoltenVK/MoltenVK/GPUObjects/MVKRenderPass.mm#L407 + } + + [render.encoder endEncoding]; + render.encoder = nil; +} + +void MDCommandBuffer::_render_clear_render_area() { + MDRenderPass const &pass = *render.pass; + MDSubpass const &subpass = pass.subpasses[render.current_subpass]; + + // First determine attachments that should be cleared. + LocalVector clears; + clears.reserve(subpass.color_references.size() + /* possible depth stencil clear */ 1); + + for (uint32_t i = 0; i < subpass.color_references.size(); i++) { + uint32_t idx = subpass.color_references[i].attachment; + if (idx != RDD::AttachmentReference::UNUSED && pass.attachments[idx].shouldClear(subpass, false)) { + clears.push_back({ .aspect = RDD::TEXTURE_ASPECT_COLOR_BIT, .color_attachment = idx, .value = render.clear_values[idx] }); + } + } + uint32_t ds_index = subpass.depth_stencil_reference.attachment; + MDAttachment const &attachment = pass.attachments[ds_index]; + bool shouldClearDepth = (ds_index != RDD::AttachmentReference::UNUSED && attachment.shouldClear(subpass, false)); + bool shouldClearStencil = (ds_index != RDD::AttachmentReference::UNUSED && attachment.shouldClear(subpass, true)); + if (shouldClearDepth || shouldClearStencil) { + BitField bits; + if (shouldClearDepth && attachment.type & MDAttachmentType::Depth) { + bits.set_flag(RDD::TEXTURE_ASPECT_DEPTH_BIT); + } + if (shouldClearStencil && attachment.type & MDAttachmentType::Stencil) { + bits.set_flag(RDD::TEXTURE_ASPECT_STENCIL_BIT); + } + + clears.push_back({ .aspect = bits, .color_attachment = ds_index, .value = render.clear_values[ds_index] }); + } + + if (clears.is_empty()) { + return; + } + + render_clear_attachments(clears, { render.render_area }); +} + +void MDCommandBuffer::render_next_subpass() { + DEV_ASSERT(commandBuffer != nil); + + if (render.current_subpass == UINT32_MAX) { + render.current_subpass = 0; + } else { + _end_render_pass(); + render.current_subpass++; + } + + MDFrameBuffer const &fb = *render.frameBuffer; + MDRenderPass const &pass = *render.pass; + MDSubpass const &subpass = pass.subpasses[render.current_subpass]; + + MTLRenderPassDescriptor *desc = MTLRenderPassDescriptor.renderPassDescriptor; + PixelFormats &pf = device_driver->get_pixel_formats(); + + uint32_t attachmentCount = 0; + for (uint32_t i = 0; i < subpass.color_references.size(); i++) { + uint32_t idx = subpass.color_references[i].attachment; + if (idx == RDD::AttachmentReference::UNUSED) { + continue; + } + + attachmentCount += 1; + MTLRenderPassColorAttachmentDescriptor *ca = desc.colorAttachments[i]; + + uint32_t resolveIdx = subpass.resolve_references.is_empty() ? RDD::AttachmentReference::UNUSED : subpass.resolve_references[i].attachment; + bool has_resolve = resolveIdx != RDD::AttachmentReference::UNUSED; + bool can_resolve = true; + if (resolveIdx != RDD::AttachmentReference::UNUSED) { + id resolve_tex = fb.textures[resolveIdx]; + can_resolve = flags::all(pf.getCapabilities(resolve_tex.pixelFormat), kMTLFmtCapsResolve); + if (can_resolve) { + ca.resolveTexture = resolve_tex; + } else { + CRASH_NOW_MSG("unimplemented: using a texture format that is not supported for resolve"); + } + } + + MDAttachment const &attachment = pass.attachments[idx]; + + id tex = fb.textures[idx]; + if ((attachment.type & MDAttachmentType::Color)) { + if (attachment.configureDescriptor(ca, pf, subpass, tex, render.is_rendering_entire_area, has_resolve, can_resolve, false)) { + Color clearColor = render.clear_values[idx].color; + ca.clearColor = MTLClearColorMake(clearColor.r, clearColor.g, clearColor.b, clearColor.a); + } + } + } + + if (subpass.depth_stencil_reference.attachment != RDD::AttachmentReference::UNUSED) { + attachmentCount += 1; + uint32_t idx = subpass.depth_stencil_reference.attachment; + MDAttachment const &attachment = pass.attachments[idx]; + id tex = fb.textures[idx]; + if (attachment.type & MDAttachmentType::Depth) { + MTLRenderPassDepthAttachmentDescriptor *da = desc.depthAttachment; + if (attachment.configureDescriptor(da, pf, subpass, tex, render.is_rendering_entire_area, false, false, false)) { + da.clearDepth = render.clear_values[idx].depth; + } + } + + if (attachment.type & MDAttachmentType::Stencil) { + MTLRenderPassStencilAttachmentDescriptor *sa = desc.stencilAttachment; + if (attachment.configureDescriptor(sa, pf, subpass, tex, render.is_rendering_entire_area, false, false, true)) { + sa.clearStencil = render.clear_values[idx].stencil; + } + } + } + + desc.renderTargetWidth = MAX((NSUInteger)MIN(render.render_area.position.x + render.render_area.size.width, fb.size.width), 1u); + desc.renderTargetHeight = MAX((NSUInteger)MIN(render.render_area.position.y + render.render_area.size.height, fb.size.height), 1u); + + if (attachmentCount == 0) { + // If there are no attachments, delay the creation of the encoder, + // so we can use a matching sample count for the pipeline, by setting + // the defaultRasterSampleCount from the pipeline's sample count. + render.desc = desc; + } else { + render.encoder = [commandBuffer renderCommandEncoderWithDescriptor:desc]; + + if (!render.is_rendering_entire_area) { + _render_clear_render_area(); + } + // With a new encoder, all state is dirty. + render.dirty.set_flag(RenderState::DIRTY_ALL); + } +} + +void MDCommandBuffer::render_draw(uint32_t p_vertex_count, + uint32_t p_instance_count, + uint32_t p_base_vertex, + uint32_t p_first_instance) { + DEV_ASSERT(type == MDCommandBufferStateType::Render); + _render_set_dirty_state(); + + DEV_ASSERT(render.dirty == 0); + + id enc = render.encoder; + + [enc drawPrimitives:render.pipeline->raster_state.render_primitive + vertexStart:p_base_vertex + vertexCount:p_vertex_count + instanceCount:p_instance_count + baseInstance:p_first_instance]; +} + +void MDCommandBuffer::render_bind_vertex_buffers(uint32_t p_binding_count, const RDD::BufferID *p_buffers, const uint64_t *p_offsets) { + DEV_ASSERT(type == MDCommandBufferStateType::Render); + + render.vertex_buffers.resize(p_binding_count); + render.vertex_offsets.resize(p_binding_count); + + // Reverse the buffers, as their bindings are assigned in descending order. + for (uint32_t i = 0; i < p_binding_count; i += 1) { + render.vertex_buffers[i] = rid::get(p_buffers[p_binding_count - i - 1]); + render.vertex_offsets[i] = p_offsets[p_binding_count - i - 1]; + } + + if (render.encoder) { + uint32_t first = device_driver->get_metal_buffer_index_for_vertex_attribute_binding(p_binding_count - 1); + [render.encoder setVertexBuffers:render.vertex_buffers.ptr() + offsets:render.vertex_offsets.ptr() + withRange:NSMakeRange(first, p_binding_count)]; + } else { + render.dirty.set_flag(RenderState::DIRTY_VERTEX); + } +} + +void MDCommandBuffer::render_bind_index_buffer(RDD::BufferID p_buffer, RDD::IndexBufferFormat p_format, uint64_t p_offset) { + DEV_ASSERT(type == MDCommandBufferStateType::Render); + + render.index_buffer = rid::get(p_buffer); + render.index_type = p_format == RDD::IndexBufferFormat::INDEX_BUFFER_FORMAT_UINT16 ? MTLIndexTypeUInt16 : MTLIndexTypeUInt32; +} + +void MDCommandBuffer::render_draw_indexed(uint32_t p_index_count, + uint32_t p_instance_count, + uint32_t p_first_index, + int32_t p_vertex_offset, + uint32_t p_first_instance) { + DEV_ASSERT(type == MDCommandBufferStateType::Render); + _render_set_dirty_state(); + + id enc = render.encoder; + + [enc drawIndexedPrimitives:render.pipeline->raster_state.render_primitive + indexCount:p_index_count + indexType:render.index_type + indexBuffer:render.index_buffer + indexBufferOffset:p_vertex_offset + instanceCount:p_instance_count + baseVertex:p_first_index + baseInstance:p_first_instance]; +} + +void MDCommandBuffer::render_draw_indexed_indirect(RDD::BufferID p_indirect_buffer, uint64_t p_offset, uint32_t p_draw_count, uint32_t p_stride) { + DEV_ASSERT(type == MDCommandBufferStateType::Render); + _render_set_dirty_state(); + + id enc = render.encoder; + + id indirect_buffer = rid::get(p_indirect_buffer); + NSUInteger indirect_offset = p_offset; + + for (uint32_t i = 0; i < p_draw_count; i++) { + [enc drawIndexedPrimitives:render.pipeline->raster_state.render_primitive + indexType:render.index_type + indexBuffer:render.index_buffer + indexBufferOffset:0 + indirectBuffer:indirect_buffer + indirectBufferOffset:indirect_offset]; + indirect_offset += p_stride; + } +} + +void MDCommandBuffer::render_draw_indexed_indirect_count(RDD::BufferID p_indirect_buffer, uint64_t p_offset, RDD::BufferID p_count_buffer, uint64_t p_count_buffer_offset, uint32_t p_max_draw_count, uint32_t p_stride) { + ERR_FAIL_MSG("not implemented"); +} + +void MDCommandBuffer::render_draw_indirect(RDD::BufferID p_indirect_buffer, uint64_t p_offset, uint32_t p_draw_count, uint32_t p_stride) { + DEV_ASSERT(type == MDCommandBufferStateType::Render); + _render_set_dirty_state(); + + id enc = render.encoder; + + id indirect_buffer = rid::get(p_indirect_buffer); + NSUInteger indirect_offset = p_offset; + + for (uint32_t i = 0; i < p_draw_count; i++) { + [enc drawPrimitives:render.pipeline->raster_state.render_primitive + indirectBuffer:indirect_buffer + indirectBufferOffset:indirect_offset]; + indirect_offset += p_stride; + } +} + +void MDCommandBuffer::render_draw_indirect_count(RDD::BufferID p_indirect_buffer, uint64_t p_offset, RDD::BufferID p_count_buffer, uint64_t p_count_buffer_offset, uint32_t p_max_draw_count, uint32_t p_stride) { + ERR_FAIL_MSG("not implemented"); +} + +void MDCommandBuffer::render_end_pass() { + DEV_ASSERT(type == MDCommandBufferStateType::Render); + + [render.encoder endEncoding]; + render.reset(); + type = MDCommandBufferStateType::None; +} + +#pragma mark - Compute + +void MDCommandBuffer::compute_bind_uniform_set(RDD::UniformSetID p_uniform_set, RDD::ShaderID p_shader, uint32_t p_set_index) { + DEV_ASSERT(type == MDCommandBufferStateType::Compute); + + id enc = compute.encoder; + id device = enc.device; + + MDShader *shader = (MDShader *)(p_shader.id); + UniformSet const &set_info = shader->sets[p_set_index]; + + MDUniformSet *set = (MDUniformSet *)(p_uniform_set.id); + BoundUniformSet &bus = set->boundUniformSetForShader(shader, device); + + for (KeyValue, StageResourceUsage> &keyval : bus.bound_resources) { + MTLResourceUsage usage = resource_usage_for_stage(keyval.value, RDD::ShaderStage::SHADER_STAGE_COMPUTE); + if (usage != 0) { + [enc useResource:keyval.key usage:usage]; + } + } + + uint32_t const *offset = set_info.offsets.getptr(RDD::SHADER_STAGE_COMPUTE); + if (offset) { + [enc setBuffer:bus.buffer offset:*offset atIndex:p_set_index]; + } +} + +void MDCommandBuffer::compute_dispatch(uint32_t p_x_groups, uint32_t p_y_groups, uint32_t p_z_groups) { + DEV_ASSERT(type == MDCommandBufferStateType::Compute); + + MTLRegion region = MTLRegionMake3D(0, 0, 0, p_x_groups, p_y_groups, p_z_groups); + + id enc = compute.encoder; + [enc dispatchThreadgroups:region.size threadsPerThreadgroup:compute.pipeline->compute_state.local]; +} + +void MDCommandBuffer::compute_dispatch_indirect(RDD::BufferID p_indirect_buffer, uint64_t p_offset) { + DEV_ASSERT(type == MDCommandBufferStateType::Compute); + + id indirectBuffer = rid::get(p_indirect_buffer); + + id enc = compute.encoder; + [enc dispatchThreadgroupsWithIndirectBuffer:indirectBuffer indirectBufferOffset:p_offset threadsPerThreadgroup:compute.pipeline->compute_state.local]; +} + +void MDCommandBuffer::_end_compute_dispatch() { + DEV_ASSERT(type == MDCommandBufferStateType::Compute); + + [compute.encoder endEncoding]; + compute.reset(); + type = MDCommandBufferStateType::None; +} + +void MDCommandBuffer::_end_blit() { + DEV_ASSERT(type == MDCommandBufferStateType::Blit); + + [blit.encoder endEncoding]; + blit.reset(); + type = MDCommandBufferStateType::None; +} + +MDComputeShader::MDComputeShader(CharString p_name, Vector p_sets, id p_kernel) : + MDShader(p_name, p_sets), kernel(p_kernel) { +} + +void MDComputeShader::encode_push_constant_data(VectorView p_data, MDCommandBuffer *p_cb) { + DEV_ASSERT(p_cb->type == MDCommandBufferStateType::Compute); + if (push_constants.binding == (uint32_t)-1) { + return; + } + + id enc = p_cb->compute.encoder; + + void const *ptr = p_data.ptr(); + size_t length = p_data.size() * sizeof(uint32_t); + + [enc setBytes:ptr length:length atIndex:push_constants.binding]; +} + +MDRenderShader::MDRenderShader(CharString p_name, Vector p_sets, id _Nonnull p_vert, id _Nonnull p_frag) : + MDShader(p_name, p_sets), vert(p_vert), frag(p_frag) { +} + +void MDRenderShader::encode_push_constant_data(VectorView p_data, MDCommandBuffer *p_cb) { + DEV_ASSERT(p_cb->type == MDCommandBufferStateType::Render); + id enc = p_cb->render.encoder; + + void const *ptr = p_data.ptr(); + size_t length = p_data.size() * sizeof(uint32_t); + + if (push_constants.vert.binding > -1) { + [enc setVertexBytes:ptr length:length atIndex:push_constants.vert.binding]; + } + + if (push_constants.frag.binding > -1) { + [enc setFragmentBytes:ptr length:length atIndex:push_constants.frag.binding]; + } +} + +BoundUniformSet &MDUniformSet::boundUniformSetForShader(MDShader *p_shader, id p_device) { + BoundUniformSet *sus = bound_uniforms.getptr(p_shader); + if (sus != nullptr) { + return *sus; + } + + UniformSet const &set = p_shader->sets[index]; + + HashMap, StageResourceUsage> bound_resources; + auto add_usage = [&bound_resources](id __unsafe_unretained res, RDD::ShaderStage stage, MTLResourceUsage usage) { + StageResourceUsage *sru = bound_resources.getptr(res); + if (sru == nullptr) { + bound_resources.insert(res, stage_resource_usage(stage, usage)); + } else { + *sru |= stage_resource_usage(stage, usage); + } + }; + id enc_buffer = nil; + if (set.buffer_size > 0) { + MTLResourceOptions options = MTLResourceStorageModeShared | MTLResourceHazardTrackingModeTracked; + enc_buffer = [p_device newBufferWithLength:set.buffer_size options:options]; + for (KeyValue> const &kv : set.encoders) { + RDD::ShaderStage const stage = kv.key; + ShaderStageUsage const stage_usage = ShaderStageUsage(1 << stage); + id const enc = kv.value; + + [enc setArgumentBuffer:enc_buffer offset:set.offsets[stage]]; + + for (uint32_t i = 0; i < uniforms.size(); i++) { + RDD::BoundUniform const &uniform = uniforms[i]; + UniformInfo ui = set.uniforms[i]; + + BindingInfo *bi = ui.bindings.getptr(stage); + if (bi == nullptr) { + // No binding for this stage. + continue; + } + + if ((ui.active_stages & stage_usage) == 0) { + // Not active for this state, so don't bind anything. + continue; + } + + switch (uniform.type) { + case RDD::UNIFORM_TYPE_SAMPLER: { + size_t count = uniform.ids.size(); + id __unsafe_unretained *objects = ALLOCA_ARRAY(id __unsafe_unretained, count); + for (size_t j = 0; j < count; j += 1) { + objects[j] = rid::get(uniform.ids[j].id); + } + [enc setSamplerStates:objects withRange:NSMakeRange(bi->index, count)]; + } break; + case RDD::UNIFORM_TYPE_SAMPLER_WITH_TEXTURE: { + size_t count = uniform.ids.size() / 2; + id __unsafe_unretained *textures = ALLOCA_ARRAY(id __unsafe_unretained, count); + id __unsafe_unretained *samplers = ALLOCA_ARRAY(id __unsafe_unretained, count); + for (uint32_t j = 0; j < count; j += 1) { + id sampler = rid::get(uniform.ids[j * 2 + 0]); + id texture = rid::get(uniform.ids[j * 2 + 1]); + samplers[j] = sampler; + textures[j] = texture; + add_usage(texture, stage, bi->usage); + } + BindingInfo *sbi = ui.bindings_secondary.getptr(stage); + if (sbi) { + [enc setSamplerStates:samplers withRange:NSMakeRange(sbi->index, count)]; + } + [enc setTextures:textures + withRange:NSMakeRange(bi->index, count)]; + } break; + case RDD::UNIFORM_TYPE_TEXTURE: { + size_t count = uniform.ids.size(); + if (count == 1) { + id obj = rid::get(uniform.ids[0]); + [enc setTexture:obj atIndex:bi->index]; + add_usage(obj, stage, bi->usage); + } else { + id __unsafe_unretained *objects = ALLOCA_ARRAY(id __unsafe_unretained, count); + for (size_t j = 0; j < count; j += 1) { + id obj = rid::get(uniform.ids[j]); + objects[j] = obj; + add_usage(obj, stage, bi->usage); + } + [enc setTextures:objects withRange:NSMakeRange(bi->index, count)]; + } + } break; + case RDD::UNIFORM_TYPE_IMAGE: { + size_t count = uniform.ids.size(); + if (count == 1) { + id obj = rid::get(uniform.ids[0]); + [enc setTexture:obj atIndex:bi->index]; + add_usage(obj, stage, bi->usage); + BindingInfo *sbi = ui.bindings_secondary.getptr(stage); + if (sbi) { + id tex = obj.parentTexture ? obj.parentTexture : obj; + id buf = tex.buffer; + if (buf) { + [enc setBuffer:buf offset:tex.bufferOffset atIndex:sbi->index]; + } + } + } else { + id __unsafe_unretained *objects = ALLOCA_ARRAY(id __unsafe_unretained, count); + for (size_t j = 0; j < count; j += 1) { + id obj = rid::get(uniform.ids[j]); + objects[j] = obj; + add_usage(obj, stage, bi->usage); + } + [enc setTextures:objects withRange:NSMakeRange(bi->index, count)]; + } + } break; + case RDD::UNIFORM_TYPE_TEXTURE_BUFFER: { + ERR_PRINT("not implemented: UNIFORM_TYPE_TEXTURE_BUFFER"); + } break; + case RDD::UNIFORM_TYPE_SAMPLER_WITH_TEXTURE_BUFFER: { + ERR_PRINT("not implemented: UNIFORM_TYPE_SAMPLER_WITH_TEXTURE_BUFFER"); + } break; + case RDD::UNIFORM_TYPE_IMAGE_BUFFER: { + CRASH_NOW_MSG("not implemented: UNIFORM_TYPE_IMAGE_BUFFER"); + } break; + case RDD::UNIFORM_TYPE_UNIFORM_BUFFER: { + id buffer = rid::get(uniform.ids[0]); + [enc setBuffer:buffer offset:0 atIndex:bi->index]; + add_usage(buffer, stage, bi->usage); + } break; + case RDD::UNIFORM_TYPE_STORAGE_BUFFER: { + id buffer = rid::get(uniform.ids[0]); + [enc setBuffer:buffer offset:0 atIndex:bi->index]; + add_usage(buffer, stage, bi->usage); + } break; + case RDD::UNIFORM_TYPE_INPUT_ATTACHMENT: { + size_t count = uniform.ids.size(); + if (count == 1) { + id obj = rid::get(uniform.ids[0]); + [enc setTexture:obj atIndex:bi->index]; + add_usage(obj, stage, bi->usage); + } else { + id __unsafe_unretained *objects = ALLOCA_ARRAY(id __unsafe_unretained, count); + for (size_t j = 0; j < count; j += 1) { + id obj = rid::get(uniform.ids[j]); + objects[j] = obj; + add_usage(obj, stage, bi->usage); + } + [enc setTextures:objects withRange:NSMakeRange(bi->index, count)]; + } + } break; + default: { + DEV_ASSERT(false); + } + } + } + } + } + + BoundUniformSet bs = { .buffer = enc_buffer, .bound_resources = bound_resources }; + bound_uniforms.insert(p_shader, bs); + return bound_uniforms.get(p_shader); +} + +MTLFmtCaps MDSubpass::getRequiredFmtCapsForAttachmentAt(uint32_t p_index) const { + MTLFmtCaps caps = kMTLFmtCapsNone; + + for (RDD::AttachmentReference const &ar : input_references) { + if (ar.attachment == p_index) { + flags::set(caps, kMTLFmtCapsRead); + break; + } + } + + for (RDD::AttachmentReference const &ar : color_references) { + if (ar.attachment == p_index) { + flags::set(caps, kMTLFmtCapsColorAtt); + break; + } + } + + for (RDD::AttachmentReference const &ar : resolve_references) { + if (ar.attachment == p_index) { + flags::set(caps, kMTLFmtCapsResolve); + break; + } + } + + if (depth_stencil_reference.attachment == p_index) { + flags::set(caps, kMTLFmtCapsDSAtt); + } + + return caps; +} + +void MDAttachment::linkToSubpass(const MDRenderPass &p_pass) { + firstUseSubpassIndex = UINT32_MAX; + lastUseSubpassIndex = 0; + + for (MDSubpass const &subpass : p_pass.subpasses) { + MTLFmtCaps reqCaps = subpass.getRequiredFmtCapsForAttachmentAt(index); + if (reqCaps) { + firstUseSubpassIndex = MIN(subpass.subpass_index, firstUseSubpassIndex); + lastUseSubpassIndex = MAX(subpass.subpass_index, lastUseSubpassIndex); + } + } +} + +MTLStoreAction MDAttachment::getMTLStoreAction(MDSubpass const &p_subpass, + bool p_is_rendering_entire_area, + bool p_has_resolve, + bool p_can_resolve, + bool p_is_stencil) const { + if (!p_is_rendering_entire_area || !isLastUseOf(p_subpass)) { + return p_has_resolve && p_can_resolve ? MTLStoreActionStoreAndMultisampleResolve : MTLStoreActionStore; + } + + switch (p_is_stencil ? stencilStoreAction : storeAction) { + case MTLStoreActionStore: + return p_has_resolve && p_can_resolve ? MTLStoreActionStoreAndMultisampleResolve : MTLStoreActionStore; + case MTLStoreActionDontCare: + return p_has_resolve ? (p_can_resolve ? MTLStoreActionMultisampleResolve : MTLStoreActionStore) : MTLStoreActionDontCare; + + default: + return MTLStoreActionStore; + } +} + +bool MDAttachment::configureDescriptor(MTLRenderPassAttachmentDescriptor *p_desc, + PixelFormats &p_pf, + MDSubpass const &p_subpass, + id p_attachment, + bool p_is_rendering_entire_area, + bool p_has_resolve, + bool p_can_resolve, + bool p_is_stencil) const { + p_desc.texture = p_attachment; + + MTLLoadAction load; + if (!p_is_rendering_entire_area || !isFirstUseOf(p_subpass)) { + load = MTLLoadActionLoad; + } else { + load = p_is_stencil ? stencilLoadAction : loadAction; + } + + p_desc.loadAction = load; + + MTLPixelFormat mtlFmt = p_attachment.pixelFormat; + bool isDepthFormat = p_pf.isDepthFormat(mtlFmt); + bool isStencilFormat = p_pf.isStencilFormat(mtlFmt); + if (isStencilFormat && !p_is_stencil && !isDepthFormat) { + p_desc.storeAction = MTLStoreActionDontCare; + } else { + p_desc.storeAction = getMTLStoreAction(p_subpass, p_is_rendering_entire_area, p_has_resolve, p_can_resolve, p_is_stencil); + } + + return load == MTLLoadActionClear; +} + +bool MDAttachment::shouldClear(const MDSubpass &p_subpass, bool p_is_stencil) const { + // If the subpass is not the first subpass to use this attachment, don't clear this attachment. + if (p_subpass.subpass_index != firstUseSubpassIndex) { + return false; + } + return (p_is_stencil ? stencilLoadAction : loadAction) == MTLLoadActionClear; +} + +MDRenderPass::MDRenderPass(Vector &p_attachments, Vector &p_subpasses) : + attachments(p_attachments), subpasses(p_subpasses) { + for (MDAttachment &att : attachments) { + att.linkToSubpass(*this); + } +} + +#pragma mark - Resource Factory + +id MDResourceFactory::new_func(NSString *p_source, NSString *p_name, NSError **p_error) { + @autoreleasepool { + NSError *err = nil; + MTLCompileOptions *options = [MTLCompileOptions new]; + id device = device_driver->get_device(); + id mtlLib = [device newLibraryWithSource:p_source + options:options + error:&err]; + if (err) { + if (p_error != nil) { + *p_error = err; + } + } + return [mtlLib newFunctionWithName:p_name]; + } +} + +id MDResourceFactory::new_clear_vert_func(ClearAttKey &p_key) { + @autoreleasepool { + NSString *msl = [NSString stringWithFormat:@R"( +#include +using namespace metal; + +typedef struct { + float4 a_position [[attribute(0)]]; +} AttributesPos; + +typedef struct { + float4 colors[9]; +} ClearColorsIn; + +typedef struct { + float4 v_position [[position]]; + uint layer; +} VaryingsPos; + +vertex VaryingsPos vertClear(AttributesPos attributes [[stage_in]], constant ClearColorsIn& ccIn [[buffer(0)]]) { + VaryingsPos varyings; + varyings.v_position = float4(attributes.a_position.x, -attributes.a_position.y, ccIn.colors[%d].r, 1.0); + varyings.layer = uint(attributes.a_position.w); + return varyings; +} +)", + ClearAttKey::DEPTH_INDEX]; + + return new_func(msl, @"vertClear", nil); + } +} + +id MDResourceFactory::new_clear_frag_func(ClearAttKey &p_key) { + @autoreleasepool { + NSMutableString *msl = [NSMutableString stringWithCapacity:2048]; + + [msl appendFormat:@R"( +#include +using namespace metal; + +typedef struct { + float4 v_position [[position]]; +} VaryingsPos; + +typedef struct { + float4 colors[9]; +} ClearColorsIn; + +typedef struct { +)"]; + + for (uint32_t caIdx = 0; caIdx < ClearAttKey::COLOR_COUNT; caIdx++) { + if (p_key.is_enabled(caIdx)) { + NSString *typeStr = get_format_type_string((MTLPixelFormat)p_key.pixel_formats[caIdx]); + [msl appendFormat:@" %@4 color%u [[color(%u)]];\n", typeStr, caIdx, caIdx]; + } + } + [msl appendFormat:@R"(} ClearColorsOut; + +fragment ClearColorsOut fragClear(VaryingsPos varyings [[stage_in]], constant ClearColorsIn& ccIn [[buffer(0)]]) { + + ClearColorsOut ccOut; +)"]; + for (uint32_t caIdx = 0; caIdx < ClearAttKey::COLOR_COUNT; caIdx++) { + if (p_key.is_enabled(caIdx)) { + NSString *typeStr = get_format_type_string((MTLPixelFormat)p_key.pixel_formats[caIdx]); + [msl appendFormat:@" ccOut.color%u = %@4(ccIn.colors[%u]);\n", caIdx, typeStr, caIdx]; + } + } + [msl appendString:@R"( return ccOut; +})"]; + + return new_func(msl, @"fragClear", nil); + } +} + +NSString *MDResourceFactory::get_format_type_string(MTLPixelFormat p_fmt) { + switch (device_driver->get_pixel_formats().getFormatType(p_fmt)) { + case MTLFormatType::ColorInt8: + case MTLFormatType::ColorInt16: + return @"short"; + case MTLFormatType::ColorUInt8: + case MTLFormatType::ColorUInt16: + return @"ushort"; + case MTLFormatType::ColorInt32: + return @"int"; + case MTLFormatType::ColorUInt32: + return @"uint"; + case MTLFormatType::ColorHalf: + return @"half"; + case MTLFormatType::ColorFloat: + case MTLFormatType::DepthStencil: + case MTLFormatType::Compressed: + return @"float"; + case MTLFormatType::None: + return @"unexpected_MTLPixelFormatInvalid"; + } +} + +id MDResourceFactory::new_depth_stencil_state(bool p_use_depth, bool p_use_stencil) { + MTLDepthStencilDescriptor *dsDesc = [MTLDepthStencilDescriptor new]; + dsDesc.depthCompareFunction = MTLCompareFunctionAlways; + dsDesc.depthWriteEnabled = p_use_depth; + + if (p_use_stencil) { + MTLStencilDescriptor *sDesc = [MTLStencilDescriptor new]; + sDesc.stencilCompareFunction = MTLCompareFunctionAlways; + sDesc.stencilFailureOperation = MTLStencilOperationReplace; + sDesc.depthFailureOperation = MTLStencilOperationReplace; + sDesc.depthStencilPassOperation = MTLStencilOperationReplace; + + dsDesc.frontFaceStencil = sDesc; + dsDesc.backFaceStencil = sDesc; + } else { + dsDesc.frontFaceStencil = nil; + dsDesc.backFaceStencil = nil; + } + + return [device_driver->get_device() newDepthStencilStateWithDescriptor:dsDesc]; +} + +id MDResourceFactory::new_clear_pipeline_state(ClearAttKey &p_key, NSError **p_error) { + PixelFormats &pixFmts = device_driver->get_pixel_formats(); + + id vtxFunc = new_clear_vert_func(p_key); + id fragFunc = new_clear_frag_func(p_key); + MTLRenderPipelineDescriptor *plDesc = [MTLRenderPipelineDescriptor new]; + plDesc.label = @"ClearRenderAttachments"; + plDesc.vertexFunction = vtxFunc; + plDesc.fragmentFunction = fragFunc; + plDesc.rasterSampleCount = p_key.sample_count; + plDesc.inputPrimitiveTopology = MTLPrimitiveTopologyClassTriangle; + + for (uint32_t caIdx = 0; caIdx < ClearAttKey::COLOR_COUNT; caIdx++) { + MTLRenderPipelineColorAttachmentDescriptor *colorDesc = plDesc.colorAttachments[caIdx]; + colorDesc.pixelFormat = (MTLPixelFormat)p_key.pixel_formats[caIdx]; + colorDesc.writeMask = p_key.is_enabled(caIdx) ? MTLColorWriteMaskAll : MTLColorWriteMaskNone; + } + + MTLPixelFormat mtlDepthFormat = p_key.depth_format(); + if (pixFmts.isDepthFormat(mtlDepthFormat)) { + plDesc.depthAttachmentPixelFormat = mtlDepthFormat; + } + + MTLPixelFormat mtlStencilFormat = p_key.stencil_format(); + if (pixFmts.isStencilFormat(mtlStencilFormat)) { + plDesc.stencilAttachmentPixelFormat = mtlStencilFormat; + } + + MTLVertexDescriptor *vtxDesc = plDesc.vertexDescriptor; + + // Vertex attribute descriptors. + MTLVertexAttributeDescriptorArray *vaDescArray = vtxDesc.attributes; + MTLVertexAttributeDescriptor *vaDesc; + NSUInteger vtxBuffIdx = device_driver->get_metal_buffer_index_for_vertex_attribute_binding(VERT_CONTENT_BUFFER_INDEX); + NSUInteger vtxStride = 0; + + // Vertex location. + vaDesc = vaDescArray[0]; + vaDesc.format = MTLVertexFormatFloat4; + vaDesc.bufferIndex = vtxBuffIdx; + vaDesc.offset = vtxStride; + vtxStride += sizeof(simd::float4); + + // Vertex attribute buffer. + MTLVertexBufferLayoutDescriptorArray *vbDescArray = vtxDesc.layouts; + MTLVertexBufferLayoutDescriptor *vbDesc = vbDescArray[vtxBuffIdx]; + vbDesc.stepFunction = MTLVertexStepFunctionPerVertex; + vbDesc.stepRate = 1; + vbDesc.stride = vtxStride; + + return [device_driver->get_device() newRenderPipelineStateWithDescriptor:plDesc error:p_error]; +} + +id MDResourceCache::get_clear_render_pipeline_state(ClearAttKey &p_key, NSError **p_error) { + HashMap::ConstIterator it = clear_states.find(p_key); + if (it != clear_states.end()) { + return it->value; + } + + id state = resource_factory->new_clear_pipeline_state(p_key, p_error); + clear_states[p_key] = state; + return state; +} + +id MDResourceCache::get_depth_stencil_state(bool p_use_depth, bool p_use_stencil) { + id __strong *val; + if (p_use_depth && p_use_stencil) { + val = &clear_depth_stencil_state.all; + } else if (p_use_depth) { + val = &clear_depth_stencil_state.depth_only; + } else if (p_use_stencil) { + val = &clear_depth_stencil_state.stencil_only; + } else { + val = &clear_depth_stencil_state.none; + } + DEV_ASSERT(val != nullptr); + + if (*val == nil) { + *val = resource_factory->new_depth_stencil_state(p_use_depth, p_use_stencil); + } + return *val; +} diff --git a/drivers/metal/metal_utils.h b/drivers/metal/metal_utils.h new file mode 100644 index 00000000000..eed1aad89bc --- /dev/null +++ b/drivers/metal/metal_utils.h @@ -0,0 +1,81 @@ +/**************************************************************************/ +/* metal_utils.h */ +/**************************************************************************/ +/* This file is part of: */ +/* GODOT ENGINE */ +/* https://godotengine.org */ +/**************************************************************************/ +/* Copyright (c) 2014-present Godot Engine contributors (see AUTHORS.md). */ +/* Copyright (c) 2007-2014 Juan Linietsky, Ariel Manzur. */ +/* */ +/* Permission is hereby granted, free of charge, to any person obtaining */ +/* a copy of this software and associated documentation files (the */ +/* "Software"), to deal in the Software without restriction, including */ +/* without limitation the rights to use, copy, modify, merge, publish, */ +/* distribute, sublicense, and/or sell copies of the Software, and to */ +/* permit persons to whom the Software is furnished to do so, subject to */ +/* the following conditions: */ +/* */ +/* The above copyright notice and this permission notice shall be */ +/* included in all copies or substantial portions of the Software. */ +/* */ +/* THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, */ +/* EXPRESS OR IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF */ +/* MERCHANTABILITY, FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. */ +/* IN NO EVENT SHALL THE AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY */ +/* CLAIM, DAMAGES OR OTHER LIABILITY, WHETHER IN AN ACTION OF CONTRACT, */ +/* TORT OR OTHERWISE, ARISING FROM, OUT OF OR IN CONNECTION WITH THE */ +/* SOFTWARE OR THE USE OR OTHER DEALINGS IN THE SOFTWARE. */ +/**************************************************************************/ + +#ifndef METAL_UTILS_H +#define METAL_UTILS_H + +#pragma mark - Boolean flags + +namespace flags { + +/*! Sets the flags within the value parameter specified by the mask parameter. */ +template +void set(Tv &p_value, Tm p_mask) { + using T = std::underlying_type_t; + p_value = static_cast(static_cast(p_value) | static_cast(p_mask)); +} + +/*! Clears the flags within the value parameter specified by the mask parameter. */ +template +void clear(Tv &p_value, Tm p_mask) { + using T = std::underlying_type_t; + p_value = static_cast(static_cast(p_value) & ~static_cast(p_mask)); +} + +/*! Returns whether the specified value has any of the bits specified in mask set to 1. */ +template +static constexpr bool any(Tv p_value, const Tm p_mask) { return ((p_value & p_mask) != 0); } + +/*! Returns whether the specified value has all of the bits specified in mask set to 1. */ +template +static constexpr bool all(Tv p_value, const Tm p_mask) { return ((p_value & p_mask) == p_mask); } + +} //namespace flags + +#pragma mark - Alignment and Offsets + +static constexpr bool is_power_of_two(uint64_t p_value) { + return p_value && ((p_value & (p_value - 1)) == 0); +} + +static constexpr uint64_t round_up_to_alignment(uint64_t p_value, uint64_t p_alignment) { + DEV_ASSERT(is_power_of_two(p_alignment)); + + if (p_alignment == 0) { + return p_value; + } + + uint64_t mask = p_alignment - 1; + uint64_t aligned_value = (p_value + mask) & ~mask; + + return aligned_value; +} + +#endif // METAL_UTILS_H diff --git a/drivers/metal/pixel_formats.h b/drivers/metal/pixel_formats.h new file mode 100644 index 00000000000..167c3d56009 --- /dev/null +++ b/drivers/metal/pixel_formats.h @@ -0,0 +1,416 @@ +/**************************************************************************/ +/* pixel_formats.h */ +/**************************************************************************/ +/* This file is part of: */ +/* GODOT ENGINE */ +/* https://godotengine.org */ +/**************************************************************************/ +/* Copyright (c) 2014-present Godot Engine contributors (see AUTHORS.md). */ +/* Copyright (c) 2007-2014 Juan Linietsky, Ariel Manzur. */ +/* */ +/* Permission is hereby granted, free of charge, to any person obtaining */ +/* a copy of this software and associated documentation files (the */ +/* "Software"), to deal in the Software without restriction, including */ +/* without limitation the rights to use, copy, modify, merge, publish, */ +/* distribute, sublicense, and/or sell copies of the Software, and to */ +/* permit persons to whom the Software is furnished to do so, subject to */ +/* the following conditions: */ +/* */ +/* The above copyright notice and this permission notice shall be */ +/* included in all copies or substantial portions of the Software. */ +/* */ +/* THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, */ +/* EXPRESS OR IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF */ +/* MERCHANTABILITY, FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. */ +/* IN NO EVENT SHALL THE AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY */ +/* CLAIM, DAMAGES OR OTHER LIABILITY, WHETHER IN AN ACTION OF CONTRACT, */ +/* TORT OR OTHERWISE, ARISING FROM, OUT OF OR IN CONNECTION WITH THE */ +/* SOFTWARE OR THE USE OR OTHER DEALINGS IN THE SOFTWARE. */ +/**************************************************************************/ + +/**************************************************************************/ +/* */ +/* Portions of this code were derived from MoltenVK. */ +/* */ +/* Copyright (c) 2015-2023 The Brenwill Workshop Ltd. */ +/* (http://www.brenwill.com) */ +/* */ +/* Licensed under the Apache License, Version 2.0 (the "License"); */ +/* you may not use this file except in compliance with the License. */ +/* You may obtain a copy of the License at */ +/* */ +/* http://www.apache.org/licenses/LICENSE-2.0 */ +/* */ +/* Unless required by applicable law or agreed to in writing, software */ +/* distributed under the License is distributed on an "AS IS" BASIS, */ +/* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or */ +/* implied. See the License for the specific language governing */ +/* permissions and limitations under the License. */ +/**************************************************************************/ + +#ifndef PIXEL_FORMATS_H +#define PIXEL_FORMATS_H + +#pragma clang diagnostic push +#pragma clang diagnostic ignored "-Wdeprecated-declarations" + +#import "servers/rendering/rendering_device.h" + +#import + +static const uint32_t _mtlPixelFormatCount = 256; +static const uint32_t _mtlPixelFormatCoreCount = MTLPixelFormatX32_Stencil8 + 2; // The actual last enum value is not available on iOS. +static const uint32_t _mtlVertexFormatCount = MTLVertexFormatHalf + 1; + +#pragma mark - +#pragma mark Metal format capabilities + +typedef enum : uint16_t { + + kMTLFmtCapsNone = 0, + /*! The format can be used in a shader read operation. */ + kMTLFmtCapsRead = (1 << 0), + /*! The format can be used in a shader filter operation during sampling. */ + kMTLFmtCapsFilter = (1 << 1), + /*! The format can be used in a shader write operation. */ + kMTLFmtCapsWrite = (1 << 2), + /*! The format can be used with atomic operations. */ + kMTLFmtCapsAtomic = (1 << 3), + /*! The format can be used as a color attachment. */ + kMTLFmtCapsColorAtt = (1 << 4), + /*! The format can be used as a depth-stencil attachment. */ + kMTLFmtCapsDSAtt = (1 << 5), + /*! The format can be used with blend operations. */ + kMTLFmtCapsBlend = (1 << 6), + /*! The format can be used as a destination for multisample antialias (MSAA) data. */ + kMTLFmtCapsMSAA = (1 << 7), + /*! The format can be used as a resolve attachment. */ + kMTLFmtCapsResolve = (1 << 8), + kMTLFmtCapsVertex = (1 << 9), + + kMTLFmtCapsRF = (kMTLFmtCapsRead | kMTLFmtCapsFilter), + kMTLFmtCapsRC = (kMTLFmtCapsRead | kMTLFmtCapsColorAtt), + kMTLFmtCapsRCB = (kMTLFmtCapsRC | kMTLFmtCapsBlend), + kMTLFmtCapsRCM = (kMTLFmtCapsRC | kMTLFmtCapsMSAA), + kMTLFmtCapsRCMB = (kMTLFmtCapsRCM | kMTLFmtCapsBlend), + kMTLFmtCapsRWC = (kMTLFmtCapsRC | kMTLFmtCapsWrite), + kMTLFmtCapsRWCB = (kMTLFmtCapsRWC | kMTLFmtCapsBlend), + kMTLFmtCapsRWCM = (kMTLFmtCapsRWC | kMTLFmtCapsMSAA), + kMTLFmtCapsRWCMB = (kMTLFmtCapsRWCM | kMTLFmtCapsBlend), + kMTLFmtCapsRFCMRB = (kMTLFmtCapsRCMB | kMTLFmtCapsFilter | kMTLFmtCapsResolve), + kMTLFmtCapsRFWCMB = (kMTLFmtCapsRWCMB | kMTLFmtCapsFilter), + kMTLFmtCapsAll = (kMTLFmtCapsRFWCMB | kMTLFmtCapsResolve), + + kMTLFmtCapsDRM = (kMTLFmtCapsDSAtt | kMTLFmtCapsRead | kMTLFmtCapsMSAA), + kMTLFmtCapsDRFM = (kMTLFmtCapsDRM | kMTLFmtCapsFilter), + kMTLFmtCapsDRMR = (kMTLFmtCapsDRM | kMTLFmtCapsResolve), + kMTLFmtCapsDRFMR = (kMTLFmtCapsDRFM | kMTLFmtCapsResolve), + + kMTLFmtCapsChromaSubsampling = kMTLFmtCapsRF, + kMTLFmtCapsMultiPlanar = kMTLFmtCapsChromaSubsampling, +} MTLFmtCaps; + +inline MTLFmtCaps operator|(MTLFmtCaps p_left, MTLFmtCaps p_right) { + return static_cast(static_cast(p_left) | p_right); +} + +inline MTLFmtCaps &operator|=(MTLFmtCaps &p_left, MTLFmtCaps p_right) { + return (p_left = p_left | p_right); +} + +#pragma mark - +#pragma mark Metal view classes + +enum class MTLViewClass : uint8_t { + None, + Color8, + Color16, + Color32, + Color64, + Color128, + PVRTC_RGB_2BPP, + PVRTC_RGB_4BPP, + PVRTC_RGBA_2BPP, + PVRTC_RGBA_4BPP, + EAC_R11, + EAC_RG11, + EAC_RGBA8, + ETC2_RGB8, + ETC2_RGB8A1, + ASTC_4x4, + ASTC_5x4, + ASTC_5x5, + ASTC_6x5, + ASTC_6x6, + ASTC_8x5, + ASTC_8x6, + ASTC_8x8, + ASTC_10x5, + ASTC_10x6, + ASTC_10x8, + ASTC_10x10, + ASTC_12x10, + ASTC_12x12, + BC1_RGBA, + BC2_RGBA, + BC3_RGBA, + BC4_R, + BC5_RG, + BC6H_RGB, + BC7_RGBA, + Depth24_Stencil8, + Depth32_Stencil8, + BGRA10_XR, + BGR10_XR +}; + +#pragma mark - +#pragma mark Format descriptors + +/** Enumerates the data type of a format. */ +enum class MTLFormatType { + None, /**< Format type is unknown. */ + ColorHalf, /**< A 16-bit floating point color. */ + ColorFloat, /**< A 32-bit floating point color. */ + ColorInt8, /**< A signed 8-bit integer color. */ + ColorUInt8, /**< An unsigned 8-bit integer color. */ + ColorInt16, /**< A signed 16-bit integer color. */ + ColorUInt16, /**< An unsigned 16-bit integer color. */ + ColorInt32, /**< A signed 32-bit integer color. */ + ColorUInt32, /**< An unsigned 32-bit integer color. */ + DepthStencil, /**< A depth and stencil value. */ + Compressed, /**< A block-compressed color. */ +}; + +typedef struct Extent2D { + uint32_t width; + uint32_t height; +} Extent2D; + +/** Describes the properties of a DataFormat, including the corresponding Metal pixel and vertex format. */ +typedef struct DataFormatDesc { + RD::DataFormat dataFormat; + MTLPixelFormat mtlPixelFormat; + MTLPixelFormat mtlPixelFormatSubstitute; + MTLVertexFormat mtlVertexFormat; + MTLVertexFormat mtlVertexFormatSubstitute; + uint8_t chromaSubsamplingPlaneCount; + uint8_t chromaSubsamplingComponentBits; + Extent2D blockTexelSize; + uint32_t bytesPerBlock; + MTLFormatType formatType; + const char *name; + bool hasReportedSubstitution; + + inline double bytesPerTexel() const { return (double)bytesPerBlock / (double)(blockTexelSize.width * blockTexelSize.height); } + + inline bool isSupported() const { return (mtlPixelFormat != MTLPixelFormatInvalid || chromaSubsamplingPlaneCount > 1); } + inline bool isSupportedOrSubstitutable() const { return isSupported() || (mtlPixelFormatSubstitute != MTLPixelFormatInvalid); } + + inline bool vertexIsSupported() const { return (mtlVertexFormat != MTLVertexFormatInvalid); } + inline bool vertexIsSupportedOrSubstitutable() const { return vertexIsSupported() || (mtlVertexFormatSubstitute != MTLVertexFormatInvalid); } +} DataFormatDesc; + +/** Describes the properties of a MTLPixelFormat or MTLVertexFormat. */ +typedef struct MTLFormatDesc { + union { + MTLPixelFormat mtlPixelFormat; + MTLVertexFormat mtlVertexFormat; + }; + RD::DataFormat dataFormat; + MTLFmtCaps mtlFmtCaps; + MTLViewClass mtlViewClass; + MTLPixelFormat mtlPixelFormatLinear; + const char *name = nullptr; + + inline bool isSupported() const { return (mtlPixelFormat != MTLPixelFormatInvalid) && (mtlFmtCaps != kMTLFmtCapsNone); } +} MTLFormatDesc; + +class API_AVAILABLE(macos(11.0), ios(14.0)) PixelFormats { + using DataFormat = RD::DataFormat; + +public: + /** Returns whether the DataFormat is supported by the GPU bound to this instance. */ + bool isSupported(DataFormat p_format); + + /** Returns whether the DataFormat is supported by this implementation, or can be substituted by one that is. */ + bool isSupportedOrSubstitutable(DataFormat p_format); + + /** Returns whether the specified Metal MTLPixelFormat can be used as a depth format. */ + _FORCE_INLINE_ bool isDepthFormat(MTLPixelFormat p_format) { + switch (p_format) { + case MTLPixelFormatDepth32Float: + case MTLPixelFormatDepth16Unorm: + case MTLPixelFormatDepth32Float_Stencil8: +#if TARGET_OS_OSX + case MTLPixelFormatDepth24Unorm_Stencil8: +#endif + return true; + default: + return false; + } + } + + /** Returns whether the specified Metal MTLPixelFormat can be used as a stencil format. */ + _FORCE_INLINE_ bool isStencilFormat(MTLPixelFormat p_format) { + switch (p_format) { + case MTLPixelFormatStencil8: +#if TARGET_OS_OSX + case MTLPixelFormatDepth24Unorm_Stencil8: + case MTLPixelFormatX24_Stencil8: +#endif + case MTLPixelFormatDepth32Float_Stencil8: + case MTLPixelFormatX32_Stencil8: + return true; + default: + return false; + } + } + + /** Returns whether the specified Metal MTLPixelFormat is a PVRTC format. */ + bool isPVRTCFormat(MTLPixelFormat p_format); + + /** Returns the format type corresponding to the specified Godot pixel format, */ + MTLFormatType getFormatType(DataFormat p_format); + + /** Returns the format type corresponding to the specified Metal MTLPixelFormat, */ + MTLFormatType getFormatType(MTLPixelFormat p_formt); + + /** + * Returns the Metal MTLPixelFormat corresponding to the specified Godot pixel + * or returns MTLPixelFormatInvalid if no corresponding MTLPixelFormat exists. + */ + MTLPixelFormat getMTLPixelFormat(DataFormat p_format); + + /** + * Returns the DataFormat corresponding to the specified Metal MTLPixelFormat, + * or returns DATA_FORMAT_MAX if no corresponding DataFormat exists. + */ + DataFormat getDataFormat(MTLPixelFormat p_format); + + /** + * Returns the size, in bytes, of a texel block of the specified Godot pixel. + * For uncompressed formats, the returned value corresponds to the size in bytes of a single texel. + */ + uint32_t getBytesPerBlock(DataFormat p_format); + + /** + * Returns the size, in bytes, of a texel block of the specified Metal format. + * For uncompressed formats, the returned value corresponds to the size in bytes of a single texel. + */ + uint32_t getBytesPerBlock(MTLPixelFormat p_format); + + /** Returns the number of planes of the specified chroma-subsampling (YCbCr) DataFormat */ + uint8_t getChromaSubsamplingPlaneCount(DataFormat p_format); + + /** Returns the number of bits per channel of the specified chroma-subsampling (YCbCr) DataFormat */ + uint8_t getChromaSubsamplingComponentBits(DataFormat p_format); + + /** + * Returns the size, in bytes, of a texel of the specified Godot format. + * The returned value may be fractional for certain compressed formats. + */ + float getBytesPerTexel(DataFormat p_format); + + /** + * Returns the size, in bytes, of a texel of the specified Metal format. + * The returned value may be fractional for certain compressed formats. + */ + float getBytesPerTexel(MTLPixelFormat p_format); + + /** + * Returns the size, in bytes, of a row of texels of the specified Godot pixel format. + * + * For compressed formats, this takes into consideration the compression block size, + * and p_texels_per_row should specify the width in texels, not blocks. The result is rounded + * up if p_texels_per_row is not an integer multiple of the compression block width. + */ + size_t getBytesPerRow(DataFormat p_format, uint32_t p_texels_per_row); + + /** + * Returns the size, in bytes, of a row of texels of the specified Metal format. + * + * For compressed formats, this takes into consideration the compression block size, + * and texelsPerRow should specify the width in texels, not blocks. The result is rounded + * up if texelsPerRow is not an integer multiple of the compression block width. + */ + size_t getBytesPerRow(MTLPixelFormat p_format, uint32_t p_texels_per_row); + + /** + * Returns the size, in bytes, of a texture layer of the specified Godot pixel format. + * + * For compressed formats, this takes into consideration the compression block size, + * and p_texel_rows_per_layer should specify the height in texels, not blocks. The result is + * rounded up if p_texel_rows_per_layer is not an integer multiple of the compression block height. + */ + size_t getBytesPerLayer(DataFormat p_format, size_t p_bytes_per_row, uint32_t p_texel_rows_per_layer); + + /** + * Returns the size, in bytes, of a texture layer of the specified Metal format. + * For compressed formats, this takes into consideration the compression block size, + * and p_texel_rows_per_layer should specify the height in texels, not blocks. The result is + * rounded up if p_texel_rows_per_layer is not an integer multiple of the compression block height. + */ + size_t getBytesPerLayer(MTLPixelFormat p_format, size_t p_bytes_per_row, uint32_t p_texel_rows_per_layer); + + /** Returns the Metal format capabilities supported by the specified Godot format, without substitution. */ + MTLFmtCaps getCapabilities(DataFormat p_format, bool p_extended = false); + + /** Returns the Metal format capabilities supported by the specified Metal format. */ + MTLFmtCaps getCapabilities(MTLPixelFormat p_format, bool p_extended = false); + + /** + * Returns the Metal MTLVertexFormat corresponding to the specified + * DataFormat as used as a vertex attribute format. + */ + MTLVertexFormat getMTLVertexFormat(DataFormat p_format); + +#pragma mark Construction + + explicit PixelFormats(id p_device); + +protected: + id device; + + DataFormatDesc &getDataFormatDesc(DataFormat p_format); + DataFormatDesc &getDataFormatDesc(MTLPixelFormat p_format); + MTLFormatDesc &getMTLPixelFormatDesc(MTLPixelFormat p_format); + MTLFormatDesc &getMTLVertexFormatDesc(MTLVertexFormat p_format); + void initDataFormatCapabilities(); + void initMTLPixelFormatCapabilities(); + void initMTLVertexFormatCapabilities(); + void buildMTLFormatMaps(); + void buildDFFormatMaps(); + void modifyMTLFormatCapabilities(); + void modifyMTLFormatCapabilities(id p_device); + void addMTLPixelFormatCapabilities(id p_device, + MTLFeatureSet p_feature_set, + MTLPixelFormat p_format, + MTLFmtCaps p_caps); + void addMTLPixelFormatCapabilities(id p_device, + MTLGPUFamily p_family, + MTLPixelFormat p_format, + MTLFmtCaps p_caps); + void disableMTLPixelFormatCapabilities(MTLPixelFormat p_format, + MTLFmtCaps p_caps); + void disableAllMTLPixelFormatCapabilities(MTLPixelFormat p_format); + void addMTLVertexFormatCapabilities(id p_device, + MTLFeatureSet p_feature_set, + MTLVertexFormat p_format, + MTLFmtCaps p_caps); + + DataFormatDesc _dataFormatDescriptions[RD::DATA_FORMAT_MAX]; + MTLFormatDesc _mtlPixelFormatDescriptions[_mtlPixelFormatCount]; + MTLFormatDesc _mtlVertexFormatDescriptions[_mtlVertexFormatCount]; + + // Most Metal formats have small values and are mapped by simple lookup array. + // Outliers are mapped by a map. + uint16_t _mtlFormatDescIndicesByMTLPixelFormatsCore[_mtlPixelFormatCoreCount]; + HashMap _mtlFormatDescIndicesByMTLPixelFormatsExt; + + uint16_t _mtlFormatDescIndicesByMTLVertexFormats[_mtlVertexFormatCount]; +}; + +#pragma clang diagnostic pop + +#endif // PIXEL_FORMATS_H diff --git a/drivers/metal/pixel_formats.mm b/drivers/metal/pixel_formats.mm new file mode 100644 index 00000000000..ac737b3f0a0 --- /dev/null +++ b/drivers/metal/pixel_formats.mm @@ -0,0 +1,1298 @@ +/**************************************************************************/ +/* pixel_formats.mm */ +/**************************************************************************/ +/* This file is part of: */ +/* GODOT ENGINE */ +/* https://godotengine.org */ +/**************************************************************************/ +/* Copyright (c) 2014-present Godot Engine contributors (see AUTHORS.md). */ +/* Copyright (c) 2007-2014 Juan Linietsky, Ariel Manzur. */ +/* */ +/* Permission is hereby granted, free of charge, to any person obtaining */ +/* a copy of this software and associated documentation files (the */ +/* "Software"), to deal in the Software without restriction, including */ +/* without limitation the rights to use, copy, modify, merge, publish, */ +/* distribute, sublicense, and/or sell copies of the Software, and to */ +/* permit persons to whom the Software is furnished to do so, subject to */ +/* the following conditions: */ +/* */ +/* The above copyright notice and this permission notice shall be */ +/* included in all copies or substantial portions of the Software. */ +/* */ +/* THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, */ +/* EXPRESS OR IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF */ +/* MERCHANTABILITY, FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. */ +/* IN NO EVENT SHALL THE AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY */ +/* CLAIM, DAMAGES OR OTHER LIABILITY, WHETHER IN AN ACTION OF CONTRACT, */ +/* TORT OR OTHERWISE, ARISING FROM, OUT OF OR IN CONNECTION WITH THE */ +/* SOFTWARE OR THE USE OR OTHER DEALINGS IN THE SOFTWARE. */ +/**************************************************************************/ + +/**************************************************************************/ +/* */ +/* Portions of this code were derived from MoltenVK. */ +/* */ +/* Copyright (c) 2015-2023 The Brenwill Workshop Ltd. */ +/* (http://www.brenwill.com) */ +/* */ +/* Licensed under the Apache License, Version 2.0 (the "License"); */ +/* you may not use this file except in compliance with the License. */ +/* You may obtain a copy of the License at */ +/* */ +/* http://www.apache.org/licenses/LICENSE-2.0 */ +/* */ +/* Unless required by applicable law or agreed to in writing, software */ +/* distributed under the License is distributed on an "AS IS" BASIS, */ +/* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or */ +/* implied. See the License for the specific language governing */ +/* permissions and limitations under the License. */ +/**************************************************************************/ + +#import "pixel_formats.h" + +#import "metal_utils.h" + +#if TARGET_OS_IPHONE || TARGET_OS_TV +#if !(__IPHONE_OS_VERSION_MAX_ALLOWED >= 160400) // iOS/tvOS 16.4 +#define MTLPixelFormatBC1_RGBA MTLPixelFormatInvalid +#define MTLPixelFormatBC1_RGBA_sRGB MTLPixelFormatInvalid +#define MTLPixelFormatBC2_RGBA MTLPixelFormatInvalid +#define MTLPixelFormatBC2_RGBA_sRGB MTLPixelFormatInvalid +#define MTLPixelFormatBC3_RGBA MTLPixelFormatInvalid +#define MTLPixelFormatBC3_RGBA_sRGB MTLPixelFormatInvalid +#define MTLPixelFormatBC4_RUnorm MTLPixelFormatInvalid +#define MTLPixelFormatBC4_RSnorm MTLPixelFormatInvalid +#define MTLPixelFormatBC5_RGUnorm MTLPixelFormatInvalid +#define MTLPixelFormatBC5_RGSnorm MTLPixelFormatInvalid +#define MTLPixelFormatBC6H_RGBUfloat MTLPixelFormatInvalid +#define MTLPixelFormatBC6H_RGBFloat MTLPixelFormatInvalid +#define MTLPixelFormatBC7_RGBAUnorm MTLPixelFormatInvalid +#define MTLPixelFormatBC7_RGBAUnorm_sRGB MTLPixelFormatInvalid +#endif + +#define MTLPixelFormatDepth16Unorm_Stencil8 MTLPixelFormatDepth32Float_Stencil8 +#define MTLPixelFormatDepth24Unorm_Stencil8 MTLPixelFormatInvalid +#define MTLPixelFormatX24_Stencil8 MTLPixelFormatInvalid +#endif + +#if TARGET_OS_TV +#define MTLPixelFormatASTC_4x4_HDR MTLPixelFormatInvalid +#define MTLPixelFormatASTC_5x4_HDR MTLPixelFormatInvalid +#define MTLPixelFormatASTC_5x5_HDR MTLPixelFormatInvalid +#define MTLPixelFormatASTC_6x5_HDR MTLPixelFormatInvalid +#define MTLPixelFormatASTC_6x6_HDR MTLPixelFormatInvalid +#define MTLPixelFormatASTC_8x5_HDR MTLPixelFormatInvalid +#define MTLPixelFormatASTC_8x6_HDR MTLPixelFormatInvalid +#define MTLPixelFormatASTC_8x8_HDR MTLPixelFormatInvalid +#define MTLPixelFormatASTC_10x5_HDR MTLPixelFormatInvalid +#define MTLPixelFormatASTC_10x6_HDR MTLPixelFormatInvalid +#define MTLPixelFormatASTC_10x8_HDR MTLPixelFormatInvalid +#define MTLPixelFormatASTC_10x10_HDR MTLPixelFormatInvalid +#define MTLPixelFormatASTC_12x10_HDR MTLPixelFormatInvalid +#define MTLPixelFormatASTC_12x12_HDR MTLPixelFormatInvalid +#endif + +#if !((__MAC_OS_X_VERSION_MAX_ALLOWED >= 140000) || (__IPHONE_OS_VERSION_MAX_ALLOWED >= 170000)) // Xcode 15 +#define MTLVertexFormatFloatRG11B10 MTLVertexFormatInvalid +#define MTLVertexFormatFloatRGB9E5 MTLVertexFormatInvalid +#endif + +/** Selects and returns one of the values, based on the platform OS. */ +_FORCE_INLINE_ constexpr MTLFmtCaps select_platform_caps(MTLFmtCaps p_macOS_val, MTLFmtCaps p_iOS_val) { +#if (TARGET_OS_IOS || TARGET_OS_TV) && !TARGET_OS_MACCATALYST + return p_iOS_val; +#elif TARGET_OS_OSX + return p_macOS_val; +#else +#error "unsupported platform" +#endif +} + +template +void clear(T *p_val, size_t p_count = 1) { + memset(p_val, 0, sizeof(T) * p_count); +} + +#pragma mark - +#pragma mark PixelFormats + +bool PixelFormats::isSupported(DataFormat p_format) { + return getDataFormatDesc(p_format).isSupported(); +} + +bool PixelFormats::isSupportedOrSubstitutable(DataFormat p_format) { + return getDataFormatDesc(p_format).isSupportedOrSubstitutable(); +} + +bool PixelFormats::isPVRTCFormat(MTLPixelFormat p_format) { + switch (p_format) { + case MTLPixelFormatPVRTC_RGBA_2BPP: + case MTLPixelFormatPVRTC_RGBA_2BPP_sRGB: + case MTLPixelFormatPVRTC_RGBA_4BPP: + case MTLPixelFormatPVRTC_RGBA_4BPP_sRGB: + case MTLPixelFormatPVRTC_RGB_2BPP: + case MTLPixelFormatPVRTC_RGB_2BPP_sRGB: + case MTLPixelFormatPVRTC_RGB_4BPP: + case MTLPixelFormatPVRTC_RGB_4BPP_sRGB: + return true; + default: + return false; + } +} + +MTLFormatType PixelFormats::getFormatType(DataFormat p_format) { + return getDataFormatDesc(p_format).formatType; +} + +MTLFormatType PixelFormats::getFormatType(MTLPixelFormat p_formt) { + return getDataFormatDesc(p_formt).formatType; +} + +MTLPixelFormat PixelFormats::getMTLPixelFormat(DataFormat p_format) { + DataFormatDesc &dfDesc = getDataFormatDesc(p_format); + MTLPixelFormat mtlPixFmt = dfDesc.mtlPixelFormat; + + // If the MTLPixelFormat is not supported but DataFormat is valid, + // attempt to substitute a different format. + if (mtlPixFmt == MTLPixelFormatInvalid && p_format != RD::DATA_FORMAT_MAX && dfDesc.chromaSubsamplingPlaneCount <= 1) { + mtlPixFmt = dfDesc.mtlPixelFormatSubstitute; + } + + return mtlPixFmt; +} + +RD::DataFormat PixelFormats::getDataFormat(MTLPixelFormat p_format) { + return getMTLPixelFormatDesc(p_format).dataFormat; +} + +uint32_t PixelFormats::getBytesPerBlock(DataFormat p_format) { + return getDataFormatDesc(p_format).bytesPerBlock; +} + +uint32_t PixelFormats::getBytesPerBlock(MTLPixelFormat p_format) { + return getDataFormatDesc(p_format).bytesPerBlock; +} + +uint8_t PixelFormats::getChromaSubsamplingPlaneCount(DataFormat p_format) { + return getDataFormatDesc(p_format).chromaSubsamplingPlaneCount; +} + +uint8_t PixelFormats::getChromaSubsamplingComponentBits(DataFormat p_format) { + return getDataFormatDesc(p_format).chromaSubsamplingComponentBits; +} + +float PixelFormats::getBytesPerTexel(DataFormat p_format) { + return getDataFormatDesc(p_format).bytesPerTexel(); +} + +float PixelFormats::getBytesPerTexel(MTLPixelFormat p_format) { + return getDataFormatDesc(p_format).bytesPerTexel(); +} + +size_t PixelFormats::getBytesPerRow(DataFormat p_format, uint32_t p_texels_per_row) { + DataFormatDesc &dfDesc = getDataFormatDesc(p_format); + return Math::division_round_up(p_texels_per_row, dfDesc.blockTexelSize.width) * dfDesc.bytesPerBlock; +} + +size_t PixelFormats::getBytesPerRow(MTLPixelFormat p_format, uint32_t p_texels_per_row) { + DataFormatDesc &dfDesc = getDataFormatDesc(p_format); + return Math::division_round_up(p_texels_per_row, dfDesc.blockTexelSize.width) * dfDesc.bytesPerBlock; +} + +size_t PixelFormats::getBytesPerLayer(DataFormat p_format, size_t p_bytes_per_row, uint32_t p_texel_rows_per_layer) { + return Math::division_round_up(p_texel_rows_per_layer, getDataFormatDesc(p_format).blockTexelSize.height) * p_bytes_per_row; +} + +size_t PixelFormats::getBytesPerLayer(MTLPixelFormat p_format, size_t p_bytes_per_row, uint32_t p_texel_rows_per_layer) { + return Math::division_round_up(p_texel_rows_per_layer, getDataFormatDesc(p_format).blockTexelSize.height) * p_bytes_per_row; +} + +MTLFmtCaps PixelFormats::getCapabilities(DataFormat p_format, bool p_extended) { + return getCapabilities(getDataFormatDesc(p_format).mtlPixelFormat, p_extended); +} + +MTLFmtCaps PixelFormats::getCapabilities(MTLPixelFormat p_format, bool p_extended) { + MTLFormatDesc &mtlDesc = getMTLPixelFormatDesc(p_format); + MTLFmtCaps caps = mtlDesc.mtlFmtCaps; + if (!p_extended || mtlDesc.mtlViewClass == MTLViewClass::None) { + return caps; + } + // Now get caps of all formats in the view class. + for (MTLFormatDesc &otherDesc : _mtlPixelFormatDescriptions) { + if (otherDesc.mtlViewClass == mtlDesc.mtlViewClass) { + caps |= otherDesc.mtlFmtCaps; + } + } + return caps; +} + +MTLVertexFormat PixelFormats::getMTLVertexFormat(DataFormat p_format) { + DataFormatDesc &dfDesc = getDataFormatDesc(p_format); + MTLVertexFormat format = dfDesc.mtlVertexFormat; + + if (format == MTLVertexFormatInvalid) { + String errMsg; + errMsg += "DataFormat "; + errMsg += dfDesc.name; + errMsg += " is not supported for vertex buffers on this device."; + + if (dfDesc.vertexIsSupportedOrSubstitutable()) { + format = dfDesc.mtlVertexFormatSubstitute; + + DataFormatDesc &dfDescSubs = getDataFormatDesc(getMTLVertexFormatDesc(format).dataFormat); + errMsg += " Using DataFormat "; + errMsg += dfDescSubs.name; + errMsg += " instead."; + } + WARN_PRINT(errMsg); + } + + return format; +} + +DataFormatDesc &PixelFormats::getDataFormatDesc(DataFormat p_format) { + CRASH_BAD_INDEX_MSG(p_format, RD::DATA_FORMAT_MAX, "Attempting to describe an invalid DataFormat"); + return _dataFormatDescriptions[p_format]; +} + +DataFormatDesc &PixelFormats::getDataFormatDesc(MTLPixelFormat p_format) { + return getDataFormatDesc(getMTLPixelFormatDesc(p_format).dataFormat); +} + +// Return a reference to the Metal format descriptor corresponding to the MTLPixelFormat. +MTLFormatDesc &PixelFormats::getMTLPixelFormatDesc(MTLPixelFormat p_format) { + uint16_t fmtIdx = ((p_format < _mtlPixelFormatCoreCount) + ? _mtlFormatDescIndicesByMTLPixelFormatsCore[p_format] + : _mtlFormatDescIndicesByMTLPixelFormatsExt[p_format]); + return _mtlPixelFormatDescriptions[fmtIdx]; +} + +// Return a reference to the Metal format descriptor corresponding to the MTLVertexFormat. +MTLFormatDesc &PixelFormats::getMTLVertexFormatDesc(MTLVertexFormat p_format) { + uint16_t fmtIdx = (p_format < _mtlVertexFormatCount) ? _mtlFormatDescIndicesByMTLVertexFormats[p_format] : 0; + return _mtlVertexFormatDescriptions[fmtIdx]; +} + +PixelFormats::PixelFormats(id p_device) : + device(p_device) { + initMTLPixelFormatCapabilities(); + initMTLVertexFormatCapabilities(); + buildMTLFormatMaps(); + modifyMTLFormatCapabilities(); + + initDataFormatCapabilities(); + buildDFFormatMaps(); +} + +#define addDfFormatDescFull(DATA_FMT, MTL_FMT, MTL_FMT_ALT, MTL_VTX_FMT, MTL_VTX_FMT_ALT, CSPC, CSCB, BLK_W, BLK_H, BLK_BYTE_CNT, MVK_FMT_TYPE) \ + CRASH_BAD_INDEX_MSG(RD::DATA_FORMAT_##DATA_FMT, RD::DATA_FORMAT_MAX, "Attempting to describe too many DataFormats"); \ + _dataFormatDescriptions[RD::DATA_FORMAT_##DATA_FMT] = { RD::DATA_FORMAT_##DATA_FMT, MTLPixelFormat##MTL_FMT, MTLPixelFormat##MTL_FMT_ALT, MTLVertexFormat##MTL_VTX_FMT, MTLVertexFormat##MTL_VTX_FMT_ALT, \ + CSPC, CSCB, { BLK_W, BLK_H }, BLK_BYTE_CNT, MTLFormatType::MVK_FMT_TYPE, "DATA_FORMAT_" #DATA_FMT, false } + +#define addDataFormatDesc(DATA_FMT, MTL_FMT, MTL_FMT_ALT, MTL_VTX_FMT, MTL_VTX_FMT_ALT, BLK_W, BLK_H, BLK_BYTE_CNT, MVK_FMT_TYPE) \ + addDfFormatDescFull(DATA_FMT, MTL_FMT, MTL_FMT_ALT, MTL_VTX_FMT, MTL_VTX_FMT_ALT, 0, 0, BLK_W, BLK_H, BLK_BYTE_CNT, MVK_FMT_TYPE) + +#define addDfFormatDescChromaSubsampling(DATA_FMT, MTL_FMT, CSPC, CSCB, BLK_W, BLK_H, BLK_BYTE_CNT) \ + addDfFormatDescFull(DATA_FMT, MTL_FMT, Invalid, Invalid, Invalid, CSPC, CSCB, BLK_W, BLK_H, BLK_BYTE_CNT, ColorFloat) + +void PixelFormats::initDataFormatCapabilities() { + clear(_dataFormatDescriptions, RD::DATA_FORMAT_MAX); + + addDataFormatDesc(R4G4_UNORM_PACK8, Invalid, Invalid, Invalid, Invalid, 1, 1, 1, ColorFloat); + addDataFormatDesc(R4G4B4A4_UNORM_PACK16, ABGR4Unorm, Invalid, Invalid, Invalid, 1, 1, 2, ColorFloat); + addDataFormatDesc(B4G4R4A4_UNORM_PACK16, Invalid, Invalid, Invalid, Invalid, 1, 1, 2, ColorFloat); + + addDataFormatDesc(R5G6B5_UNORM_PACK16, B5G6R5Unorm, Invalid, Invalid, Invalid, 1, 1, 2, ColorFloat); + addDataFormatDesc(B5G6R5_UNORM_PACK16, Invalid, Invalid, Invalid, Invalid, 1, 1, 2, ColorFloat); + addDataFormatDesc(R5G5B5A1_UNORM_PACK16, A1BGR5Unorm, Invalid, Invalid, Invalid, 1, 1, 2, ColorFloat); + addDataFormatDesc(B5G5R5A1_UNORM_PACK16, Invalid, Invalid, Invalid, Invalid, 1, 1, 2, ColorFloat); + addDataFormatDesc(A1R5G5B5_UNORM_PACK16, BGR5A1Unorm, Invalid, Invalid, Invalid, 1, 1, 2, ColorFloat); + + addDataFormatDesc(R8_UNORM, R8Unorm, Invalid, UCharNormalized, UChar2Normalized, 1, 1, 1, ColorFloat); + addDataFormatDesc(R8_SNORM, R8Snorm, Invalid, CharNormalized, Char2Normalized, 1, 1, 1, ColorFloat); + addDataFormatDesc(R8_USCALED, Invalid, Invalid, UChar, UChar2, 1, 1, 1, ColorFloat); + addDataFormatDesc(R8_SSCALED, Invalid, Invalid, Char, Char2, 1, 1, 1, ColorFloat); + addDataFormatDesc(R8_UINT, R8Uint, Invalid, UChar, UChar2, 1, 1, 1, ColorUInt8); + addDataFormatDesc(R8_SINT, R8Sint, Invalid, Char, Char2, 1, 1, 1, ColorInt8); + addDataFormatDesc(R8_SRGB, R8Unorm_sRGB, Invalid, UCharNormalized, UChar2Normalized, 1, 1, 1, ColorFloat); + + addDataFormatDesc(R8G8_UNORM, RG8Unorm, Invalid, UChar2Normalized, Invalid, 1, 1, 2, ColorFloat); + addDataFormatDesc(R8G8_SNORM, RG8Snorm, Invalid, Char2Normalized, Invalid, 1, 1, 2, ColorFloat); + addDataFormatDesc(R8G8_USCALED, Invalid, Invalid, UChar2, Invalid, 1, 1, 2, ColorFloat); + addDataFormatDesc(R8G8_SSCALED, Invalid, Invalid, Char2, Invalid, 1, 1, 2, ColorFloat); + addDataFormatDesc(R8G8_UINT, RG8Uint, Invalid, UChar2, Invalid, 1, 1, 2, ColorUInt8); + addDataFormatDesc(R8G8_SINT, RG8Sint, Invalid, Char2, Invalid, 1, 1, 2, ColorInt8); + addDataFormatDesc(R8G8_SRGB, RG8Unorm_sRGB, Invalid, UChar2Normalized, Invalid, 1, 1, 2, ColorFloat); + + addDataFormatDesc(R8G8B8_UNORM, Invalid, Invalid, UChar3Normalized, Invalid, 1, 1, 3, ColorFloat); + addDataFormatDesc(R8G8B8_SNORM, Invalid, Invalid, Char3Normalized, Invalid, 1, 1, 3, ColorFloat); + addDataFormatDesc(R8G8B8_USCALED, Invalid, Invalid, UChar3, Invalid, 1, 1, 3, ColorFloat); + addDataFormatDesc(R8G8B8_SSCALED, Invalid, Invalid, Char3, Invalid, 1, 1, 3, ColorFloat); + addDataFormatDesc(R8G8B8_UINT, Invalid, Invalid, UChar3, Invalid, 1, 1, 3, ColorUInt8); + addDataFormatDesc(R8G8B8_SINT, Invalid, Invalid, Char3, Invalid, 1, 1, 3, ColorInt8); + addDataFormatDesc(R8G8B8_SRGB, Invalid, Invalid, UChar3Normalized, Invalid, 1, 1, 3, ColorFloat); + + addDataFormatDesc(B8G8R8_UNORM, Invalid, Invalid, Invalid, Invalid, 1, 1, 3, ColorFloat); + addDataFormatDesc(B8G8R8_SNORM, Invalid, Invalid, Invalid, Invalid, 1, 1, 3, ColorFloat); + addDataFormatDesc(B8G8R8_USCALED, Invalid, Invalid, Invalid, Invalid, 1, 1, 3, ColorFloat); + addDataFormatDesc(B8G8R8_SSCALED, Invalid, Invalid, Invalid, Invalid, 1, 1, 3, ColorFloat); + addDataFormatDesc(B8G8R8_UINT, Invalid, Invalid, Invalid, Invalid, 1, 1, 3, ColorUInt8); + addDataFormatDesc(B8G8R8_SINT, Invalid, Invalid, Invalid, Invalid, 1, 1, 3, ColorInt8); + addDataFormatDesc(B8G8R8_SRGB, Invalid, Invalid, Invalid, Invalid, 1, 1, 3, ColorFloat); + + addDataFormatDesc(R8G8B8A8_UNORM, RGBA8Unorm, Invalid, UChar4Normalized, Invalid, 1, 1, 4, ColorFloat); + addDataFormatDesc(R8G8B8A8_SNORM, RGBA8Snorm, Invalid, Char4Normalized, Invalid, 1, 1, 4, ColorFloat); + addDataFormatDesc(R8G8B8A8_USCALED, Invalid, Invalid, UChar4, Invalid, 1, 1, 4, ColorFloat); + addDataFormatDesc(R8G8B8A8_SSCALED, Invalid, Invalid, Char4, Invalid, 1, 1, 4, ColorFloat); + addDataFormatDesc(R8G8B8A8_UINT, RGBA8Uint, Invalid, UChar4, Invalid, 1, 1, 4, ColorUInt8); + addDataFormatDesc(R8G8B8A8_SINT, RGBA8Sint, Invalid, Char4, Invalid, 1, 1, 4, ColorInt8); + addDataFormatDesc(R8G8B8A8_SRGB, RGBA8Unorm_sRGB, Invalid, UChar4Normalized, Invalid, 1, 1, 4, ColorFloat); + + addDataFormatDesc(B8G8R8A8_UNORM, BGRA8Unorm, Invalid, UChar4Normalized_BGRA, Invalid, 1, 1, 4, ColorFloat); + addDataFormatDesc(B8G8R8A8_SNORM, Invalid, Invalid, Invalid, Invalid, 1, 1, 4, ColorFloat); + addDataFormatDesc(B8G8R8A8_USCALED, Invalid, Invalid, Invalid, Invalid, 1, 1, 4, ColorFloat); + addDataFormatDesc(B8G8R8A8_SSCALED, Invalid, Invalid, Invalid, Invalid, 1, 1, 4, ColorFloat); + addDataFormatDesc(B8G8R8A8_UINT, Invalid, Invalid, Invalid, Invalid, 1, 1, 4, ColorUInt8); + addDataFormatDesc(B8G8R8A8_SINT, Invalid, Invalid, Invalid, Invalid, 1, 1, 4, ColorInt8); + addDataFormatDesc(B8G8R8A8_SRGB, BGRA8Unorm_sRGB, Invalid, Invalid, Invalid, 1, 1, 4, ColorFloat); + + addDataFormatDesc(A8B8G8R8_UNORM_PACK32, RGBA8Unorm, Invalid, UChar4Normalized, Invalid, 1, 1, 4, ColorFloat); + addDataFormatDesc(A8B8G8R8_SNORM_PACK32, RGBA8Snorm, Invalid, Char4Normalized, Invalid, 1, 1, 4, ColorFloat); + addDataFormatDesc(A8B8G8R8_USCALED_PACK32, Invalid, Invalid, UChar4, Invalid, 1, 1, 4, ColorFloat); + addDataFormatDesc(A8B8G8R8_SSCALED_PACK32, Invalid, Invalid, Char4, Invalid, 1, 1, 4, ColorFloat); + addDataFormatDesc(A8B8G8R8_UINT_PACK32, RGBA8Uint, Invalid, UChar4, Invalid, 1, 1, 4, ColorUInt8); + addDataFormatDesc(A8B8G8R8_SINT_PACK32, RGBA8Sint, Invalid, Char4, Invalid, 1, 1, 4, ColorInt8); + addDataFormatDesc(A8B8G8R8_SRGB_PACK32, RGBA8Unorm_sRGB, Invalid, UChar4Normalized, Invalid, 1, 1, 4, ColorFloat); + + addDataFormatDesc(A2R10G10B10_UNORM_PACK32, BGR10A2Unorm, Invalid, Invalid, Invalid, 1, 1, 4, ColorFloat); + addDataFormatDesc(A2R10G10B10_SNORM_PACK32, Invalid, Invalid, Invalid, Invalid, 1, 1, 4, ColorFloat); + addDataFormatDesc(A2R10G10B10_USCALED_PACK32, Invalid, Invalid, Invalid, Invalid, 1, 1, 4, ColorFloat); + addDataFormatDesc(A2R10G10B10_SSCALED_PACK32, Invalid, Invalid, Invalid, Invalid, 1, 1, 4, ColorFloat); + addDataFormatDesc(A2R10G10B10_UINT_PACK32, Invalid, Invalid, Invalid, Invalid, 1, 1, 4, ColorUInt16); + addDataFormatDesc(A2R10G10B10_SINT_PACK32, Invalid, Invalid, Invalid, Invalid, 1, 1, 4, ColorInt16); + + addDataFormatDesc(A2B10G10R10_UNORM_PACK32, RGB10A2Unorm, Invalid, UInt1010102Normalized, Invalid, 1, 1, 4, ColorFloat); + addDataFormatDesc(A2B10G10R10_SNORM_PACK32, Invalid, Invalid, Int1010102Normalized, Invalid, 1, 1, 4, ColorFloat); + addDataFormatDesc(A2B10G10R10_USCALED_PACK32, Invalid, Invalid, Invalid, Invalid, 1, 1, 4, ColorFloat); + addDataFormatDesc(A2B10G10R10_SSCALED_PACK32, Invalid, Invalid, Invalid, Invalid, 1, 1, 4, ColorFloat); + addDataFormatDesc(A2B10G10R10_UINT_PACK32, RGB10A2Uint, Invalid, Invalid, Invalid, 1, 1, 4, ColorUInt16); + addDataFormatDesc(A2B10G10R10_SINT_PACK32, Invalid, Invalid, Invalid, Invalid, 1, 1, 4, ColorInt16); + + addDataFormatDesc(R16_UNORM, R16Unorm, Invalid, UShortNormalized, UShort2Normalized, 1, 1, 2, ColorFloat); + addDataFormatDesc(R16_SNORM, R16Snorm, Invalid, ShortNormalized, Short2Normalized, 1, 1, 2, ColorFloat); + addDataFormatDesc(R16_USCALED, Invalid, Invalid, UShort, UShort2, 1, 1, 2, ColorFloat); + addDataFormatDesc(R16_SSCALED, Invalid, Invalid, Short, Short2, 1, 1, 2, ColorFloat); + addDataFormatDesc(R16_UINT, R16Uint, Invalid, UShort, UShort2, 1, 1, 2, ColorUInt16); + addDataFormatDesc(R16_SINT, R16Sint, Invalid, Short, Short2, 1, 1, 2, ColorInt16); + addDataFormatDesc(R16_SFLOAT, R16Float, Invalid, Half, Half2, 1, 1, 2, ColorFloat); + + addDataFormatDesc(R16G16_UNORM, RG16Unorm, Invalid, UShort2Normalized, Invalid, 1, 1, 4, ColorFloat); + addDataFormatDesc(R16G16_SNORM, RG16Snorm, Invalid, Short2Normalized, Invalid, 1, 1, 4, ColorFloat); + addDataFormatDesc(R16G16_USCALED, Invalid, Invalid, UShort2, Invalid, 1, 1, 4, ColorFloat); + addDataFormatDesc(R16G16_SSCALED, Invalid, Invalid, Short2, Invalid, 1, 1, 4, ColorFloat); + addDataFormatDesc(R16G16_UINT, RG16Uint, Invalid, UShort2, Invalid, 1, 1, 4, ColorUInt16); + addDataFormatDesc(R16G16_SINT, RG16Sint, Invalid, Short2, Invalid, 1, 1, 4, ColorInt16); + addDataFormatDesc(R16G16_SFLOAT, RG16Float, Invalid, Half2, Invalid, 1, 1, 4, ColorFloat); + + addDataFormatDesc(R16G16B16_UNORM, Invalid, Invalid, UShort3Normalized, Invalid, 1, 1, 6, ColorFloat); + addDataFormatDesc(R16G16B16_SNORM, Invalid, Invalid, Short3Normalized, Invalid, 1, 1, 6, ColorFloat); + addDataFormatDesc(R16G16B16_USCALED, Invalid, Invalid, UShort3, Invalid, 1, 1, 6, ColorFloat); + addDataFormatDesc(R16G16B16_SSCALED, Invalid, Invalid, Short3, Invalid, 1, 1, 6, ColorFloat); + addDataFormatDesc(R16G16B16_UINT, Invalid, Invalid, UShort3, Invalid, 1, 1, 6, ColorUInt16); + addDataFormatDesc(R16G16B16_SINT, Invalid, Invalid, Short3, Invalid, 1, 1, 6, ColorInt16); + addDataFormatDesc(R16G16B16_SFLOAT, Invalid, Invalid, Half3, Invalid, 1, 1, 6, ColorFloat); + + addDataFormatDesc(R16G16B16A16_UNORM, RGBA16Unorm, Invalid, UShort4Normalized, Invalid, 1, 1, 8, ColorFloat); + addDataFormatDesc(R16G16B16A16_SNORM, RGBA16Snorm, Invalid, Short4Normalized, Invalid, 1, 1, 8, ColorFloat); + addDataFormatDesc(R16G16B16A16_USCALED, Invalid, Invalid, UShort4, Invalid, 1, 1, 8, ColorFloat); + addDataFormatDesc(R16G16B16A16_SSCALED, Invalid, Invalid, Short4, Invalid, 1, 1, 8, ColorFloat); + addDataFormatDesc(R16G16B16A16_UINT, RGBA16Uint, Invalid, UShort4, Invalid, 1, 1, 8, ColorUInt16); + addDataFormatDesc(R16G16B16A16_SINT, RGBA16Sint, Invalid, Short4, Invalid, 1, 1, 8, ColorInt16); + addDataFormatDesc(R16G16B16A16_SFLOAT, RGBA16Float, Invalid, Half4, Invalid, 1, 1, 8, ColorFloat); + + addDataFormatDesc(R32_UINT, R32Uint, Invalid, UInt, Invalid, 1, 1, 4, ColorUInt32); + addDataFormatDesc(R32_SINT, R32Sint, Invalid, Int, Invalid, 1, 1, 4, ColorInt32); + addDataFormatDesc(R32_SFLOAT, R32Float, Invalid, Float, Invalid, 1, 1, 4, ColorFloat); + + addDataFormatDesc(R32G32_UINT, RG32Uint, Invalid, UInt2, Invalid, 1, 1, 8, ColorUInt32); + addDataFormatDesc(R32G32_SINT, RG32Sint, Invalid, Int2, Invalid, 1, 1, 8, ColorInt32); + addDataFormatDesc(R32G32_SFLOAT, RG32Float, Invalid, Float2, Invalid, 1, 1, 8, ColorFloat); + + addDataFormatDesc(R32G32B32_UINT, Invalid, Invalid, UInt3, Invalid, 1, 1, 12, ColorUInt32); + addDataFormatDesc(R32G32B32_SINT, Invalid, Invalid, Int3, Invalid, 1, 1, 12, ColorInt32); + addDataFormatDesc(R32G32B32_SFLOAT, Invalid, Invalid, Float3, Invalid, 1, 1, 12, ColorFloat); + + addDataFormatDesc(R32G32B32A32_UINT, RGBA32Uint, Invalid, UInt4, Invalid, 1, 1, 16, ColorUInt32); + addDataFormatDesc(R32G32B32A32_SINT, RGBA32Sint, Invalid, Int4, Invalid, 1, 1, 16, ColorInt32); + addDataFormatDesc(R32G32B32A32_SFLOAT, RGBA32Float, Invalid, Float4, Invalid, 1, 1, 16, ColorFloat); + + addDataFormatDesc(R64_UINT, Invalid, Invalid, Invalid, Invalid, 1, 1, 8, ColorFloat); + addDataFormatDesc(R64_SINT, Invalid, Invalid, Invalid, Invalid, 1, 1, 8, ColorFloat); + addDataFormatDesc(R64_SFLOAT, Invalid, Invalid, Invalid, Invalid, 1, 1, 8, ColorFloat); + + addDataFormatDesc(R64G64_UINT, Invalid, Invalid, Invalid, Invalid, 1, 1, 16, ColorFloat); + addDataFormatDesc(R64G64_SINT, Invalid, Invalid, Invalid, Invalid, 1, 1, 16, ColorFloat); + addDataFormatDesc(R64G64_SFLOAT, Invalid, Invalid, Invalid, Invalid, 1, 1, 16, ColorFloat); + + addDataFormatDesc(R64G64B64_UINT, Invalid, Invalid, Invalid, Invalid, 1, 1, 24, ColorFloat); + addDataFormatDesc(R64G64B64_SINT, Invalid, Invalid, Invalid, Invalid, 1, 1, 24, ColorFloat); + addDataFormatDesc(R64G64B64_SFLOAT, Invalid, Invalid, Invalid, Invalid, 1, 1, 24, ColorFloat); + + addDataFormatDesc(R64G64B64A64_UINT, Invalid, Invalid, Invalid, Invalid, 1, 1, 32, ColorFloat); + addDataFormatDesc(R64G64B64A64_SINT, Invalid, Invalid, Invalid, Invalid, 1, 1, 32, ColorFloat); + addDataFormatDesc(R64G64B64A64_SFLOAT, Invalid, Invalid, Invalid, Invalid, 1, 1, 32, ColorFloat); + + addDataFormatDesc(B10G11R11_UFLOAT_PACK32, RG11B10Float, Invalid, Invalid, Invalid, 1, 1, 4, ColorFloat); + addDataFormatDesc(E5B9G9R9_UFLOAT_PACK32, RGB9E5Float, Invalid, Invalid, Invalid, 1, 1, 4, ColorFloat); + + addDataFormatDesc(D32_SFLOAT, Depth32Float, Invalid, Invalid, Invalid, 1, 1, 4, DepthStencil); + addDataFormatDesc(D32_SFLOAT_S8_UINT, Depth32Float_Stencil8, Invalid, Invalid, Invalid, 1, 1, 5, DepthStencil); + + addDataFormatDesc(S8_UINT, Stencil8, Invalid, Invalid, Invalid, 1, 1, 1, DepthStencil); + + addDataFormatDesc(D16_UNORM, Depth16Unorm, Depth32Float, Invalid, Invalid, 1, 1, 2, DepthStencil); + addDataFormatDesc(D16_UNORM_S8_UINT, Invalid, Invalid, Invalid, Invalid, 1, 1, 3, DepthStencil); + addDataFormatDesc(D24_UNORM_S8_UINT, Depth24Unorm_Stencil8, Depth32Float_Stencil8, Invalid, Invalid, 1, 1, 4, DepthStencil); + + addDataFormatDesc(X8_D24_UNORM_PACK32, Invalid, Depth24Unorm_Stencil8, Invalid, Invalid, 1, 1, 4, DepthStencil); + +#pragma clang diagnostic push +#pragma clang diagnostic ignored "-Wunguarded-availability" + + addDataFormatDesc(BC1_RGB_UNORM_BLOCK, BC1_RGBA, Invalid, Invalid, Invalid, 4, 4, 8, Compressed); + addDataFormatDesc(BC1_RGB_SRGB_BLOCK, BC1_RGBA_sRGB, Invalid, Invalid, Invalid, 4, 4, 8, Compressed); + addDataFormatDesc(BC1_RGBA_UNORM_BLOCK, BC1_RGBA, Invalid, Invalid, Invalid, 4, 4, 8, Compressed); + addDataFormatDesc(BC1_RGBA_SRGB_BLOCK, BC1_RGBA_sRGB, Invalid, Invalid, Invalid, 4, 4, 8, Compressed); + + addDataFormatDesc(BC2_UNORM_BLOCK, BC2_RGBA, Invalid, Invalid, Invalid, 4, 4, 16, Compressed); + addDataFormatDesc(BC2_SRGB_BLOCK, BC2_RGBA_sRGB, Invalid, Invalid, Invalid, 4, 4, 16, Compressed); + + addDataFormatDesc(BC3_UNORM_BLOCK, BC3_RGBA, Invalid, Invalid, Invalid, 4, 4, 16, Compressed); + addDataFormatDesc(BC3_SRGB_BLOCK, BC3_RGBA_sRGB, Invalid, Invalid, Invalid, 4, 4, 16, Compressed); + + addDataFormatDesc(BC4_UNORM_BLOCK, BC4_RUnorm, Invalid, Invalid, Invalid, 4, 4, 8, Compressed); + addDataFormatDesc(BC4_SNORM_BLOCK, BC4_RSnorm, Invalid, Invalid, Invalid, 4, 4, 8, Compressed); + + addDataFormatDesc(BC5_UNORM_BLOCK, BC5_RGUnorm, Invalid, Invalid, Invalid, 4, 4, 16, Compressed); + addDataFormatDesc(BC5_SNORM_BLOCK, BC5_RGSnorm, Invalid, Invalid, Invalid, 4, 4, 16, Compressed); + + addDataFormatDesc(BC6H_UFLOAT_BLOCK, BC6H_RGBUfloat, Invalid, Invalid, Invalid, 4, 4, 16, Compressed); + addDataFormatDesc(BC6H_SFLOAT_BLOCK, BC6H_RGBFloat, Invalid, Invalid, Invalid, 4, 4, 16, Compressed); + + addDataFormatDesc(BC7_UNORM_BLOCK, BC7_RGBAUnorm, Invalid, Invalid, Invalid, 4, 4, 16, Compressed); + addDataFormatDesc(BC7_SRGB_BLOCK, BC7_RGBAUnorm_sRGB, Invalid, Invalid, Invalid, 4, 4, 16, Compressed); + +#pragma clang diagnostic pop + + addDataFormatDesc(ETC2_R8G8B8_UNORM_BLOCK, ETC2_RGB8, Invalid, Invalid, Invalid, 4, 4, 8, Compressed); + addDataFormatDesc(ETC2_R8G8B8_SRGB_BLOCK, ETC2_RGB8_sRGB, Invalid, Invalid, Invalid, 4, 4, 8, Compressed); + addDataFormatDesc(ETC2_R8G8B8A1_UNORM_BLOCK, ETC2_RGB8A1, Invalid, Invalid, Invalid, 4, 4, 8, Compressed); + addDataFormatDesc(ETC2_R8G8B8A1_SRGB_BLOCK, ETC2_RGB8A1_sRGB, Invalid, Invalid, Invalid, 4, 4, 8, Compressed); + + addDataFormatDesc(ETC2_R8G8B8A8_UNORM_BLOCK, EAC_RGBA8, Invalid, Invalid, Invalid, 4, 4, 16, Compressed); + addDataFormatDesc(ETC2_R8G8B8A8_SRGB_BLOCK, EAC_RGBA8_sRGB, Invalid, Invalid, Invalid, 4, 4, 16, Compressed); + + addDataFormatDesc(EAC_R11_UNORM_BLOCK, EAC_R11Unorm, Invalid, Invalid, Invalid, 4, 4, 8, Compressed); + addDataFormatDesc(EAC_R11_SNORM_BLOCK, EAC_R11Snorm, Invalid, Invalid, Invalid, 4, 4, 8, Compressed); + + addDataFormatDesc(EAC_R11G11_UNORM_BLOCK, EAC_RG11Unorm, Invalid, Invalid, Invalid, 4, 4, 16, Compressed); + addDataFormatDesc(EAC_R11G11_SNORM_BLOCK, EAC_RG11Snorm, Invalid, Invalid, Invalid, 4, 4, 16, Compressed); + + addDataFormatDesc(ASTC_4x4_UNORM_BLOCK, ASTC_4x4_LDR, Invalid, Invalid, Invalid, 4, 4, 16, Compressed); + addDataFormatDesc(ASTC_4x4_SRGB_BLOCK, ASTC_4x4_sRGB, Invalid, Invalid, Invalid, 4, 4, 16, Compressed); + addDataFormatDesc(ASTC_5x4_UNORM_BLOCK, ASTC_5x4_LDR, Invalid, Invalid, Invalid, 5, 4, 16, Compressed); + addDataFormatDesc(ASTC_5x4_SRGB_BLOCK, ASTC_5x4_sRGB, Invalid, Invalid, Invalid, 5, 4, 16, Compressed); + addDataFormatDesc(ASTC_5x5_UNORM_BLOCK, ASTC_5x5_LDR, Invalid, Invalid, Invalid, 5, 5, 16, Compressed); + addDataFormatDesc(ASTC_5x5_SRGB_BLOCK, ASTC_5x5_sRGB, Invalid, Invalid, Invalid, 5, 5, 16, Compressed); + addDataFormatDesc(ASTC_6x5_UNORM_BLOCK, ASTC_6x5_LDR, Invalid, Invalid, Invalid, 6, 5, 16, Compressed); + addDataFormatDesc(ASTC_6x5_SRGB_BLOCK, ASTC_6x5_sRGB, Invalid, Invalid, Invalid, 6, 5, 16, Compressed); + addDataFormatDesc(ASTC_6x6_UNORM_BLOCK, ASTC_6x6_LDR, Invalid, Invalid, Invalid, 6, 6, 16, Compressed); + addDataFormatDesc(ASTC_6x6_SRGB_BLOCK, ASTC_6x6_sRGB, Invalid, Invalid, Invalid, 6, 6, 16, Compressed); + addDataFormatDesc(ASTC_8x5_UNORM_BLOCK, ASTC_8x5_LDR, Invalid, Invalid, Invalid, 8, 5, 16, Compressed); + addDataFormatDesc(ASTC_8x5_SRGB_BLOCK, ASTC_8x5_sRGB, Invalid, Invalid, Invalid, 8, 5, 16, Compressed); + addDataFormatDesc(ASTC_8x6_UNORM_BLOCK, ASTC_8x6_LDR, Invalid, Invalid, Invalid, 8, 6, 16, Compressed); + addDataFormatDesc(ASTC_8x6_SRGB_BLOCK, ASTC_8x6_sRGB, Invalid, Invalid, Invalid, 8, 6, 16, Compressed); + addDataFormatDesc(ASTC_8x8_UNORM_BLOCK, ASTC_8x8_LDR, Invalid, Invalid, Invalid, 8, 8, 16, Compressed); + addDataFormatDesc(ASTC_8x8_SRGB_BLOCK, ASTC_8x8_sRGB, Invalid, Invalid, Invalid, 8, 8, 16, Compressed); + addDataFormatDesc(ASTC_10x5_UNORM_BLOCK, ASTC_10x5_LDR, Invalid, Invalid, Invalid, 10, 5, 16, Compressed); + addDataFormatDesc(ASTC_10x5_SRGB_BLOCK, ASTC_10x5_sRGB, Invalid, Invalid, Invalid, 10, 5, 16, Compressed); + addDataFormatDesc(ASTC_10x6_UNORM_BLOCK, ASTC_10x6_LDR, Invalid, Invalid, Invalid, 10, 6, 16, Compressed); + addDataFormatDesc(ASTC_10x6_SRGB_BLOCK, ASTC_10x6_sRGB, Invalid, Invalid, Invalid, 10, 6, 16, Compressed); + addDataFormatDesc(ASTC_10x8_UNORM_BLOCK, ASTC_10x8_LDR, Invalid, Invalid, Invalid, 10, 8, 16, Compressed); + addDataFormatDesc(ASTC_10x8_SRGB_BLOCK, ASTC_10x8_sRGB, Invalid, Invalid, Invalid, 10, 8, 16, Compressed); + addDataFormatDesc(ASTC_10x10_UNORM_BLOCK, ASTC_10x10_LDR, Invalid, Invalid, Invalid, 10, 10, 16, Compressed); + addDataFormatDesc(ASTC_10x10_SRGB_BLOCK, ASTC_10x10_sRGB, Invalid, Invalid, Invalid, 10, 10, 16, Compressed); + addDataFormatDesc(ASTC_12x10_UNORM_BLOCK, ASTC_12x10_LDR, Invalid, Invalid, Invalid, 12, 10, 16, Compressed); + addDataFormatDesc(ASTC_12x10_SRGB_BLOCK, ASTC_12x10_sRGB, Invalid, Invalid, Invalid, 12, 10, 16, Compressed); + addDataFormatDesc(ASTC_12x12_UNORM_BLOCK, ASTC_12x12_LDR, Invalid, Invalid, Invalid, 12, 12, 16, Compressed); + addDataFormatDesc(ASTC_12x12_SRGB_BLOCK, ASTC_12x12_sRGB, Invalid, Invalid, Invalid, 12, 12, 16, Compressed); + + addDfFormatDescChromaSubsampling(G8B8G8R8_422_UNORM, GBGR422, 1, 8, 2, 1, 4); + addDfFormatDescChromaSubsampling(B8G8R8G8_422_UNORM, BGRG422, 1, 8, 2, 1, 4); + addDfFormatDescChromaSubsampling(G8_B8_R8_3PLANE_420_UNORM, Invalid, 3, 8, 2, 2, 6); + addDfFormatDescChromaSubsampling(G8_B8R8_2PLANE_420_UNORM, Invalid, 2, 8, 2, 2, 6); + addDfFormatDescChromaSubsampling(G8_B8_R8_3PLANE_422_UNORM, Invalid, 3, 8, 2, 1, 4); + addDfFormatDescChromaSubsampling(G8_B8R8_2PLANE_422_UNORM, Invalid, 2, 8, 2, 1, 4); + addDfFormatDescChromaSubsampling(G8_B8_R8_3PLANE_444_UNORM, Invalid, 3, 8, 1, 1, 3); + addDfFormatDescChromaSubsampling(R10X6_UNORM_PACK16, R16Unorm, 0, 10, 1, 1, 2); + addDfFormatDescChromaSubsampling(R10X6G10X6_UNORM_2PACK16, RG16Unorm, 0, 10, 1, 1, 4); + addDfFormatDescChromaSubsampling(R10X6G10X6B10X6A10X6_UNORM_4PACK16, RGBA16Unorm, 0, 10, 1, 1, 8); + addDfFormatDescChromaSubsampling(G10X6B10X6G10X6R10X6_422_UNORM_4PACK16, Invalid, 1, 10, 2, 1, 8); + addDfFormatDescChromaSubsampling(B10X6G10X6R10X6G10X6_422_UNORM_4PACK16, Invalid, 1, 10, 2, 1, 8); + addDfFormatDescChromaSubsampling(G10X6_B10X6_R10X6_3PLANE_420_UNORM_3PACK16, Invalid, 3, 10, 2, 2, 12); + addDfFormatDescChromaSubsampling(G10X6_B10X6R10X6_2PLANE_420_UNORM_3PACK16, Invalid, 2, 10, 2, 2, 12); + addDfFormatDescChromaSubsampling(G10X6_B10X6_R10X6_3PLANE_422_UNORM_3PACK16, Invalid, 3, 10, 2, 1, 8); + addDfFormatDescChromaSubsampling(G10X6_B10X6R10X6_2PLANE_422_UNORM_3PACK16, Invalid, 2, 10, 2, 1, 8); + addDfFormatDescChromaSubsampling(G10X6_B10X6_R10X6_3PLANE_444_UNORM_3PACK16, Invalid, 3, 10, 1, 1, 6); + addDfFormatDescChromaSubsampling(R12X4_UNORM_PACK16, R16Unorm, 0, 12, 1, 1, 2); + addDfFormatDescChromaSubsampling(R12X4G12X4_UNORM_2PACK16, RG16Unorm, 0, 12, 1, 1, 4); + addDfFormatDescChromaSubsampling(R12X4G12X4B12X4A12X4_UNORM_4PACK16, RGBA16Unorm, 0, 12, 1, 1, 8); + addDfFormatDescChromaSubsampling(G12X4B12X4G12X4R12X4_422_UNORM_4PACK16, Invalid, 1, 12, 2, 1, 8); + addDfFormatDescChromaSubsampling(B12X4G12X4R12X4G12X4_422_UNORM_4PACK16, Invalid, 1, 12, 2, 1, 8); + addDfFormatDescChromaSubsampling(G12X4_B12X4_R12X4_3PLANE_420_UNORM_3PACK16, Invalid, 3, 12, 2, 2, 12); + addDfFormatDescChromaSubsampling(G12X4_B12X4R12X4_2PLANE_420_UNORM_3PACK16, Invalid, 2, 12, 2, 2, 12); + addDfFormatDescChromaSubsampling(G12X4_B12X4_R12X4_3PLANE_422_UNORM_3PACK16, Invalid, 3, 12, 2, 1, 8); + addDfFormatDescChromaSubsampling(G12X4_B12X4R12X4_2PLANE_422_UNORM_3PACK16, Invalid, 2, 12, 2, 1, 8); + addDfFormatDescChromaSubsampling(G12X4_B12X4_R12X4_3PLANE_444_UNORM_3PACK16, Invalid, 3, 12, 1, 1, 6); + addDfFormatDescChromaSubsampling(G16B16G16R16_422_UNORM, Invalid, 1, 16, 2, 1, 8); + addDfFormatDescChromaSubsampling(B16G16R16G16_422_UNORM, Invalid, 1, 16, 2, 1, 8); + addDfFormatDescChromaSubsampling(G16_B16_R16_3PLANE_420_UNORM, Invalid, 3, 16, 2, 2, 12); + addDfFormatDescChromaSubsampling(G16_B16R16_2PLANE_420_UNORM, Invalid, 2, 16, 2, 2, 12); + addDfFormatDescChromaSubsampling(G16_B16_R16_3PLANE_422_UNORM, Invalid, 3, 16, 2, 1, 8); + addDfFormatDescChromaSubsampling(G16_B16R16_2PLANE_422_UNORM, Invalid, 2, 16, 2, 1, 8); + addDfFormatDescChromaSubsampling(G16_B16_R16_3PLANE_444_UNORM, Invalid, 3, 16, 1, 1, 6); +} + +#define addMTLPixelFormatDescFull(MTL_FMT, VIEW_CLASS, IOS_CAPS, MACOS_CAPS, MTL_FMT_LINEAR) \ + CRASH_BAD_INDEX_MSG(fmtIdx, _mtlPixelFormatCount, "Adding too many pixel formats"); \ + _mtlPixelFormatDescriptions[fmtIdx++] = { .mtlPixelFormat = MTLPixelFormat##MTL_FMT, RD::DATA_FORMAT_MAX, select_platform_caps(kMTLFmtCaps##MACOS_CAPS, kMTLFmtCaps##IOS_CAPS), MTLViewClass::VIEW_CLASS, MTLPixelFormat##MTL_FMT_LINEAR, "MTLPixelFormat" #MTL_FMT } + +#define addMTLPixelFormatDesc(MTL_FMT, VIEW_CLASS, IOS_CAPS, MACOS_CAPS) \ + addMTLPixelFormatDescFull(MTL_FMT, VIEW_CLASS, IOS_CAPS, MACOS_CAPS, MTL_FMT) + +#define addMTLPixelFormatDescSRGB(MTL_FMT, VIEW_CLASS, IOS_CAPS, MACOS_CAPS, MTL_FMT_LINEAR) \ + addMTLPixelFormatDescFull(MTL_FMT, VIEW_CLASS, IOS_CAPS, MACOS_CAPS, MTL_FMT_LINEAR) + +void PixelFormats::initMTLPixelFormatCapabilities() { + clear(_mtlPixelFormatDescriptions, _mtlPixelFormatCount); + + uint32_t fmtIdx = 0; + + // When adding to this list, be sure to ensure _mtlPixelFormatCount is large enough for the format count. + + // MTLPixelFormatInvalid must come first. + addMTLPixelFormatDesc(Invalid, None, None, None); + + // Ordinary 8-bit pixel formats. + addMTLPixelFormatDesc(A8Unorm, Color8, RF, RF); + addMTLPixelFormatDesc(R8Unorm, Color8, All, All); + addMTLPixelFormatDescSRGB(R8Unorm_sRGB, Color8, RFCMRB, None, R8Unorm); + addMTLPixelFormatDesc(R8Snorm, Color8, RFWCMB, All); + addMTLPixelFormatDesc(R8Uint, Color8, RWCM, RWCM); + addMTLPixelFormatDesc(R8Sint, Color8, RWCM, RWCM); + + // Ordinary 16-bit pixel formats. + addMTLPixelFormatDesc(R16Unorm, Color16, RFWCMB, All); + addMTLPixelFormatDesc(R16Snorm, Color16, RFWCMB, All); + addMTLPixelFormatDesc(R16Uint, Color16, RWCM, RWCM); + addMTLPixelFormatDesc(R16Sint, Color16, RWCM, RWCM); + addMTLPixelFormatDesc(R16Float, Color16, All, All); + + addMTLPixelFormatDesc(RG8Unorm, Color16, All, All); + addMTLPixelFormatDescSRGB(RG8Unorm_sRGB, Color16, RFCMRB, None, RG8Unorm); + addMTLPixelFormatDesc(RG8Snorm, Color16, RFWCMB, All); + addMTLPixelFormatDesc(RG8Uint, Color16, RWCM, RWCM); + addMTLPixelFormatDesc(RG8Sint, Color16, RWCM, RWCM); + + // Packed 16-bit pixel formats. + addMTLPixelFormatDesc(B5G6R5Unorm, Color16, RFCMRB, None); + addMTLPixelFormatDesc(A1BGR5Unorm, Color16, RFCMRB, None); + addMTLPixelFormatDesc(ABGR4Unorm, Color16, RFCMRB, None); + addMTLPixelFormatDesc(BGR5A1Unorm, Color16, RFCMRB, None); + + // Ordinary 32-bit pixel formats. + addMTLPixelFormatDesc(R32Uint, Color32, RC, RWCM); + addMTLPixelFormatDesc(R32Sint, Color32, RC, RWCM); + addMTLPixelFormatDesc(R32Float, Color32, RCMB, All); + + addMTLPixelFormatDesc(RG16Unorm, Color32, RFWCMB, All); + addMTLPixelFormatDesc(RG16Snorm, Color32, RFWCMB, All); + addMTLPixelFormatDesc(RG16Uint, Color32, RWCM, RWCM); + addMTLPixelFormatDesc(RG16Sint, Color32, RWCM, RWCM); + addMTLPixelFormatDesc(RG16Float, Color32, All, All); + + addMTLPixelFormatDesc(RGBA8Unorm, Color32, All, All); + addMTLPixelFormatDescSRGB(RGBA8Unorm_sRGB, Color32, RFCMRB, RFCMRB, RGBA8Unorm); + addMTLPixelFormatDesc(RGBA8Snorm, Color32, RFWCMB, All); + addMTLPixelFormatDesc(RGBA8Uint, Color32, RWCM, RWCM); + addMTLPixelFormatDesc(RGBA8Sint, Color32, RWCM, RWCM); + + addMTLPixelFormatDesc(BGRA8Unorm, Color32, All, All); + addMTLPixelFormatDescSRGB(BGRA8Unorm_sRGB, Color32, RFCMRB, RFCMRB, BGRA8Unorm); + + // Packed 32-bit pixel formats. + addMTLPixelFormatDesc(RGB10A2Unorm, Color32, RFCMRB, All); + addMTLPixelFormatDesc(RGB10A2Uint, Color32, RCM, RWCM); + addMTLPixelFormatDesc(RG11B10Float, Color32, RFCMRB, All); + addMTLPixelFormatDesc(RGB9E5Float, Color32, RFCMRB, RF); + + // Ordinary 64-bit pixel formats. + addMTLPixelFormatDesc(RG32Uint, Color64, RC, RWCM); + addMTLPixelFormatDesc(RG32Sint, Color64, RC, RWCM); + addMTLPixelFormatDesc(RG32Float, Color64, RCB, All); + + addMTLPixelFormatDesc(RGBA16Unorm, Color64, RFWCMB, All); + addMTLPixelFormatDesc(RGBA16Snorm, Color64, RFWCMB, All); + addMTLPixelFormatDesc(RGBA16Uint, Color64, RWCM, RWCM); + addMTLPixelFormatDesc(RGBA16Sint, Color64, RWCM, RWCM); + addMTLPixelFormatDesc(RGBA16Float, Color64, All, All); + + // Ordinary 128-bit pixel formats. + addMTLPixelFormatDesc(RGBA32Uint, Color128, RC, RWCM); + addMTLPixelFormatDesc(RGBA32Sint, Color128, RC, RWCM); + addMTLPixelFormatDesc(RGBA32Float, Color128, RC, All); + + // Compressed pixel formats. + addMTLPixelFormatDesc(PVRTC_RGBA_2BPP, PVRTC_RGBA_2BPP, RF, None); + addMTLPixelFormatDescSRGB(PVRTC_RGBA_2BPP_sRGB, PVRTC_RGBA_2BPP, RF, None, PVRTC_RGBA_2BPP); + addMTLPixelFormatDesc(PVRTC_RGBA_4BPP, PVRTC_RGBA_4BPP, RF, None); + addMTLPixelFormatDescSRGB(PVRTC_RGBA_4BPP_sRGB, PVRTC_RGBA_4BPP, RF, None, PVRTC_RGBA_4BPP); + + addMTLPixelFormatDesc(ETC2_RGB8, ETC2_RGB8, RF, None); + addMTLPixelFormatDescSRGB(ETC2_RGB8_sRGB, ETC2_RGB8, RF, None, ETC2_RGB8); + addMTLPixelFormatDesc(ETC2_RGB8A1, ETC2_RGB8A1, RF, None); + addMTLPixelFormatDescSRGB(ETC2_RGB8A1_sRGB, ETC2_RGB8A1, RF, None, ETC2_RGB8A1); + addMTLPixelFormatDesc(EAC_RGBA8, EAC_RGBA8, RF, None); + addMTLPixelFormatDescSRGB(EAC_RGBA8_sRGB, EAC_RGBA8, RF, None, EAC_RGBA8); + addMTLPixelFormatDesc(EAC_R11Unorm, EAC_R11, RF, None); + addMTLPixelFormatDesc(EAC_R11Snorm, EAC_R11, RF, None); + addMTLPixelFormatDesc(EAC_RG11Unorm, EAC_RG11, RF, None); + addMTLPixelFormatDesc(EAC_RG11Snorm, EAC_RG11, RF, None); + + addMTLPixelFormatDesc(ASTC_4x4_LDR, ASTC_4x4, None, None); + addMTLPixelFormatDescSRGB(ASTC_4x4_sRGB, ASTC_4x4, None, None, ASTC_4x4_LDR); + addMTLPixelFormatDesc(ASTC_4x4_HDR, ASTC_4x4, None, None); + addMTLPixelFormatDesc(ASTC_5x4_LDR, ASTC_5x4, None, None); + addMTLPixelFormatDescSRGB(ASTC_5x4_sRGB, ASTC_5x4, None, None, ASTC_5x4_LDR); + addMTLPixelFormatDesc(ASTC_5x4_HDR, ASTC_5x4, None, None); + addMTLPixelFormatDesc(ASTC_5x5_LDR, ASTC_5x5, None, None); + addMTLPixelFormatDescSRGB(ASTC_5x5_sRGB, ASTC_5x5, None, None, ASTC_5x5_LDR); + addMTLPixelFormatDesc(ASTC_5x5_HDR, ASTC_5x5, None, None); + addMTLPixelFormatDesc(ASTC_6x5_LDR, ASTC_6x5, None, None); + addMTLPixelFormatDescSRGB(ASTC_6x5_sRGB, ASTC_6x5, None, None, ASTC_6x5_LDR); + addMTLPixelFormatDesc(ASTC_6x5_HDR, ASTC_6x5, None, None); + addMTLPixelFormatDesc(ASTC_6x6_LDR, ASTC_6x6, None, None); + addMTLPixelFormatDescSRGB(ASTC_6x6_sRGB, ASTC_6x6, None, None, ASTC_6x6_LDR); + addMTLPixelFormatDesc(ASTC_6x6_HDR, ASTC_6x6, None, None); + addMTLPixelFormatDesc(ASTC_8x5_LDR, ASTC_8x5, None, None); + addMTLPixelFormatDescSRGB(ASTC_8x5_sRGB, ASTC_8x5, None, None, ASTC_8x5_LDR); + addMTLPixelFormatDesc(ASTC_8x5_HDR, ASTC_8x5, None, None); + addMTLPixelFormatDesc(ASTC_8x6_LDR, ASTC_8x6, None, None); + addMTLPixelFormatDescSRGB(ASTC_8x6_sRGB, ASTC_8x6, None, None, ASTC_8x6_LDR); + addMTLPixelFormatDesc(ASTC_8x6_HDR, ASTC_8x6, None, None); + addMTLPixelFormatDesc(ASTC_8x8_LDR, ASTC_8x8, None, None); + addMTLPixelFormatDescSRGB(ASTC_8x8_sRGB, ASTC_8x8, None, None, ASTC_8x8_LDR); + addMTLPixelFormatDesc(ASTC_8x8_HDR, ASTC_8x8, None, None); + addMTLPixelFormatDesc(ASTC_10x5_LDR, ASTC_10x5, None, None); + addMTLPixelFormatDescSRGB(ASTC_10x5_sRGB, ASTC_10x5, None, None, ASTC_10x5_LDR); + addMTLPixelFormatDesc(ASTC_10x5_HDR, ASTC_10x5, None, None); + addMTLPixelFormatDesc(ASTC_10x6_LDR, ASTC_10x6, None, None); + addMTLPixelFormatDescSRGB(ASTC_10x6_sRGB, ASTC_10x6, None, None, ASTC_10x6_LDR); + addMTLPixelFormatDesc(ASTC_10x6_HDR, ASTC_10x6, None, None); + addMTLPixelFormatDesc(ASTC_10x8_LDR, ASTC_10x8, None, None); + addMTLPixelFormatDescSRGB(ASTC_10x8_sRGB, ASTC_10x8, None, None, ASTC_10x8_LDR); + addMTLPixelFormatDesc(ASTC_10x8_HDR, ASTC_10x8, None, None); + addMTLPixelFormatDesc(ASTC_10x10_LDR, ASTC_10x10, None, None); + addMTLPixelFormatDescSRGB(ASTC_10x10_sRGB, ASTC_10x10, None, None, ASTC_10x10_LDR); + addMTLPixelFormatDesc(ASTC_10x10_HDR, ASTC_10x10, None, None); + addMTLPixelFormatDesc(ASTC_12x10_LDR, ASTC_12x10, None, None); + addMTLPixelFormatDescSRGB(ASTC_12x10_sRGB, ASTC_12x10, None, None, ASTC_12x10_LDR); + addMTLPixelFormatDesc(ASTC_12x10_HDR, ASTC_12x10, None, None); + addMTLPixelFormatDesc(ASTC_12x12_LDR, ASTC_12x12, None, None); + addMTLPixelFormatDescSRGB(ASTC_12x12_sRGB, ASTC_12x12, None, None, ASTC_12x12_LDR); + addMTLPixelFormatDesc(ASTC_12x12_HDR, ASTC_12x12, None, None); + +#pragma clang diagnostic push +#pragma clang diagnostic ignored "-Wunguarded-availability" + + addMTLPixelFormatDesc(BC1_RGBA, BC1_RGBA, RF, RF); + addMTLPixelFormatDescSRGB(BC1_RGBA_sRGB, BC1_RGBA, RF, RF, BC1_RGBA); + addMTLPixelFormatDesc(BC2_RGBA, BC2_RGBA, RF, RF); + addMTLPixelFormatDescSRGB(BC2_RGBA_sRGB, BC2_RGBA, RF, RF, BC2_RGBA); + addMTLPixelFormatDesc(BC3_RGBA, BC3_RGBA, RF, RF); + addMTLPixelFormatDescSRGB(BC3_RGBA_sRGB, BC3_RGBA, RF, RF, BC3_RGBA); + addMTLPixelFormatDesc(BC4_RUnorm, BC4_R, RF, RF); + addMTLPixelFormatDesc(BC4_RSnorm, BC4_R, RF, RF); + addMTLPixelFormatDesc(BC5_RGUnorm, BC5_RG, RF, RF); + addMTLPixelFormatDesc(BC5_RGSnorm, BC5_RG, RF, RF); + addMTLPixelFormatDesc(BC6H_RGBUfloat, BC6H_RGB, RF, RF); + addMTLPixelFormatDesc(BC6H_RGBFloat, BC6H_RGB, RF, RF); + addMTLPixelFormatDesc(BC7_RGBAUnorm, BC7_RGBA, RF, RF); + addMTLPixelFormatDescSRGB(BC7_RGBAUnorm_sRGB, BC7_RGBA, RF, RF, BC7_RGBAUnorm); + +#pragma clang diagnostic pop + + // YUV pixel formats. + addMTLPixelFormatDesc(GBGR422, None, RF, RF); + addMTLPixelFormatDesc(BGRG422, None, RF, RF); + + // Extended range and wide color pixel formats. + addMTLPixelFormatDesc(BGRA10_XR, BGRA10_XR, None, None); + addMTLPixelFormatDescSRGB(BGRA10_XR_sRGB, BGRA10_XR, None, None, BGRA10_XR); + addMTLPixelFormatDesc(BGR10_XR, BGR10_XR, None, None); + addMTLPixelFormatDescSRGB(BGR10_XR_sRGB, BGR10_XR, None, None, BGR10_XR); + addMTLPixelFormatDesc(BGR10A2Unorm, Color32, None, None); + + // Depth and stencil pixel formats. + addMTLPixelFormatDesc(Depth16Unorm, None, None, None); + addMTLPixelFormatDesc(Depth32Float, None, DRM, DRFMR); + addMTLPixelFormatDesc(Stencil8, None, DRM, DRMR); + addMTLPixelFormatDesc(Depth24Unorm_Stencil8, Depth24_Stencil8, None, None); + addMTLPixelFormatDesc(Depth32Float_Stencil8, Depth32_Stencil8, DRM, DRFMR); + addMTLPixelFormatDesc(X24_Stencil8, Depth24_Stencil8, None, DRMR); + addMTLPixelFormatDesc(X32_Stencil8, Depth32_Stencil8, DRM, DRMR); + + // When adding to this list, be sure to ensure _mtlPixelFormatCount is large enough for the format count. +} + +#define addMTLVertexFormatDesc(MTL_VTX_FMT, IOS_CAPS, MACOS_CAPS) \ + CRASH_BAD_INDEX_MSG(fmtIdx, _mtlVertexFormatCount, "Attempting to describe too many MTLVertexFormats"); \ + _mtlVertexFormatDescriptions[fmtIdx++] = { .mtlVertexFormat = MTLVertexFormat##MTL_VTX_FMT, RD::DATA_FORMAT_MAX, select_platform_caps(kMTLFmtCaps##MACOS_CAPS, kMTLFmtCaps##IOS_CAPS), MTLViewClass::None, MTLPixelFormatInvalid, "MTLVertexFormat" #MTL_VTX_FMT } + +void PixelFormats::initMTLVertexFormatCapabilities() { + clear(_mtlVertexFormatDescriptions, _mtlVertexFormatCount); + + uint32_t fmtIdx = 0; + + // When adding to this list, be sure to ensure _mtlVertexFormatCount is large enough for the format count. + + // MTLVertexFormatInvalid must come first. + addMTLVertexFormatDesc(Invalid, None, None); + + addMTLVertexFormatDesc(UChar2Normalized, Vertex, Vertex); + addMTLVertexFormatDesc(Char2Normalized, Vertex, Vertex); + addMTLVertexFormatDesc(UChar2, Vertex, Vertex); + addMTLVertexFormatDesc(Char2, Vertex, Vertex); + + addMTLVertexFormatDesc(UChar3Normalized, Vertex, Vertex); + addMTLVertexFormatDesc(Char3Normalized, Vertex, Vertex); + addMTLVertexFormatDesc(UChar3, Vertex, Vertex); + addMTLVertexFormatDesc(Char3, Vertex, Vertex); + + addMTLVertexFormatDesc(UChar4Normalized, Vertex, Vertex); + addMTLVertexFormatDesc(Char4Normalized, Vertex, Vertex); + addMTLVertexFormatDesc(UChar4, Vertex, Vertex); + addMTLVertexFormatDesc(Char4, Vertex, Vertex); + + addMTLVertexFormatDesc(UInt1010102Normalized, Vertex, Vertex); + addMTLVertexFormatDesc(Int1010102Normalized, Vertex, Vertex); + + addMTLVertexFormatDesc(UShort2Normalized, Vertex, Vertex); + addMTLVertexFormatDesc(Short2Normalized, Vertex, Vertex); + addMTLVertexFormatDesc(UShort2, Vertex, Vertex); + addMTLVertexFormatDesc(Short2, Vertex, Vertex); + addMTLVertexFormatDesc(Half2, Vertex, Vertex); + + addMTLVertexFormatDesc(UShort3Normalized, Vertex, Vertex); + addMTLVertexFormatDesc(Short3Normalized, Vertex, Vertex); + addMTLVertexFormatDesc(UShort3, Vertex, Vertex); + addMTLVertexFormatDesc(Short3, Vertex, Vertex); + addMTLVertexFormatDesc(Half3, Vertex, Vertex); + + addMTLVertexFormatDesc(UShort4Normalized, Vertex, Vertex); + addMTLVertexFormatDesc(Short4Normalized, Vertex, Vertex); + addMTLVertexFormatDesc(UShort4, Vertex, Vertex); + addMTLVertexFormatDesc(Short4, Vertex, Vertex); + addMTLVertexFormatDesc(Half4, Vertex, Vertex); + + addMTLVertexFormatDesc(UInt, Vertex, Vertex); + addMTLVertexFormatDesc(Int, Vertex, Vertex); + addMTLVertexFormatDesc(Float, Vertex, Vertex); + + addMTLVertexFormatDesc(UInt2, Vertex, Vertex); + addMTLVertexFormatDesc(Int2, Vertex, Vertex); + addMTLVertexFormatDesc(Float2, Vertex, Vertex); + + addMTLVertexFormatDesc(UInt3, Vertex, Vertex); + addMTLVertexFormatDesc(Int3, Vertex, Vertex); + addMTLVertexFormatDesc(Float3, Vertex, Vertex); + + addMTLVertexFormatDesc(UInt4, Vertex, Vertex); + addMTLVertexFormatDesc(Int4, Vertex, Vertex); + addMTLVertexFormatDesc(Float4, Vertex, Vertex); + + addMTLVertexFormatDesc(UCharNormalized, None, None); + addMTLVertexFormatDesc(CharNormalized, None, None); + addMTLVertexFormatDesc(UChar, None, None); + addMTLVertexFormatDesc(Char, None, None); + + addMTLVertexFormatDesc(UShortNormalized, None, None); + addMTLVertexFormatDesc(ShortNormalized, None, None); + addMTLVertexFormatDesc(UShort, None, None); + addMTLVertexFormatDesc(Short, None, None); + addMTLVertexFormatDesc(Half, None, None); + + addMTLVertexFormatDesc(UChar4Normalized_BGRA, None, None); + + // When adding to this list, be sure to ensure _mtlVertexFormatCount is large enough for the format count. +} + +void PixelFormats::buildMTLFormatMaps() { + // Set all MTLPixelFormats and MTLVertexFormats to undefined/invalid. + clear(_mtlFormatDescIndicesByMTLPixelFormatsCore, _mtlPixelFormatCoreCount); + clear(_mtlFormatDescIndicesByMTLVertexFormats, _mtlVertexFormatCount); + + // Build lookup table for MTLPixelFormat specs. + // For most Metal format values, which are small and consecutive, use a simple lookup array. + // For outlier format values, which can be large, use a map. + for (uint32_t fmtIdx = 0; fmtIdx < _mtlPixelFormatCount; fmtIdx++) { + MTLPixelFormat fmt = _mtlPixelFormatDescriptions[fmtIdx].mtlPixelFormat; + if (fmt) { + if (fmt < _mtlPixelFormatCoreCount) { + _mtlFormatDescIndicesByMTLPixelFormatsCore[fmt] = fmtIdx; + } else { + _mtlFormatDescIndicesByMTLPixelFormatsExt[fmt] = fmtIdx; + } + } + } + + // Build lookup table for MTLVertexFormat specs. + for (uint32_t fmtIdx = 0; fmtIdx < _mtlVertexFormatCount; fmtIdx++) { + MTLVertexFormat fmt = _mtlVertexFormatDescriptions[fmtIdx].mtlVertexFormat; + if (fmt) { + _mtlFormatDescIndicesByMTLVertexFormats[fmt] = fmtIdx; + } + } +} + +// If the device supports the feature set, add additional capabilities to a MTLPixelFormat. +void PixelFormats::addMTLPixelFormatCapabilities(id p_device, + MTLFeatureSet p_feature_set, + MTLPixelFormat p_format, + MTLFmtCaps p_caps) { + if ([p_device supportsFeatureSet:p_feature_set]) { + flags::set(getMTLPixelFormatDesc(p_format).mtlFmtCaps, p_caps); + } +} + +// If the device supports the GPU family, add additional capabilities to a MTLPixelFormat. +void PixelFormats::addMTLPixelFormatCapabilities(id p_device, + MTLGPUFamily p_family, + MTLPixelFormat p_format, + MTLFmtCaps p_caps) { + if ([p_device supportsFamily:p_family]) { + flags::set(getMTLPixelFormatDesc(p_format).mtlFmtCaps, p_caps); + } +} + +// Disable capability flags in the Metal pixel format. +void PixelFormats::disableMTLPixelFormatCapabilities(MTLPixelFormat p_format, + MTLFmtCaps p_caps) { + flags::clear(getMTLPixelFormatDesc(p_format).mtlFmtCaps, p_caps); +} + +void PixelFormats::disableAllMTLPixelFormatCapabilities(MTLPixelFormat p_format) { + getMTLPixelFormatDesc(p_format).mtlFmtCaps = kMTLFmtCapsNone; +} + +// If the device supports the feature set, add additional capabilities to a MTLVertexFormat. +void PixelFormats::addMTLVertexFormatCapabilities(id p_device, + MTLFeatureSet p_feature_set, + MTLVertexFormat p_format, + MTLFmtCaps p_caps) { + if ([p_device supportsFeatureSet:p_feature_set]) { + flags::set(getMTLVertexFormatDesc(p_format).mtlFmtCaps, p_caps); + } +} + +void PixelFormats::modifyMTLFormatCapabilities() { + modifyMTLFormatCapabilities(device); +} + +// If the supportsBCTextureCompression query is available, use it. +bool supports_bc_texture_compression(id p_device) { +#if (TARGET_OS_OSX || TARGET_OS_IOS && __IPHONE_OS_VERSION_MAX_ALLOWED >= 160400) + if (@available(macOS 11.0, iOS 16.4, *)) { + return p_device.supportsBCTextureCompression; + } +#endif + return false; +} + +#define addFeatSetMTLPixFmtCaps(FEAT_SET, MTL_FMT, CAPS) \ + addMTLPixelFormatCapabilities(p_device, MTLFeatureSet_##FEAT_SET, MTLPixelFormat##MTL_FMT, kMTLFmtCaps##CAPS) + +#define addFeatSetMTLVtxFmtCaps(FEAT_SET, MTL_FMT, CAPS) \ + addMTLVertexFormatCapabilities(p_device, MTLFeatureSet_##FEAT_SET, MTLVertexFormat##MTL_FMT, kMTLFmtCaps##CAPS) + +#define addGPUMTLPixFmtCaps(GPU_FAM, MTL_FMT, CAPS) \ + addMTLPixelFormatCapabilities(p_device, MTLGPUFamily##GPU_FAM, MTLPixelFormat##MTL_FMT, kMTLFmtCaps##CAPS) + +#define disableAllMTLPixFmtCaps(MTL_FMT) \ + disableAllMTLPixelFormatCapabilities(MTLPixelFormat##MTL_FMT) + +#define disableMTLPixFmtCaps(MTL_FMT, CAPS) \ + disableMTLPixelFormatCapabilities(MTLPixelFormat##MTL_FMT, kMTLFmtCaps##CAPS) + +void PixelFormats::modifyMTLFormatCapabilities(id p_device) { + if (!supports_bc_texture_compression(p_device)) { +#pragma clang diagnostic push +#pragma clang diagnostic ignored "-Wunguarded-availability" + + disableAllMTLPixFmtCaps(BC1_RGBA); + disableAllMTLPixFmtCaps(BC1_RGBA_sRGB); + disableAllMTLPixFmtCaps(BC2_RGBA); + disableAllMTLPixFmtCaps(BC2_RGBA_sRGB); + disableAllMTLPixFmtCaps(BC3_RGBA); + disableAllMTLPixFmtCaps(BC3_RGBA_sRGB); + disableAllMTLPixFmtCaps(BC4_RUnorm); + disableAllMTLPixFmtCaps(BC4_RSnorm); + disableAllMTLPixFmtCaps(BC5_RGUnorm); + disableAllMTLPixFmtCaps(BC5_RGSnorm); + disableAllMTLPixFmtCaps(BC6H_RGBUfloat); + disableAllMTLPixFmtCaps(BC6H_RGBFloat); + disableAllMTLPixFmtCaps(BC7_RGBAUnorm); + disableAllMTLPixFmtCaps(BC7_RGBAUnorm_sRGB); + +#pragma clang diagnostic pop + } + + if (!p_device.supports32BitMSAA) { + disableMTLPixFmtCaps(R32Uint, MSAA); + disableMTLPixFmtCaps(R32Uint, Resolve); + disableMTLPixFmtCaps(R32Sint, MSAA); + disableMTLPixFmtCaps(R32Sint, Resolve); + disableMTLPixFmtCaps(R32Float, MSAA); + disableMTLPixFmtCaps(R32Float, Resolve); + disableMTLPixFmtCaps(RG32Uint, MSAA); + disableMTLPixFmtCaps(RG32Uint, Resolve); + disableMTLPixFmtCaps(RG32Sint, MSAA); + disableMTLPixFmtCaps(RG32Sint, Resolve); + disableMTLPixFmtCaps(RG32Float, MSAA); + disableMTLPixFmtCaps(RG32Float, Resolve); + disableMTLPixFmtCaps(RGBA32Uint, MSAA); + disableMTLPixFmtCaps(RGBA32Uint, Resolve); + disableMTLPixFmtCaps(RGBA32Sint, MSAA); + disableMTLPixFmtCaps(RGBA32Sint, Resolve); + disableMTLPixFmtCaps(RGBA32Float, MSAA); + disableMTLPixFmtCaps(RGBA32Float, Resolve); + } + + if (!p_device.supports32BitFloatFiltering) { + disableMTLPixFmtCaps(R32Float, Filter); + disableMTLPixFmtCaps(RG32Float, Filter); + disableMTLPixFmtCaps(RGBA32Float, Filter); + } + +#if TARGET_OS_OSX + addGPUMTLPixFmtCaps(Apple1, R32Uint, Atomic); + addGPUMTLPixFmtCaps(Apple1, R32Sint, Atomic); + + if (p_device.isDepth24Stencil8PixelFormatSupported) { + addGPUMTLPixFmtCaps(Apple1, Depth24Unorm_Stencil8, DRFMR); + } + + addFeatSetMTLPixFmtCaps(macOS_GPUFamily1_v2, Depth16Unorm, DRFMR); + + addFeatSetMTLPixFmtCaps(macOS_GPUFamily1_v3, BGR10A2Unorm, RFCMRB); + + addGPUMTLPixFmtCaps(Apple5, R8Unorm_sRGB, All); + + addGPUMTLPixFmtCaps(Apple5, RG8Unorm_sRGB, All); + + addGPUMTLPixFmtCaps(Apple5, B5G6R5Unorm, RFCMRB); + addGPUMTLPixFmtCaps(Apple5, A1BGR5Unorm, RFCMRB); + addGPUMTLPixFmtCaps(Apple5, ABGR4Unorm, RFCMRB); + addGPUMTLPixFmtCaps(Apple5, BGR5A1Unorm, RFCMRB); + + addGPUMTLPixFmtCaps(Apple5, RGBA8Unorm_sRGB, All); + addGPUMTLPixFmtCaps(Apple5, BGRA8Unorm_sRGB, All); + + // Blending is actually supported for this format, but format channels cannot be individually write-enabled during blending. + // Disabling blending is the least-intrusive way to handle this in a Godot-friendly way. + addGPUMTLPixFmtCaps(Apple5, RGB9E5Float, All); + disableMTLPixFmtCaps(RGB9E5Float, Blend); + + addGPUMTLPixFmtCaps(Apple5, PVRTC_RGBA_2BPP, RF); + addGPUMTLPixFmtCaps(Apple5, PVRTC_RGBA_2BPP_sRGB, RF); + addGPUMTLPixFmtCaps(Apple5, PVRTC_RGBA_4BPP, RF); + addGPUMTLPixFmtCaps(Apple5, PVRTC_RGBA_4BPP_sRGB, RF); + + addGPUMTLPixFmtCaps(Apple5, ETC2_RGB8, RF); + addGPUMTLPixFmtCaps(Apple5, ETC2_RGB8_sRGB, RF); + addGPUMTLPixFmtCaps(Apple5, ETC2_RGB8A1, RF); + addGPUMTLPixFmtCaps(Apple5, ETC2_RGB8A1_sRGB, RF); + addGPUMTLPixFmtCaps(Apple5, EAC_RGBA8, RF); + addGPUMTLPixFmtCaps(Apple5, EAC_RGBA8_sRGB, RF); + addGPUMTLPixFmtCaps(Apple5, EAC_R11Unorm, RF); + addGPUMTLPixFmtCaps(Apple5, EAC_R11Snorm, RF); + addGPUMTLPixFmtCaps(Apple5, EAC_RG11Unorm, RF); + addGPUMTLPixFmtCaps(Apple5, EAC_RG11Snorm, RF); + + addGPUMTLPixFmtCaps(Apple5, ASTC_4x4_LDR, RF); + addGPUMTLPixFmtCaps(Apple5, ASTC_4x4_sRGB, RF); + addGPUMTLPixFmtCaps(Apple6, ASTC_4x4_HDR, RF); + addGPUMTLPixFmtCaps(Apple5, ASTC_5x4_LDR, RF); + addGPUMTLPixFmtCaps(Apple5, ASTC_5x4_sRGB, RF); + addGPUMTLPixFmtCaps(Apple6, ASTC_5x4_HDR, RF); + addGPUMTLPixFmtCaps(Apple5, ASTC_5x5_LDR, RF); + addGPUMTLPixFmtCaps(Apple5, ASTC_5x5_sRGB, RF); + addGPUMTLPixFmtCaps(Apple6, ASTC_5x5_HDR, RF); + addGPUMTLPixFmtCaps(Apple5, ASTC_6x5_LDR, RF); + addGPUMTLPixFmtCaps(Apple5, ASTC_6x5_sRGB, RF); + addGPUMTLPixFmtCaps(Apple6, ASTC_6x5_HDR, RF); + addGPUMTLPixFmtCaps(Apple5, ASTC_6x6_LDR, RF); + addGPUMTLPixFmtCaps(Apple5, ASTC_6x6_sRGB, RF); + addGPUMTLPixFmtCaps(Apple6, ASTC_6x6_HDR, RF); + addGPUMTLPixFmtCaps(Apple5, ASTC_8x5_LDR, RF); + addGPUMTLPixFmtCaps(Apple5, ASTC_8x5_sRGB, RF); + addGPUMTLPixFmtCaps(Apple6, ASTC_8x5_HDR, RF); + addGPUMTLPixFmtCaps(Apple5, ASTC_8x6_LDR, RF); + addGPUMTLPixFmtCaps(Apple5, ASTC_8x6_sRGB, RF); + addGPUMTLPixFmtCaps(Apple6, ASTC_8x6_HDR, RF); + addGPUMTLPixFmtCaps(Apple5, ASTC_8x8_LDR, RF); + addGPUMTLPixFmtCaps(Apple5, ASTC_8x8_sRGB, RF); + addGPUMTLPixFmtCaps(Apple6, ASTC_8x8_HDR, RF); + addGPUMTLPixFmtCaps(Apple5, ASTC_10x5_LDR, RF); + addGPUMTLPixFmtCaps(Apple5, ASTC_10x5_sRGB, RF); + addGPUMTLPixFmtCaps(Apple6, ASTC_10x5_HDR, RF); + addGPUMTLPixFmtCaps(Apple5, ASTC_10x6_LDR, RF); + addGPUMTLPixFmtCaps(Apple5, ASTC_10x6_sRGB, RF); + addGPUMTLPixFmtCaps(Apple6, ASTC_10x6_HDR, RF); + addGPUMTLPixFmtCaps(Apple5, ASTC_10x8_LDR, RF); + addGPUMTLPixFmtCaps(Apple5, ASTC_10x8_sRGB, RF); + addGPUMTLPixFmtCaps(Apple6, ASTC_10x8_HDR, RF); + addGPUMTLPixFmtCaps(Apple5, ASTC_10x10_LDR, RF); + addGPUMTLPixFmtCaps(Apple5, ASTC_10x10_sRGB, RF); + addGPUMTLPixFmtCaps(Apple6, ASTC_10x10_HDR, RF); + addGPUMTLPixFmtCaps(Apple5, ASTC_12x10_LDR, RF); + addGPUMTLPixFmtCaps(Apple5, ASTC_12x10_sRGB, RF); + addGPUMTLPixFmtCaps(Apple6, ASTC_12x10_HDR, RF); + addGPUMTLPixFmtCaps(Apple5, ASTC_12x12_LDR, RF); + addGPUMTLPixFmtCaps(Apple5, ASTC_12x12_sRGB, RF); + addGPUMTLPixFmtCaps(Apple6, ASTC_12x12_HDR, RF); + + addGPUMTLPixFmtCaps(Apple5, BGRA10_XR, All); + addGPUMTLPixFmtCaps(Apple5, BGRA10_XR_sRGB, All); + addGPUMTLPixFmtCaps(Apple5, BGR10_XR, All); + addGPUMTLPixFmtCaps(Apple5, BGR10_XR_sRGB, All); + + addFeatSetMTLVtxFmtCaps(macOS_GPUFamily1_v3, UCharNormalized, Vertex); + addFeatSetMTLVtxFmtCaps(macOS_GPUFamily1_v3, CharNormalized, Vertex); + addFeatSetMTLVtxFmtCaps(macOS_GPUFamily1_v3, UChar, Vertex); + addFeatSetMTLVtxFmtCaps(macOS_GPUFamily1_v3, Char, Vertex); + addFeatSetMTLVtxFmtCaps(macOS_GPUFamily1_v3, UShortNormalized, Vertex); + addFeatSetMTLVtxFmtCaps(macOS_GPUFamily1_v3, ShortNormalized, Vertex); + addFeatSetMTLVtxFmtCaps(macOS_GPUFamily1_v3, UShort, Vertex); + addFeatSetMTLVtxFmtCaps(macOS_GPUFamily1_v3, Short, Vertex); + addFeatSetMTLVtxFmtCaps(macOS_GPUFamily1_v3, Half, Vertex); + addFeatSetMTLVtxFmtCaps(macOS_GPUFamily1_v3, UChar4Normalized_BGRA, Vertex); +#endif + +#if TARGET_OS_IOS && !TARGET_OS_MACCATALYST + addFeatSetMTLPixFmtCaps(iOS_GPUFamily2_v3, R8Unorm_sRGB, All); + addFeatSetMTLPixFmtCaps(iOS_GPUFamily3_v1, R8Unorm_sRGB, All); + + addFeatSetMTLPixFmtCaps(iOS_GPUFamily2_v1, R8Snorm, All); + + addFeatSetMTLPixFmtCaps(iOS_GPUFamily2_v3, RG8Unorm_sRGB, All); + addFeatSetMTLPixFmtCaps(iOS_GPUFamily3_v1, RG8Unorm_sRGB, All); + + addFeatSetMTLPixFmtCaps(iOS_GPUFamily2_v1, RG8Snorm, All); + + addFeatSetMTLPixFmtCaps(iOS_GPUFamily1_v2, R32Uint, RWC); + addFeatSetMTLPixFmtCaps(iOS_GPUFamily1_v2, R32Uint, Atomic); + addFeatSetMTLPixFmtCaps(iOS_GPUFamily1_v2, R32Sint, RWC); + addFeatSetMTLPixFmtCaps(iOS_GPUFamily1_v2, R32Sint, Atomic); + + addFeatSetMTLPixFmtCaps(iOS_GPUFamily1_v2, R32Float, RWCMB); + + addFeatSetMTLPixFmtCaps(iOS_GPUFamily2_v3, RGBA8Unorm_sRGB, All); + addFeatSetMTLPixFmtCaps(iOS_GPUFamily3_v1, RGBA8Unorm_sRGB, All); + + addFeatSetMTLPixFmtCaps(iOS_GPUFamily2_v1, RGBA8Snorm, All); + + addFeatSetMTLPixFmtCaps(iOS_GPUFamily2_v3, BGRA8Unorm_sRGB, All); + addFeatSetMTLPixFmtCaps(iOS_GPUFamily3_v1, BGRA8Unorm_sRGB, All); + + addFeatSetMTLPixFmtCaps(iOS_GPUFamily3_v1, RGB10A2Unorm, All); + addFeatSetMTLPixFmtCaps(iOS_GPUFamily3_v1, RGB10A2Uint, RWCM); + addFeatSetMTLPixFmtCaps(iOS_GPUFamily3_v1, RG11B10Float, All); + addFeatSetMTLPixFmtCaps(iOS_GPUFamily3_v1, RGB9E5Float, All); + + addFeatSetMTLPixFmtCaps(iOS_GPUFamily1_v2, RG32Uint, RWC); + addFeatSetMTLPixFmtCaps(iOS_GPUFamily1_v2, RG32Sint, RWC); + addFeatSetMTLPixFmtCaps(iOS_GPUFamily1_v2, RG32Float, RWCB); + + addFeatSetMTLPixFmtCaps(iOS_GPUFamily1_v2, RGBA32Uint, RWC); + addFeatSetMTLPixFmtCaps(iOS_GPUFamily1_v2, RGBA32Sint, RWC); + addFeatSetMTLPixFmtCaps(iOS_GPUFamily1_v2, RGBA32Float, RWC); + + addFeatSetMTLPixFmtCaps(iOS_GPUFamily2_v1, ASTC_4x4_LDR, RF); + addFeatSetMTLPixFmtCaps(iOS_GPUFamily2_v1, ASTC_4x4_sRGB, RF); + addFeatSetMTLPixFmtCaps(iOS_GPUFamily2_v1, ASTC_5x4_LDR, RF); + addFeatSetMTLPixFmtCaps(iOS_GPUFamily2_v1, ASTC_5x4_sRGB, RF); + addFeatSetMTLPixFmtCaps(iOS_GPUFamily2_v1, ASTC_5x5_LDR, RF); + addFeatSetMTLPixFmtCaps(iOS_GPUFamily2_v1, ASTC_5x5_sRGB, RF); + addFeatSetMTLPixFmtCaps(iOS_GPUFamily2_v1, ASTC_6x5_LDR, RF); + addFeatSetMTLPixFmtCaps(iOS_GPUFamily2_v1, ASTC_6x5_sRGB, RF); + addFeatSetMTLPixFmtCaps(iOS_GPUFamily2_v1, ASTC_6x6_LDR, RF); + addFeatSetMTLPixFmtCaps(iOS_GPUFamily2_v1, ASTC_6x6_sRGB, RF); + addFeatSetMTLPixFmtCaps(iOS_GPUFamily2_v1, ASTC_8x5_LDR, RF); + addFeatSetMTLPixFmtCaps(iOS_GPUFamily2_v1, ASTC_8x5_sRGB, RF); + addFeatSetMTLPixFmtCaps(iOS_GPUFamily2_v1, ASTC_8x6_LDR, RF); + addFeatSetMTLPixFmtCaps(iOS_GPUFamily2_v1, ASTC_8x6_sRGB, RF); + addFeatSetMTLPixFmtCaps(iOS_GPUFamily2_v1, ASTC_8x8_LDR, RF); + addFeatSetMTLPixFmtCaps(iOS_GPUFamily2_v1, ASTC_8x8_sRGB, RF); + addFeatSetMTLPixFmtCaps(iOS_GPUFamily2_v1, ASTC_10x5_LDR, RF); + addFeatSetMTLPixFmtCaps(iOS_GPUFamily2_v1, ASTC_10x5_sRGB, RF); + addFeatSetMTLPixFmtCaps(iOS_GPUFamily2_v1, ASTC_10x6_LDR, RF); + addFeatSetMTLPixFmtCaps(iOS_GPUFamily2_v1, ASTC_10x6_sRGB, RF); + addFeatSetMTLPixFmtCaps(iOS_GPUFamily2_v1, ASTC_10x8_LDR, RF); + addFeatSetMTLPixFmtCaps(iOS_GPUFamily2_v1, ASTC_10x8_sRGB, RF); + addFeatSetMTLPixFmtCaps(iOS_GPUFamily2_v1, ASTC_10x10_LDR, RF); + addFeatSetMTLPixFmtCaps(iOS_GPUFamily2_v1, ASTC_10x10_sRGB, RF); + addFeatSetMTLPixFmtCaps(iOS_GPUFamily2_v1, ASTC_12x10_LDR, RF); + addFeatSetMTLPixFmtCaps(iOS_GPUFamily2_v1, ASTC_12x10_sRGB, RF); + addFeatSetMTLPixFmtCaps(iOS_GPUFamily2_v1, ASTC_12x12_LDR, RF); + addFeatSetMTLPixFmtCaps(iOS_GPUFamily2_v1, ASTC_12x12_sRGB, RF); + + addFeatSetMTLPixFmtCaps(iOS_GPUFamily3_v1, Depth32Float, DRMR); + addFeatSetMTLPixFmtCaps(iOS_GPUFamily3_v1, Depth32Float_Stencil8, DRMR); + addFeatSetMTLPixFmtCaps(iOS_GPUFamily3_v1, Stencil8, DRMR); + + addFeatSetMTLPixFmtCaps(iOS_GPUFamily3_v2, BGRA10_XR, All); + addFeatSetMTLPixFmtCaps(iOS_GPUFamily3_v2, BGRA10_XR_sRGB, All); + addFeatSetMTLPixFmtCaps(iOS_GPUFamily3_v2, BGR10_XR, All); + addFeatSetMTLPixFmtCaps(iOS_GPUFamily3_v2, BGR10_XR_sRGB, All); + + addFeatSetMTLPixFmtCaps(iOS_GPUFamily1_v4, BGR10A2Unorm, All); + + addGPUMTLPixFmtCaps(Apple6, ASTC_4x4_HDR, RF); + addGPUMTLPixFmtCaps(Apple6, ASTC_5x4_HDR, RF); + addGPUMTLPixFmtCaps(Apple6, ASTC_5x5_HDR, RF); + addGPUMTLPixFmtCaps(Apple6, ASTC_6x5_HDR, RF); + addGPUMTLPixFmtCaps(Apple6, ASTC_6x6_HDR, RF); + addGPUMTLPixFmtCaps(Apple6, ASTC_8x5_HDR, RF); + addGPUMTLPixFmtCaps(Apple6, ASTC_8x6_HDR, RF); + addGPUMTLPixFmtCaps(Apple6, ASTC_8x8_HDR, RF); + addGPUMTLPixFmtCaps(Apple6, ASTC_10x5_HDR, RF); + addGPUMTLPixFmtCaps(Apple6, ASTC_10x6_HDR, RF); + addGPUMTLPixFmtCaps(Apple6, ASTC_10x8_HDR, RF); + addGPUMTLPixFmtCaps(Apple6, ASTC_10x10_HDR, RF); + addGPUMTLPixFmtCaps(Apple6, ASTC_12x10_HDR, RF); + addGPUMTLPixFmtCaps(Apple6, ASTC_12x12_HDR, RF); + + addGPUMTLPixFmtCaps(Apple1, Depth16Unorm, DRFM); + addGPUMTLPixFmtCaps(Apple3, Depth16Unorm, DRFMR); + + // Vertex formats. + addFeatSetMTLVtxFmtCaps(iOS_GPUFamily1_v4, UCharNormalized, Vertex); + addFeatSetMTLVtxFmtCaps(iOS_GPUFamily1_v4, CharNormalized, Vertex); + addFeatSetMTLVtxFmtCaps(iOS_GPUFamily1_v4, UChar, Vertex); + addFeatSetMTLVtxFmtCaps(iOS_GPUFamily1_v4, Char, Vertex); + addFeatSetMTLVtxFmtCaps(iOS_GPUFamily1_v4, UShortNormalized, Vertex); + addFeatSetMTLVtxFmtCaps(iOS_GPUFamily1_v4, ShortNormalized, Vertex); + addFeatSetMTLVtxFmtCaps(iOS_GPUFamily1_v4, UShort, Vertex); + addFeatSetMTLVtxFmtCaps(iOS_GPUFamily1_v4, Short, Vertex); + addFeatSetMTLVtxFmtCaps(iOS_GPUFamily1_v4, Half, Vertex); + addFeatSetMTLVtxFmtCaps(iOS_GPUFamily1_v4, UChar4Normalized_BGRA, Vertex); + +// Disable for iOS simulator last. +#if TARGET_OS_SIMULATOR + if (![mtlDevice supportsFamily:MTLGPUFamilyApple5]) { + disableAllMTLPixFmtCaps(R8Unorm_sRGB); + disableAllMTLPixFmtCaps(RG8Unorm_sRGB); + disableAllMTLPixFmtCaps(B5G6R5Unorm); + disableAllMTLPixFmtCaps(A1BGR5Unorm); + disableAllMTLPixFmtCaps(ABGR4Unorm); + disableAllMTLPixFmtCaps(BGR5A1Unorm); + + disableAllMTLPixFmtCaps(BGRA10_XR); + disableAllMTLPixFmtCaps(BGRA10_XR_sRGB); + disableAllMTLPixFmtCaps(BGR10_XR); + disableAllMTLPixFmtCaps(BGR10_XR_sRGB); + + disableAllMTLPixFmtCaps(GBGR422); + disableAllMTLPixFmtCaps(BGRG422); + + disableMTLPixFmtCaps(RGB9E5Float, ColorAtt); + + disableMTLPixFmtCaps(R8Unorm_sRGB, Write); + disableMTLPixFmtCaps(RG8Unorm_sRGB, Write); + disableMTLPixFmtCaps(RGBA8Unorm_sRGB, Write); + disableMTLPixFmtCaps(BGRA8Unorm_sRGB, Write); + disableMTLPixFmtCaps(PVRTC_RGBA_2BPP_sRGB, Write); + disableMTLPixFmtCaps(PVRTC_RGBA_4BPP_sRGB, Write); + disableMTLPixFmtCaps(ETC2_RGB8_sRGB, Write); + disableMTLPixFmtCaps(ETC2_RGB8A1_sRGB, Write); + disableMTLPixFmtCaps(EAC_RGBA8_sRGB, Write); + disableMTLPixFmtCaps(ASTC_4x4_sRGB, Write); + disableMTLPixFmtCaps(ASTC_5x4_sRGB, Write); + disableMTLPixFmtCaps(ASTC_5x5_sRGB, Write); + disableMTLPixFmtCaps(ASTC_6x5_sRGB, Write); + disableMTLPixFmtCaps(ASTC_6x6_sRGB, Write); + disableMTLPixFmtCaps(ASTC_8x5_sRGB, Write); + disableMTLPixFmtCaps(ASTC_8x6_sRGB, Write); + disableMTLPixFmtCaps(ASTC_8x8_sRGB, Write); + disableMTLPixFmtCaps(ASTC_10x5_sRGB, Write); + disableMTLPixFmtCaps(ASTC_10x6_sRGB, Write); + disableMTLPixFmtCaps(ASTC_10x8_sRGB, Write); + disableMTLPixFmtCaps(ASTC_10x10_sRGB, Write); + disableMTLPixFmtCaps(ASTC_12x10_sRGB, Write); + disableMTLPixFmtCaps(ASTC_12x12_sRGB, Write); + } +#endif +#endif +} + +#undef addFeatSetMTLPixFmtCaps +#undef addGPUOSMTLPixFmtCaps +#undef disableMTLPixFmtCaps +#undef disableAllMTLPixFmtCaps +#undef addFeatSetMTLVtxFmtCaps + +// Populates the DataFormat lookup maps and connects Godot and Metal pixel formats to one-another. +void PixelFormats::buildDFFormatMaps() { + // Iterate through the DataFormat descriptions, populate the lookup maps and back pointers, + // and validate the Metal formats for the platform and OS. + for (uint32_t fmtIdx = 0; fmtIdx < RD::DATA_FORMAT_MAX; fmtIdx++) { + DataFormatDesc &dfDesc = _dataFormatDescriptions[fmtIdx]; + DataFormat dfFmt = dfDesc.dataFormat; + if (dfFmt != RD::DATA_FORMAT_MAX) { + // Populate the back reference from the Metal formats to the Godot format. + // Validate the corresponding Metal formats for the platform, and clear them + // in the Godot format if not supported. + if (dfDesc.mtlPixelFormat) { + MTLFormatDesc &mtlDesc = getMTLPixelFormatDesc(dfDesc.mtlPixelFormat); + if (mtlDesc.dataFormat == RD::DATA_FORMAT_MAX) { + mtlDesc.dataFormat = dfFmt; + } + if (!mtlDesc.isSupported()) { + dfDesc.mtlPixelFormat = MTLPixelFormatInvalid; + } + } + if (dfDesc.mtlPixelFormatSubstitute) { + MTLFormatDesc &mtlDesc = getMTLPixelFormatDesc(dfDesc.mtlPixelFormatSubstitute); + if (!mtlDesc.isSupported()) { + dfDesc.mtlPixelFormatSubstitute = MTLPixelFormatInvalid; + } + } + if (dfDesc.mtlVertexFormat) { + MTLFormatDesc &mtlDesc = getMTLVertexFormatDesc(dfDesc.mtlVertexFormat); + if (mtlDesc.dataFormat == RD::DATA_FORMAT_MAX) { + mtlDesc.dataFormat = dfFmt; + } + if (!mtlDesc.isSupported()) { + dfDesc.mtlVertexFormat = MTLVertexFormatInvalid; + } + } + if (dfDesc.mtlVertexFormatSubstitute) { + MTLFormatDesc &mtlDesc = getMTLVertexFormatDesc(dfDesc.mtlVertexFormatSubstitute); + if (!mtlDesc.isSupported()) { + dfDesc.mtlVertexFormatSubstitute = MTLVertexFormatInvalid; + } + } + } + } +} diff --git a/drivers/metal/rendering_context_driver_metal.h b/drivers/metal/rendering_context_driver_metal.h new file mode 100644 index 00000000000..0363ab111a2 --- /dev/null +++ b/drivers/metal/rendering_context_driver_metal.h @@ -0,0 +1,206 @@ +/**************************************************************************/ +/* rendering_context_driver_metal.h */ +/**************************************************************************/ +/* This file is part of: */ +/* GODOT ENGINE */ +/* https://godotengine.org */ +/**************************************************************************/ +/* Copyright (c) 2014-present Godot Engine contributors (see AUTHORS.md). */ +/* Copyright (c) 2007-2014 Juan Linietsky, Ariel Manzur. */ +/* */ +/* Permission is hereby granted, free of charge, to any person obtaining */ +/* a copy of this software and associated documentation files (the */ +/* "Software"), to deal in the Software without restriction, including */ +/* without limitation the rights to use, copy, modify, merge, publish, */ +/* distribute, sublicense, and/or sell copies of the Software, and to */ +/* permit persons to whom the Software is furnished to do so, subject to */ +/* the following conditions: */ +/* */ +/* The above copyright notice and this permission notice shall be */ +/* included in all copies or substantial portions of the Software. */ +/* */ +/* THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, */ +/* EXPRESS OR IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF */ +/* MERCHANTABILITY, FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. */ +/* IN NO EVENT SHALL THE AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY */ +/* CLAIM, DAMAGES OR OTHER LIABILITY, WHETHER IN AN ACTION OF CONTRACT, */ +/* TORT OR OTHERWISE, ARISING FROM, OUT OF OR IN CONNECTION WITH THE */ +/* SOFTWARE OR THE USE OR OTHER DEALINGS IN THE SOFTWARE. */ +/**************************************************************************/ + +#ifndef RENDERING_CONTEXT_DRIVER_METAL_H +#define RENDERING_CONTEXT_DRIVER_METAL_H + +#ifdef METAL_ENABLED + +#import "rendering_device_driver_metal.h" + +#import "servers/rendering/rendering_context_driver.h" + +#import +#import +#import + +@class CAMetalLayer; +@protocol CAMetalDrawable; +class PixelFormats; +class MDResourceCache; + +class API_AVAILABLE(macos(11.0), ios(14.0)) RenderingContextDriverMetal : public RenderingContextDriver { +protected: + id metal_device = nil; + Device device; // There is only one device on Apple Silicon. + +public: + Error initialize() final override; + const Device &device_get(uint32_t p_device_index) const final override; + uint32_t device_get_count() const final override; + bool device_supports_present(uint32_t p_device_index, SurfaceID p_surface) const final override { return true; } + RenderingDeviceDriver *driver_create() final override; + void driver_free(RenderingDeviceDriver *p_driver) final override; + SurfaceID surface_create(const void *p_platform_data) final override; + void surface_set_size(SurfaceID p_surface, uint32_t p_width, uint32_t p_height) final override; + void surface_set_vsync_mode(SurfaceID p_surface, DisplayServer::VSyncMode p_vsync_mode) final override; + DisplayServer::VSyncMode surface_get_vsync_mode(SurfaceID p_surface) const final override; + uint32_t surface_get_width(SurfaceID p_surface) const final override; + uint32_t surface_get_height(SurfaceID p_surface) const final override; + void surface_set_needs_resize(SurfaceID p_surface, bool p_needs_resize) final override; + bool surface_get_needs_resize(SurfaceID p_surface) const final override; + void surface_destroy(SurfaceID p_surface) final override; + bool is_debug_utils_enabled() const final override { return true; } + +#pragma mark - Metal-specific methods + + // Platform-specific data for the Windows embedded in this driver. + struct WindowPlatformData { + CAMetalLayer *__unsafe_unretained layer; + }; + + class Surface { + protected: + id device; + + public: + uint32_t width = 0; + uint32_t height = 0; + DisplayServer::VSyncMode vsync_mode = DisplayServer::VSYNC_ENABLED; + bool needs_resize = false; + + Surface(id p_device) : + device(p_device) {} + virtual ~Surface() = default; + + MTLPixelFormat get_pixel_format() const { return MTLPixelFormatBGRA8Unorm; } + virtual Error resize(uint32_t p_desired_framebuffer_count) = 0; + virtual RDD::FramebufferID acquire_next_frame_buffer() = 0; + virtual void present(MDCommandBuffer *p_cmd_buffer) = 0; + }; + + class SurfaceLayer : public Surface { + CAMetalLayer *__unsafe_unretained layer = nil; + LocalVector frame_buffers; + LocalVector> drawables; + uint32_t rear = -1; + uint32_t front = 0; + uint32_t count = 0; + + public: + SurfaceLayer(CAMetalLayer *p_layer, id p_device) : + Surface(p_device), layer(p_layer) { + layer.allowsNextDrawableTimeout = YES; + layer.framebufferOnly = YES; + layer.opaque = OS::get_singleton()->is_layered_allowed() ? NO : YES; + layer.pixelFormat = get_pixel_format(); + layer.device = p_device; + } + + ~SurfaceLayer() override { + layer = nil; + } + + Error resize(uint32_t p_desired_framebuffer_count) override final { + if (width == 0 || height == 0) { + // Very likely the window is minimized, don't create a swap chain. + return ERR_SKIP; + } + + CGSize drawableSize = CGSizeMake(width, height); + CGSize current = layer.drawableSize; + if (!CGSizeEqualToSize(current, drawableSize)) { + layer.drawableSize = drawableSize; + } + + // Metal supports a maximum of 3 drawables. + p_desired_framebuffer_count = MIN(3U, p_desired_framebuffer_count); + layer.maximumDrawableCount = p_desired_framebuffer_count; + +#if TARGET_OS_OSX + // Display sync is only supported on macOS. + switch (vsync_mode) { + case DisplayServer::VSYNC_MAILBOX: + case DisplayServer::VSYNC_ADAPTIVE: + case DisplayServer::VSYNC_ENABLED: + layer.displaySyncEnabled = YES; + break; + case DisplayServer::VSYNC_DISABLED: + layer.displaySyncEnabled = NO; + break; + } +#endif + drawables.resize(p_desired_framebuffer_count); + frame_buffers.resize(p_desired_framebuffer_count); + for (uint32_t i = 0; i < p_desired_framebuffer_count; i++) { + // Reserve space for the drawable texture. + frame_buffers[i].textures.resize(1); + } + + return OK; + } + + RDD::FramebufferID acquire_next_frame_buffer() override final { + if (count == frame_buffers.size()) { + return RDD::FramebufferID(); + } + + rear = (rear + 1) % frame_buffers.size(); + count++; + + MDFrameBuffer &frame_buffer = frame_buffers[rear]; + frame_buffer.size = Size2i(width, height); + + id drawable = layer.nextDrawable; + ERR_FAIL_NULL_V_MSG(drawable, RDD::FramebufferID(), "no drawable available"); + drawables[rear] = drawable; + frame_buffer.textures.write[0] = drawable.texture; + + return RDD::FramebufferID(&frame_buffer); + } + + void present(MDCommandBuffer *p_cmd_buffer) override final { + if (count == 0) { + return; + } + + // Release texture and drawable. + frame_buffers[front].textures.write[0] = nil; + id drawable = drawables[front]; + drawables[front] = nil; + + count--; + front = (front + 1) % frame_buffers.size(); + + [p_cmd_buffer->get_command_buffer() presentDrawable:drawable]; + } + }; + + id get_metal_device() const { return metal_device; } + +#pragma mark - Initialization + + RenderingContextDriverMetal(); + ~RenderingContextDriverMetal() override; +}; + +#endif // METAL_ENABLED + +#endif // RENDERING_CONTEXT_DRIVER_METAL_H diff --git a/drivers/metal/rendering_context_driver_metal.mm b/drivers/metal/rendering_context_driver_metal.mm new file mode 100644 index 00000000000..b257d7142ab --- /dev/null +++ b/drivers/metal/rendering_context_driver_metal.mm @@ -0,0 +1,134 @@ +/**************************************************************************/ +/* rendering_context_driver_metal.mm */ +/**************************************************************************/ +/* This file is part of: */ +/* GODOT ENGINE */ +/* https://godotengine.org */ +/**************************************************************************/ +/* Copyright (c) 2014-present Godot Engine contributors (see AUTHORS.md). */ +/* Copyright (c) 2007-2014 Juan Linietsky, Ariel Manzur. */ +/* */ +/* Permission is hereby granted, free of charge, to any person obtaining */ +/* a copy of this software and associated documentation files (the */ +/* "Software"), to deal in the Software without restriction, including */ +/* without limitation the rights to use, copy, modify, merge, publish, */ +/* distribute, sublicense, and/or sell copies of the Software, and to */ +/* permit persons to whom the Software is furnished to do so, subject to */ +/* the following conditions: */ +/* */ +/* The above copyright notice and this permission notice shall be */ +/* included in all copies or substantial portions of the Software. */ +/* */ +/* THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, */ +/* EXPRESS OR IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF */ +/* MERCHANTABILITY, FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. */ +/* IN NO EVENT SHALL THE AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY */ +/* CLAIM, DAMAGES OR OTHER LIABILITY, WHETHER IN AN ACTION OF CONTRACT, */ +/* TORT OR OTHERWISE, ARISING FROM, OUT OF OR IN CONNECTION WITH THE */ +/* SOFTWARE OR THE USE OR OTHER DEALINGS IN THE SOFTWARE. */ +/**************************************************************************/ + +#import "rendering_context_driver_metal.h" + +@protocol MTLDeviceEx +#if TARGET_OS_OSX && __MAC_OS_X_VERSION_MAX_ALLOWED < 130300 +- (void)setShouldMaximizeConcurrentCompilation:(BOOL)v; +#endif +@end + +RenderingContextDriverMetal::RenderingContextDriverMetal() { +} + +RenderingContextDriverMetal::~RenderingContextDriverMetal() { +} + +Error RenderingContextDriverMetal::initialize() { + metal_device = MTLCreateSystemDefaultDevice(); +#if TARGET_OS_OSX + if (@available(macOS 13.3, *)) { + [id(metal_device) setShouldMaximizeConcurrentCompilation:YES]; + } +#endif + device.type = DEVICE_TYPE_INTEGRATED_GPU; + device.vendor = VENDOR_APPLE; + device.workarounds = Workarounds(); + + MetalDeviceProperties props(metal_device); + int version = (int)props.features.highestFamily - (int)MTLGPUFamilyApple1 + 1; + device.name = vformat("%s (Apple%d)", metal_device.name.UTF8String, version); + + return OK; +} + +const RenderingContextDriver::Device &RenderingContextDriverMetal::device_get(uint32_t p_device_index) const { + DEV_ASSERT(p_device_index < 1); + return device; +} + +uint32_t RenderingContextDriverMetal::device_get_count() const { + return 1; +} + +RenderingDeviceDriver *RenderingContextDriverMetal::driver_create() { + return memnew(RenderingDeviceDriverMetal(this)); +} + +void RenderingContextDriverMetal::driver_free(RenderingDeviceDriver *p_driver) { + memdelete(p_driver); +} + +RenderingContextDriver::SurfaceID RenderingContextDriverMetal::surface_create(const void *p_platform_data) { + const WindowPlatformData *wpd = (const WindowPlatformData *)(p_platform_data); + Surface *surface = memnew(SurfaceLayer(wpd->layer, metal_device)); + + return SurfaceID(surface); +} + +void RenderingContextDriverMetal::surface_set_size(SurfaceID p_surface, uint32_t p_width, uint32_t p_height) { + Surface *surface = (Surface *)(p_surface); + if (surface->width == p_width && surface->height == p_height) { + return; + } + surface->width = p_width; + surface->height = p_height; + surface->needs_resize = true; +} + +void RenderingContextDriverMetal::surface_set_vsync_mode(SurfaceID p_surface, DisplayServer::VSyncMode p_vsync_mode) { + Surface *surface = (Surface *)(p_surface); + if (surface->vsync_mode == p_vsync_mode) { + return; + } + surface->vsync_mode = p_vsync_mode; + surface->needs_resize = true; +} + +DisplayServer::VSyncMode RenderingContextDriverMetal::surface_get_vsync_mode(SurfaceID p_surface) const { + Surface *surface = (Surface *)(p_surface); + return surface->vsync_mode; +} + +uint32_t RenderingContextDriverMetal::surface_get_width(SurfaceID p_surface) const { + Surface *surface = (Surface *)(p_surface); + return surface->width; +} + +uint32_t RenderingContextDriverMetal::surface_get_height(SurfaceID p_surface) const { + Surface *surface = (Surface *)(p_surface); + return surface->height; +} + +void RenderingContextDriverMetal::surface_set_needs_resize(SurfaceID p_surface, bool p_needs_resize) { + Surface *surface = (Surface *)(p_surface); + surface->needs_resize = p_needs_resize; +} + +bool RenderingContextDriverMetal::surface_get_needs_resize(SurfaceID p_surface) const { + Surface *surface = (Surface *)(p_surface); + return surface->needs_resize; +} + +void RenderingContextDriverMetal::surface_destroy(SurfaceID p_surface) { + Surface *surface = (Surface *)(p_surface); + memdelete(surface); +} diff --git a/drivers/metal/rendering_device_driver_metal.h b/drivers/metal/rendering_device_driver_metal.h new file mode 100644 index 00000000000..e7efaff8177 --- /dev/null +++ b/drivers/metal/rendering_device_driver_metal.h @@ -0,0 +1,417 @@ +/**************************************************************************/ +/* rendering_device_driver_metal.h */ +/**************************************************************************/ +/* This file is part of: */ +/* GODOT ENGINE */ +/* https://godotengine.org */ +/**************************************************************************/ +/* Copyright (c) 2014-present Godot Engine contributors (see AUTHORS.md). */ +/* Copyright (c) 2007-2014 Juan Linietsky, Ariel Manzur. */ +/* */ +/* Permission is hereby granted, free of charge, to any person obtaining */ +/* a copy of this software and associated documentation files (the */ +/* "Software"), to deal in the Software without restriction, including */ +/* without limitation the rights to use, copy, modify, merge, publish, */ +/* distribute, sublicense, and/or sell copies of the Software, and to */ +/* permit persons to whom the Software is furnished to do so, subject to */ +/* the following conditions: */ +/* */ +/* The above copyright notice and this permission notice shall be */ +/* included in all copies or substantial portions of the Software. */ +/* */ +/* THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, */ +/* EXPRESS OR IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF */ +/* MERCHANTABILITY, FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. */ +/* IN NO EVENT SHALL THE AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY */ +/* CLAIM, DAMAGES OR OTHER LIABILITY, WHETHER IN AN ACTION OF CONTRACT, */ +/* TORT OR OTHERWISE, ARISING FROM, OUT OF OR IN CONNECTION WITH THE */ +/* SOFTWARE OR THE USE OR OTHER DEALINGS IN THE SOFTWARE. */ +/**************************************************************************/ + +#ifndef RENDERING_DEVICE_DRIVER_METAL_H +#define RENDERING_DEVICE_DRIVER_METAL_H + +#import "metal_objects.h" + +#import "servers/rendering/rendering_device_driver.h" + +#import +#import +#import + +#ifdef DEBUG_ENABLED +#ifndef _DEBUG +#define _DEBUG +#endif +#endif + +class RenderingContextDriverMetal; + +class API_AVAILABLE(macos(11.0), ios(14.0)) RenderingDeviceDriverMetal : public RenderingDeviceDriver { + template + using Result = std::variant; + +#pragma mark - Generic + + RenderingContextDriverMetal *context_driver = nullptr; + RenderingContextDriver::Device context_device; + id device = nil; + + uint32_t version_major = 2; + uint32_t version_minor = 0; + MetalDeviceProperties *metal_device_properties = nullptr; + PixelFormats *pixel_formats = nullptr; + std::unique_ptr resource_cache; + + RDD::Capabilities capabilities; + RDD::MultiviewCapabilities multiview_capabilities; + + id archive = nil; + uint32_t archive_count = 0; + + id device_queue = nil; + id device_scope = nil; + + String pipeline_cache_id; + + Error _create_device(); + Error _check_capabilities(); + +public: + Error initialize(uint32_t p_device_index, uint32_t p_frame_count) override final; + +#pragma mark - Memory + +#pragma mark - Buffers + +public: + virtual BufferID buffer_create(uint64_t p_size, BitField p_usage, MemoryAllocationType p_allocation_type) override final; + virtual bool buffer_set_texel_format(BufferID p_buffer, DataFormat p_format) override final; + virtual void buffer_free(BufferID p_buffer) override final; + virtual uint64_t buffer_get_allocation_size(BufferID p_buffer) override final; + virtual uint8_t *buffer_map(BufferID p_buffer) override final; + virtual void buffer_unmap(BufferID p_buffer) override final; + +#pragma mark - Texture + +private: + // Returns true if the texture is a valid linear format. + Result is_valid_linear(TextureFormat const &p_format) const; + void _get_sub_resource(TextureID p_texture, const TextureSubresource &p_subresource, TextureCopyableLayout *r_layout) const; + +public: + virtual TextureID texture_create(const TextureFormat &p_format, const TextureView &p_view) override final; + virtual TextureID texture_create_from_extension(uint64_t p_native_texture, TextureType p_type, DataFormat p_format, uint32_t p_array_layers, bool p_depth_stencil) override final; + virtual TextureID texture_create_shared(TextureID p_original_texture, const TextureView &p_view) override final; + virtual TextureID texture_create_shared_from_slice(TextureID p_original_texture, const TextureView &p_view, TextureSliceType p_slice_type, uint32_t p_layer, uint32_t p_layers, uint32_t p_mipmap, uint32_t p_mipmaps) override final; + virtual void texture_free(TextureID p_texture) override final; + virtual uint64_t texture_get_allocation_size(TextureID p_texture) override final; + virtual void texture_get_copyable_layout(TextureID p_texture, const TextureSubresource &p_subresource, TextureCopyableLayout *r_layout) override final; + virtual uint8_t *texture_map(TextureID p_texture, const TextureSubresource &p_subresource) override final; + virtual void texture_unmap(TextureID p_texture) override final; + virtual BitField texture_get_usages_supported_by_format(DataFormat p_format, bool p_cpu_readable) override final; + virtual bool texture_can_make_shared_with_format(TextureID p_texture, DataFormat p_format, bool &r_raw_reinterpretation) override final; + +#pragma mark - Sampler + +public: + virtual SamplerID sampler_create(const SamplerState &p_state) final override; + virtual void sampler_free(SamplerID p_sampler) final override; + virtual bool sampler_is_format_supported_for_filter(DataFormat p_format, SamplerFilter p_filter) override final; + +#pragma mark - Vertex Array + +private: +public: + virtual VertexFormatID vertex_format_create(VectorView p_vertex_attribs) override final; + virtual void vertex_format_free(VertexFormatID p_vertex_format) override final; + +#pragma mark - Barriers + + virtual void command_pipeline_barrier( + CommandBufferID p_cmd_buffer, + BitField p_src_stages, + BitField p_dst_stages, + VectorView p_memory_barriers, + VectorView p_buffer_barriers, + VectorView p_texture_barriers) override final; + +#pragma mark - Fences + +private: + struct Fence { + dispatch_semaphore_t semaphore; + Fence() : + semaphore(dispatch_semaphore_create(0)) {} + }; + +public: + virtual FenceID fence_create() override final; + virtual Error fence_wait(FenceID p_fence) override final; + virtual void fence_free(FenceID p_fence) override final; + +#pragma mark - Semaphores + +public: + virtual SemaphoreID semaphore_create() override final; + virtual void semaphore_free(SemaphoreID p_semaphore) override final; + +#pragma mark - Commands + // ----- QUEUE FAMILY ----- + + virtual CommandQueueFamilyID command_queue_family_get(BitField p_cmd_queue_family_bits, RenderingContextDriver::SurfaceID p_surface = 0) override final; + + // ----- QUEUE ----- +public: + virtual CommandQueueID command_queue_create(CommandQueueFamilyID p_cmd_queue_family, bool p_identify_as_main_queue = false) override final; + virtual Error command_queue_execute_and_present(CommandQueueID p_cmd_queue, VectorView p_wait_semaphores, VectorView p_cmd_buffers, VectorView p_cmd_semaphores, FenceID p_cmd_fence, VectorView p_swap_chains) override final; + virtual void command_queue_free(CommandQueueID p_cmd_queue) override final; + + // ----- POOL ----- + + virtual CommandPoolID command_pool_create(CommandQueueFamilyID p_cmd_queue_family, CommandBufferType p_cmd_buffer_type) override final; + virtual void command_pool_free(CommandPoolID p_cmd_pool) override final; + + // ----- BUFFER ----- + +private: + // Used to maintain references. + Vector command_buffers; + +public: + virtual CommandBufferID command_buffer_create(CommandPoolID p_cmd_pool) override final; + virtual bool command_buffer_begin(CommandBufferID p_cmd_buffer) override final; + virtual bool command_buffer_begin_secondary(CommandBufferID p_cmd_buffer, RenderPassID p_render_pass, uint32_t p_subpass, FramebufferID p_framebuffer) override final; + virtual void command_buffer_end(CommandBufferID p_cmd_buffer) override final; + virtual void command_buffer_execute_secondary(CommandBufferID p_cmd_buffer, VectorView p_secondary_cmd_buffers) override final; + +#pragma mark - Swapchain + +private: + struct SwapChain { + RenderingContextDriver::SurfaceID surface = RenderingContextDriver::SurfaceID(); + RenderPassID render_pass; + RDD::DataFormat data_format = DATA_FORMAT_MAX; + SwapChain() : + render_pass(nullptr) {} + }; + + void _swap_chain_release(SwapChain *p_swap_chain); + void _swap_chain_release_buffers(SwapChain *p_swap_chain); + +public: + virtual SwapChainID swap_chain_create(RenderingContextDriver::SurfaceID p_surface) override final; + virtual Error swap_chain_resize(CommandQueueID p_cmd_queue, SwapChainID p_swap_chain, uint32_t p_desired_framebuffer_count) override final; + virtual FramebufferID swap_chain_acquire_framebuffer(CommandQueueID p_cmd_queue, SwapChainID p_swap_chain, bool &r_resize_required) override final; + virtual RenderPassID swap_chain_get_render_pass(SwapChainID p_swap_chain) override final; + virtual DataFormat swap_chain_get_format(SwapChainID p_swap_chain) override final; + virtual void swap_chain_free(SwapChainID p_swap_chain) override final; + +#pragma mark - Frame Buffer + + virtual FramebufferID framebuffer_create(RenderPassID p_render_pass, VectorView p_attachments, uint32_t p_width, uint32_t p_height) override final; + virtual void framebuffer_free(FramebufferID p_framebuffer) override final; + +#pragma mark - Shader + +private: + // Serialization types need access to private state. + + friend struct ShaderStageData; + friend struct SpecializationConstantData; + friend struct UniformData; + friend struct ShaderBinaryData; + friend struct PushConstantData; + +private: + Error _reflect_spirv16(VectorView p_spirv, ShaderReflection &r_reflection); + +public: + virtual String shader_get_binary_cache_key() override final; + virtual Vector shader_compile_binary_from_spirv(VectorView p_spirv, const String &p_shader_name) override final; + virtual ShaderID shader_create_from_bytecode(const Vector &p_shader_binary, ShaderDescription &r_shader_desc, String &r_name) override final; + virtual void shader_free(ShaderID p_shader) override final; + +#pragma mark - Uniform Set + +public: + virtual UniformSetID uniform_set_create(VectorView p_uniforms, ShaderID p_shader, uint32_t p_set_index) override final; + virtual void uniform_set_free(UniformSetID p_uniform_set) override final; + +#pragma mark - Commands + + virtual void command_uniform_set_prepare_for_use(CommandBufferID p_cmd_buffer, UniformSetID p_uniform_set, ShaderID p_shader, uint32_t p_set_index) override final; + +#pragma mark Transfer + +private: + enum class CopySource { + Buffer, + Texture, + }; + void _copy_texture_buffer(CommandBufferID p_cmd_buffer, + CopySource p_source, + TextureID p_texture, + BufferID p_buffer, + VectorView p_regions); + +public: + virtual void command_clear_buffer(CommandBufferID p_cmd_buffer, BufferID p_buffer, uint64_t p_offset, uint64_t p_size) override final; + virtual void command_copy_buffer(CommandBufferID p_cmd_buffer, BufferID p_src_buffer, BufferID p_dst_buffer, VectorView p_regions) override final; + + virtual void command_copy_texture(CommandBufferID p_cmd_buffer, TextureID p_src_texture, TextureLayout p_src_texture_layout, TextureID p_dst_texture, TextureLayout p_dst_texture_layout, VectorView p_regions) override final; + virtual void command_resolve_texture(CommandBufferID p_cmd_buffer, TextureID p_src_texture, TextureLayout p_src_texture_layout, uint32_t p_src_layer, uint32_t p_src_mipmap, TextureID p_dst_texture, TextureLayout p_dst_texture_layout, uint32_t p_dst_layer, uint32_t p_dst_mipmap) override final; + virtual void command_clear_color_texture(CommandBufferID p_cmd_buffer, TextureID p_texture, TextureLayout p_texture_layout, const Color &p_color, const TextureSubresourceRange &p_subresources) override final; + + virtual void command_copy_buffer_to_texture(CommandBufferID p_cmd_buffer, BufferID p_src_buffer, TextureID p_dst_texture, TextureLayout p_dst_texture_layout, VectorView p_regions) override final; + virtual void command_copy_texture_to_buffer(CommandBufferID p_cmd_buffer, TextureID p_src_texture, TextureLayout p_src_texture_layout, BufferID p_dst_buffer, VectorView p_regions) override final; + +#pragma mark Pipeline + +private: + Result> _create_function(id p_library, NSString *p_name, VectorView &p_specialization_constants); + +public: + virtual void pipeline_free(PipelineID p_pipeline_id) override final; + + // ----- BINDING ----- + + virtual void command_bind_push_constants(CommandBufferID p_cmd_buffer, ShaderID p_shader, uint32_t p_first_index, VectorView p_data) override final; + + // ----- CACHE ----- +private: + String _pipeline_get_cache_path() const; + +public: + virtual bool pipeline_cache_create(const Vector &p_data) override final; + virtual void pipeline_cache_free() override final; + virtual size_t pipeline_cache_query_size() override final; + virtual Vector pipeline_cache_serialize() override final; + +#pragma mark Rendering + + // ----- SUBPASS ----- + + virtual RenderPassID render_pass_create(VectorView p_attachments, VectorView p_subpasses, VectorView p_subpass_dependencies, uint32_t p_view_count) override final; + virtual void render_pass_free(RenderPassID p_render_pass) override final; + + // ----- COMMANDS ----- + +public: + virtual void command_begin_render_pass(CommandBufferID p_cmd_buffer, RenderPassID p_render_pass, FramebufferID p_framebuffer, CommandBufferType p_cmd_buffer_type, const Rect2i &p_rect, VectorView p_clear_values) override final; + virtual void command_end_render_pass(CommandBufferID p_cmd_buffer) override final; + virtual void command_next_render_subpass(CommandBufferID p_cmd_buffer, CommandBufferType p_cmd_buffer_type) override final; + virtual void command_render_set_viewport(CommandBufferID p_cmd_buffer, VectorView p_viewports) override final; + virtual void command_render_set_scissor(CommandBufferID p_cmd_buffer, VectorView p_scissors) override final; + virtual void command_render_clear_attachments(CommandBufferID p_cmd_buffer, VectorView p_attachment_clears, VectorView p_rects) override final; + + // Binding. + virtual void command_bind_render_pipeline(CommandBufferID p_cmd_buffer, PipelineID p_pipeline) override final; + virtual void command_bind_render_uniform_set(CommandBufferID p_cmd_buffer, UniformSetID p_uniform_set, ShaderID p_shader, uint32_t p_set_index) override final; + + // Drawing. + virtual void command_render_draw(CommandBufferID p_cmd_buffer, uint32_t p_vertex_count, uint32_t p_instance_count, uint32_t p_base_vertex, uint32_t p_first_instance) override final; + virtual void command_render_draw_indexed(CommandBufferID p_cmd_buffer, uint32_t p_index_count, uint32_t p_instance_count, uint32_t p_first_index, int32_t p_vertex_offset, uint32_t p_first_instance) override final; + virtual void command_render_draw_indexed_indirect(CommandBufferID p_cmd_buffer, BufferID p_indirect_buffer, uint64_t p_offset, uint32_t p_draw_count, uint32_t p_stride) override final; + virtual void command_render_draw_indexed_indirect_count(CommandBufferID p_cmd_buffer, BufferID p_indirect_buffer, uint64_t p_offset, BufferID p_count_buffer, uint64_t p_count_buffer_offset, uint32_t p_max_draw_count, uint32_t p_stride) override final; + virtual void command_render_draw_indirect(CommandBufferID p_cmd_buffer, BufferID p_indirect_buffer, uint64_t p_offset, uint32_t p_draw_count, uint32_t p_stride) override final; + virtual void command_render_draw_indirect_count(CommandBufferID p_cmd_buffer, BufferID p_indirect_buffer, uint64_t p_offset, BufferID p_count_buffer, uint64_t p_count_buffer_offset, uint32_t p_max_draw_count, uint32_t p_stride) override final; + + // Buffer binding. + virtual void command_render_bind_vertex_buffers(CommandBufferID p_cmd_buffer, uint32_t p_binding_count, const BufferID *p_buffers, const uint64_t *p_offsets) override final; + virtual void command_render_bind_index_buffer(CommandBufferID p_cmd_buffer, BufferID p_buffer, IndexBufferFormat p_format, uint64_t p_offset) override final; + + // Dynamic state. + virtual void command_render_set_blend_constants(CommandBufferID p_cmd_buffer, const Color &p_constants) override final; + virtual void command_render_set_line_width(CommandBufferID p_cmd_buffer, float p_width) override final; + + // ----- PIPELINE ----- + + virtual PipelineID render_pipeline_create( + ShaderID p_shader, + VertexFormatID p_vertex_format, + RenderPrimitive p_render_primitive, + PipelineRasterizationState p_rasterization_state, + PipelineMultisampleState p_multisample_state, + PipelineDepthStencilState p_depth_stencil_state, + PipelineColorBlendState p_blend_state, + VectorView p_color_attachments, + BitField p_dynamic_state, + RenderPassID p_render_pass, + uint32_t p_render_subpass, + VectorView p_specialization_constants) override final; + +#pragma mark - Compute + + // ----- COMMANDS ----- + + // Binding. + virtual void command_bind_compute_pipeline(CommandBufferID p_cmd_buffer, PipelineID p_pipeline) override final; + virtual void command_bind_compute_uniform_set(CommandBufferID p_cmd_buffer, UniformSetID p_uniform_set, ShaderID p_shader, uint32_t p_set_index) override final; + + // Dispatching. + virtual void command_compute_dispatch(CommandBufferID p_cmd_buffer, uint32_t p_x_groups, uint32_t p_y_groups, uint32_t p_z_groups) override final; + virtual void command_compute_dispatch_indirect(CommandBufferID p_cmd_buffer, BufferID p_indirect_buffer, uint64_t p_offset) override final; + + // ----- PIPELINE ----- + + virtual PipelineID compute_pipeline_create(ShaderID p_shader, VectorView p_specialization_constants) override final; + +#pragma mark - Queries + + // ----- TIMESTAMP ----- + + // Basic. + virtual QueryPoolID timestamp_query_pool_create(uint32_t p_query_count) override final; + virtual void timestamp_query_pool_free(QueryPoolID p_pool_id) override final; + virtual void timestamp_query_pool_get_results(QueryPoolID p_pool_id, uint32_t p_query_count, uint64_t *r_results) override final; + virtual uint64_t timestamp_query_result_to_time(uint64_t p_result) override final; + + // Commands. + virtual void command_timestamp_query_pool_reset(CommandBufferID p_cmd_buffer, QueryPoolID p_pool_id, uint32_t p_query_count) override final; + virtual void command_timestamp_write(CommandBufferID p_cmd_buffer, QueryPoolID p_pool_id, uint32_t p_index) override final; + +#pragma mark - Labels + + virtual void command_begin_label(CommandBufferID p_cmd_buffer, const char *p_label_name, const Color &p_color) override final; + virtual void command_end_label(CommandBufferID p_cmd_buffer) override final; + +#pragma mark - Submission + + virtual void begin_segment(uint32_t p_frame_index, uint32_t p_frames_drawn) override final; + virtual void end_segment() override final; + +#pragma mark - Miscellaneous + + virtual void set_object_name(ObjectType p_type, ID p_driver_id, const String &p_name) override final; + virtual uint64_t get_resource_native_handle(DriverResource p_type, ID p_driver_id) override final; + virtual uint64_t get_total_memory_used() override final; + virtual uint64_t limit_get(Limit p_limit) override final; + virtual uint64_t api_trait_get(ApiTrait p_trait) override final; + virtual bool has_feature(Features p_feature) override final; + virtual const MultiviewCapabilities &get_multiview_capabilities() override final; + virtual String get_api_name() const override final { return "Metal"; }; + virtual String get_api_version() const override final; + virtual String get_pipeline_cache_uuid() const override final; + virtual const Capabilities &get_capabilities() const override final; + virtual bool is_composite_alpha_supported(CommandQueueID p_queue) const override final; + + // Metal-specific. + id get_device() const { return device; } + PixelFormats &get_pixel_formats() const { return *pixel_formats; } + MDResourceCache &get_resource_cache() const { return *resource_cache; } + MetalDeviceProperties const &get_device_properties() const { return *metal_device_properties; } + + _FORCE_INLINE_ uint32_t get_metal_buffer_index_for_vertex_attribute_binding(uint32_t p_binding) { + return (metal_device_properties->limits.maxPerStageBufferCount - 1) - p_binding; + } + + size_t get_texel_buffer_alignment_for_format(RDD::DataFormat p_format) const; + size_t get_texel_buffer_alignment_for_format(MTLPixelFormat p_format) const; + + /******************/ + RenderingDeviceDriverMetal(RenderingContextDriverMetal *p_context_driver); + ~RenderingDeviceDriverMetal(); +}; + +#endif // RENDERING_DEVICE_DRIVER_METAL_H diff --git a/drivers/metal/rendering_device_driver_metal.mm b/drivers/metal/rendering_device_driver_metal.mm new file mode 100644 index 00000000000..cde730df9c2 --- /dev/null +++ b/drivers/metal/rendering_device_driver_metal.mm @@ -0,0 +1,3883 @@ +/**************************************************************************/ +/* rendering_device_driver_metal.mm */ +/**************************************************************************/ +/* This file is part of: */ +/* GODOT ENGINE */ +/* https://godotengine.org */ +/**************************************************************************/ +/* Copyright (c) 2014-present Godot Engine contributors (see AUTHORS.md). */ +/* Copyright (c) 2007-2014 Juan Linietsky, Ariel Manzur. */ +/* */ +/* Permission is hereby granted, free of charge, to any person obtaining */ +/* a copy of this software and associated documentation files (the */ +/* "Software"), to deal in the Software without restriction, including */ +/* without limitation the rights to use, copy, modify, merge, publish, */ +/* distribute, sublicense, and/or sell copies of the Software, and to */ +/* permit persons to whom the Software is furnished to do so, subject to */ +/* the following conditions: */ +/* */ +/* The above copyright notice and this permission notice shall be */ +/* included in all copies or substantial portions of the Software. */ +/* */ +/* THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, */ +/* EXPRESS OR IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF */ +/* MERCHANTABILITY, FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. */ +/* IN NO EVENT SHALL THE AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY */ +/* CLAIM, DAMAGES OR OTHER LIABILITY, WHETHER IN AN ACTION OF CONTRACT, */ +/* TORT OR OTHERWISE, ARISING FROM, OUT OF OR IN CONNECTION WITH THE */ +/* SOFTWARE OR THE USE OR OTHER DEALINGS IN THE SOFTWARE. */ +/**************************************************************************/ + +/**************************************************************************/ +/* */ +/* Portions of this code were derived from MoltenVK. */ +/* */ +/* Copyright (c) 2015-2023 The Brenwill Workshop Ltd. */ +/* (http://www.brenwill.com) */ +/* */ +/* Licensed under the Apache License, Version 2.0 (the "License"); */ +/* you may not use this file except in compliance with the License. */ +/* You may obtain a copy of the License at */ +/* */ +/* http://www.apache.org/licenses/LICENSE-2.0 */ +/* */ +/* Unless required by applicable law or agreed to in writing, software */ +/* distributed under the License is distributed on an "AS IS" BASIS, */ +/* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or */ +/* implied. See the License for the specific language governing */ +/* permissions and limitations under the License. */ +/**************************************************************************/ + +#import "rendering_device_driver_metal.h" + +#import "pixel_formats.h" +#import "rendering_context_driver_metal.h" + +#import "core/io/compression.h" +#import "core/io/marshalls.h" +#import "core/string/ustring.h" +#import "core/templates/hash_map.h" + +#import +#import +#import +#import + +/*****************/ +/**** GENERIC ****/ +/*****************/ + +// RDD::CompareOperator == VkCompareOp. +static_assert(ENUM_MEMBERS_EQUAL(RDD::COMPARE_OP_NEVER, MTLCompareFunctionNever)); +static_assert(ENUM_MEMBERS_EQUAL(RDD::COMPARE_OP_LESS, MTLCompareFunctionLess)); +static_assert(ENUM_MEMBERS_EQUAL(RDD::COMPARE_OP_EQUAL, MTLCompareFunctionEqual)); +static_assert(ENUM_MEMBERS_EQUAL(RDD::COMPARE_OP_LESS_OR_EQUAL, MTLCompareFunctionLessEqual)); +static_assert(ENUM_MEMBERS_EQUAL(RDD::COMPARE_OP_GREATER, MTLCompareFunctionGreater)); +static_assert(ENUM_MEMBERS_EQUAL(RDD::COMPARE_OP_NOT_EQUAL, MTLCompareFunctionNotEqual)); +static_assert(ENUM_MEMBERS_EQUAL(RDD::COMPARE_OP_GREATER_OR_EQUAL, MTLCompareFunctionGreaterEqual)); +static_assert(ENUM_MEMBERS_EQUAL(RDD::COMPARE_OP_ALWAYS, MTLCompareFunctionAlways)); + +_FORCE_INLINE_ MTLSize mipmapLevelSizeFromTexture(id p_tex, NSUInteger p_level) { + MTLSize lvlSize; + lvlSize.width = MAX(p_tex.width >> p_level, 1UL); + lvlSize.height = MAX(p_tex.height >> p_level, 1UL); + lvlSize.depth = MAX(p_tex.depth >> p_level, 1UL); + return lvlSize; +} + +_FORCE_INLINE_ MTLSize mipmapLevelSizeFromSize(MTLSize p_size, NSUInteger p_level) { + if (p_level == 0) { + return p_size; + } + + MTLSize lvlSize; + lvlSize.width = MAX(p_size.width >> p_level, 1UL); + lvlSize.height = MAX(p_size.height >> p_level, 1UL); + lvlSize.depth = MAX(p_size.depth >> p_level, 1UL); + return lvlSize; +} + +_FORCE_INLINE_ static bool operator==(MTLSize p_a, MTLSize p_b) { + return p_a.width == p_b.width && p_a.height == p_b.height && p_a.depth == p_b.depth; +} + +/*****************/ +/**** BUFFERS ****/ +/*****************/ + +RDD::BufferID RenderingDeviceDriverMetal::buffer_create(uint64_t p_size, BitField p_usage, MemoryAllocationType p_allocation_type) { + MTLResourceOptions options = MTLResourceHazardTrackingModeTracked; + switch (p_allocation_type) { + case MEMORY_ALLOCATION_TYPE_CPU: + options |= MTLResourceStorageModeShared; + break; + case MEMORY_ALLOCATION_TYPE_GPU: + options |= MTLResourceStorageModePrivate; + break; + } + + id obj = [device newBufferWithLength:p_size options:options]; + ERR_FAIL_NULL_V_MSG(obj, BufferID(), "Can't create buffer of size: " + itos(p_size)); + return rid::make(obj); +} + +bool RenderingDeviceDriverMetal::buffer_set_texel_format(BufferID p_buffer, DataFormat p_format) { + // Nothing to do. + return true; +} + +void RenderingDeviceDriverMetal::buffer_free(BufferID p_buffer) { + rid::release(p_buffer); +} + +uint64_t RenderingDeviceDriverMetal::buffer_get_allocation_size(BufferID p_buffer) { + id obj = rid::get(p_buffer); + return obj.allocatedSize; +} + +uint8_t *RenderingDeviceDriverMetal::buffer_map(BufferID p_buffer) { + id obj = rid::get(p_buffer); + ERR_FAIL_COND_V_MSG(obj.storageMode != MTLStorageModeShared, nullptr, "Unable to map private buffers"); + return (uint8_t *)obj.contents; +} + +void RenderingDeviceDriverMetal::buffer_unmap(BufferID p_buffer) { + // Nothing to do. +} + +#pragma mark - Texture + +#pragma mark - Format Conversions + +static const MTLTextureType TEXTURE_TYPE[RD::TEXTURE_TYPE_MAX] = { + MTLTextureType1D, + MTLTextureType2D, + MTLTextureType3D, + MTLTextureTypeCube, + MTLTextureType1DArray, + MTLTextureType2DArray, + MTLTextureTypeCubeArray, +}; + +RenderingDeviceDriverMetal::Result RenderingDeviceDriverMetal::is_valid_linear(TextureFormat const &p_format) const { + if (!flags::any(p_format.usage_bits, TEXTURE_USAGE_CPU_READ_BIT)) { + return false; + } + + PixelFormats &pf = *pixel_formats; + MTLFormatType ft = pf.getFormatType(p_format.format); + + // Requesting a linear format, which has further restrictions, similar to Vulkan + // when specifying VK_IMAGE_TILING_LINEAR. + + ERR_FAIL_COND_V_MSG(p_format.texture_type != TEXTURE_TYPE_2D, ERR_CANT_CREATE, "Linear (TEXTURE_USAGE_CPU_READ_BIT) textures must be 2D"); + ERR_FAIL_COND_V_MSG(ft != MTLFormatType::DepthStencil, ERR_CANT_CREATE, "Linear (TEXTURE_USAGE_CPU_READ_BIT) textures must not be a depth/stencil format"); + ERR_FAIL_COND_V_MSG(ft != MTLFormatType::Compressed, ERR_CANT_CREATE, "Linear (TEXTURE_USAGE_CPU_READ_BIT) textures must not be a compressed format"); + ERR_FAIL_COND_V_MSG(p_format.mipmaps != 1, ERR_CANT_CREATE, "Linear (TEXTURE_USAGE_CPU_READ_BIT) textures must have 1 mipmap level"); + ERR_FAIL_COND_V_MSG(p_format.array_layers != 1, ERR_CANT_CREATE, "Linear (TEXTURE_USAGE_CPU_READ_BIT) textures must have 1 array layer"); + ERR_FAIL_COND_V_MSG(p_format.samples != TEXTURE_SAMPLES_1, ERR_CANT_CREATE, "Linear (TEXTURE_USAGE_CPU_READ_BIT) textures must have 1 sample"); + + return true; +} + +RDD::TextureID RenderingDeviceDriverMetal::texture_create(const TextureFormat &p_format, const TextureView &p_view) { + MTLTextureDescriptor *desc = [MTLTextureDescriptor new]; + desc.textureType = TEXTURE_TYPE[p_format.texture_type]; + + PixelFormats &formats = *pixel_formats; + desc.pixelFormat = formats.getMTLPixelFormat(p_format.format); + MTLFmtCaps format_caps = formats.getCapabilities(desc.pixelFormat); + + desc.width = p_format.width; + desc.height = p_format.height; + desc.depth = p_format.depth; + desc.mipmapLevelCount = p_format.mipmaps; + + if (p_format.texture_type == TEXTURE_TYPE_1D_ARRAY || + p_format.texture_type == TEXTURE_TYPE_2D_ARRAY) { + desc.arrayLength = p_format.array_layers; + } else if (p_format.texture_type == TEXTURE_TYPE_CUBE_ARRAY) { + desc.arrayLength = p_format.array_layers / 6; + } + + // TODO(sgc): Evaluate lossy texture support (perhaps as a project option?) + // https://developer.apple.com/videos/play/tech-talks/10876?time=459 + // desc.compressionType = MTLTextureCompressionTypeLossy; + + if (p_format.samples > TEXTURE_SAMPLES_1) { + SampleCount supported = (*metal_device_properties).find_nearest_supported_sample_count(p_format.samples); + + if (supported > SampleCount1) { + bool ok = p_format.texture_type == TEXTURE_TYPE_2D || p_format.texture_type == TEXTURE_TYPE_2D_ARRAY; + if (ok) { + switch (p_format.texture_type) { + case TEXTURE_TYPE_2D: + desc.textureType = MTLTextureType2DMultisample; + break; + case TEXTURE_TYPE_2D_ARRAY: + desc.textureType = MTLTextureType2DMultisampleArray; + break; + default: + break; + } + desc.sampleCount = (NSUInteger)supported; + if (p_format.mipmaps > 1) { + // For a buffer-backed or multi-sample texture, the value must be 1. + WARN_PRINT("mipmaps == 1 for multi-sample textures"); + desc.mipmapLevelCount = 1; + } + } else { + WARN_PRINT("Unsupported multi-sample texture type; disabling multi-sample"); + } + } + } + + static const MTLTextureSwizzle COMPONENT_SWIZZLE[TEXTURE_SWIZZLE_MAX] = { + static_cast(255), // IDENTITY + MTLTextureSwizzleZero, + MTLTextureSwizzleOne, + MTLTextureSwizzleRed, + MTLTextureSwizzleGreen, + MTLTextureSwizzleBlue, + MTLTextureSwizzleAlpha, + }; + + MTLTextureSwizzleChannels swizzle = MTLTextureSwizzleChannelsMake( + p_view.swizzle_r != TEXTURE_SWIZZLE_IDENTITY ? COMPONENT_SWIZZLE[p_view.swizzle_r] : MTLTextureSwizzleRed, + p_view.swizzle_g != TEXTURE_SWIZZLE_IDENTITY ? COMPONENT_SWIZZLE[p_view.swizzle_g] : MTLTextureSwizzleGreen, + p_view.swizzle_b != TEXTURE_SWIZZLE_IDENTITY ? COMPONENT_SWIZZLE[p_view.swizzle_b] : MTLTextureSwizzleBlue, + p_view.swizzle_a != TEXTURE_SWIZZLE_IDENTITY ? COMPONENT_SWIZZLE[p_view.swizzle_a] : MTLTextureSwizzleAlpha); + + // Represents a swizzle operation that is a no-op. + static MTLTextureSwizzleChannels IDENTITY_SWIZZLE = { + .red = MTLTextureSwizzleRed, + .green = MTLTextureSwizzleGreen, + .blue = MTLTextureSwizzleBlue, + .alpha = MTLTextureSwizzleAlpha, + }; + + bool no_swizzle = memcmp(&IDENTITY_SWIZZLE, &swizzle, sizeof(MTLTextureSwizzleChannels)) == 0; + if (!no_swizzle) { + desc.swizzle = swizzle; + } + + // Usage. + MTLResourceOptions options = MTLResourceCPUCacheModeDefaultCache | MTLResourceHazardTrackingModeTracked; + if (p_format.usage_bits & TEXTURE_USAGE_CPU_READ_BIT) { + options |= MTLResourceStorageModeShared; + } else { + options |= MTLResourceStorageModePrivate; + } + desc.resourceOptions = options; + + if (p_format.usage_bits & TEXTURE_USAGE_SAMPLING_BIT) { + desc.usage |= MTLTextureUsageShaderRead; + } + + if (p_format.usage_bits & TEXTURE_USAGE_STORAGE_BIT) { + desc.usage |= MTLTextureUsageShaderWrite; + } + + if (p_format.usage_bits & TEXTURE_USAGE_STORAGE_ATOMIC_BIT) { + desc.usage |= MTLTextureUsageShaderWrite; + } + + bool can_be_attachment = flags::any(format_caps, (kMTLFmtCapsColorAtt | kMTLFmtCapsDSAtt)); + + if (flags::any(p_format.usage_bits, TEXTURE_USAGE_COLOR_ATTACHMENT_BIT | TEXTURE_USAGE_DEPTH_STENCIL_ATTACHMENT_BIT) && + can_be_attachment) { + desc.usage |= MTLTextureUsageRenderTarget; + } + + if (p_format.usage_bits & TEXTURE_USAGE_INPUT_ATTACHMENT_BIT) { + desc.usage |= MTLTextureUsageShaderRead; + } + + if (p_format.usage_bits & TEXTURE_USAGE_VRS_ATTACHMENT_BIT) { + ERR_FAIL_V_MSG(RDD::TextureID(), "unsupported: TEXTURE_USAGE_VRS_ATTACHMENT_BIT"); + } + + if (flags::any(p_format.usage_bits, TEXTURE_USAGE_CAN_UPDATE_BIT | TEXTURE_USAGE_CAN_COPY_TO_BIT) && + can_be_attachment && no_swizzle) { + // Per MoltenVK, can be cleared as a render attachment. + desc.usage |= MTLTextureUsageRenderTarget; + } + if (p_format.usage_bits & TEXTURE_USAGE_CAN_COPY_FROM_BIT) { + // Covered by blits. + } + + // Create texture views with a different component layout. + if (!p_format.shareable_formats.is_empty()) { + desc.usage |= MTLTextureUsagePixelFormatView; + } + + // Allocate memory. + + bool is_linear; + { + Result is_linear_or_err = is_valid_linear(p_format); + ERR_FAIL_COND_V(std::holds_alternative(is_linear_or_err), TextureID()); + is_linear = std::get(is_linear_or_err); + } + + // Check if it is a linear format for atomic operations and therefore needs a buffer, + // as generally Metal does not support atomic operations on textures. + bool needs_buffer = is_linear || (p_format.array_layers == 1 && p_format.mipmaps == 1 && p_format.texture_type == TEXTURE_TYPE_2D && flags::any(p_format.usage_bits, TEXTURE_USAGE_STORAGE_BIT) && (p_format.format == DATA_FORMAT_R32_UINT || p_format.format == DATA_FORMAT_R32_SINT)); + + id obj = nil; + if (needs_buffer) { + // Linear textures are restricted to 2D textures, a single mipmap level and a single array layer. + MTLPixelFormat pixel_format = desc.pixelFormat; + size_t row_alignment = get_texel_buffer_alignment_for_format(p_format.format); + size_t bytes_per_row = formats.getBytesPerRow(pixel_format, p_format.width); + bytes_per_row = round_up_to_alignment(bytes_per_row, row_alignment); + size_t bytes_per_layer = formats.getBytesPerLayer(pixel_format, bytes_per_row, p_format.height); + size_t byte_count = bytes_per_layer * p_format.depth * p_format.array_layers; + + id buf = [device newBufferWithLength:byte_count options:options]; + obj = [buf newTextureWithDescriptor:desc offset:0 bytesPerRow:bytes_per_row]; + } else { + obj = [device newTextureWithDescriptor:desc]; + } + ERR_FAIL_NULL_V_MSG(obj, TextureID(), "Unable to create texture."); + + return rid::make(obj); +} + +RDD::TextureID RenderingDeviceDriverMetal::texture_create_from_extension(uint64_t p_native_texture, TextureType p_type, DataFormat p_format, uint32_t p_array_layers, bool p_depth_stencil) { + ERR_FAIL_V_MSG(RDD::TextureID(), "not implemented"); +} + +RDD::TextureID RenderingDeviceDriverMetal::texture_create_shared(TextureID p_original_texture, const TextureView &p_view) { + id src_texture = rid::get(p_original_texture); + +#if DEV_ENABLED + if (src_texture.sampleCount > 1) { + // TODO(sgc): is it ok to create a shared texture from a multi-sample texture? + WARN_PRINT("Is it safe to create a shared texture from multi-sample texture?"); + } +#endif + + MTLPixelFormat format = pixel_formats->getMTLPixelFormat(p_view.format); + + static const MTLTextureSwizzle component_swizzle[TEXTURE_SWIZZLE_MAX] = { + static_cast(255), // IDENTITY + MTLTextureSwizzleZero, + MTLTextureSwizzleOne, + MTLTextureSwizzleRed, + MTLTextureSwizzleGreen, + MTLTextureSwizzleBlue, + MTLTextureSwizzleAlpha, + }; + +#define SWIZZLE(C, CHAN) (p_view.swizzle_##C != TEXTURE_SWIZZLE_IDENTITY ? component_swizzle[p_view.swizzle_##C] : MTLTextureSwizzle##CHAN) + MTLTextureSwizzleChannels swizzle = MTLTextureSwizzleChannelsMake( + SWIZZLE(r, Red), + SWIZZLE(g, Green), + SWIZZLE(b, Blue), + SWIZZLE(a, Alpha)); +#undef SWIZZLE + id obj = [src_texture newTextureViewWithPixelFormat:format + textureType:src_texture.textureType + levels:NSMakeRange(0, src_texture.mipmapLevelCount) + slices:NSMakeRange(0, src_texture.arrayLength) + swizzle:swizzle]; + ERR_FAIL_NULL_V_MSG(obj, TextureID(), "Unable to create shared texture"); + return rid::make(obj); +} + +RDD::TextureID RenderingDeviceDriverMetal::texture_create_shared_from_slice(TextureID p_original_texture, const TextureView &p_view, TextureSliceType p_slice_type, uint32_t p_layer, uint32_t p_layers, uint32_t p_mipmap, uint32_t p_mipmaps) { + id src_texture = rid::get(p_original_texture); + + static const MTLTextureType VIEW_TYPES[] = { + MTLTextureType1D, // MTLTextureType1D + MTLTextureType1D, // MTLTextureType1DArray + MTLTextureType2D, // MTLTextureType2D + MTLTextureType2D, // MTLTextureType2DArray + MTLTextureType2D, // MTLTextureType2DMultisample + MTLTextureType2D, // MTLTextureTypeCube + MTLTextureType2D, // MTLTextureTypeCubeArray + MTLTextureType2D, // MTLTextureType3D + MTLTextureType2D, // MTLTextureType2DMultisampleArray + }; + + MTLTextureType textureType = VIEW_TYPES[src_texture.textureType]; + switch (p_slice_type) { + case TEXTURE_SLICE_2D: { + textureType = MTLTextureType2D; + } break; + case TEXTURE_SLICE_3D: { + textureType = MTLTextureType3D; + } break; + case TEXTURE_SLICE_CUBEMAP: { + textureType = MTLTextureTypeCube; + } break; + case TEXTURE_SLICE_2D_ARRAY: { + textureType = MTLTextureType2DArray; + } break; + case TEXTURE_SLICE_MAX: { + ERR_FAIL_V_MSG(TextureID(), "Invalid texture slice type"); + } break; + } + + MTLPixelFormat format = pixel_formats->getMTLPixelFormat(p_view.format); + + static const MTLTextureSwizzle component_swizzle[TEXTURE_SWIZZLE_MAX] = { + static_cast(255), // IDENTITY + MTLTextureSwizzleZero, + MTLTextureSwizzleOne, + MTLTextureSwizzleRed, + MTLTextureSwizzleGreen, + MTLTextureSwizzleBlue, + MTLTextureSwizzleAlpha, + }; + +#define SWIZZLE(C, CHAN) (p_view.swizzle_##C != TEXTURE_SWIZZLE_IDENTITY ? component_swizzle[p_view.swizzle_##C] : MTLTextureSwizzle##CHAN) + MTLTextureSwizzleChannels swizzle = MTLTextureSwizzleChannelsMake( + SWIZZLE(r, Red), + SWIZZLE(g, Green), + SWIZZLE(b, Blue), + SWIZZLE(a, Alpha)); +#undef SWIZZLE + id obj = [src_texture newTextureViewWithPixelFormat:format + textureType:textureType + levels:NSMakeRange(p_mipmap, p_mipmaps) + slices:NSMakeRange(p_layer, p_layers) + swizzle:swizzle]; + ERR_FAIL_NULL_V_MSG(obj, TextureID(), "Unable to create shared texture"); + return rid::make(obj); +} + +void RenderingDeviceDriverMetal::texture_free(TextureID p_texture) { + rid::release(p_texture); +} + +uint64_t RenderingDeviceDriverMetal::texture_get_allocation_size(TextureID p_texture) { + id obj = rid::get(p_texture); + return obj.allocatedSize; +} + +void RenderingDeviceDriverMetal::_get_sub_resource(TextureID p_texture, const TextureSubresource &p_subresource, TextureCopyableLayout *r_layout) const { + id obj = rid::get(p_texture); + + *r_layout = {}; + + PixelFormats &pf = *pixel_formats; + + size_t row_alignment = get_texel_buffer_alignment_for_format(obj.pixelFormat); + size_t offset = 0; + size_t array_layers = obj.arrayLength; + MTLSize size = MTLSizeMake(obj.width, obj.height, obj.depth); + MTLPixelFormat pixel_format = obj.pixelFormat; + + // First skip over the mipmap levels. + for (uint32_t mipLvl = 0; mipLvl < p_subresource.mipmap; mipLvl++) { + MTLSize mip_size = mipmapLevelSizeFromSize(size, mipLvl); + size_t bytes_per_row = pf.getBytesPerRow(pixel_format, mip_size.width); + bytes_per_row = round_up_to_alignment(bytes_per_row, row_alignment); + size_t bytes_per_layer = pf.getBytesPerLayer(pixel_format, bytes_per_row, mip_size.height); + offset += bytes_per_layer * mip_size.depth * array_layers; + } + + // Get current mipmap. + MTLSize mip_size = mipmapLevelSizeFromSize(size, p_subresource.mipmap); + size_t bytes_per_row = pf.getBytesPerRow(pixel_format, mip_size.width); + bytes_per_row = round_up_to_alignment(bytes_per_row, row_alignment); + size_t bytes_per_layer = pf.getBytesPerLayer(pixel_format, bytes_per_row, mip_size.height); + r_layout->size = bytes_per_layer * mip_size.depth; + r_layout->offset = offset + (r_layout->size * p_subresource.layer - 1); + r_layout->depth_pitch = bytes_per_layer; + r_layout->row_pitch = bytes_per_row; + r_layout->layer_pitch = r_layout->size * array_layers; +} + +void RenderingDeviceDriverMetal::texture_get_copyable_layout(TextureID p_texture, const TextureSubresource &p_subresource, TextureCopyableLayout *r_layout) { + id obj = rid::get(p_texture); + *r_layout = {}; + + if ((obj.resourceOptions & MTLResourceStorageModePrivate) != 0) { + MTLSize sz = MTLSizeMake(obj.width, obj.height, obj.depth); + + PixelFormats &pf = *pixel_formats; + DataFormat format = pf.getDataFormat(obj.pixelFormat); + if (p_subresource.mipmap > 0) { + r_layout->offset = get_image_format_required_size(format, sz.width, sz.height, sz.depth, p_subresource.mipmap); + } + + sz = mipmapLevelSizeFromSize(sz, p_subresource.mipmap); + + uint32_t bw = 0, bh = 0; + get_compressed_image_format_block_dimensions(format, bw, bh); + uint32_t sbw = 0, sbh = 0; + r_layout->size = get_image_format_required_size(format, sz.width, sz.height, sz.depth, 1, &sbw, &sbh); + r_layout->row_pitch = r_layout->size / ((sbh / bh) * sz.depth); + r_layout->depth_pitch = r_layout->size / sz.depth; + r_layout->layer_pitch = r_layout->size / obj.arrayLength; + } else { + CRASH_NOW_MSG("need to calculate layout for shared texture"); + } +} + +uint8_t *RenderingDeviceDriverMetal::texture_map(TextureID p_texture, const TextureSubresource &p_subresource) { + id obj = rid::get(p_texture); + ERR_FAIL_NULL_V_MSG(obj.buffer, nullptr, "texture is not created from a buffer"); + + TextureCopyableLayout layout; + _get_sub_resource(p_texture, p_subresource, &layout); + return (uint8_t *)(obj.buffer.contents) + layout.offset; + PixelFormats &pf = *pixel_formats; + + size_t row_alignment = get_texel_buffer_alignment_for_format(obj.pixelFormat); + size_t offset = 0; + size_t array_layers = obj.arrayLength; + MTLSize size = MTLSizeMake(obj.width, obj.height, obj.depth); + MTLPixelFormat pixel_format = obj.pixelFormat; + + // First skip over the mipmap levels. + for (uint32_t mipLvl = 0; mipLvl < p_subresource.mipmap; mipLvl++) { + MTLSize mipExtent = mipmapLevelSizeFromSize(size, mipLvl); + size_t bytes_per_row = pf.getBytesPerRow(pixel_format, mipExtent.width); + bytes_per_row = round_up_to_alignment(bytes_per_row, row_alignment); + size_t bytes_per_layer = pf.getBytesPerLayer(pixel_format, bytes_per_row, mipExtent.height); + offset += bytes_per_layer * mipExtent.depth * array_layers; + } + + if (p_subresource.layer > 1) { + // Calculate offset to desired layer. + MTLSize mipExtent = mipmapLevelSizeFromSize(size, p_subresource.mipmap); + size_t bytes_per_row = pf.getBytesPerRow(pixel_format, mipExtent.width); + bytes_per_row = round_up_to_alignment(bytes_per_row, row_alignment); + size_t bytes_per_layer = pf.getBytesPerLayer(pixel_format, bytes_per_row, mipExtent.height); + offset += bytes_per_layer * mipExtent.depth * (p_subresource.layer - 1); + } + + // TODO: Confirm with rendering team that there is no other way Godot may attempt to map a texture with multiple mipmaps or array layers. + + // NOTE: It is not possible to create a buffer-backed texture with mipmaps or array layers, + // as noted in the is_valid_linear function, so the offset calculation SHOULD always be zero. + // Given that, this code should be simplified. + + return (uint8_t *)(obj.buffer.contents) + offset; +} + +void RenderingDeviceDriverMetal::texture_unmap(TextureID p_texture) { + // Nothing to do. +} + +BitField RenderingDeviceDriverMetal::texture_get_usages_supported_by_format(DataFormat p_format, bool p_cpu_readable) { + PixelFormats &pf = *pixel_formats; + if (pf.getMTLPixelFormat(p_format) == MTLPixelFormatInvalid) { + return 0; + } + + MTLFmtCaps caps = pf.getCapabilities(p_format); + + // Everything supported by default makes an all-or-nothing check easier for the caller. + BitField supported = INT64_MAX; + supported.clear_flag(TEXTURE_USAGE_VRS_ATTACHMENT_BIT); // No VRS support for Metal. + + if (!flags::any(caps, kMTLFmtCapsColorAtt)) { + supported.clear_flag(TEXTURE_USAGE_COLOR_ATTACHMENT_BIT); + } + if (!flags::any(caps, kMTLFmtCapsDSAtt)) { + supported.clear_flag(TEXTURE_USAGE_DEPTH_STENCIL_ATTACHMENT_BIT); + } + if (!flags::any(caps, kMTLFmtCapsRead)) { + supported.clear_flag(TEXTURE_USAGE_SAMPLING_BIT); + } + if (!flags::any(caps, kMTLFmtCapsAtomic)) { + supported.clear_flag(TEXTURE_USAGE_STORAGE_ATOMIC_BIT); + } + + return supported; +} + +bool RenderingDeviceDriverMetal::texture_can_make_shared_with_format(TextureID p_texture, DataFormat p_format, bool &r_raw_reinterpretation) { + r_raw_reinterpretation = false; + return true; +} + +#pragma mark - Sampler + +static const MTLCompareFunction COMPARE_OPERATORS[RD::COMPARE_OP_MAX] = { + MTLCompareFunctionNever, + MTLCompareFunctionLess, + MTLCompareFunctionEqual, + MTLCompareFunctionLessEqual, + MTLCompareFunctionGreater, + MTLCompareFunctionNotEqual, + MTLCompareFunctionGreaterEqual, + MTLCompareFunctionAlways, +}; + +static const MTLStencilOperation STENCIL_OPERATIONS[RD::STENCIL_OP_MAX] = { + MTLStencilOperationKeep, + MTLStencilOperationZero, + MTLStencilOperationReplace, + MTLStencilOperationIncrementClamp, + MTLStencilOperationDecrementClamp, + MTLStencilOperationInvert, + MTLStencilOperationIncrementWrap, + MTLStencilOperationDecrementWrap, +}; + +static const MTLBlendFactor BLEND_FACTORS[RD::BLEND_FACTOR_MAX] = { + MTLBlendFactorZero, + MTLBlendFactorOne, + MTLBlendFactorSourceColor, + MTLBlendFactorOneMinusSourceColor, + MTLBlendFactorDestinationColor, + MTLBlendFactorOneMinusDestinationColor, + MTLBlendFactorSourceAlpha, + MTLBlendFactorOneMinusSourceAlpha, + MTLBlendFactorDestinationAlpha, + MTLBlendFactorOneMinusDestinationAlpha, + MTLBlendFactorBlendColor, + MTLBlendFactorOneMinusBlendColor, + MTLBlendFactorBlendAlpha, + MTLBlendFactorOneMinusBlendAlpha, + MTLBlendFactorSourceAlphaSaturated, + MTLBlendFactorSource1Color, + MTLBlendFactorOneMinusSource1Color, + MTLBlendFactorSource1Alpha, + MTLBlendFactorOneMinusSource1Alpha, +}; +static const MTLBlendOperation BLEND_OPERATIONS[RD::BLEND_OP_MAX] = { + MTLBlendOperationAdd, + MTLBlendOperationSubtract, + MTLBlendOperationReverseSubtract, + MTLBlendOperationMin, + MTLBlendOperationMax, +}; + +static const API_AVAILABLE(macos(11.0), ios(14.0)) MTLSamplerAddressMode ADDRESS_MODES[RD::SAMPLER_REPEAT_MODE_MAX] = { + MTLSamplerAddressModeRepeat, + MTLSamplerAddressModeMirrorRepeat, + MTLSamplerAddressModeClampToEdge, + MTLSamplerAddressModeClampToBorderColor, + MTLSamplerAddressModeMirrorClampToEdge, +}; + +static const API_AVAILABLE(macos(11.0), ios(14.0)) MTLSamplerBorderColor SAMPLER_BORDER_COLORS[RD::SAMPLER_BORDER_COLOR_MAX] = { + MTLSamplerBorderColorTransparentBlack, + MTLSamplerBorderColorTransparentBlack, + MTLSamplerBorderColorOpaqueBlack, + MTLSamplerBorderColorOpaqueBlack, + MTLSamplerBorderColorOpaqueWhite, + MTLSamplerBorderColorOpaqueWhite, +}; + +RDD::SamplerID RenderingDeviceDriverMetal::sampler_create(const SamplerState &p_state) { + MTLSamplerDescriptor *desc = [MTLSamplerDescriptor new]; + desc.supportArgumentBuffers = YES; + + desc.magFilter = p_state.mag_filter == SAMPLER_FILTER_LINEAR ? MTLSamplerMinMagFilterLinear : MTLSamplerMinMagFilterNearest; + desc.minFilter = p_state.min_filter == SAMPLER_FILTER_LINEAR ? MTLSamplerMinMagFilterLinear : MTLSamplerMinMagFilterNearest; + desc.mipFilter = p_state.mip_filter == SAMPLER_FILTER_LINEAR ? MTLSamplerMipFilterLinear : MTLSamplerMipFilterNearest; + + desc.sAddressMode = ADDRESS_MODES[p_state.repeat_u]; + desc.tAddressMode = ADDRESS_MODES[p_state.repeat_v]; + desc.rAddressMode = ADDRESS_MODES[p_state.repeat_w]; + + if (p_state.use_anisotropy) { + desc.maxAnisotropy = p_state.anisotropy_max; + } + + desc.compareFunction = COMPARE_OPERATORS[p_state.compare_op]; + + desc.lodMinClamp = p_state.min_lod; + desc.lodMaxClamp = p_state.max_lod; + + desc.borderColor = SAMPLER_BORDER_COLORS[p_state.border_color]; + + desc.normalizedCoordinates = !p_state.unnormalized_uvw; + + if (p_state.lod_bias != 0.0) { + WARN_VERBOSE("Metal does not support LOD bias for samplers."); + } + + id obj = [device newSamplerStateWithDescriptor:desc]; + ERR_FAIL_NULL_V_MSG(obj, SamplerID(), "newSamplerStateWithDescriptor failed"); + return rid::make(obj); +} + +void RenderingDeviceDriverMetal::sampler_free(SamplerID p_sampler) { + rid::release(p_sampler); +} + +bool RenderingDeviceDriverMetal::sampler_is_format_supported_for_filter(DataFormat p_format, SamplerFilter p_filter) { + switch (p_filter) { + case SAMPLER_FILTER_NEAREST: + return true; + case SAMPLER_FILTER_LINEAR: { + MTLFmtCaps caps = pixel_formats->getCapabilities(p_format); + return flags::any(caps, kMTLFmtCapsFilter); + } + } +} + +#pragma mark - Vertex Array + +RDD::VertexFormatID RenderingDeviceDriverMetal::vertex_format_create(VectorView p_vertex_attribs) { + MTLVertexDescriptor *desc = MTLVertexDescriptor.vertexDescriptor; + + for (uint32_t i = 0; i < p_vertex_attribs.size(); i++) { + VertexAttribute const &vf = p_vertex_attribs[i]; + + ERR_FAIL_COND_V_MSG(get_format_vertex_size(vf.format) == 0, VertexFormatID(), + "Data format for attachment (" + itos(i) + "), '" + FORMAT_NAMES[vf.format] + "', is not valid for a vertex array."); + + desc.attributes[vf.location].format = pixel_formats->getMTLVertexFormat(vf.format); + desc.attributes[vf.location].offset = vf.offset; + uint32_t idx = get_metal_buffer_index_for_vertex_attribute_binding(i); + desc.attributes[vf.location].bufferIndex = idx; + if (vf.stride == 0) { + desc.layouts[idx].stepFunction = MTLVertexStepFunctionConstant; + desc.layouts[idx].stepRate = 0; + desc.layouts[idx].stride = pixel_formats->getBytesPerBlock(vf.format); + } else { + desc.layouts[idx].stepFunction = vf.frequency == VERTEX_FREQUENCY_VERTEX ? MTLVertexStepFunctionPerVertex : MTLVertexStepFunctionPerInstance; + desc.layouts[idx].stepRate = 1; + desc.layouts[idx].stride = vf.stride; + } + } + + return rid::make(desc); +} + +void RenderingDeviceDriverMetal::vertex_format_free(VertexFormatID p_vertex_format) { + rid::release(p_vertex_format); +} + +#pragma mark - Barriers + +void RenderingDeviceDriverMetal::command_pipeline_barrier( + CommandBufferID p_cmd_buffer, + BitField p_src_stages, + BitField p_dst_stages, + VectorView p_memory_barriers, + VectorView p_buffer_barriers, + VectorView p_texture_barriers) { + WARN_PRINT_ONCE("not implemented"); +} + +#pragma mark - Fences + +RDD::FenceID RenderingDeviceDriverMetal::fence_create() { + Fence *fence = memnew(Fence); + return FenceID(fence); +} + +Error RenderingDeviceDriverMetal::fence_wait(FenceID p_fence) { + Fence *fence = (Fence *)(p_fence.id); + + // Wait forever, so this function is infallible. + dispatch_semaphore_wait(fence->semaphore, DISPATCH_TIME_FOREVER); + + return OK; +} + +void RenderingDeviceDriverMetal::fence_free(FenceID p_fence) { + Fence *fence = (Fence *)(p_fence.id); + memdelete(fence); +} + +#pragma mark - Semaphores + +RDD::SemaphoreID RenderingDeviceDriverMetal::semaphore_create() { + // Metal doesn't use semaphores, as their purpose within Godot is to ensure ordering of command buffer execution. + return SemaphoreID(1); +} + +void RenderingDeviceDriverMetal::semaphore_free(SemaphoreID p_semaphore) { +} + +#pragma mark - Queues + +RDD::CommandQueueFamilyID RenderingDeviceDriverMetal::command_queue_family_get(BitField p_cmd_queue_family_bits, RenderingContextDriver::SurfaceID p_surface) { + if (p_cmd_queue_family_bits.has_flag(COMMAND_QUEUE_FAMILY_GRAPHICS_BIT) || (p_surface != 0)) { + return CommandQueueFamilyID(COMMAND_QUEUE_FAMILY_GRAPHICS_BIT); + } else if (p_cmd_queue_family_bits.has_flag(COMMAND_QUEUE_FAMILY_COMPUTE_BIT)) { + return CommandQueueFamilyID(COMMAND_QUEUE_FAMILY_COMPUTE_BIT); + } else if (p_cmd_queue_family_bits.has_flag(COMMAND_QUEUE_FAMILY_TRANSFER_BIT)) { + return CommandQueueFamilyID(COMMAND_QUEUE_FAMILY_TRANSFER_BIT); + } else { + return CommandQueueFamilyID(); + } +} + +RDD::CommandQueueID RenderingDeviceDriverMetal::command_queue_create(CommandQueueFamilyID p_cmd_queue_family, bool p_identify_as_main_queue) { + return CommandQueueID(1); +} + +Error RenderingDeviceDriverMetal::command_queue_execute_and_present(CommandQueueID p_cmd_queue, VectorView, VectorView p_cmd_buffers, VectorView, FenceID p_cmd_fence, VectorView p_swap_chains) { + uint32_t size = p_cmd_buffers.size(); + if (size == 0) { + return OK; + } + + for (uint32_t i = 0; i < size - 1; i++) { + MDCommandBuffer *cmd_buffer = (MDCommandBuffer *)(p_cmd_buffers[i].id); + cmd_buffer->commit(); + } + + // The last command buffer will signal the fence and semaphores. + MDCommandBuffer *cmd_buffer = (MDCommandBuffer *)(p_cmd_buffers[size - 1].id); + Fence *fence = (Fence *)(p_cmd_fence.id); + if (fence != nullptr) { + [cmd_buffer->get_command_buffer() addCompletedHandler:^(id buffer) { + dispatch_semaphore_signal(fence->semaphore); + }]; + } + + for (uint32_t i = 0; i < p_swap_chains.size(); i++) { + SwapChain *swap_chain = (SwapChain *)(p_swap_chains[i].id); + RenderingContextDriverMetal::Surface *metal_surface = (RenderingContextDriverMetal::Surface *)(swap_chain->surface); + metal_surface->present(cmd_buffer); + } + + cmd_buffer->commit(); + + if (p_swap_chains.size() > 0) { + // Used as a signal that we're presenting, so this is the end of a frame. + [device_scope endScope]; + [device_scope beginScope]; + } + + return OK; +} + +void RenderingDeviceDriverMetal::command_queue_free(CommandQueueID p_cmd_queue) { +} + +#pragma mark - Command Buffers + +// ----- POOL ----- + +RDD::CommandPoolID RenderingDeviceDriverMetal::command_pool_create(CommandQueueFamilyID p_cmd_queue_family, CommandBufferType p_cmd_buffer_type) { + DEV_ASSERT(p_cmd_buffer_type == COMMAND_BUFFER_TYPE_PRIMARY); + return rid::make(device_queue); +} + +void RenderingDeviceDriverMetal::command_pool_free(CommandPoolID p_cmd_pool) { + rid::release(p_cmd_pool); +} + +// ----- BUFFER ----- + +RDD::CommandBufferID RenderingDeviceDriverMetal::command_buffer_create(CommandPoolID p_cmd_pool) { + id queue = rid::get(p_cmd_pool); + MDCommandBuffer *obj = new MDCommandBuffer(queue, this); + command_buffers.push_back(obj); + return CommandBufferID(obj); +} + +bool RenderingDeviceDriverMetal::command_buffer_begin(CommandBufferID p_cmd_buffer) { + MDCommandBuffer *obj = (MDCommandBuffer *)(p_cmd_buffer.id); + obj->begin(); + return true; +} + +bool RenderingDeviceDriverMetal::command_buffer_begin_secondary(CommandBufferID p_cmd_buffer, RenderPassID p_render_pass, uint32_t p_subpass, FramebufferID p_framebuffer) { + ERR_FAIL_V_MSG(false, "not implemented"); +} + +void RenderingDeviceDriverMetal::command_buffer_end(CommandBufferID p_cmd_buffer) { + MDCommandBuffer *obj = (MDCommandBuffer *)(p_cmd_buffer.id); + obj->end(); +} + +void RenderingDeviceDriverMetal::command_buffer_execute_secondary(CommandBufferID p_cmd_buffer, VectorView p_secondary_cmd_buffers) { + ERR_FAIL_MSG("not implemented"); +} + +#pragma mark - Swap Chain + +void RenderingDeviceDriverMetal::_swap_chain_release(SwapChain *p_swap_chain) { + _swap_chain_release_buffers(p_swap_chain); +} + +void RenderingDeviceDriverMetal::_swap_chain_release_buffers(SwapChain *p_swap_chain) { +} + +RDD::SwapChainID RenderingDeviceDriverMetal::swap_chain_create(RenderingContextDriver::SurfaceID p_surface) { + RenderingContextDriverMetal::Surface const *surface = (RenderingContextDriverMetal::Surface *)(p_surface); + + // Create the render pass that will be used to draw to the swap chain's framebuffers. + RDD::Attachment attachment; + attachment.format = pixel_formats->getDataFormat(surface->get_pixel_format()); + attachment.samples = RDD::TEXTURE_SAMPLES_1; + attachment.load_op = RDD::ATTACHMENT_LOAD_OP_CLEAR; + attachment.store_op = RDD::ATTACHMENT_STORE_OP_STORE; + + RDD::Subpass subpass; + RDD::AttachmentReference color_ref; + color_ref.attachment = 0; + color_ref.aspect.set_flag(RDD::TEXTURE_ASPECT_COLOR_BIT); + subpass.color_references.push_back(color_ref); + + RenderPassID render_pass = render_pass_create(attachment, subpass, {}, 1); + ERR_FAIL_COND_V(!render_pass, SwapChainID()); + + // Create the empty swap chain until it is resized. + SwapChain *swap_chain = memnew(SwapChain); + swap_chain->surface = p_surface; + swap_chain->data_format = attachment.format; + swap_chain->render_pass = render_pass; + return SwapChainID(swap_chain); +} + +Error RenderingDeviceDriverMetal::swap_chain_resize(CommandQueueID p_cmd_queue, SwapChainID p_swap_chain, uint32_t p_desired_framebuffer_count) { + DEV_ASSERT(p_cmd_queue.id != 0); + DEV_ASSERT(p_swap_chain.id != 0); + + SwapChain *swap_chain = (SwapChain *)(p_swap_chain.id); + RenderingContextDriverMetal::Surface *surface = (RenderingContextDriverMetal::Surface *)(swap_chain->surface); + surface->resize(p_desired_framebuffer_count); + + // Once everything's been created correctly, indicate the surface no longer needs to be resized. + context_driver->surface_set_needs_resize(swap_chain->surface, false); + + return OK; +} + +RDD::FramebufferID RenderingDeviceDriverMetal::swap_chain_acquire_framebuffer(CommandQueueID p_cmd_queue, SwapChainID p_swap_chain, bool &r_resize_required) { + DEV_ASSERT(p_cmd_queue.id != 0); + DEV_ASSERT(p_swap_chain.id != 0); + + SwapChain *swap_chain = (SwapChain *)(p_swap_chain.id); + if (context_driver->surface_get_needs_resize(swap_chain->surface)) { + r_resize_required = true; + return FramebufferID(); + } + + RenderingContextDriverMetal::Surface *metal_surface = (RenderingContextDriverMetal::Surface *)(swap_chain->surface); + return metal_surface->acquire_next_frame_buffer(); +} + +RDD::RenderPassID RenderingDeviceDriverMetal::swap_chain_get_render_pass(SwapChainID p_swap_chain) { + const SwapChain *swap_chain = (const SwapChain *)(p_swap_chain.id); + return swap_chain->render_pass; +} + +RDD::DataFormat RenderingDeviceDriverMetal::swap_chain_get_format(SwapChainID p_swap_chain) { + const SwapChain *swap_chain = (const SwapChain *)(p_swap_chain.id); + return swap_chain->data_format; +} + +void RenderingDeviceDriverMetal::swap_chain_free(SwapChainID p_swap_chain) { + SwapChain *swap_chain = (SwapChain *)(p_swap_chain.id); + _swap_chain_release(swap_chain); + render_pass_free(swap_chain->render_pass); + memdelete(swap_chain); +} + +#pragma mark - Frame buffer + +RDD::FramebufferID RenderingDeviceDriverMetal::framebuffer_create(RenderPassID p_render_pass, VectorView p_attachments, uint32_t p_width, uint32_t p_height) { + MDRenderPass *pass = (MDRenderPass *)(p_render_pass.id); + + Vector textures; + textures.resize(p_attachments.size()); + + for (uint32_t i = 0; i < p_attachments.size(); i += 1) { + MDAttachment const &a = pass->attachments[i]; + id tex = rid::get(p_attachments[i]); + if (tex == nil) { +#if DEV_ENABLED + WARN_PRINT("Invalid texture for attachment " + itos(i)); +#endif + } + if (a.samples > 1) { + if (tex.sampleCount != a.samples) { +#if DEV_ENABLED + WARN_PRINT("Mismatched sample count for attachment " + itos(i) + "; expected " + itos(a.samples) + ", got " + itos(tex.sampleCount)); +#endif + } + } + textures.write[i] = tex; + } + + MDFrameBuffer *fb = new MDFrameBuffer(textures, Size2i(p_width, p_height)); + return FramebufferID(fb); +} + +void RenderingDeviceDriverMetal::framebuffer_free(FramebufferID p_framebuffer) { + MDFrameBuffer *obj = (MDFrameBuffer *)(p_framebuffer.id); + delete obj; +} + +#pragma mark - Shader + +const uint32_t SHADER_BINARY_VERSION = 1; + +// region Serialization + +class BufWriter; + +template +concept Serializable = requires(T t, BufWriter &p_writer) { + { + t.serialize_size() + } -> std::same_as; + { + t.serialize(p_writer) + } -> std::same_as; +}; + +class BufWriter { + uint8_t *data = nullptr; + uint64_t length = 0; // Length of data. + uint64_t pos = 0; + +public: + BufWriter(uint8_t *p_data, uint64_t p_length) : + data(p_data), length(p_length) {} + + template + void write(T const &p_value) { + p_value.serialize(*this); + } + + _FORCE_INLINE_ void write(uint32_t p_value) { + DEV_ASSERT(pos + sizeof(uint32_t) <= length); + pos += encode_uint32(p_value, data + pos); + } + + _FORCE_INLINE_ void write(RD::ShaderStage p_value) { + write((uint32_t)p_value); + } + + _FORCE_INLINE_ void write(bool p_value) { + DEV_ASSERT(pos + sizeof(uint8_t) <= length); + *(data + pos) = p_value ? 1 : 0; + pos += 1; + } + + _FORCE_INLINE_ void write(int p_value) { + write((uint32_t)p_value); + } + + _FORCE_INLINE_ void write(uint64_t p_value) { + DEV_ASSERT(pos + sizeof(uint64_t) <= length); + pos += encode_uint64(p_value, data + pos); + } + + _FORCE_INLINE_ void write(float p_value) { + DEV_ASSERT(pos + sizeof(float) <= length); + pos += encode_float(p_value, data + pos); + } + + _FORCE_INLINE_ void write(double p_value) { + DEV_ASSERT(pos + sizeof(double) <= length); + pos += encode_double(p_value, data + pos); + } + + void write_compressed(CharString const &p_string) { + write(p_string.length()); // Uncompressed size. + + DEV_ASSERT(pos + sizeof(uint32_t) + Compression::get_max_compressed_buffer_size(p_string.length(), Compression::MODE_ZSTD) <= length); + + // Save pointer for compressed size. + uint8_t *dst_size_ptr = data + pos; // Compressed size. + pos += sizeof(uint32_t); + + int dst_size = Compression::compress(data + pos, reinterpret_cast(p_string.ptr()), p_string.length(), Compression::MODE_ZSTD); + encode_uint32(dst_size, dst_size_ptr); + pos += dst_size; + } + + void write(CharString const &p_string) { + write_buffer(reinterpret_cast(p_string.ptr()), p_string.length()); + } + + template + void write(VectorView p_vector) { + write(p_vector.size()); + for (uint32_t i = 0; i < p_vector.size(); i++) { + T const &e = p_vector[i]; + write(e); + } + } + + void write(VectorView p_vector) { + write_buffer(p_vector.ptr(), p_vector.size()); + } + + template + void write(HashMap const &p_map) { + write(p_map.size()); + for (KeyValue const &e : p_map) { + write(e.key); + write(e.value); + } + } + + uint64_t get_pos() const { + return pos; + } + + uint64_t get_length() const { + return length; + } + +private: + void write_buffer(uint8_t const *p_buffer, uint32_t p_length) { + write(p_length); + + DEV_ASSERT(pos + p_length <= length); + memcpy(data + pos, p_buffer, p_length); + pos += p_length; + } +}; + +class BufReader; + +template +concept Deserializable = requires(T t, BufReader &p_reader) { + { + t.serialize_size() + } -> std::same_as; + { + t.deserialize(p_reader) + } -> std::same_as; +}; + +class BufReader { + uint8_t const *data = nullptr; + uint64_t length = 0; + uint64_t pos = 0; + + bool check_length(size_t p_size) { + if (status != Status::OK) + return false; + + if (pos + p_size > length) { + status = Status::SHORT_BUFFER; + return false; + } + return true; + } + +#define CHECK(p_size) \ + if (!check_length(p_size)) \ + return + +public: + enum class Status { + OK, + SHORT_BUFFER, + BAD_COMPRESSION, + }; + + Status status = Status::OK; + + BufReader(uint8_t const *p_data, uint64_t p_length) : + data(p_data), length(p_length) {} + + template + void read(T &p_value) { + p_value.deserialize(*this); + } + + _FORCE_INLINE_ void read(uint32_t &p_val) { + CHECK(sizeof(uint32_t)); + + p_val = decode_uint32(data + pos); + pos += sizeof(uint32_t); + } + + _FORCE_INLINE_ void read(RD::ShaderStage &p_val) { + uint32_t val; + read(val); + p_val = (RD::ShaderStage)val; + } + + _FORCE_INLINE_ void read(bool &p_val) { + CHECK(sizeof(uint8_t)); + + p_val = *(data + pos) > 0; + pos += 1; + } + + _FORCE_INLINE_ void read(uint64_t &p_val) { + CHECK(sizeof(uint64_t)); + + p_val = decode_uint64(data + pos); + pos += sizeof(uint64_t); + } + + _FORCE_INLINE_ void read(float &p_val) { + CHECK(sizeof(float)); + + p_val = decode_float(data + pos); + pos += sizeof(float); + } + + _FORCE_INLINE_ void read(double &p_val) { + CHECK(sizeof(double)); + + p_val = decode_double(data + pos); + pos += sizeof(double); + } + + void read(CharString &p_val) { + uint32_t len; + read(len); + CHECK(len); + p_val.resize(len + 1 /* NUL */); + memcpy(p_val.ptrw(), data + pos, len); + p_val.set(len, 0); + pos += len; + } + + void read_compressed(CharString &p_val) { + uint32_t len; + read(len); + uint32_t comp_size; + read(comp_size); + + CHECK(comp_size); + + p_val.resize(len + 1 /* NUL */); + uint32_t bytes = (uint32_t)Compression::decompress(reinterpret_cast(p_val.ptrw()), len, data + pos, comp_size, Compression::MODE_ZSTD); + if (bytes != len) { + status = Status::BAD_COMPRESSION; + return; + } + p_val.set(len, 0); + pos += comp_size; + } + + void read(LocalVector &p_val) { + uint32_t len; + read(len); + CHECK(len); + p_val.resize(len); + memcpy(p_val.ptr(), data + pos, len); + pos += len; + } + + template + void read(LocalVector &p_val) { + uint32_t len; + read(len); + CHECK(len); + p_val.resize(len); + for (uint32_t i = 0; i < len; i++) { + read(p_val[i]); + } + } + + template + void read(HashMap &p_map) { + uint32_t len; + read(len); + CHECK(len); + p_map.reserve(len); + for (uint32_t i = 0; i < len; i++) { + K key; + read(key); + V value; + read(value); + p_map[key] = value; + } + } + +#undef CHECK +}; + +const uint32_t R32UI_ALIGNMENT_CONSTANT_ID = 65535; + +struct ComputeSize { + uint32_t x = 0; + uint32_t y = 0; + uint32_t z = 0; + + size_t serialize_size() const { + return sizeof(uint32_t) * 3; + } + + void serialize(BufWriter &p_writer) const { + p_writer.write(x); + p_writer.write(y); + p_writer.write(z); + } + + void deserialize(BufReader &p_reader) { + p_reader.read(x); + p_reader.read(y); + p_reader.read(z); + } +}; + +struct ShaderStageData { + RD::ShaderStage stage = RD::ShaderStage::SHADER_STAGE_MAX; + CharString entry_point_name; + CharString source; + + size_t serialize_size() const { + int comp_size = Compression::get_max_compressed_buffer_size(source.length(), Compression::MODE_ZSTD); + return sizeof(uint32_t) // Stage. + + sizeof(uint32_t) /* entry_point_name.utf8().length */ + entry_point_name.length() + sizeof(uint32_t) /* uncompressed size */ + sizeof(uint32_t) /* compressed size */ + comp_size; + } + + void serialize(BufWriter &p_writer) const { + p_writer.write((uint32_t)stage); + p_writer.write(entry_point_name); + p_writer.write_compressed(source); + } + + void deserialize(BufReader &p_reader) { + p_reader.read((uint32_t &)stage); + p_reader.read(entry_point_name); + p_reader.read_compressed(source); + } +}; + +struct SpecializationConstantData { + uint32_t constant_id = UINT32_MAX; + RD::PipelineSpecializationConstantType type = RD::PIPELINE_SPECIALIZATION_CONSTANT_TYPE_FLOAT; + ShaderStageUsage stages = ShaderStageUsage::None; + // Specifies the stages the constant is used by Metal. + ShaderStageUsage used_stages = ShaderStageUsage::None; + uint32_t int_value = UINT32_MAX; + + size_t serialize_size() const { + return sizeof(constant_id) + sizeof(uint32_t) // type + + sizeof(stages) + sizeof(used_stages) // used_stages + + sizeof(int_value); // int_value + } + + void serialize(BufWriter &p_writer) const { + p_writer.write(constant_id); + p_writer.write((uint32_t)type); + p_writer.write(stages); + p_writer.write(used_stages); + p_writer.write(int_value); + } + + void deserialize(BufReader &p_reader) { + p_reader.read(constant_id); + p_reader.read((uint32_t &)type); + p_reader.read((uint32_t &)stages); + p_reader.read((uint32_t &)used_stages); + p_reader.read(int_value); + } +}; + +struct API_AVAILABLE(macos(11.0), ios(14.0)) UniformData { + RD::UniformType type = RD::UniformType::UNIFORM_TYPE_MAX; + uint32_t binding = UINT32_MAX; + bool writable = false; + uint32_t length = UINT32_MAX; + ShaderStageUsage stages = ShaderStageUsage::None; + // Specifies the stages the uniform data is + // used by the Metal shader. + ShaderStageUsage active_stages = ShaderStageUsage::None; + BindingInfoMap bindings; + BindingInfoMap bindings_secondary; + + size_t serialize_size() const { + size_t size = 0; + size += sizeof(uint32_t); // type + size += sizeof(uint32_t); // binding + size += sizeof(uint32_t); // writable + size += sizeof(uint32_t); // length + size += sizeof(uint32_t); // stages + size += sizeof(uint32_t); // active_stages + size += sizeof(uint32_t); // bindings.size() + size += sizeof(uint32_t) * bindings.size(); // Total size of keys. + for (KeyValue const &e : bindings) { + size += e.value.serialize_size(); + } + size += sizeof(uint32_t); // bindings_secondary.size() + size += sizeof(uint32_t) * bindings_secondary.size(); // Total size of keys. + for (KeyValue const &e : bindings_secondary) { + size += e.value.serialize_size(); + } + return size; + } + + void serialize(BufWriter &p_writer) const { + p_writer.write((uint32_t)type); + p_writer.write(binding); + p_writer.write(writable); + p_writer.write(length); + p_writer.write(stages); + p_writer.write(active_stages); + p_writer.write(bindings); + p_writer.write(bindings_secondary); + } + + void deserialize(BufReader &p_reader) { + p_reader.read((uint32_t &)type); + p_reader.read(binding); + p_reader.read(writable); + p_reader.read(length); + p_reader.read((uint32_t &)stages); + p_reader.read((uint32_t &)active_stages); + p_reader.read(bindings); + p_reader.read(bindings_secondary); + } +}; + +struct API_AVAILABLE(macos(11.0), ios(14.0)) UniformSetData { + uint32_t index = UINT32_MAX; + LocalVector uniforms; + + size_t serialize_size() const { + size_t size = 0; + size += sizeof(uint32_t); // index + size += sizeof(uint32_t); // uniforms.size() + for (UniformData const &e : uniforms) { + size += e.serialize_size(); + } + return size; + } + + void serialize(BufWriter &p_writer) const { + p_writer.write(index); + p_writer.write(VectorView(uniforms)); + } + + void deserialize(BufReader &p_reader) { + p_reader.read(index); + p_reader.read(uniforms); + } +}; + +struct PushConstantData { + uint32_t size = UINT32_MAX; + ShaderStageUsage stages = ShaderStageUsage::None; + ShaderStageUsage used_stages = ShaderStageUsage::None; + HashMap msl_binding; + + size_t serialize_size() const { + return sizeof(uint32_t) // size + + sizeof(uint32_t) // stages + + sizeof(uint32_t) // used_stages + + sizeof(uint32_t) // msl_binding.size() + + sizeof(uint32_t) * msl_binding.size() // keys + + sizeof(uint32_t) * msl_binding.size(); // values + } + + void serialize(BufWriter &p_writer) const { + p_writer.write(size); + p_writer.write((uint32_t)stages); + p_writer.write((uint32_t)used_stages); + p_writer.write(msl_binding); + } + + void deserialize(BufReader &p_reader) { + p_reader.read(size); + p_reader.read((uint32_t &)stages); + p_reader.read((uint32_t &)used_stages); + p_reader.read(msl_binding); + } +}; + +struct API_AVAILABLE(macos(11.0), ios(14.0)) ShaderBinaryData { + CharString shader_name; + // The Metal language version specified when compiling SPIR-V to MSL. + // Format is major * 10000 + minor * 100 + patch. + uint32_t msl_version = UINT32_MAX; + uint32_t vertex_input_mask = UINT32_MAX; + uint32_t fragment_output_mask = UINT32_MAX; + uint32_t spirv_specialization_constants_ids_mask = UINT32_MAX; + uint32_t is_compute = UINT32_MAX; + ComputeSize compute_local_size; + PushConstantData push_constant; + LocalVector stages; + LocalVector constants; + LocalVector uniforms; + + MTLLanguageVersion get_msl_version() const { + uint32_t major = msl_version / 10000; + uint32_t minor = (msl_version / 100) % 100; + return MTLLanguageVersion((major << 0x10) + minor); + } + + size_t serialize_size() const { + size_t size = 0; + size += sizeof(uint32_t) + shader_name.length(); // shader_name + size += sizeof(uint32_t); // msl_version + size += sizeof(uint32_t); // vertex_input_mask + size += sizeof(uint32_t); // fragment_output_mask + size += sizeof(uint32_t); // spirv_specialization_constants_ids_mask + size += sizeof(uint32_t); // is_compute + size += compute_local_size.serialize_size(); // compute_local_size + size += push_constant.serialize_size(); // push_constant + size += sizeof(uint32_t); // stages.size() + for (ShaderStageData const &e : stages) { + size += e.serialize_size(); + } + size += sizeof(uint32_t); // constants.size() + for (SpecializationConstantData const &e : constants) { + size += e.serialize_size(); + } + size += sizeof(uint32_t); // uniforms.size() + for (UniformSetData const &e : uniforms) { + size += e.serialize_size(); + } + return size; + } + + void serialize(BufWriter &p_writer) const { + p_writer.write(shader_name); + p_writer.write(msl_version); + p_writer.write(vertex_input_mask); + p_writer.write(fragment_output_mask); + p_writer.write(spirv_specialization_constants_ids_mask); + p_writer.write(is_compute); + p_writer.write(compute_local_size); + p_writer.write(push_constant); + p_writer.write(VectorView(stages)); + p_writer.write(VectorView(constants)); + p_writer.write(VectorView(uniforms)); + } + + void deserialize(BufReader &p_reader) { + p_reader.read(shader_name); + p_reader.read(msl_version); + p_reader.read(vertex_input_mask); + p_reader.read(fragment_output_mask); + p_reader.read(spirv_specialization_constants_ids_mask); + p_reader.read(is_compute); + p_reader.read(compute_local_size); + p_reader.read(push_constant); + p_reader.read(stages); + p_reader.read(constants); + p_reader.read(uniforms); + } +}; + +// endregion + +String RenderingDeviceDriverMetal::shader_get_binary_cache_key() { + return "Metal-SV" + uitos(SHADER_BINARY_VERSION); +} + +Error RenderingDeviceDriverMetal::_reflect_spirv16(VectorView p_spirv, ShaderReflection &r_reflection) { + using namespace spirv_cross; + using spirv_cross::Resource; + + r_reflection = {}; + + for (uint32_t i = 0; i < p_spirv.size(); i++) { + ShaderStageSPIRVData const &v = p_spirv[i]; + ShaderStage stage = v.shader_stage; + uint32_t const *const ir = reinterpret_cast(v.spirv.ptr()); + size_t word_count = v.spirv.size() / sizeof(uint32_t); + Parser parser(ir, word_count); + try { + parser.parse(); + } catch (CompilerError &e) { + ERR_FAIL_V_MSG(ERR_CANT_CREATE, "Failed to parse IR at stage " + String(SHADER_STAGE_NAMES[stage]) + ": " + e.what()); + } + + ShaderStage stage_flag = (ShaderStage)(1 << p_spirv[i].shader_stage); + + if (p_spirv[i].shader_stage == SHADER_STAGE_COMPUTE) { + r_reflection.is_compute = true; + ERR_FAIL_COND_V_MSG(p_spirv.size() != 1, FAILED, + "Compute shaders can only receive one stage, dedicated to compute."); + } + ERR_FAIL_COND_V_MSG(r_reflection.stages.has_flag(stage_flag), FAILED, + "Stage " + String(SHADER_STAGE_NAMES[p_spirv[i].shader_stage]) + " submitted more than once."); + + ParsedIR &pir = parser.get_parsed_ir(); + using BT = SPIRType::BaseType; + + Compiler compiler(std::move(pir)); + + if (r_reflection.is_compute) { + r_reflection.compute_local_size[0] = compiler.get_execution_mode_argument(spv::ExecutionModeLocalSize, 0); + r_reflection.compute_local_size[1] = compiler.get_execution_mode_argument(spv::ExecutionModeLocalSize, 1); + r_reflection.compute_local_size[2] = compiler.get_execution_mode_argument(spv::ExecutionModeLocalSize, 2); + } + + // Parse bindings. + + auto get_decoration = [&compiler](spirv_cross::ID id, spv::Decoration decoration) { + uint32_t res = -1; + if (compiler.has_decoration(id, decoration)) { + res = compiler.get_decoration(id, decoration); + } + return res; + }; + + // Always clearer than a boolean. + enum class Writable { + No, + Maybe, + }; + + // clang-format off + enum { + SPIRV_WORD_SIZE = sizeof(uint32_t), + SPIRV_DATA_ALIGNMENT = 4 * SPIRV_WORD_SIZE, + }; + // clang-format on + + auto process_uniforms = [&r_reflection, &compiler, &get_decoration, stage, stage_flag](SmallVector &resources, Writable writable, std::function uniform_type) { + for (Resource const &res : resources) { + ShaderUniform uniform; + + std::string const &name = compiler.get_name(res.id); + uint32_t set = get_decoration(res.id, spv::DecorationDescriptorSet); + ERR_FAIL_COND_V_MSG(set == (uint32_t)-1, FAILED, "No descriptor set found"); + ERR_FAIL_COND_V_MSG(set >= MAX_UNIFORM_SETS, FAILED, "On shader stage '" + String(SHADER_STAGE_NAMES[stage]) + "', uniform '" + name.c_str() + "' uses a set (" + itos(set) + ") index larger than what is supported (" + itos(MAX_UNIFORM_SETS) + ")."); + + uniform.binding = get_decoration(res.id, spv::DecorationBinding); + ERR_FAIL_COND_V_MSG(uniform.binding == (uint32_t)-1, FAILED, "No binding found"); + + SPIRType const &a_type = compiler.get_type(res.type_id); + uniform.type = uniform_type(a_type); + + // Update length. + switch (a_type.basetype) { + case BT::Struct: { + if (uniform.type == UNIFORM_TYPE_STORAGE_BUFFER) { + // Consistent with spirv_reflect. + uniform.length = 0; + } else { + uniform.length = round_up_to_alignment(compiler.get_declared_struct_size(a_type), SPIRV_DATA_ALIGNMENT); + } + } break; + case BT::Image: + case BT::Sampler: + case BT::SampledImage: { + uniform.length = 1; + for (uint32_t const &a : a_type.array) { + uniform.length *= a; + } + } break; + default: + break; + } + + // Update writable. + if (writable == Writable::Maybe) { + if (a_type.basetype == BT::Struct) { + Bitset flags = compiler.get_buffer_block_flags(res.id); + uniform.writable = !compiler.has_decoration(res.id, spv::DecorationNonWritable) && !flags.get(spv::DecorationNonWritable); + } else if (a_type.basetype == BT::Image) { + if (a_type.image.access == spv::AccessQualifierMax) { + uniform.writable = !compiler.has_decoration(res.id, spv::DecorationNonWritable); + } else { + uniform.writable = a_type.image.access != spv::AccessQualifierReadOnly; + } + } + } + + if (set < (uint32_t)r_reflection.uniform_sets.size()) { + // Check if this already exists. + bool exists = false; + for (uint32_t k = 0; k < r_reflection.uniform_sets[set].size(); k++) { + if (r_reflection.uniform_sets[set][k].binding == uniform.binding) { + // Already exists, verify that it's the same type. + ERR_FAIL_COND_V_MSG(r_reflection.uniform_sets[set][k].type != uniform.type, FAILED, + "On shader stage '" + String(SHADER_STAGE_NAMES[stage]) + "', uniform '" + name.c_str() + "' trying to reuse location for set=" + itos(set) + ", binding=" + itos(uniform.binding) + " with different uniform type."); + + // Also, verify that it's the same size. + ERR_FAIL_COND_V_MSG(r_reflection.uniform_sets[set][k].length != uniform.length, FAILED, + "On shader stage '" + String(SHADER_STAGE_NAMES[stage]) + "', uniform '" + name.c_str() + "' trying to reuse location for set=" + itos(set) + ", binding=" + itos(uniform.binding) + " with different uniform size."); + + // Also, verify that it has the same writability. + ERR_FAIL_COND_V_MSG(r_reflection.uniform_sets[set][k].writable != uniform.writable, FAILED, + "On shader stage '" + String(SHADER_STAGE_NAMES[stage]) + "', uniform '" + name.c_str() + "' trying to reuse location for set=" + itos(set) + ", binding=" + itos(uniform.binding) + " with different writability."); + + // Just append stage mask and continue. + r_reflection.uniform_sets.write[set].write[k].stages.set_flag(stage_flag); + exists = true; + break; + } + } + + if (exists) { + continue; // Merged. + } + } + + uniform.stages.set_flag(stage_flag); + + if (set >= (uint32_t)r_reflection.uniform_sets.size()) { + r_reflection.uniform_sets.resize(set + 1); + } + + r_reflection.uniform_sets.write[set].push_back(uniform); + } + + return OK; + }; + + ShaderResources resources = compiler.get_shader_resources(); + + process_uniforms(resources.uniform_buffers, Writable::No, [](SPIRType const &a_type) { + DEV_ASSERT(a_type.basetype == BT::Struct); + return UNIFORM_TYPE_UNIFORM_BUFFER; + }); + + process_uniforms(resources.storage_buffers, Writable::Maybe, [](SPIRType const &a_type) { + DEV_ASSERT(a_type.basetype == BT::Struct); + return UNIFORM_TYPE_STORAGE_BUFFER; + }); + + process_uniforms(resources.storage_images, Writable::Maybe, [](SPIRType const &a_type) { + DEV_ASSERT(a_type.basetype == BT::Image); + if (a_type.image.dim == spv::DimBuffer) { + return UNIFORM_TYPE_IMAGE_BUFFER; + } else { + return UNIFORM_TYPE_IMAGE; + } + }); + + process_uniforms(resources.sampled_images, Writable::No, [](SPIRType const &a_type) { + DEV_ASSERT(a_type.basetype == BT::SampledImage); + return UNIFORM_TYPE_SAMPLER_WITH_TEXTURE; + }); + + process_uniforms(resources.separate_images, Writable::No, [](SPIRType const &a_type) { + DEV_ASSERT(a_type.basetype == BT::Image); + if (a_type.image.dim == spv::DimBuffer) { + return UNIFORM_TYPE_TEXTURE_BUFFER; + } else { + return UNIFORM_TYPE_TEXTURE; + } + }); + + process_uniforms(resources.separate_samplers, Writable::No, [](SPIRType const &a_type) { + DEV_ASSERT(a_type.basetype == BT::Sampler); + return UNIFORM_TYPE_SAMPLER; + }); + + process_uniforms(resources.subpass_inputs, Writable::No, [](SPIRType const &a_type) { + DEV_ASSERT(a_type.basetype == BT::Image && a_type.image.dim == spv::DimSubpassData); + return UNIFORM_TYPE_INPUT_ATTACHMENT; + }); + + if (!resources.push_constant_buffers.empty()) { + // There can be only one push constant block. + Resource const &res = resources.push_constant_buffers.front(); + + size_t push_constant_size = round_up_to_alignment(compiler.get_declared_struct_size(compiler.get_type(res.base_type_id)), SPIRV_DATA_ALIGNMENT); + ERR_FAIL_COND_V_MSG(r_reflection.push_constant_size && r_reflection.push_constant_size != push_constant_size, FAILED, + "Reflection of SPIR-V shader stage '" + String(SHADER_STAGE_NAMES[p_spirv[i].shader_stage]) + "': Push constant block must be the same across shader stages."); + + r_reflection.push_constant_size = push_constant_size; + r_reflection.push_constant_stages.set_flag(stage_flag); + } + + ERR_FAIL_COND_V_MSG(!resources.atomic_counters.empty(), FAILED, "Atomic counters not supported"); + ERR_FAIL_COND_V_MSG(!resources.acceleration_structures.empty(), FAILED, "Acceleration structures not supported"); + ERR_FAIL_COND_V_MSG(!resources.shader_record_buffers.empty(), FAILED, "Shader record buffers not supported"); + + if (stage == SHADER_STAGE_VERTEX && !resources.stage_inputs.empty()) { + for (Resource const &res : resources.stage_inputs) { + SPIRType a_type = compiler.get_type(res.base_type_id); + uint32_t loc = get_decoration(res.id, spv::DecorationLocation); + if (loc != (uint32_t)-1) { + r_reflection.vertex_input_mask |= 1 << loc; + } + } + } + + if (stage == SHADER_STAGE_FRAGMENT && !resources.stage_outputs.empty()) { + for (Resource const &res : resources.stage_outputs) { + SPIRType a_type = compiler.get_type(res.base_type_id); + uint32_t loc = get_decoration(res.id, spv::DecorationLocation); + uint32_t built_in = spv::BuiltIn(get_decoration(res.id, spv::DecorationBuiltIn)); + if (loc != (uint32_t)-1 && built_in != spv::BuiltInFragDepth) { + r_reflection.fragment_output_mask |= 1 << loc; + } + } + } + + // Specialization constants. + for (SpecializationConstant const &constant : compiler.get_specialization_constants()) { + int32_t existing = -1; + ShaderSpecializationConstant sconst; + SPIRConstant &spc = compiler.get_constant(constant.id); + SPIRType const &spct = compiler.get_type(spc.constant_type); + + sconst.constant_id = constant.constant_id; + sconst.int_value = 0; + + switch (spct.basetype) { + case BT::Boolean: { + sconst.type = PIPELINE_SPECIALIZATION_CONSTANT_TYPE_BOOL; + sconst.bool_value = spc.scalar() != 0; + } break; + case BT::Int: + case BT::UInt: { + sconst.type = PIPELINE_SPECIALIZATION_CONSTANT_TYPE_INT; + sconst.int_value = spc.scalar(); + } break; + case BT::Float: { + sconst.type = PIPELINE_SPECIALIZATION_CONSTANT_TYPE_FLOAT; + sconst.float_value = spc.scalar_f32(); + } break; + default: + ERR_FAIL_V_MSG(FAILED, "Unsupported specialization constant type"); + } + sconst.stages.set_flag(stage_flag); + + for (uint32_t k = 0; k < r_reflection.specialization_constants.size(); k++) { + if (r_reflection.specialization_constants[k].constant_id == sconst.constant_id) { + ERR_FAIL_COND_V_MSG(r_reflection.specialization_constants[k].type != sconst.type, FAILED, "More than one specialization constant used for id (" + itos(sconst.constant_id) + "), but their types differ."); + ERR_FAIL_COND_V_MSG(r_reflection.specialization_constants[k].int_value != sconst.int_value, FAILED, "More than one specialization constant used for id (" + itos(sconst.constant_id) + "), but their default values differ."); + existing = k; + break; + } + } + + if (existing > 0) { + r_reflection.specialization_constants.write[existing].stages.set_flag(stage_flag); + } else { + r_reflection.specialization_constants.push_back(sconst); + } + } + + r_reflection.stages.set_flag(stage_flag); + } + + // Sort all uniform_sets. + for (uint32_t i = 0; i < r_reflection.uniform_sets.size(); i++) { + r_reflection.uniform_sets.write[i].sort(); + } + + return OK; +} + +Vector RenderingDeviceDriverMetal::shader_compile_binary_from_spirv(VectorView p_spirv, const String &p_shader_name) { + using Result = ::Vector; + using namespace spirv_cross; + using spirv_cross::CompilerMSL; + using spirv_cross::Resource; + + ShaderReflection spirv_data; + ERR_FAIL_COND_V(_reflect_spirv16(p_spirv, spirv_data), Result()); + + ShaderBinaryData bin_data{}; + if (!p_shader_name.is_empty()) { + bin_data.shader_name = p_shader_name.utf8(); + } else { + bin_data.shader_name = "unnamed"; + } + + bin_data.vertex_input_mask = spirv_data.vertex_input_mask; + bin_data.fragment_output_mask = spirv_data.fragment_output_mask; + bin_data.compute_local_size = ComputeSize{ + .x = spirv_data.compute_local_size[0], + .y = spirv_data.compute_local_size[1], + .z = spirv_data.compute_local_size[2], + }; + bin_data.is_compute = spirv_data.is_compute; + bin_data.push_constant.size = spirv_data.push_constant_size; + bin_data.push_constant.stages = (ShaderStageUsage)(uint8_t)spirv_data.push_constant_stages; + + for (uint32_t i = 0; i < spirv_data.uniform_sets.size(); i++) { + const ::Vector &spirv_set = spirv_data.uniform_sets[i]; + UniformSetData set{ .index = i }; + for (const ShaderUniform &spirv_uniform : spirv_set) { + UniformData binding{}; + binding.type = spirv_uniform.type; + binding.binding = spirv_uniform.binding; + binding.writable = spirv_uniform.writable; + binding.stages = (ShaderStageUsage)(uint8_t)spirv_uniform.stages; + binding.length = spirv_uniform.length; + set.uniforms.push_back(binding); + } + bin_data.uniforms.push_back(set); + } + + for (const ShaderSpecializationConstant &spirv_sc : spirv_data.specialization_constants) { + SpecializationConstantData spec_constant{}; + spec_constant.type = spirv_sc.type; + spec_constant.constant_id = spirv_sc.constant_id; + spec_constant.int_value = spirv_sc.int_value; + spec_constant.stages = (ShaderStageUsage)(uint8_t)spirv_sc.stages; + bin_data.constants.push_back(spec_constant); + bin_data.spirv_specialization_constants_ids_mask |= (1 << spirv_sc.constant_id); + } + + // Reflection using SPIRV-Cross: + // https://github.com/KhronosGroup/SPIRV-Cross/wiki/Reflection-API-user-guide + + CompilerMSL::Options msl_options{}; + msl_options.set_msl_version(version_major, version_minor); + if (version_major == 3 && version_minor >= 1) { + // TODO(sgc): Restrict to Metal 3.0 for now, until bugs in SPIRV-cross image atomics are resolved. + msl_options.set_msl_version(3, 0); + } + bin_data.msl_version = msl_options.msl_version; +#if TARGET_OS_OSX + msl_options.platform = CompilerMSL::Options::macOS; +#else + msl_options.platform = CompilerMSL::Options::iOS; +#endif + +#if TARGET_OS_IOS + msl_options.ios_use_simdgroup_functions = (*metal_device_properties).features.simdPermute; +#endif + + msl_options.argument_buffers = true; + msl_options.force_active_argument_buffer_resources = true; // Same as MoltenVK when using argument buffers. + // msl_options.pad_argument_buffer_resources = true; // Same as MoltenVK when using argument buffers. + msl_options.texture_buffer_native = true; // Enable texture buffer support. + msl_options.use_framebuffer_fetch_subpasses = false; + msl_options.pad_fragment_output_components = true; + msl_options.r32ui_alignment_constant_id = R32UI_ALIGNMENT_CONSTANT_ID; + msl_options.agx_manual_cube_grad_fixup = true; + + CompilerGLSL::Options options{}; + options.vertex.flip_vert_y = true; +#if DEV_ENABLED + options.emit_line_directives = true; +#endif + + for (uint32_t i = 0; i < p_spirv.size(); i++) { + ShaderStageSPIRVData const &v = p_spirv[i]; + ShaderStage stage = v.shader_stage; + char const *stage_name = SHADER_STAGE_NAMES[stage]; + uint32_t const *const ir = reinterpret_cast(v.spirv.ptr()); + size_t word_count = v.spirv.size() / sizeof(uint32_t); + Parser parser(ir, word_count); + try { + parser.parse(); + } catch (CompilerError &e) { + ERR_FAIL_V_MSG(Result(), "Failed to parse IR at stage " + String(SHADER_STAGE_NAMES[stage]) + ": " + e.what()); + } + + CompilerMSL compiler(std::move(parser.get_parsed_ir())); + compiler.set_msl_options(msl_options); + compiler.set_common_options(options); + + std::unordered_set active = compiler.get_active_interface_variables(); + ShaderResources resources = compiler.get_shader_resources(); + + std::string source = compiler.compile(); + + ERR_FAIL_COND_V_MSG(compiler.get_entry_points_and_stages().size() != 1, Result(), "Expected a single entry point and stage."); + + EntryPoint &entry_point_stage = compiler.get_entry_points_and_stages().front(); + SPIREntryPoint &entry_point = compiler.get_entry_point(entry_point_stage.name, entry_point_stage.execution_model); + + // Process specialization constants. + if (!compiler.get_specialization_constants().empty()) { + for (SpecializationConstant const &constant : compiler.get_specialization_constants()) { + LocalVector::Iterator res = bin_data.constants.begin(); + while (res != bin_data.constants.end()) { + if (res->constant_id == constant.constant_id) { + res->used_stages |= 1 << stage; + break; + } + ++res; + } + if (res == bin_data.constants.end()) { + WARN_PRINT(String(stage_name) + ": unable to find constant_id: " + itos(constant.constant_id)); + } + } + } + + // Process bindings. + + LocalVector &uniform_sets = bin_data.uniforms; + using BT = SPIRType::BaseType; + + // Always clearer than a boolean. + enum class Writable { + No, + Maybe, + }; + + // Returns a std::optional containing the value of the + // decoration, if it exists. + auto get_decoration = [&compiler](spirv_cross::ID id, spv::Decoration decoration) { + uint32_t res = -1; + if (compiler.has_decoration(id, decoration)) { + res = compiler.get_decoration(id, decoration); + } + return res; + }; + + auto descriptor_bindings = [&compiler, &active, &uniform_sets, stage, &get_decoration](SmallVector &resources, Writable writable) { + for (Resource const &res : resources) { + uint32_t dset = get_decoration(res.id, spv::DecorationDescriptorSet); + uint32_t dbin = get_decoration(res.id, spv::DecorationBinding); + UniformData *found = nullptr; + if (dset != (uint32_t)-1 && dbin != (uint32_t)-1 && dset < uniform_sets.size()) { + UniformSetData &set = uniform_sets[dset]; + LocalVector::Iterator pos = set.uniforms.begin(); + while (pos != set.uniforms.end()) { + if (dbin == pos->binding) { + found = &(*pos); + break; + } + ++pos; + } + } + + ERR_FAIL_NULL_V_MSG(found, ERR_CANT_CREATE, "UniformData not found"); + + bool is_active = active.find(res.id) != active.end(); + if (is_active) { + found->active_stages |= 1 << stage; + } + + BindingInfo primary{}; + + SPIRType const &a_type = compiler.get_type(res.type_id); + BT basetype = a_type.basetype; + + switch (basetype) { + case BT::Struct: { + primary.dataType = MTLDataTypePointer; + } break; + + case BT::Image: + case BT::SampledImage: { + primary.dataType = MTLDataTypeTexture; + } break; + + case BT::Sampler: { + primary.dataType = MTLDataTypeSampler; + } break; + + default: { + ERR_FAIL_V_MSG(ERR_CANT_CREATE, "Unexpected BaseType"); + } break; + } + + // Find array length. + if (basetype == BT::Image || basetype == BT::SampledImage) { + primary.arrayLength = 1; + for (uint32_t const &a : a_type.array) { + primary.arrayLength *= a; + } + primary.isMultisampled = a_type.image.ms; + + SPIRType::ImageType const &image = a_type.image; + primary.imageFormat = image.format; + + switch (image.dim) { + case spv::Dim1D: { + if (image.arrayed) { + primary.textureType = MTLTextureType1DArray; + } else { + primary.textureType = MTLTextureType1D; + } + } break; + case spv::DimSubpassData: { + DISPATCH_FALLTHROUGH; + } + case spv::Dim2D: { + if (image.arrayed && image.ms) { + primary.textureType = MTLTextureType2DMultisampleArray; + } else if (image.arrayed) { + primary.textureType = MTLTextureType2DArray; + } else if (image.ms) { + primary.textureType = MTLTextureType2DMultisample; + } else { + primary.textureType = MTLTextureType2D; + } + } break; + case spv::Dim3D: { + primary.textureType = MTLTextureType3D; + } break; + case spv::DimCube: { + if (image.arrayed) { + primary.textureType = MTLTextureTypeCube; + } + } break; + case spv::DimRect: { + } break; + case spv::DimBuffer: { + // VK_DESCRIPTOR_TYPE_STORAGE_TEXEL_BUFFER + primary.textureType = MTLTextureTypeTextureBuffer; + } break; + case spv::DimMax: { + // Add all enumerations to silence the compiler warning + // and generate future warnings, should a new one be added. + } break; + } + } + + // Update writable. + if (writable == Writable::Maybe) { + if (basetype == BT::Struct) { + Bitset flags = compiler.get_buffer_block_flags(res.id); + if (!flags.get(spv::DecorationNonWritable)) { + if (flags.get(spv::DecorationNonReadable)) { + primary.access = MTLBindingAccessWriteOnly; + } else { + primary.access = MTLBindingAccessReadWrite; + } + } + } else if (basetype == BT::Image) { + switch (a_type.image.access) { + case spv::AccessQualifierWriteOnly: + primary.access = MTLBindingAccessWriteOnly; + break; + case spv::AccessQualifierReadWrite: + primary.access = MTLBindingAccessReadWrite; + break; + case spv::AccessQualifierReadOnly: + break; + case spv::AccessQualifierMax: + DISPATCH_FALLTHROUGH; + default: + if (!compiler.has_decoration(res.id, spv::DecorationNonWritable)) { + if (compiler.has_decoration(res.id, spv::DecorationNonReadable)) { + primary.access = MTLBindingAccessWriteOnly; + } else { + primary.access = MTLBindingAccessReadWrite; + } + } + break; + } + } + } + + switch (primary.access) { + case MTLBindingAccessReadOnly: + primary.usage = MTLResourceUsageRead; + break; + case MTLBindingAccessWriteOnly: + primary.usage = MTLResourceUsageWrite; + break; + case MTLBindingAccessReadWrite: + primary.usage = MTLResourceUsageRead | MTLResourceUsageWrite; + break; + } + + primary.index = compiler.get_automatic_msl_resource_binding(res.id); + + found->bindings[stage] = primary; + + // A sampled image contains two bindings, the primary + // is to the image, and the secondary is to the associated sampler. + if (basetype == BT::SampledImage) { + uint32_t binding = compiler.get_automatic_msl_resource_binding_secondary(res.id); + if (binding != (uint32_t)-1) { + found->bindings_secondary[stage] = BindingInfo{ + .dataType = MTLDataTypeSampler, + .index = binding, + .access = MTLBindingAccessReadOnly, + }; + } + } + + // An image may have a secondary binding if it is used + // for atomic operations. + if (basetype == BT::Image) { + uint32_t binding = compiler.get_automatic_msl_resource_binding_secondary(res.id); + if (binding != (uint32_t)-1) { + found->bindings_secondary[stage] = BindingInfo{ + .dataType = MTLDataTypePointer, + .index = binding, + .access = MTLBindingAccessReadWrite, + }; + } + } + } + return Error::OK; + }; + + if (!resources.uniform_buffers.empty()) { + Error err = descriptor_bindings(resources.uniform_buffers, Writable::No); + ERR_FAIL_COND_V(err != OK, Result()); + } + if (!resources.storage_buffers.empty()) { + Error err = descriptor_bindings(resources.storage_buffers, Writable::Maybe); + ERR_FAIL_COND_V(err != OK, Result()); + } + if (!resources.storage_images.empty()) { + Error err = descriptor_bindings(resources.storage_images, Writable::Maybe); + ERR_FAIL_COND_V(err != OK, Result()); + } + if (!resources.sampled_images.empty()) { + Error err = descriptor_bindings(resources.sampled_images, Writable::No); + ERR_FAIL_COND_V(err != OK, Result()); + } + if (!resources.separate_images.empty()) { + Error err = descriptor_bindings(resources.separate_images, Writable::No); + ERR_FAIL_COND_V(err != OK, Result()); + } + if (!resources.separate_samplers.empty()) { + Error err = descriptor_bindings(resources.separate_samplers, Writable::No); + ERR_FAIL_COND_V(err != OK, Result()); + } + if (!resources.subpass_inputs.empty()) { + Error err = descriptor_bindings(resources.subpass_inputs, Writable::No); + ERR_FAIL_COND_V(err != OK, Result()); + } + + if (!resources.push_constant_buffers.empty()) { + for (Resource const &res : resources.push_constant_buffers) { + uint32_t binding = compiler.get_automatic_msl_resource_binding(res.id); + if (binding != (uint32_t)-1) { + bin_data.push_constant.used_stages |= 1 << stage; + bin_data.push_constant.msl_binding[stage] = binding; + } + } + } + + ERR_FAIL_COND_V_MSG(!resources.atomic_counters.empty(), Result(), "Atomic counters not supported"); + ERR_FAIL_COND_V_MSG(!resources.acceleration_structures.empty(), Result(), "Acceleration structures not supported"); + ERR_FAIL_COND_V_MSG(!resources.shader_record_buffers.empty(), Result(), "Shader record buffers not supported"); + + if (!resources.stage_inputs.empty()) { + for (Resource const &res : resources.stage_inputs) { + uint32_t binding = compiler.get_automatic_msl_resource_binding(res.id); + if (binding != (uint32_t)-1) { + bin_data.vertex_input_mask |= 1 << binding; + } + } + } + + ShaderStageData stage_data; + stage_data.stage = v.shader_stage; + stage_data.entry_point_name = entry_point.name.c_str(); + stage_data.source = source.c_str(); + bin_data.stages.push_back(stage_data); + } + + size_t vec_size = bin_data.serialize_size() + 8; + + ::Vector ret; + ret.resize(vec_size); + BufWriter writer(ret.ptrw(), vec_size); + const uint8_t HEADER[4] = { 'G', 'M', 'S', 'L' }; + writer.write(*(uint32_t *)HEADER); + writer.write(SHADER_BINARY_VERSION); + bin_data.serialize(writer); + ret.resize(writer.get_pos()); + + return ret; +} + +RDD::ShaderID RenderingDeviceDriverMetal::shader_create_from_bytecode(const Vector &p_shader_binary, ShaderDescription &r_shader_desc, String &r_name) { + r_shader_desc = {}; // Driver-agnostic. + + const uint8_t *binptr = p_shader_binary.ptr(); + uint32_t binsize = p_shader_binary.size(); + + BufReader reader(binptr, binsize); + uint8_t header[4]; + reader.read((uint32_t &)header); + ERR_FAIL_COND_V_MSG(memcmp(header, "GMSL", 4) != 0, ShaderID(), "Invalid header"); + uint32_t version = 0; + reader.read(version); + ERR_FAIL_COND_V_MSG(version != SHADER_BINARY_VERSION, ShaderID(), "Invalid shader binary version"); + + ShaderBinaryData binary_data; + binary_data.deserialize(reader); + switch (reader.status) { + case BufReader::Status::OK: + break; + case BufReader::Status::BAD_COMPRESSION: + ERR_FAIL_V_MSG(ShaderID(), "Invalid compressed data"); + case BufReader::Status::SHORT_BUFFER: + ERR_FAIL_V_MSG(ShaderID(), "Unexpected end of buffer"); + } + + MTLCompileOptions *options = [MTLCompileOptions new]; + options.languageVersion = binary_data.get_msl_version(); + HashMap> libraries; + for (ShaderStageData &shader_data : binary_data.stages) { + NSString *source = [[NSString alloc] initWithBytesNoCopy:(void *)shader_data.source.ptr() + length:shader_data.source.length() + encoding:NSUTF8StringEncoding + freeWhenDone:NO]; + NSError *error = nil; + id library = [device newLibraryWithSource:source options:options error:&error]; + if (error != nil) { + print_error(error.localizedDescription.UTF8String); + ERR_FAIL_V_MSG(ShaderID(), "failed to compile Metal source"); + } + libraries[shader_data.stage] = library; + } + + Vector uniform_sets; + uniform_sets.resize(binary_data.uniforms.size()); + + r_shader_desc.uniform_sets.resize(binary_data.uniforms.size()); + + // Create sets. + for (UniformSetData &uniform_set : binary_data.uniforms) { + UniformSet &set = uniform_sets.write[uniform_set.index]; + set.uniforms.resize(uniform_set.uniforms.size()); + + Vector &uset = r_shader_desc.uniform_sets.write[uniform_set.index]; + uset.resize(uniform_set.uniforms.size()); + + for (uint32_t i = 0; i < uniform_set.uniforms.size(); i++) { + UniformData &uniform = uniform_set.uniforms[i]; + + ShaderUniform su; + su.type = uniform.type; + su.writable = uniform.writable; + su.length = uniform.length; + su.binding = uniform.binding; + su.stages = uniform.stages; + uset.write[i] = su; + + UniformInfo ui; + ui.binding = uniform.binding; + ui.active_stages = uniform.active_stages; + for (KeyValue &kv : uniform.bindings) { + ui.bindings.insert(kv.key, kv.value); + } + for (KeyValue &kv : uniform.bindings_secondary) { + ui.bindings_secondary.insert(kv.key, kv.value); + } + set.uniforms[i] = ui; + } + } + for (UniformSetData &uniform_set : binary_data.uniforms) { + UniformSet &set = uniform_sets.write[uniform_set.index]; + + // Make encoders. + for (ShaderStageData const &stage_data : binary_data.stages) { + ShaderStage stage = stage_data.stage; + NSMutableArray *descriptors = [NSMutableArray new]; + + for (UniformInfo const &uniform : set.uniforms) { + BindingInfo const *binding_info = uniform.bindings.getptr(stage); + if (binding_info == nullptr) + continue; + + [descriptors addObject:binding_info->new_argument_descriptor()]; + BindingInfo const *secondary_binding_info = uniform.bindings_secondary.getptr(stage); + if (secondary_binding_info != nullptr) { + [descriptors addObject:secondary_binding_info->new_argument_descriptor()]; + } + } + + if (descriptors.count == 0) { + // No bindings. + continue; + } + // Sort by index. + [descriptors sortUsingComparator:^NSComparisonResult(MTLArgumentDescriptor *a, MTLArgumentDescriptor *b) { + if (a.index < b.index) { + return NSOrderedAscending; + } else if (a.index > b.index) { + return NSOrderedDescending; + } else { + return NSOrderedSame; + } + }]; + + id enc = [device newArgumentEncoderWithArguments:descriptors]; + set.encoders[stage] = enc; + set.offsets[stage] = set.buffer_size; + set.buffer_size += enc.encodedLength; + } + } + + r_shader_desc.specialization_constants.resize(binary_data.constants.size()); + for (uint32_t i = 0; i < binary_data.constants.size(); i++) { + SpecializationConstantData &c = binary_data.constants[i]; + + ShaderSpecializationConstant sc; + sc.type = c.type; + sc.constant_id = c.constant_id; + sc.int_value = c.int_value; + sc.stages = c.stages; + r_shader_desc.specialization_constants.write[i] = sc; + } + + MDShader *shader = nullptr; + if (binary_data.is_compute) { + MDComputeShader *cs = new MDComputeShader(binary_data.shader_name, uniform_sets, libraries[ShaderStage::SHADER_STAGE_COMPUTE]); + + uint32_t *binding = binary_data.push_constant.msl_binding.getptr(SHADER_STAGE_COMPUTE); + if (binding) { + cs->push_constants.size = binary_data.push_constant.size; + cs->push_constants.binding = *binding; + } + + cs->local = MTLSizeMake(binary_data.compute_local_size.x, binary_data.compute_local_size.y, binary_data.compute_local_size.z); +#if DEV_ENABLED + cs->kernel_source = binary_data.stages[0].source; +#endif + shader = cs; + } else { + MDRenderShader *rs = new MDRenderShader(binary_data.shader_name, uniform_sets, libraries[ShaderStage::SHADER_STAGE_VERTEX], libraries[ShaderStage::SHADER_STAGE_FRAGMENT]); + + uint32_t *vert_binding = binary_data.push_constant.msl_binding.getptr(SHADER_STAGE_VERTEX); + if (vert_binding) { + rs->push_constants.vert.size = binary_data.push_constant.size; + rs->push_constants.vert.binding = *vert_binding; + } + uint32_t *frag_binding = binary_data.push_constant.msl_binding.getptr(SHADER_STAGE_FRAGMENT); + if (frag_binding) { + rs->push_constants.frag.size = binary_data.push_constant.size; + rs->push_constants.frag.binding = *frag_binding; + } + +#if DEV_ENABLED + for (ShaderStageData &stage_data : binary_data.stages) { + if (stage_data.stage == ShaderStage::SHADER_STAGE_VERTEX) { + rs->vert_source = stage_data.source; + } else if (stage_data.stage == ShaderStage::SHADER_STAGE_FRAGMENT) { + rs->frag_source = stage_data.source; + } + } +#endif + shader = rs; + } + + r_shader_desc.vertex_input_mask = binary_data.vertex_input_mask; + r_shader_desc.fragment_output_mask = binary_data.fragment_output_mask; + r_shader_desc.is_compute = binary_data.is_compute; + r_shader_desc.compute_local_size[0] = binary_data.compute_local_size.x; + r_shader_desc.compute_local_size[1] = binary_data.compute_local_size.y; + r_shader_desc.compute_local_size[2] = binary_data.compute_local_size.z; + r_shader_desc.push_constant_size = binary_data.push_constant.size; + + return ShaderID(shader); +} + +void RenderingDeviceDriverMetal::shader_free(ShaderID p_shader) { + MDShader *obj = (MDShader *)p_shader.id; + delete obj; +} + +/*********************/ +/**** UNIFORM SET ****/ +/*********************/ + +RDD::UniformSetID RenderingDeviceDriverMetal::uniform_set_create(VectorView p_uniforms, ShaderID p_shader, uint32_t p_set_index) { + MDUniformSet *set = new MDUniformSet(); + Vector bound_uniforms; + bound_uniforms.resize(p_uniforms.size()); + for (uint32_t i = 0; i < p_uniforms.size(); i += 1) { + bound_uniforms.write[i] = p_uniforms[i]; + } + set->uniforms = bound_uniforms; + set->index = p_set_index; + + return UniformSetID(set); +} + +void RenderingDeviceDriverMetal::uniform_set_free(UniformSetID p_uniform_set) { + MDUniformSet *obj = (MDUniformSet *)p_uniform_set.id; + delete obj; +} + +void RenderingDeviceDriverMetal::command_uniform_set_prepare_for_use(CommandBufferID p_cmd_buffer, UniformSetID p_uniform_set, ShaderID p_shader, uint32_t p_set_index) { +} + +#pragma mark - Transfer + +void RenderingDeviceDriverMetal::command_clear_buffer(CommandBufferID p_cmd_buffer, BufferID p_buffer, uint64_t p_offset, uint64_t p_size) { + MDCommandBuffer *cmd = (MDCommandBuffer *)(p_cmd_buffer.id); + id buffer = rid::get(p_buffer); + + id blit = cmd->blit_command_encoder(); + [blit fillBuffer:buffer + range:NSMakeRange(p_offset, p_size) + value:0]; +} + +void RenderingDeviceDriverMetal::command_copy_buffer(CommandBufferID p_cmd_buffer, BufferID p_src_buffer, BufferID p_dst_buffer, VectorView p_regions) { + MDCommandBuffer *cmd = (MDCommandBuffer *)(p_cmd_buffer.id); + id src = rid::get(p_src_buffer); + id dst = rid::get(p_dst_buffer); + + id blit = cmd->blit_command_encoder(); + + for (uint32_t i = 0; i < p_regions.size(); i++) { + BufferCopyRegion region = p_regions[i]; + [blit copyFromBuffer:src + sourceOffset:region.src_offset + toBuffer:dst + destinationOffset:region.dst_offset + size:region.size]; + } +} + +MTLSize MTLSizeFromVector3i(Vector3i p_size) { + return MTLSizeMake(p_size.x, p_size.y, p_size.z); +} + +MTLOrigin MTLOriginFromVector3i(Vector3i p_origin) { + return MTLOriginMake(p_origin.x, p_origin.y, p_origin.z); +} + +// Clamps the size so that the sum of the origin and size do not exceed the maximum size. +static inline MTLSize clampMTLSize(MTLSize p_size, MTLOrigin p_origin, MTLSize p_max_size) { + MTLSize clamped; + clamped.width = MIN(p_size.width, p_max_size.width - p_origin.x); + clamped.height = MIN(p_size.height, p_max_size.height - p_origin.y); + clamped.depth = MIN(p_size.depth, p_max_size.depth - p_origin.z); + return clamped; +} + +void RenderingDeviceDriverMetal::command_copy_texture(CommandBufferID p_cmd_buffer, TextureID p_src_texture, TextureLayout p_src_texture_layout, TextureID p_dst_texture, TextureLayout p_dst_texture_layout, VectorView p_regions) { + MDCommandBuffer *cmd = (MDCommandBuffer *)(p_cmd_buffer.id); + id src = rid::get(p_src_texture); + id dst = rid::get(p_dst_texture); + + id blit = cmd->blit_command_encoder(); + PixelFormats &pf = *pixel_formats; + + MTLPixelFormat src_fmt = src.pixelFormat; + bool src_is_compressed = pf.getFormatType(src_fmt) == MTLFormatType::Compressed; + MTLPixelFormat dst_fmt = dst.pixelFormat; + bool dst_is_compressed = pf.getFormatType(dst_fmt) == MTLFormatType::Compressed; + + // Validate copy. + if (src.sampleCount != dst.sampleCount || pf.getBytesPerBlock(src_fmt) != pf.getBytesPerBlock(dst_fmt)) { + ERR_FAIL_MSG("Cannot copy between incompatible pixel formats, such as formats of different pixel sizes, or between images with different sample counts."); + } + + // If source and destination have different formats and at least one is compressed, a temporary buffer is required. + bool need_tmp_buffer = (src_fmt != dst_fmt) && (src_is_compressed || dst_is_compressed); + if (need_tmp_buffer) { + ERR_FAIL_MSG("not implemented: copy with intermediate buffer"); + } + + if (src_fmt != dst_fmt) { + // Map the source pixel format to the dst through a texture view on the source texture. + src = [src newTextureViewWithPixelFormat:dst_fmt]; + } + + for (uint32_t i = 0; i < p_regions.size(); i++) { + TextureCopyRegion region = p_regions[i]; + + MTLSize extent = MTLSizeFromVector3i(region.size); + + // If copies can be performed using direct texture-texture copying, do so. + uint32_t src_level = region.src_subresources.mipmap; + uint32_t src_base_layer = region.src_subresources.base_layer; + MTLSize src_extent = mipmapLevelSizeFromTexture(src, src_level); + uint32_t dst_level = region.dst_subresources.mipmap; + uint32_t dst_base_layer = region.dst_subresources.base_layer; + MTLSize dst_extent = mipmapLevelSizeFromTexture(dst, dst_level); + + // All layers may be copied at once, if the extent completely covers both images. + if (src_extent == extent && dst_extent == extent) { + [blit copyFromTexture:src + sourceSlice:src_base_layer + sourceLevel:src_level + toTexture:dst + destinationSlice:dst_base_layer + destinationLevel:dst_level + sliceCount:region.src_subresources.layer_count + levelCount:1]; + } else { + MTLOrigin src_origin = MTLOriginFromVector3i(region.src_offset); + MTLSize src_size = clampMTLSize(extent, src_origin, src_extent); + uint32_t layer_count = 0; + if ((src.textureType == MTLTextureType3D) != (dst.textureType == MTLTextureType3D)) { + // In the case, the number of layers to copy is in extent.depth. Use that value, + // then clamp the depth, so we don't try to copy more than Metal will allow. + layer_count = extent.depth; + src_size.depth = 1; + } else { + layer_count = region.src_subresources.layer_count; + } + MTLOrigin dst_origin = MTLOriginFromVector3i(region.dst_offset); + + for (uint32_t layer = 0; layer < layer_count; layer++) { + // We can copy between a 3D and a 2D image easily. Just copy between + // one slice of the 2D image and one plane of the 3D image at a time. + if ((src.textureType == MTLTextureType3D) == (dst.textureType == MTLTextureType3D)) { + [blit copyFromTexture:src + sourceSlice:src_base_layer + layer + sourceLevel:src_level + sourceOrigin:src_origin + sourceSize:src_size + toTexture:dst + destinationSlice:dst_base_layer + layer + destinationLevel:dst_level + destinationOrigin:dst_origin]; + } else if (src.textureType == MTLTextureType3D) { + [blit copyFromTexture:src + sourceSlice:src_base_layer + sourceLevel:src_level + sourceOrigin:MTLOriginMake(src_origin.x, src_origin.y, src_origin.z + layer) + sourceSize:src_size + toTexture:dst + destinationSlice:dst_base_layer + layer + destinationLevel:dst_level + destinationOrigin:dst_origin]; + } else { + DEV_ASSERT(dst.textureType == MTLTextureType3D); + [blit copyFromTexture:src + sourceSlice:src_base_layer + layer + sourceLevel:src_level + sourceOrigin:src_origin + sourceSize:src_size + toTexture:dst + destinationSlice:dst_base_layer + destinationLevel:dst_level + destinationOrigin:MTLOriginMake(dst_origin.x, dst_origin.y, dst_origin.z + layer)]; + } + } + } + } +} + +void RenderingDeviceDriverMetal::command_resolve_texture(CommandBufferID p_cmd_buffer, TextureID p_src_texture, TextureLayout p_src_texture_layout, uint32_t p_src_layer, uint32_t p_src_mipmap, TextureID p_dst_texture, TextureLayout p_dst_texture_layout, uint32_t p_dst_layer, uint32_t p_dst_mipmap) { + MDCommandBuffer *cb = (MDCommandBuffer *)(p_cmd_buffer.id); + id src_tex = rid::get(p_src_texture); + id dst_tex = rid::get(p_dst_texture); + + MTLRenderPassDescriptor *mtlRPD = [MTLRenderPassDescriptor renderPassDescriptor]; + MTLRenderPassColorAttachmentDescriptor *mtlColorAttDesc = mtlRPD.colorAttachments[0]; + mtlColorAttDesc.loadAction = MTLLoadActionLoad; + mtlColorAttDesc.storeAction = MTLStoreActionMultisampleResolve; + + mtlColorAttDesc.texture = src_tex; + mtlColorAttDesc.resolveTexture = dst_tex; + mtlColorAttDesc.level = p_src_mipmap; + mtlColorAttDesc.slice = p_src_layer; + mtlColorAttDesc.resolveLevel = p_dst_mipmap; + mtlColorAttDesc.resolveSlice = p_dst_layer; + cb->encodeRenderCommandEncoderWithDescriptor(mtlRPD, @"Resolve Image"); +} + +void RenderingDeviceDriverMetal::command_clear_color_texture(CommandBufferID p_cmd_buffer, TextureID p_texture, TextureLayout p_texture_layout, const Color &p_color, const TextureSubresourceRange &p_subresources) { + MDCommandBuffer *cb = (MDCommandBuffer *)(p_cmd_buffer.id); + id src_tex = rid::get(p_texture); + + if (src_tex.parentTexture) { + // Clear via the parent texture rather than the view. + src_tex = src_tex.parentTexture; + } + + PixelFormats &pf = *pixel_formats; + + if (pf.isDepthFormat(src_tex.pixelFormat) || pf.isStencilFormat(src_tex.pixelFormat)) { + ERR_FAIL_MSG("invalid: depth or stencil texture format"); + } + + MTLRenderPassDescriptor *desc = MTLRenderPassDescriptor.renderPassDescriptor; + + if (p_subresources.aspect.has_flag(TEXTURE_ASPECT_COLOR_BIT)) { + MTLRenderPassColorAttachmentDescriptor *caDesc = desc.colorAttachments[0]; + caDesc.texture = src_tex; + caDesc.loadAction = MTLLoadActionClear; + caDesc.storeAction = MTLStoreActionStore; + caDesc.clearColor = MTLClearColorMake(p_color.r, p_color.g, p_color.b, p_color.a); + + // Extract the mipmap levels that are to be updated. + uint32_t mipLvlStart = p_subresources.base_mipmap; + uint32_t mipLvlCnt = p_subresources.mipmap_count; + uint32_t mipLvlEnd = mipLvlStart + mipLvlCnt; + + uint32_t levelCount = src_tex.mipmapLevelCount; + + // Extract the cube or array layers (slices) that are to be updated. + bool is3D = src_tex.textureType == MTLTextureType3D; + uint32_t layerStart = is3D ? 0 : p_subresources.base_layer; + uint32_t layerCnt = p_subresources.layer_count; + uint32_t layerEnd = layerStart + layerCnt; + + MetalFeatures const &features = (*metal_device_properties).features; + + // Iterate across mipmap levels and layers, and perform and empty render to clear each. + for (uint32_t mipLvl = mipLvlStart; mipLvl < mipLvlEnd; mipLvl++) { + ERR_FAIL_INDEX_MSG(mipLvl, levelCount, "mip level out of range"); + + caDesc.level = mipLvl; + + // If a 3D image, we need to get the depth for each level. + if (is3D) { + layerCnt = mipmapLevelSizeFromTexture(src_tex, mipLvl).depth; + layerEnd = layerStart + layerCnt; + } + + if ((features.layeredRendering && src_tex.sampleCount == 1) || features.multisampleLayeredRendering) { + // We can clear all layers at once. + if (is3D) { + caDesc.depthPlane = layerStart; + } else { + caDesc.slice = layerStart; + } + desc.renderTargetArrayLength = layerCnt; + cb->encodeRenderCommandEncoderWithDescriptor(desc, @"Clear Image"); + } else { + for (uint32_t layer = layerStart; layer < layerEnd; layer++) { + if (is3D) { + caDesc.depthPlane = layer; + } else { + caDesc.slice = layer; + } + cb->encodeRenderCommandEncoderWithDescriptor(desc, @"Clear Image"); + } + } + } + } +} + +API_AVAILABLE(macos(11.0), ios(14.0)) +bool isArrayTexture(MTLTextureType p_type) { + return (p_type == MTLTextureType3D || + p_type == MTLTextureType2DArray || + p_type == MTLTextureType2DMultisampleArray || + p_type == MTLTextureType1DArray); +} + +void RenderingDeviceDriverMetal::_copy_texture_buffer(CommandBufferID p_cmd_buffer, + CopySource p_source, + TextureID p_texture, + BufferID p_buffer, + VectorView p_regions) { + MDCommandBuffer *cmd = (MDCommandBuffer *)(p_cmd_buffer.id); + id buffer = rid::get(p_buffer); + id texture = rid::get(p_texture); + + id enc = cmd->blit_command_encoder(); + + PixelFormats &pf = *pixel_formats; + MTLPixelFormat mtlPixFmt = texture.pixelFormat; + + MTLBlitOption options = MTLBlitOptionNone; + if (pf.isPVRTCFormat(mtlPixFmt)) { + options |= MTLBlitOptionRowLinearPVRTC; + } + + for (uint32_t i = 0; i < p_regions.size(); i++) { + BufferTextureCopyRegion region = p_regions[i]; + + uint32_t mip_level = region.texture_subresources.mipmap; + MTLOrigin txt_origin = MTLOriginMake(region.texture_offset.x, region.texture_offset.y, region.texture_offset.z); + MTLSize src_extent = mipmapLevelSizeFromTexture(texture, mip_level); + MTLSize txt_size = clampMTLSize(MTLSizeMake(region.texture_region_size.x, region.texture_region_size.y, region.texture_region_size.z), + txt_origin, + src_extent); + + uint32_t buffImgWd = region.texture_region_size.x; + uint32_t buffImgHt = region.texture_region_size.y; + + NSUInteger bytesPerRow = pf.getBytesPerRow(mtlPixFmt, buffImgWd); + NSUInteger bytesPerImg = pf.getBytesPerLayer(mtlPixFmt, bytesPerRow, buffImgHt); + + MTLBlitOption blit_options = options; + + if (pf.isDepthFormat(mtlPixFmt) && pf.isStencilFormat(mtlPixFmt)) { + bool want_depth = flags::all(region.texture_subresources.aspect, TEXTURE_ASPECT_DEPTH_BIT); + bool want_stencil = flags::all(region.texture_subresources.aspect, TEXTURE_ASPECT_STENCIL_BIT); + + // The stencil component is always 1 byte per pixel. + // Don't reduce depths of 32-bit depth/stencil formats. + if (want_depth && !want_stencil) { + if (pf.getBytesPerTexel(mtlPixFmt) != 4) { + bytesPerRow -= buffImgWd; + bytesPerImg -= buffImgWd * buffImgHt; + } + blit_options |= MTLBlitOptionDepthFromDepthStencil; + } else if (want_stencil && !want_depth) { + bytesPerRow = buffImgWd; + bytesPerImg = buffImgWd * buffImgHt; + blit_options |= MTLBlitOptionStencilFromDepthStencil; + } + } + + if (!isArrayTexture(texture.textureType)) { + bytesPerImg = 0; + } + + if (p_source == CopySource::Buffer) { + for (uint32_t lyrIdx = 0; lyrIdx < region.texture_subresources.layer_count; lyrIdx++) { + [enc copyFromBuffer:buffer + sourceOffset:region.buffer_offset + (bytesPerImg * lyrIdx) + sourceBytesPerRow:bytesPerRow + sourceBytesPerImage:bytesPerImg + sourceSize:txt_size + toTexture:texture + destinationSlice:region.texture_subresources.base_layer + lyrIdx + destinationLevel:mip_level + destinationOrigin:txt_origin + options:blit_options]; + } + } else { + for (uint32_t lyrIdx = 0; lyrIdx < region.texture_subresources.layer_count; lyrIdx++) { + [enc copyFromTexture:texture + sourceSlice:region.texture_subresources.base_layer + lyrIdx + sourceLevel:mip_level + sourceOrigin:txt_origin + sourceSize:txt_size + toBuffer:buffer + destinationOffset:region.buffer_offset + (bytesPerImg * lyrIdx) + destinationBytesPerRow:bytesPerRow + destinationBytesPerImage:bytesPerImg + options:blit_options]; + } + } + } +} + +void RenderingDeviceDriverMetal::command_copy_buffer_to_texture(CommandBufferID p_cmd_buffer, BufferID p_src_buffer, TextureID p_dst_texture, TextureLayout p_dst_texture_layout, VectorView p_regions) { + _copy_texture_buffer(p_cmd_buffer, CopySource::Buffer, p_dst_texture, p_src_buffer, p_regions); +} + +void RenderingDeviceDriverMetal::command_copy_texture_to_buffer(CommandBufferID p_cmd_buffer, TextureID p_src_texture, TextureLayout p_src_texture_layout, BufferID p_dst_buffer, VectorView p_regions) { + _copy_texture_buffer(p_cmd_buffer, CopySource::Texture, p_src_texture, p_dst_buffer, p_regions); +} + +#pragma mark - Pipeline + +void RenderingDeviceDriverMetal::pipeline_free(PipelineID p_pipeline_id) { + MDPipeline *obj = (MDPipeline *)(p_pipeline_id.id); + delete obj; +} + +// ----- BINDING ----- + +void RenderingDeviceDriverMetal::command_bind_push_constants(CommandBufferID p_cmd_buffer, ShaderID p_shader, uint32_t p_dst_first_index, VectorView p_data) { + MDCommandBuffer *cb = (MDCommandBuffer *)(p_cmd_buffer.id); + MDShader *shader = (MDShader *)(p_shader.id); + shader->encode_push_constant_data(p_data, cb); +} + +// ----- CACHE ----- + +String RenderingDeviceDriverMetal::_pipeline_get_cache_path() const { + String path = OS::get_singleton()->get_user_data_dir() + "/metal/pipelines"; + path += "." + context_device.name.validate_filename().replace(" ", "_").to_lower(); + if (Engine::get_singleton()->is_editor_hint()) { + path += ".editor"; + } + path += ".cache"; + + return path; +} + +bool RenderingDeviceDriverMetal::pipeline_cache_create(const Vector &p_data) { + return false; + CharString path = _pipeline_get_cache_path().utf8(); + NSString *nPath = [[NSString alloc] initWithBytesNoCopy:path.ptrw() + length:path.length() + encoding:NSUTF8StringEncoding + freeWhenDone:NO]; + MTLBinaryArchiveDescriptor *desc = [MTLBinaryArchiveDescriptor new]; + if ([[NSFileManager defaultManager] fileExistsAtPath:nPath]) { + desc.url = [NSURL fileURLWithPath:nPath]; + } + NSError *error = nil; + archive = [device newBinaryArchiveWithDescriptor:desc error:&error]; + return true; +} + +void RenderingDeviceDriverMetal::pipeline_cache_free() { + archive = nil; +} + +size_t RenderingDeviceDriverMetal::pipeline_cache_query_size() { + return archive_count * 1024; +} + +Vector RenderingDeviceDriverMetal::pipeline_cache_serialize() { + if (!archive) { + return Vector(); + } + + CharString path = _pipeline_get_cache_path().utf8(); + + NSString *nPath = [[NSString alloc] initWithBytesNoCopy:path.ptrw() + length:path.length() + encoding:NSUTF8StringEncoding + freeWhenDone:NO]; + NSURL *target = [NSURL fileURLWithPath:nPath]; + NSError *error = nil; + if ([archive serializeToURL:target error:&error]) { + return Vector(); + } else { + print_line(error.localizedDescription.UTF8String); + return Vector(); + } +} + +#pragma mark - Rendering + +// ----- SUBPASS ----- + +RDD::RenderPassID RenderingDeviceDriverMetal::render_pass_create(VectorView p_attachments, VectorView p_subpasses, VectorView p_subpass_dependencies, uint32_t p_view_count) { + PixelFormats &pf = *pixel_formats; + + size_t subpass_count = p_subpasses.size(); + + Vector subpasses; + subpasses.resize(subpass_count); + for (uint32_t i = 0; i < subpass_count; i++) { + MDSubpass &subpass = subpasses.write[i]; + subpass.subpass_index = i; + subpass.input_references = p_subpasses[i].input_references; + subpass.color_references = p_subpasses[i].color_references; + subpass.depth_stencil_reference = p_subpasses[i].depth_stencil_reference; + subpass.resolve_references = p_subpasses[i].resolve_references; + } + + static const MTLLoadAction LOAD_ACTIONS[] = { + [ATTACHMENT_LOAD_OP_LOAD] = MTLLoadActionLoad, + [ATTACHMENT_LOAD_OP_CLEAR] = MTLLoadActionClear, + [ATTACHMENT_LOAD_OP_DONT_CARE] = MTLLoadActionDontCare, + }; + + static const MTLStoreAction STORE_ACTIONS[] = { + [ATTACHMENT_STORE_OP_STORE] = MTLStoreActionStore, + [ATTACHMENT_STORE_OP_DONT_CARE] = MTLStoreActionDontCare, + }; + + Vector attachments; + attachments.resize(p_attachments.size()); + + for (uint32_t i = 0; i < p_attachments.size(); i++) { + Attachment const &a = p_attachments[i]; + MDAttachment &mda = attachments.write[i]; + MTLPixelFormat format = pf.getMTLPixelFormat(a.format); + mda.format = format; + if (a.samples > TEXTURE_SAMPLES_1) { + mda.samples = (*metal_device_properties).find_nearest_supported_sample_count(a.samples); + } + mda.loadAction = LOAD_ACTIONS[a.load_op]; + mda.storeAction = STORE_ACTIONS[a.store_op]; + bool is_depth = pf.isDepthFormat(format); + if (is_depth) { + mda.type |= MDAttachmentType::Depth; + } + bool is_stencil = pf.isStencilFormat(format); + if (is_stencil) { + mda.type |= MDAttachmentType::Stencil; + mda.stencilLoadAction = LOAD_ACTIONS[a.stencil_load_op]; + mda.stencilStoreAction = STORE_ACTIONS[a.stencil_store_op]; + } + if (!is_depth && !is_stencil) { + mda.type |= MDAttachmentType::Color; + } + } + MDRenderPass *obj = new MDRenderPass(attachments, subpasses); + return RenderPassID(obj); +} + +void RenderingDeviceDriverMetal::render_pass_free(RenderPassID p_render_pass) { + MDRenderPass *obj = (MDRenderPass *)(p_render_pass.id); + delete obj; +} + +// ----- COMMANDS ----- + +void RenderingDeviceDriverMetal::command_begin_render_pass(CommandBufferID p_cmd_buffer, RenderPassID p_render_pass, FramebufferID p_framebuffer, CommandBufferType p_cmd_buffer_type, const Rect2i &p_rect, VectorView p_clear_values) { + MDCommandBuffer *cb = (MDCommandBuffer *)(p_cmd_buffer.id); + cb->render_begin_pass(p_render_pass, p_framebuffer, p_cmd_buffer_type, p_rect, p_clear_values); +} + +void RenderingDeviceDriverMetal::command_end_render_pass(CommandBufferID p_cmd_buffer) { + MDCommandBuffer *cb = (MDCommandBuffer *)(p_cmd_buffer.id); + cb->render_end_pass(); +} + +void RenderingDeviceDriverMetal::command_next_render_subpass(CommandBufferID p_cmd_buffer, CommandBufferType p_cmd_buffer_type) { + MDCommandBuffer *cb = (MDCommandBuffer *)(p_cmd_buffer.id); + cb->render_next_subpass(); +} + +void RenderingDeviceDriverMetal::command_render_set_viewport(CommandBufferID p_cmd_buffer, VectorView p_viewports) { + MDCommandBuffer *cb = (MDCommandBuffer *)(p_cmd_buffer.id); + cb->render_set_viewport(p_viewports); +} + +void RenderingDeviceDriverMetal::command_render_set_scissor(CommandBufferID p_cmd_buffer, VectorView p_scissors) { + MDCommandBuffer *cb = (MDCommandBuffer *)(p_cmd_buffer.id); + cb->render_set_scissor(p_scissors); +} + +void RenderingDeviceDriverMetal::command_render_clear_attachments(CommandBufferID p_cmd_buffer, VectorView p_attachment_clears, VectorView p_rects) { + MDCommandBuffer *cb = (MDCommandBuffer *)(p_cmd_buffer.id); + cb->render_clear_attachments(p_attachment_clears, p_rects); +} + +void RenderingDeviceDriverMetal::command_bind_render_pipeline(CommandBufferID p_cmd_buffer, PipelineID p_pipeline) { + MDCommandBuffer *cb = (MDCommandBuffer *)(p_cmd_buffer.id); + cb->bind_pipeline(p_pipeline); +} + +void RenderingDeviceDriverMetal::command_bind_render_uniform_set(CommandBufferID p_cmd_buffer, UniformSetID p_uniform_set, ShaderID p_shader, uint32_t p_set_index) { + MDCommandBuffer *cb = (MDCommandBuffer *)(p_cmd_buffer.id); + cb->render_bind_uniform_set(p_uniform_set, p_shader, p_set_index); +} + +void RenderingDeviceDriverMetal::command_render_draw(CommandBufferID p_cmd_buffer, uint32_t p_vertex_count, uint32_t p_instance_count, uint32_t p_base_vertex, uint32_t p_first_instance) { + MDCommandBuffer *cb = (MDCommandBuffer *)(p_cmd_buffer.id); + cb->render_draw(p_vertex_count, p_instance_count, p_base_vertex, p_first_instance); +} + +void RenderingDeviceDriverMetal::command_render_draw_indexed(CommandBufferID p_cmd_buffer, uint32_t p_index_count, uint32_t p_instance_count, uint32_t p_first_index, int32_t p_vertex_offset, uint32_t p_first_instance) { + MDCommandBuffer *cb = (MDCommandBuffer *)(p_cmd_buffer.id); + cb->render_draw_indexed(p_index_count, p_instance_count, p_first_index, p_vertex_offset, p_first_instance); +} + +void RenderingDeviceDriverMetal::command_render_draw_indexed_indirect(CommandBufferID p_cmd_buffer, BufferID p_indirect_buffer, uint64_t p_offset, uint32_t p_draw_count, uint32_t p_stride) { + MDCommandBuffer *cb = (MDCommandBuffer *)(p_cmd_buffer.id); + cb->render_draw_indexed_indirect(p_indirect_buffer, p_offset, p_draw_count, p_stride); +} + +void RenderingDeviceDriverMetal::command_render_draw_indexed_indirect_count(CommandBufferID p_cmd_buffer, BufferID p_indirect_buffer, uint64_t p_offset, BufferID p_count_buffer, uint64_t p_count_buffer_offset, uint32_t p_max_draw_count, uint32_t p_stride) { + MDCommandBuffer *cb = (MDCommandBuffer *)(p_cmd_buffer.id); + cb->render_draw_indexed_indirect_count(p_indirect_buffer, p_offset, p_count_buffer, p_count_buffer_offset, p_max_draw_count, p_stride); +} + +void RenderingDeviceDriverMetal::command_render_draw_indirect(CommandBufferID p_cmd_buffer, BufferID p_indirect_buffer, uint64_t p_offset, uint32_t p_draw_count, uint32_t p_stride) { + MDCommandBuffer *cb = (MDCommandBuffer *)(p_cmd_buffer.id); + cb->render_draw_indirect(p_indirect_buffer, p_offset, p_draw_count, p_stride); +} + +void RenderingDeviceDriverMetal::command_render_draw_indirect_count(CommandBufferID p_cmd_buffer, BufferID p_indirect_buffer, uint64_t p_offset, BufferID p_count_buffer, uint64_t p_count_buffer_offset, uint32_t p_max_draw_count, uint32_t p_stride) { + MDCommandBuffer *cb = (MDCommandBuffer *)(p_cmd_buffer.id); + cb->render_draw_indirect_count(p_indirect_buffer, p_offset, p_count_buffer, p_count_buffer_offset, p_max_draw_count, p_stride); +} + +void RenderingDeviceDriverMetal::command_render_bind_vertex_buffers(CommandBufferID p_cmd_buffer, uint32_t p_binding_count, const BufferID *p_buffers, const uint64_t *p_offsets) { + MDCommandBuffer *cb = (MDCommandBuffer *)(p_cmd_buffer.id); + cb->render_bind_vertex_buffers(p_binding_count, p_buffers, p_offsets); +} + +void RenderingDeviceDriverMetal::command_render_bind_index_buffer(CommandBufferID p_cmd_buffer, BufferID p_buffer, IndexBufferFormat p_format, uint64_t p_offset) { + MDCommandBuffer *cb = (MDCommandBuffer *)(p_cmd_buffer.id); + cb->render_bind_index_buffer(p_buffer, p_format, p_offset); +} + +void RenderingDeviceDriverMetal::command_render_set_blend_constants(CommandBufferID p_cmd_buffer, const Color &p_constants) { + MDCommandBuffer *cb = (MDCommandBuffer *)(p_cmd_buffer.id); + cb->render_set_blend_constants(p_constants); +} + +void RenderingDeviceDriverMetal::command_render_set_line_width(CommandBufferID p_cmd_buffer, float p_width) { + if (!Math::is_equal_approx(p_width, 1.0f)) { + ERR_FAIL_MSG("Setting line widths other than 1.0 is not supported by the Metal rendering driver."); + } +} + +// ----- PIPELINE ----- + +RenderingDeviceDriverMetal::Result> RenderingDeviceDriverMetal::_create_function(id p_library, NSString *p_name, VectorView &p_specialization_constants) { + id function = [p_library newFunctionWithName:p_name]; + ERR_FAIL_NULL_V_MSG(function, ERR_CANT_CREATE, "No function named main0"); + + if (function.functionConstantsDictionary.count == 0) { + return function; + } + + NSArray *constants = function.functionConstantsDictionary.allValues; + bool is_sorted = true; + for (uint32_t i = 1; i < constants.count; i++) { + if (constants[i - 1].index < constants[i].index) { + is_sorted = false; + break; + } + } + + if (!is_sorted) { + constants = [constants sortedArrayUsingComparator:^NSComparisonResult(MTLFunctionConstant *a, MTLFunctionConstant *b) { + if (a.index < b.index) { + return NSOrderedAscending; + } else if (a.index > b.index) { + return NSOrderedDescending; + } else { + return NSOrderedSame; + } + }]; + } + + MTLFunctionConstantValues *constantValues = [MTLFunctionConstantValues new]; + uint32_t i = 0; + uint32_t j = 0; + while (i < constants.count && j < p_specialization_constants.size()) { + MTLFunctionConstant *curr = constants[i]; + PipelineSpecializationConstant const &sc = p_specialization_constants[j]; + if (curr.index == sc.constant_id) { + switch (curr.type) { + case MTLDataTypeBool: + case MTLDataTypeFloat: + case MTLDataTypeInt: + case MTLDataTypeUInt: { + [constantValues setConstantValue:&sc.int_value + type:curr.type + atIndex:sc.constant_id]; + } break; + default: + ERR_FAIL_V_MSG(function, "Invalid specialization constant type"); + } + i++; + j++; + } else if (curr.index < sc.constant_id) { + i++; + } else { + j++; + } + } + + if (i != constants.count) { + MTLFunctionConstant *curr = constants[i]; + if (curr.index == R32UI_ALIGNMENT_CONSTANT_ID) { + uint32_t alignment = 16; // TODO(sgc): is this always correct? + [constantValues setConstantValue:&alignment + type:curr.type + atIndex:curr.index]; + i++; + } + } + + NSError *err = nil; + function = [p_library newFunctionWithName:@"main0" + constantValues:constantValues + error:&err]; + ERR_FAIL_NULL_V_MSG(function, ERR_CANT_CREATE, String("specialized function failed: ") + err.localizedDescription.UTF8String); + + return function; +} + +// RDD::PolygonCullMode == MTLCullMode. +static_assert(ENUM_MEMBERS_EQUAL(RDD::POLYGON_CULL_DISABLED, MTLCullModeNone)); +static_assert(ENUM_MEMBERS_EQUAL(RDD::POLYGON_CULL_FRONT, MTLCullModeFront)); +static_assert(ENUM_MEMBERS_EQUAL(RDD::POLYGON_CULL_BACK, MTLCullModeBack)); + +// RDD::StencilOperation == MTLStencilOperation. +static_assert(ENUM_MEMBERS_EQUAL(RDD::STENCIL_OP_KEEP, MTLStencilOperationKeep)); +static_assert(ENUM_MEMBERS_EQUAL(RDD::STENCIL_OP_ZERO, MTLStencilOperationZero)); +static_assert(ENUM_MEMBERS_EQUAL(RDD::STENCIL_OP_REPLACE, MTLStencilOperationReplace)); +static_assert(ENUM_MEMBERS_EQUAL(RDD::STENCIL_OP_INCREMENT_AND_CLAMP, MTLStencilOperationIncrementClamp)); +static_assert(ENUM_MEMBERS_EQUAL(RDD::STENCIL_OP_DECREMENT_AND_CLAMP, MTLStencilOperationDecrementClamp)); +static_assert(ENUM_MEMBERS_EQUAL(RDD::STENCIL_OP_INVERT, MTLStencilOperationInvert)); +static_assert(ENUM_MEMBERS_EQUAL(RDD::STENCIL_OP_INCREMENT_AND_WRAP, MTLStencilOperationIncrementWrap)); +static_assert(ENUM_MEMBERS_EQUAL(RDD::STENCIL_OP_DECREMENT_AND_WRAP, MTLStencilOperationDecrementWrap)); + +// RDD::BlendOperation == MTLBlendOperation. +static_assert(ENUM_MEMBERS_EQUAL(RDD::BLEND_OP_ADD, MTLBlendOperationAdd)); +static_assert(ENUM_MEMBERS_EQUAL(RDD::BLEND_OP_SUBTRACT, MTLBlendOperationSubtract)); +static_assert(ENUM_MEMBERS_EQUAL(RDD::BLEND_OP_REVERSE_SUBTRACT, MTLBlendOperationReverseSubtract)); +static_assert(ENUM_MEMBERS_EQUAL(RDD::BLEND_OP_MINIMUM, MTLBlendOperationMin)); +static_assert(ENUM_MEMBERS_EQUAL(RDD::BLEND_OP_MAXIMUM, MTLBlendOperationMax)); + +RDD::PipelineID RenderingDeviceDriverMetal::render_pipeline_create( + ShaderID p_shader, + VertexFormatID p_vertex_format, + RenderPrimitive p_render_primitive, + PipelineRasterizationState p_rasterization_state, + PipelineMultisampleState p_multisample_state, + PipelineDepthStencilState p_depth_stencil_state, + PipelineColorBlendState p_blend_state, + VectorView p_color_attachments, + BitField p_dynamic_state, + RenderPassID p_render_pass, + uint32_t p_render_subpass, + VectorView p_specialization_constants) { + MDRenderShader *shader = (MDRenderShader *)(p_shader.id); + MTLVertexDescriptor *vert_desc = rid::get(p_vertex_format); + MDRenderPass *pass = (MDRenderPass *)(p_render_pass.id); + + MTLRenderPipelineDescriptor *desc = [MTLRenderPipelineDescriptor new]; + + { + MDSubpass const &subpass = pass->subpasses[p_render_subpass]; + for (uint32_t i = 0; i < subpass.color_references.size(); i++) { + uint32_t attachment = subpass.color_references[i].attachment; + if (attachment != AttachmentReference::UNUSED) { + MDAttachment const &a = pass->attachments[attachment]; + desc.colorAttachments[i].pixelFormat = a.format; + } + } + + if (subpass.depth_stencil_reference.attachment != AttachmentReference::UNUSED) { + uint32_t attachment = subpass.depth_stencil_reference.attachment; + MDAttachment const &a = pass->attachments[attachment]; + + if (a.type & MDAttachmentType::Depth) { + desc.depthAttachmentPixelFormat = a.format; + } + + if (a.type & MDAttachmentType::Stencil) { + desc.stencilAttachmentPixelFormat = a.format; + } + } + } + + desc.vertexDescriptor = vert_desc; + desc.label = [NSString stringWithUTF8String:shader->name.get_data()]; + + // Input assembly & tessellation. + + MDRenderPipeline *pipeline = new MDRenderPipeline(); + + switch (p_render_primitive) { + case RENDER_PRIMITIVE_POINTS: + desc.inputPrimitiveTopology = MTLPrimitiveTopologyClassPoint; + break; + case RENDER_PRIMITIVE_LINES: + case RENDER_PRIMITIVE_LINES_WITH_ADJACENCY: + case RENDER_PRIMITIVE_LINESTRIPS_WITH_ADJACENCY: + case RENDER_PRIMITIVE_LINESTRIPS: + desc.inputPrimitiveTopology = MTLPrimitiveTopologyClassLine; + break; + case RENDER_PRIMITIVE_TRIANGLES: + case RENDER_PRIMITIVE_TRIANGLE_STRIPS: + case RENDER_PRIMITIVE_TRIANGLES_WITH_ADJACENCY: + case RENDER_PRIMITIVE_TRIANGLE_STRIPS_WITH_AJACENCY: + case RENDER_PRIMITIVE_TRIANGLE_STRIPS_WITH_RESTART_INDEX: + desc.inputPrimitiveTopology = MTLPrimitiveTopologyClassTriangle; + break; + case RENDER_PRIMITIVE_TESSELATION_PATCH: + desc.maxTessellationFactor = p_rasterization_state.patch_control_points; + desc.tessellationPartitionMode = MTLTessellationPartitionModeInteger; + ERR_FAIL_V_MSG(PipelineID(), "tessellation not implemented"); + break; + case RENDER_PRIMITIVE_MAX: + default: + desc.inputPrimitiveTopology = MTLPrimitiveTopologyClassUnspecified; + break; + } + + switch (p_render_primitive) { + case RENDER_PRIMITIVE_POINTS: + pipeline->raster_state.render_primitive = MTLPrimitiveTypePoint; + break; + case RENDER_PRIMITIVE_LINES: + case RENDER_PRIMITIVE_LINES_WITH_ADJACENCY: + pipeline->raster_state.render_primitive = MTLPrimitiveTypeLine; + break; + case RENDER_PRIMITIVE_LINESTRIPS: + case RENDER_PRIMITIVE_LINESTRIPS_WITH_ADJACENCY: + pipeline->raster_state.render_primitive = MTLPrimitiveTypeLineStrip; + break; + case RENDER_PRIMITIVE_TRIANGLES: + case RENDER_PRIMITIVE_TRIANGLES_WITH_ADJACENCY: + pipeline->raster_state.render_primitive = MTLPrimitiveTypeTriangle; + break; + case RENDER_PRIMITIVE_TRIANGLE_STRIPS: + case RENDER_PRIMITIVE_TRIANGLE_STRIPS_WITH_AJACENCY: + case RENDER_PRIMITIVE_TRIANGLE_STRIPS_WITH_RESTART_INDEX: + pipeline->raster_state.render_primitive = MTLPrimitiveTypeTriangleStrip; + break; + default: + break; + } + + // Rasterization. + desc.rasterizationEnabled = !p_rasterization_state.discard_primitives; + pipeline->raster_state.clip_mode = p_rasterization_state.enable_depth_clamp ? MTLDepthClipModeClamp : MTLDepthClipModeClip; + pipeline->raster_state.fill_mode = p_rasterization_state.wireframe ? MTLTriangleFillModeLines : MTLTriangleFillModeFill; + + static const MTLCullMode CULL_MODE[3] = { + MTLCullModeNone, + MTLCullModeFront, + MTLCullModeBack, + }; + pipeline->raster_state.cull_mode = CULL_MODE[p_rasterization_state.cull_mode]; + pipeline->raster_state.winding = (p_rasterization_state.front_face == POLYGON_FRONT_FACE_CLOCKWISE) ? MTLWindingClockwise : MTLWindingCounterClockwise; + pipeline->raster_state.depth_bias.enabled = p_rasterization_state.depth_bias_enabled; + pipeline->raster_state.depth_bias.depth_bias = p_rasterization_state.depth_bias_constant_factor; + pipeline->raster_state.depth_bias.slope_scale = p_rasterization_state.depth_bias_slope_factor; + pipeline->raster_state.depth_bias.clamp = p_rasterization_state.depth_bias_clamp; + // In Metal there is no line width. + if (!Math::is_equal_approx(p_rasterization_state.line_width, 1.0f)) { + WARN_PRINT("unsupported: line width"); + } + + // Multisample. + if (p_multisample_state.enable_sample_shading) { + WARN_PRINT("unsupported: multi-sample shading"); + } + + if (p_multisample_state.sample_count > TEXTURE_SAMPLES_1) { + pipeline->sample_count = (*metal_device_properties).find_nearest_supported_sample_count(p_multisample_state.sample_count); + } + desc.rasterSampleCount = static_cast(pipeline->sample_count); + desc.alphaToCoverageEnabled = p_multisample_state.enable_alpha_to_coverage; + desc.alphaToOneEnabled = p_multisample_state.enable_alpha_to_one; + + // Depth stencil. + if (p_depth_stencil_state.enable_depth_test && desc.depthAttachmentPixelFormat != MTLPixelFormatInvalid) { + pipeline->raster_state.depth_test.enabled = true; + MTLDepthStencilDescriptor *ds_desc = [MTLDepthStencilDescriptor new]; + ds_desc.depthWriteEnabled = p_depth_stencil_state.enable_depth_write; + ds_desc.depthCompareFunction = COMPARE_OPERATORS[p_depth_stencil_state.depth_compare_operator]; + if (p_depth_stencil_state.enable_depth_range) { + WARN_PRINT("unsupported: depth range"); + } + + if (p_depth_stencil_state.enable_stencil) { + pipeline->raster_state.stencil.front_reference = p_depth_stencil_state.front_op.reference; + pipeline->raster_state.stencil.back_reference = p_depth_stencil_state.back_op.reference; + + { + // Front. + MTLStencilDescriptor *sd = [MTLStencilDescriptor new]; + sd.stencilFailureOperation = STENCIL_OPERATIONS[p_depth_stencil_state.front_op.fail]; + sd.depthStencilPassOperation = STENCIL_OPERATIONS[p_depth_stencil_state.front_op.pass]; + sd.depthFailureOperation = STENCIL_OPERATIONS[p_depth_stencil_state.front_op.depth_fail]; + sd.stencilCompareFunction = COMPARE_OPERATORS[p_depth_stencil_state.front_op.compare]; + sd.readMask = p_depth_stencil_state.front_op.compare_mask; + sd.writeMask = p_depth_stencil_state.front_op.write_mask; + ds_desc.frontFaceStencil = sd; + } + { + // Back. + MTLStencilDescriptor *sd = [MTLStencilDescriptor new]; + sd.stencilFailureOperation = STENCIL_OPERATIONS[p_depth_stencil_state.back_op.fail]; + sd.depthStencilPassOperation = STENCIL_OPERATIONS[p_depth_stencil_state.back_op.pass]; + sd.depthFailureOperation = STENCIL_OPERATIONS[p_depth_stencil_state.back_op.depth_fail]; + sd.stencilCompareFunction = COMPARE_OPERATORS[p_depth_stencil_state.back_op.compare]; + sd.readMask = p_depth_stencil_state.back_op.compare_mask; + sd.writeMask = p_depth_stencil_state.back_op.write_mask; + ds_desc.backFaceStencil = sd; + } + } + + pipeline->depth_stencil = [device newDepthStencilStateWithDescriptor:ds_desc]; + ERR_FAIL_NULL_V_MSG(pipeline->depth_stencil, PipelineID(), "Failed to create depth stencil state"); + } else { + // TODO(sgc): FB13671991 raised as Apple docs state calling setDepthStencilState:nil is valid, but currently generates an exception + pipeline->depth_stencil = get_resource_cache().get_depth_stencil_state(false, false); + } + + // Blend state. + { + for (uint32_t i = 0; i < p_color_attachments.size(); i++) { + if (p_color_attachments[i] == ATTACHMENT_UNUSED) { + continue; + } + + const PipelineColorBlendState::Attachment &bs = p_blend_state.attachments[i]; + + MTLRenderPipelineColorAttachmentDescriptor *ca_desc = desc.colorAttachments[p_color_attachments[i]]; + ca_desc.blendingEnabled = bs.enable_blend; + + ca_desc.sourceRGBBlendFactor = BLEND_FACTORS[bs.src_color_blend_factor]; + ca_desc.destinationRGBBlendFactor = BLEND_FACTORS[bs.dst_color_blend_factor]; + ca_desc.rgbBlendOperation = BLEND_OPERATIONS[bs.color_blend_op]; + + ca_desc.sourceAlphaBlendFactor = BLEND_FACTORS[bs.src_alpha_blend_factor]; + ca_desc.destinationAlphaBlendFactor = BLEND_FACTORS[bs.dst_alpha_blend_factor]; + ca_desc.alphaBlendOperation = BLEND_OPERATIONS[bs.alpha_blend_op]; + + ca_desc.writeMask = MTLColorWriteMaskNone; + if (bs.write_r) { + ca_desc.writeMask |= MTLColorWriteMaskRed; + } + if (bs.write_g) { + ca_desc.writeMask |= MTLColorWriteMaskGreen; + } + if (bs.write_b) { + ca_desc.writeMask |= MTLColorWriteMaskBlue; + } + if (bs.write_a) { + ca_desc.writeMask |= MTLColorWriteMaskAlpha; + } + } + + pipeline->raster_state.blend.r = p_blend_state.blend_constant.r; + pipeline->raster_state.blend.g = p_blend_state.blend_constant.g; + pipeline->raster_state.blend.b = p_blend_state.blend_constant.b; + pipeline->raster_state.blend.a = p_blend_state.blend_constant.a; + } + + // Dynamic state. + + if (p_dynamic_state.has_flag(DYNAMIC_STATE_DEPTH_BIAS)) { + pipeline->raster_state.depth_bias.enabled = true; + } + + if (p_dynamic_state.has_flag(DYNAMIC_STATE_BLEND_CONSTANTS)) { + pipeline->raster_state.blend.enabled = true; + } + + if (p_dynamic_state.has_flag(DYNAMIC_STATE_DEPTH_BOUNDS)) { + // TODO(sgc): ?? + } + + if (p_dynamic_state.has_flag(DYNAMIC_STATE_STENCIL_COMPARE_MASK)) { + // TODO(sgc): ?? + } + + if (p_dynamic_state.has_flag(DYNAMIC_STATE_STENCIL_WRITE_MASK)) { + // TODO(sgc): ?? + } + + if (p_dynamic_state.has_flag(DYNAMIC_STATE_STENCIL_REFERENCE)) { + pipeline->raster_state.stencil.enabled = true; + } + + if (shader->vert != nil) { + Result> function_or_err = _create_function(shader->vert, @"main0", p_specialization_constants); + ERR_FAIL_COND_V(std::holds_alternative(function_or_err), PipelineID()); + desc.vertexFunction = std::get>(function_or_err); + } + + if (shader->frag != nil) { + Result> function_or_err = _create_function(shader->frag, @"main0", p_specialization_constants); + ERR_FAIL_COND_V(std::holds_alternative(function_or_err), PipelineID()); + desc.fragmentFunction = std::get>(function_or_err); + } + + if (archive) { + desc.binaryArchives = @[ archive ]; + } + + NSError *error = nil; + pipeline->state = [device newRenderPipelineStateWithDescriptor:desc + error:&error]; + pipeline->shader = shader; + + ERR_FAIL_COND_V_MSG(error != nil, PipelineID(), ([NSString stringWithFormat:@"error creating pipeline: %@", error.localizedDescription].UTF8String)); + + if (archive) { + if ([archive addRenderPipelineFunctionsWithDescriptor:desc error:&error]) { + archive_count += 1; + } else { + print_error(error.localizedDescription.UTF8String); + } + } + + return PipelineID(pipeline); +} + +#pragma mark - Compute + +// ----- COMMANDS ----- + +void RenderingDeviceDriverMetal::command_bind_compute_pipeline(CommandBufferID p_cmd_buffer, PipelineID p_pipeline) { + MDCommandBuffer *cb = (MDCommandBuffer *)(p_cmd_buffer.id); + cb->bind_pipeline(p_pipeline); +} + +void RenderingDeviceDriverMetal::command_bind_compute_uniform_set(CommandBufferID p_cmd_buffer, UniformSetID p_uniform_set, ShaderID p_shader, uint32_t p_set_index) { + MDCommandBuffer *cb = (MDCommandBuffer *)(p_cmd_buffer.id); + cb->compute_bind_uniform_set(p_uniform_set, p_shader, p_set_index); +} + +void RenderingDeviceDriverMetal::command_compute_dispatch(CommandBufferID p_cmd_buffer, uint32_t p_x_groups, uint32_t p_y_groups, uint32_t p_z_groups) { + MDCommandBuffer *cb = (MDCommandBuffer *)(p_cmd_buffer.id); + cb->compute_dispatch(p_x_groups, p_y_groups, p_z_groups); +} + +void RenderingDeviceDriverMetal::command_compute_dispatch_indirect(CommandBufferID p_cmd_buffer, BufferID p_indirect_buffer, uint64_t p_offset) { + MDCommandBuffer *cb = (MDCommandBuffer *)(p_cmd_buffer.id); + cb->compute_dispatch_indirect(p_indirect_buffer, p_offset); +} + +// ----- PIPELINE ----- + +RDD::PipelineID RenderingDeviceDriverMetal::compute_pipeline_create(ShaderID p_shader, VectorView p_specialization_constants) { + MDComputeShader *shader = (MDComputeShader *)(p_shader.id); + + id library = shader->kernel; + + Result> function_or_err = _create_function(library, @"main0", p_specialization_constants); + ERR_FAIL_COND_V(std::holds_alternative(function_or_err), PipelineID()); + id function = std::get>(function_or_err); + + MTLComputePipelineDescriptor *desc = [MTLComputePipelineDescriptor new]; + desc.computeFunction = function; + if (archive) { + desc.binaryArchives = @[ archive ]; + } + + NSError *error; + id state = [device newComputePipelineStateWithDescriptor:desc + options:MTLPipelineOptionNone + reflection:nil + error:&error]; + ERR_FAIL_COND_V_MSG(error != nil, PipelineID(), ([NSString stringWithFormat:@"error creating pipeline: %@", error.localizedDescription].UTF8String)); + + MDComputePipeline *pipeline = new MDComputePipeline(state); + pipeline->compute_state.local = shader->local; + pipeline->shader = shader; + + if (archive) { + if ([archive addComputePipelineFunctionsWithDescriptor:desc error:&error]) { + archive_count += 1; + } else { + print_error(error.localizedDescription.UTF8String); + } + } + + return PipelineID(pipeline); +} + +#pragma mark - Queries + +// ----- TIMESTAMP ----- + +RDD::QueryPoolID RenderingDeviceDriverMetal::timestamp_query_pool_create(uint32_t p_query_count) { + return QueryPoolID(1); +} + +void RenderingDeviceDriverMetal::timestamp_query_pool_free(QueryPoolID p_pool_id) { +} + +void RenderingDeviceDriverMetal::timestamp_query_pool_get_results(QueryPoolID p_pool_id, uint32_t p_query_count, uint64_t *r_results) { + // Metal doesn't support timestamp queries, so we just clear the buffer. + bzero(r_results, p_query_count * sizeof(uint64_t)); +} + +uint64_t RenderingDeviceDriverMetal::timestamp_query_result_to_time(uint64_t p_result) { + return p_result; +} + +void RenderingDeviceDriverMetal::command_timestamp_query_pool_reset(CommandBufferID p_cmd_buffer, QueryPoolID p_pool_id, uint32_t p_query_count) { +} + +void RenderingDeviceDriverMetal::command_timestamp_write(CommandBufferID p_cmd_buffer, QueryPoolID p_pool_id, uint32_t p_index) { +} + +#pragma mark - Labels + +void RenderingDeviceDriverMetal::command_begin_label(CommandBufferID p_cmd_buffer, const char *p_label_name, const Color &p_color) { + MDCommandBuffer *cb = (MDCommandBuffer *)(p_cmd_buffer.id); + NSString *s = [[NSString alloc] initWithBytesNoCopy:(void *)p_label_name length:strlen(p_label_name) encoding:NSUTF8StringEncoding freeWhenDone:NO]; + [cb->get_command_buffer() pushDebugGroup:s]; +} + +void RenderingDeviceDriverMetal::command_end_label(CommandBufferID p_cmd_buffer) { + MDCommandBuffer *cb = (MDCommandBuffer *)(p_cmd_buffer.id); + [cb->get_command_buffer() popDebugGroup]; +} + +#pragma mark - Submission + +void RenderingDeviceDriverMetal::begin_segment(uint32_t p_frame_index, uint32_t p_frames_drawn) { +} + +void RenderingDeviceDriverMetal::end_segment() { +} + +#pragma mark - Misc + +void RenderingDeviceDriverMetal::set_object_name(ObjectType p_type, ID p_driver_id, const String &p_name) { + switch (p_type) { + case OBJECT_TYPE_TEXTURE: { + id tex = rid::get(p_driver_id); + tex.label = [NSString stringWithUTF8String:p_name.utf8().get_data()]; + } break; + case OBJECT_TYPE_SAMPLER: { + // Can't set label after creation. + } break; + case OBJECT_TYPE_BUFFER: { + id buffer = rid::get(p_driver_id); + buffer.label = [NSString stringWithUTF8String:p_name.utf8().get_data()]; + } break; + case OBJECT_TYPE_SHADER: { + MDShader *shader = (MDShader *)(p_driver_id.id); + if (MDRenderShader *rs = dynamic_cast(shader); rs != nullptr) { + rs->vert.label = [NSString stringWithUTF8String:p_name.utf8().get_data()]; + rs->frag.label = [NSString stringWithUTF8String:p_name.utf8().get_data()]; + } else if (MDComputeShader *cs = dynamic_cast(shader); cs != nullptr) { + cs->kernel.label = [NSString stringWithUTF8String:p_name.utf8().get_data()]; + } else { + DEV_ASSERT(false); + } + } break; + case OBJECT_TYPE_UNIFORM_SET: { + MDUniformSet *set = (MDUniformSet *)(p_driver_id.id); + for (KeyValue &keyval : set->bound_uniforms) { + keyval.value.buffer.label = [NSString stringWithUTF8String:p_name.utf8().get_data()]; + } + } break; + case OBJECT_TYPE_PIPELINE: { + // Can't set label after creation. + } break; + default: { + DEV_ASSERT(false); + } + } +} + +uint64_t RenderingDeviceDriverMetal::get_resource_native_handle(DriverResource p_type, ID p_driver_id) { + switch (p_type) { + case DRIVER_RESOURCE_LOGICAL_DEVICE: { + return 0; + } + case DRIVER_RESOURCE_PHYSICAL_DEVICE: { + return 0; + } + case DRIVER_RESOURCE_TOPMOST_OBJECT: { + return 0; + } + case DRIVER_RESOURCE_COMMAND_QUEUE: { + return 0; + } + case DRIVER_RESOURCE_QUEUE_FAMILY: { + return 0; + } + case DRIVER_RESOURCE_TEXTURE: { + return p_driver_id.id; + } + case DRIVER_RESOURCE_TEXTURE_VIEW: { + return p_driver_id.id; + } + case DRIVER_RESOURCE_TEXTURE_DATA_FORMAT: { + return 0; + } + case DRIVER_RESOURCE_SAMPLER: { + return p_driver_id.id; + } + case DRIVER_RESOURCE_UNIFORM_SET: + return 0; + case DRIVER_RESOURCE_BUFFER: { + return p_driver_id.id; + } + case DRIVER_RESOURCE_COMPUTE_PIPELINE: + return 0; + case DRIVER_RESOURCE_RENDER_PIPELINE: + return 0; + default: { + return 0; + } + } +} + +uint64_t RenderingDeviceDriverMetal::get_total_memory_used() { + return device.currentAllocatedSize; +} + +uint64_t RenderingDeviceDriverMetal::limit_get(Limit p_limit) { + MetalDeviceProperties const &props = (*metal_device_properties); + MetalLimits const &limits = props.limits; + +#if defined(DEV_ENABLED) +#define UNKNOWN(NAME) \ + case NAME: \ + WARN_PRINT_ONCE("Returning maximum value for unknown limit " #NAME "."); \ + return (uint64_t)1 << 30; +#else +#define UNKNOWN(NAME) \ + case NAME: \ + return (uint64_t)1 << 30 +#endif + + // clang-format off + switch (p_limit) { + case LIMIT_MAX_BOUND_UNIFORM_SETS: + return limits.maxBoundDescriptorSets; + case LIMIT_MAX_FRAMEBUFFER_COLOR_ATTACHMENTS: + return limits.maxColorAttachments; + case LIMIT_MAX_TEXTURES_PER_UNIFORM_SET: + return limits.maxTexturesPerArgumentBuffer; + case LIMIT_MAX_SAMPLERS_PER_UNIFORM_SET: + return limits.maxSamplersPerArgumentBuffer; + case LIMIT_MAX_STORAGE_BUFFERS_PER_UNIFORM_SET: + return limits.maxBuffersPerArgumentBuffer; + case LIMIT_MAX_STORAGE_IMAGES_PER_UNIFORM_SET: + return limits.maxTexturesPerArgumentBuffer; + case LIMIT_MAX_UNIFORM_BUFFERS_PER_UNIFORM_SET: + return limits.maxBuffersPerArgumentBuffer; + case LIMIT_MAX_DRAW_INDEXED_INDEX: + return limits.maxDrawIndexedIndexValue; + case LIMIT_MAX_FRAMEBUFFER_HEIGHT: + return limits.maxFramebufferHeight; + case LIMIT_MAX_FRAMEBUFFER_WIDTH: + return limits.maxFramebufferWidth; + case LIMIT_MAX_TEXTURE_ARRAY_LAYERS: + return limits.maxImageArrayLayers; + case LIMIT_MAX_TEXTURE_SIZE_1D: + return limits.maxImageDimension1D; + case LIMIT_MAX_TEXTURE_SIZE_2D: + return limits.maxImageDimension2D; + case LIMIT_MAX_TEXTURE_SIZE_3D: + return limits.maxImageDimension3D; + case LIMIT_MAX_TEXTURE_SIZE_CUBE: + return limits.maxImageDimensionCube; + case LIMIT_MAX_TEXTURES_PER_SHADER_STAGE: + return limits.maxTexturesPerArgumentBuffer; + case LIMIT_MAX_SAMPLERS_PER_SHADER_STAGE: + return limits.maxSamplersPerArgumentBuffer; + case LIMIT_MAX_STORAGE_BUFFERS_PER_SHADER_STAGE: + return limits.maxBuffersPerArgumentBuffer; + case LIMIT_MAX_STORAGE_IMAGES_PER_SHADER_STAGE: + return limits.maxTexturesPerArgumentBuffer; + case LIMIT_MAX_UNIFORM_BUFFERS_PER_SHADER_STAGE: + return limits.maxBuffersPerArgumentBuffer; + case LIMIT_MAX_PUSH_CONSTANT_SIZE: + return limits.maxBufferLength; + case LIMIT_MAX_UNIFORM_BUFFER_SIZE: + return limits.maxBufferLength; + case LIMIT_MAX_VERTEX_INPUT_ATTRIBUTE_OFFSET: + return limits.maxVertexDescriptorLayoutStride; + case LIMIT_MAX_VERTEX_INPUT_ATTRIBUTES: + return limits.maxVertexInputAttributes; + case LIMIT_MAX_VERTEX_INPUT_BINDINGS: + return limits.maxVertexInputBindings; + case LIMIT_MAX_VERTEX_INPUT_BINDING_STRIDE: + return limits.maxVertexInputBindingStride; + case LIMIT_MIN_UNIFORM_BUFFER_OFFSET_ALIGNMENT: + return limits.minUniformBufferOffsetAlignment; + case LIMIT_MAX_COMPUTE_WORKGROUP_COUNT_X: + return limits.maxComputeWorkGroupCount.width; + case LIMIT_MAX_COMPUTE_WORKGROUP_COUNT_Y: + return limits.maxComputeWorkGroupCount.height; + case LIMIT_MAX_COMPUTE_WORKGROUP_COUNT_Z: + return limits.maxComputeWorkGroupCount.depth; + case LIMIT_MAX_COMPUTE_WORKGROUP_INVOCATIONS: + return std::max({ limits.maxThreadsPerThreadGroup.width, limits.maxThreadsPerThreadGroup.height, limits.maxThreadsPerThreadGroup.depth }); + case LIMIT_MAX_COMPUTE_WORKGROUP_SIZE_X: + return limits.maxThreadsPerThreadGroup.width; + case LIMIT_MAX_COMPUTE_WORKGROUP_SIZE_Y: + return limits.maxThreadsPerThreadGroup.height; + case LIMIT_MAX_COMPUTE_WORKGROUP_SIZE_Z: + return limits.maxThreadsPerThreadGroup.depth; + case LIMIT_MAX_VIEWPORT_DIMENSIONS_X: + return limits.maxViewportDimensionX; + case LIMIT_MAX_VIEWPORT_DIMENSIONS_Y: + return limits.maxViewportDimensionY; + case LIMIT_SUBGROUP_SIZE: + // MoltenVK sets the subgroupSize to the same as the maxSubgroupSize. + return limits.maxSubgroupSize; + case LIMIT_SUBGROUP_MIN_SIZE: + return limits.minSubgroupSize; + case LIMIT_SUBGROUP_MAX_SIZE: + return limits.maxSubgroupSize; + case LIMIT_SUBGROUP_IN_SHADERS: + return (int64_t)limits.subgroupSupportedShaderStages; + case LIMIT_SUBGROUP_OPERATIONS: + return (int64_t)limits.subgroupSupportedOperations; + UNKNOWN(LIMIT_VRS_TEXEL_WIDTH); + UNKNOWN(LIMIT_VRS_TEXEL_HEIGHT); + default: + ERR_FAIL_V(0); + } + // clang-format on + return 0; +} + +uint64_t RenderingDeviceDriverMetal::api_trait_get(ApiTrait p_trait) { + switch (p_trait) { + case API_TRAIT_HONORS_PIPELINE_BARRIERS: + return 0; + default: + return RenderingDeviceDriver::api_trait_get(p_trait); + } +} + +bool RenderingDeviceDriverMetal::has_feature(Features p_feature) { + switch (p_feature) { + case SUPPORTS_MULTIVIEW: + return true; + case SUPPORTS_FSR_HALF_FLOAT: + return true; + case SUPPORTS_ATTACHMENT_VRS: + // TODO(sgc): Maybe supported via https://developer.apple.com/documentation/metal/render_passes/rendering_at_different_rasterization_rates?language=objc + // See also: + // + // * https://forum.beyond3d.com/threads/variable-rate-shading-vs-variable-rate-rasterization.62243/post-2191363 + // + return false; + case SUPPORTS_FRAGMENT_SHADER_WITH_ONLY_SIDE_EFFECTS: + return true; + default: + return false; + } +} + +const RDD::MultiviewCapabilities &RenderingDeviceDriverMetal::get_multiview_capabilities() { + return multiview_capabilities; +} + +String RenderingDeviceDriverMetal::get_api_version() const { + return vformat("%d.%d", version_major, version_minor); +} + +String RenderingDeviceDriverMetal::get_pipeline_cache_uuid() const { + return pipeline_cache_id; +} + +const RDD::Capabilities &RenderingDeviceDriverMetal::get_capabilities() const { + return capabilities; +} + +bool RenderingDeviceDriverMetal::is_composite_alpha_supported(CommandQueueID p_queue) const { + // The CAMetalLayer.opaque property is configured according to this global setting. + return OS::get_singleton()->is_layered_allowed(); +} + +size_t RenderingDeviceDriverMetal::get_texel_buffer_alignment_for_format(RDD::DataFormat p_format) const { + return [device minimumLinearTextureAlignmentForPixelFormat:pixel_formats->getMTLPixelFormat(p_format)]; +} + +size_t RenderingDeviceDriverMetal::get_texel_buffer_alignment_for_format(MTLPixelFormat p_format) const { + return [device minimumLinearTextureAlignmentForPixelFormat:p_format]; +} + +/******************/ + +RenderingDeviceDriverMetal::RenderingDeviceDriverMetal(RenderingContextDriverMetal *p_context_driver) : + context_driver(p_context_driver) { + DEV_ASSERT(p_context_driver != nullptr); +} + +RenderingDeviceDriverMetal::~RenderingDeviceDriverMetal() { + for (MDCommandBuffer *cb : command_buffers) { + delete cb; + } +} + +#pragma mark - Initialization + +Error RenderingDeviceDriverMetal::_create_device() { + device = context_driver->get_metal_device(); + + device_queue = [device newCommandQueue]; + ERR_FAIL_NULL_V(device_queue, ERR_CANT_CREATE); + + device_scope = [MTLCaptureManager.sharedCaptureManager newCaptureScopeWithCommandQueue:device_queue]; + device_scope.label = @"Godot Frame"; + [device_scope beginScope]; // Allow Xcode to capture the first frame, if desired. + + resource_cache = std::make_unique(this); + + return OK; +} + +Error RenderingDeviceDriverMetal::_check_capabilities() { + MTLCompileOptions *options = [MTLCompileOptions new]; + version_major = (options.languageVersion >> 0x10) & 0xff; + version_minor = (options.languageVersion >> 0x00) & 0xff; + + capabilities.device_family = DEVICE_METAL; + capabilities.version_major = version_major; + capabilities.version_minor = version_minor; + + return OK; +} + +Error RenderingDeviceDriverMetal::initialize(uint32_t p_device_index, uint32_t p_frame_count) { + context_device = context_driver->device_get(p_device_index); + Error err = _create_device(); + ERR_FAIL_COND_V(err, ERR_CANT_CREATE); + + err = _check_capabilities(); + ERR_FAIL_COND_V(err, ERR_CANT_CREATE); + + // Set the pipeline cache ID based on the Metal version. + pipeline_cache_id = "metal-driver-" + get_api_version(); + + metal_device_properties = memnew(MetalDeviceProperties(device)); + pixel_formats = memnew(PixelFormats(device)); + + // Check required features and abort if any of them is missing. + if (!metal_device_properties->features.imageCubeArray) { + // NOTE: Apple A11 (Apple4) GPUs support image cube arrays, which are devices from 2017 and newer. + String error_string = vformat("Your Apple GPU does not support the following features which are required to use Metal-based renderers in Godot:\n\n"); + if (!metal_device_properties->features.imageCubeArray) { + error_string += "- No support for image cube arrays.\n"; + } + +#if defined(IOS_ENABLED) + // iOS platform ports currently don't exit themselves when this method returns `ERR_CANT_CREATE`. + OS::get_singleton()->alert(error_string + "\nClick OK to exit (black screen will be visible)."); +#else + OS::get_singleton()->alert(error_string + "\nClick OK to exit."); +#endif + + return ERR_CANT_CREATE; + } + + return OK; +} diff --git a/editor/editor_node.cpp b/editor/editor_node.cpp index c7c00e04962..1f2c07c6a13 100644 --- a/editor/editor_node.cpp +++ b/editor/editor_node.cpp @@ -5032,6 +5032,8 @@ String EditorNode::_get_system_info() const { driver_name = "Vulkan"; } else if (driver_name.begins_with("opengl3")) { driver_name = "GLES3"; + } else if (driver_name == "metal") { + driver_name = "Metal"; } // Join info. diff --git a/main/main.cpp b/main/main.cpp index 9ee88af60e1..f82df786bc2 100644 --- a/main/main.cpp +++ b/main/main.cpp @@ -1935,6 +1935,7 @@ Error Main::setup(const char *execpath, int argc, char *argv[], bool p_second_ph { String driver_hints = ""; String driver_hints_with_d3d12 = ""; + String driver_hints_with_metal = ""; { Vector driver_hints_arr; @@ -1947,18 +1948,25 @@ Error Main::setup(const char *execpath, int argc, char *argv[], bool p_second_ph driver_hints_arr.push_back("d3d12"); #endif driver_hints_with_d3d12 = String(",").join(driver_hints_arr); + +#ifdef METAL_ENABLED + // Make metal the preferred and default driver. + driver_hints_arr.insert(0, "metal"); +#endif + driver_hints_with_metal = String(",").join(driver_hints_arr); } String default_driver = driver_hints.get_slice(",", 0); String default_driver_with_d3d12 = driver_hints_with_d3d12.get_slice(",", 0); + String default_driver_with_metal = driver_hints_with_metal.get_slice(",", 0); // For now everything defaults to vulkan when available. This can change in future updates. GLOBAL_DEF_RST_NOVAL("rendering/rendering_device/driver", default_driver); GLOBAL_DEF_RST_NOVAL(PropertyInfo(Variant::STRING, "rendering/rendering_device/driver.windows", PROPERTY_HINT_ENUM, driver_hints_with_d3d12), default_driver_with_d3d12); GLOBAL_DEF_RST_NOVAL(PropertyInfo(Variant::STRING, "rendering/rendering_device/driver.linuxbsd", PROPERTY_HINT_ENUM, driver_hints), default_driver); GLOBAL_DEF_RST_NOVAL(PropertyInfo(Variant::STRING, "rendering/rendering_device/driver.android", PROPERTY_HINT_ENUM, driver_hints), default_driver); - GLOBAL_DEF_RST_NOVAL(PropertyInfo(Variant::STRING, "rendering/rendering_device/driver.ios", PROPERTY_HINT_ENUM, driver_hints), default_driver); - GLOBAL_DEF_RST_NOVAL(PropertyInfo(Variant::STRING, "rendering/rendering_device/driver.macos", PROPERTY_HINT_ENUM, driver_hints), default_driver); + GLOBAL_DEF_RST_NOVAL(PropertyInfo(Variant::STRING, "rendering/rendering_device/driver.ios", PROPERTY_HINT_ENUM, driver_hints_with_metal), default_driver_with_metal); + GLOBAL_DEF_RST_NOVAL(PropertyInfo(Variant::STRING, "rendering/rendering_device/driver.macos", PROPERTY_HINT_ENUM, driver_hints_with_metal), default_driver_with_metal); GLOBAL_DEF_RST("rendering/rendering_device/fallback_to_vulkan", true); GLOBAL_DEF_RST("rendering/rendering_device/fallback_to_d3d12", true); @@ -2232,6 +2240,9 @@ Error Main::setup(const char *execpath, int argc, char *argv[], bool p_second_ph #endif #ifdef D3D12_ENABLED available_drivers.push_back("d3d12"); +#endif +#ifdef METAL_ENABLED + available_drivers.push_back("metal"); #endif } #ifdef GLES3_ENABLED diff --git a/modules/glslang/config.py b/modules/glslang/config.py index 1169776a739..a024dcf9846 100644 --- a/modules/glslang/config.py +++ b/modules/glslang/config.py @@ -1,7 +1,7 @@ def can_build(env, platform): - # glslang is only needed when Vulkan or Direct3D 12-based renderers are available, + # glslang is only needed when Vulkan, Direct3D 12 or Metal-based renderers are available, # as OpenGL doesn't use glslang. - return env["vulkan"] or env["d3d12"] + return env["vulkan"] or env["d3d12"] or env["metal"] def configure(env): diff --git a/modules/glslang/register_types.cpp b/modules/glslang/register_types.cpp index b5f70fb98b8..81505f716a2 100644 --- a/modules/glslang/register_types.cpp +++ b/modules/glslang/register_types.cpp @@ -73,6 +73,9 @@ static Vector _compile_shader_glsl(RenderingDevice::ShaderStage p_stage // - SPIRV-Reflect won't be able to parse the compute workgroup size. // - We want to play it safe with NIR-DXIL. TargetVersion = glslang::EShTargetSpv_1_3; + } else if (capabilities.device_family == RDD::DEVICE_METAL) { + ClientVersion = glslang::EShTargetVulkan_1_1; + TargetVersion = glslang::EShTargetSpv_1_6; } else { // once we support other backends we'll need to do something here if (r_error) { diff --git a/platform/ios/detect.py b/platform/ios/detect.py index 53b367a0a7f..ecd1f3d32d0 100644 --- a/platform/ios/detect.py +++ b/platform/ios/detect.py @@ -51,6 +51,7 @@ def get_flags(): "arch": "arm64", "target": "template_debug", "use_volk": False, + "metal": True, "supported": ["mono"], "builtin_pcre2_with_jit": False, } @@ -154,8 +155,22 @@ def configure(env: "SConsEnvironment"): env.Prepend(CPPPATH=["#platform/ios"]) env.Append(CPPDEFINES=["IOS_ENABLED", "UNIX_ENABLED", "COREAUDIO_ENABLED"]) + if env["metal"] and env["arch"] != "arm64": + # Only supported on arm64, so skip it for x86_64 builds. + env["metal"] = False + + if env["metal"]: + env.AppendUnique(CPPDEFINES=["METAL_ENABLED", "RD_ENABLED"]) + env.Prepend( + CPPPATH=[ + "$IOS_SDK_PATH/System/Library/Frameworks/Metal.framework/Headers", + "$IOS_SDK_PATH/System/Library/Frameworks/QuartzCore.framework/Headers", + ] + ) + env.Prepend(CPPPATH=["#thirdparty/spirv-cross"]) + if env["vulkan"]: - env.Append(CPPDEFINES=["VULKAN_ENABLED", "RD_ENABLED"]) + env.AppendUnique(CPPDEFINES=["VULKAN_ENABLED", "RD_ENABLED"]) if env["opengl3"]: env.Append(CPPDEFINES=["GLES3_ENABLED", "GLES_SILENCE_DEPRECATION"]) diff --git a/platform/ios/display_server_ios.h b/platform/ios/display_server_ios.h index 4dded5aa293..bbb758074d9 100644 --- a/platform/ios/display_server_ios.h +++ b/platform/ios/display_server_ios.h @@ -47,6 +47,10 @@ #include #endif #endif // VULKAN_ENABLED + +#if defined(METAL_ENABLED) +#include "drivers/metal/rendering_context_driver_metal.h" +#endif // METAL_ENABLED #endif // RD_ENABLED #if defined(GLES3_ENABLED) diff --git a/platform/ios/display_server_ios.mm b/platform/ios/display_server_ios.mm index 802fbefc0db..5a027e01964 100644 --- a/platform/ios/display_server_ios.mm +++ b/platform/ios/display_server_ios.mm @@ -72,6 +72,13 @@ DisplayServerIOS::DisplayServerIOS(const String &p_rendering_driver, WindowMode union { #ifdef VULKAN_ENABLED RenderingContextDriverVulkanIOS::WindowPlatformData vulkan; +#endif +#ifdef METAL_ENABLED +#pragma clang diagnostic push +#pragma clang diagnostic ignored "-Wunguarded-availability" + // Eliminate "RenderingContextDriverMetal is only available on iOS 14.0 or newer". + RenderingContextDriverMetal::WindowPlatformData metal; +#pragma clang diagnostic pop #endif } wpd; @@ -85,7 +92,19 @@ DisplayServerIOS::DisplayServerIOS(const String &p_rendering_driver, WindowMode rendering_context = memnew(RenderingContextDriverVulkanIOS); } #endif - +#ifdef METAL_ENABLED + if (rendering_driver == "metal") { + if (@available(iOS 14.0, *)) { + layer = [AppDelegate.viewController.godotView initializeRenderingForDriver:@"metal"]; + wpd.metal.layer = (CAMetalLayer *)layer; + rendering_context = memnew(RenderingContextDriverMetal); + } else { + OS::get_singleton()->alert("Metal is only supported on iOS 14.0 and later."); + r_error = ERR_UNAVAILABLE; + return; + } + } +#endif if (rendering_context) { if (rendering_context->initialize() != OK) { ERR_PRINT(vformat("Failed to initialize %s context", rendering_driver)); @@ -172,6 +191,11 @@ Vector DisplayServerIOS::get_rendering_drivers_func() { #if defined(VULKAN_ENABLED) drivers.push_back("vulkan"); #endif +#if defined(METAL_ENABLED) + if (@available(ios 14.0, *)) { + drivers.push_back("metal"); + } +#endif #if defined(GLES3_ENABLED) drivers.push_back("opengl3"); #endif diff --git a/platform/ios/export/export_plugin.cpp b/platform/ios/export/export_plugin.cpp index 580e3ba7891..e4b5392c4e5 100644 --- a/platform/ios/export/export_plugin.cpp +++ b/platform/ios/export/export_plugin.cpp @@ -282,6 +282,7 @@ void EditorExportPlatformIOS::get_export_options(List *r_options) r_options->push_back(ExportOption(PropertyInfo(Variant::STRING, "application/short_version", PROPERTY_HINT_PLACEHOLDER_TEXT, "Leave empty to use project version"), "")); r_options->push_back(ExportOption(PropertyInfo(Variant::STRING, "application/version", PROPERTY_HINT_PLACEHOLDER_TEXT, "Leave empty to use project version"), "")); + // TODO(sgc): set to iOS 14.0 for Metal r_options->push_back(ExportOption(PropertyInfo(Variant::STRING, "application/min_ios_version"), "12.0")); r_options->push_back(ExportOption(PropertyInfo(Variant::STRING, "application/additional_plist_content", PROPERTY_HINT_MULTILINE_TEXT), "")); @@ -2656,6 +2657,13 @@ bool EditorExportPlatformIOS::has_valid_export_configuration(const Refget("application/min_ios_version").operator String().to_float(); + if (version < 14.0) { + err += TTR("Metal renderer require iOS 14+.") + "\n"; + } + } + if (!err.is_empty()) { r_error = err; } diff --git a/platform/ios/godot_view.mm b/platform/ios/godot_view.mm index e9100361983..552c4c262cf 100644 --- a/platform/ios/godot_view.mm +++ b/platform/ios/godot_view.mm @@ -71,7 +71,7 @@ static const float earth_gravity = 9.80665; CALayer *layer; - if ([driverName isEqualToString:@"vulkan"]) { + if ([driverName isEqualToString:@"vulkan"] || [driverName isEqualToString:@"metal"]) { #if defined(TARGET_OS_SIMULATOR) && TARGET_OS_SIMULATOR if (@available(iOS 13, *)) { layer = [GodotMetalLayer layer]; diff --git a/platform/macos/detect.py b/platform/macos/detect.py index 1ce7c86c7b8..bfd18bea99f 100644 --- a/platform/macos/detect.py +++ b/platform/macos/detect.py @@ -56,6 +56,7 @@ def get_flags(): return { "arch": detect_arch(), "use_volk": False, + "metal": True, "supported": ["mono"], } @@ -239,9 +240,22 @@ def configure(env: "SConsEnvironment"): env.Append(LINKFLAGS=["-rpath", "@executable_path/../Frameworks", "-rpath", "@executable_path"]) + if env["metal"] and env["arch"] != "arm64": + # Only supported on arm64, so skip it for x86_64 builds. + env["metal"] = False + + extra_frameworks = set() + + if env["metal"]: + env.AppendUnique(CPPDEFINES=["METAL_ENABLED", "RD_ENABLED"]) + extra_frameworks.add("Metal") + extra_frameworks.add("MetalKit") + env.Prepend(CPPPATH=["#thirdparty/spirv-cross"]) + if env["vulkan"]: - env.Append(CPPDEFINES=["VULKAN_ENABLED", "RD_ENABLED"]) - env.Append(LINKFLAGS=["-framework", "Metal", "-framework", "IOSurface"]) + env.AppendUnique(CPPDEFINES=["VULKAN_ENABLED", "RD_ENABLED"]) + extra_frameworks.add("Metal") + extra_frameworks.add("IOSurface") if not env["use_volk"]: env.Append(LINKFLAGS=["-lMoltenVK"]) @@ -260,3 +274,7 @@ def configure(env: "SConsEnvironment"): "MoltenVK SDK installation directory not found, use 'vulkan_sdk_path' SCons parameter to specify SDK path." ) sys.exit(255) + + if len(extra_frameworks) > 0: + frameworks = [item for key in extra_frameworks for item in ["-framework", key]] + env.Append(LINKFLAGS=frameworks) diff --git a/platform/macos/display_server_macos.h b/platform/macos/display_server_macos.h index bd3c6af273b..97af6d0a5a2 100644 --- a/platform/macos/display_server_macos.h +++ b/platform/macos/display_server_macos.h @@ -47,6 +47,9 @@ #if defined(VULKAN_ENABLED) #include "rendering_context_driver_vulkan_macos.h" #endif // VULKAN_ENABLED +#if defined(METAL_ENABLED) +#include "drivers/metal/rendering_context_driver_metal.h" +#endif #endif // RD_ENABLED #define BitMap _QDBitMap // Suppress deprecated QuickDraw definition. diff --git a/platform/macos/display_server_macos.mm b/platform/macos/display_server_macos.mm index b2243dd8d5b..3e0a5efe526 100644 --- a/platform/macos/display_server_macos.mm +++ b/platform/macos/display_server_macos.mm @@ -138,12 +138,20 @@ DisplayServerMacOS::WindowID DisplayServerMacOS::_create_window(WindowMode p_mod union { #ifdef VULKAN_ENABLED RenderingContextDriverVulkanMacOS::WindowPlatformData vulkan; +#endif +#ifdef METAL_ENABLED + RenderingContextDriverMetal::WindowPlatformData metal; #endif } wpd; #ifdef VULKAN_ENABLED if (rendering_driver == "vulkan") { wpd.vulkan.layer_ptr = (CAMetalLayer *const *)&layer; } +#endif +#ifdef METAL_ENABLED + if (rendering_driver == "metal") { + wpd.metal.layer = (CAMetalLayer *)layer; + } #endif Error err = rendering_context->window_create(window_id_counter, &wpd); ERR_FAIL_COND_V_MSG(err != OK, INVALID_WINDOW_ID, vformat("Can't create a %s context", rendering_driver)); @@ -2700,7 +2708,7 @@ void DisplayServerMacOS::window_set_vsync_mode(DisplayServer::VSyncMode p_vsync_ gl_manager_legacy->set_use_vsync(p_vsync_mode != DisplayServer::VSYNC_DISABLED); } #endif -#if defined(VULKAN_ENABLED) +#if defined(RD_ENABLED) if (rendering_context) { rendering_context->window_set_vsync_mode(p_window, p_vsync_mode); } @@ -2717,7 +2725,7 @@ DisplayServer::VSyncMode DisplayServerMacOS::window_get_vsync_mode(WindowID p_wi return (gl_manager_legacy->is_using_vsync() ? DisplayServer::VSyncMode::VSYNC_ENABLED : DisplayServer::VSyncMode::VSYNC_DISABLED); } #endif -#if defined(VULKAN_ENABLED) +#if defined(RD_ENABLED) if (rendering_context) { return rendering_context->window_get_vsync_mode(p_window); } @@ -3301,6 +3309,9 @@ Vector DisplayServerMacOS::get_rendering_drivers_func() { #if defined(VULKAN_ENABLED) drivers.push_back("vulkan"); #endif +#if defined(METAL_ENABLED) + drivers.push_back("metal"); +#endif #if defined(GLES3_ENABLED) drivers.push_back("opengl3"); drivers.push_back("opengl3_angle"); @@ -3623,6 +3634,11 @@ DisplayServerMacOS::DisplayServerMacOS(const String &p_rendering_driver, WindowM rendering_context = memnew(RenderingContextDriverVulkanMacOS); } #endif +#if defined(METAL_ENABLED) + if (rendering_driver == "metal") { + rendering_context = memnew(RenderingContextDriverMetal); + } +#endif if (rendering_context) { if (rendering_context->initialize() != OK) { diff --git a/platform/macos/export/export_plugin.cpp b/platform/macos/export/export_plugin.cpp index 057fb4ec165..290b0082fc8 100644 --- a/platform/macos/export/export_plugin.cpp +++ b/platform/macos/export/export_plugin.cpp @@ -458,6 +458,7 @@ void EditorExportPlatformMacOS::get_export_options(List *r_options r_options->push_back(ExportOption(PropertyInfo(Variant::STRING, "application/additional_plist_content", PROPERTY_HINT_MULTILINE_TEXT), "")); r_options->push_back(ExportOption(PropertyInfo(Variant::STRING, "xcode/platform_build"), "14C18")); + // TODO(sgc): Need to set appropriate version when using Metal r_options->push_back(ExportOption(PropertyInfo(Variant::STRING, "xcode/sdk_version"), "13.1")); r_options->push_back(ExportOption(PropertyInfo(Variant::STRING, "xcode/sdk_build"), "22C55")); r_options->push_back(ExportOption(PropertyInfo(Variant::STRING, "xcode/sdk_name"), "macosx13.1")); diff --git a/servers/rendering/renderer_rd/cluster_builder_rd.h b/servers/rendering/renderer_rd/cluster_builder_rd.h index 3ca7af70ca9..1badce2b811 100644 --- a/servers/rendering/renderer_rd/cluster_builder_rd.h +++ b/servers/rendering/renderer_rd/cluster_builder_rd.h @@ -185,7 +185,14 @@ private: }; uint32_t cluster_size = 32; +#if defined(MACOS_ENABLED) || defined(IOS_ENABLED) + // Results in visual artifacts on macOS and iOS when using MSAA and subgroups. + // Using subgroups and disabling MSAA is the optimal solution for now and also works + // with MoltenVK. + bool use_msaa = false; +#else bool use_msaa = true; +#endif Divisor divisor = DIVISOR_4; Size2i screen_size; diff --git a/servers/rendering/renderer_rd/shader_rd.cpp b/servers/rendering/renderer_rd/shader_rd.cpp index 1c1b8366e5f..e6a745d3b44 100644 --- a/servers/rendering/renderer_rd/shader_rd.cpp +++ b/servers/rendering/renderer_rd/shader_rd.cpp @@ -191,9 +191,14 @@ void ShaderRD::_build_variant_code(StringBuilder &builder, uint32_t p_variant, c for (const KeyValue &E : p_version->code_sections) { builder.append(String("#define ") + String(E.key) + "_CODE_USED\n"); } -#if defined(MACOS_ENABLED) || defined(IOS_ENABLED) - builder.append("#define MOLTENVK_USED\n"); +#if (defined(MACOS_ENABLED) || defined(IOS_ENABLED)) + if (RD::get_singleton()->get_device_capabilities().device_family == RDD::DEVICE_VULKAN) { + builder.append("#define MOLTENVK_USED\n"); + } + // Image atomics are supported on Metal 3.1 but no support in MoltenVK or SPIRV-Cross yet. + builder.append("#define NO_IMAGE_ATOMICS\n"); #endif + builder.append(String("#define RENDER_DRIVER_") + OS::get_singleton()->get_current_rendering_driver_name().to_upper() + "\n"); } break; case StageTemplate::Chunk::TYPE_MATERIAL_UNIFORMS: { diff --git a/servers/rendering/renderer_rd/shaders/environment/volumetric_fog.glsl b/servers/rendering/renderer_rd/shaders/environment/volumetric_fog.glsl index 6ba1fe2819a..2360375c237 100644 --- a/servers/rendering/renderer_rd/shaders/environment/volumetric_fog.glsl +++ b/servers/rendering/renderer_rd/shaders/environment/volumetric_fog.glsl @@ -34,7 +34,7 @@ layout(push_constant, std430) uniform Params { } params; -#ifdef MOLTENVK_USED +#ifdef NO_IMAGE_ATOMICS layout(set = 1, binding = 1) volatile buffer emissive_only_map_buffer { uint emissive_only_map[]; }; @@ -64,7 +64,7 @@ layout(set = 1, binding = 2, std140) uniform SceneParams { } scene_params; -#ifdef MOLTENVK_USED +#ifdef NO_IMAGE_ATOMICS layout(set = 1, binding = 3) volatile buffer density_only_map_buffer { uint density_only_map[]; }; @@ -117,7 +117,7 @@ void main() { if (any(greaterThanEqual(pos, scene_params.fog_volume_size))) { return; //do not compute } -#ifdef MOLTENVK_USED +#ifdef NO_IMAGE_ATOMICS uint lpos = pos.z * scene_params.fog_volume_size.x * scene_params.fog_volume_size.y + pos.y * scene_params.fog_volume_size.x + pos.x; #endif @@ -222,7 +222,7 @@ void main() { density *= cull_mask; if (abs(density) > 0.001) { int final_density = int(density * DENSITY_SCALE); -#ifdef MOLTENVK_USED +#ifdef NO_IMAGE_ATOMICS atomicAdd(density_only_map[lpos], uint(final_density)); #else imageAtomicAdd(density_only_map, pos, uint(final_density)); @@ -236,7 +236,7 @@ void main() { uvec3 emission_u = uvec3(emission.r * 511.0, emission.g * 511.0, emission.b * 255.0); // R and G have 11 bits each and B has 10. Then pack them into a 32 bit uint uint final_emission = emission_u.r << 21 | emission_u.g << 10 | emission_u.b; -#ifdef MOLTENVK_USED +#ifdef NO_IMAGE_ATOMICS uint prev_emission = atomicAdd(emissive_only_map[lpos], final_emission); #else uint prev_emission = imageAtomicAdd(emissive_only_map, pos, final_emission); @@ -252,7 +252,7 @@ void main() { if (any(overflowing)) { uvec3 overflow_factor = mix(uvec3(0), uvec3(2047 << 21, 2047 << 10, 1023), overflowing); uint force_max = overflow_factor.r | overflow_factor.g | overflow_factor.b; -#ifdef MOLTENVK_USED +#ifdef NO_IMAGE_ATOMICS atomicOr(emissive_only_map[lpos], force_max); #else imageAtomicOr(emissive_only_map, pos, force_max); @@ -267,7 +267,7 @@ void main() { uvec3 scattering_u = uvec3(scattering.r * 2047.0, scattering.g * 2047.0, scattering.b * 1023.0); // R and G have 11 bits each and B has 10. Then pack them into a 32 bit uint uint final_scattering = scattering_u.r << 21 | scattering_u.g << 10 | scattering_u.b; -#ifdef MOLTENVK_USED +#ifdef NO_IMAGE_ATOMICS uint prev_scattering = atomicAdd(light_only_map[lpos], final_scattering); #else uint prev_scattering = imageAtomicAdd(light_only_map, pos, final_scattering); @@ -283,7 +283,7 @@ void main() { if (any(overflowing)) { uvec3 overflow_factor = mix(uvec3(0), uvec3(2047 << 21, 2047 << 10, 1023), overflowing); uint force_max = overflow_factor.r | overflow_factor.g | overflow_factor.b; -#ifdef MOLTENVK_USED +#ifdef NO_IMAGE_ATOMICS atomicOr(light_only_map[lpos], force_max); #else imageAtomicOr(light_only_map, pos, force_max); diff --git a/servers/rendering/renderer_rd/shaders/environment/volumetric_fog_process.glsl b/servers/rendering/renderer_rd/shaders/environment/volumetric_fog_process.glsl index d0cfe6a3b87..a797891ab65 100644 --- a/servers/rendering/renderer_rd/shaders/environment/volumetric_fog_process.glsl +++ b/servers/rendering/renderer_rd/shaders/environment/volumetric_fog_process.glsl @@ -190,7 +190,7 @@ params; #ifndef MODE_COPY layout(set = 0, binding = 15) uniform texture3D prev_density_texture; -#ifdef MOLTENVK_USED +#ifdef NO_IMAGE_ATOMICS layout(set = 0, binding = 16) buffer density_only_map_buffer { uint density_only_map[]; }; @@ -287,7 +287,7 @@ void main() { if (any(greaterThanEqual(pos, params.fog_volume_size))) { return; //do not compute } -#ifdef MOLTENVK_USED +#ifdef NO_IMAGE_ATOMICS uint lpos = pos.z * params.fog_volume_size.x * params.fog_volume_size.y + pos.y * params.fog_volume_size.x + pos.x; #endif @@ -353,7 +353,7 @@ void main() { vec3 total_light = vec3(0.0); float total_density = params.base_density; -#ifdef MOLTENVK_USED +#ifdef NO_IMAGE_ATOMICS uint local_density = density_only_map[lpos]; #else uint local_density = imageLoad(density_only_map, pos).x; @@ -362,7 +362,7 @@ void main() { total_density += float(int(local_density)) / DENSITY_SCALE; total_density = max(0.0, total_density); -#ifdef MOLTENVK_USED +#ifdef NO_IMAGE_ATOMICS uint scattering_u = light_only_map[lpos]; #else uint scattering_u = imageLoad(light_only_map, pos).x; @@ -370,7 +370,7 @@ void main() { vec3 scattering = vec3(scattering_u >> 21, (scattering_u << 11) >> 21, scattering_u % 1024) / vec3(2047.0, 2047.0, 1023.0); scattering += params.base_scattering * params.base_density; -#ifdef MOLTENVK_USED +#ifdef NO_IMAGE_ATOMICS uint emission_u = emissive_only_map[lpos]; #else uint emission_u = imageLoad(emissive_only_map, pos).x; @@ -710,7 +710,7 @@ void main() { final_density = mix(final_density, reprojected_density, reproject_amount); imageStore(density_map, pos, final_density); -#ifdef MOLTENVK_USED +#ifdef NO_IMAGE_ATOMICS density_only_map[lpos] = 0; light_only_map[lpos] = 0; emissive_only_map[lpos] = 0; diff --git a/servers/rendering/renderer_rd/shaders/forward_clustered/scene_forward_clustered.glsl b/servers/rendering/renderer_rd/shaders/forward_clustered/scene_forward_clustered.glsl index 5706f670eb4..c2d36adc923 100644 --- a/servers/rendering/renderer_rd/shaders/forward_clustered/scene_forward_clustered.glsl +++ b/servers/rendering/renderer_rd/shaders/forward_clustered/scene_forward_clustered.glsl @@ -2374,7 +2374,7 @@ void fragment_shader(in SceneData scene_data) { } } -#ifdef MOLTENVK_USED +#ifdef NO_IMAGE_ATOMICS imageStore(geom_facing_grid, grid_pos, uvec4(imageLoad(geom_facing_grid, grid_pos).r | facing_bits)); //store facing bits #else imageAtomicOr(geom_facing_grid, grid_pos, facing_bits); //store facing bits diff --git a/servers/rendering/rendering_device.cpp b/servers/rendering/rendering_device.cpp index da48b24a1a2..aef6978e259 100644 --- a/servers/rendering/rendering_device.cpp +++ b/servers/rendering/rendering_device.cpp @@ -2826,6 +2826,7 @@ RID RenderingDevice::uniform_set_create(const Vector &p_uniforms, RID p for (int j = 0; j < (int)uniform_count; j++) { if (uniforms[j].binding == set_uniform.binding) { uniform_idx = j; + break; } } ERR_FAIL_COND_V_MSG(uniform_idx == -1, RID(), @@ -3240,6 +3241,7 @@ RID RenderingDevice::render_pipeline_create(RID p_shader, FramebufferFormatID p_ for (int j = 0; j < vd.vertex_formats.size(); j++) { if (vd.vertex_formats[j].location == i) { found = true; + break; } } diff --git a/servers/rendering/rendering_device_driver.h b/servers/rendering/rendering_device_driver.h index 0b5fc51a1d8..934536ad9f4 100644 --- a/servers/rendering/rendering_device_driver.h +++ b/servers/rendering/rendering_device_driver.h @@ -759,6 +759,7 @@ public: DEVICE_OPENGL, DEVICE_VULKAN, DEVICE_DIRECTX, + DEVICE_METAL, }; struct Capabilities { diff --git a/thirdparty/README.md b/thirdparty/README.md index f241ce56d11..ddc28375f6d 100644 --- a/thirdparty/README.md +++ b/thirdparty/README.md @@ -827,6 +827,22 @@ and solve conflicts and also enrich the feature set originally proposed by these libraries and better integrate them with Godot. +## spirv-cross + +- Upstream: https://github.com/KhronosGroup/SPIRV-Cross +- Version: vulkan-sdk-1.3.290.0 (5d127b917f080c6f052553c47170ec0ba702e54f, 2024) +- License: Apache 2.0 + +Files extracted from upstream source: + +- All `.cpp`, `.hpp` and `.h` files, minus `main.cpp`, `spirv_cross_c.*`, `spirv_hlsl.*`, `spirv_cpp.*` +- `include/` folder +- `LICENSE` and `LICENSES/` folder, minus `CC-BY-4.0.txt` + +Versions of this SDK do not have to match the `vulkan` section, as this SDK is required +to generate Metal source from Vulkan SPIR-V. + + ## spirv-reflect - Upstream: https://github.com/KhronosGroup/SPIRV-Reflect diff --git a/thirdparty/spirv-cross/GLSL.std.450.h b/thirdparty/spirv-cross/GLSL.std.450.h new file mode 100644 index 00000000000..2686fc4ea7e --- /dev/null +++ b/thirdparty/spirv-cross/GLSL.std.450.h @@ -0,0 +1,114 @@ +/* + * Copyright 2014-2016,2021 The Khronos Group, Inc. + * SPDX-License-Identifier: MIT + * + * MODIFICATIONS TO THIS FILE MAY MEAN IT NO LONGER ACCURATELY REFLECTS KHRONOS + * STANDARDS. THE UNMODIFIED, NORMATIVE VERSIONS OF KHRONOS SPECIFICATIONS AND + * HEADER INFORMATION ARE LOCATED AT https://www.khronos.org/registry/ +*/ + +#ifndef GLSLstd450_H +#define GLSLstd450_H + +static const int GLSLstd450Version = 100; +static const int GLSLstd450Revision = 3; + +enum GLSLstd450 { + GLSLstd450Bad = 0, // Don't use + + GLSLstd450Round = 1, + GLSLstd450RoundEven = 2, + GLSLstd450Trunc = 3, + GLSLstd450FAbs = 4, + GLSLstd450SAbs = 5, + GLSLstd450FSign = 6, + GLSLstd450SSign = 7, + GLSLstd450Floor = 8, + GLSLstd450Ceil = 9, + GLSLstd450Fract = 10, + + GLSLstd450Radians = 11, + GLSLstd450Degrees = 12, + GLSLstd450Sin = 13, + GLSLstd450Cos = 14, + GLSLstd450Tan = 15, + GLSLstd450Asin = 16, + GLSLstd450Acos = 17, + GLSLstd450Atan = 18, + GLSLstd450Sinh = 19, + GLSLstd450Cosh = 20, + GLSLstd450Tanh = 21, + GLSLstd450Asinh = 22, + GLSLstd450Acosh = 23, + GLSLstd450Atanh = 24, + GLSLstd450Atan2 = 25, + + GLSLstd450Pow = 26, + GLSLstd450Exp = 27, + GLSLstd450Log = 28, + GLSLstd450Exp2 = 29, + GLSLstd450Log2 = 30, + GLSLstd450Sqrt = 31, + GLSLstd450InverseSqrt = 32, + + GLSLstd450Determinant = 33, + GLSLstd450MatrixInverse = 34, + + GLSLstd450Modf = 35, // second operand needs an OpVariable to write to + GLSLstd450ModfStruct = 36, // no OpVariable operand + GLSLstd450FMin = 37, + GLSLstd450UMin = 38, + GLSLstd450SMin = 39, + GLSLstd450FMax = 40, + GLSLstd450UMax = 41, + GLSLstd450SMax = 42, + GLSLstd450FClamp = 43, + GLSLstd450UClamp = 44, + GLSLstd450SClamp = 45, + GLSLstd450FMix = 46, + GLSLstd450IMix = 47, // Reserved + GLSLstd450Step = 48, + GLSLstd450SmoothStep = 49, + + GLSLstd450Fma = 50, + GLSLstd450Frexp = 51, // second operand needs an OpVariable to write to + GLSLstd450FrexpStruct = 52, // no OpVariable operand + GLSLstd450Ldexp = 53, + + GLSLstd450PackSnorm4x8 = 54, + GLSLstd450PackUnorm4x8 = 55, + GLSLstd450PackSnorm2x16 = 56, + GLSLstd450PackUnorm2x16 = 57, + GLSLstd450PackHalf2x16 = 58, + GLSLstd450PackDouble2x32 = 59, + GLSLstd450UnpackSnorm2x16 = 60, + GLSLstd450UnpackUnorm2x16 = 61, + GLSLstd450UnpackHalf2x16 = 62, + GLSLstd450UnpackSnorm4x8 = 63, + GLSLstd450UnpackUnorm4x8 = 64, + GLSLstd450UnpackDouble2x32 = 65, + + GLSLstd450Length = 66, + GLSLstd450Distance = 67, + GLSLstd450Cross = 68, + GLSLstd450Normalize = 69, + GLSLstd450FaceForward = 70, + GLSLstd450Reflect = 71, + GLSLstd450Refract = 72, + + GLSLstd450FindILsb = 73, + GLSLstd450FindSMsb = 74, + GLSLstd450FindUMsb = 75, + + GLSLstd450InterpolateAtCentroid = 76, + GLSLstd450InterpolateAtSample = 77, + GLSLstd450InterpolateAtOffset = 78, + + GLSLstd450NMin = 79, + GLSLstd450NMax = 80, + GLSLstd450NClamp = 81, + + GLSLstd450Count +}; + +#endif // #ifndef GLSLstd450_H diff --git a/thirdparty/spirv-cross/LICENSE b/thirdparty/spirv-cross/LICENSE new file mode 100644 index 00000000000..d6456956733 --- /dev/null +++ b/thirdparty/spirv-cross/LICENSE @@ -0,0 +1,202 @@ + + Apache License + Version 2.0, January 2004 + http://www.apache.org/licenses/ + + TERMS AND CONDITIONS FOR USE, REPRODUCTION, AND DISTRIBUTION + + 1. Definitions. + + "License" shall mean the terms and conditions for use, reproduction, + and distribution as defined by Sections 1 through 9 of this document. + + "Licensor" shall mean the copyright owner or entity authorized by + the copyright owner that is granting the License. + + "Legal Entity" shall mean the union of the acting entity and all + other entities that control, are controlled by, or are under common + control with that entity. For the purposes of this definition, + "control" means (i) the power, direct or indirect, to cause the + direction or management of such entity, whether by contract or + otherwise, or (ii) ownership of fifty percent (50%) or more of the + outstanding shares, or (iii) beneficial ownership of such entity. + + "You" (or "Your") shall mean an individual or Legal Entity + exercising permissions granted by this License. + + "Source" form shall mean the preferred form for making modifications, + including but not limited to software source code, documentation + source, and configuration files. + + "Object" form shall mean any form resulting from mechanical + transformation or translation of a Source form, including but + not limited to compiled object code, generated documentation, + and conversions to other media types. + + "Work" shall mean the work of authorship, whether in Source or + Object form, made available under the License, as indicated by a + copyright notice that is included in or attached to the work + (an example is provided in the Appendix below). + + "Derivative Works" shall mean any work, whether in Source or Object + form, that is based on (or derived from) the Work and for which the + editorial revisions, annotations, elaborations, or other modifications + represent, as a whole, an original work of authorship. For the purposes + of this License, Derivative Works shall not include works that remain + separable from, or merely link (or bind by name) to the interfaces of, + the Work and Derivative Works thereof. + + "Contribution" shall mean any work of authorship, including + the original version of the Work and any modifications or additions + to that Work or Derivative Works thereof, that is intentionally + submitted to Licensor for inclusion in the Work by the copyright owner + or by an individual or Legal Entity authorized to submit on behalf of + the copyright owner. For the purposes of this definition, "submitted" + means any form of electronic, verbal, or written communication sent + to the Licensor or its representatives, including but not limited to + communication on electronic mailing lists, source code control systems, + and issue tracking systems that are managed by, or on behalf of, the + Licensor for the purpose of discussing and improving the Work, but + excluding communication that is conspicuously marked or otherwise + designated in writing by the copyright owner as "Not a Contribution." + + "Contributor" shall mean Licensor and any individual or Legal Entity + on behalf of whom a Contribution has been received by Licensor and + subsequently incorporated within the Work. + + 2. Grant of Copyright License. Subject to the terms and conditions of + this License, each Contributor hereby grants to You a perpetual, + worldwide, non-exclusive, no-charge, royalty-free, irrevocable + copyright license to reproduce, prepare Derivative Works of, + publicly display, publicly perform, sublicense, and distribute the + Work and such Derivative Works in Source or Object form. + + 3. Grant of Patent License. Subject to the terms and conditions of + this License, each Contributor hereby grants to You a perpetual, + worldwide, non-exclusive, no-charge, royalty-free, irrevocable + (except as stated in this section) patent license to make, have made, + use, offer to sell, sell, import, and otherwise transfer the Work, + where such license applies only to those patent claims licensable + by such Contributor that are necessarily infringed by their + Contribution(s) alone or by combination of their Contribution(s) + with the Work to which such Contribution(s) was submitted. If You + institute patent litigation against any entity (including a + cross-claim or counterclaim in a lawsuit) alleging that the Work + or a Contribution incorporated within the Work constitutes direct + or contributory patent infringement, then any patent licenses + granted to You under this License for that Work shall terminate + as of the date such litigation is filed. + + 4. Redistribution. You may reproduce and distribute copies of the + Work or Derivative Works thereof in any medium, with or without + modifications, and in Source or Object form, provided that You + meet the following conditions: + + (a) You must give any other recipients of the Work or + Derivative Works a copy of this License; and + + (b) You must cause any modified files to carry prominent notices + stating that You changed the files; and + + (c) You must retain, in the Source form of any Derivative Works + that You distribute, all copyright, patent, trademark, and + attribution notices from the Source form of the Work, + excluding those notices that do not pertain to any part of + the Derivative Works; and + + (d) If the Work includes a "NOTICE" text file as part of its + distribution, then any Derivative Works that You distribute must + include a readable copy of the attribution notices contained + within such NOTICE file, excluding those notices that do not + pertain to any part of the Derivative Works, in at least one + of the following places: within a NOTICE text file distributed + as part of the Derivative Works; within the Source form or + documentation, if provided along with the Derivative Works; or, + within a display generated by the Derivative Works, if and + wherever such third-party notices normally appear. The contents + of the NOTICE file are for informational purposes only and + do not modify the License. You may add Your own attribution + notices within Derivative Works that You distribute, alongside + or as an addendum to the NOTICE text from the Work, provided + that such additional attribution notices cannot be construed + as modifying the License. + + You may add Your own copyright statement to Your modifications and + may provide additional or different license terms and conditions + for use, reproduction, or distribution of Your modifications, or + for any such Derivative Works as a whole, provided Your use, + reproduction, and distribution of the Work otherwise complies with + the conditions stated in this License. + + 5. Submission of Contributions. Unless You explicitly state otherwise, + any Contribution intentionally submitted for inclusion in the Work + by You to the Licensor shall be under the terms and conditions of + this License, without any additional terms or conditions. + Notwithstanding the above, nothing herein shall supersede or modify + the terms of any separate license agreement you may have executed + with Licensor regarding such Contributions. + + 6. Trademarks. This License does not grant permission to use the trade + names, trademarks, service marks, or product names of the Licensor, + except as required for reasonable and customary use in describing the + origin of the Work and reproducing the content of the NOTICE file. + + 7. Disclaimer of Warranty. Unless required by applicable law or + agreed to in writing, Licensor provides the Work (and each + Contributor provides its Contributions) on an "AS IS" BASIS, + WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or + implied, including, without limitation, any warranties or conditions + of TITLE, NON-INFRINGEMENT, MERCHANTABILITY, or FITNESS FOR A + PARTICULAR PURPOSE. You are solely responsible for determining the + appropriateness of using or redistributing the Work and assume any + risks associated with Your exercise of permissions under this License. + + 8. Limitation of Liability. In no event and under no legal theory, + whether in tort (including negligence), contract, or otherwise, + unless required by applicable law (such as deliberate and grossly + negligent acts) or agreed to in writing, shall any Contributor be + liable to You for damages, including any direct, indirect, special, + incidental, or consequential damages of any character arising as a + result of this License or out of the use or inability to use the + Work (including but not limited to damages for loss of goodwill, + work stoppage, computer failure or malfunction, or any and all + other commercial damages or losses), even if such Contributor + has been advised of the possibility of such damages. + + 9. Accepting Warranty or Additional Liability. While redistributing + the Work or Derivative Works thereof, You may choose to offer, + and charge a fee for, acceptance of support, warranty, indemnity, + or other liability obligations and/or rights consistent with this + License. However, in accepting such obligations, You may act only + on Your own behalf and on Your sole responsibility, not on behalf + of any other Contributor, and only if You agree to indemnify, + defend, and hold each Contributor harmless for any liability + incurred by, or claims asserted against, such Contributor by reason + of your accepting any such warranty or additional liability. + + END OF TERMS AND CONDITIONS + + APPENDIX: How to apply the Apache License to your work. + + To apply the Apache License to your work, attach the following + boilerplate notice, with the fields enclosed by brackets "[]" + replaced with your own identifying information. (Don't include + the brackets!) The text should be enclosed in the appropriate + comment syntax for the file format. We also recommend that a + file or class name and description of purpose be included on the + same "printed page" as the copyright notice for easier + identification within third-party archives. + + Copyright [yyyy] [name of copyright owner] + + Licensed under the Apache License, Version 2.0 (the "License"); + you may not use this file except in compliance with the License. + You may obtain a copy of the License at + + http://www.apache.org/licenses/LICENSE-2.0 + + Unless required by applicable law or agreed to in writing, software + distributed under the License is distributed on an "AS IS" BASIS, + WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + See the License for the specific language governing permissions and + limitations under the License. diff --git a/thirdparty/spirv-cross/LICENSES/Apache-2.0.txt b/thirdparty/spirv-cross/LICENSES/Apache-2.0.txt new file mode 100644 index 00000000000..4ed90b95224 --- /dev/null +++ b/thirdparty/spirv-cross/LICENSES/Apache-2.0.txt @@ -0,0 +1,208 @@ +Apache License + +Version 2.0, January 2004 + +http://www.apache.org/licenses/ TERMS AND CONDITIONS FOR USE, REPRODUCTION, +AND DISTRIBUTION + + 1. Definitions. + + + +"License" shall mean the terms and conditions for use, reproduction, and distribution +as defined by Sections 1 through 9 of this document. + + + +"Licensor" shall mean the copyright owner or entity authorized by the copyright +owner that is granting the License. + + + +"Legal Entity" shall mean the union of the acting entity and all other entities +that control, are controlled by, or are under common control with that entity. +For the purposes of this definition, "control" means (i) the power, direct +or indirect, to cause the direction or management of such entity, whether +by contract or otherwise, or (ii) ownership of fifty percent (50%) or more +of the outstanding shares, or (iii) beneficial ownership of such entity. + + + +"You" (or "Your") shall mean an individual or Legal Entity exercising permissions +granted by this License. + + + +"Source" form shall mean the preferred form for making modifications, including +but not limited to software source code, documentation source, and configuration +files. + + + +"Object" form shall mean any form resulting from mechanical transformation +or translation of a Source form, including but not limited to compiled object +code, generated documentation, and conversions to other media types. + + + +"Work" shall mean the work of authorship, whether in Source or Object form, +made available under the License, as indicated by a copyright notice that +is included in or attached to the work (an example is provided in the Appendix +below). + + + +"Derivative Works" shall mean any work, whether in Source or Object form, +that is based on (or derived from) the Work and for which the editorial revisions, +annotations, elaborations, or other modifications represent, as a whole, an +original work of authorship. For the purposes of this License, Derivative +Works shall not include works that remain separable from, or merely link (or +bind by name) to the interfaces of, the Work and Derivative Works thereof. + + + +"Contribution" shall mean any work of authorship, including the original version +of the Work and any modifications or additions to that Work or Derivative +Works thereof, that is intentionally submitted to Licensor for inclusion in +the Work by the copyright owner or by an individual or Legal Entity authorized +to submit on behalf of the copyright owner. For the purposes of this definition, +"submitted" means any form of electronic, verbal, or written communication +sent to the Licensor or its representatives, including but not limited to +communication on electronic mailing lists, source code control systems, and +issue tracking systems that are managed by, or on behalf of, the Licensor +for the purpose of discussing and improving the Work, but excluding communication +that is conspicuously marked or otherwise designated in writing by the copyright +owner as "Not a Contribution." + + + +"Contributor" shall mean Licensor and any individual or Legal Entity on behalf +of whom a Contribution has been received by Licensor and subsequently incorporated +within the Work. + +2. Grant of Copyright License. Subject to the terms and conditions of this +License, each Contributor hereby grants to You a perpetual, worldwide, non-exclusive, +no-charge, royalty-free, irrevocable copyright license to reproduce, prepare +Derivative Works of, publicly display, publicly perform, sublicense, and distribute +the Work and such Derivative Works in Source or Object form. + +3. Grant of Patent License. Subject to the terms and conditions of this License, +each Contributor hereby grants to You a perpetual, worldwide, non-exclusive, +no-charge, royalty-free, irrevocable (except as stated in this section) patent +license to make, have made, use, offer to sell, sell, import, and otherwise +transfer the Work, where such license applies only to those patent claims +licensable by such Contributor that are necessarily infringed by their Contribution(s) +alone or by combination of their Contribution(s) with the Work to which such +Contribution(s) was submitted. If You institute patent litigation against +any entity (including a cross-claim or counterclaim in a lawsuit) alleging +that the Work or a Contribution incorporated within the Work constitutes direct +or contributory patent infringement, then any patent licenses granted to You +under this License for that Work shall terminate as of the date such litigation +is filed. + +4. Redistribution. You may reproduce and distribute copies of the Work or +Derivative Works thereof in any medium, with or without modifications, and +in Source or Object form, provided that You meet the following conditions: + +(a) You must give any other recipients of the Work or Derivative Works a copy +of this License; and + +(b) You must cause any modified files to carry prominent notices stating that +You changed the files; and + +(c) You must retain, in the Source form of any Derivative Works that You distribute, +all copyright, patent, trademark, and attribution notices from the Source +form of the Work, excluding those notices that do not pertain to any part +of the Derivative Works; and + +(d) If the Work includes a "NOTICE" text file as part of its distribution, +then any Derivative Works that You distribute must include a readable copy +of the attribution notices contained within such NOTICE file, excluding those +notices that do not pertain to any part of the Derivative Works, in at least +one of the following places: within a NOTICE text file distributed as part +of the Derivative Works; within the Source form or documentation, if provided +along with the Derivative Works; or, within a display generated by the Derivative +Works, if and wherever such third-party notices normally appear. The contents +of the NOTICE file are for informational purposes only and do not modify the +License. You may add Your own attribution notices within Derivative Works +that You distribute, alongside or as an addendum to the NOTICE text from the +Work, provided that such additional attribution notices cannot be construed +as modifying the License. + +You may add Your own copyright statement to Your modifications and may provide +additional or different license terms and conditions for use, reproduction, +or distribution of Your modifications, or for any such Derivative Works as +a whole, provided Your use, reproduction, and distribution of the Work otherwise +complies with the conditions stated in this License. + +5. Submission of Contributions. Unless You explicitly state otherwise, any +Contribution intentionally submitted for inclusion in the Work by You to the +Licensor shall be under the terms and conditions of this License, without +any additional terms or conditions. Notwithstanding the above, nothing herein +shall supersede or modify the terms of any separate license agreement you +may have executed with Licensor regarding such Contributions. + +6. Trademarks. This License does not grant permission to use the trade names, +trademarks, service marks, or product names of the Licensor, except as required +for reasonable and customary use in describing the origin of the Work and +reproducing the content of the NOTICE file. + +7. Disclaimer of Warranty. Unless required by applicable law or agreed to +in writing, Licensor provides the Work (and each Contributor provides its +Contributions) on an "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY +KIND, either express or implied, including, without limitation, any warranties +or conditions of TITLE, NON-INFRINGEMENT, MERCHANTABILITY, or FITNESS FOR +A PARTICULAR PURPOSE. You are solely responsible for determining the appropriateness +of using or redistributing the Work and assume any risks associated with Your +exercise of permissions under this License. + +8. Limitation of Liability. In no event and under no legal theory, whether +in tort (including negligence), contract, or otherwise, unless required by +applicable law (such as deliberate and grossly negligent acts) or agreed to +in writing, shall any Contributor be liable to You for damages, including +any direct, indirect, special, incidental, or consequential damages of any +character arising as a result of this License or out of the use or inability +to use the Work (including but not limited to damages for loss of goodwill, +work stoppage, computer failure or malfunction, or any and all other commercial +damages or losses), even if such Contributor has been advised of the possibility +of such damages. + +9. Accepting Warranty or Additional Liability. While redistributing the Work +or Derivative Works thereof, You may choose to offer, and charge a fee for, +acceptance of support, warranty, indemnity, or other liability obligations +and/or rights consistent with this License. However, in accepting such obligations, +You may act only on Your own behalf and on Your sole responsibility, not on +behalf of any other Contributor, and only if You agree to indemnify, defend, +and hold each Contributor harmless for any liability incurred by, or claims +asserted against, such Contributor by reason of your accepting any such warranty +or additional liability. END OF TERMS AND CONDITIONS + +APPENDIX: How to apply the Apache License to your work. + +To apply the Apache License to your work, attach the following boilerplate +notice, with the fields enclosed by brackets "[]" replaced with your own identifying +information. (Don't include the brackets!) The text should be enclosed in +the appropriate comment syntax for the file format. We also recommend that +a file or class name and description of purpose be included on the same "printed +page" as the copyright notice for easier identification within third-party +archives. + +Copyright [yyyy] [name of copyright owner] + +Licensed under the Apache License, Version 2.0 (the "License"); + +you may not use this file except in compliance with the License. + +You may obtain a copy of the License at + +http://www.apache.org/licenses/LICENSE-2.0 + +Unless required by applicable law or agreed to in writing, software + +distributed under the License is distributed on an "AS IS" BASIS, + +WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + +See the License for the specific language governing permissions and + +limitations under the License. diff --git a/thirdparty/spirv-cross/LICENSES/LicenseRef-KhronosFreeUse.txt b/thirdparty/spirv-cross/LICENSES/LicenseRef-KhronosFreeUse.txt new file mode 100644 index 00000000000..430863bc98a --- /dev/null +++ b/thirdparty/spirv-cross/LICENSES/LicenseRef-KhronosFreeUse.txt @@ -0,0 +1,23 @@ +Copyright (c) 2014-2020 The Khronos Group Inc. + +Permission is hereby granted, free of charge, to any person obtaining a copy +of this software and/or associated documentation files (the "Materials"), +to deal in the Materials without restriction, including without limitation +the rights to use, copy, modify, merge, publish, distribute, sublicense, +and/or sell copies of the Materials, and to permit persons to whom the +Materials are furnished to do so, subject to the following conditions: + +The above copyright notice and this permission notice shall be included in +all copies or substantial portions of the Materials. + +MODIFICATIONS TO THIS FILE MAY MEAN IT NO LONGER ACCURATELY REFLECTS KHRONOS +STANDARDS. THE UNMODIFIED, NORMATIVE VERSIONS OF KHRONOS SPECIFICATIONS AND +HEADER INFORMATION ARE LOCATED AT https://www.khronos.org/registry/ + +THE MATERIALS ARE PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS +OR IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY, +FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL +THE AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER +LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING +FROM,OUT OF OR IN CONNECTION WITH THE MATERIALS OR THE USE OR OTHER DEALINGS +IN THE MATERIALS. diff --git a/thirdparty/spirv-cross/LICENSES/MIT.txt b/thirdparty/spirv-cross/LICENSES/MIT.txt new file mode 100644 index 00000000000..204b93da48d --- /dev/null +++ b/thirdparty/spirv-cross/LICENSES/MIT.txt @@ -0,0 +1,19 @@ +MIT License Copyright (c) + +Permission is hereby granted, free of charge, to any person obtaining a copy +of this software and associated documentation files (the "Software"), to deal +in the Software without restriction, including without limitation the rights +to use, copy, modify, merge, publish, distribute, sublicense, and/or sell +copies of the Software, and to permit persons to whom the Software is furnished +to do so, subject to the following conditions: + +The above copyright notice and this permission notice (including the next +paragraph) shall be included in all copies or substantial portions of the +Software. + +THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR +IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY, FITNESS +FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE AUTHORS +OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER LIABILITY, +WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM, OUT OF +OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN THE SOFTWARE. diff --git a/thirdparty/spirv-cross/include/spirv_cross/barrier.hpp b/thirdparty/spirv-cross/include/spirv_cross/barrier.hpp new file mode 100644 index 00000000000..4ca7f4d77cb --- /dev/null +++ b/thirdparty/spirv-cross/include/spirv_cross/barrier.hpp @@ -0,0 +1,80 @@ +/* + * Copyright 2015-2017 ARM Limited + * SPDX-License-Identifier: Apache-2.0 + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +#ifndef SPIRV_CROSS_BARRIER_HPP +#define SPIRV_CROSS_BARRIER_HPP + +#include +#include + +namespace spirv_cross +{ +class Barrier +{ +public: + Barrier() + { + count.store(0); + iteration.store(0); + } + + void set_release_divisor(unsigned divisor) + { + this->divisor = divisor; + } + + static inline void memoryBarrier() + { + std::atomic_thread_fence(std::memory_order_seq_cst); + } + + void reset_counter() + { + count.store(0); + iteration.store(0); + } + + void wait() + { + unsigned target_iteration = iteration.load(std::memory_order_relaxed) + 1; + // Overflows cleanly. + unsigned target_count = divisor * target_iteration; + + // Barriers don't enforce memory ordering. + // Be as relaxed about the barrier as we possibly can! + unsigned c = count.fetch_add(1u, std::memory_order_relaxed); + + if (c + 1 == target_count) + { + iteration.store(target_iteration, std::memory_order_relaxed); + } + else + { + // If we have more threads than the CPU, don't hog the CPU for very long periods of time. + while (iteration.load(std::memory_order_relaxed) != target_iteration) + std::this_thread::yield(); + } + } + +private: + unsigned divisor = 1; + std::atomic count; + std::atomic iteration; +}; +} + +#endif diff --git a/thirdparty/spirv-cross/include/spirv_cross/external_interface.h b/thirdparty/spirv-cross/include/spirv_cross/external_interface.h new file mode 100644 index 00000000000..949654f5bff --- /dev/null +++ b/thirdparty/spirv-cross/include/spirv_cross/external_interface.h @@ -0,0 +1,127 @@ +/* + * Copyright 2015-2017 ARM Limited + * SPDX-License-Identifier: Apache-2.0 + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +#ifndef SPIRV_CROSS_EXTERNAL_INTERFACE_H +#define SPIRV_CROSS_EXTERNAL_INTERFACE_H + +#ifdef __cplusplus +extern "C" { +#endif + +#include + +typedef struct spirv_cross_shader spirv_cross_shader_t; + +struct spirv_cross_interface +{ + spirv_cross_shader_t *(*construct)(void); + void (*destruct)(spirv_cross_shader_t *thiz); + void (*invoke)(spirv_cross_shader_t *thiz); +}; + +void spirv_cross_set_stage_input(spirv_cross_shader_t *thiz, unsigned location, void *data, size_t size); + +void spirv_cross_set_stage_output(spirv_cross_shader_t *thiz, unsigned location, void *data, size_t size); + +void spirv_cross_set_push_constant(spirv_cross_shader_t *thiz, void *data, size_t size); + +void spirv_cross_set_uniform_constant(spirv_cross_shader_t *thiz, unsigned location, void *data, size_t size); + +void spirv_cross_set_resource(spirv_cross_shader_t *thiz, unsigned set, unsigned binding, void **data, size_t size); + +const struct spirv_cross_interface *spirv_cross_get_interface(void); + +typedef enum spirv_cross_builtin { + SPIRV_CROSS_BUILTIN_POSITION = 0, + SPIRV_CROSS_BUILTIN_FRAG_COORD = 1, + SPIRV_CROSS_BUILTIN_WORK_GROUP_ID = 2, + SPIRV_CROSS_BUILTIN_NUM_WORK_GROUPS = 3, + SPIRV_CROSS_NUM_BUILTINS +} spirv_cross_builtin; + +void spirv_cross_set_builtin(spirv_cross_shader_t *thiz, spirv_cross_builtin builtin, void *data, size_t size); + +#define SPIRV_CROSS_NUM_DESCRIPTOR_SETS 4 +#define SPIRV_CROSS_NUM_DESCRIPTOR_BINDINGS 16 +#define SPIRV_CROSS_NUM_STAGE_INPUTS 16 +#define SPIRV_CROSS_NUM_STAGE_OUTPUTS 16 +#define SPIRV_CROSS_NUM_UNIFORM_CONSTANTS 32 + +enum spirv_cross_format +{ + SPIRV_CROSS_FORMAT_R8_UNORM = 0, + SPIRV_CROSS_FORMAT_R8G8_UNORM = 1, + SPIRV_CROSS_FORMAT_R8G8B8_UNORM = 2, + SPIRV_CROSS_FORMAT_R8G8B8A8_UNORM = 3, + + SPIRV_CROSS_NUM_FORMATS +}; + +enum spirv_cross_wrap +{ + SPIRV_CROSS_WRAP_CLAMP_TO_EDGE = 0, + SPIRV_CROSS_WRAP_REPEAT = 1, + + SPIRV_CROSS_NUM_WRAP +}; + +enum spirv_cross_filter +{ + SPIRV_CROSS_FILTER_NEAREST = 0, + SPIRV_CROSS_FILTER_LINEAR = 1, + + SPIRV_CROSS_NUM_FILTER +}; + +enum spirv_cross_mipfilter +{ + SPIRV_CROSS_MIPFILTER_BASE = 0, + SPIRV_CROSS_MIPFILTER_NEAREST = 1, + SPIRV_CROSS_MIPFILTER_LINEAR = 2, + + SPIRV_CROSS_NUM_MIPFILTER +}; + +struct spirv_cross_miplevel +{ + const void *data; + unsigned width, height; + size_t stride; +}; + +struct spirv_cross_sampler_info +{ + const struct spirv_cross_miplevel *mipmaps; + unsigned num_mipmaps; + + enum spirv_cross_format format; + enum spirv_cross_wrap wrap_s; + enum spirv_cross_wrap wrap_t; + enum spirv_cross_filter min_filter; + enum spirv_cross_filter mag_filter; + enum spirv_cross_mipfilter mip_filter; +}; + +typedef struct spirv_cross_sampler_2d spirv_cross_sampler_2d_t; +spirv_cross_sampler_2d_t *spirv_cross_create_sampler_2d(const struct spirv_cross_sampler_info *info); +void spirv_cross_destroy_sampler_2d(spirv_cross_sampler_2d_t *samp); + +#ifdef __cplusplus +} +#endif + +#endif diff --git a/thirdparty/spirv-cross/include/spirv_cross/image.hpp b/thirdparty/spirv-cross/include/spirv_cross/image.hpp new file mode 100644 index 00000000000..a41ccdfbb40 --- /dev/null +++ b/thirdparty/spirv-cross/include/spirv_cross/image.hpp @@ -0,0 +1,63 @@ +/* + * Copyright 2015-2017 ARM Limited + * SPDX-License-Identifier: Apache-2.0 + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +#ifndef SPIRV_CROSS_IMAGE_HPP +#define SPIRV_CROSS_IMAGE_HPP + +#ifndef GLM_SWIZZLE +#define GLM_SWIZZLE +#endif + +#ifndef GLM_FORCE_RADIANS +#define GLM_FORCE_RADIANS +#endif + +#include + +namespace spirv_cross +{ +template +struct image2DBase +{ + virtual ~image2DBase() = default; + inline virtual T load(glm::ivec2 coord) const + { + return T(0, 0, 0, 1); + } + inline virtual void store(glm::ivec2 coord, const T &v) + { + } +}; + +typedef image2DBase image2D; +typedef image2DBase iimage2D; +typedef image2DBase uimage2D; + +template +inline T imageLoad(const image2DBase &image, glm::ivec2 coord) +{ + return image.load(coord); +} + +template +void imageStore(image2DBase &image, glm::ivec2 coord, const T &value) +{ + image.store(coord, value); +} +} + +#endif diff --git a/thirdparty/spirv-cross/include/spirv_cross/internal_interface.hpp b/thirdparty/spirv-cross/include/spirv_cross/internal_interface.hpp new file mode 100644 index 00000000000..3ff7f8e258c --- /dev/null +++ b/thirdparty/spirv-cross/include/spirv_cross/internal_interface.hpp @@ -0,0 +1,604 @@ +/* + * Copyright 2015-2017 ARM Limited + * SPDX-License-Identifier: Apache-2.0 + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +#ifndef SPIRV_CROSS_INTERNAL_INTERFACE_HPP +#define SPIRV_CROSS_INTERNAL_INTERFACE_HPP + +// This file must only be included by the shader generated by spirv-cross! + +#ifndef GLM_FORCE_SWIZZLE +#define GLM_FORCE_SWIZZLE +#endif + +#ifndef GLM_FORCE_RADIANS +#define GLM_FORCE_RADIANS +#endif + +#include + +#include "barrier.hpp" +#include "external_interface.h" +#include "image.hpp" +#include "sampler.hpp" +#include "thread_group.hpp" +#include +#include + +namespace internal +{ +// Adaptor helpers to adapt GLSL access chain syntax to C++. +// Don't bother with arrays of arrays on uniforms ... +// Would likely need horribly complex variadic template munging. + +template +struct Interface +{ + enum + { + ArraySize = 1, + Size = sizeof(T) + }; + + Interface() + : ptr(0) + { + } + T &get() + { + assert(ptr); + return *ptr; + } + + T *ptr; +}; + +// For array types, return a pointer instead. +template +struct Interface +{ + enum + { + ArraySize = U, + Size = U * sizeof(T) + }; + + Interface() + : ptr(0) + { + } + T *get() + { + assert(ptr); + return ptr; + } + + T *ptr; +}; + +// For case when array size is 1, avoid double dereference. +template +struct PointerInterface +{ + enum + { + ArraySize = 1, + Size = sizeof(T *) + }; + enum + { + PreDereference = true + }; + + PointerInterface() + : ptr(0) + { + } + + T &get() + { + assert(ptr); + return *ptr; + } + + T *ptr; +}; + +// Automatically converts a pointer down to reference to match GLSL syntax. +template +struct DereferenceAdaptor +{ + DereferenceAdaptor(T **ptr) + : ptr(ptr) + { + } + T &operator[](unsigned index) const + { + return *(ptr[index]); + } + T **ptr; +}; + +// We can't have a linear array of T* since T* can be an abstract type in case of samplers. +// We also need a list of pointers since we can have run-time length SSBOs. +template +struct PointerInterface +{ + enum + { + ArraySize = U, + Size = sizeof(T *) * U + }; + enum + { + PreDereference = false + }; + PointerInterface() + : ptr(0) + { + } + + DereferenceAdaptor get() + { + assert(ptr); + return DereferenceAdaptor(ptr); + } + + T **ptr; +}; + +// Resources can be more abstract and be unsized, +// so we need to have an array of pointers for those cases. +template +struct Resource : PointerInterface +{ +}; + +// POD with no unknown sizes, so we can express these as flat arrays. +template +struct UniformConstant : Interface +{ +}; +template +struct StageInput : Interface +{ +}; +template +struct StageOutput : Interface +{ +}; +template +struct PushConstant : Interface +{ +}; +} + +struct spirv_cross_shader +{ + struct PPSize + { + PPSize() + : ptr(0) + , size(0) + { + } + void **ptr; + size_t size; + }; + + struct PPSizeResource + { + PPSizeResource() + : ptr(0) + , size(0) + , pre_dereference(false) + { + } + void **ptr; + size_t size; + bool pre_dereference; + }; + + PPSizeResource resources[SPIRV_CROSS_NUM_DESCRIPTOR_SETS][SPIRV_CROSS_NUM_DESCRIPTOR_BINDINGS]; + PPSize stage_inputs[SPIRV_CROSS_NUM_STAGE_INPUTS]; + PPSize stage_outputs[SPIRV_CROSS_NUM_STAGE_OUTPUTS]; + PPSize uniform_constants[SPIRV_CROSS_NUM_UNIFORM_CONSTANTS]; + PPSize push_constant; + PPSize builtins[SPIRV_CROSS_NUM_BUILTINS]; + + template + void register_builtin(spirv_cross_builtin builtin, const U &value) + { + assert(!builtins[builtin].ptr); + + builtins[builtin].ptr = (void **)&value.ptr; + builtins[builtin].size = sizeof(*value.ptr) * U::ArraySize; + } + + void set_builtin(spirv_cross_builtin builtin, void *data, size_t size) + { + assert(builtins[builtin].ptr); + assert(size >= builtins[builtin].size); + + *builtins[builtin].ptr = data; + } + + template + void register_resource(const internal::Resource &value, unsigned set, unsigned binding) + { + assert(set < SPIRV_CROSS_NUM_DESCRIPTOR_SETS); + assert(binding < SPIRV_CROSS_NUM_DESCRIPTOR_BINDINGS); + assert(!resources[set][binding].ptr); + + resources[set][binding].ptr = (void **)&value.ptr; + resources[set][binding].size = internal::Resource::Size; + resources[set][binding].pre_dereference = internal::Resource::PreDereference; + } + + template + void register_stage_input(const internal::StageInput &value, unsigned location) + { + assert(location < SPIRV_CROSS_NUM_STAGE_INPUTS); + assert(!stage_inputs[location].ptr); + + stage_inputs[location].ptr = (void **)&value.ptr; + stage_inputs[location].size = internal::StageInput::Size; + } + + template + void register_stage_output(const internal::StageOutput &value, unsigned location) + { + assert(location < SPIRV_CROSS_NUM_STAGE_OUTPUTS); + assert(!stage_outputs[location].ptr); + + stage_outputs[location].ptr = (void **)&value.ptr; + stage_outputs[location].size = internal::StageOutput::Size; + } + + template + void register_uniform_constant(const internal::UniformConstant &value, unsigned location) + { + assert(location < SPIRV_CROSS_NUM_UNIFORM_CONSTANTS); + assert(!uniform_constants[location].ptr); + + uniform_constants[location].ptr = (void **)&value.ptr; + uniform_constants[location].size = internal::UniformConstant::Size; + } + + template + void register_push_constant(const internal::PushConstant &value) + { + assert(!push_constant.ptr); + + push_constant.ptr = (void **)&value.ptr; + push_constant.size = internal::PushConstant::Size; + } + + void set_stage_input(unsigned location, void *data, size_t size) + { + assert(location < SPIRV_CROSS_NUM_STAGE_INPUTS); + assert(stage_inputs[location].ptr); + assert(size >= stage_inputs[location].size); + + *stage_inputs[location].ptr = data; + } + + void set_stage_output(unsigned location, void *data, size_t size) + { + assert(location < SPIRV_CROSS_NUM_STAGE_OUTPUTS); + assert(stage_outputs[location].ptr); + assert(size >= stage_outputs[location].size); + + *stage_outputs[location].ptr = data; + } + + void set_uniform_constant(unsigned location, void *data, size_t size) + { + assert(location < SPIRV_CROSS_NUM_UNIFORM_CONSTANTS); + assert(uniform_constants[location].ptr); + assert(size >= uniform_constants[location].size); + + *uniform_constants[location].ptr = data; + } + + void set_push_constant(void *data, size_t size) + { + assert(push_constant.ptr); + assert(size >= push_constant.size); + + *push_constant.ptr = data; + } + + void set_resource(unsigned set, unsigned binding, void **data, size_t size) + { + assert(set < SPIRV_CROSS_NUM_DESCRIPTOR_SETS); + assert(binding < SPIRV_CROSS_NUM_DESCRIPTOR_BINDINGS); + assert(resources[set][binding].ptr); + assert(size >= resources[set][binding].size); + + // We're using the regular PointerInterface, dereference ahead of time. + if (resources[set][binding].pre_dereference) + *resources[set][binding].ptr = *data; + else + *resources[set][binding].ptr = data; + } +}; + +namespace spirv_cross +{ +template +struct BaseShader : spirv_cross_shader +{ + void invoke() + { + static_cast(this)->main(); + } +}; + +struct FragmentResources +{ + internal::StageOutput gl_FragCoord; + void init(spirv_cross_shader &s) + { + s.register_builtin(SPIRV_CROSS_BUILTIN_FRAG_COORD, gl_FragCoord); + } +#define gl_FragCoord __res->gl_FragCoord.get() +}; + +template +struct FragmentShader : BaseShader> +{ + inline void main() + { + impl.main(); + } + + FragmentShader() + { + resources.init(*this); + impl.__res = &resources; + } + + T impl; + Res resources; +}; + +struct VertexResources +{ + internal::StageOutput gl_Position; + void init(spirv_cross_shader &s) + { + s.register_builtin(SPIRV_CROSS_BUILTIN_POSITION, gl_Position); + } +#define gl_Position __res->gl_Position.get() +}; + +template +struct VertexShader : BaseShader> +{ + inline void main() + { + impl.main(); + } + + VertexShader() + { + resources.init(*this); + impl.__res = &resources; + } + + T impl; + Res resources; +}; + +struct TessEvaluationResources +{ + inline void init(spirv_cross_shader &) + { + } +}; + +template +struct TessEvaluationShader : BaseShader> +{ + inline void main() + { + impl.main(); + } + + TessEvaluationShader() + { + resources.init(*this); + impl.__res = &resources; + } + + T impl; + Res resources; +}; + +struct TessControlResources +{ + inline void init(spirv_cross_shader &) + { + } +}; + +template +struct TessControlShader : BaseShader> +{ + inline void main() + { + impl.main(); + } + + TessControlShader() + { + resources.init(*this); + impl.__res = &resources; + } + + T impl; + Res resources; +}; + +struct GeometryResources +{ + inline void init(spirv_cross_shader &) + { + } +}; + +template +struct GeometryShader : BaseShader> +{ + inline void main() + { + impl.main(); + } + + GeometryShader() + { + resources.init(*this); + impl.__res = &resources; + } + + T impl; + Res resources; +}; + +struct ComputeResources +{ + internal::StageInput gl_WorkGroupID__; + internal::StageInput gl_NumWorkGroups__; + void init(spirv_cross_shader &s) + { + s.register_builtin(SPIRV_CROSS_BUILTIN_WORK_GROUP_ID, gl_WorkGroupID__); + s.register_builtin(SPIRV_CROSS_BUILTIN_NUM_WORK_GROUPS, gl_NumWorkGroups__); + } +#define gl_WorkGroupID __res->gl_WorkGroupID__.get() +#define gl_NumWorkGroups __res->gl_NumWorkGroups__.get() + + Barrier barrier__; +#define barrier() __res->barrier__.wait() +}; + +struct ComputePrivateResources +{ + uint32_t gl_LocalInvocationIndex__; +#define gl_LocalInvocationIndex __priv_res.gl_LocalInvocationIndex__ + glm::uvec3 gl_LocalInvocationID__; +#define gl_LocalInvocationID __priv_res.gl_LocalInvocationID__ + glm::uvec3 gl_GlobalInvocationID__; +#define gl_GlobalInvocationID __priv_res.gl_GlobalInvocationID__ +}; + +template +struct ComputeShader : BaseShader> +{ + inline void main() + { + resources.barrier__.reset_counter(); + + for (unsigned z = 0; z < WorkGroupZ; z++) + for (unsigned y = 0; y < WorkGroupY; y++) + for (unsigned x = 0; x < WorkGroupX; x++) + impl[z][y][x].__priv_res.gl_GlobalInvocationID__ = + glm::uvec3(WorkGroupX, WorkGroupY, WorkGroupZ) * resources.gl_WorkGroupID__.get() + + glm::uvec3(x, y, z); + + group.run(); + group.wait(); + } + + ComputeShader() + : group(&impl[0][0][0]) + { + resources.init(*this); + resources.barrier__.set_release_divisor(WorkGroupX * WorkGroupY * WorkGroupZ); + + unsigned i = 0; + for (unsigned z = 0; z < WorkGroupZ; z++) + { + for (unsigned y = 0; y < WorkGroupY; y++) + { + for (unsigned x = 0; x < WorkGroupX; x++) + { + impl[z][y][x].__priv_res.gl_LocalInvocationID__ = glm::uvec3(x, y, z); + impl[z][y][x].__priv_res.gl_LocalInvocationIndex__ = i++; + impl[z][y][x].__res = &resources; + } + } + } + } + + T impl[WorkGroupZ][WorkGroupY][WorkGroupX]; + ThreadGroup group; + Res resources; +}; + +inline void memoryBarrierShared() +{ + Barrier::memoryBarrier(); +} +inline void memoryBarrier() +{ + Barrier::memoryBarrier(); +} +// TODO: Rest of the barriers. + +// Atomics +template +inline T atomicAdd(T &v, T a) +{ + static_assert(sizeof(std::atomic) == sizeof(T), "Cannot cast properly to std::atomic."); + + // We need explicit memory barriers in GLSL to enfore any ordering. + // FIXME: Can we really cast this? There is no other way I think ... + return std::atomic_fetch_add_explicit(reinterpret_cast *>(&v), a, std::memory_order_relaxed); +} +} + +void spirv_cross_set_stage_input(spirv_cross_shader_t *shader, unsigned location, void *data, size_t size) +{ + shader->set_stage_input(location, data, size); +} + +void spirv_cross_set_stage_output(spirv_cross_shader_t *shader, unsigned location, void *data, size_t size) +{ + shader->set_stage_output(location, data, size); +} + +void spirv_cross_set_uniform_constant(spirv_cross_shader_t *shader, unsigned location, void *data, size_t size) +{ + shader->set_uniform_constant(location, data, size); +} + +void spirv_cross_set_resource(spirv_cross_shader_t *shader, unsigned set, unsigned binding, void **data, size_t size) +{ + shader->set_resource(set, binding, data, size); +} + +void spirv_cross_set_push_constant(spirv_cross_shader_t *shader, void *data, size_t size) +{ + shader->set_push_constant(data, size); +} + +void spirv_cross_set_builtin(spirv_cross_shader_t *shader, spirv_cross_builtin builtin, void *data, size_t size) +{ + shader->set_builtin(builtin, data, size); +} + +#endif diff --git a/thirdparty/spirv-cross/include/spirv_cross/sampler.hpp b/thirdparty/spirv-cross/include/spirv_cross/sampler.hpp new file mode 100644 index 00000000000..02084809514 --- /dev/null +++ b/thirdparty/spirv-cross/include/spirv_cross/sampler.hpp @@ -0,0 +1,106 @@ +/* + * Copyright 2015-2017 ARM Limited + * SPDX-License-Identifier: Apache-2.0 + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +#ifndef SPIRV_CROSS_SAMPLER_HPP +#define SPIRV_CROSS_SAMPLER_HPP + +#include + +namespace spirv_cross +{ +struct spirv_cross_sampler_2d +{ + inline virtual ~spirv_cross_sampler_2d() + { + } +}; + +template +struct sampler2DBase : spirv_cross_sampler_2d +{ + sampler2DBase(const spirv_cross_sampler_info *info) + { + mips.insert(mips.end(), info->mipmaps, info->mipmaps + info->num_mipmaps); + format = info->format; + wrap_s = info->wrap_s; + wrap_t = info->wrap_t; + min_filter = info->min_filter; + mag_filter = info->mag_filter; + mip_filter = info->mip_filter; + } + + inline virtual T sample(glm::vec2 uv, float bias) + { + return sampleLod(uv, bias); + } + + inline virtual T sampleLod(glm::vec2 uv, float lod) + { + if (mag_filter == SPIRV_CROSS_FILTER_NEAREST) + { + uv.x = wrap(uv.x, wrap_s, mips[0].width); + uv.y = wrap(uv.y, wrap_t, mips[0].height); + glm::vec2 uv_full = uv * glm::vec2(mips[0].width, mips[0].height); + + int x = int(uv_full.x); + int y = int(uv_full.y); + return sample(x, y, 0); + } + else + { + return T(0, 0, 0, 1); + } + } + + inline float wrap(float v, spirv_cross_wrap wrap, unsigned size) + { + switch (wrap) + { + case SPIRV_CROSS_WRAP_REPEAT: + return v - glm::floor(v); + case SPIRV_CROSS_WRAP_CLAMP_TO_EDGE: + { + float half = 0.5f / size; + return glm::clamp(v, half, 1.0f - half); + } + + default: + return 0.0f; + } + } + + std::vector mips; + spirv_cross_format format; + spirv_cross_wrap wrap_s; + spirv_cross_wrap wrap_t; + spirv_cross_filter min_filter; + spirv_cross_filter mag_filter; + spirv_cross_mipfilter mip_filter; +}; + +typedef sampler2DBase sampler2D; +typedef sampler2DBase isampler2D; +typedef sampler2DBase usampler2D; + +template +inline T texture(const sampler2DBase &samp, const glm::vec2 &uv, float bias = 0.0f) +{ + return samp.sample(uv, bias); +} +} + +#endif diff --git a/thirdparty/spirv-cross/include/spirv_cross/thread_group.hpp b/thirdparty/spirv-cross/include/spirv_cross/thread_group.hpp new file mode 100644 index 00000000000..b2155815625 --- /dev/null +++ b/thirdparty/spirv-cross/include/spirv_cross/thread_group.hpp @@ -0,0 +1,114 @@ +/* + * Copyright 2015-2017 ARM Limited + * SPDX-License-Identifier: Apache-2.0 + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +#ifndef SPIRV_CROSS_THREAD_GROUP_HPP +#define SPIRV_CROSS_THREAD_GROUP_HPP + +#include +#include +#include + +namespace spirv_cross +{ +template +class ThreadGroup +{ +public: + ThreadGroup(T *impl) + { + for (unsigned i = 0; i < Size; i++) + workers[i].start(&impl[i]); + } + + void run() + { + for (auto &worker : workers) + worker.run(); + } + + void wait() + { + for (auto &worker : workers) + worker.wait(); + } + +private: + struct Thread + { + enum State + { + Idle, + Running, + Dying + }; + State state = Idle; + + void start(T *impl) + { + worker = std::thread([impl, this] { + for (;;) + { + { + std::unique_lock l{ lock }; + cond.wait(l, [this] { return state != Idle; }); + if (state == Dying) + break; + } + + impl->main(); + + std::lock_guard l{ lock }; + state = Idle; + cond.notify_one(); + } + }); + } + + void wait() + { + std::unique_lock l{ lock }; + cond.wait(l, [this] { return state == Idle; }); + } + + void run() + { + std::lock_guard l{ lock }; + state = Running; + cond.notify_one(); + } + + ~Thread() + { + if (worker.joinable()) + { + { + std::lock_guard l{ lock }; + state = Dying; + cond.notify_one(); + } + worker.join(); + } + } + std::thread worker; + std::condition_variable cond; + std::mutex lock; + }; + Thread workers[Size]; +}; +} + +#endif diff --git a/thirdparty/spirv-cross/spirv.hpp b/thirdparty/spirv-cross/spirv.hpp new file mode 100644 index 00000000000..f2ee9096bdd --- /dev/null +++ b/thirdparty/spirv-cross/spirv.hpp @@ -0,0 +1,2592 @@ +// Copyright (c) 2014-2020 The Khronos Group Inc. +// +// Permission is hereby granted, free of charge, to any person obtaining a copy +// of this software and/or associated documentation files (the "Materials"), +// to deal in the Materials without restriction, including without limitation +// the rights to use, copy, modify, merge, publish, distribute, sublicense, +// and/or sell copies of the Materials, and to permit persons to whom the +// Materials are furnished to do so, subject to the following conditions: +// +// The above copyright notice and this permission notice shall be included in +// all copies or substantial portions of the Materials. +// +// MODIFICATIONS TO THIS FILE MAY MEAN IT NO LONGER ACCURATELY REFLECTS KHRONOS +// STANDARDS. THE UNMODIFIED, NORMATIVE VERSIONS OF KHRONOS SPECIFICATIONS AND +// HEADER INFORMATION ARE LOCATED AT https://www.khronos.org/registry/ +// +// THE MATERIALS ARE PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS +// OR IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY, +// FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL +// THE AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER +// LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING +// FROM,OUT OF OR IN CONNECTION WITH THE MATERIALS OR THE USE OR OTHER DEALINGS +// IN THE MATERIALS. + +// This header is automatically generated by the same tool that creates +// the Binary Section of the SPIR-V specification. + +// Enumeration tokens for SPIR-V, in various styles: +// C, C++, C++11, JSON, Lua, Python, C#, D, Beef +// +// - C will have tokens with a "Spv" prefix, e.g.: SpvSourceLanguageGLSL +// - C++ will have tokens in the "spv" name space, e.g.: spv::SourceLanguageGLSL +// - C++11 will use enum classes in the spv namespace, e.g.: spv::SourceLanguage::GLSL +// - Lua will use tables, e.g.: spv.SourceLanguage.GLSL +// - Python will use dictionaries, e.g.: spv['SourceLanguage']['GLSL'] +// - C# will use enum classes in the Specification class located in the "Spv" namespace, +// e.g.: Spv.Specification.SourceLanguage.GLSL +// - D will have tokens under the "spv" module, e.g: spv.SourceLanguage.GLSL +// - Beef will use enum classes in the Specification class located in the "Spv" namespace, +// e.g.: Spv.Specification.SourceLanguage.GLSL +// +// Some tokens act like mask values, which can be OR'd together, +// while others are mutually exclusive. The mask-like ones have +// "Mask" in their name, and a parallel enum that has the shift +// amount (1 << x) for each corresponding enumerant. + +#ifndef spirv_HPP +#define spirv_HPP + +namespace spv { + +typedef unsigned int Id; + +#define SPV_VERSION 0x10600 +#define SPV_REVISION 1 + +static const unsigned int MagicNumber = 0x07230203; +static const unsigned int Version = 0x00010600; +static const unsigned int Revision = 1; +static const unsigned int OpCodeMask = 0xffff; +static const unsigned int WordCountShift = 16; + +enum SourceLanguage { + SourceLanguageUnknown = 0, + SourceLanguageESSL = 1, + SourceLanguageGLSL = 2, + SourceLanguageOpenCL_C = 3, + SourceLanguageOpenCL_CPP = 4, + SourceLanguageHLSL = 5, + SourceLanguageCPP_for_OpenCL = 6, + SourceLanguageSYCL = 7, + SourceLanguageMax = 0x7fffffff, +}; + +enum ExecutionModel { + ExecutionModelVertex = 0, + ExecutionModelTessellationControl = 1, + ExecutionModelTessellationEvaluation = 2, + ExecutionModelGeometry = 3, + ExecutionModelFragment = 4, + ExecutionModelGLCompute = 5, + ExecutionModelKernel = 6, + ExecutionModelTaskNV = 5267, + ExecutionModelMeshNV = 5268, + ExecutionModelRayGenerationKHR = 5313, + ExecutionModelRayGenerationNV = 5313, + ExecutionModelIntersectionKHR = 5314, + ExecutionModelIntersectionNV = 5314, + ExecutionModelAnyHitKHR = 5315, + ExecutionModelAnyHitNV = 5315, + ExecutionModelClosestHitKHR = 5316, + ExecutionModelClosestHitNV = 5316, + ExecutionModelMissKHR = 5317, + ExecutionModelMissNV = 5317, + ExecutionModelCallableKHR = 5318, + ExecutionModelCallableNV = 5318, + ExecutionModelTaskEXT = 5364, + ExecutionModelMeshEXT = 5365, + ExecutionModelMax = 0x7fffffff, +}; + +enum AddressingModel { + AddressingModelLogical = 0, + AddressingModelPhysical32 = 1, + AddressingModelPhysical64 = 2, + AddressingModelPhysicalStorageBuffer64 = 5348, + AddressingModelPhysicalStorageBuffer64EXT = 5348, + AddressingModelMax = 0x7fffffff, +}; + +enum MemoryModel { + MemoryModelSimple = 0, + MemoryModelGLSL450 = 1, + MemoryModelOpenCL = 2, + MemoryModelVulkan = 3, + MemoryModelVulkanKHR = 3, + MemoryModelMax = 0x7fffffff, +}; + +enum ExecutionMode { + ExecutionModeInvocations = 0, + ExecutionModeSpacingEqual = 1, + ExecutionModeSpacingFractionalEven = 2, + ExecutionModeSpacingFractionalOdd = 3, + ExecutionModeVertexOrderCw = 4, + ExecutionModeVertexOrderCcw = 5, + ExecutionModePixelCenterInteger = 6, + ExecutionModeOriginUpperLeft = 7, + ExecutionModeOriginLowerLeft = 8, + ExecutionModeEarlyFragmentTests = 9, + ExecutionModePointMode = 10, + ExecutionModeXfb = 11, + ExecutionModeDepthReplacing = 12, + ExecutionModeDepthGreater = 14, + ExecutionModeDepthLess = 15, + ExecutionModeDepthUnchanged = 16, + ExecutionModeLocalSize = 17, + ExecutionModeLocalSizeHint = 18, + ExecutionModeInputPoints = 19, + ExecutionModeInputLines = 20, + ExecutionModeInputLinesAdjacency = 21, + ExecutionModeTriangles = 22, + ExecutionModeInputTrianglesAdjacency = 23, + ExecutionModeQuads = 24, + ExecutionModeIsolines = 25, + ExecutionModeOutputVertices = 26, + ExecutionModeOutputPoints = 27, + ExecutionModeOutputLineStrip = 28, + ExecutionModeOutputTriangleStrip = 29, + ExecutionModeVecTypeHint = 30, + ExecutionModeContractionOff = 31, + ExecutionModeInitializer = 33, + ExecutionModeFinalizer = 34, + ExecutionModeSubgroupSize = 35, + ExecutionModeSubgroupsPerWorkgroup = 36, + ExecutionModeSubgroupsPerWorkgroupId = 37, + ExecutionModeLocalSizeId = 38, + ExecutionModeLocalSizeHintId = 39, + ExecutionModeSubgroupUniformControlFlowKHR = 4421, + ExecutionModePostDepthCoverage = 4446, + ExecutionModeDenormPreserve = 4459, + ExecutionModeDenormFlushToZero = 4460, + ExecutionModeSignedZeroInfNanPreserve = 4461, + ExecutionModeRoundingModeRTE = 4462, + ExecutionModeRoundingModeRTZ = 4463, + ExecutionModeEarlyAndLateFragmentTestsAMD = 5017, + ExecutionModeStencilRefReplacingEXT = 5027, + ExecutionModeStencilRefUnchangedFrontAMD = 5079, + ExecutionModeStencilRefGreaterFrontAMD = 5080, + ExecutionModeStencilRefLessFrontAMD = 5081, + ExecutionModeStencilRefUnchangedBackAMD = 5082, + ExecutionModeStencilRefGreaterBackAMD = 5083, + ExecutionModeStencilRefLessBackAMD = 5084, + ExecutionModeOutputLinesEXT = 5269, + ExecutionModeOutputLinesNV = 5269, + ExecutionModeOutputPrimitivesEXT = 5270, + ExecutionModeOutputPrimitivesNV = 5270, + ExecutionModeDerivativeGroupQuadsNV = 5289, + ExecutionModeDerivativeGroupLinearNV = 5290, + ExecutionModeOutputTrianglesEXT = 5298, + ExecutionModeOutputTrianglesNV = 5298, + ExecutionModePixelInterlockOrderedEXT = 5366, + ExecutionModePixelInterlockUnorderedEXT = 5367, + ExecutionModeSampleInterlockOrderedEXT = 5368, + ExecutionModeSampleInterlockUnorderedEXT = 5369, + ExecutionModeShadingRateInterlockOrderedEXT = 5370, + ExecutionModeShadingRateInterlockUnorderedEXT = 5371, + ExecutionModeSharedLocalMemorySizeINTEL = 5618, + ExecutionModeRoundingModeRTPINTEL = 5620, + ExecutionModeRoundingModeRTNINTEL = 5621, + ExecutionModeFloatingPointModeALTINTEL = 5622, + ExecutionModeFloatingPointModeIEEEINTEL = 5623, + ExecutionModeMaxWorkgroupSizeINTEL = 5893, + ExecutionModeMaxWorkDimINTEL = 5894, + ExecutionModeNoGlobalOffsetINTEL = 5895, + ExecutionModeNumSIMDWorkitemsINTEL = 5896, + ExecutionModeSchedulerTargetFmaxMhzINTEL = 5903, + ExecutionModeNamedBarrierCountINTEL = 6417, + ExecutionModeMax = 0x7fffffff, +}; + +enum StorageClass { + StorageClassUniformConstant = 0, + StorageClassInput = 1, + StorageClassUniform = 2, + StorageClassOutput = 3, + StorageClassWorkgroup = 4, + StorageClassCrossWorkgroup = 5, + StorageClassPrivate = 6, + StorageClassFunction = 7, + StorageClassGeneric = 8, + StorageClassPushConstant = 9, + StorageClassAtomicCounter = 10, + StorageClassImage = 11, + StorageClassStorageBuffer = 12, + StorageClassCallableDataKHR = 5328, + StorageClassCallableDataNV = 5328, + StorageClassIncomingCallableDataKHR = 5329, + StorageClassIncomingCallableDataNV = 5329, + StorageClassRayPayloadKHR = 5338, + StorageClassRayPayloadNV = 5338, + StorageClassHitAttributeKHR = 5339, + StorageClassHitAttributeNV = 5339, + StorageClassIncomingRayPayloadKHR = 5342, + StorageClassIncomingRayPayloadNV = 5342, + StorageClassShaderRecordBufferKHR = 5343, + StorageClassShaderRecordBufferNV = 5343, + StorageClassPhysicalStorageBuffer = 5349, + StorageClassPhysicalStorageBufferEXT = 5349, + StorageClassTaskPayloadWorkgroupEXT = 5402, + StorageClassCodeSectionINTEL = 5605, + StorageClassDeviceOnlyINTEL = 5936, + StorageClassHostOnlyINTEL = 5937, + StorageClassMax = 0x7fffffff, +}; + +enum Dim { + Dim1D = 0, + Dim2D = 1, + Dim3D = 2, + DimCube = 3, + DimRect = 4, + DimBuffer = 5, + DimSubpassData = 6, + DimMax = 0x7fffffff, +}; + +enum SamplerAddressingMode { + SamplerAddressingModeNone = 0, + SamplerAddressingModeClampToEdge = 1, + SamplerAddressingModeClamp = 2, + SamplerAddressingModeRepeat = 3, + SamplerAddressingModeRepeatMirrored = 4, + SamplerAddressingModeMax = 0x7fffffff, +}; + +enum SamplerFilterMode { + SamplerFilterModeNearest = 0, + SamplerFilterModeLinear = 1, + SamplerFilterModeMax = 0x7fffffff, +}; + +enum ImageFormat { + ImageFormatUnknown = 0, + ImageFormatRgba32f = 1, + ImageFormatRgba16f = 2, + ImageFormatR32f = 3, + ImageFormatRgba8 = 4, + ImageFormatRgba8Snorm = 5, + ImageFormatRg32f = 6, + ImageFormatRg16f = 7, + ImageFormatR11fG11fB10f = 8, + ImageFormatR16f = 9, + ImageFormatRgba16 = 10, + ImageFormatRgb10A2 = 11, + ImageFormatRg16 = 12, + ImageFormatRg8 = 13, + ImageFormatR16 = 14, + ImageFormatR8 = 15, + ImageFormatRgba16Snorm = 16, + ImageFormatRg16Snorm = 17, + ImageFormatRg8Snorm = 18, + ImageFormatR16Snorm = 19, + ImageFormatR8Snorm = 20, + ImageFormatRgba32i = 21, + ImageFormatRgba16i = 22, + ImageFormatRgba8i = 23, + ImageFormatR32i = 24, + ImageFormatRg32i = 25, + ImageFormatRg16i = 26, + ImageFormatRg8i = 27, + ImageFormatR16i = 28, + ImageFormatR8i = 29, + ImageFormatRgba32ui = 30, + ImageFormatRgba16ui = 31, + ImageFormatRgba8ui = 32, + ImageFormatR32ui = 33, + ImageFormatRgb10a2ui = 34, + ImageFormatRg32ui = 35, + ImageFormatRg16ui = 36, + ImageFormatRg8ui = 37, + ImageFormatR16ui = 38, + ImageFormatR8ui = 39, + ImageFormatR64ui = 40, + ImageFormatR64i = 41, + ImageFormatMax = 0x7fffffff, +}; + +enum ImageChannelOrder { + ImageChannelOrderR = 0, + ImageChannelOrderA = 1, + ImageChannelOrderRG = 2, + ImageChannelOrderRA = 3, + ImageChannelOrderRGB = 4, + ImageChannelOrderRGBA = 5, + ImageChannelOrderBGRA = 6, + ImageChannelOrderARGB = 7, + ImageChannelOrderIntensity = 8, + ImageChannelOrderLuminance = 9, + ImageChannelOrderRx = 10, + ImageChannelOrderRGx = 11, + ImageChannelOrderRGBx = 12, + ImageChannelOrderDepth = 13, + ImageChannelOrderDepthStencil = 14, + ImageChannelOrdersRGB = 15, + ImageChannelOrdersRGBx = 16, + ImageChannelOrdersRGBA = 17, + ImageChannelOrdersBGRA = 18, + ImageChannelOrderABGR = 19, + ImageChannelOrderMax = 0x7fffffff, +}; + +enum ImageChannelDataType { + ImageChannelDataTypeSnormInt8 = 0, + ImageChannelDataTypeSnormInt16 = 1, + ImageChannelDataTypeUnormInt8 = 2, + ImageChannelDataTypeUnormInt16 = 3, + ImageChannelDataTypeUnormShort565 = 4, + ImageChannelDataTypeUnormShort555 = 5, + ImageChannelDataTypeUnormInt101010 = 6, + ImageChannelDataTypeSignedInt8 = 7, + ImageChannelDataTypeSignedInt16 = 8, + ImageChannelDataTypeSignedInt32 = 9, + ImageChannelDataTypeUnsignedInt8 = 10, + ImageChannelDataTypeUnsignedInt16 = 11, + ImageChannelDataTypeUnsignedInt32 = 12, + ImageChannelDataTypeHalfFloat = 13, + ImageChannelDataTypeFloat = 14, + ImageChannelDataTypeUnormInt24 = 15, + ImageChannelDataTypeUnormInt101010_2 = 16, + ImageChannelDataTypeMax = 0x7fffffff, +}; + +enum ImageOperandsShift { + ImageOperandsBiasShift = 0, + ImageOperandsLodShift = 1, + ImageOperandsGradShift = 2, + ImageOperandsConstOffsetShift = 3, + ImageOperandsOffsetShift = 4, + ImageOperandsConstOffsetsShift = 5, + ImageOperandsSampleShift = 6, + ImageOperandsMinLodShift = 7, + ImageOperandsMakeTexelAvailableShift = 8, + ImageOperandsMakeTexelAvailableKHRShift = 8, + ImageOperandsMakeTexelVisibleShift = 9, + ImageOperandsMakeTexelVisibleKHRShift = 9, + ImageOperandsNonPrivateTexelShift = 10, + ImageOperandsNonPrivateTexelKHRShift = 10, + ImageOperandsVolatileTexelShift = 11, + ImageOperandsVolatileTexelKHRShift = 11, + ImageOperandsSignExtendShift = 12, + ImageOperandsZeroExtendShift = 13, + ImageOperandsNontemporalShift = 14, + ImageOperandsOffsetsShift = 16, + ImageOperandsMax = 0x7fffffff, +}; + +enum ImageOperandsMask { + ImageOperandsMaskNone = 0, + ImageOperandsBiasMask = 0x00000001, + ImageOperandsLodMask = 0x00000002, + ImageOperandsGradMask = 0x00000004, + ImageOperandsConstOffsetMask = 0x00000008, + ImageOperandsOffsetMask = 0x00000010, + ImageOperandsConstOffsetsMask = 0x00000020, + ImageOperandsSampleMask = 0x00000040, + ImageOperandsMinLodMask = 0x00000080, + ImageOperandsMakeTexelAvailableMask = 0x00000100, + ImageOperandsMakeTexelAvailableKHRMask = 0x00000100, + ImageOperandsMakeTexelVisibleMask = 0x00000200, + ImageOperandsMakeTexelVisibleKHRMask = 0x00000200, + ImageOperandsNonPrivateTexelMask = 0x00000400, + ImageOperandsNonPrivateTexelKHRMask = 0x00000400, + ImageOperandsVolatileTexelMask = 0x00000800, + ImageOperandsVolatileTexelKHRMask = 0x00000800, + ImageOperandsSignExtendMask = 0x00001000, + ImageOperandsZeroExtendMask = 0x00002000, + ImageOperandsNontemporalMask = 0x00004000, + ImageOperandsOffsetsMask = 0x00010000, +}; + +enum FPFastMathModeShift { + FPFastMathModeNotNaNShift = 0, + FPFastMathModeNotInfShift = 1, + FPFastMathModeNSZShift = 2, + FPFastMathModeAllowRecipShift = 3, + FPFastMathModeFastShift = 4, + FPFastMathModeAllowContractFastINTELShift = 16, + FPFastMathModeAllowReassocINTELShift = 17, + FPFastMathModeMax = 0x7fffffff, +}; + +enum FPFastMathModeMask { + FPFastMathModeMaskNone = 0, + FPFastMathModeNotNaNMask = 0x00000001, + FPFastMathModeNotInfMask = 0x00000002, + FPFastMathModeNSZMask = 0x00000004, + FPFastMathModeAllowRecipMask = 0x00000008, + FPFastMathModeFastMask = 0x00000010, + FPFastMathModeAllowContractFastINTELMask = 0x00010000, + FPFastMathModeAllowReassocINTELMask = 0x00020000, +}; + +enum FPRoundingMode { + FPRoundingModeRTE = 0, + FPRoundingModeRTZ = 1, + FPRoundingModeRTP = 2, + FPRoundingModeRTN = 3, + FPRoundingModeMax = 0x7fffffff, +}; + +enum LinkageType { + LinkageTypeExport = 0, + LinkageTypeImport = 1, + LinkageTypeLinkOnceODR = 2, + LinkageTypeMax = 0x7fffffff, +}; + +enum AccessQualifier { + AccessQualifierReadOnly = 0, + AccessQualifierWriteOnly = 1, + AccessQualifierReadWrite = 2, + AccessQualifierMax = 0x7fffffff, +}; + +enum FunctionParameterAttribute { + FunctionParameterAttributeZext = 0, + FunctionParameterAttributeSext = 1, + FunctionParameterAttributeByVal = 2, + FunctionParameterAttributeSret = 3, + FunctionParameterAttributeNoAlias = 4, + FunctionParameterAttributeNoCapture = 5, + FunctionParameterAttributeNoWrite = 6, + FunctionParameterAttributeNoReadWrite = 7, + FunctionParameterAttributeMax = 0x7fffffff, +}; + +enum Decoration { + DecorationRelaxedPrecision = 0, + DecorationSpecId = 1, + DecorationBlock = 2, + DecorationBufferBlock = 3, + DecorationRowMajor = 4, + DecorationColMajor = 5, + DecorationArrayStride = 6, + DecorationMatrixStride = 7, + DecorationGLSLShared = 8, + DecorationGLSLPacked = 9, + DecorationCPacked = 10, + DecorationBuiltIn = 11, + DecorationNoPerspective = 13, + DecorationFlat = 14, + DecorationPatch = 15, + DecorationCentroid = 16, + DecorationSample = 17, + DecorationInvariant = 18, + DecorationRestrict = 19, + DecorationAliased = 20, + DecorationVolatile = 21, + DecorationConstant = 22, + DecorationCoherent = 23, + DecorationNonWritable = 24, + DecorationNonReadable = 25, + DecorationUniform = 26, + DecorationUniformId = 27, + DecorationSaturatedConversion = 28, + DecorationStream = 29, + DecorationLocation = 30, + DecorationComponent = 31, + DecorationIndex = 32, + DecorationBinding = 33, + DecorationDescriptorSet = 34, + DecorationOffset = 35, + DecorationXfbBuffer = 36, + DecorationXfbStride = 37, + DecorationFuncParamAttr = 38, + DecorationFPRoundingMode = 39, + DecorationFPFastMathMode = 40, + DecorationLinkageAttributes = 41, + DecorationNoContraction = 42, + DecorationInputAttachmentIndex = 43, + DecorationAlignment = 44, + DecorationMaxByteOffset = 45, + DecorationAlignmentId = 46, + DecorationMaxByteOffsetId = 47, + DecorationNoSignedWrap = 4469, + DecorationNoUnsignedWrap = 4470, + DecorationWeightTextureQCOM = 4487, + DecorationBlockMatchTextureQCOM = 4488, + DecorationExplicitInterpAMD = 4999, + DecorationOverrideCoverageNV = 5248, + DecorationPassthroughNV = 5250, + DecorationViewportRelativeNV = 5252, + DecorationSecondaryViewportRelativeNV = 5256, + DecorationPerPrimitiveEXT = 5271, + DecorationPerPrimitiveNV = 5271, + DecorationPerViewNV = 5272, + DecorationPerTaskNV = 5273, + DecorationPerVertexKHR = 5285, + DecorationPerVertexNV = 5285, + DecorationNonUniform = 5300, + DecorationNonUniformEXT = 5300, + DecorationRestrictPointer = 5355, + DecorationRestrictPointerEXT = 5355, + DecorationAliasedPointer = 5356, + DecorationAliasedPointerEXT = 5356, + DecorationBindlessSamplerNV = 5398, + DecorationBindlessImageNV = 5399, + DecorationBoundSamplerNV = 5400, + DecorationBoundImageNV = 5401, + DecorationSIMTCallINTEL = 5599, + DecorationReferencedIndirectlyINTEL = 5602, + DecorationClobberINTEL = 5607, + DecorationSideEffectsINTEL = 5608, + DecorationVectorComputeVariableINTEL = 5624, + DecorationFuncParamIOKindINTEL = 5625, + DecorationVectorComputeFunctionINTEL = 5626, + DecorationStackCallINTEL = 5627, + DecorationGlobalVariableOffsetINTEL = 5628, + DecorationCounterBuffer = 5634, + DecorationHlslCounterBufferGOOGLE = 5634, + DecorationHlslSemanticGOOGLE = 5635, + DecorationUserSemantic = 5635, + DecorationUserTypeGOOGLE = 5636, + DecorationFunctionRoundingModeINTEL = 5822, + DecorationFunctionDenormModeINTEL = 5823, + DecorationRegisterINTEL = 5825, + DecorationMemoryINTEL = 5826, + DecorationNumbanksINTEL = 5827, + DecorationBankwidthINTEL = 5828, + DecorationMaxPrivateCopiesINTEL = 5829, + DecorationSinglepumpINTEL = 5830, + DecorationDoublepumpINTEL = 5831, + DecorationMaxReplicatesINTEL = 5832, + DecorationSimpleDualPortINTEL = 5833, + DecorationMergeINTEL = 5834, + DecorationBankBitsINTEL = 5835, + DecorationForcePow2DepthINTEL = 5836, + DecorationBurstCoalesceINTEL = 5899, + DecorationCacheSizeINTEL = 5900, + DecorationDontStaticallyCoalesceINTEL = 5901, + DecorationPrefetchINTEL = 5902, + DecorationStallEnableINTEL = 5905, + DecorationFuseLoopsInFunctionINTEL = 5907, + DecorationAliasScopeINTEL = 5914, + DecorationNoAliasINTEL = 5915, + DecorationBufferLocationINTEL = 5921, + DecorationIOPipeStorageINTEL = 5944, + DecorationFunctionFloatingPointModeINTEL = 6080, + DecorationSingleElementVectorINTEL = 6085, + DecorationVectorComputeCallableFunctionINTEL = 6087, + DecorationMediaBlockIOINTEL = 6140, + DecorationMax = 0x7fffffff, +}; + +enum BuiltIn { + BuiltInPosition = 0, + BuiltInPointSize = 1, + BuiltInClipDistance = 3, + BuiltInCullDistance = 4, + BuiltInVertexId = 5, + BuiltInInstanceId = 6, + BuiltInPrimitiveId = 7, + BuiltInInvocationId = 8, + BuiltInLayer = 9, + BuiltInViewportIndex = 10, + BuiltInTessLevelOuter = 11, + BuiltInTessLevelInner = 12, + BuiltInTessCoord = 13, + BuiltInPatchVertices = 14, + BuiltInFragCoord = 15, + BuiltInPointCoord = 16, + BuiltInFrontFacing = 17, + BuiltInSampleId = 18, + BuiltInSamplePosition = 19, + BuiltInSampleMask = 20, + BuiltInFragDepth = 22, + BuiltInHelperInvocation = 23, + BuiltInNumWorkgroups = 24, + BuiltInWorkgroupSize = 25, + BuiltInWorkgroupId = 26, + BuiltInLocalInvocationId = 27, + BuiltInGlobalInvocationId = 28, + BuiltInLocalInvocationIndex = 29, + BuiltInWorkDim = 30, + BuiltInGlobalSize = 31, + BuiltInEnqueuedWorkgroupSize = 32, + BuiltInGlobalOffset = 33, + BuiltInGlobalLinearId = 34, + BuiltInSubgroupSize = 36, + BuiltInSubgroupMaxSize = 37, + BuiltInNumSubgroups = 38, + BuiltInNumEnqueuedSubgroups = 39, + BuiltInSubgroupId = 40, + BuiltInSubgroupLocalInvocationId = 41, + BuiltInVertexIndex = 42, + BuiltInInstanceIndex = 43, + BuiltInSubgroupEqMask = 4416, + BuiltInSubgroupEqMaskKHR = 4416, + BuiltInSubgroupGeMask = 4417, + BuiltInSubgroupGeMaskKHR = 4417, + BuiltInSubgroupGtMask = 4418, + BuiltInSubgroupGtMaskKHR = 4418, + BuiltInSubgroupLeMask = 4419, + BuiltInSubgroupLeMaskKHR = 4419, + BuiltInSubgroupLtMask = 4420, + BuiltInSubgroupLtMaskKHR = 4420, + BuiltInBaseVertex = 4424, + BuiltInBaseInstance = 4425, + BuiltInDrawIndex = 4426, + BuiltInPrimitiveShadingRateKHR = 4432, + BuiltInDeviceIndex = 4438, + BuiltInViewIndex = 4440, + BuiltInShadingRateKHR = 4444, + BuiltInBaryCoordNoPerspAMD = 4992, + BuiltInBaryCoordNoPerspCentroidAMD = 4993, + BuiltInBaryCoordNoPerspSampleAMD = 4994, + BuiltInBaryCoordSmoothAMD = 4995, + BuiltInBaryCoordSmoothCentroidAMD = 4996, + BuiltInBaryCoordSmoothSampleAMD = 4997, + BuiltInBaryCoordPullModelAMD = 4998, + BuiltInFragStencilRefEXT = 5014, + BuiltInViewportMaskNV = 5253, + BuiltInSecondaryPositionNV = 5257, + BuiltInSecondaryViewportMaskNV = 5258, + BuiltInPositionPerViewNV = 5261, + BuiltInViewportMaskPerViewNV = 5262, + BuiltInFullyCoveredEXT = 5264, + BuiltInTaskCountNV = 5274, + BuiltInPrimitiveCountNV = 5275, + BuiltInPrimitiveIndicesNV = 5276, + BuiltInClipDistancePerViewNV = 5277, + BuiltInCullDistancePerViewNV = 5278, + BuiltInLayerPerViewNV = 5279, + BuiltInMeshViewCountNV = 5280, + BuiltInMeshViewIndicesNV = 5281, + BuiltInBaryCoordKHR = 5286, + BuiltInBaryCoordNV = 5286, + BuiltInBaryCoordNoPerspKHR = 5287, + BuiltInBaryCoordNoPerspNV = 5287, + BuiltInFragSizeEXT = 5292, + BuiltInFragmentSizeNV = 5292, + BuiltInFragInvocationCountEXT = 5293, + BuiltInInvocationsPerPixelNV = 5293, + BuiltInPrimitivePointIndicesEXT = 5294, + BuiltInPrimitiveLineIndicesEXT = 5295, + BuiltInPrimitiveTriangleIndicesEXT = 5296, + BuiltInCullPrimitiveEXT = 5299, + BuiltInLaunchIdKHR = 5319, + BuiltInLaunchIdNV = 5319, + BuiltInLaunchSizeKHR = 5320, + BuiltInLaunchSizeNV = 5320, + BuiltInWorldRayOriginKHR = 5321, + BuiltInWorldRayOriginNV = 5321, + BuiltInWorldRayDirectionKHR = 5322, + BuiltInWorldRayDirectionNV = 5322, + BuiltInObjectRayOriginKHR = 5323, + BuiltInObjectRayOriginNV = 5323, + BuiltInObjectRayDirectionKHR = 5324, + BuiltInObjectRayDirectionNV = 5324, + BuiltInRayTminKHR = 5325, + BuiltInRayTminNV = 5325, + BuiltInRayTmaxKHR = 5326, + BuiltInRayTmaxNV = 5326, + BuiltInInstanceCustomIndexKHR = 5327, + BuiltInInstanceCustomIndexNV = 5327, + BuiltInObjectToWorldKHR = 5330, + BuiltInObjectToWorldNV = 5330, + BuiltInWorldToObjectKHR = 5331, + BuiltInWorldToObjectNV = 5331, + BuiltInHitTNV = 5332, + BuiltInHitKindKHR = 5333, + BuiltInHitKindNV = 5333, + BuiltInCurrentRayTimeNV = 5334, + BuiltInIncomingRayFlagsKHR = 5351, + BuiltInIncomingRayFlagsNV = 5351, + BuiltInRayGeometryIndexKHR = 5352, + BuiltInWarpsPerSMNV = 5374, + BuiltInSMCountNV = 5375, + BuiltInWarpIDNV = 5376, + BuiltInSMIDNV = 5377, + BuiltInCullMaskKHR = 6021, + BuiltInMax = 0x7fffffff, +}; + +enum SelectionControlShift { + SelectionControlFlattenShift = 0, + SelectionControlDontFlattenShift = 1, + SelectionControlMax = 0x7fffffff, +}; + +enum SelectionControlMask { + SelectionControlMaskNone = 0, + SelectionControlFlattenMask = 0x00000001, + SelectionControlDontFlattenMask = 0x00000002, +}; + +enum LoopControlShift { + LoopControlUnrollShift = 0, + LoopControlDontUnrollShift = 1, + LoopControlDependencyInfiniteShift = 2, + LoopControlDependencyLengthShift = 3, + LoopControlMinIterationsShift = 4, + LoopControlMaxIterationsShift = 5, + LoopControlIterationMultipleShift = 6, + LoopControlPeelCountShift = 7, + LoopControlPartialCountShift = 8, + LoopControlInitiationIntervalINTELShift = 16, + LoopControlMaxConcurrencyINTELShift = 17, + LoopControlDependencyArrayINTELShift = 18, + LoopControlPipelineEnableINTELShift = 19, + LoopControlLoopCoalesceINTELShift = 20, + LoopControlMaxInterleavingINTELShift = 21, + LoopControlSpeculatedIterationsINTELShift = 22, + LoopControlNoFusionINTELShift = 23, + LoopControlMax = 0x7fffffff, +}; + +enum LoopControlMask { + LoopControlMaskNone = 0, + LoopControlUnrollMask = 0x00000001, + LoopControlDontUnrollMask = 0x00000002, + LoopControlDependencyInfiniteMask = 0x00000004, + LoopControlDependencyLengthMask = 0x00000008, + LoopControlMinIterationsMask = 0x00000010, + LoopControlMaxIterationsMask = 0x00000020, + LoopControlIterationMultipleMask = 0x00000040, + LoopControlPeelCountMask = 0x00000080, + LoopControlPartialCountMask = 0x00000100, + LoopControlInitiationIntervalINTELMask = 0x00010000, + LoopControlMaxConcurrencyINTELMask = 0x00020000, + LoopControlDependencyArrayINTELMask = 0x00040000, + LoopControlPipelineEnableINTELMask = 0x00080000, + LoopControlLoopCoalesceINTELMask = 0x00100000, + LoopControlMaxInterleavingINTELMask = 0x00200000, + LoopControlSpeculatedIterationsINTELMask = 0x00400000, + LoopControlNoFusionINTELMask = 0x00800000, +}; + +enum FunctionControlShift { + FunctionControlInlineShift = 0, + FunctionControlDontInlineShift = 1, + FunctionControlPureShift = 2, + FunctionControlConstShift = 3, + FunctionControlOptNoneINTELShift = 16, + FunctionControlMax = 0x7fffffff, +}; + +enum FunctionControlMask { + FunctionControlMaskNone = 0, + FunctionControlInlineMask = 0x00000001, + FunctionControlDontInlineMask = 0x00000002, + FunctionControlPureMask = 0x00000004, + FunctionControlConstMask = 0x00000008, + FunctionControlOptNoneINTELMask = 0x00010000, +}; + +enum MemorySemanticsShift { + MemorySemanticsAcquireShift = 1, + MemorySemanticsReleaseShift = 2, + MemorySemanticsAcquireReleaseShift = 3, + MemorySemanticsSequentiallyConsistentShift = 4, + MemorySemanticsUniformMemoryShift = 6, + MemorySemanticsSubgroupMemoryShift = 7, + MemorySemanticsWorkgroupMemoryShift = 8, + MemorySemanticsCrossWorkgroupMemoryShift = 9, + MemorySemanticsAtomicCounterMemoryShift = 10, + MemorySemanticsImageMemoryShift = 11, + MemorySemanticsOutputMemoryShift = 12, + MemorySemanticsOutputMemoryKHRShift = 12, + MemorySemanticsMakeAvailableShift = 13, + MemorySemanticsMakeAvailableKHRShift = 13, + MemorySemanticsMakeVisibleShift = 14, + MemorySemanticsMakeVisibleKHRShift = 14, + MemorySemanticsVolatileShift = 15, + MemorySemanticsMax = 0x7fffffff, +}; + +enum MemorySemanticsMask { + MemorySemanticsMaskNone = 0, + MemorySemanticsAcquireMask = 0x00000002, + MemorySemanticsReleaseMask = 0x00000004, + MemorySemanticsAcquireReleaseMask = 0x00000008, + MemorySemanticsSequentiallyConsistentMask = 0x00000010, + MemorySemanticsUniformMemoryMask = 0x00000040, + MemorySemanticsSubgroupMemoryMask = 0x00000080, + MemorySemanticsWorkgroupMemoryMask = 0x00000100, + MemorySemanticsCrossWorkgroupMemoryMask = 0x00000200, + MemorySemanticsAtomicCounterMemoryMask = 0x00000400, + MemorySemanticsImageMemoryMask = 0x00000800, + MemorySemanticsOutputMemoryMask = 0x00001000, + MemorySemanticsOutputMemoryKHRMask = 0x00001000, + MemorySemanticsMakeAvailableMask = 0x00002000, + MemorySemanticsMakeAvailableKHRMask = 0x00002000, + MemorySemanticsMakeVisibleMask = 0x00004000, + MemorySemanticsMakeVisibleKHRMask = 0x00004000, + MemorySemanticsVolatileMask = 0x00008000, +}; + +enum MemoryAccessShift { + MemoryAccessVolatileShift = 0, + MemoryAccessAlignedShift = 1, + MemoryAccessNontemporalShift = 2, + MemoryAccessMakePointerAvailableShift = 3, + MemoryAccessMakePointerAvailableKHRShift = 3, + MemoryAccessMakePointerVisibleShift = 4, + MemoryAccessMakePointerVisibleKHRShift = 4, + MemoryAccessNonPrivatePointerShift = 5, + MemoryAccessNonPrivatePointerKHRShift = 5, + MemoryAccessAliasScopeINTELMaskShift = 16, + MemoryAccessNoAliasINTELMaskShift = 17, + MemoryAccessMax = 0x7fffffff, +}; + +enum MemoryAccessMask { + MemoryAccessMaskNone = 0, + MemoryAccessVolatileMask = 0x00000001, + MemoryAccessAlignedMask = 0x00000002, + MemoryAccessNontemporalMask = 0x00000004, + MemoryAccessMakePointerAvailableMask = 0x00000008, + MemoryAccessMakePointerAvailableKHRMask = 0x00000008, + MemoryAccessMakePointerVisibleMask = 0x00000010, + MemoryAccessMakePointerVisibleKHRMask = 0x00000010, + MemoryAccessNonPrivatePointerMask = 0x00000020, + MemoryAccessNonPrivatePointerKHRMask = 0x00000020, + MemoryAccessAliasScopeINTELMaskMask = 0x00010000, + MemoryAccessNoAliasINTELMaskMask = 0x00020000, +}; + +enum Scope { + ScopeCrossDevice = 0, + ScopeDevice = 1, + ScopeWorkgroup = 2, + ScopeSubgroup = 3, + ScopeInvocation = 4, + ScopeQueueFamily = 5, + ScopeQueueFamilyKHR = 5, + ScopeShaderCallKHR = 6, + ScopeMax = 0x7fffffff, +}; + +enum GroupOperation { + GroupOperationReduce = 0, + GroupOperationInclusiveScan = 1, + GroupOperationExclusiveScan = 2, + GroupOperationClusteredReduce = 3, + GroupOperationPartitionedReduceNV = 6, + GroupOperationPartitionedInclusiveScanNV = 7, + GroupOperationPartitionedExclusiveScanNV = 8, + GroupOperationMax = 0x7fffffff, +}; + +enum KernelEnqueueFlags { + KernelEnqueueFlagsNoWait = 0, + KernelEnqueueFlagsWaitKernel = 1, + KernelEnqueueFlagsWaitWorkGroup = 2, + KernelEnqueueFlagsMax = 0x7fffffff, +}; + +enum KernelProfilingInfoShift { + KernelProfilingInfoCmdExecTimeShift = 0, + KernelProfilingInfoMax = 0x7fffffff, +}; + +enum KernelProfilingInfoMask { + KernelProfilingInfoMaskNone = 0, + KernelProfilingInfoCmdExecTimeMask = 0x00000001, +}; + +enum Capability { + CapabilityMatrix = 0, + CapabilityShader = 1, + CapabilityGeometry = 2, + CapabilityTessellation = 3, + CapabilityAddresses = 4, + CapabilityLinkage = 5, + CapabilityKernel = 6, + CapabilityVector16 = 7, + CapabilityFloat16Buffer = 8, + CapabilityFloat16 = 9, + CapabilityFloat64 = 10, + CapabilityInt64 = 11, + CapabilityInt64Atomics = 12, + CapabilityImageBasic = 13, + CapabilityImageReadWrite = 14, + CapabilityImageMipmap = 15, + CapabilityPipes = 17, + CapabilityGroups = 18, + CapabilityDeviceEnqueue = 19, + CapabilityLiteralSampler = 20, + CapabilityAtomicStorage = 21, + CapabilityInt16 = 22, + CapabilityTessellationPointSize = 23, + CapabilityGeometryPointSize = 24, + CapabilityImageGatherExtended = 25, + CapabilityStorageImageMultisample = 27, + CapabilityUniformBufferArrayDynamicIndexing = 28, + CapabilitySampledImageArrayDynamicIndexing = 29, + CapabilityStorageBufferArrayDynamicIndexing = 30, + CapabilityStorageImageArrayDynamicIndexing = 31, + CapabilityClipDistance = 32, + CapabilityCullDistance = 33, + CapabilityImageCubeArray = 34, + CapabilitySampleRateShading = 35, + CapabilityImageRect = 36, + CapabilitySampledRect = 37, + CapabilityGenericPointer = 38, + CapabilityInt8 = 39, + CapabilityInputAttachment = 40, + CapabilitySparseResidency = 41, + CapabilityMinLod = 42, + CapabilitySampled1D = 43, + CapabilityImage1D = 44, + CapabilitySampledCubeArray = 45, + CapabilitySampledBuffer = 46, + CapabilityImageBuffer = 47, + CapabilityImageMSArray = 48, + CapabilityStorageImageExtendedFormats = 49, + CapabilityImageQuery = 50, + CapabilityDerivativeControl = 51, + CapabilityInterpolationFunction = 52, + CapabilityTransformFeedback = 53, + CapabilityGeometryStreams = 54, + CapabilityStorageImageReadWithoutFormat = 55, + CapabilityStorageImageWriteWithoutFormat = 56, + CapabilityMultiViewport = 57, + CapabilitySubgroupDispatch = 58, + CapabilityNamedBarrier = 59, + CapabilityPipeStorage = 60, + CapabilityGroupNonUniform = 61, + CapabilityGroupNonUniformVote = 62, + CapabilityGroupNonUniformArithmetic = 63, + CapabilityGroupNonUniformBallot = 64, + CapabilityGroupNonUniformShuffle = 65, + CapabilityGroupNonUniformShuffleRelative = 66, + CapabilityGroupNonUniformClustered = 67, + CapabilityGroupNonUniformQuad = 68, + CapabilityShaderLayer = 69, + CapabilityShaderViewportIndex = 70, + CapabilityUniformDecoration = 71, + CapabilityFragmentShadingRateKHR = 4422, + CapabilitySubgroupBallotKHR = 4423, + CapabilityDrawParameters = 4427, + CapabilityWorkgroupMemoryExplicitLayoutKHR = 4428, + CapabilityWorkgroupMemoryExplicitLayout8BitAccessKHR = 4429, + CapabilityWorkgroupMemoryExplicitLayout16BitAccessKHR = 4430, + CapabilitySubgroupVoteKHR = 4431, + CapabilityStorageBuffer16BitAccess = 4433, + CapabilityStorageUniformBufferBlock16 = 4433, + CapabilityStorageUniform16 = 4434, + CapabilityUniformAndStorageBuffer16BitAccess = 4434, + CapabilityStoragePushConstant16 = 4435, + CapabilityStorageInputOutput16 = 4436, + CapabilityDeviceGroup = 4437, + CapabilityMultiView = 4439, + CapabilityVariablePointersStorageBuffer = 4441, + CapabilityVariablePointers = 4442, + CapabilityAtomicStorageOps = 4445, + CapabilitySampleMaskPostDepthCoverage = 4447, + CapabilityStorageBuffer8BitAccess = 4448, + CapabilityUniformAndStorageBuffer8BitAccess = 4449, + CapabilityStoragePushConstant8 = 4450, + CapabilityDenormPreserve = 4464, + CapabilityDenormFlushToZero = 4465, + CapabilitySignedZeroInfNanPreserve = 4466, + CapabilityRoundingModeRTE = 4467, + CapabilityRoundingModeRTZ = 4468, + CapabilityRayQueryProvisionalKHR = 4471, + CapabilityRayQueryKHR = 4472, + CapabilityRayTraversalPrimitiveCullingKHR = 4478, + CapabilityRayTracingKHR = 4479, + CapabilityTextureSampleWeightedQCOM = 4484, + CapabilityTextureBoxFilterQCOM = 4485, + CapabilityTextureBlockMatchQCOM = 4486, + CapabilityFloat16ImageAMD = 5008, + CapabilityImageGatherBiasLodAMD = 5009, + CapabilityFragmentMaskAMD = 5010, + CapabilityStencilExportEXT = 5013, + CapabilityImageReadWriteLodAMD = 5015, + CapabilityInt64ImageEXT = 5016, + CapabilityShaderClockKHR = 5055, + CapabilitySampleMaskOverrideCoverageNV = 5249, + CapabilityGeometryShaderPassthroughNV = 5251, + CapabilityShaderViewportIndexLayerEXT = 5254, + CapabilityShaderViewportIndexLayerNV = 5254, + CapabilityShaderViewportMaskNV = 5255, + CapabilityShaderStereoViewNV = 5259, + CapabilityPerViewAttributesNV = 5260, + CapabilityFragmentFullyCoveredEXT = 5265, + CapabilityMeshShadingNV = 5266, + CapabilityImageFootprintNV = 5282, + CapabilityMeshShadingEXT = 5283, + CapabilityFragmentBarycentricKHR = 5284, + CapabilityFragmentBarycentricNV = 5284, + CapabilityComputeDerivativeGroupQuadsNV = 5288, + CapabilityFragmentDensityEXT = 5291, + CapabilityShadingRateNV = 5291, + CapabilityGroupNonUniformPartitionedNV = 5297, + CapabilityShaderNonUniform = 5301, + CapabilityShaderNonUniformEXT = 5301, + CapabilityRuntimeDescriptorArray = 5302, + CapabilityRuntimeDescriptorArrayEXT = 5302, + CapabilityInputAttachmentArrayDynamicIndexing = 5303, + CapabilityInputAttachmentArrayDynamicIndexingEXT = 5303, + CapabilityUniformTexelBufferArrayDynamicIndexing = 5304, + CapabilityUniformTexelBufferArrayDynamicIndexingEXT = 5304, + CapabilityStorageTexelBufferArrayDynamicIndexing = 5305, + CapabilityStorageTexelBufferArrayDynamicIndexingEXT = 5305, + CapabilityUniformBufferArrayNonUniformIndexing = 5306, + CapabilityUniformBufferArrayNonUniformIndexingEXT = 5306, + CapabilitySampledImageArrayNonUniformIndexing = 5307, + CapabilitySampledImageArrayNonUniformIndexingEXT = 5307, + CapabilityStorageBufferArrayNonUniformIndexing = 5308, + CapabilityStorageBufferArrayNonUniformIndexingEXT = 5308, + CapabilityStorageImageArrayNonUniformIndexing = 5309, + CapabilityStorageImageArrayNonUniformIndexingEXT = 5309, + CapabilityInputAttachmentArrayNonUniformIndexing = 5310, + CapabilityInputAttachmentArrayNonUniformIndexingEXT = 5310, + CapabilityUniformTexelBufferArrayNonUniformIndexing = 5311, + CapabilityUniformTexelBufferArrayNonUniformIndexingEXT = 5311, + CapabilityStorageTexelBufferArrayNonUniformIndexing = 5312, + CapabilityStorageTexelBufferArrayNonUniformIndexingEXT = 5312, + CapabilityRayTracingNV = 5340, + CapabilityRayTracingMotionBlurNV = 5341, + CapabilityVulkanMemoryModel = 5345, + CapabilityVulkanMemoryModelKHR = 5345, + CapabilityVulkanMemoryModelDeviceScope = 5346, + CapabilityVulkanMemoryModelDeviceScopeKHR = 5346, + CapabilityPhysicalStorageBufferAddresses = 5347, + CapabilityPhysicalStorageBufferAddressesEXT = 5347, + CapabilityComputeDerivativeGroupLinearNV = 5350, + CapabilityRayTracingProvisionalKHR = 5353, + CapabilityCooperativeMatrixNV = 5357, + CapabilityFragmentShaderSampleInterlockEXT = 5363, + CapabilityFragmentShaderShadingRateInterlockEXT = 5372, + CapabilityShaderSMBuiltinsNV = 5373, + CapabilityFragmentShaderPixelInterlockEXT = 5378, + CapabilityDemoteToHelperInvocation = 5379, + CapabilityDemoteToHelperInvocationEXT = 5379, + CapabilityBindlessTextureNV = 5390, + CapabilitySubgroupShuffleINTEL = 5568, + CapabilitySubgroupBufferBlockIOINTEL = 5569, + CapabilitySubgroupImageBlockIOINTEL = 5570, + CapabilitySubgroupImageMediaBlockIOINTEL = 5579, + CapabilityRoundToInfinityINTEL = 5582, + CapabilityFloatingPointModeINTEL = 5583, + CapabilityIntegerFunctions2INTEL = 5584, + CapabilityFunctionPointersINTEL = 5603, + CapabilityIndirectReferencesINTEL = 5604, + CapabilityAsmINTEL = 5606, + CapabilityAtomicFloat32MinMaxEXT = 5612, + CapabilityAtomicFloat64MinMaxEXT = 5613, + CapabilityAtomicFloat16MinMaxEXT = 5616, + CapabilityVectorComputeINTEL = 5617, + CapabilityVectorAnyINTEL = 5619, + CapabilityExpectAssumeKHR = 5629, + CapabilitySubgroupAvcMotionEstimationINTEL = 5696, + CapabilitySubgroupAvcMotionEstimationIntraINTEL = 5697, + CapabilitySubgroupAvcMotionEstimationChromaINTEL = 5698, + CapabilityVariableLengthArrayINTEL = 5817, + CapabilityFunctionFloatControlINTEL = 5821, + CapabilityFPGAMemoryAttributesINTEL = 5824, + CapabilityFPFastMathModeINTEL = 5837, + CapabilityArbitraryPrecisionIntegersINTEL = 5844, + CapabilityArbitraryPrecisionFloatingPointINTEL = 5845, + CapabilityUnstructuredLoopControlsINTEL = 5886, + CapabilityFPGALoopControlsINTEL = 5888, + CapabilityKernelAttributesINTEL = 5892, + CapabilityFPGAKernelAttributesINTEL = 5897, + CapabilityFPGAMemoryAccessesINTEL = 5898, + CapabilityFPGAClusterAttributesINTEL = 5904, + CapabilityLoopFuseINTEL = 5906, + CapabilityMemoryAccessAliasingINTEL = 5910, + CapabilityFPGABufferLocationINTEL = 5920, + CapabilityArbitraryPrecisionFixedPointINTEL = 5922, + CapabilityUSMStorageClassesINTEL = 5935, + CapabilityIOPipesINTEL = 5943, + CapabilityBlockingPipesINTEL = 5945, + CapabilityFPGARegINTEL = 5948, + CapabilityDotProductInputAll = 6016, + CapabilityDotProductInputAllKHR = 6016, + CapabilityDotProductInput4x8Bit = 6017, + CapabilityDotProductInput4x8BitKHR = 6017, + CapabilityDotProductInput4x8BitPacked = 6018, + CapabilityDotProductInput4x8BitPackedKHR = 6018, + CapabilityDotProduct = 6019, + CapabilityDotProductKHR = 6019, + CapabilityRayCullMaskKHR = 6020, + CapabilityBitInstructions = 6025, + CapabilityGroupNonUniformRotateKHR = 6026, + CapabilityAtomicFloat32AddEXT = 6033, + CapabilityAtomicFloat64AddEXT = 6034, + CapabilityLongConstantCompositeINTEL = 6089, + CapabilityOptNoneINTEL = 6094, + CapabilityAtomicFloat16AddEXT = 6095, + CapabilityDebugInfoModuleINTEL = 6114, + CapabilitySplitBarrierINTEL = 6141, + CapabilityGroupUniformArithmeticKHR = 6400, + CapabilityMax = 0x7fffffff, +}; + +enum RayFlagsShift { + RayFlagsOpaqueKHRShift = 0, + RayFlagsNoOpaqueKHRShift = 1, + RayFlagsTerminateOnFirstHitKHRShift = 2, + RayFlagsSkipClosestHitShaderKHRShift = 3, + RayFlagsCullBackFacingTrianglesKHRShift = 4, + RayFlagsCullFrontFacingTrianglesKHRShift = 5, + RayFlagsCullOpaqueKHRShift = 6, + RayFlagsCullNoOpaqueKHRShift = 7, + RayFlagsSkipTrianglesKHRShift = 8, + RayFlagsSkipAABBsKHRShift = 9, + RayFlagsMax = 0x7fffffff, +}; + +enum RayFlagsMask { + RayFlagsMaskNone = 0, + RayFlagsOpaqueKHRMask = 0x00000001, + RayFlagsNoOpaqueKHRMask = 0x00000002, + RayFlagsTerminateOnFirstHitKHRMask = 0x00000004, + RayFlagsSkipClosestHitShaderKHRMask = 0x00000008, + RayFlagsCullBackFacingTrianglesKHRMask = 0x00000010, + RayFlagsCullFrontFacingTrianglesKHRMask = 0x00000020, + RayFlagsCullOpaqueKHRMask = 0x00000040, + RayFlagsCullNoOpaqueKHRMask = 0x00000080, + RayFlagsSkipTrianglesKHRMask = 0x00000100, + RayFlagsSkipAABBsKHRMask = 0x00000200, +}; + +enum RayQueryIntersection { + RayQueryIntersectionRayQueryCandidateIntersectionKHR = 0, + RayQueryIntersectionRayQueryCommittedIntersectionKHR = 1, + RayQueryIntersectionMax = 0x7fffffff, +}; + +enum RayQueryCommittedIntersectionType { + RayQueryCommittedIntersectionTypeRayQueryCommittedIntersectionNoneKHR = 0, + RayQueryCommittedIntersectionTypeRayQueryCommittedIntersectionTriangleKHR = 1, + RayQueryCommittedIntersectionTypeRayQueryCommittedIntersectionGeneratedKHR = 2, + RayQueryCommittedIntersectionTypeMax = 0x7fffffff, +}; + +enum RayQueryCandidateIntersectionType { + RayQueryCandidateIntersectionTypeRayQueryCandidateIntersectionTriangleKHR = 0, + RayQueryCandidateIntersectionTypeRayQueryCandidateIntersectionAABBKHR = 1, + RayQueryCandidateIntersectionTypeMax = 0x7fffffff, +}; + +enum FragmentShadingRateShift { + FragmentShadingRateVertical2PixelsShift = 0, + FragmentShadingRateVertical4PixelsShift = 1, + FragmentShadingRateHorizontal2PixelsShift = 2, + FragmentShadingRateHorizontal4PixelsShift = 3, + FragmentShadingRateMax = 0x7fffffff, +}; + +enum FragmentShadingRateMask { + FragmentShadingRateMaskNone = 0, + FragmentShadingRateVertical2PixelsMask = 0x00000001, + FragmentShadingRateVertical4PixelsMask = 0x00000002, + FragmentShadingRateHorizontal2PixelsMask = 0x00000004, + FragmentShadingRateHorizontal4PixelsMask = 0x00000008, +}; + +enum FPDenormMode { + FPDenormModePreserve = 0, + FPDenormModeFlushToZero = 1, + FPDenormModeMax = 0x7fffffff, +}; + +enum FPOperationMode { + FPOperationModeIEEE = 0, + FPOperationModeALT = 1, + FPOperationModeMax = 0x7fffffff, +}; + +enum QuantizationModes { + QuantizationModesTRN = 0, + QuantizationModesTRN_ZERO = 1, + QuantizationModesRND = 2, + QuantizationModesRND_ZERO = 3, + QuantizationModesRND_INF = 4, + QuantizationModesRND_MIN_INF = 5, + QuantizationModesRND_CONV = 6, + QuantizationModesRND_CONV_ODD = 7, + QuantizationModesMax = 0x7fffffff, +}; + +enum OverflowModes { + OverflowModesWRAP = 0, + OverflowModesSAT = 1, + OverflowModesSAT_ZERO = 2, + OverflowModesSAT_SYM = 3, + OverflowModesMax = 0x7fffffff, +}; + +enum PackedVectorFormat { + PackedVectorFormatPackedVectorFormat4x8Bit = 0, + PackedVectorFormatPackedVectorFormat4x8BitKHR = 0, + PackedVectorFormatMax = 0x7fffffff, +}; + +enum Op { + OpNop = 0, + OpUndef = 1, + OpSourceContinued = 2, + OpSource = 3, + OpSourceExtension = 4, + OpName = 5, + OpMemberName = 6, + OpString = 7, + OpLine = 8, + OpExtension = 10, + OpExtInstImport = 11, + OpExtInst = 12, + OpMemoryModel = 14, + OpEntryPoint = 15, + OpExecutionMode = 16, + OpCapability = 17, + OpTypeVoid = 19, + OpTypeBool = 20, + OpTypeInt = 21, + OpTypeFloat = 22, + OpTypeVector = 23, + OpTypeMatrix = 24, + OpTypeImage = 25, + OpTypeSampler = 26, + OpTypeSampledImage = 27, + OpTypeArray = 28, + OpTypeRuntimeArray = 29, + OpTypeStruct = 30, + OpTypeOpaque = 31, + OpTypePointer = 32, + OpTypeFunction = 33, + OpTypeEvent = 34, + OpTypeDeviceEvent = 35, + OpTypeReserveId = 36, + OpTypeQueue = 37, + OpTypePipe = 38, + OpTypeForwardPointer = 39, + OpConstantTrue = 41, + OpConstantFalse = 42, + OpConstant = 43, + OpConstantComposite = 44, + OpConstantSampler = 45, + OpConstantNull = 46, + OpSpecConstantTrue = 48, + OpSpecConstantFalse = 49, + OpSpecConstant = 50, + OpSpecConstantComposite = 51, + OpSpecConstantOp = 52, + OpFunction = 54, + OpFunctionParameter = 55, + OpFunctionEnd = 56, + OpFunctionCall = 57, + OpVariable = 59, + OpImageTexelPointer = 60, + OpLoad = 61, + OpStore = 62, + OpCopyMemory = 63, + OpCopyMemorySized = 64, + OpAccessChain = 65, + OpInBoundsAccessChain = 66, + OpPtrAccessChain = 67, + OpArrayLength = 68, + OpGenericPtrMemSemantics = 69, + OpInBoundsPtrAccessChain = 70, + OpDecorate = 71, + OpMemberDecorate = 72, + OpDecorationGroup = 73, + OpGroupDecorate = 74, + OpGroupMemberDecorate = 75, + OpVectorExtractDynamic = 77, + OpVectorInsertDynamic = 78, + OpVectorShuffle = 79, + OpCompositeConstruct = 80, + OpCompositeExtract = 81, + OpCompositeInsert = 82, + OpCopyObject = 83, + OpTranspose = 84, + OpSampledImage = 86, + OpImageSampleImplicitLod = 87, + OpImageSampleExplicitLod = 88, + OpImageSampleDrefImplicitLod = 89, + OpImageSampleDrefExplicitLod = 90, + OpImageSampleProjImplicitLod = 91, + OpImageSampleProjExplicitLod = 92, + OpImageSampleProjDrefImplicitLod = 93, + OpImageSampleProjDrefExplicitLod = 94, + OpImageFetch = 95, + OpImageGather = 96, + OpImageDrefGather = 97, + OpImageRead = 98, + OpImageWrite = 99, + OpImage = 100, + OpImageQueryFormat = 101, + OpImageQueryOrder = 102, + OpImageQuerySizeLod = 103, + OpImageQuerySize = 104, + OpImageQueryLod = 105, + OpImageQueryLevels = 106, + OpImageQuerySamples = 107, + OpConvertFToU = 109, + OpConvertFToS = 110, + OpConvertSToF = 111, + OpConvertUToF = 112, + OpUConvert = 113, + OpSConvert = 114, + OpFConvert = 115, + OpQuantizeToF16 = 116, + OpConvertPtrToU = 117, + OpSatConvertSToU = 118, + OpSatConvertUToS = 119, + OpConvertUToPtr = 120, + OpPtrCastToGeneric = 121, + OpGenericCastToPtr = 122, + OpGenericCastToPtrExplicit = 123, + OpBitcast = 124, + OpSNegate = 126, + OpFNegate = 127, + OpIAdd = 128, + OpFAdd = 129, + OpISub = 130, + OpFSub = 131, + OpIMul = 132, + OpFMul = 133, + OpUDiv = 134, + OpSDiv = 135, + OpFDiv = 136, + OpUMod = 137, + OpSRem = 138, + OpSMod = 139, + OpFRem = 140, + OpFMod = 141, + OpVectorTimesScalar = 142, + OpMatrixTimesScalar = 143, + OpVectorTimesMatrix = 144, + OpMatrixTimesVector = 145, + OpMatrixTimesMatrix = 146, + OpOuterProduct = 147, + OpDot = 148, + OpIAddCarry = 149, + OpISubBorrow = 150, + OpUMulExtended = 151, + OpSMulExtended = 152, + OpAny = 154, + OpAll = 155, + OpIsNan = 156, + OpIsInf = 157, + OpIsFinite = 158, + OpIsNormal = 159, + OpSignBitSet = 160, + OpLessOrGreater = 161, + OpOrdered = 162, + OpUnordered = 163, + OpLogicalEqual = 164, + OpLogicalNotEqual = 165, + OpLogicalOr = 166, + OpLogicalAnd = 167, + OpLogicalNot = 168, + OpSelect = 169, + OpIEqual = 170, + OpINotEqual = 171, + OpUGreaterThan = 172, + OpSGreaterThan = 173, + OpUGreaterThanEqual = 174, + OpSGreaterThanEqual = 175, + OpULessThan = 176, + OpSLessThan = 177, + OpULessThanEqual = 178, + OpSLessThanEqual = 179, + OpFOrdEqual = 180, + OpFUnordEqual = 181, + OpFOrdNotEqual = 182, + OpFUnordNotEqual = 183, + OpFOrdLessThan = 184, + OpFUnordLessThan = 185, + OpFOrdGreaterThan = 186, + OpFUnordGreaterThan = 187, + OpFOrdLessThanEqual = 188, + OpFUnordLessThanEqual = 189, + OpFOrdGreaterThanEqual = 190, + OpFUnordGreaterThanEqual = 191, + OpShiftRightLogical = 194, + OpShiftRightArithmetic = 195, + OpShiftLeftLogical = 196, + OpBitwiseOr = 197, + OpBitwiseXor = 198, + OpBitwiseAnd = 199, + OpNot = 200, + OpBitFieldInsert = 201, + OpBitFieldSExtract = 202, + OpBitFieldUExtract = 203, + OpBitReverse = 204, + OpBitCount = 205, + OpDPdx = 207, + OpDPdy = 208, + OpFwidth = 209, + OpDPdxFine = 210, + OpDPdyFine = 211, + OpFwidthFine = 212, + OpDPdxCoarse = 213, + OpDPdyCoarse = 214, + OpFwidthCoarse = 215, + OpEmitVertex = 218, + OpEndPrimitive = 219, + OpEmitStreamVertex = 220, + OpEndStreamPrimitive = 221, + OpControlBarrier = 224, + OpMemoryBarrier = 225, + OpAtomicLoad = 227, + OpAtomicStore = 228, + OpAtomicExchange = 229, + OpAtomicCompareExchange = 230, + OpAtomicCompareExchangeWeak = 231, + OpAtomicIIncrement = 232, + OpAtomicIDecrement = 233, + OpAtomicIAdd = 234, + OpAtomicISub = 235, + OpAtomicSMin = 236, + OpAtomicUMin = 237, + OpAtomicSMax = 238, + OpAtomicUMax = 239, + OpAtomicAnd = 240, + OpAtomicOr = 241, + OpAtomicXor = 242, + OpPhi = 245, + OpLoopMerge = 246, + OpSelectionMerge = 247, + OpLabel = 248, + OpBranch = 249, + OpBranchConditional = 250, + OpSwitch = 251, + OpKill = 252, + OpReturn = 253, + OpReturnValue = 254, + OpUnreachable = 255, + OpLifetimeStart = 256, + OpLifetimeStop = 257, + OpGroupAsyncCopy = 259, + OpGroupWaitEvents = 260, + OpGroupAll = 261, + OpGroupAny = 262, + OpGroupBroadcast = 263, + OpGroupIAdd = 264, + OpGroupFAdd = 265, + OpGroupFMin = 266, + OpGroupUMin = 267, + OpGroupSMin = 268, + OpGroupFMax = 269, + OpGroupUMax = 270, + OpGroupSMax = 271, + OpReadPipe = 274, + OpWritePipe = 275, + OpReservedReadPipe = 276, + OpReservedWritePipe = 277, + OpReserveReadPipePackets = 278, + OpReserveWritePipePackets = 279, + OpCommitReadPipe = 280, + OpCommitWritePipe = 281, + OpIsValidReserveId = 282, + OpGetNumPipePackets = 283, + OpGetMaxPipePackets = 284, + OpGroupReserveReadPipePackets = 285, + OpGroupReserveWritePipePackets = 286, + OpGroupCommitReadPipe = 287, + OpGroupCommitWritePipe = 288, + OpEnqueueMarker = 291, + OpEnqueueKernel = 292, + OpGetKernelNDrangeSubGroupCount = 293, + OpGetKernelNDrangeMaxSubGroupSize = 294, + OpGetKernelWorkGroupSize = 295, + OpGetKernelPreferredWorkGroupSizeMultiple = 296, + OpRetainEvent = 297, + OpReleaseEvent = 298, + OpCreateUserEvent = 299, + OpIsValidEvent = 300, + OpSetUserEventStatus = 301, + OpCaptureEventProfilingInfo = 302, + OpGetDefaultQueue = 303, + OpBuildNDRange = 304, + OpImageSparseSampleImplicitLod = 305, + OpImageSparseSampleExplicitLod = 306, + OpImageSparseSampleDrefImplicitLod = 307, + OpImageSparseSampleDrefExplicitLod = 308, + OpImageSparseSampleProjImplicitLod = 309, + OpImageSparseSampleProjExplicitLod = 310, + OpImageSparseSampleProjDrefImplicitLod = 311, + OpImageSparseSampleProjDrefExplicitLod = 312, + OpImageSparseFetch = 313, + OpImageSparseGather = 314, + OpImageSparseDrefGather = 315, + OpImageSparseTexelsResident = 316, + OpNoLine = 317, + OpAtomicFlagTestAndSet = 318, + OpAtomicFlagClear = 319, + OpImageSparseRead = 320, + OpSizeOf = 321, + OpTypePipeStorage = 322, + OpConstantPipeStorage = 323, + OpCreatePipeFromPipeStorage = 324, + OpGetKernelLocalSizeForSubgroupCount = 325, + OpGetKernelMaxNumSubgroups = 326, + OpTypeNamedBarrier = 327, + OpNamedBarrierInitialize = 328, + OpMemoryNamedBarrier = 329, + OpModuleProcessed = 330, + OpExecutionModeId = 331, + OpDecorateId = 332, + OpGroupNonUniformElect = 333, + OpGroupNonUniformAll = 334, + OpGroupNonUniformAny = 335, + OpGroupNonUniformAllEqual = 336, + OpGroupNonUniformBroadcast = 337, + OpGroupNonUniformBroadcastFirst = 338, + OpGroupNonUniformBallot = 339, + OpGroupNonUniformInverseBallot = 340, + OpGroupNonUniformBallotBitExtract = 341, + OpGroupNonUniformBallotBitCount = 342, + OpGroupNonUniformBallotFindLSB = 343, + OpGroupNonUniformBallotFindMSB = 344, + OpGroupNonUniformShuffle = 345, + OpGroupNonUniformShuffleXor = 346, + OpGroupNonUniformShuffleUp = 347, + OpGroupNonUniformShuffleDown = 348, + OpGroupNonUniformIAdd = 349, + OpGroupNonUniformFAdd = 350, + OpGroupNonUniformIMul = 351, + OpGroupNonUniformFMul = 352, + OpGroupNonUniformSMin = 353, + OpGroupNonUniformUMin = 354, + OpGroupNonUniformFMin = 355, + OpGroupNonUniformSMax = 356, + OpGroupNonUniformUMax = 357, + OpGroupNonUniformFMax = 358, + OpGroupNonUniformBitwiseAnd = 359, + OpGroupNonUniformBitwiseOr = 360, + OpGroupNonUniformBitwiseXor = 361, + OpGroupNonUniformLogicalAnd = 362, + OpGroupNonUniformLogicalOr = 363, + OpGroupNonUniformLogicalXor = 364, + OpGroupNonUniformQuadBroadcast = 365, + OpGroupNonUniformQuadSwap = 366, + OpCopyLogical = 400, + OpPtrEqual = 401, + OpPtrNotEqual = 402, + OpPtrDiff = 403, + OpTerminateInvocation = 4416, + OpSubgroupBallotKHR = 4421, + OpSubgroupFirstInvocationKHR = 4422, + OpSubgroupAllKHR = 4428, + OpSubgroupAnyKHR = 4429, + OpSubgroupAllEqualKHR = 4430, + OpGroupNonUniformRotateKHR = 4431, + OpSubgroupReadInvocationKHR = 4432, + OpTraceRayKHR = 4445, + OpExecuteCallableKHR = 4446, + OpConvertUToAccelerationStructureKHR = 4447, + OpIgnoreIntersectionKHR = 4448, + OpTerminateRayKHR = 4449, + OpSDot = 4450, + OpSDotKHR = 4450, + OpUDot = 4451, + OpUDotKHR = 4451, + OpSUDot = 4452, + OpSUDotKHR = 4452, + OpSDotAccSat = 4453, + OpSDotAccSatKHR = 4453, + OpUDotAccSat = 4454, + OpUDotAccSatKHR = 4454, + OpSUDotAccSat = 4455, + OpSUDotAccSatKHR = 4455, + OpTypeRayQueryKHR = 4472, + OpRayQueryInitializeKHR = 4473, + OpRayQueryTerminateKHR = 4474, + OpRayQueryGenerateIntersectionKHR = 4475, + OpRayQueryConfirmIntersectionKHR = 4476, + OpRayQueryProceedKHR = 4477, + OpRayQueryGetIntersectionTypeKHR = 4479, + OpImageSampleWeightedQCOM = 4480, + OpImageBoxFilterQCOM = 4481, + OpImageBlockMatchSSDQCOM = 4482, + OpImageBlockMatchSADQCOM = 4483, + OpGroupIAddNonUniformAMD = 5000, + OpGroupFAddNonUniformAMD = 5001, + OpGroupFMinNonUniformAMD = 5002, + OpGroupUMinNonUniformAMD = 5003, + OpGroupSMinNonUniformAMD = 5004, + OpGroupFMaxNonUniformAMD = 5005, + OpGroupUMaxNonUniformAMD = 5006, + OpGroupSMaxNonUniformAMD = 5007, + OpFragmentMaskFetchAMD = 5011, + OpFragmentFetchAMD = 5012, + OpReadClockKHR = 5056, + OpImageSampleFootprintNV = 5283, + OpEmitMeshTasksEXT = 5294, + OpSetMeshOutputsEXT = 5295, + OpGroupNonUniformPartitionNV = 5296, + OpWritePackedPrimitiveIndices4x8NV = 5299, + OpReportIntersectionKHR = 5334, + OpReportIntersectionNV = 5334, + OpIgnoreIntersectionNV = 5335, + OpTerminateRayNV = 5336, + OpTraceNV = 5337, + OpTraceMotionNV = 5338, + OpTraceRayMotionNV = 5339, + OpTypeAccelerationStructureKHR = 5341, + OpTypeAccelerationStructureNV = 5341, + OpExecuteCallableNV = 5344, + OpTypeCooperativeMatrixNV = 5358, + OpCooperativeMatrixLoadNV = 5359, + OpCooperativeMatrixStoreNV = 5360, + OpCooperativeMatrixMulAddNV = 5361, + OpCooperativeMatrixLengthNV = 5362, + OpBeginInvocationInterlockEXT = 5364, + OpEndInvocationInterlockEXT = 5365, + OpDemoteToHelperInvocation = 5380, + OpDemoteToHelperInvocationEXT = 5380, + OpIsHelperInvocationEXT = 5381, + OpConvertUToImageNV = 5391, + OpConvertUToSamplerNV = 5392, + OpConvertImageToUNV = 5393, + OpConvertSamplerToUNV = 5394, + OpConvertUToSampledImageNV = 5395, + OpConvertSampledImageToUNV = 5396, + OpSamplerImageAddressingModeNV = 5397, + OpSubgroupShuffleINTEL = 5571, + OpSubgroupShuffleDownINTEL = 5572, + OpSubgroupShuffleUpINTEL = 5573, + OpSubgroupShuffleXorINTEL = 5574, + OpSubgroupBlockReadINTEL = 5575, + OpSubgroupBlockWriteINTEL = 5576, + OpSubgroupImageBlockReadINTEL = 5577, + OpSubgroupImageBlockWriteINTEL = 5578, + OpSubgroupImageMediaBlockReadINTEL = 5580, + OpSubgroupImageMediaBlockWriteINTEL = 5581, + OpUCountLeadingZerosINTEL = 5585, + OpUCountTrailingZerosINTEL = 5586, + OpAbsISubINTEL = 5587, + OpAbsUSubINTEL = 5588, + OpIAddSatINTEL = 5589, + OpUAddSatINTEL = 5590, + OpIAverageINTEL = 5591, + OpUAverageINTEL = 5592, + OpIAverageRoundedINTEL = 5593, + OpUAverageRoundedINTEL = 5594, + OpISubSatINTEL = 5595, + OpUSubSatINTEL = 5596, + OpIMul32x16INTEL = 5597, + OpUMul32x16INTEL = 5598, + OpConstantFunctionPointerINTEL = 5600, + OpFunctionPointerCallINTEL = 5601, + OpAsmTargetINTEL = 5609, + OpAsmINTEL = 5610, + OpAsmCallINTEL = 5611, + OpAtomicFMinEXT = 5614, + OpAtomicFMaxEXT = 5615, + OpAssumeTrueKHR = 5630, + OpExpectKHR = 5631, + OpDecorateString = 5632, + OpDecorateStringGOOGLE = 5632, + OpMemberDecorateString = 5633, + OpMemberDecorateStringGOOGLE = 5633, + OpVmeImageINTEL = 5699, + OpTypeVmeImageINTEL = 5700, + OpTypeAvcImePayloadINTEL = 5701, + OpTypeAvcRefPayloadINTEL = 5702, + OpTypeAvcSicPayloadINTEL = 5703, + OpTypeAvcMcePayloadINTEL = 5704, + OpTypeAvcMceResultINTEL = 5705, + OpTypeAvcImeResultINTEL = 5706, + OpTypeAvcImeResultSingleReferenceStreamoutINTEL = 5707, + OpTypeAvcImeResultDualReferenceStreamoutINTEL = 5708, + OpTypeAvcImeSingleReferenceStreaminINTEL = 5709, + OpTypeAvcImeDualReferenceStreaminINTEL = 5710, + OpTypeAvcRefResultINTEL = 5711, + OpTypeAvcSicResultINTEL = 5712, + OpSubgroupAvcMceGetDefaultInterBaseMultiReferencePenaltyINTEL = 5713, + OpSubgroupAvcMceSetInterBaseMultiReferencePenaltyINTEL = 5714, + OpSubgroupAvcMceGetDefaultInterShapePenaltyINTEL = 5715, + OpSubgroupAvcMceSetInterShapePenaltyINTEL = 5716, + OpSubgroupAvcMceGetDefaultInterDirectionPenaltyINTEL = 5717, + OpSubgroupAvcMceSetInterDirectionPenaltyINTEL = 5718, + OpSubgroupAvcMceGetDefaultIntraLumaShapePenaltyINTEL = 5719, + OpSubgroupAvcMceGetDefaultInterMotionVectorCostTableINTEL = 5720, + OpSubgroupAvcMceGetDefaultHighPenaltyCostTableINTEL = 5721, + OpSubgroupAvcMceGetDefaultMediumPenaltyCostTableINTEL = 5722, + OpSubgroupAvcMceGetDefaultLowPenaltyCostTableINTEL = 5723, + OpSubgroupAvcMceSetMotionVectorCostFunctionINTEL = 5724, + OpSubgroupAvcMceGetDefaultIntraLumaModePenaltyINTEL = 5725, + OpSubgroupAvcMceGetDefaultNonDcLumaIntraPenaltyINTEL = 5726, + OpSubgroupAvcMceGetDefaultIntraChromaModeBasePenaltyINTEL = 5727, + OpSubgroupAvcMceSetAcOnlyHaarINTEL = 5728, + OpSubgroupAvcMceSetSourceInterlacedFieldPolarityINTEL = 5729, + OpSubgroupAvcMceSetSingleReferenceInterlacedFieldPolarityINTEL = 5730, + OpSubgroupAvcMceSetDualReferenceInterlacedFieldPolaritiesINTEL = 5731, + OpSubgroupAvcMceConvertToImePayloadINTEL = 5732, + OpSubgroupAvcMceConvertToImeResultINTEL = 5733, + OpSubgroupAvcMceConvertToRefPayloadINTEL = 5734, + OpSubgroupAvcMceConvertToRefResultINTEL = 5735, + OpSubgroupAvcMceConvertToSicPayloadINTEL = 5736, + OpSubgroupAvcMceConvertToSicResultINTEL = 5737, + OpSubgroupAvcMceGetMotionVectorsINTEL = 5738, + OpSubgroupAvcMceGetInterDistortionsINTEL = 5739, + OpSubgroupAvcMceGetBestInterDistortionsINTEL = 5740, + OpSubgroupAvcMceGetInterMajorShapeINTEL = 5741, + OpSubgroupAvcMceGetInterMinorShapeINTEL = 5742, + OpSubgroupAvcMceGetInterDirectionsINTEL = 5743, + OpSubgroupAvcMceGetInterMotionVectorCountINTEL = 5744, + OpSubgroupAvcMceGetInterReferenceIdsINTEL = 5745, + OpSubgroupAvcMceGetInterReferenceInterlacedFieldPolaritiesINTEL = 5746, + OpSubgroupAvcImeInitializeINTEL = 5747, + OpSubgroupAvcImeSetSingleReferenceINTEL = 5748, + OpSubgroupAvcImeSetDualReferenceINTEL = 5749, + OpSubgroupAvcImeRefWindowSizeINTEL = 5750, + OpSubgroupAvcImeAdjustRefOffsetINTEL = 5751, + OpSubgroupAvcImeConvertToMcePayloadINTEL = 5752, + OpSubgroupAvcImeSetMaxMotionVectorCountINTEL = 5753, + OpSubgroupAvcImeSetUnidirectionalMixDisableINTEL = 5754, + OpSubgroupAvcImeSetEarlySearchTerminationThresholdINTEL = 5755, + OpSubgroupAvcImeSetWeightedSadINTEL = 5756, + OpSubgroupAvcImeEvaluateWithSingleReferenceINTEL = 5757, + OpSubgroupAvcImeEvaluateWithDualReferenceINTEL = 5758, + OpSubgroupAvcImeEvaluateWithSingleReferenceStreaminINTEL = 5759, + OpSubgroupAvcImeEvaluateWithDualReferenceStreaminINTEL = 5760, + OpSubgroupAvcImeEvaluateWithSingleReferenceStreamoutINTEL = 5761, + OpSubgroupAvcImeEvaluateWithDualReferenceStreamoutINTEL = 5762, + OpSubgroupAvcImeEvaluateWithSingleReferenceStreaminoutINTEL = 5763, + OpSubgroupAvcImeEvaluateWithDualReferenceStreaminoutINTEL = 5764, + OpSubgroupAvcImeConvertToMceResultINTEL = 5765, + OpSubgroupAvcImeGetSingleReferenceStreaminINTEL = 5766, + OpSubgroupAvcImeGetDualReferenceStreaminINTEL = 5767, + OpSubgroupAvcImeStripSingleReferenceStreamoutINTEL = 5768, + OpSubgroupAvcImeStripDualReferenceStreamoutINTEL = 5769, + OpSubgroupAvcImeGetStreamoutSingleReferenceMajorShapeMotionVectorsINTEL = 5770, + OpSubgroupAvcImeGetStreamoutSingleReferenceMajorShapeDistortionsINTEL = 5771, + OpSubgroupAvcImeGetStreamoutSingleReferenceMajorShapeReferenceIdsINTEL = 5772, + OpSubgroupAvcImeGetStreamoutDualReferenceMajorShapeMotionVectorsINTEL = 5773, + OpSubgroupAvcImeGetStreamoutDualReferenceMajorShapeDistortionsINTEL = 5774, + OpSubgroupAvcImeGetStreamoutDualReferenceMajorShapeReferenceIdsINTEL = 5775, + OpSubgroupAvcImeGetBorderReachedINTEL = 5776, + OpSubgroupAvcImeGetTruncatedSearchIndicationINTEL = 5777, + OpSubgroupAvcImeGetUnidirectionalEarlySearchTerminationINTEL = 5778, + OpSubgroupAvcImeGetWeightingPatternMinimumMotionVectorINTEL = 5779, + OpSubgroupAvcImeGetWeightingPatternMinimumDistortionINTEL = 5780, + OpSubgroupAvcFmeInitializeINTEL = 5781, + OpSubgroupAvcBmeInitializeINTEL = 5782, + OpSubgroupAvcRefConvertToMcePayloadINTEL = 5783, + OpSubgroupAvcRefSetBidirectionalMixDisableINTEL = 5784, + OpSubgroupAvcRefSetBilinearFilterEnableINTEL = 5785, + OpSubgroupAvcRefEvaluateWithSingleReferenceINTEL = 5786, + OpSubgroupAvcRefEvaluateWithDualReferenceINTEL = 5787, + OpSubgroupAvcRefEvaluateWithMultiReferenceINTEL = 5788, + OpSubgroupAvcRefEvaluateWithMultiReferenceInterlacedINTEL = 5789, + OpSubgroupAvcRefConvertToMceResultINTEL = 5790, + OpSubgroupAvcSicInitializeINTEL = 5791, + OpSubgroupAvcSicConfigureSkcINTEL = 5792, + OpSubgroupAvcSicConfigureIpeLumaINTEL = 5793, + OpSubgroupAvcSicConfigureIpeLumaChromaINTEL = 5794, + OpSubgroupAvcSicGetMotionVectorMaskINTEL = 5795, + OpSubgroupAvcSicConvertToMcePayloadINTEL = 5796, + OpSubgroupAvcSicSetIntraLumaShapePenaltyINTEL = 5797, + OpSubgroupAvcSicSetIntraLumaModeCostFunctionINTEL = 5798, + OpSubgroupAvcSicSetIntraChromaModeCostFunctionINTEL = 5799, + OpSubgroupAvcSicSetBilinearFilterEnableINTEL = 5800, + OpSubgroupAvcSicSetSkcForwardTransformEnableINTEL = 5801, + OpSubgroupAvcSicSetBlockBasedRawSkipSadINTEL = 5802, + OpSubgroupAvcSicEvaluateIpeINTEL = 5803, + OpSubgroupAvcSicEvaluateWithSingleReferenceINTEL = 5804, + OpSubgroupAvcSicEvaluateWithDualReferenceINTEL = 5805, + OpSubgroupAvcSicEvaluateWithMultiReferenceINTEL = 5806, + OpSubgroupAvcSicEvaluateWithMultiReferenceInterlacedINTEL = 5807, + OpSubgroupAvcSicConvertToMceResultINTEL = 5808, + OpSubgroupAvcSicGetIpeLumaShapeINTEL = 5809, + OpSubgroupAvcSicGetBestIpeLumaDistortionINTEL = 5810, + OpSubgroupAvcSicGetBestIpeChromaDistortionINTEL = 5811, + OpSubgroupAvcSicGetPackedIpeLumaModesINTEL = 5812, + OpSubgroupAvcSicGetIpeChromaModeINTEL = 5813, + OpSubgroupAvcSicGetPackedSkcLumaCountThresholdINTEL = 5814, + OpSubgroupAvcSicGetPackedSkcLumaSumThresholdINTEL = 5815, + OpSubgroupAvcSicGetInterRawSadsINTEL = 5816, + OpVariableLengthArrayINTEL = 5818, + OpSaveMemoryINTEL = 5819, + OpRestoreMemoryINTEL = 5820, + OpArbitraryFloatSinCosPiINTEL = 5840, + OpArbitraryFloatCastINTEL = 5841, + OpArbitraryFloatCastFromIntINTEL = 5842, + OpArbitraryFloatCastToIntINTEL = 5843, + OpArbitraryFloatAddINTEL = 5846, + OpArbitraryFloatSubINTEL = 5847, + OpArbitraryFloatMulINTEL = 5848, + OpArbitraryFloatDivINTEL = 5849, + OpArbitraryFloatGTINTEL = 5850, + OpArbitraryFloatGEINTEL = 5851, + OpArbitraryFloatLTINTEL = 5852, + OpArbitraryFloatLEINTEL = 5853, + OpArbitraryFloatEQINTEL = 5854, + OpArbitraryFloatRecipINTEL = 5855, + OpArbitraryFloatRSqrtINTEL = 5856, + OpArbitraryFloatCbrtINTEL = 5857, + OpArbitraryFloatHypotINTEL = 5858, + OpArbitraryFloatSqrtINTEL = 5859, + OpArbitraryFloatLogINTEL = 5860, + OpArbitraryFloatLog2INTEL = 5861, + OpArbitraryFloatLog10INTEL = 5862, + OpArbitraryFloatLog1pINTEL = 5863, + OpArbitraryFloatExpINTEL = 5864, + OpArbitraryFloatExp2INTEL = 5865, + OpArbitraryFloatExp10INTEL = 5866, + OpArbitraryFloatExpm1INTEL = 5867, + OpArbitraryFloatSinINTEL = 5868, + OpArbitraryFloatCosINTEL = 5869, + OpArbitraryFloatSinCosINTEL = 5870, + OpArbitraryFloatSinPiINTEL = 5871, + OpArbitraryFloatCosPiINTEL = 5872, + OpArbitraryFloatASinINTEL = 5873, + OpArbitraryFloatASinPiINTEL = 5874, + OpArbitraryFloatACosINTEL = 5875, + OpArbitraryFloatACosPiINTEL = 5876, + OpArbitraryFloatATanINTEL = 5877, + OpArbitraryFloatATanPiINTEL = 5878, + OpArbitraryFloatATan2INTEL = 5879, + OpArbitraryFloatPowINTEL = 5880, + OpArbitraryFloatPowRINTEL = 5881, + OpArbitraryFloatPowNINTEL = 5882, + OpLoopControlINTEL = 5887, + OpAliasDomainDeclINTEL = 5911, + OpAliasScopeDeclINTEL = 5912, + OpAliasScopeListDeclINTEL = 5913, + OpFixedSqrtINTEL = 5923, + OpFixedRecipINTEL = 5924, + OpFixedRsqrtINTEL = 5925, + OpFixedSinINTEL = 5926, + OpFixedCosINTEL = 5927, + OpFixedSinCosINTEL = 5928, + OpFixedSinPiINTEL = 5929, + OpFixedCosPiINTEL = 5930, + OpFixedSinCosPiINTEL = 5931, + OpFixedLogINTEL = 5932, + OpFixedExpINTEL = 5933, + OpPtrCastToCrossWorkgroupINTEL = 5934, + OpCrossWorkgroupCastToPtrINTEL = 5938, + OpReadPipeBlockingINTEL = 5946, + OpWritePipeBlockingINTEL = 5947, + OpFPGARegINTEL = 5949, + OpRayQueryGetRayTMinKHR = 6016, + OpRayQueryGetRayFlagsKHR = 6017, + OpRayQueryGetIntersectionTKHR = 6018, + OpRayQueryGetIntersectionInstanceCustomIndexKHR = 6019, + OpRayQueryGetIntersectionInstanceIdKHR = 6020, + OpRayQueryGetIntersectionInstanceShaderBindingTableRecordOffsetKHR = 6021, + OpRayQueryGetIntersectionGeometryIndexKHR = 6022, + OpRayQueryGetIntersectionPrimitiveIndexKHR = 6023, + OpRayQueryGetIntersectionBarycentricsKHR = 6024, + OpRayQueryGetIntersectionFrontFaceKHR = 6025, + OpRayQueryGetIntersectionCandidateAABBOpaqueKHR = 6026, + OpRayQueryGetIntersectionObjectRayDirectionKHR = 6027, + OpRayQueryGetIntersectionObjectRayOriginKHR = 6028, + OpRayQueryGetWorldRayDirectionKHR = 6029, + OpRayQueryGetWorldRayOriginKHR = 6030, + OpRayQueryGetIntersectionObjectToWorldKHR = 6031, + OpRayQueryGetIntersectionWorldToObjectKHR = 6032, + OpAtomicFAddEXT = 6035, + OpTypeBufferSurfaceINTEL = 6086, + OpTypeStructContinuedINTEL = 6090, + OpConstantCompositeContinuedINTEL = 6091, + OpSpecConstantCompositeContinuedINTEL = 6092, + OpControlBarrierArriveINTEL = 6142, + OpControlBarrierWaitINTEL = 6143, + OpGroupIMulKHR = 6401, + OpGroupFMulKHR = 6402, + OpGroupBitwiseAndKHR = 6403, + OpGroupBitwiseOrKHR = 6404, + OpGroupBitwiseXorKHR = 6405, + OpGroupLogicalAndKHR = 6406, + OpGroupLogicalOrKHR = 6407, + OpGroupLogicalXorKHR = 6408, + OpMax = 0x7fffffff, +}; + +#ifdef SPV_ENABLE_UTILITY_CODE +#ifndef __cplusplus +#include +#endif +inline void HasResultAndType(Op opcode, bool *hasResult, bool *hasResultType) { + *hasResult = *hasResultType = false; + switch (opcode) { + default: /* unknown opcode */ break; + case OpNop: *hasResult = false; *hasResultType = false; break; + case OpUndef: *hasResult = true; *hasResultType = true; break; + case OpSourceContinued: *hasResult = false; *hasResultType = false; break; + case OpSource: *hasResult = false; *hasResultType = false; break; + case OpSourceExtension: *hasResult = false; *hasResultType = false; break; + case OpName: *hasResult = false; *hasResultType = false; break; + case OpMemberName: *hasResult = false; *hasResultType = false; break; + case OpString: *hasResult = true; *hasResultType = false; break; + case OpLine: *hasResult = false; *hasResultType = false; break; + case OpExtension: *hasResult = false; *hasResultType = false; break; + case OpExtInstImport: *hasResult = true; *hasResultType = false; break; + case OpExtInst: *hasResult = true; *hasResultType = true; break; + case OpMemoryModel: *hasResult = false; *hasResultType = false; break; + case OpEntryPoint: *hasResult = false; *hasResultType = false; break; + case OpExecutionMode: *hasResult = false; *hasResultType = false; break; + case OpCapability: *hasResult = false; *hasResultType = false; break; + case OpTypeVoid: *hasResult = true; *hasResultType = false; break; + case OpTypeBool: *hasResult = true; *hasResultType = false; break; + case OpTypeInt: *hasResult = true; *hasResultType = false; break; + case OpTypeFloat: *hasResult = true; *hasResultType = false; break; + case OpTypeVector: *hasResult = true; *hasResultType = false; break; + case OpTypeMatrix: *hasResult = true; *hasResultType = false; break; + case OpTypeImage: *hasResult = true; *hasResultType = false; break; + case OpTypeSampler: *hasResult = true; *hasResultType = false; break; + case OpTypeSampledImage: *hasResult = true; *hasResultType = false; break; + case OpTypeArray: *hasResult = true; *hasResultType = false; break; + case OpTypeRuntimeArray: *hasResult = true; *hasResultType = false; break; + case OpTypeStruct: *hasResult = true; *hasResultType = false; break; + case OpTypeOpaque: *hasResult = true; *hasResultType = false; break; + case OpTypePointer: *hasResult = true; *hasResultType = false; break; + case OpTypeFunction: *hasResult = true; *hasResultType = false; break; + case OpTypeEvent: *hasResult = true; *hasResultType = false; break; + case OpTypeDeviceEvent: *hasResult = true; *hasResultType = false; break; + case OpTypeReserveId: *hasResult = true; *hasResultType = false; break; + case OpTypeQueue: *hasResult = true; *hasResultType = false; break; + case OpTypePipe: *hasResult = true; *hasResultType = false; break; + case OpTypeForwardPointer: *hasResult = false; *hasResultType = false; break; + case OpConstantTrue: *hasResult = true; *hasResultType = true; break; + case OpConstantFalse: *hasResult = true; *hasResultType = true; break; + case OpConstant: *hasResult = true; *hasResultType = true; break; + case OpConstantComposite: *hasResult = true; *hasResultType = true; break; + case OpConstantSampler: *hasResult = true; *hasResultType = true; break; + case OpConstantNull: *hasResult = true; *hasResultType = true; break; + case OpSpecConstantTrue: *hasResult = true; *hasResultType = true; break; + case OpSpecConstantFalse: *hasResult = true; *hasResultType = true; break; + case OpSpecConstant: *hasResult = true; *hasResultType = true; break; + case OpSpecConstantComposite: *hasResult = true; *hasResultType = true; break; + case OpSpecConstantOp: *hasResult = true; *hasResultType = true; break; + case OpFunction: *hasResult = true; *hasResultType = true; break; + case OpFunctionParameter: *hasResult = true; *hasResultType = true; break; + case OpFunctionEnd: *hasResult = false; *hasResultType = false; break; + case OpFunctionCall: *hasResult = true; *hasResultType = true; break; + case OpVariable: *hasResult = true; *hasResultType = true; break; + case OpImageTexelPointer: *hasResult = true; *hasResultType = true; break; + case OpLoad: *hasResult = true; *hasResultType = true; break; + case OpStore: *hasResult = false; *hasResultType = false; break; + case OpCopyMemory: *hasResult = false; *hasResultType = false; break; + case OpCopyMemorySized: *hasResult = false; *hasResultType = false; break; + case OpAccessChain: *hasResult = true; *hasResultType = true; break; + case OpInBoundsAccessChain: *hasResult = true; *hasResultType = true; break; + case OpPtrAccessChain: *hasResult = true; *hasResultType = true; break; + case OpArrayLength: *hasResult = true; *hasResultType = true; break; + case OpGenericPtrMemSemantics: *hasResult = true; *hasResultType = true; break; + case OpInBoundsPtrAccessChain: *hasResult = true; *hasResultType = true; break; + case OpDecorate: *hasResult = false; *hasResultType = false; break; + case OpMemberDecorate: *hasResult = false; *hasResultType = false; break; + case OpDecorationGroup: *hasResult = true; *hasResultType = false; break; + case OpGroupDecorate: *hasResult = false; *hasResultType = false; break; + case OpGroupMemberDecorate: *hasResult = false; *hasResultType = false; break; + case OpVectorExtractDynamic: *hasResult = true; *hasResultType = true; break; + case OpVectorInsertDynamic: *hasResult = true; *hasResultType = true; break; + case OpVectorShuffle: *hasResult = true; *hasResultType = true; break; + case OpCompositeConstruct: *hasResult = true; *hasResultType = true; break; + case OpCompositeExtract: *hasResult = true; *hasResultType = true; break; + case OpCompositeInsert: *hasResult = true; *hasResultType = true; break; + case OpCopyObject: *hasResult = true; *hasResultType = true; break; + case OpTranspose: *hasResult = true; *hasResultType = true; break; + case OpSampledImage: *hasResult = true; *hasResultType = true; break; + case OpImageSampleImplicitLod: *hasResult = true; *hasResultType = true; break; + case OpImageSampleExplicitLod: *hasResult = true; *hasResultType = true; break; + case OpImageSampleDrefImplicitLod: *hasResult = true; *hasResultType = true; break; + case OpImageSampleDrefExplicitLod: *hasResult = true; *hasResultType = true; break; + case OpImageSampleProjImplicitLod: *hasResult = true; *hasResultType = true; break; + case OpImageSampleProjExplicitLod: *hasResult = true; *hasResultType = true; break; + case OpImageSampleProjDrefImplicitLod: *hasResult = true; *hasResultType = true; break; + case OpImageSampleProjDrefExplicitLod: *hasResult = true; *hasResultType = true; break; + case OpImageFetch: *hasResult = true; *hasResultType = true; break; + case OpImageGather: *hasResult = true; *hasResultType = true; break; + case OpImageDrefGather: *hasResult = true; *hasResultType = true; break; + case OpImageRead: *hasResult = true; *hasResultType = true; break; + case OpImageWrite: *hasResult = false; *hasResultType = false; break; + case OpImage: *hasResult = true; *hasResultType = true; break; + case OpImageQueryFormat: *hasResult = true; *hasResultType = true; break; + case OpImageQueryOrder: *hasResult = true; *hasResultType = true; break; + case OpImageQuerySizeLod: *hasResult = true; *hasResultType = true; break; + case OpImageQuerySize: *hasResult = true; *hasResultType = true; break; + case OpImageQueryLod: *hasResult = true; *hasResultType = true; break; + case OpImageQueryLevels: *hasResult = true; *hasResultType = true; break; + case OpImageQuerySamples: *hasResult = true; *hasResultType = true; break; + case OpConvertFToU: *hasResult = true; *hasResultType = true; break; + case OpConvertFToS: *hasResult = true; *hasResultType = true; break; + case OpConvertSToF: *hasResult = true; *hasResultType = true; break; + case OpConvertUToF: *hasResult = true; *hasResultType = true; break; + case OpUConvert: *hasResult = true; *hasResultType = true; break; + case OpSConvert: *hasResult = true; *hasResultType = true; break; + case OpFConvert: *hasResult = true; *hasResultType = true; break; + case OpQuantizeToF16: *hasResult = true; *hasResultType = true; break; + case OpConvertPtrToU: *hasResult = true; *hasResultType = true; break; + case OpSatConvertSToU: *hasResult = true; *hasResultType = true; break; + case OpSatConvertUToS: *hasResult = true; *hasResultType = true; break; + case OpConvertUToPtr: *hasResult = true; *hasResultType = true; break; + case OpPtrCastToGeneric: *hasResult = true; *hasResultType = true; break; + case OpGenericCastToPtr: *hasResult = true; *hasResultType = true; break; + case OpGenericCastToPtrExplicit: *hasResult = true; *hasResultType = true; break; + case OpBitcast: *hasResult = true; *hasResultType = true; break; + case OpSNegate: *hasResult = true; *hasResultType = true; break; + case OpFNegate: *hasResult = true; *hasResultType = true; break; + case OpIAdd: *hasResult = true; *hasResultType = true; break; + case OpFAdd: *hasResult = true; *hasResultType = true; break; + case OpISub: *hasResult = true; *hasResultType = true; break; + case OpFSub: *hasResult = true; *hasResultType = true; break; + case OpIMul: *hasResult = true; *hasResultType = true; break; + case OpFMul: *hasResult = true; *hasResultType = true; break; + case OpUDiv: *hasResult = true; *hasResultType = true; break; + case OpSDiv: *hasResult = true; *hasResultType = true; break; + case OpFDiv: *hasResult = true; *hasResultType = true; break; + case OpUMod: *hasResult = true; *hasResultType = true; break; + case OpSRem: *hasResult = true; *hasResultType = true; break; + case OpSMod: *hasResult = true; *hasResultType = true; break; + case OpFRem: *hasResult = true; *hasResultType = true; break; + case OpFMod: *hasResult = true; *hasResultType = true; break; + case OpVectorTimesScalar: *hasResult = true; *hasResultType = true; break; + case OpMatrixTimesScalar: *hasResult = true; *hasResultType = true; break; + case OpVectorTimesMatrix: *hasResult = true; *hasResultType = true; break; + case OpMatrixTimesVector: *hasResult = true; *hasResultType = true; break; + case OpMatrixTimesMatrix: *hasResult = true; *hasResultType = true; break; + case OpOuterProduct: *hasResult = true; *hasResultType = true; break; + case OpDot: *hasResult = true; *hasResultType = true; break; + case OpIAddCarry: *hasResult = true; *hasResultType = true; break; + case OpISubBorrow: *hasResult = true; *hasResultType = true; break; + case OpUMulExtended: *hasResult = true; *hasResultType = true; break; + case OpSMulExtended: *hasResult = true; *hasResultType = true; break; + case OpAny: *hasResult = true; *hasResultType = true; break; + case OpAll: *hasResult = true; *hasResultType = true; break; + case OpIsNan: *hasResult = true; *hasResultType = true; break; + case OpIsInf: *hasResult = true; *hasResultType = true; break; + case OpIsFinite: *hasResult = true; *hasResultType = true; break; + case OpIsNormal: *hasResult = true; *hasResultType = true; break; + case OpSignBitSet: *hasResult = true; *hasResultType = true; break; + case OpLessOrGreater: *hasResult = true; *hasResultType = true; break; + case OpOrdered: *hasResult = true; *hasResultType = true; break; + case OpUnordered: *hasResult = true; *hasResultType = true; break; + case OpLogicalEqual: *hasResult = true; *hasResultType = true; break; + case OpLogicalNotEqual: *hasResult = true; *hasResultType = true; break; + case OpLogicalOr: *hasResult = true; *hasResultType = true; break; + case OpLogicalAnd: *hasResult = true; *hasResultType = true; break; + case OpLogicalNot: *hasResult = true; *hasResultType = true; break; + case OpSelect: *hasResult = true; *hasResultType = true; break; + case OpIEqual: *hasResult = true; *hasResultType = true; break; + case OpINotEqual: *hasResult = true; *hasResultType = true; break; + case OpUGreaterThan: *hasResult = true; *hasResultType = true; break; + case OpSGreaterThan: *hasResult = true; *hasResultType = true; break; + case OpUGreaterThanEqual: *hasResult = true; *hasResultType = true; break; + case OpSGreaterThanEqual: *hasResult = true; *hasResultType = true; break; + case OpULessThan: *hasResult = true; *hasResultType = true; break; + case OpSLessThan: *hasResult = true; *hasResultType = true; break; + case OpULessThanEqual: *hasResult = true; *hasResultType = true; break; + case OpSLessThanEqual: *hasResult = true; *hasResultType = true; break; + case OpFOrdEqual: *hasResult = true; *hasResultType = true; break; + case OpFUnordEqual: *hasResult = true; *hasResultType = true; break; + case OpFOrdNotEqual: *hasResult = true; *hasResultType = true; break; + case OpFUnordNotEqual: *hasResult = true; *hasResultType = true; break; + case OpFOrdLessThan: *hasResult = true; *hasResultType = true; break; + case OpFUnordLessThan: *hasResult = true; *hasResultType = true; break; + case OpFOrdGreaterThan: *hasResult = true; *hasResultType = true; break; + case OpFUnordGreaterThan: *hasResult = true; *hasResultType = true; break; + case OpFOrdLessThanEqual: *hasResult = true; *hasResultType = true; break; + case OpFUnordLessThanEqual: *hasResult = true; *hasResultType = true; break; + case OpFOrdGreaterThanEqual: *hasResult = true; *hasResultType = true; break; + case OpFUnordGreaterThanEqual: *hasResult = true; *hasResultType = true; break; + case OpShiftRightLogical: *hasResult = true; *hasResultType = true; break; + case OpShiftRightArithmetic: *hasResult = true; *hasResultType = true; break; + case OpShiftLeftLogical: *hasResult = true; *hasResultType = true; break; + case OpBitwiseOr: *hasResult = true; *hasResultType = true; break; + case OpBitwiseXor: *hasResult = true; *hasResultType = true; break; + case OpBitwiseAnd: *hasResult = true; *hasResultType = true; break; + case OpNot: *hasResult = true; *hasResultType = true; break; + case OpBitFieldInsert: *hasResult = true; *hasResultType = true; break; + case OpBitFieldSExtract: *hasResult = true; *hasResultType = true; break; + case OpBitFieldUExtract: *hasResult = true; *hasResultType = true; break; + case OpBitReverse: *hasResult = true; *hasResultType = true; break; + case OpBitCount: *hasResult = true; *hasResultType = true; break; + case OpDPdx: *hasResult = true; *hasResultType = true; break; + case OpDPdy: *hasResult = true; *hasResultType = true; break; + case OpFwidth: *hasResult = true; *hasResultType = true; break; + case OpDPdxFine: *hasResult = true; *hasResultType = true; break; + case OpDPdyFine: *hasResult = true; *hasResultType = true; break; + case OpFwidthFine: *hasResult = true; *hasResultType = true; break; + case OpDPdxCoarse: *hasResult = true; *hasResultType = true; break; + case OpDPdyCoarse: *hasResult = true; *hasResultType = true; break; + case OpFwidthCoarse: *hasResult = true; *hasResultType = true; break; + case OpEmitVertex: *hasResult = false; *hasResultType = false; break; + case OpEndPrimitive: *hasResult = false; *hasResultType = false; break; + case OpEmitStreamVertex: *hasResult = false; *hasResultType = false; break; + case OpEndStreamPrimitive: *hasResult = false; *hasResultType = false; break; + case OpControlBarrier: *hasResult = false; *hasResultType = false; break; + case OpMemoryBarrier: *hasResult = false; *hasResultType = false; break; + case OpAtomicLoad: *hasResult = true; *hasResultType = true; break; + case OpAtomicStore: *hasResult = false; *hasResultType = false; break; + case OpAtomicExchange: *hasResult = true; *hasResultType = true; break; + case OpAtomicCompareExchange: *hasResult = true; *hasResultType = true; break; + case OpAtomicCompareExchangeWeak: *hasResult = true; *hasResultType = true; break; + case OpAtomicIIncrement: *hasResult = true; *hasResultType = true; break; + case OpAtomicIDecrement: *hasResult = true; *hasResultType = true; break; + case OpAtomicIAdd: *hasResult = true; *hasResultType = true; break; + case OpAtomicISub: *hasResult = true; *hasResultType = true; break; + case OpAtomicSMin: *hasResult = true; *hasResultType = true; break; + case OpAtomicUMin: *hasResult = true; *hasResultType = true; break; + case OpAtomicSMax: *hasResult = true; *hasResultType = true; break; + case OpAtomicUMax: *hasResult = true; *hasResultType = true; break; + case OpAtomicAnd: *hasResult = true; *hasResultType = true; break; + case OpAtomicOr: *hasResult = true; *hasResultType = true; break; + case OpAtomicXor: *hasResult = true; *hasResultType = true; break; + case OpPhi: *hasResult = true; *hasResultType = true; break; + case OpLoopMerge: *hasResult = false; *hasResultType = false; break; + case OpSelectionMerge: *hasResult = false; *hasResultType = false; break; + case OpLabel: *hasResult = true; *hasResultType = false; break; + case OpBranch: *hasResult = false; *hasResultType = false; break; + case OpBranchConditional: *hasResult = false; *hasResultType = false; break; + case OpSwitch: *hasResult = false; *hasResultType = false; break; + case OpKill: *hasResult = false; *hasResultType = false; break; + case OpReturn: *hasResult = false; *hasResultType = false; break; + case OpReturnValue: *hasResult = false; *hasResultType = false; break; + case OpUnreachable: *hasResult = false; *hasResultType = false; break; + case OpLifetimeStart: *hasResult = false; *hasResultType = false; break; + case OpLifetimeStop: *hasResult = false; *hasResultType = false; break; + case OpGroupAsyncCopy: *hasResult = true; *hasResultType = true; break; + case OpGroupWaitEvents: *hasResult = false; *hasResultType = false; break; + case OpGroupAll: *hasResult = true; *hasResultType = true; break; + case OpGroupAny: *hasResult = true; *hasResultType = true; break; + case OpGroupBroadcast: *hasResult = true; *hasResultType = true; break; + case OpGroupIAdd: *hasResult = true; *hasResultType = true; break; + case OpGroupFAdd: *hasResult = true; *hasResultType = true; break; + case OpGroupFMin: *hasResult = true; *hasResultType = true; break; + case OpGroupUMin: *hasResult = true; *hasResultType = true; break; + case OpGroupSMin: *hasResult = true; *hasResultType = true; break; + case OpGroupFMax: *hasResult = true; *hasResultType = true; break; + case OpGroupUMax: *hasResult = true; *hasResultType = true; break; + case OpGroupSMax: *hasResult = true; *hasResultType = true; break; + case OpReadPipe: *hasResult = true; *hasResultType = true; break; + case OpWritePipe: *hasResult = true; *hasResultType = true; break; + case OpReservedReadPipe: *hasResult = true; *hasResultType = true; break; + case OpReservedWritePipe: *hasResult = true; *hasResultType = true; break; + case OpReserveReadPipePackets: *hasResult = true; *hasResultType = true; break; + case OpReserveWritePipePackets: *hasResult = true; *hasResultType = true; break; + case OpCommitReadPipe: *hasResult = false; *hasResultType = false; break; + case OpCommitWritePipe: *hasResult = false; *hasResultType = false; break; + case OpIsValidReserveId: *hasResult = true; *hasResultType = true; break; + case OpGetNumPipePackets: *hasResult = true; *hasResultType = true; break; + case OpGetMaxPipePackets: *hasResult = true; *hasResultType = true; break; + case OpGroupReserveReadPipePackets: *hasResult = true; *hasResultType = true; break; + case OpGroupReserveWritePipePackets: *hasResult = true; *hasResultType = true; break; + case OpGroupCommitReadPipe: *hasResult = false; *hasResultType = false; break; + case OpGroupCommitWritePipe: *hasResult = false; *hasResultType = false; break; + case OpEnqueueMarker: *hasResult = true; *hasResultType = true; break; + case OpEnqueueKernel: *hasResult = true; *hasResultType = true; break; + case OpGetKernelNDrangeSubGroupCount: *hasResult = true; *hasResultType = true; break; + case OpGetKernelNDrangeMaxSubGroupSize: *hasResult = true; *hasResultType = true; break; + case OpGetKernelWorkGroupSize: *hasResult = true; *hasResultType = true; break; + case OpGetKernelPreferredWorkGroupSizeMultiple: *hasResult = true; *hasResultType = true; break; + case OpRetainEvent: *hasResult = false; *hasResultType = false; break; + case OpReleaseEvent: *hasResult = false; *hasResultType = false; break; + case OpCreateUserEvent: *hasResult = true; *hasResultType = true; break; + case OpIsValidEvent: *hasResult = true; *hasResultType = true; break; + case OpSetUserEventStatus: *hasResult = false; *hasResultType = false; break; + case OpCaptureEventProfilingInfo: *hasResult = false; *hasResultType = false; break; + case OpGetDefaultQueue: *hasResult = true; *hasResultType = true; break; + case OpBuildNDRange: *hasResult = true; *hasResultType = true; break; + case OpImageSparseSampleImplicitLod: *hasResult = true; *hasResultType = true; break; + case OpImageSparseSampleExplicitLod: *hasResult = true; *hasResultType = true; break; + case OpImageSparseSampleDrefImplicitLod: *hasResult = true; *hasResultType = true; break; + case OpImageSparseSampleDrefExplicitLod: *hasResult = true; *hasResultType = true; break; + case OpImageSparseSampleProjImplicitLod: *hasResult = true; *hasResultType = true; break; + case OpImageSparseSampleProjExplicitLod: *hasResult = true; *hasResultType = true; break; + case OpImageSparseSampleProjDrefImplicitLod: *hasResult = true; *hasResultType = true; break; + case OpImageSparseSampleProjDrefExplicitLod: *hasResult = true; *hasResultType = true; break; + case OpImageSparseFetch: *hasResult = true; *hasResultType = true; break; + case OpImageSparseGather: *hasResult = true; *hasResultType = true; break; + case OpImageSparseDrefGather: *hasResult = true; *hasResultType = true; break; + case OpImageSparseTexelsResident: *hasResult = true; *hasResultType = true; break; + case OpNoLine: *hasResult = false; *hasResultType = false; break; + case OpAtomicFlagTestAndSet: *hasResult = true; *hasResultType = true; break; + case OpAtomicFlagClear: *hasResult = false; *hasResultType = false; break; + case OpImageSparseRead: *hasResult = true; *hasResultType = true; break; + case OpSizeOf: *hasResult = true; *hasResultType = true; break; + case OpTypePipeStorage: *hasResult = true; *hasResultType = false; break; + case OpConstantPipeStorage: *hasResult = true; *hasResultType = true; break; + case OpCreatePipeFromPipeStorage: *hasResult = true; *hasResultType = true; break; + case OpGetKernelLocalSizeForSubgroupCount: *hasResult = true; *hasResultType = true; break; + case OpGetKernelMaxNumSubgroups: *hasResult = true; *hasResultType = true; break; + case OpTypeNamedBarrier: *hasResult = true; *hasResultType = false; break; + case OpNamedBarrierInitialize: *hasResult = true; *hasResultType = true; break; + case OpMemoryNamedBarrier: *hasResult = false; *hasResultType = false; break; + case OpModuleProcessed: *hasResult = false; *hasResultType = false; break; + case OpExecutionModeId: *hasResult = false; *hasResultType = false; break; + case OpDecorateId: *hasResult = false; *hasResultType = false; break; + case OpGroupNonUniformElect: *hasResult = true; *hasResultType = true; break; + case OpGroupNonUniformAll: *hasResult = true; *hasResultType = true; break; + case OpGroupNonUniformAny: *hasResult = true; *hasResultType = true; break; + case OpGroupNonUniformAllEqual: *hasResult = true; *hasResultType = true; break; + case OpGroupNonUniformBroadcast: *hasResult = true; *hasResultType = true; break; + case OpGroupNonUniformBroadcastFirst: *hasResult = true; *hasResultType = true; break; + case OpGroupNonUniformBallot: *hasResult = true; *hasResultType = true; break; + case OpGroupNonUniformInverseBallot: *hasResult = true; *hasResultType = true; break; + case OpGroupNonUniformBallotBitExtract: *hasResult = true; *hasResultType = true; break; + case OpGroupNonUniformBallotBitCount: *hasResult = true; *hasResultType = true; break; + case OpGroupNonUniformBallotFindLSB: *hasResult = true; *hasResultType = true; break; + case OpGroupNonUniformBallotFindMSB: *hasResult = true; *hasResultType = true; break; + case OpGroupNonUniformShuffle: *hasResult = true; *hasResultType = true; break; + case OpGroupNonUniformShuffleXor: *hasResult = true; *hasResultType = true; break; + case OpGroupNonUniformShuffleUp: *hasResult = true; *hasResultType = true; break; + case OpGroupNonUniformShuffleDown: *hasResult = true; *hasResultType = true; break; + case OpGroupNonUniformIAdd: *hasResult = true; *hasResultType = true; break; + case OpGroupNonUniformFAdd: *hasResult = true; *hasResultType = true; break; + case OpGroupNonUniformIMul: *hasResult = true; *hasResultType = true; break; + case OpGroupNonUniformFMul: *hasResult = true; *hasResultType = true; break; + case OpGroupNonUniformSMin: *hasResult = true; *hasResultType = true; break; + case OpGroupNonUniformUMin: *hasResult = true; *hasResultType = true; break; + case OpGroupNonUniformFMin: *hasResult = true; *hasResultType = true; break; + case OpGroupNonUniformSMax: *hasResult = true; *hasResultType = true; break; + case OpGroupNonUniformUMax: *hasResult = true; *hasResultType = true; break; + case OpGroupNonUniformFMax: *hasResult = true; *hasResultType = true; break; + case OpGroupNonUniformBitwiseAnd: *hasResult = true; *hasResultType = true; break; + case OpGroupNonUniformBitwiseOr: *hasResult = true; *hasResultType = true; break; + case OpGroupNonUniformBitwiseXor: *hasResult = true; *hasResultType = true; break; + case OpGroupNonUniformLogicalAnd: *hasResult = true; *hasResultType = true; break; + case OpGroupNonUniformLogicalOr: *hasResult = true; *hasResultType = true; break; + case OpGroupNonUniformLogicalXor: *hasResult = true; *hasResultType = true; break; + case OpGroupNonUniformQuadBroadcast: *hasResult = true; *hasResultType = true; break; + case OpGroupNonUniformQuadSwap: *hasResult = true; *hasResultType = true; break; + case OpCopyLogical: *hasResult = true; *hasResultType = true; break; + case OpPtrEqual: *hasResult = true; *hasResultType = true; break; + case OpPtrNotEqual: *hasResult = true; *hasResultType = true; break; + case OpPtrDiff: *hasResult = true; *hasResultType = true; break; + case OpTerminateInvocation: *hasResult = false; *hasResultType = false; break; + case OpSubgroupBallotKHR: *hasResult = true; *hasResultType = true; break; + case OpSubgroupFirstInvocationKHR: *hasResult = true; *hasResultType = true; break; + case OpSubgroupAllKHR: *hasResult = true; *hasResultType = true; break; + case OpSubgroupAnyKHR: *hasResult = true; *hasResultType = true; break; + case OpSubgroupAllEqualKHR: *hasResult = true; *hasResultType = true; break; + case OpGroupNonUniformRotateKHR: *hasResult = true; *hasResultType = true; break; + case OpSubgroupReadInvocationKHR: *hasResult = true; *hasResultType = true; break; + case OpTraceRayKHR: *hasResult = false; *hasResultType = false; break; + case OpExecuteCallableKHR: *hasResult = false; *hasResultType = false; break; + case OpConvertUToAccelerationStructureKHR: *hasResult = true; *hasResultType = true; break; + case OpIgnoreIntersectionKHR: *hasResult = false; *hasResultType = false; break; + case OpTerminateRayKHR: *hasResult = false; *hasResultType = false; break; + case OpSDot: *hasResult = true; *hasResultType = true; break; + case OpUDot: *hasResult = true; *hasResultType = true; break; + case OpSUDot: *hasResult = true; *hasResultType = true; break; + case OpSDotAccSat: *hasResult = true; *hasResultType = true; break; + case OpUDotAccSat: *hasResult = true; *hasResultType = true; break; + case OpSUDotAccSat: *hasResult = true; *hasResultType = true; break; + case OpTypeRayQueryKHR: *hasResult = true; *hasResultType = false; break; + case OpRayQueryInitializeKHR: *hasResult = false; *hasResultType = false; break; + case OpRayQueryTerminateKHR: *hasResult = false; *hasResultType = false; break; + case OpRayQueryGenerateIntersectionKHR: *hasResult = false; *hasResultType = false; break; + case OpRayQueryConfirmIntersectionKHR: *hasResult = false; *hasResultType = false; break; + case OpRayQueryProceedKHR: *hasResult = true; *hasResultType = true; break; + case OpRayQueryGetIntersectionTypeKHR: *hasResult = true; *hasResultType = true; break; + case OpImageSampleWeightedQCOM: *hasResult = true; *hasResultType = true; break; + case OpImageBoxFilterQCOM: *hasResult = true; *hasResultType = true; break; + case OpImageBlockMatchSSDQCOM: *hasResult = true; *hasResultType = true; break; + case OpImageBlockMatchSADQCOM: *hasResult = true; *hasResultType = true; break; + case OpGroupIAddNonUniformAMD: *hasResult = true; *hasResultType = true; break; + case OpGroupFAddNonUniformAMD: *hasResult = true; *hasResultType = true; break; + case OpGroupFMinNonUniformAMD: *hasResult = true; *hasResultType = true; break; + case OpGroupUMinNonUniformAMD: *hasResult = true; *hasResultType = true; break; + case OpGroupSMinNonUniformAMD: *hasResult = true; *hasResultType = true; break; + case OpGroupFMaxNonUniformAMD: *hasResult = true; *hasResultType = true; break; + case OpGroupUMaxNonUniformAMD: *hasResult = true; *hasResultType = true; break; + case OpGroupSMaxNonUniformAMD: *hasResult = true; *hasResultType = true; break; + case OpFragmentMaskFetchAMD: *hasResult = true; *hasResultType = true; break; + case OpFragmentFetchAMD: *hasResult = true; *hasResultType = true; break; + case OpReadClockKHR: *hasResult = true; *hasResultType = true; break; + case OpImageSampleFootprintNV: *hasResult = true; *hasResultType = true; break; + case OpEmitMeshTasksEXT: *hasResult = false; *hasResultType = false; break; + case OpSetMeshOutputsEXT: *hasResult = false; *hasResultType = false; break; + case OpGroupNonUniformPartitionNV: *hasResult = true; *hasResultType = true; break; + case OpWritePackedPrimitiveIndices4x8NV: *hasResult = false; *hasResultType = false; break; + case OpReportIntersectionNV: *hasResult = true; *hasResultType = true; break; + case OpIgnoreIntersectionNV: *hasResult = false; *hasResultType = false; break; + case OpTerminateRayNV: *hasResult = false; *hasResultType = false; break; + case OpTraceNV: *hasResult = false; *hasResultType = false; break; + case OpTraceMotionNV: *hasResult = false; *hasResultType = false; break; + case OpTraceRayMotionNV: *hasResult = false; *hasResultType = false; break; + case OpTypeAccelerationStructureNV: *hasResult = true; *hasResultType = false; break; + case OpExecuteCallableNV: *hasResult = false; *hasResultType = false; break; + case OpTypeCooperativeMatrixNV: *hasResult = true; *hasResultType = false; break; + case OpCooperativeMatrixLoadNV: *hasResult = true; *hasResultType = true; break; + case OpCooperativeMatrixStoreNV: *hasResult = false; *hasResultType = false; break; + case OpCooperativeMatrixMulAddNV: *hasResult = true; *hasResultType = true; break; + case OpCooperativeMatrixLengthNV: *hasResult = true; *hasResultType = true; break; + case OpBeginInvocationInterlockEXT: *hasResult = false; *hasResultType = false; break; + case OpEndInvocationInterlockEXT: *hasResult = false; *hasResultType = false; break; + case OpDemoteToHelperInvocation: *hasResult = false; *hasResultType = false; break; + case OpIsHelperInvocationEXT: *hasResult = true; *hasResultType = true; break; + case OpConvertUToImageNV: *hasResult = true; *hasResultType = true; break; + case OpConvertUToSamplerNV: *hasResult = true; *hasResultType = true; break; + case OpConvertImageToUNV: *hasResult = true; *hasResultType = true; break; + case OpConvertSamplerToUNV: *hasResult = true; *hasResultType = true; break; + case OpConvertUToSampledImageNV: *hasResult = true; *hasResultType = true; break; + case OpConvertSampledImageToUNV: *hasResult = true; *hasResultType = true; break; + case OpSamplerImageAddressingModeNV: *hasResult = false; *hasResultType = false; break; + case OpSubgroupShuffleINTEL: *hasResult = true; *hasResultType = true; break; + case OpSubgroupShuffleDownINTEL: *hasResult = true; *hasResultType = true; break; + case OpSubgroupShuffleUpINTEL: *hasResult = true; *hasResultType = true; break; + case OpSubgroupShuffleXorINTEL: *hasResult = true; *hasResultType = true; break; + case OpSubgroupBlockReadINTEL: *hasResult = true; *hasResultType = true; break; + case OpSubgroupBlockWriteINTEL: *hasResult = false; *hasResultType = false; break; + case OpSubgroupImageBlockReadINTEL: *hasResult = true; *hasResultType = true; break; + case OpSubgroupImageBlockWriteINTEL: *hasResult = false; *hasResultType = false; break; + case OpSubgroupImageMediaBlockReadINTEL: *hasResult = true; *hasResultType = true; break; + case OpSubgroupImageMediaBlockWriteINTEL: *hasResult = false; *hasResultType = false; break; + case OpUCountLeadingZerosINTEL: *hasResult = true; *hasResultType = true; break; + case OpUCountTrailingZerosINTEL: *hasResult = true; *hasResultType = true; break; + case OpAbsISubINTEL: *hasResult = true; *hasResultType = true; break; + case OpAbsUSubINTEL: *hasResult = true; *hasResultType = true; break; + case OpIAddSatINTEL: *hasResult = true; *hasResultType = true; break; + case OpUAddSatINTEL: *hasResult = true; *hasResultType = true; break; + case OpIAverageINTEL: *hasResult = true; *hasResultType = true; break; + case OpUAverageINTEL: *hasResult = true; *hasResultType = true; break; + case OpIAverageRoundedINTEL: *hasResult = true; *hasResultType = true; break; + case OpUAverageRoundedINTEL: *hasResult = true; *hasResultType = true; break; + case OpISubSatINTEL: *hasResult = true; *hasResultType = true; break; + case OpUSubSatINTEL: *hasResult = true; *hasResultType = true; break; + case OpIMul32x16INTEL: *hasResult = true; *hasResultType = true; break; + case OpUMul32x16INTEL: *hasResult = true; *hasResultType = true; break; + case OpConstantFunctionPointerINTEL: *hasResult = true; *hasResultType = true; break; + case OpFunctionPointerCallINTEL: *hasResult = true; *hasResultType = true; break; + case OpAsmTargetINTEL: *hasResult = true; *hasResultType = true; break; + case OpAsmINTEL: *hasResult = true; *hasResultType = true; break; + case OpAsmCallINTEL: *hasResult = true; *hasResultType = true; break; + case OpAtomicFMinEXT: *hasResult = true; *hasResultType = true; break; + case OpAtomicFMaxEXT: *hasResult = true; *hasResultType = true; break; + case OpAssumeTrueKHR: *hasResult = false; *hasResultType = false; break; + case OpExpectKHR: *hasResult = true; *hasResultType = true; break; + case OpDecorateString: *hasResult = false; *hasResultType = false; break; + case OpMemberDecorateString: *hasResult = false; *hasResultType = false; break; + case OpVmeImageINTEL: *hasResult = true; *hasResultType = true; break; + case OpTypeVmeImageINTEL: *hasResult = true; *hasResultType = false; break; + case OpTypeAvcImePayloadINTEL: *hasResult = true; *hasResultType = false; break; + case OpTypeAvcRefPayloadINTEL: *hasResult = true; *hasResultType = false; break; + case OpTypeAvcSicPayloadINTEL: *hasResult = true; *hasResultType = false; break; + case OpTypeAvcMcePayloadINTEL: *hasResult = true; *hasResultType = false; break; + case OpTypeAvcMceResultINTEL: *hasResult = true; *hasResultType = false; break; + case OpTypeAvcImeResultINTEL: *hasResult = true; *hasResultType = false; break; + case OpTypeAvcImeResultSingleReferenceStreamoutINTEL: *hasResult = true; *hasResultType = false; break; + case OpTypeAvcImeResultDualReferenceStreamoutINTEL: *hasResult = true; *hasResultType = false; break; + case OpTypeAvcImeSingleReferenceStreaminINTEL: *hasResult = true; *hasResultType = false; break; + case OpTypeAvcImeDualReferenceStreaminINTEL: *hasResult = true; *hasResultType = false; break; + case OpTypeAvcRefResultINTEL: *hasResult = true; *hasResultType = false; break; + case OpTypeAvcSicResultINTEL: *hasResult = true; *hasResultType = false; break; + case OpSubgroupAvcMceGetDefaultInterBaseMultiReferencePenaltyINTEL: *hasResult = true; *hasResultType = true; break; + case OpSubgroupAvcMceSetInterBaseMultiReferencePenaltyINTEL: *hasResult = true; *hasResultType = true; break; + case OpSubgroupAvcMceGetDefaultInterShapePenaltyINTEL: *hasResult = true; *hasResultType = true; break; + case OpSubgroupAvcMceSetInterShapePenaltyINTEL: *hasResult = true; *hasResultType = true; break; + case OpSubgroupAvcMceGetDefaultInterDirectionPenaltyINTEL: *hasResult = true; *hasResultType = true; break; + case OpSubgroupAvcMceSetInterDirectionPenaltyINTEL: *hasResult = true; *hasResultType = true; break; + case OpSubgroupAvcMceGetDefaultIntraLumaShapePenaltyINTEL: *hasResult = true; *hasResultType = true; break; + case OpSubgroupAvcMceGetDefaultInterMotionVectorCostTableINTEL: *hasResult = true; *hasResultType = true; break; + case OpSubgroupAvcMceGetDefaultHighPenaltyCostTableINTEL: *hasResult = true; *hasResultType = true; break; + case OpSubgroupAvcMceGetDefaultMediumPenaltyCostTableINTEL: *hasResult = true; *hasResultType = true; break; + case OpSubgroupAvcMceGetDefaultLowPenaltyCostTableINTEL: *hasResult = true; *hasResultType = true; break; + case OpSubgroupAvcMceSetMotionVectorCostFunctionINTEL: *hasResult = true; *hasResultType = true; break; + case OpSubgroupAvcMceGetDefaultIntraLumaModePenaltyINTEL: *hasResult = true; *hasResultType = true; break; + case OpSubgroupAvcMceGetDefaultNonDcLumaIntraPenaltyINTEL: *hasResult = true; *hasResultType = true; break; + case OpSubgroupAvcMceGetDefaultIntraChromaModeBasePenaltyINTEL: *hasResult = true; *hasResultType = true; break; + case OpSubgroupAvcMceSetAcOnlyHaarINTEL: *hasResult = true; *hasResultType = true; break; + case OpSubgroupAvcMceSetSourceInterlacedFieldPolarityINTEL: *hasResult = true; *hasResultType = true; break; + case OpSubgroupAvcMceSetSingleReferenceInterlacedFieldPolarityINTEL: *hasResult = true; *hasResultType = true; break; + case OpSubgroupAvcMceSetDualReferenceInterlacedFieldPolaritiesINTEL: *hasResult = true; *hasResultType = true; break; + case OpSubgroupAvcMceConvertToImePayloadINTEL: *hasResult = true; *hasResultType = true; break; + case OpSubgroupAvcMceConvertToImeResultINTEL: *hasResult = true; *hasResultType = true; break; + case OpSubgroupAvcMceConvertToRefPayloadINTEL: *hasResult = true; *hasResultType = true; break; + case OpSubgroupAvcMceConvertToRefResultINTEL: *hasResult = true; *hasResultType = true; break; + case OpSubgroupAvcMceConvertToSicPayloadINTEL: *hasResult = true; *hasResultType = true; break; + case OpSubgroupAvcMceConvertToSicResultINTEL: *hasResult = true; *hasResultType = true; break; + case OpSubgroupAvcMceGetMotionVectorsINTEL: *hasResult = true; *hasResultType = true; break; + case OpSubgroupAvcMceGetInterDistortionsINTEL: *hasResult = true; *hasResultType = true; break; + case OpSubgroupAvcMceGetBestInterDistortionsINTEL: *hasResult = true; *hasResultType = true; break; + case OpSubgroupAvcMceGetInterMajorShapeINTEL: *hasResult = true; *hasResultType = true; break; + case OpSubgroupAvcMceGetInterMinorShapeINTEL: *hasResult = true; *hasResultType = true; break; + case OpSubgroupAvcMceGetInterDirectionsINTEL: *hasResult = true; *hasResultType = true; break; + case OpSubgroupAvcMceGetInterMotionVectorCountINTEL: *hasResult = true; *hasResultType = true; break; + case OpSubgroupAvcMceGetInterReferenceIdsINTEL: *hasResult = true; *hasResultType = true; break; + case OpSubgroupAvcMceGetInterReferenceInterlacedFieldPolaritiesINTEL: *hasResult = true; *hasResultType = true; break; + case OpSubgroupAvcImeInitializeINTEL: *hasResult = true; *hasResultType = true; break; + case OpSubgroupAvcImeSetSingleReferenceINTEL: *hasResult = true; *hasResultType = true; break; + case OpSubgroupAvcImeSetDualReferenceINTEL: *hasResult = true; *hasResultType = true; break; + case OpSubgroupAvcImeRefWindowSizeINTEL: *hasResult = true; *hasResultType = true; break; + case OpSubgroupAvcImeAdjustRefOffsetINTEL: *hasResult = true; *hasResultType = true; break; + case OpSubgroupAvcImeConvertToMcePayloadINTEL: *hasResult = true; *hasResultType = true; break; + case OpSubgroupAvcImeSetMaxMotionVectorCountINTEL: *hasResult = true; *hasResultType = true; break; + case OpSubgroupAvcImeSetUnidirectionalMixDisableINTEL: *hasResult = true; *hasResultType = true; break; + case OpSubgroupAvcImeSetEarlySearchTerminationThresholdINTEL: *hasResult = true; *hasResultType = true; break; + case OpSubgroupAvcImeSetWeightedSadINTEL: *hasResult = true; *hasResultType = true; break; + case OpSubgroupAvcImeEvaluateWithSingleReferenceINTEL: *hasResult = true; *hasResultType = true; break; + case OpSubgroupAvcImeEvaluateWithDualReferenceINTEL: *hasResult = true; *hasResultType = true; break; + case OpSubgroupAvcImeEvaluateWithSingleReferenceStreaminINTEL: *hasResult = true; *hasResultType = true; break; + case OpSubgroupAvcImeEvaluateWithDualReferenceStreaminINTEL: *hasResult = true; *hasResultType = true; break; + case OpSubgroupAvcImeEvaluateWithSingleReferenceStreamoutINTEL: *hasResult = true; *hasResultType = true; break; + case OpSubgroupAvcImeEvaluateWithDualReferenceStreamoutINTEL: *hasResult = true; *hasResultType = true; break; + case OpSubgroupAvcImeEvaluateWithSingleReferenceStreaminoutINTEL: *hasResult = true; *hasResultType = true; break; + case OpSubgroupAvcImeEvaluateWithDualReferenceStreaminoutINTEL: *hasResult = true; *hasResultType = true; break; + case OpSubgroupAvcImeConvertToMceResultINTEL: *hasResult = true; *hasResultType = true; break; + case OpSubgroupAvcImeGetSingleReferenceStreaminINTEL: *hasResult = true; *hasResultType = true; break; + case OpSubgroupAvcImeGetDualReferenceStreaminINTEL: *hasResult = true; *hasResultType = true; break; + case OpSubgroupAvcImeStripSingleReferenceStreamoutINTEL: *hasResult = true; *hasResultType = true; break; + case OpSubgroupAvcImeStripDualReferenceStreamoutINTEL: *hasResult = true; *hasResultType = true; break; + case OpSubgroupAvcImeGetStreamoutSingleReferenceMajorShapeMotionVectorsINTEL: *hasResult = true; *hasResultType = true; break; + case OpSubgroupAvcImeGetStreamoutSingleReferenceMajorShapeDistortionsINTEL: *hasResult = true; *hasResultType = true; break; + case OpSubgroupAvcImeGetStreamoutSingleReferenceMajorShapeReferenceIdsINTEL: *hasResult = true; *hasResultType = true; break; + case OpSubgroupAvcImeGetStreamoutDualReferenceMajorShapeMotionVectorsINTEL: *hasResult = true; *hasResultType = true; break; + case OpSubgroupAvcImeGetStreamoutDualReferenceMajorShapeDistortionsINTEL: *hasResult = true; *hasResultType = true; break; + case OpSubgroupAvcImeGetStreamoutDualReferenceMajorShapeReferenceIdsINTEL: *hasResult = true; *hasResultType = true; break; + case OpSubgroupAvcImeGetBorderReachedINTEL: *hasResult = true; *hasResultType = true; break; + case OpSubgroupAvcImeGetTruncatedSearchIndicationINTEL: *hasResult = true; *hasResultType = true; break; + case OpSubgroupAvcImeGetUnidirectionalEarlySearchTerminationINTEL: *hasResult = true; *hasResultType = true; break; + case OpSubgroupAvcImeGetWeightingPatternMinimumMotionVectorINTEL: *hasResult = true; *hasResultType = true; break; + case OpSubgroupAvcImeGetWeightingPatternMinimumDistortionINTEL: *hasResult = true; *hasResultType = true; break; + case OpSubgroupAvcFmeInitializeINTEL: *hasResult = true; *hasResultType = true; break; + case OpSubgroupAvcBmeInitializeINTEL: *hasResult = true; *hasResultType = true; break; + case OpSubgroupAvcRefConvertToMcePayloadINTEL: *hasResult = true; *hasResultType = true; break; + case OpSubgroupAvcRefSetBidirectionalMixDisableINTEL: *hasResult = true; *hasResultType = true; break; + case OpSubgroupAvcRefSetBilinearFilterEnableINTEL: *hasResult = true; *hasResultType = true; break; + case OpSubgroupAvcRefEvaluateWithSingleReferenceINTEL: *hasResult = true; *hasResultType = true; break; + case OpSubgroupAvcRefEvaluateWithDualReferenceINTEL: *hasResult = true; *hasResultType = true; break; + case OpSubgroupAvcRefEvaluateWithMultiReferenceINTEL: *hasResult = true; *hasResultType = true; break; + case OpSubgroupAvcRefEvaluateWithMultiReferenceInterlacedINTEL: *hasResult = true; *hasResultType = true; break; + case OpSubgroupAvcRefConvertToMceResultINTEL: *hasResult = true; *hasResultType = true; break; + case OpSubgroupAvcSicInitializeINTEL: *hasResult = true; *hasResultType = true; break; + case OpSubgroupAvcSicConfigureSkcINTEL: *hasResult = true; *hasResultType = true; break; + case OpSubgroupAvcSicConfigureIpeLumaINTEL: *hasResult = true; *hasResultType = true; break; + case OpSubgroupAvcSicConfigureIpeLumaChromaINTEL: *hasResult = true; *hasResultType = true; break; + case OpSubgroupAvcSicGetMotionVectorMaskINTEL: *hasResult = true; *hasResultType = true; break; + case OpSubgroupAvcSicConvertToMcePayloadINTEL: *hasResult = true; *hasResultType = true; break; + case OpSubgroupAvcSicSetIntraLumaShapePenaltyINTEL: *hasResult = true; *hasResultType = true; break; + case OpSubgroupAvcSicSetIntraLumaModeCostFunctionINTEL: *hasResult = true; *hasResultType = true; break; + case OpSubgroupAvcSicSetIntraChromaModeCostFunctionINTEL: *hasResult = true; *hasResultType = true; break; + case OpSubgroupAvcSicSetBilinearFilterEnableINTEL: *hasResult = true; *hasResultType = true; break; + case OpSubgroupAvcSicSetSkcForwardTransformEnableINTEL: *hasResult = true; *hasResultType = true; break; + case OpSubgroupAvcSicSetBlockBasedRawSkipSadINTEL: *hasResult = true; *hasResultType = true; break; + case OpSubgroupAvcSicEvaluateIpeINTEL: *hasResult = true; *hasResultType = true; break; + case OpSubgroupAvcSicEvaluateWithSingleReferenceINTEL: *hasResult = true; *hasResultType = true; break; + case OpSubgroupAvcSicEvaluateWithDualReferenceINTEL: *hasResult = true; *hasResultType = true; break; + case OpSubgroupAvcSicEvaluateWithMultiReferenceINTEL: *hasResult = true; *hasResultType = true; break; + case OpSubgroupAvcSicEvaluateWithMultiReferenceInterlacedINTEL: *hasResult = true; *hasResultType = true; break; + case OpSubgroupAvcSicConvertToMceResultINTEL: *hasResult = true; *hasResultType = true; break; + case OpSubgroupAvcSicGetIpeLumaShapeINTEL: *hasResult = true; *hasResultType = true; break; + case OpSubgroupAvcSicGetBestIpeLumaDistortionINTEL: *hasResult = true; *hasResultType = true; break; + case OpSubgroupAvcSicGetBestIpeChromaDistortionINTEL: *hasResult = true; *hasResultType = true; break; + case OpSubgroupAvcSicGetPackedIpeLumaModesINTEL: *hasResult = true; *hasResultType = true; break; + case OpSubgroupAvcSicGetIpeChromaModeINTEL: *hasResult = true; *hasResultType = true; break; + case OpSubgroupAvcSicGetPackedSkcLumaCountThresholdINTEL: *hasResult = true; *hasResultType = true; break; + case OpSubgroupAvcSicGetPackedSkcLumaSumThresholdINTEL: *hasResult = true; *hasResultType = true; break; + case OpSubgroupAvcSicGetInterRawSadsINTEL: *hasResult = true; *hasResultType = true; break; + case OpVariableLengthArrayINTEL: *hasResult = true; *hasResultType = true; break; + case OpSaveMemoryINTEL: *hasResult = true; *hasResultType = true; break; + case OpRestoreMemoryINTEL: *hasResult = false; *hasResultType = false; break; + case OpArbitraryFloatSinCosPiINTEL: *hasResult = true; *hasResultType = true; break; + case OpArbitraryFloatCastINTEL: *hasResult = true; *hasResultType = true; break; + case OpArbitraryFloatCastFromIntINTEL: *hasResult = true; *hasResultType = true; break; + case OpArbitraryFloatCastToIntINTEL: *hasResult = true; *hasResultType = true; break; + case OpArbitraryFloatAddINTEL: *hasResult = true; *hasResultType = true; break; + case OpArbitraryFloatSubINTEL: *hasResult = true; *hasResultType = true; break; + case OpArbitraryFloatMulINTEL: *hasResult = true; *hasResultType = true; break; + case OpArbitraryFloatDivINTEL: *hasResult = true; *hasResultType = true; break; + case OpArbitraryFloatGTINTEL: *hasResult = true; *hasResultType = true; break; + case OpArbitraryFloatGEINTEL: *hasResult = true; *hasResultType = true; break; + case OpArbitraryFloatLTINTEL: *hasResult = true; *hasResultType = true; break; + case OpArbitraryFloatLEINTEL: *hasResult = true; *hasResultType = true; break; + case OpArbitraryFloatEQINTEL: *hasResult = true; *hasResultType = true; break; + case OpArbitraryFloatRecipINTEL: *hasResult = true; *hasResultType = true; break; + case OpArbitraryFloatRSqrtINTEL: *hasResult = true; *hasResultType = true; break; + case OpArbitraryFloatCbrtINTEL: *hasResult = true; *hasResultType = true; break; + case OpArbitraryFloatHypotINTEL: *hasResult = true; *hasResultType = true; break; + case OpArbitraryFloatSqrtINTEL: *hasResult = true; *hasResultType = true; break; + case OpArbitraryFloatLogINTEL: *hasResult = true; *hasResultType = true; break; + case OpArbitraryFloatLog2INTEL: *hasResult = true; *hasResultType = true; break; + case OpArbitraryFloatLog10INTEL: *hasResult = true; *hasResultType = true; break; + case OpArbitraryFloatLog1pINTEL: *hasResult = true; *hasResultType = true; break; + case OpArbitraryFloatExpINTEL: *hasResult = true; *hasResultType = true; break; + case OpArbitraryFloatExp2INTEL: *hasResult = true; *hasResultType = true; break; + case OpArbitraryFloatExp10INTEL: *hasResult = true; *hasResultType = true; break; + case OpArbitraryFloatExpm1INTEL: *hasResult = true; *hasResultType = true; break; + case OpArbitraryFloatSinINTEL: *hasResult = true; *hasResultType = true; break; + case OpArbitraryFloatCosINTEL: *hasResult = true; *hasResultType = true; break; + case OpArbitraryFloatSinCosINTEL: *hasResult = true; *hasResultType = true; break; + case OpArbitraryFloatSinPiINTEL: *hasResult = true; *hasResultType = true; break; + case OpArbitraryFloatCosPiINTEL: *hasResult = true; *hasResultType = true; break; + case OpArbitraryFloatASinINTEL: *hasResult = true; *hasResultType = true; break; + case OpArbitraryFloatASinPiINTEL: *hasResult = true; *hasResultType = true; break; + case OpArbitraryFloatACosINTEL: *hasResult = true; *hasResultType = true; break; + case OpArbitraryFloatACosPiINTEL: *hasResult = true; *hasResultType = true; break; + case OpArbitraryFloatATanINTEL: *hasResult = true; *hasResultType = true; break; + case OpArbitraryFloatATanPiINTEL: *hasResult = true; *hasResultType = true; break; + case OpArbitraryFloatATan2INTEL: *hasResult = true; *hasResultType = true; break; + case OpArbitraryFloatPowINTEL: *hasResult = true; *hasResultType = true; break; + case OpArbitraryFloatPowRINTEL: *hasResult = true; *hasResultType = true; break; + case OpArbitraryFloatPowNINTEL: *hasResult = true; *hasResultType = true; break; + case OpLoopControlINTEL: *hasResult = false; *hasResultType = false; break; + case OpAliasDomainDeclINTEL: *hasResult = true; *hasResultType = false; break; + case OpAliasScopeDeclINTEL: *hasResult = true; *hasResultType = false; break; + case OpAliasScopeListDeclINTEL: *hasResult = true; *hasResultType = false; break; + case OpFixedSqrtINTEL: *hasResult = true; *hasResultType = true; break; + case OpFixedRecipINTEL: *hasResult = true; *hasResultType = true; break; + case OpFixedRsqrtINTEL: *hasResult = true; *hasResultType = true; break; + case OpFixedSinINTEL: *hasResult = true; *hasResultType = true; break; + case OpFixedCosINTEL: *hasResult = true; *hasResultType = true; break; + case OpFixedSinCosINTEL: *hasResult = true; *hasResultType = true; break; + case OpFixedSinPiINTEL: *hasResult = true; *hasResultType = true; break; + case OpFixedCosPiINTEL: *hasResult = true; *hasResultType = true; break; + case OpFixedSinCosPiINTEL: *hasResult = true; *hasResultType = true; break; + case OpFixedLogINTEL: *hasResult = true; *hasResultType = true; break; + case OpFixedExpINTEL: *hasResult = true; *hasResultType = true; break; + case OpPtrCastToCrossWorkgroupINTEL: *hasResult = true; *hasResultType = true; break; + case OpCrossWorkgroupCastToPtrINTEL: *hasResult = true; *hasResultType = true; break; + case OpReadPipeBlockingINTEL: *hasResult = true; *hasResultType = true; break; + case OpWritePipeBlockingINTEL: *hasResult = true; *hasResultType = true; break; + case OpFPGARegINTEL: *hasResult = true; *hasResultType = true; break; + case OpRayQueryGetRayTMinKHR: *hasResult = true; *hasResultType = true; break; + case OpRayQueryGetRayFlagsKHR: *hasResult = true; *hasResultType = true; break; + case OpRayQueryGetIntersectionTKHR: *hasResult = true; *hasResultType = true; break; + case OpRayQueryGetIntersectionInstanceCustomIndexKHR: *hasResult = true; *hasResultType = true; break; + case OpRayQueryGetIntersectionInstanceIdKHR: *hasResult = true; *hasResultType = true; break; + case OpRayQueryGetIntersectionInstanceShaderBindingTableRecordOffsetKHR: *hasResult = true; *hasResultType = true; break; + case OpRayQueryGetIntersectionGeometryIndexKHR: *hasResult = true; *hasResultType = true; break; + case OpRayQueryGetIntersectionPrimitiveIndexKHR: *hasResult = true; *hasResultType = true; break; + case OpRayQueryGetIntersectionBarycentricsKHR: *hasResult = true; *hasResultType = true; break; + case OpRayQueryGetIntersectionFrontFaceKHR: *hasResult = true; *hasResultType = true; break; + case OpRayQueryGetIntersectionCandidateAABBOpaqueKHR: *hasResult = true; *hasResultType = true; break; + case OpRayQueryGetIntersectionObjectRayDirectionKHR: *hasResult = true; *hasResultType = true; break; + case OpRayQueryGetIntersectionObjectRayOriginKHR: *hasResult = true; *hasResultType = true; break; + case OpRayQueryGetWorldRayDirectionKHR: *hasResult = true; *hasResultType = true; break; + case OpRayQueryGetWorldRayOriginKHR: *hasResult = true; *hasResultType = true; break; + case OpRayQueryGetIntersectionObjectToWorldKHR: *hasResult = true; *hasResultType = true; break; + case OpRayQueryGetIntersectionWorldToObjectKHR: *hasResult = true; *hasResultType = true; break; + case OpAtomicFAddEXT: *hasResult = true; *hasResultType = true; break; + case OpTypeBufferSurfaceINTEL: *hasResult = true; *hasResultType = false; break; + case OpTypeStructContinuedINTEL: *hasResult = false; *hasResultType = false; break; + case OpConstantCompositeContinuedINTEL: *hasResult = false; *hasResultType = false; break; + case OpSpecConstantCompositeContinuedINTEL: *hasResult = false; *hasResultType = false; break; + case OpControlBarrierArriveINTEL: *hasResult = false; *hasResultType = false; break; + case OpControlBarrierWaitINTEL: *hasResult = false; *hasResultType = false; break; + case OpGroupIMulKHR: *hasResult = true; *hasResultType = true; break; + case OpGroupFMulKHR: *hasResult = true; *hasResultType = true; break; + case OpGroupBitwiseAndKHR: *hasResult = true; *hasResultType = true; break; + case OpGroupBitwiseOrKHR: *hasResult = true; *hasResultType = true; break; + case OpGroupBitwiseXorKHR: *hasResult = true; *hasResultType = true; break; + case OpGroupLogicalAndKHR: *hasResult = true; *hasResultType = true; break; + case OpGroupLogicalOrKHR: *hasResult = true; *hasResultType = true; break; + case OpGroupLogicalXorKHR: *hasResult = true; *hasResultType = true; break; + } +} +#endif /* SPV_ENABLE_UTILITY_CODE */ + +// Overload operator| for mask bit combining + +inline ImageOperandsMask operator|(ImageOperandsMask a, ImageOperandsMask b) { return ImageOperandsMask(unsigned(a) | unsigned(b)); } +inline FPFastMathModeMask operator|(FPFastMathModeMask a, FPFastMathModeMask b) { return FPFastMathModeMask(unsigned(a) | unsigned(b)); } +inline SelectionControlMask operator|(SelectionControlMask a, SelectionControlMask b) { return SelectionControlMask(unsigned(a) | unsigned(b)); } +inline LoopControlMask operator|(LoopControlMask a, LoopControlMask b) { return LoopControlMask(unsigned(a) | unsigned(b)); } +inline FunctionControlMask operator|(FunctionControlMask a, FunctionControlMask b) { return FunctionControlMask(unsigned(a) | unsigned(b)); } +inline MemorySemanticsMask operator|(MemorySemanticsMask a, MemorySemanticsMask b) { return MemorySemanticsMask(unsigned(a) | unsigned(b)); } +inline MemoryAccessMask operator|(MemoryAccessMask a, MemoryAccessMask b) { return MemoryAccessMask(unsigned(a) | unsigned(b)); } +inline KernelProfilingInfoMask operator|(KernelProfilingInfoMask a, KernelProfilingInfoMask b) { return KernelProfilingInfoMask(unsigned(a) | unsigned(b)); } +inline RayFlagsMask operator|(RayFlagsMask a, RayFlagsMask b) { return RayFlagsMask(unsigned(a) | unsigned(b)); } +inline FragmentShadingRateMask operator|(FragmentShadingRateMask a, FragmentShadingRateMask b) { return FragmentShadingRateMask(unsigned(a) | unsigned(b)); } + +} // end namespace spv + +#endif // #ifndef spirv_HPP + diff --git a/thirdparty/spirv-cross/spirv_cfg.cpp b/thirdparty/spirv-cross/spirv_cfg.cpp new file mode 100644 index 00000000000..93299479815 --- /dev/null +++ b/thirdparty/spirv-cross/spirv_cfg.cpp @@ -0,0 +1,430 @@ +/* + * Copyright 2016-2021 Arm Limited + * SPDX-License-Identifier: Apache-2.0 OR MIT + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +/* + * At your option, you may choose to accept this material under either: + * 1. The Apache License, Version 2.0, found at , or + * 2. The MIT License, found at . + */ + +#include "spirv_cfg.hpp" +#include "spirv_cross.hpp" +#include +#include + +using namespace std; + +namespace SPIRV_CROSS_NAMESPACE +{ +CFG::CFG(Compiler &compiler_, const SPIRFunction &func_) + : compiler(compiler_) + , func(func_) +{ + build_post_order_visit_order(); + build_immediate_dominators(); +} + +uint32_t CFG::find_common_dominator(uint32_t a, uint32_t b) const +{ + while (a != b) + { + if (get_visit_order(a) < get_visit_order(b)) + a = get_immediate_dominator(a); + else + b = get_immediate_dominator(b); + } + return a; +} + +void CFG::build_immediate_dominators() +{ + // Traverse the post-order in reverse and build up the immediate dominator tree. + immediate_dominators.clear(); + immediate_dominators[func.entry_block] = func.entry_block; + + for (auto i = post_order.size(); i; i--) + { + uint32_t block = post_order[i - 1]; + auto &pred = preceding_edges[block]; + if (pred.empty()) // This is for the entry block, but we've already set up the dominators. + continue; + + for (auto &edge : pred) + { + if (immediate_dominators[block]) + { + assert(immediate_dominators[edge]); + immediate_dominators[block] = find_common_dominator(immediate_dominators[block], edge); + } + else + immediate_dominators[block] = edge; + } + } +} + +bool CFG::is_back_edge(uint32_t to) const +{ + // We have a back edge if the visit order is set with the temporary magic value 0. + // Crossing edges will have already been recorded with a visit order. + auto itr = visit_order.find(to); + return itr != end(visit_order) && itr->second.get() == 0; +} + +bool CFG::has_visited_forward_edge(uint32_t to) const +{ + // If > 0, we have visited the edge already, and this is not a back edge branch. + auto itr = visit_order.find(to); + return itr != end(visit_order) && itr->second.get() > 0; +} + +bool CFG::post_order_visit(uint32_t block_id) +{ + // If we have already branched to this block (back edge), stop recursion. + // If our branches are back-edges, we do not record them. + // We have to record crossing edges however. + if (has_visited_forward_edge(block_id)) + return true; + else if (is_back_edge(block_id)) + return false; + + // Block back-edges from recursively revisiting ourselves. + visit_order[block_id].get() = 0; + + auto &block = compiler.get(block_id); + + // If this is a loop header, add an implied branch to the merge target. + // This is needed to avoid annoying cases with do { ... } while(false) loops often generated by inliners. + // To the CFG, this is linear control flow, but we risk picking the do/while scope as our dominating block. + // This makes sure that if we are accessing a variable outside the do/while, we choose the loop header as dominator. + // We could use has_visited_forward_edge, but this break code-gen where the merge block is unreachable in the CFG. + + // Make a point out of visiting merge target first. This is to make sure that post visit order outside the loop + // is lower than inside the loop, which is going to be key for some traversal algorithms like post-dominance analysis. + // For selection constructs true/false blocks will end up visiting the merge block directly and it works out fine, + // but for loops, only the header might end up actually branching to merge block. + if (block.merge == SPIRBlock::MergeLoop && post_order_visit(block.merge_block)) + add_branch(block_id, block.merge_block); + + // First visit our branch targets. + switch (block.terminator) + { + case SPIRBlock::Direct: + if (post_order_visit(block.next_block)) + add_branch(block_id, block.next_block); + break; + + case SPIRBlock::Select: + if (post_order_visit(block.true_block)) + add_branch(block_id, block.true_block); + if (post_order_visit(block.false_block)) + add_branch(block_id, block.false_block); + break; + + case SPIRBlock::MultiSelect: + { + const auto &cases = compiler.get_case_list(block); + for (const auto &target : cases) + { + if (post_order_visit(target.block)) + add_branch(block_id, target.block); + } + if (block.default_block && post_order_visit(block.default_block)) + add_branch(block_id, block.default_block); + break; + } + default: + break; + } + + // If this is a selection merge, add an implied branch to the merge target. + // This is needed to avoid cases where an inner branch dominates the outer branch. + // This can happen if one of the branches exit early, e.g.: + // if (cond) { ...; break; } else { var = 100 } use_var(var); + // We can use the variable without a Phi since there is only one possible parent here. + // However, in this case, we need to hoist out the inner variable to outside the branch. + // Use same strategy as loops. + if (block.merge == SPIRBlock::MergeSelection && post_order_visit(block.next_block)) + { + // If there is only one preceding edge to the merge block and it's not ourselves, we need a fixup. + // Add a fake branch so any dominator in either the if (), or else () block, or a lone case statement + // will be hoisted out to outside the selection merge. + // If size > 1, the variable will be automatically hoisted, so we should not mess with it. + // The exception here is switch blocks, where we can have multiple edges to merge block, + // all coming from same scope, so be more conservative in this case. + // Adding fake branches unconditionally breaks parameter preservation analysis, + // which looks at how variables are accessed through the CFG. + auto pred_itr = preceding_edges.find(block.next_block); + if (pred_itr != end(preceding_edges)) + { + auto &pred = pred_itr->second; + auto succ_itr = succeeding_edges.find(block_id); + size_t num_succeeding_edges = 0; + if (succ_itr != end(succeeding_edges)) + num_succeeding_edges = succ_itr->second.size(); + + if (block.terminator == SPIRBlock::MultiSelect && num_succeeding_edges == 1) + { + // Multiple branches can come from the same scope due to "break;", so we need to assume that all branches + // come from same case scope in worst case, even if there are multiple preceding edges. + // If we have more than one succeeding edge from the block header, it should be impossible + // to have a dominator be inside the block. + // Only case this can go wrong is if we have 2 or more edges from block header and + // 2 or more edges to merge block, and still have dominator be inside a case label. + if (!pred.empty()) + add_branch(block_id, block.next_block); + } + else + { + if (pred.size() == 1 && *pred.begin() != block_id) + add_branch(block_id, block.next_block); + } + } + else + { + // If the merge block does not have any preceding edges, i.e. unreachable, hallucinate it. + // We're going to do code-gen for it, and domination analysis requires that we have at least one preceding edge. + add_branch(block_id, block.next_block); + } + } + + // Then visit ourselves. Start counting at one, to let 0 be a magic value for testing back vs. crossing edges. + visit_order[block_id].get() = ++visit_count; + post_order.push_back(block_id); + return true; +} + +void CFG::build_post_order_visit_order() +{ + uint32_t block = func.entry_block; + visit_count = 0; + visit_order.clear(); + post_order.clear(); + post_order_visit(block); +} + +void CFG::add_branch(uint32_t from, uint32_t to) +{ + const auto add_unique = [](SmallVector &l, uint32_t value) { + auto itr = find(begin(l), end(l), value); + if (itr == end(l)) + l.push_back(value); + }; + add_unique(preceding_edges[to], from); + add_unique(succeeding_edges[from], to); +} + +uint32_t CFG::find_loop_dominator(uint32_t block_id) const +{ + while (block_id != SPIRBlock::NoDominator) + { + auto itr = preceding_edges.find(block_id); + if (itr == end(preceding_edges)) + return SPIRBlock::NoDominator; + if (itr->second.empty()) + return SPIRBlock::NoDominator; + + uint32_t pred_block_id = SPIRBlock::NoDominator; + bool ignore_loop_header = false; + + // If we are a merge block, go directly to the header block. + // Only consider a loop dominator if we are branching from inside a block to a loop header. + // NOTE: In the CFG we forced an edge from header to merge block always to support variable scopes properly. + for (auto &pred : itr->second) + { + auto &pred_block = compiler.get(pred); + if (pred_block.merge == SPIRBlock::MergeLoop && pred_block.merge_block == ID(block_id)) + { + pred_block_id = pred; + ignore_loop_header = true; + break; + } + else if (pred_block.merge == SPIRBlock::MergeSelection && pred_block.next_block == ID(block_id)) + { + pred_block_id = pred; + break; + } + } + + // No merge block means we can just pick any edge. Loop headers dominate the inner loop, so any path we + // take will lead there. + if (pred_block_id == SPIRBlock::NoDominator) + pred_block_id = itr->second.front(); + + block_id = pred_block_id; + + if (!ignore_loop_header && block_id) + { + auto &block = compiler.get(block_id); + if (block.merge == SPIRBlock::MergeLoop) + return block_id; + } + } + + return block_id; +} + +bool CFG::node_terminates_control_flow_in_sub_graph(BlockID from, BlockID to) const +{ + // Walk backwards, starting from "to" block. + // Only follow pred edges if they have a 1:1 relationship, or a merge relationship. + // If we cannot find a path to "from", we must assume that to is inside control flow in some way. + + auto &from_block = compiler.get(from); + BlockID ignore_block_id = 0; + if (from_block.merge == SPIRBlock::MergeLoop) + ignore_block_id = from_block.merge_block; + + while (to != from) + { + auto pred_itr = preceding_edges.find(to); + if (pred_itr == end(preceding_edges)) + return false; + + DominatorBuilder builder(*this); + for (auto &edge : pred_itr->second) + builder.add_block(edge); + + uint32_t dominator = builder.get_dominator(); + if (dominator == 0) + return false; + + auto &dom = compiler.get(dominator); + + bool true_path_ignore = false; + bool false_path_ignore = false; + + bool merges_to_nothing = dom.merge == SPIRBlock::MergeNone || + (dom.merge == SPIRBlock::MergeSelection && dom.next_block && + compiler.get(dom.next_block).terminator == SPIRBlock::Unreachable) || + (dom.merge == SPIRBlock::MergeLoop && dom.merge_block && + compiler.get(dom.merge_block).terminator == SPIRBlock::Unreachable); + + if (dom.self == from || merges_to_nothing) + { + // We can only ignore inner branchy paths if there is no merge, + // i.e. no code is generated afterwards. E.g. this allows us to elide continue: + // for (;;) { if (cond) { continue; } else { break; } }. + // Codegen here in SPIR-V will be something like either no merge if one path directly breaks, or + // we merge to Unreachable. + if (ignore_block_id && dom.terminator == SPIRBlock::Select) + { + auto &true_block = compiler.get(dom.true_block); + auto &false_block = compiler.get(dom.false_block); + auto &ignore_block = compiler.get(ignore_block_id); + true_path_ignore = compiler.execution_is_branchless(true_block, ignore_block); + false_path_ignore = compiler.execution_is_branchless(false_block, ignore_block); + } + } + + // Cases where we allow traversal. This serves as a proxy for post-dominance in a loop body. + // TODO: Might want to do full post-dominance analysis, but it's a lot of churn for something like this ... + // - We're the merge block of a selection construct. Jump to header. + // - We're the merge block of a loop. Jump to header. + // - Direct branch. Trivial. + // - Allow cases inside a branch if the header cannot merge execution before loop exit. + if ((dom.merge == SPIRBlock::MergeSelection && dom.next_block == to) || + (dom.merge == SPIRBlock::MergeLoop && dom.merge_block == to) || + (dom.terminator == SPIRBlock::Direct && dom.next_block == to) || + (dom.terminator == SPIRBlock::Select && dom.true_block == to && false_path_ignore) || + (dom.terminator == SPIRBlock::Select && dom.false_block == to && true_path_ignore)) + { + // Allow walking selection constructs if the other branch reaches out of a loop construct. + // It cannot be in-scope anymore. + to = dominator; + } + else + return false; + } + + return true; +} + +DominatorBuilder::DominatorBuilder(const CFG &cfg_) + : cfg(cfg_) +{ +} + +void DominatorBuilder::add_block(uint32_t block) +{ + if (!cfg.get_immediate_dominator(block)) + { + // Unreachable block via the CFG, we will never emit this code anyways. + return; + } + + if (!dominator) + { + dominator = block; + return; + } + + if (block != dominator) + dominator = cfg.find_common_dominator(block, dominator); +} + +void DominatorBuilder::lift_continue_block_dominator() +{ + // It is possible for a continue block to be the dominator of a variable is only accessed inside the while block of a do-while loop. + // We cannot safely declare variables inside a continue block, so move any variable declared + // in a continue block to the entry block to simplify. + // It makes very little sense for a continue block to ever be a dominator, so fall back to the simplest + // solution. + + if (!dominator) + return; + + auto &block = cfg.get_compiler().get(dominator); + auto post_order = cfg.get_visit_order(dominator); + + // If we are branching to a block with a higher post-order traversal index (continue blocks), we have a problem + // since we cannot create sensible GLSL code for this, fallback to entry block. + bool back_edge_dominator = false; + switch (block.terminator) + { + case SPIRBlock::Direct: + if (cfg.get_visit_order(block.next_block) > post_order) + back_edge_dominator = true; + break; + + case SPIRBlock::Select: + if (cfg.get_visit_order(block.true_block) > post_order) + back_edge_dominator = true; + if (cfg.get_visit_order(block.false_block) > post_order) + back_edge_dominator = true; + break; + + case SPIRBlock::MultiSelect: + { + auto &cases = cfg.get_compiler().get_case_list(block); + for (auto &target : cases) + { + if (cfg.get_visit_order(target.block) > post_order) + back_edge_dominator = true; + } + if (block.default_block && cfg.get_visit_order(block.default_block) > post_order) + back_edge_dominator = true; + break; + } + + default: + break; + } + + if (back_edge_dominator) + dominator = cfg.get_function().entry_block; +} +} // namespace SPIRV_CROSS_NAMESPACE diff --git a/thirdparty/spirv-cross/spirv_cfg.hpp b/thirdparty/spirv-cross/spirv_cfg.hpp new file mode 100644 index 00000000000..1d85fe0a97b --- /dev/null +++ b/thirdparty/spirv-cross/spirv_cfg.hpp @@ -0,0 +1,168 @@ +/* + * Copyright 2016-2021 Arm Limited + * SPDX-License-Identifier: Apache-2.0 OR MIT + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +/* + * At your option, you may choose to accept this material under either: + * 1. The Apache License, Version 2.0, found at , or + * 2. The MIT License, found at . + */ + +#ifndef SPIRV_CROSS_CFG_HPP +#define SPIRV_CROSS_CFG_HPP + +#include "spirv_common.hpp" +#include + +namespace SPIRV_CROSS_NAMESPACE +{ +class Compiler; +class CFG +{ +public: + CFG(Compiler &compiler, const SPIRFunction &function); + + Compiler &get_compiler() + { + return compiler; + } + + const Compiler &get_compiler() const + { + return compiler; + } + + const SPIRFunction &get_function() const + { + return func; + } + + uint32_t get_immediate_dominator(uint32_t block) const + { + auto itr = immediate_dominators.find(block); + if (itr != std::end(immediate_dominators)) + return itr->second; + else + return 0; + } + + bool is_reachable(uint32_t block) const + { + return visit_order.count(block) != 0; + } + + uint32_t get_visit_order(uint32_t block) const + { + auto itr = visit_order.find(block); + assert(itr != std::end(visit_order)); + int v = itr->second.get(); + assert(v > 0); + return uint32_t(v); + } + + uint32_t find_common_dominator(uint32_t a, uint32_t b) const; + + const SmallVector &get_preceding_edges(uint32_t block) const + { + auto itr = preceding_edges.find(block); + if (itr != std::end(preceding_edges)) + return itr->second; + else + return empty_vector; + } + + const SmallVector &get_succeeding_edges(uint32_t block) const + { + auto itr = succeeding_edges.find(block); + if (itr != std::end(succeeding_edges)) + return itr->second; + else + return empty_vector; + } + + template + void walk_from(std::unordered_set &seen_blocks, uint32_t block, const Op &op) const + { + if (seen_blocks.count(block)) + return; + seen_blocks.insert(block); + + if (op(block)) + { + for (auto b : get_succeeding_edges(block)) + walk_from(seen_blocks, b, op); + } + } + + uint32_t find_loop_dominator(uint32_t block) const; + + bool node_terminates_control_flow_in_sub_graph(BlockID from, BlockID to) const; + +private: + struct VisitOrder + { + int &get() + { + return v; + } + + const int &get() const + { + return v; + } + + int v = -1; + }; + + Compiler &compiler; + const SPIRFunction &func; + std::unordered_map> preceding_edges; + std::unordered_map> succeeding_edges; + std::unordered_map immediate_dominators; + std::unordered_map visit_order; + SmallVector post_order; + SmallVector empty_vector; + + void add_branch(uint32_t from, uint32_t to); + void build_post_order_visit_order(); + void build_immediate_dominators(); + bool post_order_visit(uint32_t block); + uint32_t visit_count = 0; + + bool is_back_edge(uint32_t to) const; + bool has_visited_forward_edge(uint32_t to) const; +}; + +class DominatorBuilder +{ +public: + DominatorBuilder(const CFG &cfg); + + void add_block(uint32_t block); + uint32_t get_dominator() const + { + return dominator; + } + + void lift_continue_block_dominator(); + +private: + const CFG &cfg; + uint32_t dominator = 0; +}; +} // namespace SPIRV_CROSS_NAMESPACE + +#endif diff --git a/thirdparty/spirv-cross/spirv_common.hpp b/thirdparty/spirv-cross/spirv_common.hpp new file mode 100644 index 00000000000..93b26697709 --- /dev/null +++ b/thirdparty/spirv-cross/spirv_common.hpp @@ -0,0 +1,1943 @@ +/* + * Copyright 2015-2021 Arm Limited + * SPDX-License-Identifier: Apache-2.0 OR MIT + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +/* + * At your option, you may choose to accept this material under either: + * 1. The Apache License, Version 2.0, found at , or + * 2. The MIT License, found at . + */ + +#ifndef SPIRV_CROSS_COMMON_HPP +#define SPIRV_CROSS_COMMON_HPP + +#ifndef SPV_ENABLE_UTILITY_CODE +#define SPV_ENABLE_UTILITY_CODE +#endif +#include "spirv.hpp" + +#include "spirv_cross_containers.hpp" +#include "spirv_cross_error_handling.hpp" +#include + +// A bit crude, but allows projects which embed SPIRV-Cross statically to +// effectively hide all the symbols from other projects. +// There is a case where we have: +// - Project A links against SPIRV-Cross statically. +// - Project A links against Project B statically. +// - Project B links against SPIRV-Cross statically (might be a different version). +// This leads to a conflict with extremely bizarre results. +// By overriding the namespace in one of the project builds, we can work around this. +// If SPIRV-Cross is embedded in dynamic libraries, +// prefer using -fvisibility=hidden on GCC/Clang instead. +#ifdef SPIRV_CROSS_NAMESPACE_OVERRIDE +#define SPIRV_CROSS_NAMESPACE SPIRV_CROSS_NAMESPACE_OVERRIDE +#else +#define SPIRV_CROSS_NAMESPACE spirv_cross +#endif + +namespace SPIRV_CROSS_NAMESPACE +{ +namespace inner +{ +template +void join_helper(StringStream<> &stream, T &&t) +{ + stream << std::forward(t); +} + +template +void join_helper(StringStream<> &stream, T &&t, Ts &&... ts) +{ + stream << std::forward(t); + join_helper(stream, std::forward(ts)...); +} +} // namespace inner + +class Bitset +{ +public: + Bitset() = default; + explicit inline Bitset(uint64_t lower_) + : lower(lower_) + { + } + + inline bool get(uint32_t bit) const + { + if (bit < 64) + return (lower & (1ull << bit)) != 0; + else + return higher.count(bit) != 0; + } + + inline void set(uint32_t bit) + { + if (bit < 64) + lower |= 1ull << bit; + else + higher.insert(bit); + } + + inline void clear(uint32_t bit) + { + if (bit < 64) + lower &= ~(1ull << bit); + else + higher.erase(bit); + } + + inline uint64_t get_lower() const + { + return lower; + } + + inline void reset() + { + lower = 0; + higher.clear(); + } + + inline void merge_and(const Bitset &other) + { + lower &= other.lower; + std::unordered_set tmp_set; + for (auto &v : higher) + if (other.higher.count(v) != 0) + tmp_set.insert(v); + higher = std::move(tmp_set); + } + + inline void merge_or(const Bitset &other) + { + lower |= other.lower; + for (auto &v : other.higher) + higher.insert(v); + } + + inline bool operator==(const Bitset &other) const + { + if (lower != other.lower) + return false; + + if (higher.size() != other.higher.size()) + return false; + + for (auto &v : higher) + if (other.higher.count(v) == 0) + return false; + + return true; + } + + inline bool operator!=(const Bitset &other) const + { + return !(*this == other); + } + + template + void for_each_bit(const Op &op) const + { + // TODO: Add ctz-based iteration. + for (uint32_t i = 0; i < 64; i++) + { + if (lower & (1ull << i)) + op(i); + } + + if (higher.empty()) + return; + + // Need to enforce an order here for reproducible results, + // but hitting this path should happen extremely rarely, so having this slow path is fine. + SmallVector bits; + bits.reserve(higher.size()); + for (auto &v : higher) + bits.push_back(v); + std::sort(std::begin(bits), std::end(bits)); + + for (auto &v : bits) + op(v); + } + + inline bool empty() const + { + return lower == 0 && higher.empty(); + } + +private: + // The most common bits to set are all lower than 64, + // so optimize for this case. Bits spilling outside 64 go into a slower data structure. + // In almost all cases, higher data structure will not be used. + uint64_t lower = 0; + std::unordered_set higher; +}; + +// Helper template to avoid lots of nasty string temporary munging. +template +std::string join(Ts &&... ts) +{ + StringStream<> stream; + inner::join_helper(stream, std::forward(ts)...); + return stream.str(); +} + +inline std::string merge(const SmallVector &list, const char *between = ", ") +{ + StringStream<> stream; + for (auto &elem : list) + { + stream << elem; + if (&elem != &list.back()) + stream << between; + } + return stream.str(); +} + +// Make sure we don't accidentally call this with float or doubles with SFINAE. +// Have to use the radix-aware overload. +template ::value, int>::type = 0> +inline std::string convert_to_string(const T &t) +{ + return std::to_string(t); +} + +static inline std::string convert_to_string(int32_t value) +{ + // INT_MIN is ... special on some backends. If we use a decimal literal, and negate it, we + // could accidentally promote the literal to long first, then negate. + // To workaround it, emit int(0x80000000) instead. + if (value == (std::numeric_limits::min)()) + return "int(0x80000000)"; + else + return std::to_string(value); +} + +static inline std::string convert_to_string(int64_t value, const std::string &int64_type, bool long_long_literal_suffix) +{ + // INT64_MIN is ... special on some backends. + // If we use a decimal literal, and negate it, we might overflow the representable numbers. + // To workaround it, emit int(0x80000000) instead. + if (value == (std::numeric_limits::min)()) + return join(int64_type, "(0x8000000000000000u", (long_long_literal_suffix ? "ll" : "l"), ")"); + else + return std::to_string(value) + (long_long_literal_suffix ? "ll" : "l"); +} + +// Allow implementations to set a convenient standard precision +#ifndef SPIRV_CROSS_FLT_FMT +#define SPIRV_CROSS_FLT_FMT "%.32g" +#endif + +// Disable sprintf and strcat warnings. +// We cannot rely on snprintf and family existing because, ..., MSVC. +#if defined(__clang__) || defined(__GNUC__) +#pragma GCC diagnostic push +#pragma GCC diagnostic ignored "-Wdeprecated-declarations" +#elif defined(_MSC_VER) +#pragma warning(push) +#pragma warning(disable : 4996) +#endif + +static inline void fixup_radix_point(char *str, char radix_point) +{ + // Setting locales is a very risky business in multi-threaded program, + // so just fixup locales instead. We only need to care about the radix point. + if (radix_point != '.') + { + while (*str != '\0') + { + if (*str == radix_point) + *str = '.'; + str++; + } + } +} + +inline std::string convert_to_string(float t, char locale_radix_point) +{ + // std::to_string for floating point values is broken. + // Fallback to something more sane. + char buf[64]; + sprintf(buf, SPIRV_CROSS_FLT_FMT, t); + fixup_radix_point(buf, locale_radix_point); + + // Ensure that the literal is float. + if (!strchr(buf, '.') && !strchr(buf, 'e')) + strcat(buf, ".0"); + return buf; +} + +inline std::string convert_to_string(double t, char locale_radix_point) +{ + // std::to_string for floating point values is broken. + // Fallback to something more sane. + char buf[64]; + sprintf(buf, SPIRV_CROSS_FLT_FMT, t); + fixup_radix_point(buf, locale_radix_point); + + // Ensure that the literal is float. + if (!strchr(buf, '.') && !strchr(buf, 'e')) + strcat(buf, ".0"); + return buf; +} + +#if defined(__clang__) || defined(__GNUC__) +#pragma GCC diagnostic pop +#elif defined(_MSC_VER) +#pragma warning(pop) +#endif + +class FloatFormatter +{ +public: + virtual ~FloatFormatter() = default; + virtual std::string format_float(float value) = 0; + virtual std::string format_double(double value) = 0; +}; + +template +struct ValueSaver +{ + explicit ValueSaver(T ¤t_) + : current(current_) + , saved(current_) + { + } + + void release() + { + current = saved; + } + + ~ValueSaver() + { + release(); + } + + T ¤t; + T saved; +}; + +struct Instruction +{ + uint16_t op = 0; + uint16_t count = 0; + // If offset is 0 (not a valid offset into the instruction stream), + // we have an instruction stream which is embedded in the object. + uint32_t offset = 0; + uint32_t length = 0; + + inline bool is_embedded() const + { + return offset == 0; + } +}; + +struct EmbeddedInstruction : Instruction +{ + SmallVector ops; +}; + +enum Types +{ + TypeNone, + TypeType, + TypeVariable, + TypeConstant, + TypeFunction, + TypeFunctionPrototype, + TypeBlock, + TypeExtension, + TypeExpression, + TypeConstantOp, + TypeCombinedImageSampler, + TypeAccessChain, + TypeUndef, + TypeString, + TypeCount +}; + +template +class TypedID; + +template <> +class TypedID +{ +public: + TypedID() = default; + TypedID(uint32_t id_) + : id(id_) + { + } + + template + TypedID(const TypedID &other) + { + *this = other; + } + + template + TypedID &operator=(const TypedID &other) + { + id = uint32_t(other); + return *this; + } + + // Implicit conversion to u32 is desired here. + // As long as we block implicit conversion between TypedID and TypedID we're good. + operator uint32_t() const + { + return id; + } + + template + operator TypedID() const + { + return TypedID(*this); + } + +private: + uint32_t id = 0; +}; + +template +class TypedID +{ +public: + TypedID() = default; + TypedID(uint32_t id_) + : id(id_) + { + } + + explicit TypedID(const TypedID &other) + : id(uint32_t(other)) + { + } + + operator uint32_t() const + { + return id; + } + +private: + uint32_t id = 0; +}; + +using VariableID = TypedID; +using TypeID = TypedID; +using ConstantID = TypedID; +using FunctionID = TypedID; +using BlockID = TypedID; +using ID = TypedID; + +// Helper for Variant interface. +struct IVariant +{ + virtual ~IVariant() = default; + virtual IVariant *clone(ObjectPoolBase *pool) = 0; + ID self = 0; + +protected: + IVariant() = default; + IVariant(const IVariant&) = default; + IVariant &operator=(const IVariant&) = default; +}; + +#define SPIRV_CROSS_DECLARE_CLONE(T) \ + IVariant *clone(ObjectPoolBase *pool) override \ + { \ + return static_cast *>(pool)->allocate(*this); \ + } + +struct SPIRUndef : IVariant +{ + enum + { + type = TypeUndef + }; + + explicit SPIRUndef(TypeID basetype_) + : basetype(basetype_) + { + } + TypeID basetype; + + SPIRV_CROSS_DECLARE_CLONE(SPIRUndef) +}; + +struct SPIRString : IVariant +{ + enum + { + type = TypeString + }; + + explicit SPIRString(std::string str_) + : str(std::move(str_)) + { + } + + std::string str; + + SPIRV_CROSS_DECLARE_CLONE(SPIRString) +}; + +// This type is only used by backends which need to access the combined image and sampler IDs separately after +// the OpSampledImage opcode. +struct SPIRCombinedImageSampler : IVariant +{ + enum + { + type = TypeCombinedImageSampler + }; + SPIRCombinedImageSampler(TypeID type_, VariableID image_, VariableID sampler_) + : combined_type(type_) + , image(image_) + , sampler(sampler_) + { + } + TypeID combined_type; + VariableID image; + VariableID sampler; + + SPIRV_CROSS_DECLARE_CLONE(SPIRCombinedImageSampler) +}; + +struct SPIRConstantOp : IVariant +{ + enum + { + type = TypeConstantOp + }; + + SPIRConstantOp(TypeID result_type, spv::Op op, const uint32_t *args, uint32_t length) + : opcode(op) + , basetype(result_type) + { + arguments.reserve(length); + for (uint32_t i = 0; i < length; i++) + arguments.push_back(args[i]); + } + + spv::Op opcode; + SmallVector arguments; + TypeID basetype; + + SPIRV_CROSS_DECLARE_CLONE(SPIRConstantOp) +}; + +struct SPIRType : IVariant +{ + enum + { + type = TypeType + }; + + spv::Op op = spv::Op::OpNop; + explicit SPIRType(spv::Op op_) : op(op_) {} + + enum BaseType + { + Unknown, + Void, + Boolean, + SByte, + UByte, + Short, + UShort, + Int, + UInt, + Int64, + UInt64, + AtomicCounter, + Half, + Float, + Double, + Struct, + Image, + SampledImage, + Sampler, + AccelerationStructure, + RayQuery, + + // Keep internal types at the end. + ControlPointArray, + Interpolant, + Char + }; + + // Scalar/vector/matrix support. + BaseType basetype = Unknown; + uint32_t width = 0; + uint32_t vecsize = 1; + uint32_t columns = 1; + + // Arrays, support array of arrays by having a vector of array sizes. + SmallVector array; + + // Array elements can be either specialization constants or specialization ops. + // This array determines how to interpret the array size. + // If an element is true, the element is a literal, + // otherwise, it's an expression, which must be resolved on demand. + // The actual size is not really known until runtime. + SmallVector array_size_literal; + + // Pointers + // Keep track of how many pointer layers we have. + uint32_t pointer_depth = 0; + bool pointer = false; + bool forward_pointer = false; + + spv::StorageClass storage = spv::StorageClassGeneric; + + SmallVector member_types; + + // If member order has been rewritten to handle certain scenarios with Offset, + // allow codegen to rewrite the index. + SmallVector member_type_index_redirection; + + struct ImageType + { + TypeID type; + spv::Dim dim; + bool depth; + bool arrayed; + bool ms; + uint32_t sampled; + spv::ImageFormat format; + spv::AccessQualifier access; + } image = {}; + + // Structs can be declared multiple times if they are used as part of interface blocks. + // We want to detect this so that we only emit the struct definition once. + // Since we cannot rely on OpName to be equal, we need to figure out aliases. + TypeID type_alias = 0; + + // Denotes the type which this type is based on. + // Allows the backend to traverse how a complex type is built up during access chains. + TypeID parent_type = 0; + + // Used in backends to avoid emitting members with conflicting names. + std::unordered_set member_name_cache; + + SPIRV_CROSS_DECLARE_CLONE(SPIRType) +}; + +struct SPIRExtension : IVariant +{ + enum + { + type = TypeExtension + }; + + enum Extension + { + Unsupported, + GLSL, + SPV_debug_info, + SPV_AMD_shader_ballot, + SPV_AMD_shader_explicit_vertex_parameter, + SPV_AMD_shader_trinary_minmax, + SPV_AMD_gcn_shader, + NonSemanticDebugPrintf, + NonSemanticShaderDebugInfo, + NonSemanticGeneric + }; + + explicit SPIRExtension(Extension ext_) + : ext(ext_) + { + } + + Extension ext; + SPIRV_CROSS_DECLARE_CLONE(SPIRExtension) +}; + +// SPIREntryPoint is not a variant since its IDs are used to decorate OpFunction, +// so in order to avoid conflicts, we can't stick them in the ids array. +struct SPIREntryPoint +{ + SPIREntryPoint(FunctionID self_, spv::ExecutionModel execution_model, const std::string &entry_name) + : self(self_) + , name(entry_name) + , orig_name(entry_name) + , model(execution_model) + { + } + SPIREntryPoint() = default; + + FunctionID self = 0; + std::string name; + std::string orig_name; + SmallVector interface_variables; + + Bitset flags; + struct WorkgroupSize + { + uint32_t x = 0, y = 0, z = 0; + uint32_t id_x = 0, id_y = 0, id_z = 0; + uint32_t constant = 0; // Workgroup size can be expressed as a constant/spec-constant instead. + } workgroup_size; + uint32_t invocations = 0; + uint32_t output_vertices = 0; + uint32_t output_primitives = 0; + spv::ExecutionModel model = spv::ExecutionModelMax; + bool geometry_passthrough = false; +}; + +struct SPIRExpression : IVariant +{ + enum + { + type = TypeExpression + }; + + // Only created by the backend target to avoid creating tons of temporaries. + SPIRExpression(std::string expr, TypeID expression_type_, bool immutable_) + : expression(std::move(expr)) + , expression_type(expression_type_) + , immutable(immutable_) + { + } + + // If non-zero, prepend expression with to_expression(base_expression). + // Used in amortizing multiple calls to to_expression() + // where in certain cases that would quickly force a temporary when not needed. + ID base_expression = 0; + + std::string expression; + TypeID expression_type = 0; + + // If this expression is a forwarded load, + // allow us to reference the original variable. + ID loaded_from = 0; + + // If this expression will never change, we can avoid lots of temporaries + // in high level source. + // An expression being immutable can be speculative, + // it is assumed that this is true almost always. + bool immutable = false; + + // Before use, this expression must be transposed. + // This is needed for targets which don't support row_major layouts. + bool need_transpose = false; + + // Whether or not this is an access chain expression. + bool access_chain = false; + + // Whether or not gl_MeshVerticesEXT[].gl_Position (as a whole or .y) is referenced + bool access_meshlet_position_y = false; + + // A list of expressions which this expression depends on. + SmallVector expression_dependencies; + + // By reading this expression, we implicitly read these expressions as well. + // Used by access chain Store and Load since we read multiple expressions in this case. + SmallVector implied_read_expressions; + + // The expression was emitted at a certain scope. Lets us track when an expression read means multiple reads. + uint32_t emitted_loop_level = 0; + + SPIRV_CROSS_DECLARE_CLONE(SPIRExpression) +}; + +struct SPIRFunctionPrototype : IVariant +{ + enum + { + type = TypeFunctionPrototype + }; + + explicit SPIRFunctionPrototype(TypeID return_type_) + : return_type(return_type_) + { + } + + TypeID return_type; + SmallVector parameter_types; + + SPIRV_CROSS_DECLARE_CLONE(SPIRFunctionPrototype) +}; + +struct SPIRBlock : IVariant +{ + enum + { + type = TypeBlock + }; + + enum Terminator + { + Unknown, + Direct, // Emit next block directly without a particular condition. + + Select, // Block ends with an if/else block. + MultiSelect, // Block ends with switch statement. + + Return, // Block ends with return. + Unreachable, // Noop + Kill, // Discard + IgnoreIntersection, // Ray Tracing + TerminateRay, // Ray Tracing + EmitMeshTasks // Mesh shaders + }; + + enum Merge + { + MergeNone, + MergeLoop, + MergeSelection + }; + + enum Hints + { + HintNone, + HintUnroll, + HintDontUnroll, + HintFlatten, + HintDontFlatten + }; + + enum Method + { + MergeToSelectForLoop, + MergeToDirectForLoop, + MergeToSelectContinueForLoop + }; + + enum ContinueBlockType + { + ContinueNone, + + // Continue block is branchless and has at least one instruction. + ForLoop, + + // Noop continue block. + WhileLoop, + + // Continue block is conditional. + DoWhileLoop, + + // Highly unlikely that anything will use this, + // since it is really awkward/impossible to express in GLSL. + ComplexLoop + }; + + enum : uint32_t + { + NoDominator = 0xffffffffu + }; + + Terminator terminator = Unknown; + Merge merge = MergeNone; + Hints hint = HintNone; + BlockID next_block = 0; + BlockID merge_block = 0; + BlockID continue_block = 0; + + ID return_value = 0; // If 0, return nothing (void). + ID condition = 0; + BlockID true_block = 0; + BlockID false_block = 0; + BlockID default_block = 0; + + // If terminator is EmitMeshTasksEXT. + struct + { + ID groups[3]; + ID payload; + } mesh = {}; + + SmallVector ops; + + struct Phi + { + ID local_variable; // flush local variable ... + BlockID parent; // If we're in from_block and want to branch into this block ... + VariableID function_variable; // to this function-global "phi" variable first. + }; + + // Before entering this block flush out local variables to magical "phi" variables. + SmallVector phi_variables; + + // Declare these temporaries before beginning the block. + // Used for handling complex continue blocks which have side effects. + SmallVector> declare_temporary; + + // Declare these temporaries, but only conditionally if this block turns out to be + // a complex loop header. + SmallVector> potential_declare_temporary; + + struct Case + { + uint64_t value; + BlockID block; + }; + SmallVector cases_32bit; + SmallVector cases_64bit; + + // If we have tried to optimize code for this block but failed, + // keep track of this. + bool disable_block_optimization = false; + + // If the continue block is complex, fallback to "dumb" for loops. + bool complex_continue = false; + + // Do we need a ladder variable to defer breaking out of a loop construct after a switch block? + bool need_ladder_break = false; + + // If marked, we have explicitly handled Phi from this block, so skip any flushes related to that on a branch. + // Used to handle an edge case with switch and case-label fallthrough where fall-through writes to Phi. + BlockID ignore_phi_from_block = 0; + + // The dominating block which this block might be within. + // Used in continue; blocks to determine if we really need to write continue. + BlockID loop_dominator = 0; + + // All access to these variables are dominated by this block, + // so before branching anywhere we need to make sure that we declare these variables. + SmallVector dominated_variables; + + // These are variables which should be declared in a for loop header, if we + // fail to use a classic for-loop, + // we remove these variables, and fall back to regular variables outside the loop. + SmallVector loop_variables; + + // Some expressions are control-flow dependent, i.e. any instruction which relies on derivatives or + // sub-group-like operations. + // Make sure that we only use these expressions in the original block. + SmallVector invalidate_expressions; + + SPIRV_CROSS_DECLARE_CLONE(SPIRBlock) +}; + +struct SPIRFunction : IVariant +{ + enum + { + type = TypeFunction + }; + + SPIRFunction(TypeID return_type_, TypeID function_type_) + : return_type(return_type_) + , function_type(function_type_) + { + } + + struct Parameter + { + TypeID type; + ID id; + uint32_t read_count; + uint32_t write_count; + + // Set to true if this parameter aliases a global variable, + // used mostly in Metal where global variables + // have to be passed down to functions as regular arguments. + // However, for this kind of variable, we should not care about + // read and write counts as access to the function arguments + // is not local to the function in question. + bool alias_global_variable; + }; + + // When calling a function, and we're remapping separate image samplers, + // resolve these arguments into combined image samplers and pass them + // as additional arguments in this order. + // It gets more complicated as functions can pull in their own globals + // and combine them with parameters, + // so we need to distinguish if something is local parameter index + // or a global ID. + struct CombinedImageSamplerParameter + { + VariableID id; + VariableID image_id; + VariableID sampler_id; + bool global_image; + bool global_sampler; + bool depth; + }; + + TypeID return_type; + TypeID function_type; + SmallVector arguments; + + // Can be used by backends to add magic arguments. + // Currently used by combined image/sampler implementation. + + SmallVector shadow_arguments; + SmallVector local_variables; + BlockID entry_block = 0; + SmallVector blocks; + SmallVector combined_parameters; + + struct EntryLine + { + uint32_t file_id = 0; + uint32_t line_literal = 0; + }; + EntryLine entry_line; + + void add_local_variable(VariableID id) + { + local_variables.push_back(id); + } + + void add_parameter(TypeID parameter_type, ID id, bool alias_global_variable = false) + { + // Arguments are read-only until proven otherwise. + arguments.push_back({ parameter_type, id, 0u, 0u, alias_global_variable }); + } + + // Hooks to be run when the function returns. + // Mostly used for lowering internal data structures onto flattened structures. + // Need to defer this, because they might rely on things which change during compilation. + // Intentionally not a small vector, this one is rare, and std::function can be large. + Vector> fixup_hooks_out; + + // Hooks to be run when the function begins. + // Mostly used for populating internal data structures from flattened structures. + // Need to defer this, because they might rely on things which change during compilation. + // Intentionally not a small vector, this one is rare, and std::function can be large. + Vector> fixup_hooks_in; + + // On function entry, make sure to copy a constant array into thread addr space to work around + // the case where we are passing a constant array by value to a function on backends which do not + // consider arrays value types. + SmallVector constant_arrays_needed_on_stack; + + bool active = false; + bool flush_undeclared = true; + bool do_combined_parameters = true; + + SPIRV_CROSS_DECLARE_CLONE(SPIRFunction) +}; + +struct SPIRAccessChain : IVariant +{ + enum + { + type = TypeAccessChain + }; + + SPIRAccessChain(TypeID basetype_, spv::StorageClass storage_, std::string base_, std::string dynamic_index_, + int32_t static_index_) + : basetype(basetype_) + , storage(storage_) + , base(std::move(base_)) + , dynamic_index(std::move(dynamic_index_)) + , static_index(static_index_) + { + } + + // The access chain represents an offset into a buffer. + // Some backends need more complicated handling of access chains to be able to use buffers, like HLSL + // which has no usable buffer type ala GLSL SSBOs. + // StructuredBuffer is too limited, so our only option is to deal with ByteAddressBuffer which works with raw addresses. + + TypeID basetype; + spv::StorageClass storage; + std::string base; + std::string dynamic_index; + int32_t static_index; + + VariableID loaded_from = 0; + uint32_t matrix_stride = 0; + uint32_t array_stride = 0; + bool row_major_matrix = false; + bool immutable = false; + + // By reading this expression, we implicitly read these expressions as well. + // Used by access chain Store and Load since we read multiple expressions in this case. + SmallVector implied_read_expressions; + + SPIRV_CROSS_DECLARE_CLONE(SPIRAccessChain) +}; + +struct SPIRVariable : IVariant +{ + enum + { + type = TypeVariable + }; + + SPIRVariable() = default; + SPIRVariable(TypeID basetype_, spv::StorageClass storage_, ID initializer_ = 0, VariableID basevariable_ = 0) + : basetype(basetype_) + , storage(storage_) + , initializer(initializer_) + , basevariable(basevariable_) + { + } + + TypeID basetype = 0; + spv::StorageClass storage = spv::StorageClassGeneric; + uint32_t decoration = 0; + ID initializer = 0; + VariableID basevariable = 0; + + SmallVector dereference_chain; + bool compat_builtin = false; + + // If a variable is shadowed, we only statically assign to it + // and never actually emit a statement for it. + // When we read the variable as an expression, just forward + // shadowed_id as the expression. + bool statically_assigned = false; + ID static_expression = 0; + + // Temporaries which can remain forwarded as long as this variable is not modified. + SmallVector dependees; + + bool deferred_declaration = false; + bool phi_variable = false; + + // Used to deal with Phi variable flushes. See flush_phi(). + bool allocate_temporary_copy = false; + + bool remapped_variable = false; + uint32_t remapped_components = 0; + + // The block which dominates all access to this variable. + BlockID dominator = 0; + // If true, this variable is a loop variable, when accessing the variable + // outside a loop, + // we should statically forward it. + bool loop_variable = false; + // Set to true while we're inside the for loop. + bool loop_variable_enable = false; + + // Used to find global LUTs + bool is_written_to = false; + + SPIRFunction::Parameter *parameter = nullptr; + + SPIRV_CROSS_DECLARE_CLONE(SPIRVariable) +}; + +struct SPIRConstant : IVariant +{ + enum + { + type = TypeConstant + }; + + union Constant + { + uint32_t u32; + int32_t i32; + float f32; + + uint64_t u64; + int64_t i64; + double f64; + }; + + struct ConstantVector + { + Constant r[4]; + // If != 0, this element is a specialization constant, and we should keep track of it as such. + ID id[4]; + uint32_t vecsize = 1; + + ConstantVector() + { + memset(r, 0, sizeof(r)); + } + }; + + struct ConstantMatrix + { + ConstantVector c[4]; + // If != 0, this column is a specialization constant, and we should keep track of it as such. + ID id[4]; + uint32_t columns = 1; + }; + + static inline float f16_to_f32(uint16_t u16_value) + { + // Based on the GLM implementation. + int s = (u16_value >> 15) & 0x1; + int e = (u16_value >> 10) & 0x1f; + int m = (u16_value >> 0) & 0x3ff; + + union + { + float f32; + uint32_t u32; + } u; + + if (e == 0) + { + if (m == 0) + { + u.u32 = uint32_t(s) << 31; + return u.f32; + } + else + { + while ((m & 0x400) == 0) + { + m <<= 1; + e--; + } + + e++; + m &= ~0x400; + } + } + else if (e == 31) + { + if (m == 0) + { + u.u32 = (uint32_t(s) << 31) | 0x7f800000u; + return u.f32; + } + else + { + u.u32 = (uint32_t(s) << 31) | 0x7f800000u | (m << 13); + return u.f32; + } + } + + e += 127 - 15; + m <<= 13; + u.u32 = (uint32_t(s) << 31) | (e << 23) | m; + return u.f32; + } + + inline uint32_t specialization_constant_id(uint32_t col, uint32_t row) const + { + return m.c[col].id[row]; + } + + inline uint32_t specialization_constant_id(uint32_t col) const + { + return m.id[col]; + } + + inline uint32_t scalar(uint32_t col = 0, uint32_t row = 0) const + { + return m.c[col].r[row].u32; + } + + inline int16_t scalar_i16(uint32_t col = 0, uint32_t row = 0) const + { + return int16_t(m.c[col].r[row].u32 & 0xffffu); + } + + inline uint16_t scalar_u16(uint32_t col = 0, uint32_t row = 0) const + { + return uint16_t(m.c[col].r[row].u32 & 0xffffu); + } + + inline int8_t scalar_i8(uint32_t col = 0, uint32_t row = 0) const + { + return int8_t(m.c[col].r[row].u32 & 0xffu); + } + + inline uint8_t scalar_u8(uint32_t col = 0, uint32_t row = 0) const + { + return uint8_t(m.c[col].r[row].u32 & 0xffu); + } + + inline float scalar_f16(uint32_t col = 0, uint32_t row = 0) const + { + return f16_to_f32(scalar_u16(col, row)); + } + + inline float scalar_f32(uint32_t col = 0, uint32_t row = 0) const + { + return m.c[col].r[row].f32; + } + + inline int32_t scalar_i32(uint32_t col = 0, uint32_t row = 0) const + { + return m.c[col].r[row].i32; + } + + inline double scalar_f64(uint32_t col = 0, uint32_t row = 0) const + { + return m.c[col].r[row].f64; + } + + inline int64_t scalar_i64(uint32_t col = 0, uint32_t row = 0) const + { + return m.c[col].r[row].i64; + } + + inline uint64_t scalar_u64(uint32_t col = 0, uint32_t row = 0) const + { + return m.c[col].r[row].u64; + } + + inline const ConstantVector &vector() const + { + return m.c[0]; + } + + inline uint32_t vector_size() const + { + return m.c[0].vecsize; + } + + inline uint32_t columns() const + { + return m.columns; + } + + inline void make_null(const SPIRType &constant_type_) + { + m = {}; + m.columns = constant_type_.columns; + for (auto &c : m.c) + c.vecsize = constant_type_.vecsize; + } + + inline bool constant_is_null() const + { + if (specialization) + return false; + if (!subconstants.empty()) + return false; + + for (uint32_t col = 0; col < columns(); col++) + for (uint32_t row = 0; row < vector_size(); row++) + if (scalar_u64(col, row) != 0) + return false; + + return true; + } + + explicit SPIRConstant(uint32_t constant_type_) + : constant_type(constant_type_) + { + } + + SPIRConstant() = default; + + SPIRConstant(TypeID constant_type_, const uint32_t *elements, uint32_t num_elements, bool specialized) + : constant_type(constant_type_) + , specialization(specialized) + { + subconstants.reserve(num_elements); + for (uint32_t i = 0; i < num_elements; i++) + subconstants.push_back(elements[i]); + specialization = specialized; + } + + // Construct scalar (32-bit). + SPIRConstant(TypeID constant_type_, uint32_t v0, bool specialized) + : constant_type(constant_type_) + , specialization(specialized) + { + m.c[0].r[0].u32 = v0; + m.c[0].vecsize = 1; + m.columns = 1; + } + + // Construct scalar (64-bit). + SPIRConstant(TypeID constant_type_, uint64_t v0, bool specialized) + : constant_type(constant_type_) + , specialization(specialized) + { + m.c[0].r[0].u64 = v0; + m.c[0].vecsize = 1; + m.columns = 1; + } + + // Construct vectors and matrices. + SPIRConstant(TypeID constant_type_, const SPIRConstant *const *vector_elements, uint32_t num_elements, + bool specialized) + : constant_type(constant_type_) + , specialization(specialized) + { + bool matrix = vector_elements[0]->m.c[0].vecsize > 1; + + if (matrix) + { + m.columns = num_elements; + + for (uint32_t i = 0; i < num_elements; i++) + { + m.c[i] = vector_elements[i]->m.c[0]; + if (vector_elements[i]->specialization) + m.id[i] = vector_elements[i]->self; + } + } + else + { + m.c[0].vecsize = num_elements; + m.columns = 1; + + for (uint32_t i = 0; i < num_elements; i++) + { + m.c[0].r[i] = vector_elements[i]->m.c[0].r[0]; + if (vector_elements[i]->specialization) + m.c[0].id[i] = vector_elements[i]->self; + } + } + } + + TypeID constant_type = 0; + ConstantMatrix m; + + // If this constant is a specialization constant (i.e. created with OpSpecConstant*). + bool specialization = false; + // If this constant is used as an array length which creates specialization restrictions on some backends. + bool is_used_as_array_length = false; + + // If true, this is a LUT, and should always be declared in the outer scope. + bool is_used_as_lut = false; + + // For composites which are constant arrays, etc. + SmallVector subconstants; + + // Non-Vulkan GLSL, HLSL and sometimes MSL emits defines for each specialization constant, + // and uses them to initialize the constant. This allows the user + // to still be able to specialize the value by supplying corresponding + // preprocessor directives before compiling the shader. + std::string specialization_constant_macro_name; + + SPIRV_CROSS_DECLARE_CLONE(SPIRConstant) +}; + +// Variants have a very specific allocation scheme. +struct ObjectPoolGroup +{ + std::unique_ptr pools[TypeCount]; +}; + +class Variant +{ +public: + explicit Variant(ObjectPoolGroup *group_) + : group(group_) + { + } + + ~Variant() + { + if (holder) + group->pools[type]->deallocate_opaque(holder); + } + + // Marking custom move constructor as noexcept is important. + Variant(Variant &&other) SPIRV_CROSS_NOEXCEPT + { + *this = std::move(other); + } + + // We cannot copy from other variant without our own pool group. + // Have to explicitly copy. + Variant(const Variant &variant) = delete; + + // Marking custom move constructor as noexcept is important. + Variant &operator=(Variant &&other) SPIRV_CROSS_NOEXCEPT + { + if (this != &other) + { + if (holder) + group->pools[type]->deallocate_opaque(holder); + holder = other.holder; + group = other.group; + type = other.type; + allow_type_rewrite = other.allow_type_rewrite; + + other.holder = nullptr; + other.type = TypeNone; + } + return *this; + } + + // This copy/clone should only be called in the Compiler constructor. + // If this is called inside ::compile(), we invalidate any references we took higher in the stack. + // This should never happen. + Variant &operator=(const Variant &other) + { +//#define SPIRV_CROSS_COPY_CONSTRUCTOR_SANITIZE +#ifdef SPIRV_CROSS_COPY_CONSTRUCTOR_SANITIZE + abort(); +#endif + if (this != &other) + { + if (holder) + group->pools[type]->deallocate_opaque(holder); + + if (other.holder) + holder = other.holder->clone(group->pools[other.type].get()); + else + holder = nullptr; + + type = other.type; + allow_type_rewrite = other.allow_type_rewrite; + } + return *this; + } + + void set(IVariant *val, Types new_type) + { + if (holder) + group->pools[type]->deallocate_opaque(holder); + holder = nullptr; + + if (!allow_type_rewrite && type != TypeNone && type != new_type) + { + if (val) + group->pools[new_type]->deallocate_opaque(val); + SPIRV_CROSS_THROW("Overwriting a variant with new type."); + } + + holder = val; + type = new_type; + allow_type_rewrite = false; + } + + template + T *allocate_and_set(Types new_type, Ts &&... ts) + { + T *val = static_cast &>(*group->pools[new_type]).allocate(std::forward(ts)...); + set(val, new_type); + return val; + } + + template + T &get() + { + if (!holder) + SPIRV_CROSS_THROW("nullptr"); + if (static_cast(T::type) != type) + SPIRV_CROSS_THROW("Bad cast"); + return *static_cast(holder); + } + + template + const T &get() const + { + if (!holder) + SPIRV_CROSS_THROW("nullptr"); + if (static_cast(T::type) != type) + SPIRV_CROSS_THROW("Bad cast"); + return *static_cast(holder); + } + + Types get_type() const + { + return type; + } + + ID get_id() const + { + return holder ? holder->self : ID(0); + } + + bool empty() const + { + return !holder; + } + + void reset() + { + if (holder) + group->pools[type]->deallocate_opaque(holder); + holder = nullptr; + type = TypeNone; + } + + void set_allow_type_rewrite() + { + allow_type_rewrite = true; + } + +private: + ObjectPoolGroup *group = nullptr; + IVariant *holder = nullptr; + Types type = TypeNone; + bool allow_type_rewrite = false; +}; + +template +T &variant_get(Variant &var) +{ + return var.get(); +} + +template +const T &variant_get(const Variant &var) +{ + return var.get(); +} + +template +T &variant_set(Variant &var, P &&... args) +{ + auto *ptr = var.allocate_and_set(static_cast(T::type), std::forward

(args)...); + return *ptr; +} + +struct AccessChainMeta +{ + uint32_t storage_physical_type = 0; + bool need_transpose = false; + bool storage_is_packed = false; + bool storage_is_invariant = false; + bool flattened_struct = false; + bool relaxed_precision = false; + bool access_meshlet_position_y = false; +}; + +enum ExtendedDecorations +{ + // Marks if a buffer block is re-packed, i.e. member declaration might be subject to PhysicalTypeID remapping and padding. + SPIRVCrossDecorationBufferBlockRepacked = 0, + + // A type in a buffer block might be declared with a different physical type than the logical type. + // If this is not set, PhysicalTypeID == the SPIR-V type as declared. + SPIRVCrossDecorationPhysicalTypeID, + + // Marks if the physical type is to be declared with tight packing rules, i.e. packed_floatN on MSL and friends. + // If this is set, PhysicalTypeID might also be set. It can be set to same as logical type if all we're doing + // is converting float3 to packed_float3 for example. + // If this is marked on a struct, it means the struct itself must use only Packed types for all its members. + SPIRVCrossDecorationPhysicalTypePacked, + + // The padding in bytes before declaring this struct member. + // If used on a struct type, marks the target size of a struct. + SPIRVCrossDecorationPaddingTarget, + + SPIRVCrossDecorationInterfaceMemberIndex, + SPIRVCrossDecorationInterfaceOrigID, + SPIRVCrossDecorationResourceIndexPrimary, + // Used for decorations like resource indices for samplers when part of combined image samplers. + // A variable might need to hold two resource indices in this case. + SPIRVCrossDecorationResourceIndexSecondary, + // Used for resource indices for multiplanar images when part of combined image samplers. + SPIRVCrossDecorationResourceIndexTertiary, + SPIRVCrossDecorationResourceIndexQuaternary, + + // Marks a buffer block for using explicit offsets (GLSL/HLSL). + SPIRVCrossDecorationExplicitOffset, + + // Apply to a variable in the Input storage class; marks it as holding the base group passed to vkCmdDispatchBase(), + // or the base vertex and instance indices passed to vkCmdDrawIndexed(). + // In MSL, this is used to adjust the WorkgroupId and GlobalInvocationId variables in compute shaders, + // and to hold the BaseVertex and BaseInstance variables in vertex shaders. + SPIRVCrossDecorationBuiltInDispatchBase, + + // Apply to a variable that is a function parameter; marks it as being a "dynamic" + // combined image-sampler. In MSL, this is used when a function parameter might hold + // either a regular combined image-sampler or one that has an attached sampler + // Y'CbCr conversion. + SPIRVCrossDecorationDynamicImageSampler, + + // Apply to a variable in the Input storage class; marks it as holding the size of the stage + // input grid. + // In MSL, this is used to hold the vertex and instance counts in a tessellation pipeline + // vertex shader. + SPIRVCrossDecorationBuiltInStageInputSize, + + // Apply to any access chain of a tessellation I/O variable; stores the type of the sub-object + // that was chained to, as recorded in the input variable itself. This is used in case the pointer + // is itself used as the base of an access chain, to calculate the original type of the sub-object + // chained to, in case a swizzle needs to be applied. This should not happen normally with valid + // SPIR-V, but the MSL backend can change the type of input variables, necessitating the + // addition of swizzles to keep the generated code compiling. + SPIRVCrossDecorationTessIOOriginalInputTypeID, + + // Apply to any access chain of an interface variable used with pull-model interpolation, where the variable is a + // vector but the resulting pointer is a scalar; stores the component index that is to be accessed by the chain. + // This is used when emitting calls to interpolation functions on the chain in MSL: in this case, the component + // must be applied to the result, since pull-model interpolants in MSL cannot be swizzled directly, but the + // results of interpolation can. + SPIRVCrossDecorationInterpolantComponentExpr, + + // Apply to any struct type that is used in the Workgroup storage class. + // This causes matrices in MSL prior to Metal 3.0 to be emitted using a special + // class that is convertible to the standard matrix type, to work around the + // lack of constructors in the 'threadgroup' address space. + SPIRVCrossDecorationWorkgroupStruct, + + SPIRVCrossDecorationOverlappingBinding, + + SPIRVCrossDecorationCount +}; + +struct Meta +{ + struct Decoration + { + std::string alias; + std::string qualified_alias; + std::string hlsl_semantic; + std::string user_type; + Bitset decoration_flags; + spv::BuiltIn builtin_type = spv::BuiltInMax; + uint32_t location = 0; + uint32_t component = 0; + uint32_t set = 0; + uint32_t binding = 0; + uint32_t offset = 0; + uint32_t xfb_buffer = 0; + uint32_t xfb_stride = 0; + uint32_t stream = 0; + uint32_t array_stride = 0; + uint32_t matrix_stride = 0; + uint32_t input_attachment = 0; + uint32_t spec_id = 0; + uint32_t index = 0; + spv::FPRoundingMode fp_rounding_mode = spv::FPRoundingModeMax; + bool builtin = false; + bool qualified_alias_explicit_override = false; + + struct Extended + { + Extended() + { + // MSVC 2013 workaround to init like this. + for (auto &v : values) + v = 0; + } + + Bitset flags; + uint32_t values[SPIRVCrossDecorationCount]; + } extended; + }; + + Decoration decoration; + + // Intentionally not a SmallVector. Decoration is large and somewhat rare. + Vector members; + + std::unordered_map decoration_word_offset; + + // For SPV_GOOGLE_hlsl_functionality1. + bool hlsl_is_magic_counter_buffer = false; + // ID for the sibling counter buffer. + uint32_t hlsl_magic_counter_buffer = 0; +}; + +// A user callback that remaps the type of any variable. +// var_name is the declared name of the variable. +// name_of_type is the textual name of the type which will be used in the code unless written to by the callback. +using VariableTypeRemapCallback = + std::function; + +class Hasher +{ +public: + inline void u32(uint32_t value) + { + h = (h * 0x100000001b3ull) ^ value; + } + + inline uint64_t get() const + { + return h; + } + +private: + uint64_t h = 0xcbf29ce484222325ull; +}; + +static inline bool type_is_floating_point(const SPIRType &type) +{ + return type.basetype == SPIRType::Half || type.basetype == SPIRType::Float || type.basetype == SPIRType::Double; +} + +static inline bool type_is_integral(const SPIRType &type) +{ + return type.basetype == SPIRType::SByte || type.basetype == SPIRType::UByte || type.basetype == SPIRType::Short || + type.basetype == SPIRType::UShort || type.basetype == SPIRType::Int || type.basetype == SPIRType::UInt || + type.basetype == SPIRType::Int64 || type.basetype == SPIRType::UInt64; +} + +static inline SPIRType::BaseType to_signed_basetype(uint32_t width) +{ + switch (width) + { + case 8: + return SPIRType::SByte; + case 16: + return SPIRType::Short; + case 32: + return SPIRType::Int; + case 64: + return SPIRType::Int64; + default: + SPIRV_CROSS_THROW("Invalid bit width."); + } +} + +static inline SPIRType::BaseType to_unsigned_basetype(uint32_t width) +{ + switch (width) + { + case 8: + return SPIRType::UByte; + case 16: + return SPIRType::UShort; + case 32: + return SPIRType::UInt; + case 64: + return SPIRType::UInt64; + default: + SPIRV_CROSS_THROW("Invalid bit width."); + } +} + +// Returns true if an arithmetic operation does not change behavior depending on signedness. +static inline bool opcode_is_sign_invariant(spv::Op opcode) +{ + switch (opcode) + { + case spv::OpIEqual: + case spv::OpINotEqual: + case spv::OpISub: + case spv::OpIAdd: + case spv::OpIMul: + case spv::OpShiftLeftLogical: + case spv::OpBitwiseOr: + case spv::OpBitwiseXor: + case spv::OpBitwiseAnd: + return true; + + default: + return false; + } +} + +static inline bool opcode_can_promote_integer_implicitly(spv::Op opcode) +{ + switch (opcode) + { + case spv::OpSNegate: + case spv::OpNot: + case spv::OpBitwiseAnd: + case spv::OpBitwiseOr: + case spv::OpBitwiseXor: + case spv::OpShiftLeftLogical: + case spv::OpShiftRightLogical: + case spv::OpShiftRightArithmetic: + case spv::OpIAdd: + case spv::OpISub: + case spv::OpIMul: + case spv::OpSDiv: + case spv::OpUDiv: + case spv::OpSRem: + case spv::OpUMod: + case spv::OpSMod: + return true; + + default: + return false; + } +} + +struct SetBindingPair +{ + uint32_t desc_set; + uint32_t binding; + + inline bool operator==(const SetBindingPair &other) const + { + return desc_set == other.desc_set && binding == other.binding; + } + + inline bool operator<(const SetBindingPair &other) const + { + return desc_set < other.desc_set || (desc_set == other.desc_set && binding < other.binding); + } +}; + +struct LocationComponentPair +{ + uint32_t location; + uint32_t component; + + inline bool operator==(const LocationComponentPair &other) const + { + return location == other.location && component == other.component; + } + + inline bool operator<(const LocationComponentPair &other) const + { + return location < other.location || (location == other.location && component < other.component); + } +}; + +struct StageSetBinding +{ + spv::ExecutionModel model; + uint32_t desc_set; + uint32_t binding; + + inline bool operator==(const StageSetBinding &other) const + { + return model == other.model && desc_set == other.desc_set && binding == other.binding; + } +}; + +struct InternalHasher +{ + inline size_t operator()(const SetBindingPair &value) const + { + // Quality of hash doesn't really matter here. + auto hash_set = std::hash()(value.desc_set); + auto hash_binding = std::hash()(value.binding); + return (hash_set * 0x10001b31) ^ hash_binding; + } + + inline size_t operator()(const LocationComponentPair &value) const + { + // Quality of hash doesn't really matter here. + auto hash_set = std::hash()(value.location); + auto hash_binding = std::hash()(value.component); + return (hash_set * 0x10001b31) ^ hash_binding; + } + + inline size_t operator()(const StageSetBinding &value) const + { + // Quality of hash doesn't really matter here. + auto hash_model = std::hash()(value.model); + auto hash_set = std::hash()(value.desc_set); + auto tmp_hash = (hash_model * 0x10001b31) ^ hash_set; + return (tmp_hash * 0x10001b31) ^ value.binding; + } +}; + +// Special constant used in a {MSL,HLSL}ResourceBinding desc_set +// element to indicate the bindings for the push constants. +static const uint32_t ResourceBindingPushConstantDescriptorSet = ~(0u); + +// Special constant used in a {MSL,HLSL}ResourceBinding binding +// element to indicate the bindings for the push constants. +static const uint32_t ResourceBindingPushConstantBinding = 0; +} // namespace SPIRV_CROSS_NAMESPACE + +namespace std +{ +template +struct hash> +{ + size_t operator()(const SPIRV_CROSS_NAMESPACE::TypedID &value) const + { + return std::hash()(value); + } +}; +} // namespace std + +#endif diff --git a/thirdparty/spirv-cross/spirv_cross.cpp b/thirdparty/spirv-cross/spirv_cross.cpp new file mode 100644 index 00000000000..8c3e7d38120 --- /dev/null +++ b/thirdparty/spirv-cross/spirv_cross.cpp @@ -0,0 +1,5668 @@ +/* + * Copyright 2015-2021 Arm Limited + * SPDX-License-Identifier: Apache-2.0 OR MIT + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +/* + * At your option, you may choose to accept this material under either: + * 1. The Apache License, Version 2.0, found at , or + * 2. The MIT License, found at . + */ + +#include "spirv_cross.hpp" +#include "GLSL.std.450.h" +#include "spirv_cfg.hpp" +#include "spirv_common.hpp" +#include "spirv_parser.hpp" +#include +#include +#include + +using namespace std; +using namespace spv; +using namespace SPIRV_CROSS_NAMESPACE; + +Compiler::Compiler(vector ir_) +{ + Parser parser(std::move(ir_)); + parser.parse(); + set_ir(std::move(parser.get_parsed_ir())); +} + +Compiler::Compiler(const uint32_t *ir_, size_t word_count) +{ + Parser parser(ir_, word_count); + parser.parse(); + set_ir(std::move(parser.get_parsed_ir())); +} + +Compiler::Compiler(const ParsedIR &ir_) +{ + set_ir(ir_); +} + +Compiler::Compiler(ParsedIR &&ir_) +{ + set_ir(std::move(ir_)); +} + +void Compiler::set_ir(ParsedIR &&ir_) +{ + ir = std::move(ir_); + parse_fixup(); +} + +void Compiler::set_ir(const ParsedIR &ir_) +{ + ir = ir_; + parse_fixup(); +} + +string Compiler::compile() +{ + return ""; +} + +bool Compiler::variable_storage_is_aliased(const SPIRVariable &v) +{ + auto &type = get(v.basetype); + bool ssbo = v.storage == StorageClassStorageBuffer || + ir.meta[type.self].decoration.decoration_flags.get(DecorationBufferBlock); + bool image = type.basetype == SPIRType::Image; + bool counter = type.basetype == SPIRType::AtomicCounter; + bool buffer_reference = type.storage == StorageClassPhysicalStorageBufferEXT; + + bool is_restrict; + if (ssbo) + is_restrict = ir.get_buffer_block_flags(v).get(DecorationRestrict); + else + is_restrict = has_decoration(v.self, DecorationRestrict); + + return !is_restrict && (ssbo || image || counter || buffer_reference); +} + +bool Compiler::block_is_control_dependent(const SPIRBlock &block) +{ + for (auto &i : block.ops) + { + auto ops = stream(i); + auto op = static_cast(i.op); + + switch (op) + { + case OpFunctionCall: + { + uint32_t func = ops[2]; + if (function_is_control_dependent(get(func))) + return true; + break; + } + + // Derivatives + case OpDPdx: + case OpDPdxCoarse: + case OpDPdxFine: + case OpDPdy: + case OpDPdyCoarse: + case OpDPdyFine: + case OpFwidth: + case OpFwidthCoarse: + case OpFwidthFine: + + // Anything implicit LOD + case OpImageSampleImplicitLod: + case OpImageSampleDrefImplicitLod: + case OpImageSampleProjImplicitLod: + case OpImageSampleProjDrefImplicitLod: + case OpImageSparseSampleImplicitLod: + case OpImageSparseSampleDrefImplicitLod: + case OpImageSparseSampleProjImplicitLod: + case OpImageSparseSampleProjDrefImplicitLod: + case OpImageQueryLod: + case OpImageDrefGather: + case OpImageGather: + case OpImageSparseDrefGather: + case OpImageSparseGather: + + // Anything subgroups + case OpGroupNonUniformElect: + case OpGroupNonUniformAll: + case OpGroupNonUniformAny: + case OpGroupNonUniformAllEqual: + case OpGroupNonUniformBroadcast: + case OpGroupNonUniformBroadcastFirst: + case OpGroupNonUniformBallot: + case OpGroupNonUniformInverseBallot: + case OpGroupNonUniformBallotBitExtract: + case OpGroupNonUniformBallotBitCount: + case OpGroupNonUniformBallotFindLSB: + case OpGroupNonUniformBallotFindMSB: + case OpGroupNonUniformShuffle: + case OpGroupNonUniformShuffleXor: + case OpGroupNonUniformShuffleUp: + case OpGroupNonUniformShuffleDown: + case OpGroupNonUniformIAdd: + case OpGroupNonUniformFAdd: + case OpGroupNonUniformIMul: + case OpGroupNonUniformFMul: + case OpGroupNonUniformSMin: + case OpGroupNonUniformUMin: + case OpGroupNonUniformFMin: + case OpGroupNonUniformSMax: + case OpGroupNonUniformUMax: + case OpGroupNonUniformFMax: + case OpGroupNonUniformBitwiseAnd: + case OpGroupNonUniformBitwiseOr: + case OpGroupNonUniformBitwiseXor: + case OpGroupNonUniformLogicalAnd: + case OpGroupNonUniformLogicalOr: + case OpGroupNonUniformLogicalXor: + case OpGroupNonUniformQuadBroadcast: + case OpGroupNonUniformQuadSwap: + + // Control barriers + case OpControlBarrier: + return true; + + default: + break; + } + } + + return false; +} + +bool Compiler::block_is_pure(const SPIRBlock &block) +{ + // This is a global side effect of the function. + if (block.terminator == SPIRBlock::Kill || + block.terminator == SPIRBlock::TerminateRay || + block.terminator == SPIRBlock::IgnoreIntersection || + block.terminator == SPIRBlock::EmitMeshTasks) + return false; + + for (auto &i : block.ops) + { + auto ops = stream(i); + auto op = static_cast(i.op); + + switch (op) + { + case OpFunctionCall: + { + uint32_t func = ops[2]; + if (!function_is_pure(get(func))) + return false; + break; + } + + case OpCopyMemory: + case OpStore: + { + auto &type = expression_type(ops[0]); + if (type.storage != StorageClassFunction) + return false; + break; + } + + case OpImageWrite: + return false; + + // Atomics are impure. + case OpAtomicLoad: + case OpAtomicStore: + case OpAtomicExchange: + case OpAtomicCompareExchange: + case OpAtomicCompareExchangeWeak: + case OpAtomicIIncrement: + case OpAtomicIDecrement: + case OpAtomicIAdd: + case OpAtomicISub: + case OpAtomicSMin: + case OpAtomicUMin: + case OpAtomicSMax: + case OpAtomicUMax: + case OpAtomicAnd: + case OpAtomicOr: + case OpAtomicXor: + return false; + + // Geometry shader builtins modify global state. + case OpEndPrimitive: + case OpEmitStreamVertex: + case OpEndStreamPrimitive: + case OpEmitVertex: + return false; + + // Mesh shader functions modify global state. + // (EmitMeshTasks is a terminator). + case OpSetMeshOutputsEXT: + return false; + + // Barriers disallow any reordering, so we should treat blocks with barrier as writing. + case OpControlBarrier: + case OpMemoryBarrier: + return false; + + // Ray tracing builtins are impure. + case OpReportIntersectionKHR: + case OpIgnoreIntersectionNV: + case OpTerminateRayNV: + case OpTraceNV: + case OpTraceRayKHR: + case OpExecuteCallableNV: + case OpExecuteCallableKHR: + case OpRayQueryInitializeKHR: + case OpRayQueryTerminateKHR: + case OpRayQueryGenerateIntersectionKHR: + case OpRayQueryConfirmIntersectionKHR: + case OpRayQueryProceedKHR: + // There are various getters in ray query, but they are considered pure. + return false; + + // OpExtInst is potentially impure depending on extension, but GLSL builtins are at least pure. + + case OpDemoteToHelperInvocationEXT: + // This is a global side effect of the function. + return false; + + case OpExtInst: + { + uint32_t extension_set = ops[2]; + if (get(extension_set).ext == SPIRExtension::GLSL) + { + auto op_450 = static_cast(ops[3]); + switch (op_450) + { + case GLSLstd450Modf: + case GLSLstd450Frexp: + { + auto &type = expression_type(ops[5]); + if (type.storage != StorageClassFunction) + return false; + break; + } + + default: + break; + } + } + break; + } + + default: + break; + } + } + + return true; +} + +string Compiler::to_name(uint32_t id, bool allow_alias) const +{ + if (allow_alias && ir.ids[id].get_type() == TypeType) + { + // If this type is a simple alias, emit the + // name of the original type instead. + // We don't want to override the meta alias + // as that can be overridden by the reflection APIs after parse. + auto &type = get(id); + if (type.type_alias) + { + // If the alias master has been specially packed, we will have emitted a clean variant as well, + // so skip the name aliasing here. + if (!has_extended_decoration(type.type_alias, SPIRVCrossDecorationBufferBlockRepacked)) + return to_name(type.type_alias); + } + } + + auto &alias = ir.get_name(id); + if (alias.empty()) + return join("_", id); + else + return alias; +} + +bool Compiler::function_is_pure(const SPIRFunction &func) +{ + for (auto block : func.blocks) + if (!block_is_pure(get(block))) + return false; + + return true; +} + +bool Compiler::function_is_control_dependent(const SPIRFunction &func) +{ + for (auto block : func.blocks) + if (block_is_control_dependent(get(block))) + return true; + + return false; +} + +void Compiler::register_global_read_dependencies(const SPIRBlock &block, uint32_t id) +{ + for (auto &i : block.ops) + { + auto ops = stream(i); + auto op = static_cast(i.op); + + switch (op) + { + case OpFunctionCall: + { + uint32_t func = ops[2]; + register_global_read_dependencies(get(func), id); + break; + } + + case OpLoad: + case OpImageRead: + { + // If we're in a storage class which does not get invalidated, adding dependencies here is no big deal. + auto *var = maybe_get_backing_variable(ops[2]); + if (var && var->storage != StorageClassFunction) + { + auto &type = get(var->basetype); + + // InputTargets are immutable. + if (type.basetype != SPIRType::Image && type.image.dim != DimSubpassData) + var->dependees.push_back(id); + } + break; + } + + default: + break; + } + } +} + +void Compiler::register_global_read_dependencies(const SPIRFunction &func, uint32_t id) +{ + for (auto block : func.blocks) + register_global_read_dependencies(get(block), id); +} + +SPIRVariable *Compiler::maybe_get_backing_variable(uint32_t chain) +{ + auto *var = maybe_get(chain); + if (!var) + { + auto *cexpr = maybe_get(chain); + if (cexpr) + var = maybe_get(cexpr->loaded_from); + + auto *access_chain = maybe_get(chain); + if (access_chain) + var = maybe_get(access_chain->loaded_from); + } + + return var; +} + +void Compiler::register_read(uint32_t expr, uint32_t chain, bool forwarded) +{ + auto &e = get(expr); + auto *var = maybe_get_backing_variable(chain); + + if (var) + { + e.loaded_from = var->self; + + // If the backing variable is immutable, we do not need to depend on the variable. + if (forwarded && !is_immutable(var->self)) + var->dependees.push_back(e.self); + + // If we load from a parameter, make sure we create "inout" if we also write to the parameter. + // The default is "in" however, so we never invalidate our compilation by reading. + if (var && var->parameter) + var->parameter->read_count++; + } +} + +void Compiler::register_write(uint32_t chain) +{ + auto *var = maybe_get(chain); + if (!var) + { + // If we're storing through an access chain, invalidate the backing variable instead. + auto *expr = maybe_get(chain); + if (expr && expr->loaded_from) + var = maybe_get(expr->loaded_from); + + auto *access_chain = maybe_get(chain); + if (access_chain && access_chain->loaded_from) + var = maybe_get(access_chain->loaded_from); + } + + auto &chain_type = expression_type(chain); + + if (var) + { + bool check_argument_storage_qualifier = true; + auto &type = expression_type(chain); + + // If our variable is in a storage class which can alias with other buffers, + // invalidate all variables which depend on aliased variables. And if this is a + // variable pointer, then invalidate all variables regardless. + if (get_variable_data_type(*var).pointer) + { + flush_all_active_variables(); + + if (type.pointer_depth == 1) + { + // We have a backing variable which is a pointer-to-pointer type. + // We are storing some data through a pointer acquired through that variable, + // but we are not writing to the value of the variable itself, + // i.e., we are not modifying the pointer directly. + // If we are storing a non-pointer type (pointer_depth == 1), + // we know that we are storing some unrelated data. + // A case here would be + // void foo(Foo * const *arg) { + // Foo *bar = *arg; + // bar->unrelated = 42; + // } + // arg, the argument is constant. + check_argument_storage_qualifier = false; + } + } + + if (type.storage == StorageClassPhysicalStorageBufferEXT || variable_storage_is_aliased(*var)) + flush_all_aliased_variables(); + else if (var) + flush_dependees(*var); + + // We tried to write to a parameter which is not marked with out qualifier, force a recompile. + if (check_argument_storage_qualifier && var->parameter && var->parameter->write_count == 0) + { + var->parameter->write_count++; + force_recompile(); + } + } + else if (chain_type.pointer) + { + // If we stored through a variable pointer, then we don't know which + // variable we stored to. So *all* expressions after this point need to + // be invalidated. + // FIXME: If we can prove that the variable pointer will point to + // only certain variables, we can invalidate only those. + flush_all_active_variables(); + } + + // If chain_type.pointer is false, we're not writing to memory backed variables, but temporaries instead. + // This can happen in copy_logical_type where we unroll complex reads and writes to temporaries. +} + +void Compiler::flush_dependees(SPIRVariable &var) +{ + for (auto expr : var.dependees) + invalid_expressions.insert(expr); + var.dependees.clear(); +} + +void Compiler::flush_all_aliased_variables() +{ + for (auto aliased : aliased_variables) + flush_dependees(get(aliased)); +} + +void Compiler::flush_all_atomic_capable_variables() +{ + for (auto global : global_variables) + flush_dependees(get(global)); + flush_all_aliased_variables(); +} + +void Compiler::flush_control_dependent_expressions(uint32_t block_id) +{ + auto &block = get(block_id); + for (auto &expr : block.invalidate_expressions) + invalid_expressions.insert(expr); + block.invalidate_expressions.clear(); +} + +void Compiler::flush_all_active_variables() +{ + // Invalidate all temporaries we read from variables in this block since they were forwarded. + // Invalidate all temporaries we read from globals. + for (auto &v : current_function->local_variables) + flush_dependees(get(v)); + for (auto &arg : current_function->arguments) + flush_dependees(get(arg.id)); + for (auto global : global_variables) + flush_dependees(get(global)); + + flush_all_aliased_variables(); +} + +uint32_t Compiler::expression_type_id(uint32_t id) const +{ + switch (ir.ids[id].get_type()) + { + case TypeVariable: + return get(id).basetype; + + case TypeExpression: + return get(id).expression_type; + + case TypeConstant: + return get(id).constant_type; + + case TypeConstantOp: + return get(id).basetype; + + case TypeUndef: + return get(id).basetype; + + case TypeCombinedImageSampler: + return get(id).combined_type; + + case TypeAccessChain: + return get(id).basetype; + + default: + SPIRV_CROSS_THROW("Cannot resolve expression type."); + } +} + +const SPIRType &Compiler::expression_type(uint32_t id) const +{ + return get(expression_type_id(id)); +} + +bool Compiler::expression_is_lvalue(uint32_t id) const +{ + auto &type = expression_type(id); + switch (type.basetype) + { + case SPIRType::SampledImage: + case SPIRType::Image: + case SPIRType::Sampler: + return false; + + default: + return true; + } +} + +bool Compiler::is_immutable(uint32_t id) const +{ + if (ir.ids[id].get_type() == TypeVariable) + { + auto &var = get(id); + + // Anything we load from the UniformConstant address space is guaranteed to be immutable. + bool pointer_to_const = var.storage == StorageClassUniformConstant; + return pointer_to_const || var.phi_variable || !expression_is_lvalue(id); + } + else if (ir.ids[id].get_type() == TypeAccessChain) + return get(id).immutable; + else if (ir.ids[id].get_type() == TypeExpression) + return get(id).immutable; + else if (ir.ids[id].get_type() == TypeConstant || ir.ids[id].get_type() == TypeConstantOp || + ir.ids[id].get_type() == TypeUndef) + return true; + else + return false; +} + +static inline bool storage_class_is_interface(spv::StorageClass storage) +{ + switch (storage) + { + case StorageClassInput: + case StorageClassOutput: + case StorageClassUniform: + case StorageClassUniformConstant: + case StorageClassAtomicCounter: + case StorageClassPushConstant: + case StorageClassStorageBuffer: + return true; + + default: + return false; + } +} + +bool Compiler::is_hidden_variable(const SPIRVariable &var, bool include_builtins) const +{ + if ((is_builtin_variable(var) && !include_builtins) || var.remapped_variable) + return true; + + // Combined image samplers are always considered active as they are "magic" variables. + if (find_if(begin(combined_image_samplers), end(combined_image_samplers), [&var](const CombinedImageSampler &samp) { + return samp.combined_id == var.self; + }) != end(combined_image_samplers)) + { + return false; + } + + // In SPIR-V 1.4 and up we must also use the active variable interface to disable global variables + // which are not part of the entry point. + if (ir.get_spirv_version() >= 0x10400 && var.storage != spv::StorageClassGeneric && + var.storage != spv::StorageClassFunction && !interface_variable_exists_in_entry_point(var.self)) + { + return true; + } + + return check_active_interface_variables && storage_class_is_interface(var.storage) && + active_interface_variables.find(var.self) == end(active_interface_variables); +} + +bool Compiler::is_builtin_type(const SPIRType &type) const +{ + auto *type_meta = ir.find_meta(type.self); + + // We can have builtin structs as well. If one member of a struct is builtin, the struct must also be builtin. + if (type_meta) + for (auto &m : type_meta->members) + if (m.builtin) + return true; + + return false; +} + +bool Compiler::is_builtin_variable(const SPIRVariable &var) const +{ + auto *m = ir.find_meta(var.self); + + if (var.compat_builtin || (m && m->decoration.builtin)) + return true; + else + return is_builtin_type(get(var.basetype)); +} + +bool Compiler::is_member_builtin(const SPIRType &type, uint32_t index, BuiltIn *builtin) const +{ + auto *type_meta = ir.find_meta(type.self); + + if (type_meta) + { + auto &memb = type_meta->members; + if (index < memb.size() && memb[index].builtin) + { + if (builtin) + *builtin = memb[index].builtin_type; + return true; + } + } + + return false; +} + +bool Compiler::is_scalar(const SPIRType &type) const +{ + return type.basetype != SPIRType::Struct && type.vecsize == 1 && type.columns == 1; +} + +bool Compiler::is_vector(const SPIRType &type) const +{ + return type.vecsize > 1 && type.columns == 1; +} + +bool Compiler::is_matrix(const SPIRType &type) const +{ + return type.vecsize > 1 && type.columns > 1; +} + +bool Compiler::is_array(const SPIRType &type) const +{ + return type.op == OpTypeArray || type.op == OpTypeRuntimeArray; +} + +bool Compiler::is_pointer(const SPIRType &type) const +{ + return type.op == OpTypePointer && type.basetype != SPIRType::Unknown; // Ignore function pointers. +} + +bool Compiler::is_physical_pointer(const SPIRType &type) const +{ + return type.op == OpTypePointer && type.storage == StorageClassPhysicalStorageBuffer; +} + +bool Compiler::is_physical_pointer_to_buffer_block(const SPIRType &type) const +{ + return is_physical_pointer(type) && get_pointee_type(type).self == type.parent_type && + (has_decoration(type.self, DecorationBlock) || + has_decoration(type.self, DecorationBufferBlock)); +} + +bool Compiler::is_runtime_size_array(const SPIRType &type) +{ + return type.op == OpTypeRuntimeArray; +} + +ShaderResources Compiler::get_shader_resources() const +{ + return get_shader_resources(nullptr); +} + +ShaderResources Compiler::get_shader_resources(const unordered_set &active_variables) const +{ + return get_shader_resources(&active_variables); +} + +bool Compiler::InterfaceVariableAccessHandler::handle(Op opcode, const uint32_t *args, uint32_t length) +{ + uint32_t variable = 0; + switch (opcode) + { + // Need this first, otherwise, GCC complains about unhandled switch statements. + default: + break; + + case OpFunctionCall: + { + // Invalid SPIR-V. + if (length < 3) + return false; + + uint32_t count = length - 3; + args += 3; + for (uint32_t i = 0; i < count; i++) + { + auto *var = compiler.maybe_get(args[i]); + if (var && storage_class_is_interface(var->storage)) + variables.insert(args[i]); + } + break; + } + + case OpSelect: + { + // Invalid SPIR-V. + if (length < 5) + return false; + + uint32_t count = length - 3; + args += 3; + for (uint32_t i = 0; i < count; i++) + { + auto *var = compiler.maybe_get(args[i]); + if (var && storage_class_is_interface(var->storage)) + variables.insert(args[i]); + } + break; + } + + case OpPhi: + { + // Invalid SPIR-V. + if (length < 2) + return false; + + uint32_t count = length - 2; + args += 2; + for (uint32_t i = 0; i < count; i += 2) + { + auto *var = compiler.maybe_get(args[i]); + if (var && storage_class_is_interface(var->storage)) + variables.insert(args[i]); + } + break; + } + + case OpAtomicStore: + case OpStore: + // Invalid SPIR-V. + if (length < 1) + return false; + variable = args[0]; + break; + + case OpCopyMemory: + { + if (length < 2) + return false; + + auto *var = compiler.maybe_get(args[0]); + if (var && storage_class_is_interface(var->storage)) + variables.insert(args[0]); + + var = compiler.maybe_get(args[1]); + if (var && storage_class_is_interface(var->storage)) + variables.insert(args[1]); + break; + } + + case OpExtInst: + { + if (length < 3) + return false; + auto &extension_set = compiler.get(args[2]); + switch (extension_set.ext) + { + case SPIRExtension::GLSL: + { + auto op = static_cast(args[3]); + + switch (op) + { + case GLSLstd450InterpolateAtCentroid: + case GLSLstd450InterpolateAtSample: + case GLSLstd450InterpolateAtOffset: + { + auto *var = compiler.maybe_get(args[4]); + if (var && storage_class_is_interface(var->storage)) + variables.insert(args[4]); + break; + } + + case GLSLstd450Modf: + case GLSLstd450Fract: + { + auto *var = compiler.maybe_get(args[5]); + if (var && storage_class_is_interface(var->storage)) + variables.insert(args[5]); + break; + } + + default: + break; + } + break; + } + case SPIRExtension::SPV_AMD_shader_explicit_vertex_parameter: + { + enum AMDShaderExplicitVertexParameter + { + InterpolateAtVertexAMD = 1 + }; + + auto op = static_cast(args[3]); + + switch (op) + { + case InterpolateAtVertexAMD: + { + auto *var = compiler.maybe_get(args[4]); + if (var && storage_class_is_interface(var->storage)) + variables.insert(args[4]); + break; + } + + default: + break; + } + break; + } + default: + break; + } + break; + } + + case OpAccessChain: + case OpInBoundsAccessChain: + case OpPtrAccessChain: + case OpLoad: + case OpCopyObject: + case OpImageTexelPointer: + case OpAtomicLoad: + case OpAtomicExchange: + case OpAtomicCompareExchange: + case OpAtomicCompareExchangeWeak: + case OpAtomicIIncrement: + case OpAtomicIDecrement: + case OpAtomicIAdd: + case OpAtomicISub: + case OpAtomicSMin: + case OpAtomicUMin: + case OpAtomicSMax: + case OpAtomicUMax: + case OpAtomicAnd: + case OpAtomicOr: + case OpAtomicXor: + case OpArrayLength: + // Invalid SPIR-V. + if (length < 3) + return false; + variable = args[2]; + break; + } + + if (variable) + { + auto *var = compiler.maybe_get(variable); + if (var && storage_class_is_interface(var->storage)) + variables.insert(variable); + } + return true; +} + +unordered_set Compiler::get_active_interface_variables() const +{ + // Traverse the call graph and find all interface variables which are in use. + unordered_set variables; + InterfaceVariableAccessHandler handler(*this, variables); + traverse_all_reachable_opcodes(get(ir.default_entry_point), handler); + + ir.for_each_typed_id([&](uint32_t, const SPIRVariable &var) { + if (var.storage != StorageClassOutput) + return; + if (!interface_variable_exists_in_entry_point(var.self)) + return; + + // An output variable which is just declared (but uninitialized) might be read by subsequent stages + // so we should force-enable these outputs, + // since compilation will fail if a subsequent stage attempts to read from the variable in question. + // Also, make sure we preserve output variables which are only initialized, but never accessed by any code. + if (var.initializer != ID(0) || get_execution_model() != ExecutionModelFragment) + variables.insert(var.self); + }); + + // If we needed to create one, we'll need it. + if (dummy_sampler_id) + variables.insert(dummy_sampler_id); + + return variables; +} + +void Compiler::set_enabled_interface_variables(std::unordered_set active_variables) +{ + active_interface_variables = std::move(active_variables); + check_active_interface_variables = true; +} + +ShaderResources Compiler::get_shader_resources(const unordered_set *active_variables) const +{ + ShaderResources res; + + bool ssbo_instance_name = reflection_ssbo_instance_name_is_significant(); + + ir.for_each_typed_id([&](uint32_t, const SPIRVariable &var) { + auto &type = this->get(var.basetype); + + // It is possible for uniform storage classes to be passed as function parameters, so detect + // that. To detect function parameters, check of StorageClass of variable is function scope. + if (var.storage == StorageClassFunction || !type.pointer) + return; + + if (active_variables && active_variables->find(var.self) == end(*active_variables)) + return; + + // In SPIR-V 1.4 and up, every global must be present in the entry point interface list, + // not just IO variables. + bool active_in_entry_point = true; + if (ir.get_spirv_version() < 0x10400) + { + if (var.storage == StorageClassInput || var.storage == StorageClassOutput) + active_in_entry_point = interface_variable_exists_in_entry_point(var.self); + } + else + active_in_entry_point = interface_variable_exists_in_entry_point(var.self); + + if (!active_in_entry_point) + return; + + bool is_builtin = is_builtin_variable(var); + + if (is_builtin) + { + if (var.storage != StorageClassInput && var.storage != StorageClassOutput) + return; + + auto &list = var.storage == StorageClassInput ? res.builtin_inputs : res.builtin_outputs; + BuiltInResource resource; + + if (has_decoration(type.self, DecorationBlock)) + { + resource.resource = { var.self, var.basetype, type.self, + get_remapped_declared_block_name(var.self, false) }; + + for (uint32_t i = 0; i < uint32_t(type.member_types.size()); i++) + { + resource.value_type_id = type.member_types[i]; + resource.builtin = BuiltIn(get_member_decoration(type.self, i, DecorationBuiltIn)); + list.push_back(resource); + } + } + else + { + bool strip_array = + !has_decoration(var.self, DecorationPatch) && ( + get_execution_model() == ExecutionModelTessellationControl || + (get_execution_model() == ExecutionModelTessellationEvaluation && + var.storage == StorageClassInput)); + + resource.resource = { var.self, var.basetype, type.self, get_name(var.self) }; + + if (strip_array && !type.array.empty()) + resource.value_type_id = get_variable_data_type(var).parent_type; + else + resource.value_type_id = get_variable_data_type_id(var); + + assert(resource.value_type_id); + + resource.builtin = BuiltIn(get_decoration(var.self, DecorationBuiltIn)); + list.push_back(std::move(resource)); + } + return; + } + + // Input + if (var.storage == StorageClassInput) + { + if (has_decoration(type.self, DecorationBlock)) + { + res.stage_inputs.push_back( + { var.self, var.basetype, type.self, + get_remapped_declared_block_name(var.self, false) }); + } + else + res.stage_inputs.push_back({ var.self, var.basetype, type.self, get_name(var.self) }); + } + // Subpass inputs + else if (var.storage == StorageClassUniformConstant && type.image.dim == DimSubpassData) + { + res.subpass_inputs.push_back({ var.self, var.basetype, type.self, get_name(var.self) }); + } + // Outputs + else if (var.storage == StorageClassOutput) + { + if (has_decoration(type.self, DecorationBlock)) + { + res.stage_outputs.push_back( + { var.self, var.basetype, type.self, get_remapped_declared_block_name(var.self, false) }); + } + else + res.stage_outputs.push_back({ var.self, var.basetype, type.self, get_name(var.self) }); + } + // UBOs + else if (type.storage == StorageClassUniform && has_decoration(type.self, DecorationBlock)) + { + res.uniform_buffers.push_back( + { var.self, var.basetype, type.self, get_remapped_declared_block_name(var.self, false) }); + } + // Old way to declare SSBOs. + else if (type.storage == StorageClassUniform && has_decoration(type.self, DecorationBufferBlock)) + { + res.storage_buffers.push_back( + { var.self, var.basetype, type.self, get_remapped_declared_block_name(var.self, ssbo_instance_name) }); + } + // Modern way to declare SSBOs. + else if (type.storage == StorageClassStorageBuffer) + { + res.storage_buffers.push_back( + { var.self, var.basetype, type.self, get_remapped_declared_block_name(var.self, ssbo_instance_name) }); + } + // Push constant blocks + else if (type.storage == StorageClassPushConstant) + { + // There can only be one push constant block, but keep the vector in case this restriction is lifted + // in the future. + res.push_constant_buffers.push_back({ var.self, var.basetype, type.self, get_name(var.self) }); + } + else if (type.storage == StorageClassShaderRecordBufferKHR) + { + res.shader_record_buffers.push_back({ var.self, var.basetype, type.self, get_remapped_declared_block_name(var.self, ssbo_instance_name) }); + } + // Atomic counters + else if (type.storage == StorageClassAtomicCounter) + { + res.atomic_counters.push_back({ var.self, var.basetype, type.self, get_name(var.self) }); + } + else if (type.storage == StorageClassUniformConstant) + { + if (type.basetype == SPIRType::Image) + { + // Images + if (type.image.sampled == 2) + { + res.storage_images.push_back({ var.self, var.basetype, type.self, get_name(var.self) }); + } + // Separate images + else if (type.image.sampled == 1) + { + res.separate_images.push_back({ var.self, var.basetype, type.self, get_name(var.self) }); + } + } + // Separate samplers + else if (type.basetype == SPIRType::Sampler) + { + res.separate_samplers.push_back({ var.self, var.basetype, type.self, get_name(var.self) }); + } + // Textures + else if (type.basetype == SPIRType::SampledImage) + { + res.sampled_images.push_back({ var.self, var.basetype, type.self, get_name(var.self) }); + } + // Acceleration structures + else if (type.basetype == SPIRType::AccelerationStructure) + { + res.acceleration_structures.push_back({ var.self, var.basetype, type.self, get_name(var.self) }); + } + else + { + res.gl_plain_uniforms.push_back({ var.self, var.basetype, type.self, get_name(var.self) }); + } + } + }); + + return res; +} + +bool Compiler::type_is_top_level_block(const SPIRType &type) const +{ + if (type.basetype != SPIRType::Struct) + return false; + return has_decoration(type.self, DecorationBlock) || has_decoration(type.self, DecorationBufferBlock); +} + +bool Compiler::type_is_block_like(const SPIRType &type) const +{ + if (type_is_top_level_block(type)) + return true; + + if (type.basetype == SPIRType::Struct) + { + // Block-like types may have Offset decorations. + for (uint32_t i = 0; i < uint32_t(type.member_types.size()); i++) + if (has_member_decoration(type.self, i, DecorationOffset)) + return true; + } + + return false; +} + +void Compiler::parse_fixup() +{ + // Figure out specialization constants for work group sizes. + for (auto id_ : ir.ids_for_constant_or_variable) + { + auto &id = ir.ids[id_]; + + if (id.get_type() == TypeConstant) + { + auto &c = id.get(); + if (has_decoration(c.self, DecorationBuiltIn) && + BuiltIn(get_decoration(c.self, DecorationBuiltIn)) == BuiltInWorkgroupSize) + { + // In current SPIR-V, there can be just one constant like this. + // All entry points will receive the constant value. + // WorkgroupSize take precedence over LocalSizeId. + for (auto &entry : ir.entry_points) + { + entry.second.workgroup_size.constant = c.self; + entry.second.workgroup_size.x = c.scalar(0, 0); + entry.second.workgroup_size.y = c.scalar(0, 1); + entry.second.workgroup_size.z = c.scalar(0, 2); + } + } + } + else if (id.get_type() == TypeVariable) + { + auto &var = id.get(); + if (var.storage == StorageClassPrivate || var.storage == StorageClassWorkgroup || + var.storage == StorageClassTaskPayloadWorkgroupEXT || + var.storage == StorageClassOutput) + { + global_variables.push_back(var.self); + } + if (variable_storage_is_aliased(var)) + aliased_variables.push_back(var.self); + } + } +} + +void Compiler::update_name_cache(unordered_set &cache_primary, const unordered_set &cache_secondary, + string &name) +{ + if (name.empty()) + return; + + const auto find_name = [&](const string &n) -> bool { + if (cache_primary.find(n) != end(cache_primary)) + return true; + + if (&cache_primary != &cache_secondary) + if (cache_secondary.find(n) != end(cache_secondary)) + return true; + + return false; + }; + + const auto insert_name = [&](const string &n) { cache_primary.insert(n); }; + + if (!find_name(name)) + { + insert_name(name); + return; + } + + uint32_t counter = 0; + auto tmpname = name; + + bool use_linked_underscore = true; + + if (tmpname == "_") + { + // We cannot just append numbers, as we will end up creating internally reserved names. + // Make it like _0_ instead. + tmpname += "0"; + } + else if (tmpname.back() == '_') + { + // The last_character is an underscore, so we don't need to link in underscore. + // This would violate double underscore rules. + use_linked_underscore = false; + } + + // If there is a collision (very rare), + // keep tacking on extra identifier until it's unique. + do + { + counter++; + name = tmpname + (use_linked_underscore ? "_" : "") + convert_to_string(counter); + } while (find_name(name)); + insert_name(name); +} + +void Compiler::update_name_cache(unordered_set &cache, string &name) +{ + update_name_cache(cache, cache, name); +} + +void Compiler::set_name(ID id, const std::string &name) +{ + ir.set_name(id, name); +} + +const SPIRType &Compiler::get_type(TypeID id) const +{ + return get(id); +} + +const SPIRType &Compiler::get_type_from_variable(VariableID id) const +{ + return get(get(id).basetype); +} + +uint32_t Compiler::get_pointee_type_id(uint32_t type_id) const +{ + auto *p_type = &get(type_id); + if (p_type->pointer) + { + assert(p_type->parent_type); + type_id = p_type->parent_type; + } + return type_id; +} + +const SPIRType &Compiler::get_pointee_type(const SPIRType &type) const +{ + auto *p_type = &type; + if (p_type->pointer) + { + assert(p_type->parent_type); + p_type = &get(p_type->parent_type); + } + return *p_type; +} + +const SPIRType &Compiler::get_pointee_type(uint32_t type_id) const +{ + return get_pointee_type(get(type_id)); +} + +uint32_t Compiler::get_variable_data_type_id(const SPIRVariable &var) const +{ + if (var.phi_variable || var.storage == spv::StorageClass::StorageClassAtomicCounter) + return var.basetype; + return get_pointee_type_id(var.basetype); +} + +SPIRType &Compiler::get_variable_data_type(const SPIRVariable &var) +{ + return get(get_variable_data_type_id(var)); +} + +const SPIRType &Compiler::get_variable_data_type(const SPIRVariable &var) const +{ + return get(get_variable_data_type_id(var)); +} + +SPIRType &Compiler::get_variable_element_type(const SPIRVariable &var) +{ + SPIRType *type = &get_variable_data_type(var); + if (is_array(*type)) + type = &get(type->parent_type); + return *type; +} + +const SPIRType &Compiler::get_variable_element_type(const SPIRVariable &var) const +{ + const SPIRType *type = &get_variable_data_type(var); + if (is_array(*type)) + type = &get(type->parent_type); + return *type; +} + +bool Compiler::is_sampled_image_type(const SPIRType &type) +{ + return (type.basetype == SPIRType::Image || type.basetype == SPIRType::SampledImage) && type.image.sampled == 1 && + type.image.dim != DimBuffer; +} + +void Compiler::set_member_decoration_string(TypeID id, uint32_t index, spv::Decoration decoration, + const std::string &argument) +{ + ir.set_member_decoration_string(id, index, decoration, argument); +} + +void Compiler::set_member_decoration(TypeID id, uint32_t index, Decoration decoration, uint32_t argument) +{ + ir.set_member_decoration(id, index, decoration, argument); +} + +void Compiler::set_member_name(TypeID id, uint32_t index, const std::string &name) +{ + ir.set_member_name(id, index, name); +} + +const std::string &Compiler::get_member_name(TypeID id, uint32_t index) const +{ + return ir.get_member_name(id, index); +} + +void Compiler::set_qualified_name(uint32_t id, const string &name) +{ + ir.meta[id].decoration.qualified_alias = name; +} + +void Compiler::set_member_qualified_name(uint32_t type_id, uint32_t index, const std::string &name) +{ + ir.meta[type_id].members.resize(max(ir.meta[type_id].members.size(), size_t(index) + 1)); + ir.meta[type_id].members[index].qualified_alias = name; +} + +const string &Compiler::get_member_qualified_name(TypeID type_id, uint32_t index) const +{ + auto *m = ir.find_meta(type_id); + if (m && index < m->members.size()) + return m->members[index].qualified_alias; + else + return ir.get_empty_string(); +} + +uint32_t Compiler::get_member_decoration(TypeID id, uint32_t index, Decoration decoration) const +{ + return ir.get_member_decoration(id, index, decoration); +} + +const Bitset &Compiler::get_member_decoration_bitset(TypeID id, uint32_t index) const +{ + return ir.get_member_decoration_bitset(id, index); +} + +bool Compiler::has_member_decoration(TypeID id, uint32_t index, Decoration decoration) const +{ + return ir.has_member_decoration(id, index, decoration); +} + +void Compiler::unset_member_decoration(TypeID id, uint32_t index, Decoration decoration) +{ + ir.unset_member_decoration(id, index, decoration); +} + +void Compiler::set_decoration_string(ID id, spv::Decoration decoration, const std::string &argument) +{ + ir.set_decoration_string(id, decoration, argument); +} + +void Compiler::set_decoration(ID id, Decoration decoration, uint32_t argument) +{ + ir.set_decoration(id, decoration, argument); +} + +void Compiler::set_extended_decoration(uint32_t id, ExtendedDecorations decoration, uint32_t value) +{ + auto &dec = ir.meta[id].decoration; + dec.extended.flags.set(decoration); + dec.extended.values[decoration] = value; +} + +void Compiler::set_extended_member_decoration(uint32_t type, uint32_t index, ExtendedDecorations decoration, + uint32_t value) +{ + ir.meta[type].members.resize(max(ir.meta[type].members.size(), size_t(index) + 1)); + auto &dec = ir.meta[type].members[index]; + dec.extended.flags.set(decoration); + dec.extended.values[decoration] = value; +} + +static uint32_t get_default_extended_decoration(ExtendedDecorations decoration) +{ + switch (decoration) + { + case SPIRVCrossDecorationResourceIndexPrimary: + case SPIRVCrossDecorationResourceIndexSecondary: + case SPIRVCrossDecorationResourceIndexTertiary: + case SPIRVCrossDecorationResourceIndexQuaternary: + case SPIRVCrossDecorationInterfaceMemberIndex: + return ~(0u); + + default: + return 0; + } +} + +uint32_t Compiler::get_extended_decoration(uint32_t id, ExtendedDecorations decoration) const +{ + auto *m = ir.find_meta(id); + if (!m) + return 0; + + auto &dec = m->decoration; + + if (!dec.extended.flags.get(decoration)) + return get_default_extended_decoration(decoration); + + return dec.extended.values[decoration]; +} + +uint32_t Compiler::get_extended_member_decoration(uint32_t type, uint32_t index, ExtendedDecorations decoration) const +{ + auto *m = ir.find_meta(type); + if (!m) + return 0; + + if (index >= m->members.size()) + return 0; + + auto &dec = m->members[index]; + if (!dec.extended.flags.get(decoration)) + return get_default_extended_decoration(decoration); + return dec.extended.values[decoration]; +} + +bool Compiler::has_extended_decoration(uint32_t id, ExtendedDecorations decoration) const +{ + auto *m = ir.find_meta(id); + if (!m) + return false; + + auto &dec = m->decoration; + return dec.extended.flags.get(decoration); +} + +bool Compiler::has_extended_member_decoration(uint32_t type, uint32_t index, ExtendedDecorations decoration) const +{ + auto *m = ir.find_meta(type); + if (!m) + return false; + + if (index >= m->members.size()) + return false; + + auto &dec = m->members[index]; + return dec.extended.flags.get(decoration); +} + +void Compiler::unset_extended_decoration(uint32_t id, ExtendedDecorations decoration) +{ + auto &dec = ir.meta[id].decoration; + dec.extended.flags.clear(decoration); + dec.extended.values[decoration] = 0; +} + +void Compiler::unset_extended_member_decoration(uint32_t type, uint32_t index, ExtendedDecorations decoration) +{ + ir.meta[type].members.resize(max(ir.meta[type].members.size(), size_t(index) + 1)); + auto &dec = ir.meta[type].members[index]; + dec.extended.flags.clear(decoration); + dec.extended.values[decoration] = 0; +} + +StorageClass Compiler::get_storage_class(VariableID id) const +{ + return get(id).storage; +} + +const std::string &Compiler::get_name(ID id) const +{ + return ir.get_name(id); +} + +const std::string Compiler::get_fallback_name(ID id) const +{ + return join("_", id); +} + +const std::string Compiler::get_block_fallback_name(VariableID id) const +{ + auto &var = get(id); + if (get_name(id).empty()) + return join("_", get(var.basetype).self, "_", id); + else + return get_name(id); +} + +const Bitset &Compiler::get_decoration_bitset(ID id) const +{ + return ir.get_decoration_bitset(id); +} + +bool Compiler::has_decoration(ID id, Decoration decoration) const +{ + return ir.has_decoration(id, decoration); +} + +const string &Compiler::get_decoration_string(ID id, Decoration decoration) const +{ + return ir.get_decoration_string(id, decoration); +} + +const string &Compiler::get_member_decoration_string(TypeID id, uint32_t index, Decoration decoration) const +{ + return ir.get_member_decoration_string(id, index, decoration); +} + +uint32_t Compiler::get_decoration(ID id, Decoration decoration) const +{ + return ir.get_decoration(id, decoration); +} + +void Compiler::unset_decoration(ID id, Decoration decoration) +{ + ir.unset_decoration(id, decoration); +} + +bool Compiler::get_binary_offset_for_decoration(VariableID id, spv::Decoration decoration, uint32_t &word_offset) const +{ + auto *m = ir.find_meta(id); + if (!m) + return false; + + auto &word_offsets = m->decoration_word_offset; + auto itr = word_offsets.find(decoration); + if (itr == end(word_offsets)) + return false; + + word_offset = itr->second; + return true; +} + +bool Compiler::block_is_noop(const SPIRBlock &block) const +{ + if (block.terminator != SPIRBlock::Direct) + return false; + + auto &child = get(block.next_block); + + // If this block participates in PHI, the block isn't really noop. + for (auto &phi : block.phi_variables) + if (phi.parent == block.self || phi.parent == child.self) + return false; + + for (auto &phi : child.phi_variables) + if (phi.parent == block.self) + return false; + + // Verify all instructions have no semantic impact. + for (auto &i : block.ops) + { + auto op = static_cast(i.op); + + switch (op) + { + // Non-Semantic instructions. + case OpLine: + case OpNoLine: + break; + + case OpExtInst: + { + auto *ops = stream(i); + auto ext = get(ops[2]).ext; + + bool ext_is_nonsemantic_only = + ext == SPIRExtension::NonSemanticShaderDebugInfo || + ext == SPIRExtension::SPV_debug_info || + ext == SPIRExtension::NonSemanticGeneric; + + if (!ext_is_nonsemantic_only) + return false; + + break; + } + + default: + return false; + } + } + + return true; +} + +bool Compiler::block_is_loop_candidate(const SPIRBlock &block, SPIRBlock::Method method) const +{ + // Tried and failed. + if (block.disable_block_optimization || block.complex_continue) + return false; + + if (method == SPIRBlock::MergeToSelectForLoop || method == SPIRBlock::MergeToSelectContinueForLoop) + { + // Try to detect common for loop pattern + // which the code backend can use to create cleaner code. + // for(;;) { if (cond) { some_body; } else { break; } } + // is the pattern we're looking for. + const auto *false_block = maybe_get(block.false_block); + const auto *true_block = maybe_get(block.true_block); + const auto *merge_block = maybe_get(block.merge_block); + + bool false_block_is_merge = block.false_block == block.merge_block || + (false_block && merge_block && execution_is_noop(*false_block, *merge_block)); + + bool true_block_is_merge = block.true_block == block.merge_block || + (true_block && merge_block && execution_is_noop(*true_block, *merge_block)); + + bool positive_candidate = + block.true_block != block.merge_block && block.true_block != block.self && false_block_is_merge; + + bool negative_candidate = + block.false_block != block.merge_block && block.false_block != block.self && true_block_is_merge; + + bool ret = block.terminator == SPIRBlock::Select && block.merge == SPIRBlock::MergeLoop && + (positive_candidate || negative_candidate); + + if (ret && positive_candidate && method == SPIRBlock::MergeToSelectContinueForLoop) + ret = block.true_block == block.continue_block; + else if (ret && negative_candidate && method == SPIRBlock::MergeToSelectContinueForLoop) + ret = block.false_block == block.continue_block; + + // If we have OpPhi which depends on branches which came from our own block, + // we need to flush phi variables in else block instead of a trivial break, + // so we cannot assume this is a for loop candidate. + if (ret) + { + for (auto &phi : block.phi_variables) + if (phi.parent == block.self) + return false; + + auto *merge = maybe_get(block.merge_block); + if (merge) + for (auto &phi : merge->phi_variables) + if (phi.parent == block.self) + return false; + } + return ret; + } + else if (method == SPIRBlock::MergeToDirectForLoop) + { + // Empty loop header that just sets up merge target + // and branches to loop body. + bool ret = block.terminator == SPIRBlock::Direct && block.merge == SPIRBlock::MergeLoop && block_is_noop(block); + + if (!ret) + return false; + + auto &child = get(block.next_block); + + const auto *false_block = maybe_get(child.false_block); + const auto *true_block = maybe_get(child.true_block); + const auto *merge_block = maybe_get(block.merge_block); + + bool false_block_is_merge = child.false_block == block.merge_block || + (false_block && merge_block && execution_is_noop(*false_block, *merge_block)); + + bool true_block_is_merge = child.true_block == block.merge_block || + (true_block && merge_block && execution_is_noop(*true_block, *merge_block)); + + bool positive_candidate = + child.true_block != block.merge_block && child.true_block != block.self && false_block_is_merge; + + bool negative_candidate = + child.false_block != block.merge_block && child.false_block != block.self && true_block_is_merge; + + ret = child.terminator == SPIRBlock::Select && child.merge == SPIRBlock::MergeNone && + (positive_candidate || negative_candidate); + + if (ret) + { + auto *merge = maybe_get(block.merge_block); + if (merge) + for (auto &phi : merge->phi_variables) + if (phi.parent == block.self || phi.parent == child.false_block) + return false; + } + + return ret; + } + else + return false; +} + +bool Compiler::execution_is_noop(const SPIRBlock &from, const SPIRBlock &to) const +{ + if (!execution_is_branchless(from, to)) + return false; + + auto *start = &from; + for (;;) + { + if (start->self == to.self) + return true; + + if (!block_is_noop(*start)) + return false; + + auto &next = get(start->next_block); + start = &next; + } +} + +bool Compiler::execution_is_branchless(const SPIRBlock &from, const SPIRBlock &to) const +{ + auto *start = &from; + for (;;) + { + if (start->self == to.self) + return true; + + if (start->terminator == SPIRBlock::Direct && start->merge == SPIRBlock::MergeNone) + start = &get(start->next_block); + else + return false; + } +} + +bool Compiler::execution_is_direct_branch(const SPIRBlock &from, const SPIRBlock &to) const +{ + return from.terminator == SPIRBlock::Direct && from.merge == SPIRBlock::MergeNone && from.next_block == to.self; +} + +SPIRBlock::ContinueBlockType Compiler::continue_block_type(const SPIRBlock &block) const +{ + // The block was deemed too complex during code emit, pick conservative fallback paths. + if (block.complex_continue) + return SPIRBlock::ComplexLoop; + + // In older glslang output continue block can be equal to the loop header. + // In this case, execution is clearly branchless, so just assume a while loop header here. + if (block.merge == SPIRBlock::MergeLoop) + return SPIRBlock::WhileLoop; + + if (block.loop_dominator == BlockID(SPIRBlock::NoDominator)) + { + // Continue block is never reached from CFG. + return SPIRBlock::ComplexLoop; + } + + auto &dominator = get(block.loop_dominator); + + if (execution_is_noop(block, dominator)) + return SPIRBlock::WhileLoop; + else if (execution_is_branchless(block, dominator)) + return SPIRBlock::ForLoop; + else + { + const auto *false_block = maybe_get(block.false_block); + const auto *true_block = maybe_get(block.true_block); + const auto *merge_block = maybe_get(dominator.merge_block); + + // If we need to flush Phi in this block, we cannot have a DoWhile loop. + bool flush_phi_to_false = false_block && flush_phi_required(block.self, block.false_block); + bool flush_phi_to_true = true_block && flush_phi_required(block.self, block.true_block); + if (flush_phi_to_false || flush_phi_to_true) + return SPIRBlock::ComplexLoop; + + bool positive_do_while = block.true_block == dominator.self && + (block.false_block == dominator.merge_block || + (false_block && merge_block && execution_is_noop(*false_block, *merge_block))); + + bool negative_do_while = block.false_block == dominator.self && + (block.true_block == dominator.merge_block || + (true_block && merge_block && execution_is_noop(*true_block, *merge_block))); + + if (block.merge == SPIRBlock::MergeNone && block.terminator == SPIRBlock::Select && + (positive_do_while || negative_do_while)) + { + return SPIRBlock::DoWhileLoop; + } + else + return SPIRBlock::ComplexLoop; + } +} + +const SmallVector &Compiler::get_case_list(const SPIRBlock &block) const +{ + uint32_t width = 0; + + // First we check if we can get the type directly from the block.condition + // since it can be a SPIRConstant or a SPIRVariable. + if (const auto *constant = maybe_get(block.condition)) + { + const auto &type = get(constant->constant_type); + width = type.width; + } + else if (const auto *var = maybe_get(block.condition)) + { + const auto &type = get(var->basetype); + width = type.width; + } + else if (const auto *undef = maybe_get(block.condition)) + { + const auto &type = get(undef->basetype); + width = type.width; + } + else + { + auto search = ir.load_type_width.find(block.condition); + if (search == ir.load_type_width.end()) + { + SPIRV_CROSS_THROW("Use of undeclared variable on a switch statement."); + } + + width = search->second; + } + + if (width > 32) + return block.cases_64bit; + + return block.cases_32bit; +} + +bool Compiler::traverse_all_reachable_opcodes(const SPIRBlock &block, OpcodeHandler &handler) const +{ + handler.set_current_block(block); + handler.rearm_current_block(block); + + // Ideally, perhaps traverse the CFG instead of all blocks in order to eliminate dead blocks, + // but this shouldn't be a problem in practice unless the SPIR-V is doing insane things like recursing + // inside dead blocks ... + for (auto &i : block.ops) + { + auto ops = stream(i); + auto op = static_cast(i.op); + + if (!handler.handle(op, ops, i.length)) + return false; + + if (op == OpFunctionCall) + { + auto &func = get(ops[2]); + if (handler.follow_function_call(func)) + { + if (!handler.begin_function_scope(ops, i.length)) + return false; + if (!traverse_all_reachable_opcodes(get(ops[2]), handler)) + return false; + if (!handler.end_function_scope(ops, i.length)) + return false; + + handler.rearm_current_block(block); + } + } + } + + if (!handler.handle_terminator(block)) + return false; + + return true; +} + +bool Compiler::traverse_all_reachable_opcodes(const SPIRFunction &func, OpcodeHandler &handler) const +{ + for (auto block : func.blocks) + if (!traverse_all_reachable_opcodes(get(block), handler)) + return false; + + return true; +} + +uint32_t Compiler::type_struct_member_offset(const SPIRType &type, uint32_t index) const +{ + auto *type_meta = ir.find_meta(type.self); + if (type_meta) + { + // Decoration must be set in valid SPIR-V, otherwise throw. + auto &dec = type_meta->members[index]; + if (dec.decoration_flags.get(DecorationOffset)) + return dec.offset; + else + SPIRV_CROSS_THROW("Struct member does not have Offset set."); + } + else + SPIRV_CROSS_THROW("Struct member does not have Offset set."); +} + +uint32_t Compiler::type_struct_member_array_stride(const SPIRType &type, uint32_t index) const +{ + auto *type_meta = ir.find_meta(type.member_types[index]); + if (type_meta) + { + // Decoration must be set in valid SPIR-V, otherwise throw. + // ArrayStride is part of the array type not OpMemberDecorate. + auto &dec = type_meta->decoration; + if (dec.decoration_flags.get(DecorationArrayStride)) + return dec.array_stride; + else + SPIRV_CROSS_THROW("Struct member does not have ArrayStride set."); + } + else + SPIRV_CROSS_THROW("Struct member does not have ArrayStride set."); +} + +uint32_t Compiler::type_struct_member_matrix_stride(const SPIRType &type, uint32_t index) const +{ + auto *type_meta = ir.find_meta(type.self); + if (type_meta) + { + // Decoration must be set in valid SPIR-V, otherwise throw. + // MatrixStride is part of OpMemberDecorate. + auto &dec = type_meta->members[index]; + if (dec.decoration_flags.get(DecorationMatrixStride)) + return dec.matrix_stride; + else + SPIRV_CROSS_THROW("Struct member does not have MatrixStride set."); + } + else + SPIRV_CROSS_THROW("Struct member does not have MatrixStride set."); +} + +size_t Compiler::get_declared_struct_size(const SPIRType &type) const +{ + if (type.member_types.empty()) + SPIRV_CROSS_THROW("Declared struct in block cannot be empty."); + + // Offsets can be declared out of order, so we need to deduce the actual size + // based on last member instead. + uint32_t member_index = 0; + size_t highest_offset = 0; + for (uint32_t i = 0; i < uint32_t(type.member_types.size()); i++) + { + size_t offset = type_struct_member_offset(type, i); + if (offset > highest_offset) + { + highest_offset = offset; + member_index = i; + } + } + + size_t size = get_declared_struct_member_size(type, member_index); + return highest_offset + size; +} + +size_t Compiler::get_declared_struct_size_runtime_array(const SPIRType &type, size_t array_size) const +{ + if (type.member_types.empty()) + SPIRV_CROSS_THROW("Declared struct in block cannot be empty."); + + size_t size = get_declared_struct_size(type); + auto &last_type = get(type.member_types.back()); + if (!last_type.array.empty() && last_type.array_size_literal[0] && last_type.array[0] == 0) // Runtime array + size += array_size * type_struct_member_array_stride(type, uint32_t(type.member_types.size() - 1)); + + return size; +} + +uint32_t Compiler::evaluate_spec_constant_u32(const SPIRConstantOp &spec) const +{ + auto &result_type = get(spec.basetype); + if (result_type.basetype != SPIRType::UInt && result_type.basetype != SPIRType::Int && + result_type.basetype != SPIRType::Boolean) + { + SPIRV_CROSS_THROW( + "Only 32-bit integers and booleans are currently supported when evaluating specialization constants.\n"); + } + + if (!is_scalar(result_type)) + SPIRV_CROSS_THROW("Spec constant evaluation must be a scalar.\n"); + + uint32_t value = 0; + + const auto eval_u32 = [&](uint32_t id) -> uint32_t { + auto &type = expression_type(id); + if (type.basetype != SPIRType::UInt && type.basetype != SPIRType::Int && type.basetype != SPIRType::Boolean) + { + SPIRV_CROSS_THROW("Only 32-bit integers and booleans are currently supported when evaluating " + "specialization constants.\n"); + } + + if (!is_scalar(type)) + SPIRV_CROSS_THROW("Spec constant evaluation must be a scalar.\n"); + if (const auto *c = this->maybe_get(id)) + return c->scalar(); + else + return evaluate_spec_constant_u32(this->get(id)); + }; + +#define binary_spec_op(op, binary_op) \ + case Op##op: \ + value = eval_u32(spec.arguments[0]) binary_op eval_u32(spec.arguments[1]); \ + break +#define binary_spec_op_cast(op, binary_op, type) \ + case Op##op: \ + value = uint32_t(type(eval_u32(spec.arguments[0])) binary_op type(eval_u32(spec.arguments[1]))); \ + break + + // Support the basic opcodes which are typically used when computing array sizes. + switch (spec.opcode) + { + binary_spec_op(IAdd, +); + binary_spec_op(ISub, -); + binary_spec_op(IMul, *); + binary_spec_op(BitwiseAnd, &); + binary_spec_op(BitwiseOr, |); + binary_spec_op(BitwiseXor, ^); + binary_spec_op(LogicalAnd, &); + binary_spec_op(LogicalOr, |); + binary_spec_op(ShiftLeftLogical, <<); + binary_spec_op(ShiftRightLogical, >>); + binary_spec_op_cast(ShiftRightArithmetic, >>, int32_t); + binary_spec_op(LogicalEqual, ==); + binary_spec_op(LogicalNotEqual, !=); + binary_spec_op(IEqual, ==); + binary_spec_op(INotEqual, !=); + binary_spec_op(ULessThan, <); + binary_spec_op(ULessThanEqual, <=); + binary_spec_op(UGreaterThan, >); + binary_spec_op(UGreaterThanEqual, >=); + binary_spec_op_cast(SLessThan, <, int32_t); + binary_spec_op_cast(SLessThanEqual, <=, int32_t); + binary_spec_op_cast(SGreaterThan, >, int32_t); + binary_spec_op_cast(SGreaterThanEqual, >=, int32_t); +#undef binary_spec_op +#undef binary_spec_op_cast + + case OpLogicalNot: + value = uint32_t(!eval_u32(spec.arguments[0])); + break; + + case OpNot: + value = ~eval_u32(spec.arguments[0]); + break; + + case OpSNegate: + value = uint32_t(-int32_t(eval_u32(spec.arguments[0]))); + break; + + case OpSelect: + value = eval_u32(spec.arguments[0]) ? eval_u32(spec.arguments[1]) : eval_u32(spec.arguments[2]); + break; + + case OpUMod: + { + uint32_t a = eval_u32(spec.arguments[0]); + uint32_t b = eval_u32(spec.arguments[1]); + if (b == 0) + SPIRV_CROSS_THROW("Undefined behavior in UMod, b == 0.\n"); + value = a % b; + break; + } + + case OpSRem: + { + auto a = int32_t(eval_u32(spec.arguments[0])); + auto b = int32_t(eval_u32(spec.arguments[1])); + if (b == 0) + SPIRV_CROSS_THROW("Undefined behavior in SRem, b == 0.\n"); + value = a % b; + break; + } + + case OpSMod: + { + auto a = int32_t(eval_u32(spec.arguments[0])); + auto b = int32_t(eval_u32(spec.arguments[1])); + if (b == 0) + SPIRV_CROSS_THROW("Undefined behavior in SMod, b == 0.\n"); + auto v = a % b; + + // Makes sure we match the sign of b, not a. + if ((b < 0 && v > 0) || (b > 0 && v < 0)) + v += b; + value = v; + break; + } + + case OpUDiv: + { + uint32_t a = eval_u32(spec.arguments[0]); + uint32_t b = eval_u32(spec.arguments[1]); + if (b == 0) + SPIRV_CROSS_THROW("Undefined behavior in UDiv, b == 0.\n"); + value = a / b; + break; + } + + case OpSDiv: + { + auto a = int32_t(eval_u32(spec.arguments[0])); + auto b = int32_t(eval_u32(spec.arguments[1])); + if (b == 0) + SPIRV_CROSS_THROW("Undefined behavior in SDiv, b == 0.\n"); + value = a / b; + break; + } + + default: + SPIRV_CROSS_THROW("Unsupported spec constant opcode for evaluation.\n"); + } + + return value; +} + +uint32_t Compiler::evaluate_constant_u32(uint32_t id) const +{ + if (const auto *c = maybe_get(id)) + return c->scalar(); + else + return evaluate_spec_constant_u32(get(id)); +} + +size_t Compiler::get_declared_struct_member_size(const SPIRType &struct_type, uint32_t index) const +{ + if (struct_type.member_types.empty()) + SPIRV_CROSS_THROW("Declared struct in block cannot be empty."); + + auto &flags = get_member_decoration_bitset(struct_type.self, index); + auto &type = get(struct_type.member_types[index]); + + switch (type.basetype) + { + case SPIRType::Unknown: + case SPIRType::Void: + case SPIRType::Boolean: // Bools are purely logical, and cannot be used for externally visible types. + case SPIRType::AtomicCounter: + case SPIRType::Image: + case SPIRType::SampledImage: + case SPIRType::Sampler: + SPIRV_CROSS_THROW("Querying size for object with opaque size."); + + default: + break; + } + + if (type.pointer && type.storage == StorageClassPhysicalStorageBuffer) + { + // Check if this is a top-level pointer type, and not an array of pointers. + if (type.pointer_depth > get(type.parent_type).pointer_depth) + return 8; + } + + if (!type.array.empty()) + { + // For arrays, we can use ArrayStride to get an easy check. + bool array_size_literal = type.array_size_literal.back(); + uint32_t array_size = array_size_literal ? type.array.back() : evaluate_constant_u32(type.array.back()); + return type_struct_member_array_stride(struct_type, index) * array_size; + } + else if (type.basetype == SPIRType::Struct) + { + return get_declared_struct_size(type); + } + else + { + unsigned vecsize = type.vecsize; + unsigned columns = type.columns; + + // Vectors. + if (columns == 1) + { + size_t component_size = type.width / 8; + return vecsize * component_size; + } + else + { + uint32_t matrix_stride = type_struct_member_matrix_stride(struct_type, index); + + // Per SPIR-V spec, matrices must be tightly packed and aligned up for vec3 accesses. + if (flags.get(DecorationRowMajor)) + return matrix_stride * vecsize; + else if (flags.get(DecorationColMajor)) + return matrix_stride * columns; + else + SPIRV_CROSS_THROW("Either row-major or column-major must be declared for matrices."); + } + } +} + +bool Compiler::BufferAccessHandler::handle(Op opcode, const uint32_t *args, uint32_t length) +{ + if (opcode != OpAccessChain && opcode != OpInBoundsAccessChain && opcode != OpPtrAccessChain) + return true; + + bool ptr_chain = (opcode == OpPtrAccessChain); + + // Invalid SPIR-V. + if (length < (ptr_chain ? 5u : 4u)) + return false; + + if (args[2] != id) + return true; + + // Don't bother traversing the entire access chain tree yet. + // If we access a struct member, assume we access the entire member. + uint32_t index = compiler.get(args[ptr_chain ? 4 : 3]).scalar(); + + // Seen this index already. + if (seen.find(index) != end(seen)) + return true; + seen.insert(index); + + auto &type = compiler.expression_type(id); + uint32_t offset = compiler.type_struct_member_offset(type, index); + + size_t range; + // If we have another member in the struct, deduce the range by looking at the next member. + // This is okay since structs in SPIR-V can have padding, but Offset decoration must be + // monotonically increasing. + // Of course, this doesn't take into account if the SPIR-V for some reason decided to add + // very large amounts of padding, but that's not really a big deal. + if (index + 1 < type.member_types.size()) + { + range = compiler.type_struct_member_offset(type, index + 1) - offset; + } + else + { + // No padding, so just deduce it from the size of the member directly. + range = compiler.get_declared_struct_member_size(type, index); + } + + ranges.push_back({ index, offset, range }); + return true; +} + +SmallVector Compiler::get_active_buffer_ranges(VariableID id) const +{ + SmallVector ranges; + BufferAccessHandler handler(*this, ranges, id); + traverse_all_reachable_opcodes(get(ir.default_entry_point), handler); + return ranges; +} + +bool Compiler::types_are_logically_equivalent(const SPIRType &a, const SPIRType &b) const +{ + if (a.basetype != b.basetype) + return false; + if (a.width != b.width) + return false; + if (a.vecsize != b.vecsize) + return false; + if (a.columns != b.columns) + return false; + if (a.array.size() != b.array.size()) + return false; + + size_t array_count = a.array.size(); + if (array_count && memcmp(a.array.data(), b.array.data(), array_count * sizeof(uint32_t)) != 0) + return false; + + if (a.basetype == SPIRType::Image || a.basetype == SPIRType::SampledImage) + { + if (memcmp(&a.image, &b.image, sizeof(SPIRType::Image)) != 0) + return false; + } + + if (a.member_types.size() != b.member_types.size()) + return false; + + size_t member_types = a.member_types.size(); + for (size_t i = 0; i < member_types; i++) + { + if (!types_are_logically_equivalent(get(a.member_types[i]), get(b.member_types[i]))) + return false; + } + + return true; +} + +const Bitset &Compiler::get_execution_mode_bitset() const +{ + return get_entry_point().flags; +} + +void Compiler::set_execution_mode(ExecutionMode mode, uint32_t arg0, uint32_t arg1, uint32_t arg2) +{ + auto &execution = get_entry_point(); + + execution.flags.set(mode); + switch (mode) + { + case ExecutionModeLocalSize: + execution.workgroup_size.x = arg0; + execution.workgroup_size.y = arg1; + execution.workgroup_size.z = arg2; + break; + + case ExecutionModeLocalSizeId: + execution.workgroup_size.id_x = arg0; + execution.workgroup_size.id_y = arg1; + execution.workgroup_size.id_z = arg2; + break; + + case ExecutionModeInvocations: + execution.invocations = arg0; + break; + + case ExecutionModeOutputVertices: + execution.output_vertices = arg0; + break; + + case ExecutionModeOutputPrimitivesEXT: + execution.output_primitives = arg0; + break; + + default: + break; + } +} + +void Compiler::unset_execution_mode(ExecutionMode mode) +{ + auto &execution = get_entry_point(); + execution.flags.clear(mode); +} + +uint32_t Compiler::get_work_group_size_specialization_constants(SpecializationConstant &x, SpecializationConstant &y, + SpecializationConstant &z) const +{ + auto &execution = get_entry_point(); + x = { 0, 0 }; + y = { 0, 0 }; + z = { 0, 0 }; + + // WorkgroupSize builtin takes precedence over LocalSize / LocalSizeId. + if (execution.workgroup_size.constant != 0) + { + auto &c = get(execution.workgroup_size.constant); + + if (c.m.c[0].id[0] != ID(0)) + { + x.id = c.m.c[0].id[0]; + x.constant_id = get_decoration(c.m.c[0].id[0], DecorationSpecId); + } + + if (c.m.c[0].id[1] != ID(0)) + { + y.id = c.m.c[0].id[1]; + y.constant_id = get_decoration(c.m.c[0].id[1], DecorationSpecId); + } + + if (c.m.c[0].id[2] != ID(0)) + { + z.id = c.m.c[0].id[2]; + z.constant_id = get_decoration(c.m.c[0].id[2], DecorationSpecId); + } + } + else if (execution.flags.get(ExecutionModeLocalSizeId)) + { + auto &cx = get(execution.workgroup_size.id_x); + if (cx.specialization) + { + x.id = execution.workgroup_size.id_x; + x.constant_id = get_decoration(execution.workgroup_size.id_x, DecorationSpecId); + } + + auto &cy = get(execution.workgroup_size.id_y); + if (cy.specialization) + { + y.id = execution.workgroup_size.id_y; + y.constant_id = get_decoration(execution.workgroup_size.id_y, DecorationSpecId); + } + + auto &cz = get(execution.workgroup_size.id_z); + if (cz.specialization) + { + z.id = execution.workgroup_size.id_z; + z.constant_id = get_decoration(execution.workgroup_size.id_z, DecorationSpecId); + } + } + + return execution.workgroup_size.constant; +} + +uint32_t Compiler::get_execution_mode_argument(spv::ExecutionMode mode, uint32_t index) const +{ + auto &execution = get_entry_point(); + switch (mode) + { + case ExecutionModeLocalSizeId: + if (execution.flags.get(ExecutionModeLocalSizeId)) + { + switch (index) + { + case 0: + return execution.workgroup_size.id_x; + case 1: + return execution.workgroup_size.id_y; + case 2: + return execution.workgroup_size.id_z; + default: + return 0; + } + } + else + return 0; + + case ExecutionModeLocalSize: + switch (index) + { + case 0: + if (execution.flags.get(ExecutionModeLocalSizeId) && execution.workgroup_size.id_x != 0) + return get(execution.workgroup_size.id_x).scalar(); + else + return execution.workgroup_size.x; + case 1: + if (execution.flags.get(ExecutionModeLocalSizeId) && execution.workgroup_size.id_y != 0) + return get(execution.workgroup_size.id_y).scalar(); + else + return execution.workgroup_size.y; + case 2: + if (execution.flags.get(ExecutionModeLocalSizeId) && execution.workgroup_size.id_z != 0) + return get(execution.workgroup_size.id_z).scalar(); + else + return execution.workgroup_size.z; + default: + return 0; + } + + case ExecutionModeInvocations: + return execution.invocations; + + case ExecutionModeOutputVertices: + return execution.output_vertices; + + case ExecutionModeOutputPrimitivesEXT: + return execution.output_primitives; + + default: + return 0; + } +} + +ExecutionModel Compiler::get_execution_model() const +{ + auto &execution = get_entry_point(); + return execution.model; +} + +bool Compiler::is_tessellation_shader(ExecutionModel model) +{ + return model == ExecutionModelTessellationControl || model == ExecutionModelTessellationEvaluation; +} + +bool Compiler::is_vertex_like_shader() const +{ + auto model = get_execution_model(); + return model == ExecutionModelVertex || model == ExecutionModelGeometry || + model == ExecutionModelTessellationControl || model == ExecutionModelTessellationEvaluation; +} + +bool Compiler::is_tessellation_shader() const +{ + return is_tessellation_shader(get_execution_model()); +} + +bool Compiler::is_tessellating_triangles() const +{ + return get_execution_mode_bitset().get(ExecutionModeTriangles); +} + +void Compiler::set_remapped_variable_state(VariableID id, bool remap_enable) +{ + get(id).remapped_variable = remap_enable; +} + +bool Compiler::get_remapped_variable_state(VariableID id) const +{ + return get(id).remapped_variable; +} + +void Compiler::set_subpass_input_remapped_components(VariableID id, uint32_t components) +{ + get(id).remapped_components = components; +} + +uint32_t Compiler::get_subpass_input_remapped_components(VariableID id) const +{ + return get(id).remapped_components; +} + +void Compiler::add_implied_read_expression(SPIRExpression &e, uint32_t source) +{ + auto itr = find(begin(e.implied_read_expressions), end(e.implied_read_expressions), ID(source)); + if (itr == end(e.implied_read_expressions)) + e.implied_read_expressions.push_back(source); +} + +void Compiler::add_implied_read_expression(SPIRAccessChain &e, uint32_t source) +{ + auto itr = find(begin(e.implied_read_expressions), end(e.implied_read_expressions), ID(source)); + if (itr == end(e.implied_read_expressions)) + e.implied_read_expressions.push_back(source); +} + +void Compiler::add_active_interface_variable(uint32_t var_id) +{ + active_interface_variables.insert(var_id); + + // In SPIR-V 1.4 and up we must also track the interface variable in the entry point. + if (ir.get_spirv_version() >= 0x10400) + { + auto &vars = get_entry_point().interface_variables; + if (find(begin(vars), end(vars), VariableID(var_id)) == end(vars)) + vars.push_back(var_id); + } +} + +void Compiler::inherit_expression_dependencies(uint32_t dst, uint32_t source_expression) +{ + // Don't inherit any expression dependencies if the expression in dst + // is not a forwarded temporary. + if (forwarded_temporaries.find(dst) == end(forwarded_temporaries) || + forced_temporaries.find(dst) != end(forced_temporaries)) + { + return; + } + + auto &e = get(dst); + auto *phi = maybe_get(source_expression); + if (phi && phi->phi_variable) + { + // We have used a phi variable, which can change at the end of the block, + // so make sure we take a dependency on this phi variable. + phi->dependees.push_back(dst); + } + + auto *s = maybe_get(source_expression); + if (!s) + return; + + auto &e_deps = e.expression_dependencies; + auto &s_deps = s->expression_dependencies; + + // If we depend on a expression, we also depend on all sub-dependencies from source. + e_deps.push_back(source_expression); + e_deps.insert(end(e_deps), begin(s_deps), end(s_deps)); + + // Eliminate duplicated dependencies. + sort(begin(e_deps), end(e_deps)); + e_deps.erase(unique(begin(e_deps), end(e_deps)), end(e_deps)); +} + +SmallVector Compiler::get_entry_points_and_stages() const +{ + SmallVector entries; + for (auto &entry : ir.entry_points) + entries.push_back({ entry.second.orig_name, entry.second.model }); + return entries; +} + +void Compiler::rename_entry_point(const std::string &old_name, const std::string &new_name, spv::ExecutionModel model) +{ + auto &entry = get_entry_point(old_name, model); + entry.orig_name = new_name; + entry.name = new_name; +} + +void Compiler::set_entry_point(const std::string &name, spv::ExecutionModel model) +{ + auto &entry = get_entry_point(name, model); + ir.default_entry_point = entry.self; +} + +SPIREntryPoint &Compiler::get_first_entry_point(const std::string &name) +{ + auto itr = find_if( + begin(ir.entry_points), end(ir.entry_points), + [&](const std::pair &entry) -> bool { return entry.second.orig_name == name; }); + + if (itr == end(ir.entry_points)) + SPIRV_CROSS_THROW("Entry point does not exist."); + + return itr->second; +} + +const SPIREntryPoint &Compiler::get_first_entry_point(const std::string &name) const +{ + auto itr = find_if( + begin(ir.entry_points), end(ir.entry_points), + [&](const std::pair &entry) -> bool { return entry.second.orig_name == name; }); + + if (itr == end(ir.entry_points)) + SPIRV_CROSS_THROW("Entry point does not exist."); + + return itr->second; +} + +SPIREntryPoint &Compiler::get_entry_point(const std::string &name, ExecutionModel model) +{ + auto itr = find_if(begin(ir.entry_points), end(ir.entry_points), + [&](const std::pair &entry) -> bool { + return entry.second.orig_name == name && entry.second.model == model; + }); + + if (itr == end(ir.entry_points)) + SPIRV_CROSS_THROW("Entry point does not exist."); + + return itr->second; +} + +const SPIREntryPoint &Compiler::get_entry_point(const std::string &name, ExecutionModel model) const +{ + auto itr = find_if(begin(ir.entry_points), end(ir.entry_points), + [&](const std::pair &entry) -> bool { + return entry.second.orig_name == name && entry.second.model == model; + }); + + if (itr == end(ir.entry_points)) + SPIRV_CROSS_THROW("Entry point does not exist."); + + return itr->second; +} + +const string &Compiler::get_cleansed_entry_point_name(const std::string &name, ExecutionModel model) const +{ + return get_entry_point(name, model).name; +} + +const SPIREntryPoint &Compiler::get_entry_point() const +{ + return ir.entry_points.find(ir.default_entry_point)->second; +} + +SPIREntryPoint &Compiler::get_entry_point() +{ + return ir.entry_points.find(ir.default_entry_point)->second; +} + +bool Compiler::interface_variable_exists_in_entry_point(uint32_t id) const +{ + auto &var = get(id); + + if (ir.get_spirv_version() < 0x10400) + { + if (var.storage != StorageClassInput && var.storage != StorageClassOutput && + var.storage != StorageClassUniformConstant) + SPIRV_CROSS_THROW("Only Input, Output variables and Uniform constants are part of a shader linking interface."); + + // This is to avoid potential problems with very old glslang versions which did + // not emit input/output interfaces properly. + // We can assume they only had a single entry point, and single entry point + // shaders could easily be assumed to use every interface variable anyways. + if (ir.entry_points.size() <= 1) + return true; + } + + // In SPIR-V 1.4 and later, all global resource variables must be present. + + auto &execution = get_entry_point(); + return find(begin(execution.interface_variables), end(execution.interface_variables), VariableID(id)) != + end(execution.interface_variables); +} + +void Compiler::CombinedImageSamplerHandler::push_remap_parameters(const SPIRFunction &func, const uint32_t *args, + uint32_t length) +{ + // If possible, pipe through a remapping table so that parameters know + // which variables they actually bind to in this scope. + unordered_map remapping; + for (uint32_t i = 0; i < length; i++) + remapping[func.arguments[i].id] = remap_parameter(args[i]); + parameter_remapping.push(std::move(remapping)); +} + +void Compiler::CombinedImageSamplerHandler::pop_remap_parameters() +{ + parameter_remapping.pop(); +} + +uint32_t Compiler::CombinedImageSamplerHandler::remap_parameter(uint32_t id) +{ + auto *var = compiler.maybe_get_backing_variable(id); + if (var) + id = var->self; + + if (parameter_remapping.empty()) + return id; + + auto &remapping = parameter_remapping.top(); + auto itr = remapping.find(id); + if (itr != end(remapping)) + return itr->second; + else + return id; +} + +bool Compiler::CombinedImageSamplerHandler::begin_function_scope(const uint32_t *args, uint32_t length) +{ + if (length < 3) + return false; + + auto &callee = compiler.get(args[2]); + args += 3; + length -= 3; + push_remap_parameters(callee, args, length); + functions.push(&callee); + return true; +} + +bool Compiler::CombinedImageSamplerHandler::end_function_scope(const uint32_t *args, uint32_t length) +{ + if (length < 3) + return false; + + auto &callee = compiler.get(args[2]); + args += 3; + + // There are two types of cases we have to handle, + // a callee might call sampler2D(texture2D, sampler) directly where + // one or more parameters originate from parameters. + // Alternatively, we need to provide combined image samplers to our callees, + // and in this case we need to add those as well. + + pop_remap_parameters(); + + // Our callee has now been processed at least once. + // No point in doing it again. + callee.do_combined_parameters = false; + + auto ¶ms = functions.top()->combined_parameters; + functions.pop(); + if (functions.empty()) + return true; + + auto &caller = *functions.top(); + if (caller.do_combined_parameters) + { + for (auto ¶m : params) + { + VariableID image_id = param.global_image ? param.image_id : VariableID(args[param.image_id]); + VariableID sampler_id = param.global_sampler ? param.sampler_id : VariableID(args[param.sampler_id]); + + auto *i = compiler.maybe_get_backing_variable(image_id); + auto *s = compiler.maybe_get_backing_variable(sampler_id); + if (i) + image_id = i->self; + if (s) + sampler_id = s->self; + + register_combined_image_sampler(caller, 0, image_id, sampler_id, param.depth); + } + } + + return true; +} + +void Compiler::CombinedImageSamplerHandler::register_combined_image_sampler(SPIRFunction &caller, + VariableID combined_module_id, + VariableID image_id, VariableID sampler_id, + bool depth) +{ + // We now have a texture ID and a sampler ID which will either be found as a global + // or a parameter in our own function. If both are global, they will not need a parameter, + // otherwise, add it to our list. + SPIRFunction::CombinedImageSamplerParameter param = { + 0u, image_id, sampler_id, true, true, depth, + }; + + auto texture_itr = find_if(begin(caller.arguments), end(caller.arguments), + [image_id](const SPIRFunction::Parameter &p) { return p.id == image_id; }); + auto sampler_itr = find_if(begin(caller.arguments), end(caller.arguments), + [sampler_id](const SPIRFunction::Parameter &p) { return p.id == sampler_id; }); + + if (texture_itr != end(caller.arguments)) + { + param.global_image = false; + param.image_id = uint32_t(texture_itr - begin(caller.arguments)); + } + + if (sampler_itr != end(caller.arguments)) + { + param.global_sampler = false; + param.sampler_id = uint32_t(sampler_itr - begin(caller.arguments)); + } + + if (param.global_image && param.global_sampler) + return; + + auto itr = find_if(begin(caller.combined_parameters), end(caller.combined_parameters), + [¶m](const SPIRFunction::CombinedImageSamplerParameter &p) { + return param.image_id == p.image_id && param.sampler_id == p.sampler_id && + param.global_image == p.global_image && param.global_sampler == p.global_sampler; + }); + + if (itr == end(caller.combined_parameters)) + { + uint32_t id = compiler.ir.increase_bound_by(3); + auto type_id = id + 0; + auto ptr_type_id = id + 1; + auto combined_id = id + 2; + auto &base = compiler.expression_type(image_id); + auto &type = compiler.set(type_id, OpTypeSampledImage); + auto &ptr_type = compiler.set(ptr_type_id, OpTypePointer); + + type = base; + type.self = type_id; + type.basetype = SPIRType::SampledImage; + type.pointer = false; + type.storage = StorageClassGeneric; + type.image.depth = depth; + + ptr_type = type; + ptr_type.pointer = true; + ptr_type.storage = StorageClassUniformConstant; + ptr_type.parent_type = type_id; + + // Build new variable. + compiler.set(combined_id, ptr_type_id, StorageClassFunction, 0); + + // Inherit RelaxedPrecision. + // If any of OpSampledImage, underlying image or sampler are marked, inherit the decoration. + bool relaxed_precision = + compiler.has_decoration(sampler_id, DecorationRelaxedPrecision) || + compiler.has_decoration(image_id, DecorationRelaxedPrecision) || + (combined_module_id && compiler.has_decoration(combined_module_id, DecorationRelaxedPrecision)); + + if (relaxed_precision) + compiler.set_decoration(combined_id, DecorationRelaxedPrecision); + + param.id = combined_id; + + compiler.set_name(combined_id, + join("SPIRV_Cross_Combined", compiler.to_name(image_id), compiler.to_name(sampler_id))); + + caller.combined_parameters.push_back(param); + caller.shadow_arguments.push_back({ ptr_type_id, combined_id, 0u, 0u, true }); + } +} + +bool Compiler::DummySamplerForCombinedImageHandler::handle(Op opcode, const uint32_t *args, uint32_t length) +{ + if (need_dummy_sampler) + { + // No need to traverse further, we know the result. + return false; + } + + switch (opcode) + { + case OpLoad: + { + if (length < 3) + return false; + + uint32_t result_type = args[0]; + + auto &type = compiler.get(result_type); + bool separate_image = + type.basetype == SPIRType::Image && type.image.sampled == 1 && type.image.dim != DimBuffer; + + // If not separate image, don't bother. + if (!separate_image) + return true; + + uint32_t id = args[1]; + uint32_t ptr = args[2]; + compiler.set(id, "", result_type, true); + compiler.register_read(id, ptr, true); + break; + } + + case OpImageFetch: + case OpImageQuerySizeLod: + case OpImageQuerySize: + case OpImageQueryLevels: + case OpImageQuerySamples: + { + // If we are fetching or querying LOD from a plain OpTypeImage, we must pre-combine with our dummy sampler. + auto *var = compiler.maybe_get_backing_variable(args[2]); + if (var) + { + auto &type = compiler.get(var->basetype); + if (type.basetype == SPIRType::Image && type.image.sampled == 1 && type.image.dim != DimBuffer) + need_dummy_sampler = true; + } + + break; + } + + case OpInBoundsAccessChain: + case OpAccessChain: + case OpPtrAccessChain: + { + if (length < 3) + return false; + + uint32_t result_type = args[0]; + auto &type = compiler.get(result_type); + bool separate_image = + type.basetype == SPIRType::Image && type.image.sampled == 1 && type.image.dim != DimBuffer; + if (!separate_image) + return true; + + uint32_t id = args[1]; + uint32_t ptr = args[2]; + compiler.set(id, "", result_type, true); + compiler.register_read(id, ptr, true); + + // Other backends might use SPIRAccessChain for this later. + compiler.ir.ids[id].set_allow_type_rewrite(); + break; + } + + default: + break; + } + + return true; +} + +bool Compiler::CombinedImageSamplerHandler::handle(Op opcode, const uint32_t *args, uint32_t length) +{ + // We need to figure out where samplers and images are loaded from, so do only the bare bones compilation we need. + bool is_fetch = false; + + switch (opcode) + { + case OpLoad: + { + if (length < 3) + return false; + + uint32_t result_type = args[0]; + + auto &type = compiler.get(result_type); + bool separate_image = type.basetype == SPIRType::Image && type.image.sampled == 1; + bool separate_sampler = type.basetype == SPIRType::Sampler; + + // If not separate image or sampler, don't bother. + if (!separate_image && !separate_sampler) + return true; + + uint32_t id = args[1]; + uint32_t ptr = args[2]; + compiler.set(id, "", result_type, true); + compiler.register_read(id, ptr, true); + return true; + } + + case OpInBoundsAccessChain: + case OpAccessChain: + case OpPtrAccessChain: + { + if (length < 3) + return false; + + // Technically, it is possible to have arrays of textures and arrays of samplers and combine them, but this becomes essentially + // impossible to implement, since we don't know which concrete sampler we are accessing. + // One potential way is to create a combinatorial explosion where N textures and M samplers are combined into N * M sampler2Ds, + // but this seems ridiculously complicated for a problem which is easy to work around. + // Checking access chains like this assumes we don't have samplers or textures inside uniform structs, but this makes no sense. + + uint32_t result_type = args[0]; + + auto &type = compiler.get(result_type); + bool separate_image = type.basetype == SPIRType::Image && type.image.sampled == 1; + bool separate_sampler = type.basetype == SPIRType::Sampler; + if (separate_sampler) + SPIRV_CROSS_THROW( + "Attempting to use arrays or structs of separate samplers. This is not possible to statically " + "remap to plain GLSL."); + + if (separate_image) + { + uint32_t id = args[1]; + uint32_t ptr = args[2]; + compiler.set(id, "", result_type, true); + compiler.register_read(id, ptr, true); + } + return true; + } + + case OpImageFetch: + case OpImageQuerySizeLod: + case OpImageQuerySize: + case OpImageQueryLevels: + case OpImageQuerySamples: + { + // If we are fetching from a plain OpTypeImage or querying LOD, we must pre-combine with our dummy sampler. + auto *var = compiler.maybe_get_backing_variable(args[2]); + if (!var) + return true; + + auto &type = compiler.get(var->basetype); + if (type.basetype == SPIRType::Image && type.image.sampled == 1 && type.image.dim != DimBuffer) + { + if (compiler.dummy_sampler_id == 0) + SPIRV_CROSS_THROW("texelFetch without sampler was found, but no dummy sampler has been created with " + "build_dummy_sampler_for_combined_images()."); + + // Do it outside. + is_fetch = true; + break; + } + + return true; + } + + case OpSampledImage: + // Do it outside. + break; + + default: + return true; + } + + // Registers sampler2D calls used in case they are parameters so + // that their callees know which combined image samplers to propagate down the call stack. + if (!functions.empty()) + { + auto &callee = *functions.top(); + if (callee.do_combined_parameters) + { + uint32_t image_id = args[2]; + + auto *image = compiler.maybe_get_backing_variable(image_id); + if (image) + image_id = image->self; + + uint32_t sampler_id = is_fetch ? compiler.dummy_sampler_id : args[3]; + auto *sampler = compiler.maybe_get_backing_variable(sampler_id); + if (sampler) + sampler_id = sampler->self; + + uint32_t combined_id = args[1]; + + auto &combined_type = compiler.get(args[0]); + register_combined_image_sampler(callee, combined_id, image_id, sampler_id, combined_type.image.depth); + } + } + + // For function calls, we need to remap IDs which are function parameters into global variables. + // This information is statically known from the current place in the call stack. + // Function parameters are not necessarily pointers, so if we don't have a backing variable, remapping will know + // which backing variable the image/sample came from. + VariableID image_id = remap_parameter(args[2]); + VariableID sampler_id = is_fetch ? compiler.dummy_sampler_id : remap_parameter(args[3]); + + auto itr = find_if(begin(compiler.combined_image_samplers), end(compiler.combined_image_samplers), + [image_id, sampler_id](const CombinedImageSampler &combined) { + return combined.image_id == image_id && combined.sampler_id == sampler_id; + }); + + if (itr == end(compiler.combined_image_samplers)) + { + uint32_t sampled_type; + uint32_t combined_module_id; + if (is_fetch) + { + // Have to invent the sampled image type. + sampled_type = compiler.ir.increase_bound_by(1); + auto &type = compiler.set(sampled_type, OpTypeSampledImage); + type = compiler.expression_type(args[2]); + type.self = sampled_type; + type.basetype = SPIRType::SampledImage; + type.image.depth = false; + combined_module_id = 0; + } + else + { + sampled_type = args[0]; + combined_module_id = args[1]; + } + + auto id = compiler.ir.increase_bound_by(2); + auto type_id = id + 0; + auto combined_id = id + 1; + + // Make a new type, pointer to OpTypeSampledImage, so we can make a variable of this type. + // We will probably have this type lying around, but it doesn't hurt to make duplicates for internal purposes. + auto &type = compiler.set(type_id, OpTypePointer); + auto &base = compiler.get(sampled_type); + type = base; + type.pointer = true; + type.storage = StorageClassUniformConstant; + type.parent_type = type_id; + + // Build new variable. + compiler.set(combined_id, type_id, StorageClassUniformConstant, 0); + + // Inherit RelaxedPrecision (and potentially other useful flags if deemed relevant). + // If any of OpSampledImage, underlying image or sampler are marked, inherit the decoration. + bool relaxed_precision = + (sampler_id && compiler.has_decoration(sampler_id, DecorationRelaxedPrecision)) || + (image_id && compiler.has_decoration(image_id, DecorationRelaxedPrecision)) || + (combined_module_id && compiler.has_decoration(combined_module_id, DecorationRelaxedPrecision)); + + if (relaxed_precision) + compiler.set_decoration(combined_id, DecorationRelaxedPrecision); + + // Propagate the array type for the original image as well. + auto *var = compiler.maybe_get_backing_variable(image_id); + if (var) + { + auto &parent_type = compiler.get(var->basetype); + type.array = parent_type.array; + type.array_size_literal = parent_type.array_size_literal; + } + + compiler.combined_image_samplers.push_back({ combined_id, image_id, sampler_id }); + } + + return true; +} + +VariableID Compiler::build_dummy_sampler_for_combined_images() +{ + DummySamplerForCombinedImageHandler handler(*this); + traverse_all_reachable_opcodes(get(ir.default_entry_point), handler); + if (handler.need_dummy_sampler) + { + uint32_t offset = ir.increase_bound_by(3); + auto type_id = offset + 0; + auto ptr_type_id = offset + 1; + auto var_id = offset + 2; + + auto &sampler = set(type_id, OpTypeSampler); + sampler.basetype = SPIRType::Sampler; + + auto &ptr_sampler = set(ptr_type_id, OpTypePointer); + ptr_sampler = sampler; + ptr_sampler.self = type_id; + ptr_sampler.storage = StorageClassUniformConstant; + ptr_sampler.pointer = true; + ptr_sampler.parent_type = type_id; + + set(var_id, ptr_type_id, StorageClassUniformConstant, 0); + set_name(var_id, "SPIRV_Cross_DummySampler"); + dummy_sampler_id = var_id; + return var_id; + } + else + return 0; +} + +void Compiler::build_combined_image_samplers() +{ + ir.for_each_typed_id([&](uint32_t, SPIRFunction &func) { + func.combined_parameters.clear(); + func.shadow_arguments.clear(); + func.do_combined_parameters = true; + }); + + combined_image_samplers.clear(); + CombinedImageSamplerHandler handler(*this); + traverse_all_reachable_opcodes(get(ir.default_entry_point), handler); +} + +SmallVector Compiler::get_specialization_constants() const +{ + SmallVector spec_consts; + ir.for_each_typed_id([&](uint32_t, const SPIRConstant &c) { + if (c.specialization && has_decoration(c.self, DecorationSpecId)) + spec_consts.push_back({ c.self, get_decoration(c.self, DecorationSpecId) }); + }); + return spec_consts; +} + +SPIRConstant &Compiler::get_constant(ConstantID id) +{ + return get(id); +} + +const SPIRConstant &Compiler::get_constant(ConstantID id) const +{ + return get(id); +} + +static bool exists_unaccessed_path_to_return(const CFG &cfg, uint32_t block, const unordered_set &blocks, + unordered_set &visit_cache) +{ + // This block accesses the variable. + if (blocks.find(block) != end(blocks)) + return false; + + // We are at the end of the CFG. + if (cfg.get_succeeding_edges(block).empty()) + return true; + + // If any of our successors have a path to the end, there exists a path from block. + for (auto &succ : cfg.get_succeeding_edges(block)) + { + if (visit_cache.count(succ) == 0) + { + if (exists_unaccessed_path_to_return(cfg, succ, blocks, visit_cache)) + return true; + visit_cache.insert(succ); + } + } + + return false; +} + +void Compiler::analyze_parameter_preservation( + SPIRFunction &entry, const CFG &cfg, const unordered_map> &variable_to_blocks, + const unordered_map> &complete_write_blocks) +{ + for (auto &arg : entry.arguments) + { + // Non-pointers are always inputs. + auto &type = get(arg.type); + if (!type.pointer) + continue; + + // Opaque argument types are always in + bool potential_preserve; + switch (type.basetype) + { + case SPIRType::Sampler: + case SPIRType::Image: + case SPIRType::SampledImage: + case SPIRType::AtomicCounter: + potential_preserve = false; + break; + + default: + potential_preserve = true; + break; + } + + if (!potential_preserve) + continue; + + auto itr = variable_to_blocks.find(arg.id); + if (itr == end(variable_to_blocks)) + { + // Variable is never accessed. + continue; + } + + // We have accessed a variable, but there was no complete writes to that variable. + // We deduce that we must preserve the argument. + itr = complete_write_blocks.find(arg.id); + if (itr == end(complete_write_blocks)) + { + arg.read_count++; + continue; + } + + // If there is a path through the CFG where no block completely writes to the variable, the variable will be in an undefined state + // when the function returns. We therefore need to implicitly preserve the variable in case there are writers in the function. + // Major case here is if a function is + // void foo(int &var) { if (cond) var = 10; } + // Using read/write counts, we will think it's just an out variable, but it really needs to be inout, + // because if we don't write anything whatever we put into the function must return back to the caller. + unordered_set visit_cache; + if (exists_unaccessed_path_to_return(cfg, entry.entry_block, itr->second, visit_cache)) + arg.read_count++; + } +} + +Compiler::AnalyzeVariableScopeAccessHandler::AnalyzeVariableScopeAccessHandler(Compiler &compiler_, + SPIRFunction &entry_) + : compiler(compiler_) + , entry(entry_) +{ +} + +bool Compiler::AnalyzeVariableScopeAccessHandler::follow_function_call(const SPIRFunction &) +{ + // Only analyze within this function. + return false; +} + +void Compiler::AnalyzeVariableScopeAccessHandler::set_current_block(const SPIRBlock &block) +{ + current_block = █ + + // If we're branching to a block which uses OpPhi, in GLSL + // this will be a variable write when we branch, + // so we need to track access to these variables as well to + // have a complete picture. + const auto test_phi = [this, &block](uint32_t to) { + auto &next = compiler.get(to); + for (auto &phi : next.phi_variables) + { + if (phi.parent == block.self) + { + accessed_variables_to_block[phi.function_variable].insert(block.self); + // Phi variables are also accessed in our target branch block. + accessed_variables_to_block[phi.function_variable].insert(next.self); + + notify_variable_access(phi.local_variable, block.self); + } + } + }; + + switch (block.terminator) + { + case SPIRBlock::Direct: + notify_variable_access(block.condition, block.self); + test_phi(block.next_block); + break; + + case SPIRBlock::Select: + notify_variable_access(block.condition, block.self); + test_phi(block.true_block); + test_phi(block.false_block); + break; + + case SPIRBlock::MultiSelect: + { + notify_variable_access(block.condition, block.self); + auto &cases = compiler.get_case_list(block); + for (auto &target : cases) + test_phi(target.block); + if (block.default_block) + test_phi(block.default_block); + break; + } + + default: + break; + } +} + +void Compiler::AnalyzeVariableScopeAccessHandler::notify_variable_access(uint32_t id, uint32_t block) +{ + if (id == 0) + return; + + // Access chains used in multiple blocks mean hoisting all the variables used to construct the access chain as not all backends can use pointers. + auto itr = rvalue_forward_children.find(id); + if (itr != end(rvalue_forward_children)) + for (auto child_id : itr->second) + notify_variable_access(child_id, block); + + if (id_is_phi_variable(id)) + accessed_variables_to_block[id].insert(block); + else if (id_is_potential_temporary(id)) + accessed_temporaries_to_block[id].insert(block); +} + +bool Compiler::AnalyzeVariableScopeAccessHandler::id_is_phi_variable(uint32_t id) const +{ + if (id >= compiler.get_current_id_bound()) + return false; + auto *var = compiler.maybe_get(id); + return var && var->phi_variable; +} + +bool Compiler::AnalyzeVariableScopeAccessHandler::id_is_potential_temporary(uint32_t id) const +{ + if (id >= compiler.get_current_id_bound()) + return false; + + // Temporaries are not created before we start emitting code. + return compiler.ir.ids[id].empty() || (compiler.ir.ids[id].get_type() == TypeExpression); +} + +bool Compiler::AnalyzeVariableScopeAccessHandler::handle_terminator(const SPIRBlock &block) +{ + switch (block.terminator) + { + case SPIRBlock::Return: + if (block.return_value) + notify_variable_access(block.return_value, block.self); + break; + + case SPIRBlock::Select: + case SPIRBlock::MultiSelect: + notify_variable_access(block.condition, block.self); + break; + + default: + break; + } + + return true; +} + +bool Compiler::AnalyzeVariableScopeAccessHandler::handle(spv::Op op, const uint32_t *args, uint32_t length) +{ + // Keep track of the types of temporaries, so we can hoist them out as necessary. + uint32_t result_type = 0, result_id = 0; + if (compiler.instruction_to_result_type(result_type, result_id, op, args, length)) + { + // For some opcodes, we will need to override the result id. + // If we need to hoist the temporary, the temporary type is the input, not the result. + if (op == OpConvertUToAccelerationStructureKHR) + { + auto itr = result_id_to_type.find(args[2]); + if (itr != result_id_to_type.end()) + result_type = itr->second; + } + + result_id_to_type[result_id] = result_type; + } + + switch (op) + { + case OpStore: + { + if (length < 2) + return false; + + ID ptr = args[0]; + auto *var = compiler.maybe_get_backing_variable(ptr); + + // If we store through an access chain, we have a partial write. + if (var) + { + accessed_variables_to_block[var->self].insert(current_block->self); + if (var->self == ptr) + complete_write_variables_to_block[var->self].insert(current_block->self); + else + partial_write_variables_to_block[var->self].insert(current_block->self); + } + + // args[0] might be an access chain we have to track use of. + notify_variable_access(args[0], current_block->self); + // Might try to store a Phi variable here. + notify_variable_access(args[1], current_block->self); + break; + } + + case OpAccessChain: + case OpInBoundsAccessChain: + case OpPtrAccessChain: + { + if (length < 3) + return false; + + // Access chains used in multiple blocks mean hoisting all the variables used to construct the access chain as not all backends can use pointers. + uint32_t ptr = args[2]; + auto *var = compiler.maybe_get(ptr); + if (var) + { + accessed_variables_to_block[var->self].insert(current_block->self); + rvalue_forward_children[args[1]].insert(var->self); + } + + // args[2] might be another access chain we have to track use of. + for (uint32_t i = 2; i < length; i++) + { + notify_variable_access(args[i], current_block->self); + rvalue_forward_children[args[1]].insert(args[i]); + } + + // Also keep track of the access chain pointer itself. + // In exceptionally rare cases, we can end up with a case where + // the access chain is generated in the loop body, but is consumed in continue block. + // This means we need complex loop workarounds, and we must detect this via CFG analysis. + notify_variable_access(args[1], current_block->self); + + // The result of an access chain is a fixed expression and is not really considered a temporary. + auto &e = compiler.set(args[1], "", args[0], true); + auto *backing_variable = compiler.maybe_get_backing_variable(ptr); + e.loaded_from = backing_variable ? VariableID(backing_variable->self) : VariableID(0); + + // Other backends might use SPIRAccessChain for this later. + compiler.ir.ids[args[1]].set_allow_type_rewrite(); + access_chain_expressions.insert(args[1]); + break; + } + + case OpCopyMemory: + { + if (length < 2) + return false; + + ID lhs = args[0]; + ID rhs = args[1]; + auto *var = compiler.maybe_get_backing_variable(lhs); + + // If we store through an access chain, we have a partial write. + if (var) + { + accessed_variables_to_block[var->self].insert(current_block->self); + if (var->self == lhs) + complete_write_variables_to_block[var->self].insert(current_block->self); + else + partial_write_variables_to_block[var->self].insert(current_block->self); + } + + // args[0:1] might be access chains we have to track use of. + for (uint32_t i = 0; i < 2; i++) + notify_variable_access(args[i], current_block->self); + + var = compiler.maybe_get_backing_variable(rhs); + if (var) + accessed_variables_to_block[var->self].insert(current_block->self); + break; + } + + case OpCopyObject: + { + // OpCopyObject copies the underlying non-pointer type, + // so any temp variable should be declared using the underlying type. + // If the type is a pointer, get its base type and overwrite the result type mapping. + auto &type = compiler.get(result_type); + if (type.pointer) + result_id_to_type[result_id] = type.parent_type; + + if (length < 3) + return false; + + auto *var = compiler.maybe_get_backing_variable(args[2]); + if (var) + accessed_variables_to_block[var->self].insert(current_block->self); + + // Might be an access chain which we have to keep track of. + notify_variable_access(args[1], current_block->self); + if (access_chain_expressions.count(args[2])) + access_chain_expressions.insert(args[1]); + + // Might try to copy a Phi variable here. + notify_variable_access(args[2], current_block->self); + break; + } + + case OpLoad: + { + if (length < 3) + return false; + uint32_t ptr = args[2]; + auto *var = compiler.maybe_get_backing_variable(ptr); + if (var) + accessed_variables_to_block[var->self].insert(current_block->self); + + // Loaded value is a temporary. + notify_variable_access(args[1], current_block->self); + + // Might be an access chain we have to track use of. + notify_variable_access(args[2], current_block->self); + + // If we're loading an opaque type we cannot lower it to a temporary, + // we must defer access of args[2] until it's used. + auto &type = compiler.get(args[0]); + if (compiler.type_is_opaque_value(type)) + rvalue_forward_children[args[1]].insert(args[2]); + break; + } + + case OpFunctionCall: + { + if (length < 3) + return false; + + // Return value may be a temporary. + if (compiler.get_type(args[0]).basetype != SPIRType::Void) + notify_variable_access(args[1], current_block->self); + + length -= 3; + args += 3; + + for (uint32_t i = 0; i < length; i++) + { + auto *var = compiler.maybe_get_backing_variable(args[i]); + if (var) + { + accessed_variables_to_block[var->self].insert(current_block->self); + // Assume we can get partial writes to this variable. + partial_write_variables_to_block[var->self].insert(current_block->self); + } + + // Cannot easily prove if argument we pass to a function is completely written. + // Usually, functions write to a dummy variable, + // which is then copied to in full to the real argument. + + // Might try to copy a Phi variable here. + notify_variable_access(args[i], current_block->self); + } + break; + } + + case OpSelect: + { + // In case of variable pointers, we might access a variable here. + // We cannot prove anything about these accesses however. + for (uint32_t i = 1; i < length; i++) + { + if (i >= 3) + { + auto *var = compiler.maybe_get_backing_variable(args[i]); + if (var) + { + accessed_variables_to_block[var->self].insert(current_block->self); + // Assume we can get partial writes to this variable. + partial_write_variables_to_block[var->self].insert(current_block->self); + } + } + + // Might try to copy a Phi variable here. + notify_variable_access(args[i], current_block->self); + } + break; + } + + case OpExtInst: + { + for (uint32_t i = 4; i < length; i++) + notify_variable_access(args[i], current_block->self); + notify_variable_access(args[1], current_block->self); + + uint32_t extension_set = args[2]; + if (compiler.get(extension_set).ext == SPIRExtension::GLSL) + { + auto op_450 = static_cast(args[3]); + switch (op_450) + { + case GLSLstd450Modf: + case GLSLstd450Frexp: + { + uint32_t ptr = args[5]; + auto *var = compiler.maybe_get_backing_variable(ptr); + if (var) + { + accessed_variables_to_block[var->self].insert(current_block->self); + if (var->self == ptr) + complete_write_variables_to_block[var->self].insert(current_block->self); + else + partial_write_variables_to_block[var->self].insert(current_block->self); + } + break; + } + + default: + break; + } + } + break; + } + + case OpArrayLength: + // Only result is a temporary. + notify_variable_access(args[1], current_block->self); + break; + + case OpLine: + case OpNoLine: + // Uses literals, but cannot be a phi variable or temporary, so ignore. + break; + + // Atomics shouldn't be able to access function-local variables. + // Some GLSL builtins access a pointer. + + case OpCompositeInsert: + case OpVectorShuffle: + // Specialize for opcode which contains literals. + for (uint32_t i = 1; i < 4; i++) + notify_variable_access(args[i], current_block->self); + break; + + case OpCompositeExtract: + // Specialize for opcode which contains literals. + for (uint32_t i = 1; i < 3; i++) + notify_variable_access(args[i], current_block->self); + break; + + case OpImageWrite: + for (uint32_t i = 0; i < length; i++) + { + // Argument 3 is a literal. + if (i != 3) + notify_variable_access(args[i], current_block->self); + } + break; + + case OpImageSampleImplicitLod: + case OpImageSampleExplicitLod: + case OpImageSparseSampleImplicitLod: + case OpImageSparseSampleExplicitLod: + case OpImageSampleProjImplicitLod: + case OpImageSampleProjExplicitLod: + case OpImageSparseSampleProjImplicitLod: + case OpImageSparseSampleProjExplicitLod: + case OpImageFetch: + case OpImageSparseFetch: + case OpImageRead: + case OpImageSparseRead: + for (uint32_t i = 1; i < length; i++) + { + // Argument 4 is a literal. + if (i != 4) + notify_variable_access(args[i], current_block->self); + } + break; + + case OpImageSampleDrefImplicitLod: + case OpImageSampleDrefExplicitLod: + case OpImageSparseSampleDrefImplicitLod: + case OpImageSparseSampleDrefExplicitLod: + case OpImageSampleProjDrefImplicitLod: + case OpImageSampleProjDrefExplicitLod: + case OpImageSparseSampleProjDrefImplicitLod: + case OpImageSparseSampleProjDrefExplicitLod: + case OpImageGather: + case OpImageSparseGather: + case OpImageDrefGather: + case OpImageSparseDrefGather: + for (uint32_t i = 1; i < length; i++) + { + // Argument 5 is a literal. + if (i != 5) + notify_variable_access(args[i], current_block->self); + } + break; + + default: + { + // Rather dirty way of figuring out where Phi variables are used. + // As long as only IDs are used, we can scan through instructions and try to find any evidence that + // the ID of a variable has been used. + // There are potential false positives here where a literal is used in-place of an ID, + // but worst case, it does not affect the correctness of the compile. + // Exhaustive analysis would be better here, but it's not worth it for now. + for (uint32_t i = 0; i < length; i++) + notify_variable_access(args[i], current_block->self); + break; + } + } + return true; +} + +Compiler::StaticExpressionAccessHandler::StaticExpressionAccessHandler(Compiler &compiler_, uint32_t variable_id_) + : compiler(compiler_) + , variable_id(variable_id_) +{ +} + +bool Compiler::StaticExpressionAccessHandler::follow_function_call(const SPIRFunction &) +{ + return false; +} + +bool Compiler::StaticExpressionAccessHandler::handle(spv::Op op, const uint32_t *args, uint32_t length) +{ + switch (op) + { + case OpStore: + if (length < 2) + return false; + if (args[0] == variable_id) + { + static_expression = args[1]; + write_count++; + } + break; + + case OpLoad: + if (length < 3) + return false; + if (args[2] == variable_id && static_expression == 0) // Tried to read from variable before it was initialized. + return false; + break; + + case OpAccessChain: + case OpInBoundsAccessChain: + case OpPtrAccessChain: + if (length < 3) + return false; + if (args[2] == variable_id) // If we try to access chain our candidate variable before we store to it, bail. + return false; + break; + + default: + break; + } + + return true; +} + +void Compiler::find_function_local_luts(SPIRFunction &entry, const AnalyzeVariableScopeAccessHandler &handler, + bool single_function) +{ + auto &cfg = *function_cfgs.find(entry.self)->second; + + // For each variable which is statically accessed. + for (auto &accessed_var : handler.accessed_variables_to_block) + { + auto &blocks = accessed_var.second; + auto &var = get(accessed_var.first); + auto &type = expression_type(accessed_var.first); + + // First check if there are writes to the variable. Later, if there are none, we'll + // reconsider it as globally accessed LUT. + if (!var.is_written_to) + { + var.is_written_to = handler.complete_write_variables_to_block.count(var.self) != 0 || + handler.partial_write_variables_to_block.count(var.self) != 0; + } + + // Only consider function local variables here. + // If we only have a single function in our CFG, private storage is also fine, + // since it behaves like a function local variable. + bool allow_lut = var.storage == StorageClassFunction || (single_function && var.storage == StorageClassPrivate); + if (!allow_lut) + continue; + + // We cannot be a phi variable. + if (var.phi_variable) + continue; + + // Only consider arrays here. + if (type.array.empty()) + continue; + + // If the variable has an initializer, make sure it is a constant expression. + uint32_t static_constant_expression = 0; + if (var.initializer) + { + if (ir.ids[var.initializer].get_type() != TypeConstant) + continue; + static_constant_expression = var.initializer; + + // There can be no stores to this variable, we have now proved we have a LUT. + if (var.is_written_to) + continue; + } + else + { + // We can have one, and only one write to the variable, and that write needs to be a constant. + + // No partial writes allowed. + if (handler.partial_write_variables_to_block.count(var.self) != 0) + continue; + + auto itr = handler.complete_write_variables_to_block.find(var.self); + + // No writes? + if (itr == end(handler.complete_write_variables_to_block)) + continue; + + // We write to the variable in more than one block. + auto &write_blocks = itr->second; + if (write_blocks.size() != 1) + continue; + + // The write needs to happen in the dominating block. + DominatorBuilder builder(cfg); + for (auto &block : blocks) + builder.add_block(block); + uint32_t dominator = builder.get_dominator(); + + // The complete write happened in a branch or similar, cannot deduce static expression. + if (write_blocks.count(dominator) == 0) + continue; + + // Find the static expression for this variable. + StaticExpressionAccessHandler static_expression_handler(*this, var.self); + traverse_all_reachable_opcodes(get(dominator), static_expression_handler); + + // We want one, and exactly one write + if (static_expression_handler.write_count != 1 || static_expression_handler.static_expression == 0) + continue; + + // Is it a constant expression? + if (ir.ids[static_expression_handler.static_expression].get_type() != TypeConstant) + continue; + + // We found a LUT! + static_constant_expression = static_expression_handler.static_expression; + } + + get(static_constant_expression).is_used_as_lut = true; + var.static_expression = static_constant_expression; + var.statically_assigned = true; + var.remapped_variable = true; + } +} + +void Compiler::analyze_variable_scope(SPIRFunction &entry, AnalyzeVariableScopeAccessHandler &handler) +{ + // First, we map out all variable access within a function. + // Essentially a map of block -> { variables accessed in the basic block } + traverse_all_reachable_opcodes(entry, handler); + + auto &cfg = *function_cfgs.find(entry.self)->second; + + // Analyze if there are parameters which need to be implicitly preserved with an "in" qualifier. + analyze_parameter_preservation(entry, cfg, handler.accessed_variables_to_block, + handler.complete_write_variables_to_block); + + unordered_map potential_loop_variables; + + // Find the loop dominator block for each block. + for (auto &block_id : entry.blocks) + { + auto &block = get(block_id); + + auto itr = ir.continue_block_to_loop_header.find(block_id); + if (itr != end(ir.continue_block_to_loop_header) && itr->second != block_id) + { + // Continue block might be unreachable in the CFG, but we still like to know the loop dominator. + // Edge case is when continue block is also the loop header, don't set the dominator in this case. + block.loop_dominator = itr->second; + } + else + { + uint32_t loop_dominator = cfg.find_loop_dominator(block_id); + if (loop_dominator != block_id) + block.loop_dominator = loop_dominator; + else + block.loop_dominator = SPIRBlock::NoDominator; + } + } + + // For each variable which is statically accessed. + for (auto &var : handler.accessed_variables_to_block) + { + // Only deal with variables which are considered local variables in this function. + if (find(begin(entry.local_variables), end(entry.local_variables), VariableID(var.first)) == + end(entry.local_variables)) + continue; + + DominatorBuilder builder(cfg); + auto &blocks = var.second; + auto &type = expression_type(var.first); + BlockID potential_continue_block = 0; + + // Figure out which block is dominating all accesses of those variables. + for (auto &block : blocks) + { + // If we're accessing a variable inside a continue block, this variable might be a loop variable. + // We can only use loop variables with scalars, as we cannot track static expressions for vectors. + if (is_continue(block)) + { + // Potentially awkward case to check for. + // We might have a variable inside a loop, which is touched by the continue block, + // but is not actually a loop variable. + // The continue block is dominated by the inner part of the loop, which does not make sense in high-level + // language output because it will be declared before the body, + // so we will have to lift the dominator up to the relevant loop header instead. + builder.add_block(ir.continue_block_to_loop_header[block]); + + // Arrays or structs cannot be loop variables. + if (type.vecsize == 1 && type.columns == 1 && type.basetype != SPIRType::Struct && type.array.empty()) + { + // The variable is used in multiple continue blocks, this is not a loop + // candidate, signal that by setting block to -1u. + if (potential_continue_block == 0) + potential_continue_block = block; + else + potential_continue_block = ~(0u); + } + } + + builder.add_block(block); + } + + builder.lift_continue_block_dominator(); + + // Add it to a per-block list of variables. + BlockID dominating_block = builder.get_dominator(); + + if (dominating_block && potential_continue_block != 0 && potential_continue_block != ~0u) + { + auto &inner_block = get(dominating_block); + + BlockID merge_candidate = 0; + + // Analyze the dominator. If it lives in a different loop scope than the candidate continue + // block, reject the loop variable candidate. + if (inner_block.merge == SPIRBlock::MergeLoop) + merge_candidate = inner_block.merge_block; + else if (inner_block.loop_dominator != SPIRBlock::NoDominator) + merge_candidate = get(inner_block.loop_dominator).merge_block; + + if (merge_candidate != 0 && cfg.is_reachable(merge_candidate)) + { + // If the merge block has a higher post-visit order, we know that continue candidate + // cannot reach the merge block, and we have two separate scopes. + if (!cfg.is_reachable(potential_continue_block) || + cfg.get_visit_order(merge_candidate) > cfg.get_visit_order(potential_continue_block)) + { + potential_continue_block = 0; + } + } + } + + if (potential_continue_block != 0 && potential_continue_block != ~0u) + potential_loop_variables[var.first] = potential_continue_block; + + // For variables whose dominating block is inside a loop, there is a risk that these variables + // actually need to be preserved across loop iterations. We can express this by adding + // a "read" access to the loop header. + // In the dominating block, we must see an OpStore or equivalent as the first access of an OpVariable. + // Should that fail, we look for the outermost loop header and tack on an access there. + // Phi nodes cannot have this problem. + if (dominating_block) + { + auto &variable = get(var.first); + if (!variable.phi_variable) + { + auto *block = &get(dominating_block); + bool preserve = may_read_undefined_variable_in_block(*block, var.first); + if (preserve) + { + // Find the outermost loop scope. + while (block->loop_dominator != BlockID(SPIRBlock::NoDominator)) + block = &get(block->loop_dominator); + + if (block->self != dominating_block) + { + builder.add_block(block->self); + dominating_block = builder.get_dominator(); + } + } + } + } + + // If all blocks here are dead code, this will be 0, so the variable in question + // will be completely eliminated. + if (dominating_block) + { + auto &block = get(dominating_block); + block.dominated_variables.push_back(var.first); + get(var.first).dominator = dominating_block; + } + } + + for (auto &var : handler.accessed_temporaries_to_block) + { + auto itr = handler.result_id_to_type.find(var.first); + + if (itr == end(handler.result_id_to_type)) + { + // We found a false positive ID being used, ignore. + // This should probably be an assert. + continue; + } + + // There is no point in doing domination analysis for opaque types. + auto &type = get(itr->second); + if (type_is_opaque_value(type)) + continue; + + DominatorBuilder builder(cfg); + bool force_temporary = false; + bool used_in_header_hoisted_continue_block = false; + + // Figure out which block is dominating all accesses of those temporaries. + auto &blocks = var.second; + for (auto &block : blocks) + { + builder.add_block(block); + + if (blocks.size() != 1 && is_continue(block)) + { + // The risk here is that inner loop can dominate the continue block. + // Any temporary we access in the continue block must be declared before the loop. + // This is moot for complex loops however. + auto &loop_header_block = get(ir.continue_block_to_loop_header[block]); + assert(loop_header_block.merge == SPIRBlock::MergeLoop); + builder.add_block(loop_header_block.self); + used_in_header_hoisted_continue_block = true; + } + } + + uint32_t dominating_block = builder.get_dominator(); + + if (blocks.size() != 1 && is_single_block_loop(dominating_block)) + { + // Awkward case, because the loop header is also the continue block, + // so hoisting to loop header does not help. + force_temporary = true; + } + + if (dominating_block) + { + // If we touch a variable in the dominating block, this is the expected setup. + // SPIR-V normally mandates this, but we have extra cases for temporary use inside loops. + bool first_use_is_dominator = blocks.count(dominating_block) != 0; + + if (!first_use_is_dominator || force_temporary) + { + if (handler.access_chain_expressions.count(var.first)) + { + // Exceptionally rare case. + // We cannot declare temporaries of access chains (except on MSL perhaps with pointers). + // Rather than do that, we force the indexing expressions to be declared in the right scope by + // tracking their usage to that end. There is no temporary to hoist. + // However, we still need to observe declaration order of the access chain. + + if (used_in_header_hoisted_continue_block) + { + // For this scenario, we used an access chain inside a continue block where we also registered an access to header block. + // This is a problem as we need to declare an access chain properly first with full definition. + // We cannot use temporaries for these expressions, + // so we must make sure the access chain is declared ahead of time. + // Force a complex for loop to deal with this. + // TODO: Out-of-order declaring for loops where continue blocks are emitted last might be another option. + auto &loop_header_block = get(dominating_block); + assert(loop_header_block.merge == SPIRBlock::MergeLoop); + loop_header_block.complex_continue = true; + } + } + else + { + // This should be very rare, but if we try to declare a temporary inside a loop, + // and that temporary is used outside the loop as well (spirv-opt inliner likes this) + // we should actually emit the temporary outside the loop. + hoisted_temporaries.insert(var.first); + forced_temporaries.insert(var.first); + + auto &block_temporaries = get(dominating_block).declare_temporary; + block_temporaries.emplace_back(handler.result_id_to_type[var.first], var.first); + } + } + else if (blocks.size() > 1) + { + // Keep track of the temporary as we might have to declare this temporary. + // This can happen if the loop header dominates a temporary, but we have a complex fallback loop. + // In this case, the header is actually inside the for (;;) {} block, and we have problems. + // What we need to do is hoist the temporaries outside the for (;;) {} block in case the header block + // declares the temporary. + auto &block_temporaries = get(dominating_block).potential_declare_temporary; + block_temporaries.emplace_back(handler.result_id_to_type[var.first], var.first); + } + } + } + + unordered_set seen_blocks; + + // Now, try to analyze whether or not these variables are actually loop variables. + for (auto &loop_variable : potential_loop_variables) + { + auto &var = get(loop_variable.first); + auto dominator = var.dominator; + BlockID block = loop_variable.second; + + // The variable was accessed in multiple continue blocks, ignore. + if (block == BlockID(~(0u)) || block == BlockID(0)) + continue; + + // Dead code. + if (dominator == ID(0)) + continue; + + BlockID header = 0; + + // Find the loop header for this block if we are a continue block. + { + auto itr = ir.continue_block_to_loop_header.find(block); + if (itr != end(ir.continue_block_to_loop_header)) + { + header = itr->second; + } + else if (get(block).continue_block == block) + { + // Also check for self-referential continue block. + header = block; + } + } + + assert(header); + auto &header_block = get(header); + auto &blocks = handler.accessed_variables_to_block[loop_variable.first]; + + // If a loop variable is not used before the loop, it's probably not a loop variable. + bool has_accessed_variable = blocks.count(header) != 0; + + // Now, there are two conditions we need to meet for the variable to be a loop variable. + // 1. The dominating block must have a branch-free path to the loop header, + // this way we statically know which expression should be part of the loop variable initializer. + + // Walk from the dominator, if there is one straight edge connecting + // dominator and loop header, we statically know the loop initializer. + bool static_loop_init = true; + while (dominator != header) + { + if (blocks.count(dominator) != 0) + has_accessed_variable = true; + + auto &succ = cfg.get_succeeding_edges(dominator); + if (succ.size() != 1) + { + static_loop_init = false; + break; + } + + auto &pred = cfg.get_preceding_edges(succ.front()); + if (pred.size() != 1 || pred.front() != dominator) + { + static_loop_init = false; + break; + } + + dominator = succ.front(); + } + + if (!static_loop_init || !has_accessed_variable) + continue; + + // The second condition we need to meet is that no access after the loop + // merge can occur. Walk the CFG to see if we find anything. + + seen_blocks.clear(); + cfg.walk_from(seen_blocks, header_block.merge_block, [&](uint32_t walk_block) -> bool { + // We found a block which accesses the variable outside the loop. + if (blocks.find(walk_block) != end(blocks)) + static_loop_init = false; + return true; + }); + + if (!static_loop_init) + continue; + + // We have a loop variable. + header_block.loop_variables.push_back(loop_variable.first); + // Need to sort here as variables come from an unordered container, and pushing stuff in wrong order + // will break reproducability in regression runs. + sort(begin(header_block.loop_variables), end(header_block.loop_variables)); + get(loop_variable.first).loop_variable = true; + } +} + +bool Compiler::may_read_undefined_variable_in_block(const SPIRBlock &block, uint32_t var) +{ + for (auto &op : block.ops) + { + auto *ops = stream(op); + switch (op.op) + { + case OpStore: + case OpCopyMemory: + if (ops[0] == var) + return false; + break; + + case OpAccessChain: + case OpInBoundsAccessChain: + case OpPtrAccessChain: + // Access chains are generally used to partially read and write. It's too hard to analyze + // if all constituents are written fully before continuing, so just assume it's preserved. + // This is the same as the parameter preservation analysis. + if (ops[2] == var) + return true; + break; + + case OpSelect: + // Variable pointers. + // We might read before writing. + if (ops[3] == var || ops[4] == var) + return true; + break; + + case OpPhi: + { + // Variable pointers. + // We might read before writing. + if (op.length < 2) + break; + + uint32_t count = op.length - 2; + for (uint32_t i = 0; i < count; i += 2) + if (ops[i + 2] == var) + return true; + break; + } + + case OpCopyObject: + case OpLoad: + if (ops[2] == var) + return true; + break; + + case OpFunctionCall: + { + if (op.length < 3) + break; + + // May read before writing. + uint32_t count = op.length - 3; + for (uint32_t i = 0; i < count; i++) + if (ops[i + 3] == var) + return true; + break; + } + + default: + break; + } + } + + // Not accessed somehow, at least not in a usual fashion. + // It's likely accessed in a branch, so assume we must preserve. + return true; +} + +Bitset Compiler::get_buffer_block_flags(VariableID id) const +{ + return ir.get_buffer_block_flags(get(id)); +} + +bool Compiler::get_common_basic_type(const SPIRType &type, SPIRType::BaseType &base_type) +{ + if (type.basetype == SPIRType::Struct) + { + base_type = SPIRType::Unknown; + for (auto &member_type : type.member_types) + { + SPIRType::BaseType member_base; + if (!get_common_basic_type(get(member_type), member_base)) + return false; + + if (base_type == SPIRType::Unknown) + base_type = member_base; + else if (base_type != member_base) + return false; + } + return true; + } + else + { + base_type = type.basetype; + return true; + } +} + +void Compiler::ActiveBuiltinHandler::handle_builtin(const SPIRType &type, BuiltIn builtin, + const Bitset &decoration_flags) +{ + // If used, we will need to explicitly declare a new array size for these builtins. + + if (builtin == BuiltInClipDistance) + { + if (!type.array_size_literal[0]) + SPIRV_CROSS_THROW("Array size for ClipDistance must be a literal."); + uint32_t array_size = type.array[0]; + if (array_size == 0) + SPIRV_CROSS_THROW("Array size for ClipDistance must not be unsized."); + compiler.clip_distance_count = array_size; + } + else if (builtin == BuiltInCullDistance) + { + if (!type.array_size_literal[0]) + SPIRV_CROSS_THROW("Array size for CullDistance must be a literal."); + uint32_t array_size = type.array[0]; + if (array_size == 0) + SPIRV_CROSS_THROW("Array size for CullDistance must not be unsized."); + compiler.cull_distance_count = array_size; + } + else if (builtin == BuiltInPosition) + { + if (decoration_flags.get(DecorationInvariant)) + compiler.position_invariant = true; + } +} + +void Compiler::ActiveBuiltinHandler::add_if_builtin(uint32_t id, bool allow_blocks) +{ + // Only handle plain variables here. + // Builtins which are part of a block are handled in AccessChain. + // If allow_blocks is used however, this is to handle initializers of blocks, + // which implies that all members are written to. + + auto *var = compiler.maybe_get(id); + auto *m = compiler.ir.find_meta(id); + if (var && m) + { + auto &type = compiler.get(var->basetype); + auto &decorations = m->decoration; + auto &flags = type.storage == StorageClassInput ? + compiler.active_input_builtins : compiler.active_output_builtins; + if (decorations.builtin) + { + flags.set(decorations.builtin_type); + handle_builtin(type, decorations.builtin_type, decorations.decoration_flags); + } + else if (allow_blocks && compiler.has_decoration(type.self, DecorationBlock)) + { + uint32_t member_count = uint32_t(type.member_types.size()); + for (uint32_t i = 0; i < member_count; i++) + { + if (compiler.has_member_decoration(type.self, i, DecorationBuiltIn)) + { + auto &member_type = compiler.get(type.member_types[i]); + BuiltIn builtin = BuiltIn(compiler.get_member_decoration(type.self, i, DecorationBuiltIn)); + flags.set(builtin); + handle_builtin(member_type, builtin, compiler.get_member_decoration_bitset(type.self, i)); + } + } + } + } +} + +void Compiler::ActiveBuiltinHandler::add_if_builtin(uint32_t id) +{ + add_if_builtin(id, false); +} + +void Compiler::ActiveBuiltinHandler::add_if_builtin_or_block(uint32_t id) +{ + add_if_builtin(id, true); +} + +bool Compiler::ActiveBuiltinHandler::handle(spv::Op opcode, const uint32_t *args, uint32_t length) +{ + switch (opcode) + { + case OpStore: + if (length < 1) + return false; + + add_if_builtin(args[0]); + break; + + case OpCopyMemory: + if (length < 2) + return false; + + add_if_builtin(args[0]); + add_if_builtin(args[1]); + break; + + case OpCopyObject: + case OpLoad: + if (length < 3) + return false; + + add_if_builtin(args[2]); + break; + + case OpSelect: + if (length < 5) + return false; + + add_if_builtin(args[3]); + add_if_builtin(args[4]); + break; + + case OpPhi: + { + if (length < 2) + return false; + + uint32_t count = length - 2; + args += 2; + for (uint32_t i = 0; i < count; i += 2) + add_if_builtin(args[i]); + break; + } + + case OpFunctionCall: + { + if (length < 3) + return false; + + uint32_t count = length - 3; + args += 3; + for (uint32_t i = 0; i < count; i++) + add_if_builtin(args[i]); + break; + } + + case OpAccessChain: + case OpInBoundsAccessChain: + case OpPtrAccessChain: + { + if (length < 4) + return false; + + // Only consider global variables, cannot consider variables in functions yet, or other + // access chains as they have not been created yet. + auto *var = compiler.maybe_get(args[2]); + if (!var) + break; + + // Required if we access chain into builtins like gl_GlobalInvocationID. + add_if_builtin(args[2]); + + // Start traversing type hierarchy at the proper non-pointer types. + auto *type = &compiler.get_variable_data_type(*var); + + auto &flags = + var->storage == StorageClassInput ? compiler.active_input_builtins : compiler.active_output_builtins; + + uint32_t count = length - 3; + args += 3; + for (uint32_t i = 0; i < count; i++) + { + // Pointers + // PtrAccessChain functions more like a pointer offset. Type remains the same. + if (opcode == OpPtrAccessChain && i == 0) + continue; + + // Arrays + if (!type->array.empty()) + { + type = &compiler.get(type->parent_type); + } + // Structs + else if (type->basetype == SPIRType::Struct) + { + uint32_t index = compiler.get(args[i]).scalar(); + + if (index < uint32_t(compiler.ir.meta[type->self].members.size())) + { + auto &decorations = compiler.ir.meta[type->self].members[index]; + if (decorations.builtin) + { + flags.set(decorations.builtin_type); + handle_builtin(compiler.get(type->member_types[index]), decorations.builtin_type, + decorations.decoration_flags); + } + } + + type = &compiler.get(type->member_types[index]); + } + else + { + // No point in traversing further. We won't find any extra builtins. + break; + } + } + break; + } + + default: + break; + } + + return true; +} + +void Compiler::update_active_builtins() +{ + active_input_builtins.reset(); + active_output_builtins.reset(); + cull_distance_count = 0; + clip_distance_count = 0; + ActiveBuiltinHandler handler(*this); + traverse_all_reachable_opcodes(get(ir.default_entry_point), handler); + + ir.for_each_typed_id([&](uint32_t, const SPIRVariable &var) { + if (var.storage != StorageClassOutput) + return; + if (!interface_variable_exists_in_entry_point(var.self)) + return; + + // Also, make sure we preserve output variables which are only initialized, but never accessed by any code. + if (var.initializer != ID(0)) + handler.add_if_builtin_or_block(var.self); + }); +} + +// Returns whether this shader uses a builtin of the storage class +bool Compiler::has_active_builtin(BuiltIn builtin, StorageClass storage) const +{ + const Bitset *flags; + switch (storage) + { + case StorageClassInput: + flags = &active_input_builtins; + break; + case StorageClassOutput: + flags = &active_output_builtins; + break; + + default: + return false; + } + return flags->get(builtin); +} + +void Compiler::analyze_image_and_sampler_usage() +{ + CombinedImageSamplerDrefHandler dref_handler(*this); + traverse_all_reachable_opcodes(get(ir.default_entry_point), dref_handler); + + CombinedImageSamplerUsageHandler handler(*this, dref_handler.dref_combined_samplers); + traverse_all_reachable_opcodes(get(ir.default_entry_point), handler); + + // Need to run this traversal twice. First time, we propagate any comparison sampler usage from leaf functions + // down to main(). + // In the second pass, we can propagate up forced depth state coming from main() up into leaf functions. + handler.dependency_hierarchy.clear(); + traverse_all_reachable_opcodes(get(ir.default_entry_point), handler); + + comparison_ids = std::move(handler.comparison_ids); + need_subpass_input = handler.need_subpass_input; + need_subpass_input_ms = handler.need_subpass_input_ms; + + // Forward information from separate images and samplers into combined image samplers. + for (auto &combined : combined_image_samplers) + if (comparison_ids.count(combined.sampler_id)) + comparison_ids.insert(combined.combined_id); +} + +bool Compiler::CombinedImageSamplerDrefHandler::handle(spv::Op opcode, const uint32_t *args, uint32_t) +{ + // Mark all sampled images which are used with Dref. + switch (opcode) + { + case OpImageSampleDrefExplicitLod: + case OpImageSampleDrefImplicitLod: + case OpImageSampleProjDrefExplicitLod: + case OpImageSampleProjDrefImplicitLod: + case OpImageSparseSampleProjDrefImplicitLod: + case OpImageSparseSampleDrefImplicitLod: + case OpImageSparseSampleProjDrefExplicitLod: + case OpImageSparseSampleDrefExplicitLod: + case OpImageDrefGather: + case OpImageSparseDrefGather: + dref_combined_samplers.insert(args[2]); + return true; + + default: + break; + } + + return true; +} + +const CFG &Compiler::get_cfg_for_current_function() const +{ + assert(current_function); + return get_cfg_for_function(current_function->self); +} + +const CFG &Compiler::get_cfg_for_function(uint32_t id) const +{ + auto cfg_itr = function_cfgs.find(id); + assert(cfg_itr != end(function_cfgs)); + assert(cfg_itr->second); + return *cfg_itr->second; +} + +void Compiler::build_function_control_flow_graphs_and_analyze() +{ + CFGBuilder handler(*this); + handler.function_cfgs[ir.default_entry_point].reset(new CFG(*this, get(ir.default_entry_point))); + traverse_all_reachable_opcodes(get(ir.default_entry_point), handler); + function_cfgs = std::move(handler.function_cfgs); + bool single_function = function_cfgs.size() <= 1; + + for (auto &f : function_cfgs) + { + auto &func = get(f.first); + AnalyzeVariableScopeAccessHandler scope_handler(*this, func); + analyze_variable_scope(func, scope_handler); + find_function_local_luts(func, scope_handler, single_function); + + // Check if we can actually use the loop variables we found in analyze_variable_scope. + // To use multiple initializers, we need the same type and qualifiers. + for (auto block : func.blocks) + { + auto &b = get(block); + if (b.loop_variables.size() < 2) + continue; + + auto &flags = get_decoration_bitset(b.loop_variables.front()); + uint32_t type = get(b.loop_variables.front()).basetype; + bool invalid_initializers = false; + for (auto loop_variable : b.loop_variables) + { + if (flags != get_decoration_bitset(loop_variable) || + type != get(b.loop_variables.front()).basetype) + { + invalid_initializers = true; + break; + } + } + + if (invalid_initializers) + { + for (auto loop_variable : b.loop_variables) + get(loop_variable).loop_variable = false; + b.loop_variables.clear(); + } + } + } + + // Find LUTs which are not function local. Only consider this case if the CFG is multi-function, + // otherwise we treat Private as Function trivially. + // Needs to be analyzed from the outside since we have to block the LUT optimization if at least + // one function writes to it. + if (!single_function) + { + for (auto &id : global_variables) + { + auto &var = get(id); + auto &type = get_variable_data_type(var); + + if (is_array(type) && var.storage == StorageClassPrivate && + var.initializer && !var.is_written_to && + ir.ids[var.initializer].get_type() == TypeConstant) + { + get(var.initializer).is_used_as_lut = true; + var.static_expression = var.initializer; + var.statically_assigned = true; + var.remapped_variable = true; + } + } + } +} + +Compiler::CFGBuilder::CFGBuilder(Compiler &compiler_) + : compiler(compiler_) +{ +} + +bool Compiler::CFGBuilder::handle(spv::Op, const uint32_t *, uint32_t) +{ + return true; +} + +bool Compiler::CFGBuilder::follow_function_call(const SPIRFunction &func) +{ + if (function_cfgs.find(func.self) == end(function_cfgs)) + { + function_cfgs[func.self].reset(new CFG(compiler, func)); + return true; + } + else + return false; +} + +void Compiler::CombinedImageSamplerUsageHandler::add_dependency(uint32_t dst, uint32_t src) +{ + dependency_hierarchy[dst].insert(src); + // Propagate up any comparison state if we're loading from one such variable. + if (comparison_ids.count(src)) + comparison_ids.insert(dst); +} + +bool Compiler::CombinedImageSamplerUsageHandler::begin_function_scope(const uint32_t *args, uint32_t length) +{ + if (length < 3) + return false; + + auto &func = compiler.get(args[2]); + const auto *arg = &args[3]; + length -= 3; + + for (uint32_t i = 0; i < length; i++) + { + auto &argument = func.arguments[i]; + add_dependency(argument.id, arg[i]); + } + + return true; +} + +void Compiler::CombinedImageSamplerUsageHandler::add_hierarchy_to_comparison_ids(uint32_t id) +{ + // Traverse the variable dependency hierarchy and tag everything in its path with comparison ids. + comparison_ids.insert(id); + + for (auto &dep_id : dependency_hierarchy[id]) + add_hierarchy_to_comparison_ids(dep_id); +} + +bool Compiler::CombinedImageSamplerUsageHandler::handle(Op opcode, const uint32_t *args, uint32_t length) +{ + switch (opcode) + { + case OpAccessChain: + case OpInBoundsAccessChain: + case OpPtrAccessChain: + case OpLoad: + { + if (length < 3) + return false; + + add_dependency(args[1], args[2]); + + // Ideally defer this to OpImageRead, but then we'd need to track loaded IDs. + // If we load an image, we're going to use it and there is little harm in declaring an unused gl_FragCoord. + auto &type = compiler.get(args[0]); + if (type.image.dim == DimSubpassData) + { + need_subpass_input = true; + if (type.image.ms) + need_subpass_input_ms = true; + } + + // If we load a SampledImage and it will be used with Dref, propagate the state up. + if (dref_combined_samplers.count(args[1]) != 0) + add_hierarchy_to_comparison_ids(args[1]); + break; + } + + case OpSampledImage: + { + if (length < 4) + return false; + + // If the underlying resource has been used for comparison then duplicate loads of that resource must be too. + // This image must be a depth image. + uint32_t result_id = args[1]; + uint32_t image = args[2]; + uint32_t sampler = args[3]; + + if (dref_combined_samplers.count(result_id) != 0) + { + add_hierarchy_to_comparison_ids(image); + + // This sampler must be a SamplerComparisonState, and not a regular SamplerState. + add_hierarchy_to_comparison_ids(sampler); + + // Mark the OpSampledImage itself as being comparison state. + comparison_ids.insert(result_id); + } + return true; + } + + default: + break; + } + + return true; +} + +bool Compiler::buffer_is_hlsl_counter_buffer(VariableID id) const +{ + auto *m = ir.find_meta(id); + return m && m->hlsl_is_magic_counter_buffer; +} + +bool Compiler::buffer_get_hlsl_counter_buffer(VariableID id, uint32_t &counter_id) const +{ + auto *m = ir.find_meta(id); + + // First, check for the proper decoration. + if (m && m->hlsl_magic_counter_buffer != 0) + { + counter_id = m->hlsl_magic_counter_buffer; + return true; + } + else + return false; +} + +void Compiler::make_constant_null(uint32_t id, uint32_t type) +{ + auto &constant_type = get(type); + + if (constant_type.pointer) + { + auto &constant = set(id, type); + constant.make_null(constant_type); + } + else if (!constant_type.array.empty()) + { + assert(constant_type.parent_type); + uint32_t parent_id = ir.increase_bound_by(1); + make_constant_null(parent_id, constant_type.parent_type); + + if (!constant_type.array_size_literal.back()) + SPIRV_CROSS_THROW("Array size of OpConstantNull must be a literal."); + + SmallVector elements(constant_type.array.back()); + for (uint32_t i = 0; i < constant_type.array.back(); i++) + elements[i] = parent_id; + set(id, type, elements.data(), uint32_t(elements.size()), false); + } + else if (!constant_type.member_types.empty()) + { + uint32_t member_ids = ir.increase_bound_by(uint32_t(constant_type.member_types.size())); + SmallVector elements(constant_type.member_types.size()); + for (uint32_t i = 0; i < constant_type.member_types.size(); i++) + { + make_constant_null(member_ids + i, constant_type.member_types[i]); + elements[i] = member_ids + i; + } + set(id, type, elements.data(), uint32_t(elements.size()), false); + } + else + { + auto &constant = set(id, type); + constant.make_null(constant_type); + } +} + +const SmallVector &Compiler::get_declared_capabilities() const +{ + return ir.declared_capabilities; +} + +const SmallVector &Compiler::get_declared_extensions() const +{ + return ir.declared_extensions; +} + +std::string Compiler::get_remapped_declared_block_name(VariableID id) const +{ + return get_remapped_declared_block_name(id, false); +} + +std::string Compiler::get_remapped_declared_block_name(uint32_t id, bool fallback_prefer_instance_name) const +{ + auto itr = declared_block_names.find(id); + if (itr != end(declared_block_names)) + { + return itr->second; + } + else + { + auto &var = get(id); + + if (fallback_prefer_instance_name) + { + return to_name(var.self); + } + else + { + auto &type = get(var.basetype); + auto *type_meta = ir.find_meta(type.self); + auto *block_name = type_meta ? &type_meta->decoration.alias : nullptr; + return (!block_name || block_name->empty()) ? get_block_fallback_name(id) : *block_name; + } + } +} + +bool Compiler::reflection_ssbo_instance_name_is_significant() const +{ + if (ir.source.known) + { + // UAVs from HLSL source tend to be declared in a way where the type is reused + // but the instance name is significant, and that's the name we should report. + // For GLSL, SSBOs each have their own block type as that's how GLSL is written. + return ir.source.hlsl; + } + + unordered_set ssbo_type_ids; + bool aliased_ssbo_types = false; + + // If we don't have any OpSource information, we need to perform some shaky heuristics. + ir.for_each_typed_id([&](uint32_t, const SPIRVariable &var) { + auto &type = this->get(var.basetype); + if (!type.pointer || var.storage == StorageClassFunction) + return; + + bool ssbo = var.storage == StorageClassStorageBuffer || + (var.storage == StorageClassUniform && has_decoration(type.self, DecorationBufferBlock)); + + if (ssbo) + { + if (ssbo_type_ids.count(type.self)) + aliased_ssbo_types = true; + else + ssbo_type_ids.insert(type.self); + } + }); + + // If the block name is aliased, assume we have HLSL-style UAV declarations. + return aliased_ssbo_types; +} + +bool Compiler::instruction_to_result_type(uint32_t &result_type, uint32_t &result_id, spv::Op op, + const uint32_t *args, uint32_t length) +{ + if (length < 2) + return false; + + bool has_result_id = false, has_result_type = false; + HasResultAndType(op, &has_result_id, &has_result_type); + if (has_result_id && has_result_type) + { + result_type = args[0]; + result_id = args[1]; + return true; + } + else + return false; +} + +Bitset Compiler::combined_decoration_for_member(const SPIRType &type, uint32_t index) const +{ + Bitset flags; + auto *type_meta = ir.find_meta(type.self); + + if (type_meta) + { + auto &members = type_meta->members; + if (index >= members.size()) + return flags; + auto &dec = members[index]; + + flags.merge_or(dec.decoration_flags); + + auto &member_type = get(type.member_types[index]); + + // If our member type is a struct, traverse all the child members as well recursively. + auto &member_childs = member_type.member_types; + for (uint32_t i = 0; i < member_childs.size(); i++) + { + auto &child_member_type = get(member_childs[i]); + if (!child_member_type.pointer) + flags.merge_or(combined_decoration_for_member(member_type, i)); + } + } + + return flags; +} + +bool Compiler::is_desktop_only_format(spv::ImageFormat format) +{ + switch (format) + { + // Desktop-only formats + case ImageFormatR11fG11fB10f: + case ImageFormatR16f: + case ImageFormatRgb10A2: + case ImageFormatR8: + case ImageFormatRg8: + case ImageFormatR16: + case ImageFormatRg16: + case ImageFormatRgba16: + case ImageFormatR16Snorm: + case ImageFormatRg16Snorm: + case ImageFormatRgba16Snorm: + case ImageFormatR8Snorm: + case ImageFormatRg8Snorm: + case ImageFormatR8ui: + case ImageFormatRg8ui: + case ImageFormatR16ui: + case ImageFormatRgb10a2ui: + case ImageFormatR8i: + case ImageFormatRg8i: + case ImageFormatR16i: + return true; + default: + break; + } + + return false; +} + +// An image is determined to be a depth image if it is marked as a depth image and is not also +// explicitly marked with a color format, or if there are any sample/gather compare operations on it. +bool Compiler::is_depth_image(const SPIRType &type, uint32_t id) const +{ + return (type.image.depth && type.image.format == ImageFormatUnknown) || comparison_ids.count(id); +} + +bool Compiler::type_is_opaque_value(const SPIRType &type) const +{ + return !type.pointer && (type.basetype == SPIRType::SampledImage || type.basetype == SPIRType::Image || + type.basetype == SPIRType::Sampler); +} + +// Make these member functions so we can easily break on any force_recompile events. +void Compiler::force_recompile() +{ + is_force_recompile = true; +} + +void Compiler::force_recompile_guarantee_forward_progress() +{ + force_recompile(); + is_force_recompile_forward_progress = true; +} + +bool Compiler::is_forcing_recompilation() const +{ + return is_force_recompile; +} + +void Compiler::clear_force_recompile() +{ + is_force_recompile = false; + is_force_recompile_forward_progress = false; +} + +Compiler::PhysicalStorageBufferPointerHandler::PhysicalStorageBufferPointerHandler(Compiler &compiler_) + : compiler(compiler_) +{ +} + +Compiler::PhysicalBlockMeta *Compiler::PhysicalStorageBufferPointerHandler::find_block_meta(uint32_t id) const +{ + auto chain_itr = access_chain_to_physical_block.find(id); + if (chain_itr != access_chain_to_physical_block.end()) + return chain_itr->second; + else + return nullptr; +} + +void Compiler::PhysicalStorageBufferPointerHandler::mark_aligned_access(uint32_t id, const uint32_t *args, uint32_t length) +{ + uint32_t mask = *args; + args++; + length--; + if (length && (mask & MemoryAccessVolatileMask) != 0) + { + args++; + length--; + } + + if (length && (mask & MemoryAccessAlignedMask) != 0) + { + uint32_t alignment = *args; + auto *meta = find_block_meta(id); + + // This makes the assumption that the application does not rely on insane edge cases like: + // Bind buffer with ADDR = 8, use block offset of 8 bytes, load/store with 16 byte alignment. + // If we emit the buffer with alignment = 16 here, the first element at offset = 0 should + // actually have alignment of 8 bytes, but this is too theoretical and awkward to support. + // We could potentially keep track of any offset in the access chain, but it's + // practically impossible for high level compilers to emit code like that, + // so deducing overall alignment requirement based on maximum observed Alignment value is probably fine. + if (meta && alignment > meta->alignment) + meta->alignment = alignment; + } +} + +bool Compiler::PhysicalStorageBufferPointerHandler::type_is_bda_block_entry(uint32_t type_id) const +{ + auto &type = compiler.get(type_id); + return compiler.is_physical_pointer(type); +} + +uint32_t Compiler::PhysicalStorageBufferPointerHandler::get_minimum_scalar_alignment(const SPIRType &type) const +{ + if (type.storage == spv::StorageClassPhysicalStorageBufferEXT) + return 8; + else if (type.basetype == SPIRType::Struct) + { + uint32_t alignment = 0; + for (auto &member_type : type.member_types) + { + uint32_t member_align = get_minimum_scalar_alignment(compiler.get(member_type)); + if (member_align > alignment) + alignment = member_align; + } + return alignment; + } + else + return type.width / 8; +} + +void Compiler::PhysicalStorageBufferPointerHandler::setup_meta_chain(uint32_t type_id, uint32_t var_id) +{ + if (type_is_bda_block_entry(type_id)) + { + auto &meta = physical_block_type_meta[type_id]; + access_chain_to_physical_block[var_id] = &meta; + + auto &type = compiler.get(type_id); + + if (!compiler.is_physical_pointer_to_buffer_block(type)) + non_block_types.insert(type_id); + + if (meta.alignment == 0) + meta.alignment = get_minimum_scalar_alignment(compiler.get_pointee_type(type)); + } +} + +bool Compiler::PhysicalStorageBufferPointerHandler::handle(Op op, const uint32_t *args, uint32_t length) +{ + // When a BDA pointer comes to life, we need to keep a mapping of SSA ID -> type ID for the pointer type. + // For every load and store, we'll need to be able to look up the type ID being accessed and mark any alignment + // requirements. + switch (op) + { + case OpConvertUToPtr: + case OpBitcast: + case OpCompositeExtract: + // Extract can begin a new chain if we had a struct or array of pointers as input. + // We don't begin chains before we have a pure scalar pointer. + setup_meta_chain(args[0], args[1]); + break; + + case OpAccessChain: + case OpInBoundsAccessChain: + case OpPtrAccessChain: + case OpCopyObject: + { + auto itr = access_chain_to_physical_block.find(args[2]); + if (itr != access_chain_to_physical_block.end()) + access_chain_to_physical_block[args[1]] = itr->second; + break; + } + + case OpLoad: + { + setup_meta_chain(args[0], args[1]); + if (length >= 4) + mark_aligned_access(args[2], args + 3, length - 3); + break; + } + + case OpStore: + { + if (length >= 3) + mark_aligned_access(args[0], args + 2, length - 2); + break; + } + + default: + break; + } + + return true; +} + +uint32_t Compiler::PhysicalStorageBufferPointerHandler::get_base_non_block_type_id(uint32_t type_id) const +{ + auto *type = &compiler.get(type_id); + while (compiler.is_physical_pointer(*type) && !type_is_bda_block_entry(type_id)) + { + type_id = type->parent_type; + type = &compiler.get(type_id); + } + + assert(type_is_bda_block_entry(type_id)); + return type_id; +} + +void Compiler::PhysicalStorageBufferPointerHandler::analyze_non_block_types_from_block(const SPIRType &type) +{ + for (auto &member : type.member_types) + { + auto &subtype = compiler.get(member); + + if (compiler.is_physical_pointer(subtype) && !compiler.is_physical_pointer_to_buffer_block(subtype)) + non_block_types.insert(get_base_non_block_type_id(member)); + else if (subtype.basetype == SPIRType::Struct && !compiler.is_pointer(subtype)) + analyze_non_block_types_from_block(subtype); + } +} + +void Compiler::analyze_non_block_pointer_types() +{ + PhysicalStorageBufferPointerHandler handler(*this); + traverse_all_reachable_opcodes(get(ir.default_entry_point), handler); + + // Analyze any block declaration we have to make. It might contain + // physical pointers to POD types which we never used, and thus never added to the list. + // We'll need to add those pointer types to the set of types we declare. + ir.for_each_typed_id([&](uint32_t id, SPIRType &type) { + // Only analyze the raw block struct, not any pointer-to-struct, since that's just redundant. + if (type.self == id && + (has_decoration(type.self, DecorationBlock) || + has_decoration(type.self, DecorationBufferBlock))) + { + handler.analyze_non_block_types_from_block(type); + } + }); + + physical_storage_non_block_pointer_types.reserve(handler.non_block_types.size()); + for (auto type : handler.non_block_types) + physical_storage_non_block_pointer_types.push_back(type); + sort(begin(physical_storage_non_block_pointer_types), end(physical_storage_non_block_pointer_types)); + physical_storage_type_to_alignment = std::move(handler.physical_block_type_meta); +} + +bool Compiler::InterlockedResourceAccessPrepassHandler::handle(Op op, const uint32_t *, uint32_t) +{ + if (op == OpBeginInvocationInterlockEXT || op == OpEndInvocationInterlockEXT) + { + if (interlock_function_id != 0 && interlock_function_id != call_stack.back()) + { + // Most complex case, we have no sensible way of dealing with this + // other than taking the 100% conservative approach, exit early. + split_function_case = true; + return false; + } + else + { + interlock_function_id = call_stack.back(); + // If this call is performed inside control flow we have a problem. + auto &cfg = compiler.get_cfg_for_function(interlock_function_id); + + uint32_t from_block_id = compiler.get(interlock_function_id).entry_block; + bool outside_control_flow = cfg.node_terminates_control_flow_in_sub_graph(from_block_id, current_block_id); + if (!outside_control_flow) + control_flow_interlock = true; + } + } + return true; +} + +void Compiler::InterlockedResourceAccessPrepassHandler::rearm_current_block(const SPIRBlock &block) +{ + current_block_id = block.self; +} + +bool Compiler::InterlockedResourceAccessPrepassHandler::begin_function_scope(const uint32_t *args, uint32_t length) +{ + if (length < 3) + return false; + call_stack.push_back(args[2]); + return true; +} + +bool Compiler::InterlockedResourceAccessPrepassHandler::end_function_scope(const uint32_t *, uint32_t) +{ + call_stack.pop_back(); + return true; +} + +bool Compiler::InterlockedResourceAccessHandler::begin_function_scope(const uint32_t *args, uint32_t length) +{ + if (length < 3) + return false; + + if (args[2] == interlock_function_id) + call_stack_is_interlocked = true; + + call_stack.push_back(args[2]); + return true; +} + +bool Compiler::InterlockedResourceAccessHandler::end_function_scope(const uint32_t *, uint32_t) +{ + if (call_stack.back() == interlock_function_id) + call_stack_is_interlocked = false; + + call_stack.pop_back(); + return true; +} + +void Compiler::InterlockedResourceAccessHandler::access_potential_resource(uint32_t id) +{ + if ((use_critical_section && in_crit_sec) || (control_flow_interlock && call_stack_is_interlocked) || + split_function_case) + { + compiler.interlocked_resources.insert(id); + } +} + +bool Compiler::InterlockedResourceAccessHandler::handle(Op opcode, const uint32_t *args, uint32_t length) +{ + // Only care about critical section analysis if we have simple case. + if (use_critical_section) + { + if (opcode == OpBeginInvocationInterlockEXT) + { + in_crit_sec = true; + return true; + } + + if (opcode == OpEndInvocationInterlockEXT) + { + // End critical section--nothing more to do. + return false; + } + } + + // We need to figure out where images and buffers are loaded from, so do only the bare bones compilation we need. + switch (opcode) + { + case OpLoad: + { + if (length < 3) + return false; + + uint32_t ptr = args[2]; + auto *var = compiler.maybe_get_backing_variable(ptr); + + // We're only concerned with buffer and image memory here. + if (!var) + break; + + switch (var->storage) + { + default: + break; + + case StorageClassUniformConstant: + { + uint32_t result_type = args[0]; + uint32_t id = args[1]; + compiler.set(id, "", result_type, true); + compiler.register_read(id, ptr, true); + break; + } + + case StorageClassUniform: + // Must have BufferBlock; we only care about SSBOs. + if (!compiler.has_decoration(compiler.get(var->basetype).self, DecorationBufferBlock)) + break; + // fallthrough + case StorageClassStorageBuffer: + access_potential_resource(var->self); + break; + } + break; + } + + case OpInBoundsAccessChain: + case OpAccessChain: + case OpPtrAccessChain: + { + if (length < 3) + return false; + + uint32_t result_type = args[0]; + + auto &type = compiler.get(result_type); + if (type.storage == StorageClassUniform || type.storage == StorageClassUniformConstant || + type.storage == StorageClassStorageBuffer) + { + uint32_t id = args[1]; + uint32_t ptr = args[2]; + compiler.set(id, "", result_type, true); + compiler.register_read(id, ptr, true); + compiler.ir.ids[id].set_allow_type_rewrite(); + } + break; + } + + case OpImageTexelPointer: + { + if (length < 3) + return false; + + uint32_t result_type = args[0]; + uint32_t id = args[1]; + uint32_t ptr = args[2]; + auto &e = compiler.set(id, "", result_type, true); + auto *var = compiler.maybe_get_backing_variable(ptr); + if (var) + e.loaded_from = var->self; + break; + } + + case OpStore: + case OpImageWrite: + case OpAtomicStore: + { + if (length < 1) + return false; + + uint32_t ptr = args[0]; + auto *var = compiler.maybe_get_backing_variable(ptr); + if (var && (var->storage == StorageClassUniform || var->storage == StorageClassUniformConstant || + var->storage == StorageClassStorageBuffer)) + { + access_potential_resource(var->self); + } + + break; + } + + case OpCopyMemory: + { + if (length < 2) + return false; + + uint32_t dst = args[0]; + uint32_t src = args[1]; + auto *dst_var = compiler.maybe_get_backing_variable(dst); + auto *src_var = compiler.maybe_get_backing_variable(src); + + if (dst_var && (dst_var->storage == StorageClassUniform || dst_var->storage == StorageClassStorageBuffer)) + access_potential_resource(dst_var->self); + + if (src_var) + { + if (src_var->storage != StorageClassUniform && src_var->storage != StorageClassStorageBuffer) + break; + + if (src_var->storage == StorageClassUniform && + !compiler.has_decoration(compiler.get(src_var->basetype).self, DecorationBufferBlock)) + { + break; + } + + access_potential_resource(src_var->self); + } + + break; + } + + case OpImageRead: + case OpAtomicLoad: + { + if (length < 3) + return false; + + uint32_t ptr = args[2]; + auto *var = compiler.maybe_get_backing_variable(ptr); + + // We're only concerned with buffer and image memory here. + if (!var) + break; + + switch (var->storage) + { + default: + break; + + case StorageClassUniform: + // Must have BufferBlock; we only care about SSBOs. + if (!compiler.has_decoration(compiler.get(var->basetype).self, DecorationBufferBlock)) + break; + // fallthrough + case StorageClassUniformConstant: + case StorageClassStorageBuffer: + access_potential_resource(var->self); + break; + } + break; + } + + case OpAtomicExchange: + case OpAtomicCompareExchange: + case OpAtomicIIncrement: + case OpAtomicIDecrement: + case OpAtomicIAdd: + case OpAtomicISub: + case OpAtomicSMin: + case OpAtomicUMin: + case OpAtomicSMax: + case OpAtomicUMax: + case OpAtomicAnd: + case OpAtomicOr: + case OpAtomicXor: + { + if (length < 3) + return false; + + uint32_t ptr = args[2]; + auto *var = compiler.maybe_get_backing_variable(ptr); + if (var && (var->storage == StorageClassUniform || var->storage == StorageClassUniformConstant || + var->storage == StorageClassStorageBuffer)) + { + access_potential_resource(var->self); + } + + break; + } + + default: + break; + } + + return true; +} + +void Compiler::analyze_interlocked_resource_usage() +{ + if (get_execution_model() == ExecutionModelFragment && + (get_entry_point().flags.get(ExecutionModePixelInterlockOrderedEXT) || + get_entry_point().flags.get(ExecutionModePixelInterlockUnorderedEXT) || + get_entry_point().flags.get(ExecutionModeSampleInterlockOrderedEXT) || + get_entry_point().flags.get(ExecutionModeSampleInterlockUnorderedEXT))) + { + InterlockedResourceAccessPrepassHandler prepass_handler(*this, ir.default_entry_point); + traverse_all_reachable_opcodes(get(ir.default_entry_point), prepass_handler); + + InterlockedResourceAccessHandler handler(*this, ir.default_entry_point); + handler.interlock_function_id = prepass_handler.interlock_function_id; + handler.split_function_case = prepass_handler.split_function_case; + handler.control_flow_interlock = prepass_handler.control_flow_interlock; + handler.use_critical_section = !handler.split_function_case && !handler.control_flow_interlock; + + traverse_all_reachable_opcodes(get(ir.default_entry_point), handler); + + // For GLSL. If we hit any of these cases, we have to fall back to conservative approach. + interlocked_is_complex = + !handler.use_critical_section || handler.interlock_function_id != ir.default_entry_point; + } +} + +// Helper function +bool Compiler::check_internal_recursion(const SPIRType &type, std::unordered_set &checked_ids) +{ + if (type.basetype != SPIRType::Struct) + return false; + + if (checked_ids.count(type.self)) + return true; + + // Recurse into struct members + bool is_recursive = false; + checked_ids.insert(type.self); + uint32_t mbr_cnt = uint32_t(type.member_types.size()); + for (uint32_t mbr_idx = 0; !is_recursive && mbr_idx < mbr_cnt; mbr_idx++) + { + uint32_t mbr_type_id = type.member_types[mbr_idx]; + auto &mbr_type = get(mbr_type_id); + is_recursive |= check_internal_recursion(mbr_type, checked_ids); + } + checked_ids.erase(type.self); + return is_recursive; +} + +// Return whether the struct type contains a structural recursion nested somewhere within its content. +bool Compiler::type_contains_recursion(const SPIRType &type) +{ + std::unordered_set checked_ids; + return check_internal_recursion(type, checked_ids); +} + +bool Compiler::type_is_array_of_pointers(const SPIRType &type) const +{ + if (!is_array(type)) + return false; + + // BDA types must have parent type hierarchy. + if (!type.parent_type) + return false; + + // Punch through all array layers. + auto *parent = &get(type.parent_type); + while (is_array(*parent)) + parent = &get(parent->parent_type); + + return is_pointer(*parent); +} + +bool Compiler::flush_phi_required(BlockID from, BlockID to) const +{ + auto &child = get(to); + for (auto &phi : child.phi_variables) + if (phi.parent == from) + return true; + return false; +} + +void Compiler::add_loop_level() +{ + current_loop_level++; +} diff --git a/thirdparty/spirv-cross/spirv_cross.hpp b/thirdparty/spirv-cross/spirv_cross.hpp new file mode 100644 index 00000000000..e9062b485c6 --- /dev/null +++ b/thirdparty/spirv-cross/spirv_cross.hpp @@ -0,0 +1,1182 @@ +/* + * Copyright 2015-2021 Arm Limited + * SPDX-License-Identifier: Apache-2.0 OR MIT + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +/* + * At your option, you may choose to accept this material under either: + * 1. The Apache License, Version 2.0, found at , or + * 2. The MIT License, found at . + */ + +#ifndef SPIRV_CROSS_HPP +#define SPIRV_CROSS_HPP + +#ifndef SPV_ENABLE_UTILITY_CODE +#define SPV_ENABLE_UTILITY_CODE +#endif +#include "spirv.hpp" +#include "spirv_cfg.hpp" +#include "spirv_cross_parsed_ir.hpp" + +namespace SPIRV_CROSS_NAMESPACE +{ +struct Resource +{ + // Resources are identified with their SPIR-V ID. + // This is the ID of the OpVariable. + ID id; + + // The type ID of the variable which includes arrays and all type modifications. + // This type ID is not suitable for parsing OpMemberDecoration of a struct and other decorations in general + // since these modifications typically happen on the base_type_id. + TypeID type_id; + + // The base type of the declared resource. + // This type is the base type which ignores pointers and arrays of the type_id. + // This is mostly useful to parse decorations of the underlying type. + // base_type_id can also be obtained with get_type(get_type(type_id).self). + TypeID base_type_id; + + // The declared name (OpName) of the resource. + // For Buffer blocks, the name actually reflects the externally + // visible Block name. + // + // This name can be retrieved again by using either + // get_name(id) or get_name(base_type_id) depending if it's a buffer block or not. + // + // This name can be an empty string in which case get_fallback_name(id) can be + // used which obtains a suitable fallback identifier for an ID. + std::string name; +}; + +struct BuiltInResource +{ + // This is mostly here to support reflection of builtins such as Position/PointSize/CullDistance/ClipDistance. + // This needs to be different from Resource since we can collect builtins from blocks. + // A builtin present here does not necessarily mean it's considered an active builtin, + // since variable ID "activeness" is only tracked on OpVariable level, not Block members. + // For that, update_active_builtins() -> has_active_builtin() can be used to further refine the reflection. + spv::BuiltIn builtin; + + // This is the actual value type of the builtin. + // Typically float4, float, array for the gl_PerVertex builtins. + // If the builtin is a control point, the control point array type will be stripped away here as appropriate. + TypeID value_type_id; + + // This refers to the base resource which contains the builtin. + // If resource is a Block, it can hold multiple builtins, or it might not be a block. + // For advanced reflection scenarios, all information in builtin/value_type_id can be deduced, + // it's just more convenient this way. + Resource resource; +}; + +struct ShaderResources +{ + SmallVector uniform_buffers; + SmallVector storage_buffers; + SmallVector stage_inputs; + SmallVector stage_outputs; + SmallVector subpass_inputs; + SmallVector storage_images; + SmallVector sampled_images; + SmallVector atomic_counters; + SmallVector acceleration_structures; + SmallVector gl_plain_uniforms; + + // There can only be one push constant block, + // but keep the vector in case this restriction is lifted in the future. + SmallVector push_constant_buffers; + + SmallVector shader_record_buffers; + + // For Vulkan GLSL and HLSL source, + // these correspond to separate texture2D and samplers respectively. + SmallVector separate_images; + SmallVector separate_samplers; + + SmallVector builtin_inputs; + SmallVector builtin_outputs; +}; + +struct CombinedImageSampler +{ + // The ID of the sampler2D variable. + VariableID combined_id; + // The ID of the texture2D variable. + VariableID image_id; + // The ID of the sampler variable. + VariableID sampler_id; +}; + +struct SpecializationConstant +{ + // The ID of the specialization constant. + ConstantID id; + // The constant ID of the constant, used in Vulkan during pipeline creation. + uint32_t constant_id; +}; + +struct BufferRange +{ + unsigned index; + size_t offset; + size_t range; +}; + +enum BufferPackingStandard +{ + BufferPackingStd140, + BufferPackingStd430, + BufferPackingStd140EnhancedLayout, + BufferPackingStd430EnhancedLayout, + BufferPackingHLSLCbuffer, + BufferPackingHLSLCbufferPackOffset, + BufferPackingScalar, + BufferPackingScalarEnhancedLayout +}; + +struct EntryPoint +{ + std::string name; + spv::ExecutionModel execution_model; +}; + +class Compiler +{ +public: + friend class CFG; + friend class DominatorBuilder; + + // The constructor takes a buffer of SPIR-V words and parses it. + // It will create its own parser, parse the SPIR-V and move the parsed IR + // as if you had called the constructors taking ParsedIR directly. + explicit Compiler(std::vector ir); + Compiler(const uint32_t *ir, size_t word_count); + + // This is more modular. We can also consume a ParsedIR structure directly, either as a move, or copy. + // With copy, we can reuse the same parsed IR for multiple Compiler instances. + explicit Compiler(const ParsedIR &ir); + explicit Compiler(ParsedIR &&ir); + + virtual ~Compiler() = default; + + // After parsing, API users can modify the SPIR-V via reflection and call this + // to disassemble the SPIR-V into the desired langauage. + // Sub-classes actually implement this. + virtual std::string compile(); + + // Gets the identifier (OpName) of an ID. If not defined, an empty string will be returned. + const std::string &get_name(ID id) const; + + // Applies a decoration to an ID. Effectively injects OpDecorate. + void set_decoration(ID id, spv::Decoration decoration, uint32_t argument = 0); + void set_decoration_string(ID id, spv::Decoration decoration, const std::string &argument); + + // Overrides the identifier OpName of an ID. + // Identifiers beginning with underscores or identifiers which contain double underscores + // are reserved by the implementation. + void set_name(ID id, const std::string &name); + + // Gets a bitmask for the decorations which are applied to ID. + // I.e. (1ull << spv::DecorationFoo) | (1ull << spv::DecorationBar) + const Bitset &get_decoration_bitset(ID id) const; + + // Returns whether the decoration has been applied to the ID. + bool has_decoration(ID id, spv::Decoration decoration) const; + + // Gets the value for decorations which take arguments. + // If the decoration is a boolean (i.e. spv::DecorationNonWritable), + // 1 will be returned. + // If decoration doesn't exist or decoration is not recognized, + // 0 will be returned. + uint32_t get_decoration(ID id, spv::Decoration decoration) const; + const std::string &get_decoration_string(ID id, spv::Decoration decoration) const; + + // Removes the decoration for an ID. + void unset_decoration(ID id, spv::Decoration decoration); + + // Gets the SPIR-V type associated with ID. + // Mostly used with Resource::type_id and Resource::base_type_id to parse the underlying type of a resource. + const SPIRType &get_type(TypeID id) const; + + // Gets the SPIR-V type of a variable. + const SPIRType &get_type_from_variable(VariableID id) const; + + // Gets the underlying storage class for an OpVariable. + spv::StorageClass get_storage_class(VariableID id) const; + + // If get_name() is an empty string, get the fallback name which will be used + // instead in the disassembled source. + virtual const std::string get_fallback_name(ID id) const; + + // If get_name() of a Block struct is an empty string, get the fallback name. + // This needs to be per-variable as multiple variables can use the same block type. + virtual const std::string get_block_fallback_name(VariableID id) const; + + // Given an OpTypeStruct in ID, obtain the identifier for member number "index". + // This may be an empty string. + const std::string &get_member_name(TypeID id, uint32_t index) const; + + // Given an OpTypeStruct in ID, obtain the OpMemberDecoration for member number "index". + uint32_t get_member_decoration(TypeID id, uint32_t index, spv::Decoration decoration) const; + const std::string &get_member_decoration_string(TypeID id, uint32_t index, spv::Decoration decoration) const; + + // Sets the member identifier for OpTypeStruct ID, member number "index". + void set_member_name(TypeID id, uint32_t index, const std::string &name); + + // Returns the qualified member identifier for OpTypeStruct ID, member number "index", + // or an empty string if no qualified alias exists + const std::string &get_member_qualified_name(TypeID type_id, uint32_t index) const; + + // Gets the decoration mask for a member of a struct, similar to get_decoration_mask. + const Bitset &get_member_decoration_bitset(TypeID id, uint32_t index) const; + + // Returns whether the decoration has been applied to a member of a struct. + bool has_member_decoration(TypeID id, uint32_t index, spv::Decoration decoration) const; + + // Similar to set_decoration, but for struct members. + void set_member_decoration(TypeID id, uint32_t index, spv::Decoration decoration, uint32_t argument = 0); + void set_member_decoration_string(TypeID id, uint32_t index, spv::Decoration decoration, + const std::string &argument); + + // Unsets a member decoration, similar to unset_decoration. + void unset_member_decoration(TypeID id, uint32_t index, spv::Decoration decoration); + + // Gets the fallback name for a member, similar to get_fallback_name. + virtual const std::string get_fallback_member_name(uint32_t index) const + { + return join("_", index); + } + + // Returns a vector of which members of a struct are potentially in use by a + // SPIR-V shader. The granularity of this analysis is per-member of a struct. + // This can be used for Buffer (UBO), BufferBlock/StorageBuffer (SSBO) and PushConstant blocks. + // ID is the Resource::id obtained from get_shader_resources(). + SmallVector get_active_buffer_ranges(VariableID id) const; + + // Returns the effective size of a buffer block. + size_t get_declared_struct_size(const SPIRType &struct_type) const; + + // Returns the effective size of a buffer block, with a given array size + // for a runtime array. + // SSBOs are typically declared as runtime arrays. get_declared_struct_size() will return 0 for the size. + // This is not very helpful for applications which might need to know the array stride of its last member. + // This can be done through the API, but it is not very intuitive how to accomplish this, so here we provide a helper function + // to query the size of the buffer, assuming that the last member has a certain size. + // If the buffer does not contain a runtime array, array_size is ignored, and the function will behave as + // get_declared_struct_size(). + // To get the array stride of the last member, something like: + // get_declared_struct_size_runtime_array(type, 1) - get_declared_struct_size_runtime_array(type, 0) will work. + size_t get_declared_struct_size_runtime_array(const SPIRType &struct_type, size_t array_size) const; + + // Returns the effective size of a buffer block struct member. + size_t get_declared_struct_member_size(const SPIRType &struct_type, uint32_t index) const; + + // Returns a set of all global variables which are statically accessed + // by the control flow graph from the current entry point. + // Only variables which change the interface for a shader are returned, that is, + // variables with storage class of Input, Output, Uniform, UniformConstant, PushConstant and AtomicCounter + // storage classes are returned. + // + // To use the returned set as the filter for which variables are used during compilation, + // this set can be moved to set_enabled_interface_variables(). + std::unordered_set get_active_interface_variables() const; + + // Sets the interface variables which are used during compilation. + // By default, all variables are used. + // Once set, compile() will only consider the set in active_variables. + void set_enabled_interface_variables(std::unordered_set active_variables); + + // Query shader resources, use ids with reflection interface to modify or query binding points, etc. + ShaderResources get_shader_resources() const; + + // Query shader resources, but only return the variables which are part of active_variables. + // E.g.: get_shader_resources(get_active_variables()) to only return the variables which are statically + // accessed. + ShaderResources get_shader_resources(const std::unordered_set &active_variables) const; + + // Remapped variables are considered built-in variables and a backend will + // not emit a declaration for this variable. + // This is mostly useful for making use of builtins which are dependent on extensions. + void set_remapped_variable_state(VariableID id, bool remap_enable); + bool get_remapped_variable_state(VariableID id) const; + + // For subpassInput variables which are remapped to plain variables, + // the number of components in the remapped + // variable must be specified as the backing type of subpass inputs are opaque. + void set_subpass_input_remapped_components(VariableID id, uint32_t components); + uint32_t get_subpass_input_remapped_components(VariableID id) const; + + // All operations work on the current entry point. + // Entry points can be swapped out with set_entry_point(). + // Entry points should be set right after the constructor completes as some reflection functions traverse the graph from the entry point. + // Resource reflection also depends on the entry point. + // By default, the current entry point is set to the first OpEntryPoint which appears in the SPIR-V module. + + // Some shader languages restrict the names that can be given to entry points, and the + // corresponding backend will automatically rename an entry point name, during the call + // to compile() if it is illegal. For example, the common entry point name main() is + // illegal in MSL, and is renamed to an alternate name by the MSL backend. + // Given the original entry point name contained in the SPIR-V, this function returns + // the name, as updated by the backend during the call to compile(). If the name is not + // illegal, and has not been renamed, or if this function is called before compile(), + // this function will simply return the same name. + + // New variants of entry point query and reflection. + // Names for entry points in the SPIR-V module may alias if they belong to different execution models. + // To disambiguate, we must pass along with the entry point names the execution model. + SmallVector get_entry_points_and_stages() const; + void set_entry_point(const std::string &entry, spv::ExecutionModel execution_model); + + // Renames an entry point from old_name to new_name. + // If old_name is currently selected as the current entry point, it will continue to be the current entry point, + // albeit with a new name. + // get_entry_points() is essentially invalidated at this point. + void rename_entry_point(const std::string &old_name, const std::string &new_name, + spv::ExecutionModel execution_model); + const SPIREntryPoint &get_entry_point(const std::string &name, spv::ExecutionModel execution_model) const; + SPIREntryPoint &get_entry_point(const std::string &name, spv::ExecutionModel execution_model); + const std::string &get_cleansed_entry_point_name(const std::string &name, + spv::ExecutionModel execution_model) const; + + // Traverses all reachable opcodes and sets active_builtins to a bitmask of all builtin variables which are accessed in the shader. + void update_active_builtins(); + bool has_active_builtin(spv::BuiltIn builtin, spv::StorageClass storage) const; + + // Query and modify OpExecutionMode. + const Bitset &get_execution_mode_bitset() const; + + void unset_execution_mode(spv::ExecutionMode mode); + void set_execution_mode(spv::ExecutionMode mode, uint32_t arg0 = 0, uint32_t arg1 = 0, uint32_t arg2 = 0); + + // Gets argument for an execution mode (LocalSize, Invocations, OutputVertices). + // For LocalSize or LocalSizeId, the index argument is used to select the dimension (X = 0, Y = 1, Z = 2). + // For execution modes which do not have arguments, 0 is returned. + // LocalSizeId query returns an ID. If LocalSizeId execution mode is not used, it returns 0. + // LocalSize always returns a literal. If execution mode is LocalSizeId, + // the literal (spec constant or not) is still returned. + uint32_t get_execution_mode_argument(spv::ExecutionMode mode, uint32_t index = 0) const; + spv::ExecutionModel get_execution_model() const; + + bool is_tessellation_shader() const; + bool is_tessellating_triangles() const; + + // In SPIR-V, the compute work group size can be represented by a constant vector, in which case + // the LocalSize execution mode is ignored. + // + // This constant vector can be a constant vector, specialization constant vector, or partly specialized constant vector. + // To modify and query work group dimensions which are specialization constants, SPIRConstant values must be modified + // directly via get_constant() rather than using LocalSize directly. This function will return which constants should be modified. + // + // To modify dimensions which are *not* specialization constants, set_execution_mode should be used directly. + // Arguments to set_execution_mode which are specialization constants are effectively ignored during compilation. + // NOTE: This is somewhat different from how SPIR-V works. In SPIR-V, the constant vector will completely replace LocalSize, + // while in this interface, LocalSize is only ignored for specialization constants. + // + // The specialization constant will be written to x, y and z arguments. + // If the component is not a specialization constant, a zeroed out struct will be written. + // The return value is the constant ID of the builtin WorkGroupSize, but this is not expected to be useful + // for most use cases. + // If LocalSizeId is used, there is no uvec3 value representing the workgroup size, so the return value is 0, + // but x, y and z are written as normal if the components are specialization constants. + uint32_t get_work_group_size_specialization_constants(SpecializationConstant &x, SpecializationConstant &y, + SpecializationConstant &z) const; + + // Analyzes all OpImageFetch (texelFetch) opcodes and checks if there are instances where + // said instruction is used without a combined image sampler. + // GLSL targets do not support the use of texelFetch without a sampler. + // To workaround this, we must inject a dummy sampler which can be used to form a sampler2D at the call-site of + // texelFetch as necessary. + // + // This must be called before build_combined_image_samplers(). + // build_combined_image_samplers() may refer to the ID returned by this method if the returned ID is non-zero. + // The return value will be the ID of a sampler object if a dummy sampler is necessary, or 0 if no sampler object + // is required. + // + // If the returned ID is non-zero, it can be decorated with set/bindings as desired before calling compile(). + // Calling this function also invalidates get_active_interface_variables(), so this should be called + // before that function. + VariableID build_dummy_sampler_for_combined_images(); + + // Analyzes all separate image and samplers used from the currently selected entry point, + // and re-routes them all to a combined image sampler instead. + // This is required to "support" separate image samplers in targets which do not natively support + // this feature, like GLSL/ESSL. + // + // This must be called before compile() if such remapping is desired. + // This call will add new sampled images to the SPIR-V, + // so it will appear in reflection if get_shader_resources() is called after build_combined_image_samplers. + // + // If any image/sampler remapping was found, no separate image/samplers will appear in the decompiled output, + // but will still appear in reflection. + // + // The resulting samplers will be void of any decorations like name, descriptor sets and binding points, + // so this can be added before compile() if desired. + // + // Combined image samplers originating from this set are always considered active variables. + // Arrays of separate samplers are not supported, but arrays of separate images are supported. + // Array of images + sampler -> Array of combined image samplers. + void build_combined_image_samplers(); + + // Gets a remapping for the combined image samplers. + const SmallVector &get_combined_image_samplers() const + { + return combined_image_samplers; + } + + // Set a new variable type remap callback. + // The type remapping is designed to allow global interface variable to assume more special types. + // A typical example here is to remap sampler2D into samplerExternalOES, which currently isn't supported + // directly by SPIR-V. + // + // In compile() while emitting code, + // for every variable that is declared, including function parameters, the callback will be called + // and the API user has a chance to change the textual representation of the type used to declare the variable. + // The API user can detect special patterns in names to guide the remapping. + void set_variable_type_remap_callback(VariableTypeRemapCallback cb) + { + variable_remap_callback = std::move(cb); + } + + // API for querying which specialization constants exist. + // To modify a specialization constant before compile(), use get_constant(constant.id), + // then update constants directly in the SPIRConstant data structure. + // For composite types, the subconstants can be iterated over and modified. + // constant_type is the SPIRType for the specialization constant, + // which can be queried to determine which fields in the unions should be poked at. + SmallVector get_specialization_constants() const; + SPIRConstant &get_constant(ConstantID id); + const SPIRConstant &get_constant(ConstantID id) const; + + uint32_t get_current_id_bound() const + { + return uint32_t(ir.ids.size()); + } + + // API for querying buffer objects. + // The type passed in here should be the base type of a resource, i.e. + // get_type(resource.base_type_id) + // as decorations are set in the basic Block type. + // The type passed in here must have these decorations set, or an exception is raised. + // Only UBOs and SSBOs or sub-structs which are part of these buffer types will have these decorations set. + uint32_t type_struct_member_offset(const SPIRType &type, uint32_t index) const; + uint32_t type_struct_member_array_stride(const SPIRType &type, uint32_t index) const; + uint32_t type_struct_member_matrix_stride(const SPIRType &type, uint32_t index) const; + + // Gets the offset in SPIR-V words (uint32_t) for a decoration which was originally declared in the SPIR-V binary. + // The offset will point to one or more uint32_t literals which can be modified in-place before using the SPIR-V binary. + // Note that adding or removing decorations using the reflection API will not change the behavior of this function. + // If the decoration was declared, sets the word_offset to an offset into the provided SPIR-V binary buffer and returns true, + // otherwise, returns false. + // If the decoration does not have any value attached to it (e.g. DecorationRelaxedPrecision), this function will also return false. + bool get_binary_offset_for_decoration(VariableID id, spv::Decoration decoration, uint32_t &word_offset) const; + + // HLSL counter buffer reflection interface. + // Append/Consume/Increment/Decrement in HLSL is implemented as two "neighbor" buffer objects where + // one buffer implements the storage, and a single buffer containing just a lone "int" implements the counter. + // To SPIR-V these will be exposed as two separate buffers, but glslang HLSL frontend emits a special indentifier + // which lets us link the two buffers together. + + // Queries if a variable ID is a counter buffer which "belongs" to a regular buffer object. + + // If SPV_GOOGLE_hlsl_functionality1 is used, this can be used even with a stripped SPIR-V module. + // Otherwise, this query is purely based on OpName identifiers as found in the SPIR-V module, and will + // only return true if OpSource was reported HLSL. + // To rely on this functionality, ensure that the SPIR-V module is not stripped. + + bool buffer_is_hlsl_counter_buffer(VariableID id) const; + + // Queries if a buffer object has a neighbor "counter" buffer. + // If so, the ID of that counter buffer will be returned in counter_id. + // If SPV_GOOGLE_hlsl_functionality1 is used, this can be used even with a stripped SPIR-V module. + // Otherwise, this query is purely based on OpName identifiers as found in the SPIR-V module, and will + // only return true if OpSource was reported HLSL. + // To rely on this functionality, ensure that the SPIR-V module is not stripped. + bool buffer_get_hlsl_counter_buffer(VariableID id, uint32_t &counter_id) const; + + // Gets the list of all SPIR-V Capabilities which were declared in the SPIR-V module. + const SmallVector &get_declared_capabilities() const; + + // Gets the list of all SPIR-V extensions which were declared in the SPIR-V module. + const SmallVector &get_declared_extensions() const; + + // When declaring buffer blocks in GLSL, the name declared in the GLSL source + // might not be the same as the name declared in the SPIR-V module due to naming conflicts. + // In this case, SPIRV-Cross needs to find a fallback-name, and it might only + // be possible to know this name after compiling to GLSL. + // This is particularly important for HLSL input and UAVs which tends to reuse the same block type + // for multiple distinct blocks. For these cases it is not possible to modify the name of the type itself + // because it might be unique. Instead, you can use this interface to check after compilation which + // name was actually used if your input SPIR-V tends to have this problem. + // For other names like remapped names for variables, etc, it's generally enough to query the name of the variables + // after compiling, block names are an exception to this rule. + // ID is the name of a variable as returned by Resource::id, and must be a variable with a Block-like type. + // + // This also applies to HLSL cbuffers. + std::string get_remapped_declared_block_name(VariableID id) const; + + // For buffer block variables, get the decorations for that variable. + // Sometimes, decorations for buffer blocks are found in member decorations instead + // of direct decorations on the variable itself. + // The most common use here is to check if a buffer is readonly or writeonly. + Bitset get_buffer_block_flags(VariableID id) const; + + // Returns whether the position output is invariant + bool is_position_invariant() const + { + return position_invariant; + } + +protected: + const uint32_t *stream(const Instruction &instr) const + { + // If we're not going to use any arguments, just return nullptr. + // We want to avoid case where we return an out of range pointer + // that trips debug assertions on some platforms. + if (!instr.length) + return nullptr; + + if (instr.is_embedded()) + { + auto &embedded = static_cast(instr); + assert(embedded.ops.size() == instr.length); + return embedded.ops.data(); + } + else + { + if (instr.offset + instr.length > ir.spirv.size()) + SPIRV_CROSS_THROW("Compiler::stream() out of range."); + return &ir.spirv[instr.offset]; + } + } + + uint32_t *stream_mutable(const Instruction &instr) const + { + return const_cast(stream(instr)); + } + + ParsedIR ir; + // Marks variables which have global scope and variables which can alias with other variables + // (SSBO, image load store, etc) + SmallVector global_variables; + SmallVector aliased_variables; + + SPIRFunction *current_function = nullptr; + SPIRBlock *current_block = nullptr; + uint32_t current_loop_level = 0; + std::unordered_set active_interface_variables; + bool check_active_interface_variables = false; + + void add_loop_level(); + + void set_initializers(SPIRExpression &e) + { + e.emitted_loop_level = current_loop_level; + } + + template + void set_initializers(const T &) + { + } + + // If our IDs are out of range here as part of opcodes, throw instead of + // undefined behavior. + template + T &set(uint32_t id, P &&... args) + { + ir.add_typed_id(static_cast(T::type), id); + auto &var = variant_set(ir.ids[id], std::forward

(args)...); + var.self = id; + set_initializers(var); + return var; + } + + template + T &get(uint32_t id) + { + return variant_get(ir.ids[id]); + } + + template + T *maybe_get(uint32_t id) + { + if (id >= ir.ids.size()) + return nullptr; + else if (ir.ids[id].get_type() == static_cast(T::type)) + return &get(id); + else + return nullptr; + } + + template + const T &get(uint32_t id) const + { + return variant_get(ir.ids[id]); + } + + template + const T *maybe_get(uint32_t id) const + { + if (id >= ir.ids.size()) + return nullptr; + else if (ir.ids[id].get_type() == static_cast(T::type)) + return &get(id); + else + return nullptr; + } + + // Gets the id of SPIR-V type underlying the given type_id, which might be a pointer. + uint32_t get_pointee_type_id(uint32_t type_id) const; + + // Gets the SPIR-V type underlying the given type, which might be a pointer. + const SPIRType &get_pointee_type(const SPIRType &type) const; + + // Gets the SPIR-V type underlying the given type_id, which might be a pointer. + const SPIRType &get_pointee_type(uint32_t type_id) const; + + // Gets the ID of the SPIR-V type underlying a variable. + uint32_t get_variable_data_type_id(const SPIRVariable &var) const; + + // Gets the SPIR-V type underlying a variable. + SPIRType &get_variable_data_type(const SPIRVariable &var); + + // Gets the SPIR-V type underlying a variable. + const SPIRType &get_variable_data_type(const SPIRVariable &var) const; + + // Gets the SPIR-V element type underlying an array variable. + SPIRType &get_variable_element_type(const SPIRVariable &var); + + // Gets the SPIR-V element type underlying an array variable. + const SPIRType &get_variable_element_type(const SPIRVariable &var) const; + + // Sets the qualified member identifier for OpTypeStruct ID, member number "index". + void set_member_qualified_name(uint32_t type_id, uint32_t index, const std::string &name); + void set_qualified_name(uint32_t id, const std::string &name); + + // Returns if the given type refers to a sampled image. + bool is_sampled_image_type(const SPIRType &type); + + const SPIREntryPoint &get_entry_point() const; + SPIREntryPoint &get_entry_point(); + static bool is_tessellation_shader(spv::ExecutionModel model); + + virtual std::string to_name(uint32_t id, bool allow_alias = true) const; + bool is_builtin_variable(const SPIRVariable &var) const; + bool is_builtin_type(const SPIRType &type) const; + bool is_hidden_variable(const SPIRVariable &var, bool include_builtins = false) const; + bool is_immutable(uint32_t id) const; + bool is_member_builtin(const SPIRType &type, uint32_t index, spv::BuiltIn *builtin) const; + bool is_scalar(const SPIRType &type) const; + bool is_vector(const SPIRType &type) const; + bool is_matrix(const SPIRType &type) const; + bool is_array(const SPIRType &type) const; + bool is_pointer(const SPIRType &type) const; + bool is_physical_pointer(const SPIRType &type) const; + bool is_physical_pointer_to_buffer_block(const SPIRType &type) const; + static bool is_runtime_size_array(const SPIRType &type); + uint32_t expression_type_id(uint32_t id) const; + const SPIRType &expression_type(uint32_t id) const; + bool expression_is_lvalue(uint32_t id) const; + bool variable_storage_is_aliased(const SPIRVariable &var); + SPIRVariable *maybe_get_backing_variable(uint32_t chain); + + void register_read(uint32_t expr, uint32_t chain, bool forwarded); + void register_write(uint32_t chain); + + inline bool is_continue(uint32_t next) const + { + return (ir.block_meta[next] & ParsedIR::BLOCK_META_CONTINUE_BIT) != 0; + } + + inline bool is_single_block_loop(uint32_t next) const + { + auto &block = get(next); + return block.merge == SPIRBlock::MergeLoop && block.continue_block == ID(next); + } + + inline bool is_break(uint32_t next) const + { + return (ir.block_meta[next] & + (ParsedIR::BLOCK_META_LOOP_MERGE_BIT | ParsedIR::BLOCK_META_MULTISELECT_MERGE_BIT)) != 0; + } + + inline bool is_loop_break(uint32_t next) const + { + return (ir.block_meta[next] & ParsedIR::BLOCK_META_LOOP_MERGE_BIT) != 0; + } + + inline bool is_conditional(uint32_t next) const + { + return (ir.block_meta[next] & + (ParsedIR::BLOCK_META_SELECTION_MERGE_BIT | ParsedIR::BLOCK_META_MULTISELECT_MERGE_BIT)) != 0; + } + + // Dependency tracking for temporaries read from variables. + void flush_dependees(SPIRVariable &var); + void flush_all_active_variables(); + void flush_control_dependent_expressions(uint32_t block); + void flush_all_atomic_capable_variables(); + void flush_all_aliased_variables(); + void register_global_read_dependencies(const SPIRBlock &func, uint32_t id); + void register_global_read_dependencies(const SPIRFunction &func, uint32_t id); + std::unordered_set invalid_expressions; + + void update_name_cache(std::unordered_set &cache, std::string &name); + + // A variant which takes two sets of names. The secondary is only used to verify there are no collisions, + // but the set is not updated when we have found a new name. + // Used primarily when adding block interface names. + void update_name_cache(std::unordered_set &cache_primary, + const std::unordered_set &cache_secondary, std::string &name); + + bool function_is_pure(const SPIRFunction &func); + bool block_is_pure(const SPIRBlock &block); + bool function_is_control_dependent(const SPIRFunction &func); + bool block_is_control_dependent(const SPIRBlock &block); + + bool execution_is_branchless(const SPIRBlock &from, const SPIRBlock &to) const; + bool execution_is_direct_branch(const SPIRBlock &from, const SPIRBlock &to) const; + bool execution_is_noop(const SPIRBlock &from, const SPIRBlock &to) const; + SPIRBlock::ContinueBlockType continue_block_type(const SPIRBlock &continue_block) const; + + void force_recompile(); + void force_recompile_guarantee_forward_progress(); + void clear_force_recompile(); + bool is_forcing_recompilation() const; + bool is_force_recompile = false; + bool is_force_recompile_forward_progress = false; + + bool block_is_noop(const SPIRBlock &block) const; + bool block_is_loop_candidate(const SPIRBlock &block, SPIRBlock::Method method) const; + + bool types_are_logically_equivalent(const SPIRType &a, const SPIRType &b) const; + void inherit_expression_dependencies(uint32_t dst, uint32_t source); + void add_implied_read_expression(SPIRExpression &e, uint32_t source); + void add_implied_read_expression(SPIRAccessChain &e, uint32_t source); + void add_active_interface_variable(uint32_t var_id); + + // For proper multiple entry point support, allow querying if an Input or Output + // variable is part of that entry points interface. + bool interface_variable_exists_in_entry_point(uint32_t id) const; + + SmallVector combined_image_samplers; + + void remap_variable_type_name(const SPIRType &type, const std::string &var_name, std::string &type_name) const + { + if (variable_remap_callback) + variable_remap_callback(type, var_name, type_name); + } + + void set_ir(const ParsedIR &parsed); + void set_ir(ParsedIR &&parsed); + void parse_fixup(); + + // Used internally to implement various traversals for queries. + struct OpcodeHandler + { + virtual ~OpcodeHandler() = default; + + // Return true if traversal should continue. + // If false, traversal will end immediately. + virtual bool handle(spv::Op opcode, const uint32_t *args, uint32_t length) = 0; + virtual bool handle_terminator(const SPIRBlock &) + { + return true; + } + + virtual bool follow_function_call(const SPIRFunction &) + { + return true; + } + + virtual void set_current_block(const SPIRBlock &) + { + } + + // Called after returning from a function or when entering a block, + // can be called multiple times per block, + // while set_current_block is only called on block entry. + virtual void rearm_current_block(const SPIRBlock &) + { + } + + virtual bool begin_function_scope(const uint32_t *, uint32_t) + { + return true; + } + + virtual bool end_function_scope(const uint32_t *, uint32_t) + { + return true; + } + }; + + struct BufferAccessHandler : OpcodeHandler + { + BufferAccessHandler(const Compiler &compiler_, SmallVector &ranges_, uint32_t id_) + : compiler(compiler_) + , ranges(ranges_) + , id(id_) + { + } + + bool handle(spv::Op opcode, const uint32_t *args, uint32_t length) override; + + const Compiler &compiler; + SmallVector &ranges; + uint32_t id; + + std::unordered_set seen; + }; + + struct InterfaceVariableAccessHandler : OpcodeHandler + { + InterfaceVariableAccessHandler(const Compiler &compiler_, std::unordered_set &variables_) + : compiler(compiler_) + , variables(variables_) + { + } + + bool handle(spv::Op opcode, const uint32_t *args, uint32_t length) override; + + const Compiler &compiler; + std::unordered_set &variables; + }; + + struct CombinedImageSamplerHandler : OpcodeHandler + { + CombinedImageSamplerHandler(Compiler &compiler_) + : compiler(compiler_) + { + } + bool handle(spv::Op opcode, const uint32_t *args, uint32_t length) override; + bool begin_function_scope(const uint32_t *args, uint32_t length) override; + bool end_function_scope(const uint32_t *args, uint32_t length) override; + + Compiler &compiler; + + // Each function in the call stack needs its own remapping for parameters so we can deduce which global variable each texture/sampler the parameter is statically bound to. + std::stack> parameter_remapping; + std::stack functions; + + uint32_t remap_parameter(uint32_t id); + void push_remap_parameters(const SPIRFunction &func, const uint32_t *args, uint32_t length); + void pop_remap_parameters(); + void register_combined_image_sampler(SPIRFunction &caller, VariableID combined_id, VariableID texture_id, + VariableID sampler_id, bool depth); + }; + + struct DummySamplerForCombinedImageHandler : OpcodeHandler + { + DummySamplerForCombinedImageHandler(Compiler &compiler_) + : compiler(compiler_) + { + } + bool handle(spv::Op opcode, const uint32_t *args, uint32_t length) override; + + Compiler &compiler; + bool need_dummy_sampler = false; + }; + + struct ActiveBuiltinHandler : OpcodeHandler + { + ActiveBuiltinHandler(Compiler &compiler_) + : compiler(compiler_) + { + } + + bool handle(spv::Op opcode, const uint32_t *args, uint32_t length) override; + Compiler &compiler; + + void handle_builtin(const SPIRType &type, spv::BuiltIn builtin, const Bitset &decoration_flags); + void add_if_builtin(uint32_t id); + void add_if_builtin_or_block(uint32_t id); + void add_if_builtin(uint32_t id, bool allow_blocks); + }; + + bool traverse_all_reachable_opcodes(const SPIRBlock &block, OpcodeHandler &handler) const; + bool traverse_all_reachable_opcodes(const SPIRFunction &block, OpcodeHandler &handler) const; + // This must be an ordered data structure so we always pick the same type aliases. + SmallVector global_struct_cache; + + ShaderResources get_shader_resources(const std::unordered_set *active_variables) const; + + VariableTypeRemapCallback variable_remap_callback; + + bool get_common_basic_type(const SPIRType &type, SPIRType::BaseType &base_type); + + std::unordered_set forced_temporaries; + std::unordered_set forwarded_temporaries; + std::unordered_set suppressed_usage_tracking; + std::unordered_set hoisted_temporaries; + std::unordered_set forced_invariant_temporaries; + + Bitset active_input_builtins; + Bitset active_output_builtins; + uint32_t clip_distance_count = 0; + uint32_t cull_distance_count = 0; + bool position_invariant = false; + + void analyze_parameter_preservation( + SPIRFunction &entry, const CFG &cfg, + const std::unordered_map> &variable_to_blocks, + const std::unordered_map> &complete_write_blocks); + + // If a variable ID or parameter ID is found in this set, a sampler is actually a shadow/comparison sampler. + // SPIR-V does not support this distinction, so we must keep track of this information outside the type system. + // There might be unrelated IDs found in this set which do not correspond to actual variables. + // This set should only be queried for the existence of samplers which are already known to be variables or parameter IDs. + // Similar is implemented for images, as well as if subpass inputs are needed. + std::unordered_set comparison_ids; + bool need_subpass_input = false; + bool need_subpass_input_ms = false; + + // In certain backends, we will need to use a dummy sampler to be able to emit code. + // GLSL does not support texelFetch on texture2D objects, but SPIR-V does, + // so we need to workaround by having the application inject a dummy sampler. + uint32_t dummy_sampler_id = 0; + + void analyze_image_and_sampler_usage(); + + struct CombinedImageSamplerDrefHandler : OpcodeHandler + { + CombinedImageSamplerDrefHandler(Compiler &compiler_) + : compiler(compiler_) + { + } + bool handle(spv::Op opcode, const uint32_t *args, uint32_t length) override; + + Compiler &compiler; + std::unordered_set dref_combined_samplers; + }; + + struct CombinedImageSamplerUsageHandler : OpcodeHandler + { + CombinedImageSamplerUsageHandler(Compiler &compiler_, + const std::unordered_set &dref_combined_samplers_) + : compiler(compiler_) + , dref_combined_samplers(dref_combined_samplers_) + { + } + + bool begin_function_scope(const uint32_t *args, uint32_t length) override; + bool handle(spv::Op opcode, const uint32_t *args, uint32_t length) override; + Compiler &compiler; + const std::unordered_set &dref_combined_samplers; + + std::unordered_map> dependency_hierarchy; + std::unordered_set comparison_ids; + + void add_hierarchy_to_comparison_ids(uint32_t ids); + bool need_subpass_input = false; + bool need_subpass_input_ms = false; + void add_dependency(uint32_t dst, uint32_t src); + }; + + void build_function_control_flow_graphs_and_analyze(); + std::unordered_map> function_cfgs; + const CFG &get_cfg_for_current_function() const; + const CFG &get_cfg_for_function(uint32_t id) const; + + struct CFGBuilder : OpcodeHandler + { + explicit CFGBuilder(Compiler &compiler_); + + bool follow_function_call(const SPIRFunction &func) override; + bool handle(spv::Op op, const uint32_t *args, uint32_t length) override; + Compiler &compiler; + std::unordered_map> function_cfgs; + }; + + struct AnalyzeVariableScopeAccessHandler : OpcodeHandler + { + AnalyzeVariableScopeAccessHandler(Compiler &compiler_, SPIRFunction &entry_); + + bool follow_function_call(const SPIRFunction &) override; + void set_current_block(const SPIRBlock &block) override; + + void notify_variable_access(uint32_t id, uint32_t block); + bool id_is_phi_variable(uint32_t id) const; + bool id_is_potential_temporary(uint32_t id) const; + bool handle(spv::Op op, const uint32_t *args, uint32_t length) override; + bool handle_terminator(const SPIRBlock &block) override; + + Compiler &compiler; + SPIRFunction &entry; + std::unordered_map> accessed_variables_to_block; + std::unordered_map> accessed_temporaries_to_block; + std::unordered_map result_id_to_type; + std::unordered_map> complete_write_variables_to_block; + std::unordered_map> partial_write_variables_to_block; + std::unordered_set access_chain_expressions; + // Access chains used in multiple blocks mean hoisting all the variables used to construct the access chain as not all backends can use pointers. + // This is also relevant when forwarding opaque objects since we cannot lower these to temporaries. + std::unordered_map> rvalue_forward_children; + const SPIRBlock *current_block = nullptr; + }; + + struct StaticExpressionAccessHandler : OpcodeHandler + { + StaticExpressionAccessHandler(Compiler &compiler_, uint32_t variable_id_); + bool follow_function_call(const SPIRFunction &) override; + bool handle(spv::Op op, const uint32_t *args, uint32_t length) override; + + Compiler &compiler; + uint32_t variable_id; + uint32_t static_expression = 0; + uint32_t write_count = 0; + }; + + struct PhysicalBlockMeta + { + uint32_t alignment = 0; + }; + + struct PhysicalStorageBufferPointerHandler : OpcodeHandler + { + explicit PhysicalStorageBufferPointerHandler(Compiler &compiler_); + bool handle(spv::Op op, const uint32_t *args, uint32_t length) override; + Compiler &compiler; + + std::unordered_set non_block_types; + std::unordered_map physical_block_type_meta; + std::unordered_map access_chain_to_physical_block; + + void mark_aligned_access(uint32_t id, const uint32_t *args, uint32_t length); + PhysicalBlockMeta *find_block_meta(uint32_t id) const; + bool type_is_bda_block_entry(uint32_t type_id) const; + void setup_meta_chain(uint32_t type_id, uint32_t var_id); + uint32_t get_minimum_scalar_alignment(const SPIRType &type) const; + void analyze_non_block_types_from_block(const SPIRType &type); + uint32_t get_base_non_block_type_id(uint32_t type_id) const; + }; + void analyze_non_block_pointer_types(); + SmallVector physical_storage_non_block_pointer_types; + std::unordered_map physical_storage_type_to_alignment; + + void analyze_variable_scope(SPIRFunction &function, AnalyzeVariableScopeAccessHandler &handler); + void find_function_local_luts(SPIRFunction &function, const AnalyzeVariableScopeAccessHandler &handler, + bool single_function); + bool may_read_undefined_variable_in_block(const SPIRBlock &block, uint32_t var); + + // Finds all resources that are written to from inside the critical section, if present. + // The critical section is delimited by OpBeginInvocationInterlockEXT and + // OpEndInvocationInterlockEXT instructions. In MSL and HLSL, any resources written + // while inside the critical section must be placed in a raster order group. + struct InterlockedResourceAccessHandler : OpcodeHandler + { + InterlockedResourceAccessHandler(Compiler &compiler_, uint32_t entry_point_id) + : compiler(compiler_) + { + call_stack.push_back(entry_point_id); + } + + bool handle(spv::Op op, const uint32_t *args, uint32_t length) override; + bool begin_function_scope(const uint32_t *args, uint32_t length) override; + bool end_function_scope(const uint32_t *args, uint32_t length) override; + + Compiler &compiler; + bool in_crit_sec = false; + + uint32_t interlock_function_id = 0; + bool split_function_case = false; + bool control_flow_interlock = false; + bool use_critical_section = false; + bool call_stack_is_interlocked = false; + SmallVector call_stack; + + void access_potential_resource(uint32_t id); + }; + + struct InterlockedResourceAccessPrepassHandler : OpcodeHandler + { + InterlockedResourceAccessPrepassHandler(Compiler &compiler_, uint32_t entry_point_id) + : compiler(compiler_) + { + call_stack.push_back(entry_point_id); + } + + void rearm_current_block(const SPIRBlock &block) override; + bool handle(spv::Op op, const uint32_t *args, uint32_t length) override; + bool begin_function_scope(const uint32_t *args, uint32_t length) override; + bool end_function_scope(const uint32_t *args, uint32_t length) override; + + Compiler &compiler; + uint32_t interlock_function_id = 0; + uint32_t current_block_id = 0; + bool split_function_case = false; + bool control_flow_interlock = false; + SmallVector call_stack; + }; + + void analyze_interlocked_resource_usage(); + // The set of all resources written while inside the critical section, if present. + std::unordered_set interlocked_resources; + bool interlocked_is_complex = false; + + void make_constant_null(uint32_t id, uint32_t type); + + std::unordered_map declared_block_names; + + bool instruction_to_result_type(uint32_t &result_type, uint32_t &result_id, spv::Op op, const uint32_t *args, + uint32_t length); + + Bitset combined_decoration_for_member(const SPIRType &type, uint32_t index) const; + static bool is_desktop_only_format(spv::ImageFormat format); + + bool is_depth_image(const SPIRType &type, uint32_t id) const; + + void set_extended_decoration(uint32_t id, ExtendedDecorations decoration, uint32_t value = 0); + uint32_t get_extended_decoration(uint32_t id, ExtendedDecorations decoration) const; + bool has_extended_decoration(uint32_t id, ExtendedDecorations decoration) const; + void unset_extended_decoration(uint32_t id, ExtendedDecorations decoration); + + void set_extended_member_decoration(uint32_t type, uint32_t index, ExtendedDecorations decoration, + uint32_t value = 0); + uint32_t get_extended_member_decoration(uint32_t type, uint32_t index, ExtendedDecorations decoration) const; + bool has_extended_member_decoration(uint32_t type, uint32_t index, ExtendedDecorations decoration) const; + void unset_extended_member_decoration(uint32_t type, uint32_t index, ExtendedDecorations decoration); + + bool check_internal_recursion(const SPIRType &type, std::unordered_set &checked_ids); + bool type_contains_recursion(const SPIRType &type); + bool type_is_array_of_pointers(const SPIRType &type) const; + bool type_is_block_like(const SPIRType &type) const; + bool type_is_top_level_block(const SPIRType &type) const; + bool type_is_opaque_value(const SPIRType &type) const; + + bool reflection_ssbo_instance_name_is_significant() const; + std::string get_remapped_declared_block_name(uint32_t id, bool fallback_prefer_instance_name) const; + + bool flush_phi_required(BlockID from, BlockID to) const; + + uint32_t evaluate_spec_constant_u32(const SPIRConstantOp &spec) const; + uint32_t evaluate_constant_u32(uint32_t id) const; + + bool is_vertex_like_shader() const; + + // Get the correct case list for the OpSwitch, since it can be either a + // 32 bit wide condition or a 64 bit, but the type is not embedded in the + // instruction itself. + const SmallVector &get_case_list(const SPIRBlock &block) const; + +private: + // Used only to implement the old deprecated get_entry_point() interface. + const SPIREntryPoint &get_first_entry_point(const std::string &name) const; + SPIREntryPoint &get_first_entry_point(const std::string &name); +}; +} // namespace SPIRV_CROSS_NAMESPACE + +#endif diff --git a/thirdparty/spirv-cross/spirv_cross_containers.hpp b/thirdparty/spirv-cross/spirv_cross_containers.hpp new file mode 100644 index 00000000000..c496cb75be5 --- /dev/null +++ b/thirdparty/spirv-cross/spirv_cross_containers.hpp @@ -0,0 +1,756 @@ +/* + * Copyright 2019-2021 Hans-Kristian Arntzen + * SPDX-License-Identifier: Apache-2.0 OR MIT + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +/* + * At your option, you may choose to accept this material under either: + * 1. The Apache License, Version 2.0, found at , or + * 2. The MIT License, found at . + */ + +#ifndef SPIRV_CROSS_CONTAINERS_HPP +#define SPIRV_CROSS_CONTAINERS_HPP + +#include "spirv_cross_error_handling.hpp" +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include + +#ifdef SPIRV_CROSS_NAMESPACE_OVERRIDE +#define SPIRV_CROSS_NAMESPACE SPIRV_CROSS_NAMESPACE_OVERRIDE +#else +#define SPIRV_CROSS_NAMESPACE spirv_cross +#endif + +namespace SPIRV_CROSS_NAMESPACE +{ +#ifndef SPIRV_CROSS_FORCE_STL_TYPES +// std::aligned_storage does not support size == 0, so roll our own. +template +class AlignedBuffer +{ +public: + T *data() + { +#if defined(_MSC_VER) && _MSC_VER < 1900 + // MSVC 2013 workarounds, sigh ... + // Only use this workaround on MSVC 2013 due to some confusion around default initialized unions. + // Spec seems to suggest the memory will be zero-initialized, which is *not* what we want. + return reinterpret_cast(u.aligned_char); +#else + return reinterpret_cast(aligned_char); +#endif + } + +private: +#if defined(_MSC_VER) && _MSC_VER < 1900 + // MSVC 2013 workarounds, sigh ... + union + { + char aligned_char[sizeof(T) * N]; + double dummy_aligner; + } u; +#else + alignas(T) char aligned_char[sizeof(T) * N]; +#endif +}; + +template +class AlignedBuffer +{ +public: + T *data() + { + return nullptr; + } +}; + +// An immutable version of SmallVector which erases type information about storage. +template +class VectorView +{ +public: + T &operator[](size_t i) SPIRV_CROSS_NOEXCEPT + { + return ptr[i]; + } + + const T &operator[](size_t i) const SPIRV_CROSS_NOEXCEPT + { + return ptr[i]; + } + + bool empty() const SPIRV_CROSS_NOEXCEPT + { + return buffer_size == 0; + } + + size_t size() const SPIRV_CROSS_NOEXCEPT + { + return buffer_size; + } + + T *data() SPIRV_CROSS_NOEXCEPT + { + return ptr; + } + + const T *data() const SPIRV_CROSS_NOEXCEPT + { + return ptr; + } + + T *begin() SPIRV_CROSS_NOEXCEPT + { + return ptr; + } + + T *end() SPIRV_CROSS_NOEXCEPT + { + return ptr + buffer_size; + } + + const T *begin() const SPIRV_CROSS_NOEXCEPT + { + return ptr; + } + + const T *end() const SPIRV_CROSS_NOEXCEPT + { + return ptr + buffer_size; + } + + T &front() SPIRV_CROSS_NOEXCEPT + { + return ptr[0]; + } + + const T &front() const SPIRV_CROSS_NOEXCEPT + { + return ptr[0]; + } + + T &back() SPIRV_CROSS_NOEXCEPT + { + return ptr[buffer_size - 1]; + } + + const T &back() const SPIRV_CROSS_NOEXCEPT + { + return ptr[buffer_size - 1]; + } + + // Makes it easier to consume SmallVector. +#if defined(_MSC_VER) && _MSC_VER < 1900 + explicit operator std::vector() const + { + // Another MSVC 2013 workaround. It does not understand lvalue/rvalue qualified operations. + return std::vector(ptr, ptr + buffer_size); + } +#else + // Makes it easier to consume SmallVector. + explicit operator std::vector() const & + { + return std::vector(ptr, ptr + buffer_size); + } + + // If we are converting as an r-value, we can pilfer our elements. + explicit operator std::vector() && + { + return std::vector(std::make_move_iterator(ptr), std::make_move_iterator(ptr + buffer_size)); + } +#endif + + // Avoid sliced copies. Base class should only be read as a reference. + VectorView(const VectorView &) = delete; + void operator=(const VectorView &) = delete; + +protected: + VectorView() = default; + T *ptr = nullptr; + size_t buffer_size = 0; +}; + +// Simple vector which supports up to N elements inline, without malloc/free. +// We use a lot of throwaway vectors all over the place which triggers allocations. +// This class only implements the subset of std::vector we need in SPIRV-Cross. +// It is *NOT* a drop-in replacement in general projects. +template +class SmallVector : public VectorView +{ +public: + SmallVector() SPIRV_CROSS_NOEXCEPT + { + this->ptr = stack_storage.data(); + buffer_capacity = N; + } + + template + SmallVector(const U *arg_list_begin, const U *arg_list_end) SPIRV_CROSS_NOEXCEPT : SmallVector() + { + auto count = size_t(arg_list_end - arg_list_begin); + reserve(count); + for (size_t i = 0; i < count; i++, arg_list_begin++) + new (&this->ptr[i]) T(*arg_list_begin); + this->buffer_size = count; + } + + template + SmallVector(std::initializer_list init) SPIRV_CROSS_NOEXCEPT : SmallVector(init.begin(), init.end()) + { + } + + template + explicit SmallVector(const U (&init)[M]) SPIRV_CROSS_NOEXCEPT : SmallVector(init, init + M) + { + } + + SmallVector(SmallVector &&other) SPIRV_CROSS_NOEXCEPT : SmallVector() + { + *this = std::move(other); + } + + SmallVector &operator=(SmallVector &&other) SPIRV_CROSS_NOEXCEPT + { + clear(); + if (other.ptr != other.stack_storage.data()) + { + // Pilfer allocated pointer. + if (this->ptr != stack_storage.data()) + free(this->ptr); + this->ptr = other.ptr; + this->buffer_size = other.buffer_size; + buffer_capacity = other.buffer_capacity; + other.ptr = nullptr; + other.buffer_size = 0; + other.buffer_capacity = 0; + } + else + { + // Need to move the stack contents individually. + reserve(other.buffer_size); + for (size_t i = 0; i < other.buffer_size; i++) + { + new (&this->ptr[i]) T(std::move(other.ptr[i])); + other.ptr[i].~T(); + } + this->buffer_size = other.buffer_size; + other.buffer_size = 0; + } + return *this; + } + + SmallVector(const SmallVector &other) SPIRV_CROSS_NOEXCEPT : SmallVector() + { + *this = other; + } + + SmallVector &operator=(const SmallVector &other) SPIRV_CROSS_NOEXCEPT + { + if (this == &other) + return *this; + + clear(); + reserve(other.buffer_size); + for (size_t i = 0; i < other.buffer_size; i++) + new (&this->ptr[i]) T(other.ptr[i]); + this->buffer_size = other.buffer_size; + return *this; + } + + explicit SmallVector(size_t count) SPIRV_CROSS_NOEXCEPT : SmallVector() + { + resize(count); + } + + ~SmallVector() + { + clear(); + if (this->ptr != stack_storage.data()) + free(this->ptr); + } + + void clear() SPIRV_CROSS_NOEXCEPT + { + for (size_t i = 0; i < this->buffer_size; i++) + this->ptr[i].~T(); + this->buffer_size = 0; + } + + void push_back(const T &t) SPIRV_CROSS_NOEXCEPT + { + reserve(this->buffer_size + 1); + new (&this->ptr[this->buffer_size]) T(t); + this->buffer_size++; + } + + void push_back(T &&t) SPIRV_CROSS_NOEXCEPT + { + reserve(this->buffer_size + 1); + new (&this->ptr[this->buffer_size]) T(std::move(t)); + this->buffer_size++; + } + + void pop_back() SPIRV_CROSS_NOEXCEPT + { + // Work around false positive warning on GCC 8.3. + // Calling pop_back on empty vector is undefined. + if (!this->empty()) + resize(this->buffer_size - 1); + } + + template + void emplace_back(Ts &&... ts) SPIRV_CROSS_NOEXCEPT + { + reserve(this->buffer_size + 1); + new (&this->ptr[this->buffer_size]) T(std::forward(ts)...); + this->buffer_size++; + } + + void reserve(size_t count) SPIRV_CROSS_NOEXCEPT + { + if ((count > (std::numeric_limits::max)() / sizeof(T)) || + (count > (std::numeric_limits::max)() / 2)) + { + // Only way this should ever happen is with garbage input, terminate. + std::terminate(); + } + + if (count > buffer_capacity) + { + size_t target_capacity = buffer_capacity; + if (target_capacity == 0) + target_capacity = 1; + + // Weird parens works around macro issues on Windows if NOMINMAX is not used. + target_capacity = (std::max)(target_capacity, N); + + // Need to ensure there is a POT value of target capacity which is larger than count, + // otherwise this will overflow. + while (target_capacity < count) + target_capacity <<= 1u; + + T *new_buffer = + target_capacity > N ? static_cast(malloc(target_capacity * sizeof(T))) : stack_storage.data(); + + // If we actually fail this malloc, we are hosed anyways, there is no reason to attempt recovery. + if (!new_buffer) + std::terminate(); + + // In case for some reason two allocations both come from same stack. + if (new_buffer != this->ptr) + { + // We don't deal with types which can throw in move constructor. + for (size_t i = 0; i < this->buffer_size; i++) + { + new (&new_buffer[i]) T(std::move(this->ptr[i])); + this->ptr[i].~T(); + } + } + + if (this->ptr != stack_storage.data()) + free(this->ptr); + this->ptr = new_buffer; + buffer_capacity = target_capacity; + } + } + + void insert(T *itr, const T *insert_begin, const T *insert_end) SPIRV_CROSS_NOEXCEPT + { + auto count = size_t(insert_end - insert_begin); + if (itr == this->end()) + { + reserve(this->buffer_size + count); + for (size_t i = 0; i < count; i++, insert_begin++) + new (&this->ptr[this->buffer_size + i]) T(*insert_begin); + this->buffer_size += count; + } + else + { + if (this->buffer_size + count > buffer_capacity) + { + auto target_capacity = this->buffer_size + count; + if (target_capacity == 0) + target_capacity = 1; + if (target_capacity < N) + target_capacity = N; + + while (target_capacity < count) + target_capacity <<= 1u; + + // Need to allocate new buffer. Move everything to a new buffer. + T *new_buffer = + target_capacity > N ? static_cast(malloc(target_capacity * sizeof(T))) : stack_storage.data(); + + // If we actually fail this malloc, we are hosed anyways, there is no reason to attempt recovery. + if (!new_buffer) + std::terminate(); + + // First, move elements from source buffer to new buffer. + // We don't deal with types which can throw in move constructor. + auto *target_itr = new_buffer; + auto *original_source_itr = this->begin(); + + if (new_buffer != this->ptr) + { + while (original_source_itr != itr) + { + new (target_itr) T(std::move(*original_source_itr)); + original_source_itr->~T(); + ++original_source_itr; + ++target_itr; + } + } + + // Copy-construct new elements. + for (auto *source_itr = insert_begin; source_itr != insert_end; ++source_itr, ++target_itr) + new (target_itr) T(*source_itr); + + // Move over the other half. + if (new_buffer != this->ptr || insert_begin != insert_end) + { + while (original_source_itr != this->end()) + { + new (target_itr) T(std::move(*original_source_itr)); + original_source_itr->~T(); + ++original_source_itr; + ++target_itr; + } + } + + if (this->ptr != stack_storage.data()) + free(this->ptr); + this->ptr = new_buffer; + buffer_capacity = target_capacity; + } + else + { + // Move in place, need to be a bit careful about which elements are constructed and which are not. + // Move the end and construct the new elements. + auto *target_itr = this->end() + count; + auto *source_itr = this->end(); + while (target_itr != this->end() && source_itr != itr) + { + --target_itr; + --source_itr; + new (target_itr) T(std::move(*source_itr)); + } + + // For already constructed elements we can move-assign. + std::move_backward(itr, source_itr, target_itr); + + // For the inserts which go to already constructed elements, we can do a plain copy. + while (itr != this->end() && insert_begin != insert_end) + *itr++ = *insert_begin++; + + // For inserts into newly allocated memory, we must copy-construct instead. + while (insert_begin != insert_end) + { + new (itr) T(*insert_begin); + ++itr; + ++insert_begin; + } + } + + this->buffer_size += count; + } + } + + void insert(T *itr, const T &value) SPIRV_CROSS_NOEXCEPT + { + insert(itr, &value, &value + 1); + } + + T *erase(T *itr) SPIRV_CROSS_NOEXCEPT + { + std::move(itr + 1, this->end(), itr); + this->ptr[--this->buffer_size].~T(); + return itr; + } + + void erase(T *start_erase, T *end_erase) SPIRV_CROSS_NOEXCEPT + { + if (end_erase == this->end()) + { + resize(size_t(start_erase - this->begin())); + } + else + { + auto new_size = this->buffer_size - (end_erase - start_erase); + std::move(end_erase, this->end(), start_erase); + resize(new_size); + } + } + + void resize(size_t new_size) SPIRV_CROSS_NOEXCEPT + { + if (new_size < this->buffer_size) + { + for (size_t i = new_size; i < this->buffer_size; i++) + this->ptr[i].~T(); + } + else if (new_size > this->buffer_size) + { + reserve(new_size); + for (size_t i = this->buffer_size; i < new_size; i++) + new (&this->ptr[i]) T(); + } + + this->buffer_size = new_size; + } + +private: + size_t buffer_capacity = 0; + AlignedBuffer stack_storage; +}; + +// A vector without stack storage. +// Could also be a typedef-ed to std::vector, +// but might as well use the one we have. +template +using Vector = SmallVector; + +#else // SPIRV_CROSS_FORCE_STL_TYPES + +template +using SmallVector = std::vector; +template +using Vector = std::vector; +template +using VectorView = std::vector; + +#endif // SPIRV_CROSS_FORCE_STL_TYPES + +// An object pool which we use for allocating IVariant-derived objects. +// We know we are going to allocate a bunch of objects of each type, +// so amortize the mallocs. +class ObjectPoolBase +{ +public: + virtual ~ObjectPoolBase() = default; + virtual void deallocate_opaque(void *ptr) = 0; +}; + +template +class ObjectPool : public ObjectPoolBase +{ +public: + explicit ObjectPool(unsigned start_object_count_ = 16) + : start_object_count(start_object_count_) + { + } + + template + T *allocate(P &&... p) + { + if (vacants.empty()) + { + unsigned num_objects = start_object_count << memory.size(); + T *ptr = static_cast(malloc(num_objects * sizeof(T))); + if (!ptr) + return nullptr; + + vacants.reserve(num_objects); + for (unsigned i = 0; i < num_objects; i++) + vacants.push_back(&ptr[i]); + + memory.emplace_back(ptr); + } + + T *ptr = vacants.back(); + vacants.pop_back(); + new (ptr) T(std::forward

(p)...); + return ptr; + } + + void deallocate(T *ptr) + { + ptr->~T(); + vacants.push_back(ptr); + } + + void deallocate_opaque(void *ptr) override + { + deallocate(static_cast(ptr)); + } + + void clear() + { + vacants.clear(); + memory.clear(); + } + +protected: + Vector vacants; + + struct MallocDeleter + { + void operator()(T *ptr) + { + ::free(ptr); + } + }; + + SmallVector> memory; + unsigned start_object_count; +}; + +template +class StringStream +{ +public: + StringStream() + { + reset(); + } + + ~StringStream() + { + reset(); + } + + // Disable copies and moves. Makes it easier to implement, and we don't need it. + StringStream(const StringStream &) = delete; + void operator=(const StringStream &) = delete; + + template ::value, int>::type = 0> + StringStream &operator<<(const T &t) + { + auto s = std::to_string(t); + append(s.data(), s.size()); + return *this; + } + + // Only overload this to make float/double conversions ambiguous. + StringStream &operator<<(uint32_t v) + { + auto s = std::to_string(v); + append(s.data(), s.size()); + return *this; + } + + StringStream &operator<<(char c) + { + append(&c, 1); + return *this; + } + + StringStream &operator<<(const std::string &s) + { + append(s.data(), s.size()); + return *this; + } + + StringStream &operator<<(const char *s) + { + append(s, strlen(s)); + return *this; + } + + template + StringStream &operator<<(const char (&s)[N]) + { + append(s, strlen(s)); + return *this; + } + + std::string str() const + { + std::string ret; + size_t target_size = 0; + for (auto &saved : saved_buffers) + target_size += saved.offset; + target_size += current_buffer.offset; + ret.reserve(target_size); + + for (auto &saved : saved_buffers) + ret.insert(ret.end(), saved.buffer, saved.buffer + saved.offset); + ret.insert(ret.end(), current_buffer.buffer, current_buffer.buffer + current_buffer.offset); + return ret; + } + + void reset() + { + for (auto &saved : saved_buffers) + if (saved.buffer != stack_buffer) + free(saved.buffer); + if (current_buffer.buffer != stack_buffer) + free(current_buffer.buffer); + + saved_buffers.clear(); + current_buffer.buffer = stack_buffer; + current_buffer.offset = 0; + current_buffer.size = sizeof(stack_buffer); + } + +private: + struct Buffer + { + char *buffer = nullptr; + size_t offset = 0; + size_t size = 0; + }; + Buffer current_buffer; + char stack_buffer[StackSize]; + SmallVector saved_buffers; + + void append(const char *s, size_t len) + { + size_t avail = current_buffer.size - current_buffer.offset; + if (avail < len) + { + if (avail > 0) + { + memcpy(current_buffer.buffer + current_buffer.offset, s, avail); + s += avail; + len -= avail; + current_buffer.offset += avail; + } + + saved_buffers.push_back(current_buffer); + size_t target_size = len > BlockSize ? len : BlockSize; + current_buffer.buffer = static_cast(malloc(target_size)); + if (!current_buffer.buffer) + SPIRV_CROSS_THROW("Out of memory."); + + memcpy(current_buffer.buffer, s, len); + current_buffer.offset = len; + current_buffer.size = target_size; + } + else + { + memcpy(current_buffer.buffer + current_buffer.offset, s, len); + current_buffer.offset += len; + } + } +}; + +} // namespace SPIRV_CROSS_NAMESPACE + +#endif diff --git a/thirdparty/spirv-cross/spirv_cross_error_handling.hpp b/thirdparty/spirv-cross/spirv_cross_error_handling.hpp new file mode 100644 index 00000000000..91e6cf4f86b --- /dev/null +++ b/thirdparty/spirv-cross/spirv_cross_error_handling.hpp @@ -0,0 +1,99 @@ +/* + * Copyright 2015-2021 Arm Limited + * SPDX-License-Identifier: Apache-2.0 OR MIT + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +/* + * At your option, you may choose to accept this material under either: + * 1. The Apache License, Version 2.0, found at , or + * 2. The MIT License, found at . + */ + +#ifndef SPIRV_CROSS_ERROR_HANDLING +#define SPIRV_CROSS_ERROR_HANDLING + +#include +#include +#include +#ifndef SPIRV_CROSS_EXCEPTIONS_TO_ASSERTIONS +#include +#endif + +#ifdef SPIRV_CROSS_NAMESPACE_OVERRIDE +#define SPIRV_CROSS_NAMESPACE SPIRV_CROSS_NAMESPACE_OVERRIDE +#else +#define SPIRV_CROSS_NAMESPACE spirv_cross +#endif + +namespace SPIRV_CROSS_NAMESPACE +{ +#ifdef SPIRV_CROSS_EXCEPTIONS_TO_ASSERTIONS +#if !defined(_MSC_VER) || defined(__clang__) +[[noreturn]] +#elif defined(_MSC_VER) +__declspec(noreturn) +#endif +inline void +report_and_abort(const std::string &msg) +{ +#ifdef NDEBUG + (void)msg; +#else + fprintf(stderr, "There was a compiler error: %s\n", msg.c_str()); +#endif + fflush(stderr); + abort(); +} + +#define SPIRV_CROSS_THROW(x) report_and_abort(x) +#else +class CompilerError : public std::runtime_error +{ +public: + explicit CompilerError(const std::string &str) + : std::runtime_error(str) + { + } + + explicit CompilerError(const char *str) + : std::runtime_error(str) + { + } +}; + +#define SPIRV_CROSS_THROW(x) throw CompilerError(x) +#endif + +// MSVC 2013 does not have noexcept. We need this for Variant to get move constructor to work correctly +// instead of copy constructor. +// MSVC 2013 ignores that move constructors cannot throw in std::vector, so just don't define it. +#if defined(_MSC_VER) && _MSC_VER < 1900 +#define SPIRV_CROSS_NOEXCEPT +#else +#define SPIRV_CROSS_NOEXCEPT noexcept +#endif + +#if __cplusplus >= 201402l +#define SPIRV_CROSS_DEPRECATED(reason) [[deprecated(reason)]] +#elif defined(__GNUC__) +#define SPIRV_CROSS_DEPRECATED(reason) __attribute__((deprecated)) +#elif defined(_MSC_VER) +#define SPIRV_CROSS_DEPRECATED(reason) __declspec(deprecated(reason)) +#else +#define SPIRV_CROSS_DEPRECATED(reason) +#endif +} // namespace SPIRV_CROSS_NAMESPACE + +#endif diff --git a/thirdparty/spirv-cross/spirv_cross_parsed_ir.cpp b/thirdparty/spirv-cross/spirv_cross_parsed_ir.cpp new file mode 100644 index 00000000000..3072cd8abb0 --- /dev/null +++ b/thirdparty/spirv-cross/spirv_cross_parsed_ir.cpp @@ -0,0 +1,1083 @@ +/* + * Copyright 2018-2021 Arm Limited + * SPDX-License-Identifier: Apache-2.0 OR MIT + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +/* + * At your option, you may choose to accept this material under either: + * 1. The Apache License, Version 2.0, found at , or + * 2. The MIT License, found at . + */ + +#include "spirv_cross_parsed_ir.hpp" +#include +#include + +using namespace std; +using namespace spv; + +namespace SPIRV_CROSS_NAMESPACE +{ +ParsedIR::ParsedIR() +{ + // If we move ParsedIR, we need to make sure the pointer stays fixed since the child Variant objects consume a pointer to this group, + // so need an extra pointer here. + pool_group.reset(new ObjectPoolGroup); + + pool_group->pools[TypeType].reset(new ObjectPool); + pool_group->pools[TypeVariable].reset(new ObjectPool); + pool_group->pools[TypeConstant].reset(new ObjectPool); + pool_group->pools[TypeFunction].reset(new ObjectPool); + pool_group->pools[TypeFunctionPrototype].reset(new ObjectPool); + pool_group->pools[TypeBlock].reset(new ObjectPool); + pool_group->pools[TypeExtension].reset(new ObjectPool); + pool_group->pools[TypeExpression].reset(new ObjectPool); + pool_group->pools[TypeConstantOp].reset(new ObjectPool); + pool_group->pools[TypeCombinedImageSampler].reset(new ObjectPool); + pool_group->pools[TypeAccessChain].reset(new ObjectPool); + pool_group->pools[TypeUndef].reset(new ObjectPool); + pool_group->pools[TypeString].reset(new ObjectPool); +} + +// Should have been default-implemented, but need this on MSVC 2013. +ParsedIR::ParsedIR(ParsedIR &&other) SPIRV_CROSS_NOEXCEPT +{ + *this = std::move(other); +} + +ParsedIR &ParsedIR::operator=(ParsedIR &&other) SPIRV_CROSS_NOEXCEPT +{ + if (this != &other) + { + pool_group = std::move(other.pool_group); + spirv = std::move(other.spirv); + meta = std::move(other.meta); + for (int i = 0; i < TypeCount; i++) + ids_for_type[i] = std::move(other.ids_for_type[i]); + ids_for_constant_undef_or_type = std::move(other.ids_for_constant_undef_or_type); + ids_for_constant_or_variable = std::move(other.ids_for_constant_or_variable); + declared_capabilities = std::move(other.declared_capabilities); + declared_extensions = std::move(other.declared_extensions); + block_meta = std::move(other.block_meta); + continue_block_to_loop_header = std::move(other.continue_block_to_loop_header); + entry_points = std::move(other.entry_points); + ids = std::move(other.ids); + addressing_model = other.addressing_model; + memory_model = other.memory_model; + + default_entry_point = other.default_entry_point; + source = other.source; + loop_iteration_depth_hard = other.loop_iteration_depth_hard; + loop_iteration_depth_soft = other.loop_iteration_depth_soft; + + meta_needing_name_fixup = std::move(other.meta_needing_name_fixup); + load_type_width = std::move(other.load_type_width); + } + return *this; +} + +ParsedIR::ParsedIR(const ParsedIR &other) + : ParsedIR() +{ + *this = other; +} + +ParsedIR &ParsedIR::operator=(const ParsedIR &other) +{ + if (this != &other) + { + spirv = other.spirv; + meta = other.meta; + for (int i = 0; i < TypeCount; i++) + ids_for_type[i] = other.ids_for_type[i]; + ids_for_constant_undef_or_type = other.ids_for_constant_undef_or_type; + ids_for_constant_or_variable = other.ids_for_constant_or_variable; + declared_capabilities = other.declared_capabilities; + declared_extensions = other.declared_extensions; + block_meta = other.block_meta; + continue_block_to_loop_header = other.continue_block_to_loop_header; + entry_points = other.entry_points; + default_entry_point = other.default_entry_point; + source = other.source; + loop_iteration_depth_hard = other.loop_iteration_depth_hard; + loop_iteration_depth_soft = other.loop_iteration_depth_soft; + addressing_model = other.addressing_model; + memory_model = other.memory_model; + + + meta_needing_name_fixup = other.meta_needing_name_fixup; + load_type_width = other.load_type_width; + + // Very deliberate copying of IDs. There is no default copy constructor, nor a simple default constructor. + // Construct object first so we have the correct allocator set-up, then we can copy object into our new pool group. + ids.clear(); + ids.reserve(other.ids.size()); + for (size_t i = 0; i < other.ids.size(); i++) + { + ids.emplace_back(pool_group.get()); + ids.back() = other.ids[i]; + } + } + return *this; +} + +void ParsedIR::set_id_bounds(uint32_t bounds) +{ + ids.reserve(bounds); + while (ids.size() < bounds) + ids.emplace_back(pool_group.get()); + + block_meta.resize(bounds); +} + +// Roll our own versions of these functions to avoid potential locale shenanigans. +static bool is_alpha(char c) +{ + return (c >= 'a' && c <= 'z') || (c >= 'A' && c <= 'Z'); +} + +static bool is_numeric(char c) +{ + return c >= '0' && c <= '9'; +} + +static bool is_alphanumeric(char c) +{ + return is_alpha(c) || is_numeric(c); +} + +static bool is_valid_identifier(const string &name) +{ + if (name.empty()) + return true; + + if (is_numeric(name[0])) + return false; + + for (auto c : name) + if (!is_alphanumeric(c) && c != '_') + return false; + + bool saw_underscore = false; + // Two underscores in a row is not a valid identifier either. + // Technically reserved, but it's easier to treat it as invalid. + for (auto c : name) + { + bool is_underscore = c == '_'; + if (is_underscore && saw_underscore) + return false; + saw_underscore = is_underscore; + } + + return true; +} + +static bool is_reserved_prefix(const string &name) +{ + // Generic reserved identifiers used by the implementation. + return name.compare(0, 3, "gl_", 3) == 0 || + // Ignore this case for now, might rewrite internal code to always use spv prefix. + //name.compare(0, 11, "SPIRV_Cross", 11) == 0 || + name.compare(0, 3, "spv", 3) == 0; +} + +static bool is_reserved_identifier(const string &name, bool member, bool allow_reserved_prefixes) +{ + if (!allow_reserved_prefixes && is_reserved_prefix(name)) + return true; + + if (member) + { + // Reserved member identifiers come in one form: + // _m[0-9]+$. + if (name.size() < 3) + return false; + + if (name.compare(0, 2, "_m", 2) != 0) + return false; + + size_t index = 2; + while (index < name.size() && is_numeric(name[index])) + index++; + + return index == name.size(); + } + else + { + // Reserved non-member identifiers come in two forms: + // _[0-9]+$, used for temporaries which map directly to a SPIR-V ID. + // _[0-9]+_, used for auxillary temporaries which derived from a SPIR-V ID. + if (name.size() < 2) + return false; + + if (name[0] != '_' || !is_numeric(name[1])) + return false; + + size_t index = 2; + while (index < name.size() && is_numeric(name[index])) + index++; + + return index == name.size() || (index < name.size() && name[index] == '_'); + } +} + +bool ParsedIR::is_globally_reserved_identifier(std::string &str, bool allow_reserved_prefixes) +{ + return is_reserved_identifier(str, false, allow_reserved_prefixes); +} + +uint32_t ParsedIR::get_spirv_version() const +{ + return spirv[1]; +} + +static string make_unreserved_identifier(const string &name) +{ + if (is_reserved_prefix(name)) + return "_RESERVED_IDENTIFIER_FIXUP_" + name; + else + return "_RESERVED_IDENTIFIER_FIXUP" + name; +} + +void ParsedIR::sanitize_underscores(std::string &str) +{ + // Compact adjacent underscores to make it valid. + auto dst = str.begin(); + auto src = dst; + bool saw_underscore = false; + while (src != str.end()) + { + bool is_underscore = *src == '_'; + if (saw_underscore && is_underscore) + { + src++; + } + else + { + if (dst != src) + *dst = *src; + dst++; + src++; + saw_underscore = is_underscore; + } + } + str.erase(dst, str.end()); +} + +static string ensure_valid_identifier(const string &name) +{ + // Functions in glslangValidator are mangled with name( stuff. + // Normally, we would never see '(' in any legal identifiers, so just strip them out. + auto str = name.substr(0, name.find('(')); + + if (str.empty()) + return str; + + if (is_numeric(str[0])) + str[0] = '_'; + + for (auto &c : str) + if (!is_alphanumeric(c) && c != '_') + c = '_'; + + ParsedIR::sanitize_underscores(str); + return str; +} + +const string &ParsedIR::get_name(ID id) const +{ + auto *m = find_meta(id); + if (m) + return m->decoration.alias; + else + return empty_string; +} + +const string &ParsedIR::get_member_name(TypeID id, uint32_t index) const +{ + auto *m = find_meta(id); + if (m) + { + if (index >= m->members.size()) + return empty_string; + return m->members[index].alias; + } + else + return empty_string; +} + +void ParsedIR::sanitize_identifier(std::string &name, bool member, bool allow_reserved_prefixes) +{ + if (!is_valid_identifier(name)) + name = ensure_valid_identifier(name); + if (is_reserved_identifier(name, member, allow_reserved_prefixes)) + name = make_unreserved_identifier(name); +} + +void ParsedIR::fixup_reserved_names() +{ + for (uint32_t id : meta_needing_name_fixup) + { + // Don't rename remapped variables like 'gl_LastFragDepthARM'. + if (ids[id].get_type() == TypeVariable && get(id).remapped_variable) + continue; + + auto &m = meta[id]; + sanitize_identifier(m.decoration.alias, false, false); + for (auto &memb : m.members) + sanitize_identifier(memb.alias, true, false); + } + meta_needing_name_fixup.clear(); +} + +void ParsedIR::set_name(ID id, const string &name) +{ + auto &m = meta[id]; + m.decoration.alias = name; + if (!is_valid_identifier(name) || is_reserved_identifier(name, false, false)) + meta_needing_name_fixup.insert(id); +} + +void ParsedIR::set_member_name(TypeID id, uint32_t index, const string &name) +{ + auto &m = meta[id]; + m.members.resize(max(m.members.size(), size_t(index) + 1)); + m.members[index].alias = name; + if (!is_valid_identifier(name) || is_reserved_identifier(name, true, false)) + meta_needing_name_fixup.insert(id); +} + +void ParsedIR::set_decoration_string(ID id, Decoration decoration, const string &argument) +{ + auto &dec = meta[id].decoration; + dec.decoration_flags.set(decoration); + + switch (decoration) + { + case DecorationHlslSemanticGOOGLE: + dec.hlsl_semantic = argument; + break; + + case DecorationUserTypeGOOGLE: + dec.user_type = argument; + break; + + default: + break; + } +} + +void ParsedIR::set_decoration(ID id, Decoration decoration, uint32_t argument) +{ + auto &dec = meta[id].decoration; + dec.decoration_flags.set(decoration); + + switch (decoration) + { + case DecorationBuiltIn: + dec.builtin = true; + dec.builtin_type = static_cast(argument); + break; + + case DecorationLocation: + dec.location = argument; + break; + + case DecorationComponent: + dec.component = argument; + break; + + case DecorationOffset: + dec.offset = argument; + break; + + case DecorationXfbBuffer: + dec.xfb_buffer = argument; + break; + + case DecorationXfbStride: + dec.xfb_stride = argument; + break; + + case DecorationStream: + dec.stream = argument; + break; + + case DecorationArrayStride: + dec.array_stride = argument; + break; + + case DecorationMatrixStride: + dec.matrix_stride = argument; + break; + + case DecorationBinding: + dec.binding = argument; + break; + + case DecorationDescriptorSet: + dec.set = argument; + break; + + case DecorationInputAttachmentIndex: + dec.input_attachment = argument; + break; + + case DecorationSpecId: + dec.spec_id = argument; + break; + + case DecorationIndex: + dec.index = argument; + break; + + case DecorationHlslCounterBufferGOOGLE: + meta[id].hlsl_magic_counter_buffer = argument; + meta[argument].hlsl_is_magic_counter_buffer = true; + break; + + case DecorationFPRoundingMode: + dec.fp_rounding_mode = static_cast(argument); + break; + + default: + break; + } +} + +void ParsedIR::set_member_decoration(TypeID id, uint32_t index, Decoration decoration, uint32_t argument) +{ + auto &m = meta[id]; + m.members.resize(max(m.members.size(), size_t(index) + 1)); + auto &dec = m.members[index]; + dec.decoration_flags.set(decoration); + + switch (decoration) + { + case DecorationBuiltIn: + dec.builtin = true; + dec.builtin_type = static_cast(argument); + break; + + case DecorationLocation: + dec.location = argument; + break; + + case DecorationComponent: + dec.component = argument; + break; + + case DecorationBinding: + dec.binding = argument; + break; + + case DecorationOffset: + dec.offset = argument; + break; + + case DecorationXfbBuffer: + dec.xfb_buffer = argument; + break; + + case DecorationXfbStride: + dec.xfb_stride = argument; + break; + + case DecorationStream: + dec.stream = argument; + break; + + case DecorationSpecId: + dec.spec_id = argument; + break; + + case DecorationMatrixStride: + dec.matrix_stride = argument; + break; + + case DecorationIndex: + dec.index = argument; + break; + + default: + break; + } +} + +// Recursively marks any constants referenced by the specified constant instruction as being used +// as an array length. The id must be a constant instruction (SPIRConstant or SPIRConstantOp). +void ParsedIR::mark_used_as_array_length(ID id) +{ + switch (ids[id].get_type()) + { + case TypeConstant: + get(id).is_used_as_array_length = true; + break; + + case TypeConstantOp: + { + auto &cop = get(id); + if (cop.opcode == OpCompositeExtract) + mark_used_as_array_length(cop.arguments[0]); + else if (cop.opcode == OpCompositeInsert) + { + mark_used_as_array_length(cop.arguments[0]); + mark_used_as_array_length(cop.arguments[1]); + } + else + for (uint32_t arg_id : cop.arguments) + mark_used_as_array_length(arg_id); + break; + } + + case TypeUndef: + break; + + default: + assert(0); + } +} + +Bitset ParsedIR::get_buffer_block_type_flags(const SPIRType &type) const +{ + if (type.member_types.empty()) + return {}; + + Bitset all_members_flags = get_member_decoration_bitset(type.self, 0); + for (uint32_t i = 1; i < uint32_t(type.member_types.size()); i++) + all_members_flags.merge_and(get_member_decoration_bitset(type.self, i)); + return all_members_flags; +} + +Bitset ParsedIR::get_buffer_block_flags(const SPIRVariable &var) const +{ + auto &type = get(var.basetype); + assert(type.basetype == SPIRType::Struct); + + // Some flags like non-writable, non-readable are actually found + // as member decorations. If all members have a decoration set, propagate + // the decoration up as a regular variable decoration. + Bitset base_flags; + auto *m = find_meta(var.self); + if (m) + base_flags = m->decoration.decoration_flags; + + if (type.member_types.empty()) + return base_flags; + + auto all_members_flags = get_buffer_block_type_flags(type); + base_flags.merge_or(all_members_flags); + return base_flags; +} + +const Bitset &ParsedIR::get_member_decoration_bitset(TypeID id, uint32_t index) const +{ + auto *m = find_meta(id); + if (m) + { + if (index >= m->members.size()) + return cleared_bitset; + return m->members[index].decoration_flags; + } + else + return cleared_bitset; +} + +bool ParsedIR::has_decoration(ID id, Decoration decoration) const +{ + return get_decoration_bitset(id).get(decoration); +} + +uint32_t ParsedIR::get_decoration(ID id, Decoration decoration) const +{ + auto *m = find_meta(id); + if (!m) + return 0; + + auto &dec = m->decoration; + if (!dec.decoration_flags.get(decoration)) + return 0; + + switch (decoration) + { + case DecorationBuiltIn: + return dec.builtin_type; + case DecorationLocation: + return dec.location; + case DecorationComponent: + return dec.component; + case DecorationOffset: + return dec.offset; + case DecorationXfbBuffer: + return dec.xfb_buffer; + case DecorationXfbStride: + return dec.xfb_stride; + case DecorationStream: + return dec.stream; + case DecorationBinding: + return dec.binding; + case DecorationDescriptorSet: + return dec.set; + case DecorationInputAttachmentIndex: + return dec.input_attachment; + case DecorationSpecId: + return dec.spec_id; + case DecorationArrayStride: + return dec.array_stride; + case DecorationMatrixStride: + return dec.matrix_stride; + case DecorationIndex: + return dec.index; + case DecorationFPRoundingMode: + return dec.fp_rounding_mode; + default: + return 1; + } +} + +const string &ParsedIR::get_decoration_string(ID id, Decoration decoration) const +{ + auto *m = find_meta(id); + if (!m) + return empty_string; + + auto &dec = m->decoration; + + if (!dec.decoration_flags.get(decoration)) + return empty_string; + + switch (decoration) + { + case DecorationHlslSemanticGOOGLE: + return dec.hlsl_semantic; + + case DecorationUserTypeGOOGLE: + return dec.user_type; + + default: + return empty_string; + } +} + +void ParsedIR::unset_decoration(ID id, Decoration decoration) +{ + auto &dec = meta[id].decoration; + dec.decoration_flags.clear(decoration); + switch (decoration) + { + case DecorationBuiltIn: + dec.builtin = false; + break; + + case DecorationLocation: + dec.location = 0; + break; + + case DecorationComponent: + dec.component = 0; + break; + + case DecorationOffset: + dec.offset = 0; + break; + + case DecorationXfbBuffer: + dec.xfb_buffer = 0; + break; + + case DecorationXfbStride: + dec.xfb_stride = 0; + break; + + case DecorationStream: + dec.stream = 0; + break; + + case DecorationBinding: + dec.binding = 0; + break; + + case DecorationDescriptorSet: + dec.set = 0; + break; + + case DecorationInputAttachmentIndex: + dec.input_attachment = 0; + break; + + case DecorationSpecId: + dec.spec_id = 0; + break; + + case DecorationHlslSemanticGOOGLE: + dec.hlsl_semantic.clear(); + break; + + case DecorationFPRoundingMode: + dec.fp_rounding_mode = FPRoundingModeMax; + break; + + case DecorationHlslCounterBufferGOOGLE: + { + auto &counter = meta[id].hlsl_magic_counter_buffer; + if (counter) + { + meta[counter].hlsl_is_magic_counter_buffer = false; + counter = 0; + } + break; + } + + default: + break; + } +} + +bool ParsedIR::has_member_decoration(TypeID id, uint32_t index, Decoration decoration) const +{ + return get_member_decoration_bitset(id, index).get(decoration); +} + +uint32_t ParsedIR::get_member_decoration(TypeID id, uint32_t index, Decoration decoration) const +{ + auto *m = find_meta(id); + if (!m) + return 0; + + if (index >= m->members.size()) + return 0; + + auto &dec = m->members[index]; + if (!dec.decoration_flags.get(decoration)) + return 0; + + switch (decoration) + { + case DecorationBuiltIn: + return dec.builtin_type; + case DecorationLocation: + return dec.location; + case DecorationComponent: + return dec.component; + case DecorationBinding: + return dec.binding; + case DecorationOffset: + return dec.offset; + case DecorationXfbBuffer: + return dec.xfb_buffer; + case DecorationXfbStride: + return dec.xfb_stride; + case DecorationStream: + return dec.stream; + case DecorationSpecId: + return dec.spec_id; + case DecorationMatrixStride: + return dec.matrix_stride; + case DecorationIndex: + return dec.index; + default: + return 1; + } +} + +const Bitset &ParsedIR::get_decoration_bitset(ID id) const +{ + auto *m = find_meta(id); + if (m) + { + auto &dec = m->decoration; + return dec.decoration_flags; + } + else + return cleared_bitset; +} + +void ParsedIR::set_member_decoration_string(TypeID id, uint32_t index, Decoration decoration, const string &argument) +{ + auto &m = meta[id]; + m.members.resize(max(m.members.size(), size_t(index) + 1)); + auto &dec = meta[id].members[index]; + dec.decoration_flags.set(decoration); + + switch (decoration) + { + case DecorationHlslSemanticGOOGLE: + dec.hlsl_semantic = argument; + break; + + default: + break; + } +} + +const string &ParsedIR::get_member_decoration_string(TypeID id, uint32_t index, Decoration decoration) const +{ + auto *m = find_meta(id); + if (m) + { + if (!has_member_decoration(id, index, decoration)) + return empty_string; + + auto &dec = m->members[index]; + + switch (decoration) + { + case DecorationHlslSemanticGOOGLE: + return dec.hlsl_semantic; + + default: + return empty_string; + } + } + else + return empty_string; +} + +void ParsedIR::unset_member_decoration(TypeID id, uint32_t index, Decoration decoration) +{ + auto &m = meta[id]; + if (index >= m.members.size()) + return; + + auto &dec = m.members[index]; + + dec.decoration_flags.clear(decoration); + switch (decoration) + { + case DecorationBuiltIn: + dec.builtin = false; + break; + + case DecorationLocation: + dec.location = 0; + break; + + case DecorationComponent: + dec.component = 0; + break; + + case DecorationOffset: + dec.offset = 0; + break; + + case DecorationXfbBuffer: + dec.xfb_buffer = 0; + break; + + case DecorationXfbStride: + dec.xfb_stride = 0; + break; + + case DecorationStream: + dec.stream = 0; + break; + + case DecorationSpecId: + dec.spec_id = 0; + break; + + case DecorationHlslSemanticGOOGLE: + dec.hlsl_semantic.clear(); + break; + + default: + break; + } +} + +uint32_t ParsedIR::increase_bound_by(uint32_t incr_amount) +{ + auto curr_bound = ids.size(); + auto new_bound = curr_bound + incr_amount; + + ids.reserve(ids.size() + incr_amount); + for (uint32_t i = 0; i < incr_amount; i++) + ids.emplace_back(pool_group.get()); + + block_meta.resize(new_bound); + return uint32_t(curr_bound); +} + +void ParsedIR::remove_typed_id(Types type, ID id) +{ + auto &type_ids = ids_for_type[type]; + type_ids.erase(remove(begin(type_ids), end(type_ids), id), end(type_ids)); +} + +void ParsedIR::reset_all_of_type(Types type) +{ + for (auto &id : ids_for_type[type]) + if (ids[id].get_type() == type) + ids[id].reset(); + + ids_for_type[type].clear(); +} + +void ParsedIR::add_typed_id(Types type, ID id) +{ + if (loop_iteration_depth_hard != 0) + SPIRV_CROSS_THROW("Cannot add typed ID while looping over it."); + + if (loop_iteration_depth_soft != 0) + { + if (!ids[id].empty()) + SPIRV_CROSS_THROW("Cannot override IDs when loop is soft locked."); + return; + } + + if (ids[id].empty() || ids[id].get_type() != type) + { + switch (type) + { + case TypeConstant: + ids_for_constant_or_variable.push_back(id); + ids_for_constant_undef_or_type.push_back(id); + break; + + case TypeVariable: + ids_for_constant_or_variable.push_back(id); + break; + + case TypeType: + case TypeConstantOp: + case TypeUndef: + ids_for_constant_undef_or_type.push_back(id); + break; + + default: + break; + } + } + + if (ids[id].empty()) + { + ids_for_type[type].push_back(id); + } + else if (ids[id].get_type() != type) + { + remove_typed_id(ids[id].get_type(), id); + ids_for_type[type].push_back(id); + } +} + +const Meta *ParsedIR::find_meta(ID id) const +{ + auto itr = meta.find(id); + if (itr != end(meta)) + return &itr->second; + else + return nullptr; +} + +Meta *ParsedIR::find_meta(ID id) +{ + auto itr = meta.find(id); + if (itr != end(meta)) + return &itr->second; + else + return nullptr; +} + +ParsedIR::LoopLock ParsedIR::create_loop_hard_lock() const +{ + return ParsedIR::LoopLock(&loop_iteration_depth_hard); +} + +ParsedIR::LoopLock ParsedIR::create_loop_soft_lock() const +{ + return ParsedIR::LoopLock(&loop_iteration_depth_soft); +} + +ParsedIR::LoopLock::~LoopLock() +{ + if (lock) + (*lock)--; +} + +ParsedIR::LoopLock::LoopLock(uint32_t *lock_) + : lock(lock_) +{ + if (lock) + (*lock)++; +} + +ParsedIR::LoopLock::LoopLock(LoopLock &&other) SPIRV_CROSS_NOEXCEPT +{ + *this = std::move(other); +} + +ParsedIR::LoopLock &ParsedIR::LoopLock::operator=(LoopLock &&other) SPIRV_CROSS_NOEXCEPT +{ + if (lock) + (*lock)--; + lock = other.lock; + other.lock = nullptr; + return *this; +} + +void ParsedIR::make_constant_null(uint32_t id, uint32_t type, bool add_to_typed_id_set) +{ + auto &constant_type = get(type); + + if (constant_type.pointer) + { + if (add_to_typed_id_set) + add_typed_id(TypeConstant, id); + auto &constant = variant_set(ids[id], type); + constant.self = id; + constant.make_null(constant_type); + } + else if (!constant_type.array.empty()) + { + assert(constant_type.parent_type); + uint32_t parent_id = increase_bound_by(1); + make_constant_null(parent_id, constant_type.parent_type, add_to_typed_id_set); + + if (!constant_type.array_size_literal.back()) + SPIRV_CROSS_THROW("Array size of OpConstantNull must be a literal."); + + SmallVector elements(constant_type.array.back()); + for (uint32_t i = 0; i < constant_type.array.back(); i++) + elements[i] = parent_id; + + if (add_to_typed_id_set) + add_typed_id(TypeConstant, id); + variant_set(ids[id], type, elements.data(), uint32_t(elements.size()), false).self = id; + } + else if (!constant_type.member_types.empty()) + { + uint32_t member_ids = increase_bound_by(uint32_t(constant_type.member_types.size())); + SmallVector elements(constant_type.member_types.size()); + for (uint32_t i = 0; i < constant_type.member_types.size(); i++) + { + make_constant_null(member_ids + i, constant_type.member_types[i], add_to_typed_id_set); + elements[i] = member_ids + i; + } + + if (add_to_typed_id_set) + add_typed_id(TypeConstant, id); + variant_set(ids[id], type, elements.data(), uint32_t(elements.size()), false).self = id; + } + else + { + if (add_to_typed_id_set) + add_typed_id(TypeConstant, id); + auto &constant = variant_set(ids[id], type); + constant.self = id; + constant.make_null(constant_type); + } +} + +} // namespace SPIRV_CROSS_NAMESPACE diff --git a/thirdparty/spirv-cross/spirv_cross_parsed_ir.hpp b/thirdparty/spirv-cross/spirv_cross_parsed_ir.hpp new file mode 100644 index 00000000000..3892248aaad --- /dev/null +++ b/thirdparty/spirv-cross/spirv_cross_parsed_ir.hpp @@ -0,0 +1,256 @@ +/* + * Copyright 2018-2021 Arm Limited + * SPDX-License-Identifier: Apache-2.0 OR MIT + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +/* + * At your option, you may choose to accept this material under either: + * 1. The Apache License, Version 2.0, found at , or + * 2. The MIT License, found at . + */ + +#ifndef SPIRV_CROSS_PARSED_IR_HPP +#define SPIRV_CROSS_PARSED_IR_HPP + +#include "spirv_common.hpp" +#include +#include + +namespace SPIRV_CROSS_NAMESPACE +{ + +// This data structure holds all information needed to perform cross-compilation and reflection. +// It is the output of the Parser, but any implementation could create this structure. +// It is intentionally very "open" and struct-like with some helper functions to deal with decorations. +// Parser is the reference implementation of how this data structure should be filled in. + +class ParsedIR +{ +private: + // This must be destroyed after the "ids" vector. + std::unique_ptr pool_group; + +public: + ParsedIR(); + + // Due to custom allocations from object pools, we cannot use a default copy constructor. + ParsedIR(const ParsedIR &other); + ParsedIR &operator=(const ParsedIR &other); + + // Moves are unproblematic, but we need to implement it anyways, since MSVC 2013 does not understand + // how to default-implement these. + ParsedIR(ParsedIR &&other) SPIRV_CROSS_NOEXCEPT; + ParsedIR &operator=(ParsedIR &&other) SPIRV_CROSS_NOEXCEPT; + + // Resizes ids, meta and block_meta. + void set_id_bounds(uint32_t bounds); + + // The raw SPIR-V, instructions and opcodes refer to this by offset + count. + std::vector spirv; + + // Holds various data structures which inherit from IVariant. + SmallVector ids; + + // Various meta data for IDs, decorations, names, etc. + std::unordered_map meta; + + // Holds all IDs which have a certain type. + // This is needed so we can iterate through a specific kind of resource quickly, + // and in-order of module declaration. + SmallVector ids_for_type[TypeCount]; + + // Special purpose lists which contain a union of types. + // This is needed so we can declare specialization constants and structs in an interleaved fashion, + // among other things. + // Constants can be undef or of struct type, and struct array sizes can use specialization constants. + SmallVector ids_for_constant_undef_or_type; + SmallVector ids_for_constant_or_variable; + + // We need to keep track of the width the Ops that contains a type for the + // OpSwitch instruction, since this one doesn't contains the type in the + // instruction itself. And in some case we need to cast the condition to + // wider types. We only need the width to do the branch fixup since the + // type check itself can be done at runtime + std::unordered_map load_type_width; + + // Declared capabilities and extensions in the SPIR-V module. + // Not really used except for reflection at the moment. + SmallVector declared_capabilities; + SmallVector declared_extensions; + + // Meta data about blocks. The cross-compiler needs to query if a block is either of these types. + // It is a bitset as there can be more than one tag per block. + enum BlockMetaFlagBits + { + BLOCK_META_LOOP_HEADER_BIT = 1 << 0, + BLOCK_META_CONTINUE_BIT = 1 << 1, + BLOCK_META_LOOP_MERGE_BIT = 1 << 2, + BLOCK_META_SELECTION_MERGE_BIT = 1 << 3, + BLOCK_META_MULTISELECT_MERGE_BIT = 1 << 4 + }; + using BlockMetaFlags = uint8_t; + SmallVector block_meta; + std::unordered_map continue_block_to_loop_header; + + // Normally, we'd stick SPIREntryPoint in ids array, but it conflicts with SPIRFunction. + // Entry points can therefore be seen as some sort of meta structure. + std::unordered_map entry_points; + FunctionID default_entry_point = 0; + + struct Source + { + uint32_t version = 0; + bool es = false; + bool known = false; + bool hlsl = false; + + Source() = default; + }; + + Source source; + + spv::AddressingModel addressing_model = spv::AddressingModelMax; + spv::MemoryModel memory_model = spv::MemoryModelMax; + + // Decoration handling methods. + // Can be useful for simple "raw" reflection. + // However, most members are here because the Parser needs most of these, + // and might as well just have the whole suite of decoration/name handling in one place. + void set_name(ID id, const std::string &name); + const std::string &get_name(ID id) const; + void set_decoration(ID id, spv::Decoration decoration, uint32_t argument = 0); + void set_decoration_string(ID id, spv::Decoration decoration, const std::string &argument); + bool has_decoration(ID id, spv::Decoration decoration) const; + uint32_t get_decoration(ID id, spv::Decoration decoration) const; + const std::string &get_decoration_string(ID id, spv::Decoration decoration) const; + const Bitset &get_decoration_bitset(ID id) const; + void unset_decoration(ID id, spv::Decoration decoration); + + // Decoration handling methods (for members of a struct). + void set_member_name(TypeID id, uint32_t index, const std::string &name); + const std::string &get_member_name(TypeID id, uint32_t index) const; + void set_member_decoration(TypeID id, uint32_t index, spv::Decoration decoration, uint32_t argument = 0); + void set_member_decoration_string(TypeID id, uint32_t index, spv::Decoration decoration, + const std::string &argument); + uint32_t get_member_decoration(TypeID id, uint32_t index, spv::Decoration decoration) const; + const std::string &get_member_decoration_string(TypeID id, uint32_t index, spv::Decoration decoration) const; + bool has_member_decoration(TypeID id, uint32_t index, spv::Decoration decoration) const; + const Bitset &get_member_decoration_bitset(TypeID id, uint32_t index) const; + void unset_member_decoration(TypeID id, uint32_t index, spv::Decoration decoration); + + void mark_used_as_array_length(ID id); + uint32_t increase_bound_by(uint32_t count); + Bitset get_buffer_block_flags(const SPIRVariable &var) const; + Bitset get_buffer_block_type_flags(const SPIRType &type) const; + + void add_typed_id(Types type, ID id); + void remove_typed_id(Types type, ID id); + + class LoopLock + { + public: + explicit LoopLock(uint32_t *counter); + LoopLock(const LoopLock &) = delete; + void operator=(const LoopLock &) = delete; + LoopLock(LoopLock &&other) SPIRV_CROSS_NOEXCEPT; + LoopLock &operator=(LoopLock &&other) SPIRV_CROSS_NOEXCEPT; + ~LoopLock(); + + private: + uint32_t *lock = nullptr; + }; + + // This must be held while iterating over a type ID array. + // It is undefined if someone calls set<>() while we're iterating over a data structure, so we must + // make sure that this case is avoided. + + // If we have a hard lock, it is an error to call set<>(), and an exception is thrown. + // If we have a soft lock, we silently ignore any additions to the typed arrays. + // This should only be used for physical ID remapping where we need to create an ID, but we will never + // care about iterating over them. + LoopLock create_loop_hard_lock() const; + LoopLock create_loop_soft_lock() const; + + template + void for_each_typed_id(const Op &op) + { + auto loop_lock = create_loop_hard_lock(); + for (auto &id : ids_for_type[T::type]) + { + if (ids[id].get_type() == static_cast(T::type)) + op(id, get(id)); + } + } + + template + void for_each_typed_id(const Op &op) const + { + auto loop_lock = create_loop_hard_lock(); + for (auto &id : ids_for_type[T::type]) + { + if (ids[id].get_type() == static_cast(T::type)) + op(id, get(id)); + } + } + + template + void reset_all_of_type() + { + reset_all_of_type(static_cast(T::type)); + } + + void reset_all_of_type(Types type); + + Meta *find_meta(ID id); + const Meta *find_meta(ID id) const; + + const std::string &get_empty_string() const + { + return empty_string; + } + + void make_constant_null(uint32_t id, uint32_t type, bool add_to_typed_id_set); + + void fixup_reserved_names(); + + static void sanitize_underscores(std::string &str); + static void sanitize_identifier(std::string &str, bool member, bool allow_reserved_prefixes); + static bool is_globally_reserved_identifier(std::string &str, bool allow_reserved_prefixes); + + uint32_t get_spirv_version() const; + +private: + template + T &get(uint32_t id) + { + return variant_get(ids[id]); + } + + template + const T &get(uint32_t id) const + { + return variant_get(ids[id]); + } + + mutable uint32_t loop_iteration_depth_hard = 0; + mutable uint32_t loop_iteration_depth_soft = 0; + std::string empty_string; + Bitset cleared_bitset; + + std::unordered_set meta_needing_name_fixup; +}; +} // namespace SPIRV_CROSS_NAMESPACE + +#endif diff --git a/thirdparty/spirv-cross/spirv_cross_util.cpp b/thirdparty/spirv-cross/spirv_cross_util.cpp new file mode 100644 index 00000000000..7cff010d1c1 --- /dev/null +++ b/thirdparty/spirv-cross/spirv_cross_util.cpp @@ -0,0 +1,77 @@ +/* + * Copyright 2015-2021 Arm Limited + * SPDX-License-Identifier: Apache-2.0 OR MIT + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +/* + * At your option, you may choose to accept this material under either: + * 1. The Apache License, Version 2.0, found at , or + * 2. The MIT License, found at . + */ + +#include "spirv_cross_util.hpp" +#include "spirv_common.hpp" + +using namespace spv; +using namespace SPIRV_CROSS_NAMESPACE; + +namespace spirv_cross_util +{ +void rename_interface_variable(Compiler &compiler, const SmallVector &resources, uint32_t location, + const std::string &name) +{ + for (auto &v : resources) + { + if (!compiler.has_decoration(v.id, spv::DecorationLocation)) + continue; + + auto loc = compiler.get_decoration(v.id, spv::DecorationLocation); + if (loc != location) + continue; + + auto &type = compiler.get_type(v.base_type_id); + + // This is more of a friendly variant. If we need to rename interface variables, we might have to rename + // structs as well and make sure all the names match up. + if (type.basetype == SPIRType::Struct) + { + compiler.set_name(v.base_type_id, join("SPIRV_Cross_Interface_Location", location)); + for (uint32_t i = 0; i < uint32_t(type.member_types.size()); i++) + compiler.set_member_name(v.base_type_id, i, join("InterfaceMember", i)); + } + + compiler.set_name(v.id, name); + } +} + +void inherit_combined_sampler_bindings(Compiler &compiler) +{ + auto &samplers = compiler.get_combined_image_samplers(); + for (auto &s : samplers) + { + if (compiler.has_decoration(s.image_id, spv::DecorationDescriptorSet)) + { + uint32_t set = compiler.get_decoration(s.image_id, spv::DecorationDescriptorSet); + compiler.set_decoration(s.combined_id, spv::DecorationDescriptorSet, set); + } + + if (compiler.has_decoration(s.image_id, spv::DecorationBinding)) + { + uint32_t binding = compiler.get_decoration(s.image_id, spv::DecorationBinding); + compiler.set_decoration(s.combined_id, spv::DecorationBinding, binding); + } + } +} +} // namespace spirv_cross_util diff --git a/thirdparty/spirv-cross/spirv_cross_util.hpp b/thirdparty/spirv-cross/spirv_cross_util.hpp new file mode 100644 index 00000000000..e6e3fcdb634 --- /dev/null +++ b/thirdparty/spirv-cross/spirv_cross_util.hpp @@ -0,0 +1,37 @@ +/* + * Copyright 2015-2021 Arm Limited + * SPDX-License-Identifier: Apache-2.0 OR MIT + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +/* + * At your option, you may choose to accept this material under either: + * 1. The Apache License, Version 2.0, found at , or + * 2. The MIT License, found at . + */ + +#ifndef SPIRV_CROSS_UTIL_HPP +#define SPIRV_CROSS_UTIL_HPP + +#include "spirv_cross.hpp" + +namespace spirv_cross_util +{ +void rename_interface_variable(SPIRV_CROSS_NAMESPACE::Compiler &compiler, + const SPIRV_CROSS_NAMESPACE::SmallVector &resources, + uint32_t location, const std::string &name); +void inherit_combined_sampler_bindings(SPIRV_CROSS_NAMESPACE::Compiler &compiler); +} // namespace spirv_cross_util + +#endif diff --git a/thirdparty/spirv-cross/spirv_glsl.cpp b/thirdparty/spirv-cross/spirv_glsl.cpp new file mode 100644 index 00000000000..fad1132e82a --- /dev/null +++ b/thirdparty/spirv-cross/spirv_glsl.cpp @@ -0,0 +1,19109 @@ +/* + * Copyright 2015-2021 Arm Limited + * SPDX-License-Identifier: Apache-2.0 OR MIT + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +/* + * At your option, you may choose to accept this material under either: + * 1. The Apache License, Version 2.0, found at , or + * 2. The MIT License, found at . + */ + +#include "spirv_glsl.hpp" +#include "GLSL.std.450.h" +#include "spirv_common.hpp" +#include +#include +#include +#include +#include +#include +#include + +#ifndef _WIN32 +#include +#endif +#include + +using namespace spv; +using namespace SPIRV_CROSS_NAMESPACE; +using namespace std; + +enum ExtraSubExpressionType +{ + // Create masks above any legal ID range to allow multiple address spaces into the extra_sub_expressions map. + EXTRA_SUB_EXPRESSION_TYPE_STREAM_OFFSET = 0x10000000, + EXTRA_SUB_EXPRESSION_TYPE_AUX = 0x20000000 +}; + +static bool is_unsigned_opcode(Op op) +{ + // Don't have to be exhaustive, only relevant for legacy target checking ... + switch (op) + { + case OpShiftRightLogical: + case OpUGreaterThan: + case OpUGreaterThanEqual: + case OpULessThan: + case OpULessThanEqual: + case OpUConvert: + case OpUDiv: + case OpUMod: + case OpUMulExtended: + case OpConvertUToF: + case OpConvertFToU: + return true; + + default: + return false; + } +} + +static bool is_unsigned_glsl_opcode(GLSLstd450 op) +{ + // Don't have to be exhaustive, only relevant for legacy target checking ... + switch (op) + { + case GLSLstd450UClamp: + case GLSLstd450UMin: + case GLSLstd450UMax: + case GLSLstd450FindUMsb: + return true; + + default: + return false; + } +} + +static bool packing_is_vec4_padded(BufferPackingStandard packing) +{ + switch (packing) + { + case BufferPackingHLSLCbuffer: + case BufferPackingHLSLCbufferPackOffset: + case BufferPackingStd140: + case BufferPackingStd140EnhancedLayout: + return true; + + default: + return false; + } +} + +static bool packing_is_hlsl(BufferPackingStandard packing) +{ + switch (packing) + { + case BufferPackingHLSLCbuffer: + case BufferPackingHLSLCbufferPackOffset: + return true; + + default: + return false; + } +} + +static bool packing_has_flexible_offset(BufferPackingStandard packing) +{ + switch (packing) + { + case BufferPackingStd140: + case BufferPackingStd430: + case BufferPackingScalar: + case BufferPackingHLSLCbuffer: + return false; + + default: + return true; + } +} + +static bool packing_is_scalar(BufferPackingStandard packing) +{ + switch (packing) + { + case BufferPackingScalar: + case BufferPackingScalarEnhancedLayout: + return true; + + default: + return false; + } +} + +static BufferPackingStandard packing_to_substruct_packing(BufferPackingStandard packing) +{ + switch (packing) + { + case BufferPackingStd140EnhancedLayout: + return BufferPackingStd140; + case BufferPackingStd430EnhancedLayout: + return BufferPackingStd430; + case BufferPackingHLSLCbufferPackOffset: + return BufferPackingHLSLCbuffer; + case BufferPackingScalarEnhancedLayout: + return BufferPackingScalar; + default: + return packing; + } +} + +void CompilerGLSL::init() +{ + if (ir.source.known) + { + options.es = ir.source.es; + options.version = ir.source.version; + } + + // Query the locale to see what the decimal point is. + // We'll rely on fixing it up ourselves in the rare case we have a comma-as-decimal locale + // rather than setting locales ourselves. Settings locales in a safe and isolated way is rather + // tricky. +#ifdef _WIN32 + // On Windows, localeconv uses thread-local storage, so it should be fine. + const struct lconv *conv = localeconv(); + if (conv && conv->decimal_point) + current_locale_radix_character = *conv->decimal_point; +#elif defined(__ANDROID__) && __ANDROID_API__ < 26 + // nl_langinfo is not supported on this platform, fall back to the worse alternative. + const struct lconv *conv = localeconv(); + if (conv && conv->decimal_point) + current_locale_radix_character = *conv->decimal_point; +#else + // localeconv, the portable function is not MT safe ... + const char *decimal_point = nl_langinfo(RADIXCHAR); + if (decimal_point && *decimal_point != '\0') + current_locale_radix_character = *decimal_point; +#endif +} + +static const char *to_pls_layout(PlsFormat format) +{ + switch (format) + { + case PlsR11FG11FB10F: + return "layout(r11f_g11f_b10f) "; + case PlsR32F: + return "layout(r32f) "; + case PlsRG16F: + return "layout(rg16f) "; + case PlsRGB10A2: + return "layout(rgb10_a2) "; + case PlsRGBA8: + return "layout(rgba8) "; + case PlsRG16: + return "layout(rg16) "; + case PlsRGBA8I: + return "layout(rgba8i)"; + case PlsRG16I: + return "layout(rg16i) "; + case PlsRGB10A2UI: + return "layout(rgb10_a2ui) "; + case PlsRGBA8UI: + return "layout(rgba8ui) "; + case PlsRG16UI: + return "layout(rg16ui) "; + case PlsR32UI: + return "layout(r32ui) "; + default: + return ""; + } +} + +static std::pair pls_format_to_basetype(PlsFormat format) +{ + switch (format) + { + default: + case PlsR11FG11FB10F: + case PlsR32F: + case PlsRG16F: + case PlsRGB10A2: + case PlsRGBA8: + case PlsRG16: + return std::make_pair(spv::OpTypeFloat, SPIRType::Float); + + case PlsRGBA8I: + case PlsRG16I: + return std::make_pair(spv::OpTypeInt, SPIRType::Int); + + case PlsRGB10A2UI: + case PlsRGBA8UI: + case PlsRG16UI: + case PlsR32UI: + return std::make_pair(spv::OpTypeInt, SPIRType::UInt); + } +} + +static uint32_t pls_format_to_components(PlsFormat format) +{ + switch (format) + { + default: + case PlsR32F: + case PlsR32UI: + return 1; + + case PlsRG16F: + case PlsRG16: + case PlsRG16UI: + case PlsRG16I: + return 2; + + case PlsR11FG11FB10F: + return 3; + + case PlsRGB10A2: + case PlsRGBA8: + case PlsRGBA8I: + case PlsRGB10A2UI: + case PlsRGBA8UI: + return 4; + } +} + +const char *CompilerGLSL::vector_swizzle(int vecsize, int index) +{ + static const char *const swizzle[4][4] = { + { ".x", ".y", ".z", ".w" }, + { ".xy", ".yz", ".zw", nullptr }, + { ".xyz", ".yzw", nullptr, nullptr }, +#if defined(__GNUC__) && (__GNUC__ == 9) + // This works around a GCC 9 bug, see details in https://gcc.gnu.org/bugzilla/show_bug.cgi?id=90947. + // This array ends up being compiled as all nullptrs, tripping the assertions below. + { "", nullptr, nullptr, "$" }, +#else + { "", nullptr, nullptr, nullptr }, +#endif + }; + + assert(vecsize >= 1 && vecsize <= 4); + assert(index >= 0 && index < 4); + assert(swizzle[vecsize - 1][index]); + + return swizzle[vecsize - 1][index]; +} + +void CompilerGLSL::reset(uint32_t iteration_count) +{ + // Sanity check the iteration count to be robust against a certain class of bugs where + // we keep forcing recompilations without making clear forward progress. + // In buggy situations we will loop forever, or loop for an unbounded number of iterations. + // Certain types of recompilations are considered to make forward progress, + // but in almost all situations, we'll never see more than 3 iterations. + // It is highly context-sensitive when we need to force recompilation, + // and it is not practical with the current architecture + // to resolve everything up front. + if (iteration_count >= options.force_recompile_max_debug_iterations && !is_force_recompile_forward_progress) + SPIRV_CROSS_THROW("Maximum compilation loops detected and no forward progress was made. Must be a SPIRV-Cross bug!"); + + // We do some speculative optimizations which should pretty much always work out, + // but just in case the SPIR-V is rather weird, recompile until it's happy. + // This typically only means one extra pass. + clear_force_recompile(); + + // Clear invalid expression tracking. + invalid_expressions.clear(); + composite_insert_overwritten.clear(); + current_function = nullptr; + + // Clear temporary usage tracking. + expression_usage_counts.clear(); + forwarded_temporaries.clear(); + suppressed_usage_tracking.clear(); + + // Ensure that we declare phi-variable copies even if the original declaration isn't deferred + flushed_phi_variables.clear(); + + current_emitting_switch_stack.clear(); + + reset_name_caches(); + + ir.for_each_typed_id([&](uint32_t, SPIRFunction &func) { + func.active = false; + func.flush_undeclared = true; + }); + + ir.for_each_typed_id([&](uint32_t, SPIRVariable &var) { var.dependees.clear(); }); + + ir.reset_all_of_type(); + ir.reset_all_of_type(); + + statement_count = 0; + indent = 0; + current_loop_level = 0; +} + +void CompilerGLSL::remap_pls_variables() +{ + for (auto &input : pls_inputs) + { + auto &var = get(input.id); + + bool input_is_target = false; + if (var.storage == StorageClassUniformConstant) + { + auto &type = get(var.basetype); + input_is_target = type.image.dim == DimSubpassData; + } + + if (var.storage != StorageClassInput && !input_is_target) + SPIRV_CROSS_THROW("Can only use in and target variables for PLS inputs."); + var.remapped_variable = true; + } + + for (auto &output : pls_outputs) + { + auto &var = get(output.id); + if (var.storage != StorageClassOutput) + SPIRV_CROSS_THROW("Can only use out variables for PLS outputs."); + var.remapped_variable = true; + } +} + +void CompilerGLSL::remap_ext_framebuffer_fetch(uint32_t input_attachment_index, uint32_t color_location, bool coherent) +{ + subpass_to_framebuffer_fetch_attachment.push_back({ input_attachment_index, color_location }); + inout_color_attachments.push_back({ color_location, coherent }); +} + +bool CompilerGLSL::location_is_framebuffer_fetch(uint32_t location) const +{ + return std::find_if(begin(inout_color_attachments), end(inout_color_attachments), + [&](const std::pair &elem) { + return elem.first == location; + }) != end(inout_color_attachments); +} + +bool CompilerGLSL::location_is_non_coherent_framebuffer_fetch(uint32_t location) const +{ + return std::find_if(begin(inout_color_attachments), end(inout_color_attachments), + [&](const std::pair &elem) { + return elem.first == location && !elem.second; + }) != end(inout_color_attachments); +} + +void CompilerGLSL::find_static_extensions() +{ + ir.for_each_typed_id([&](uint32_t, const SPIRType &type) { + if (type.basetype == SPIRType::Double) + { + if (options.es) + SPIRV_CROSS_THROW("FP64 not supported in ES profile."); + if (!options.es && options.version < 400) + require_extension_internal("GL_ARB_gpu_shader_fp64"); + } + else if (type.basetype == SPIRType::Int64 || type.basetype == SPIRType::UInt64) + { + if (options.es && options.version < 310) // GL_NV_gpu_shader5 fallback requires 310. + SPIRV_CROSS_THROW("64-bit integers not supported in ES profile before version 310."); + require_extension_internal("GL_ARB_gpu_shader_int64"); + } + else if (type.basetype == SPIRType::Half) + { + require_extension_internal("GL_EXT_shader_explicit_arithmetic_types_float16"); + if (options.vulkan_semantics) + require_extension_internal("GL_EXT_shader_16bit_storage"); + } + else if (type.basetype == SPIRType::SByte || type.basetype == SPIRType::UByte) + { + require_extension_internal("GL_EXT_shader_explicit_arithmetic_types_int8"); + if (options.vulkan_semantics) + require_extension_internal("GL_EXT_shader_8bit_storage"); + } + else if (type.basetype == SPIRType::Short || type.basetype == SPIRType::UShort) + { + require_extension_internal("GL_EXT_shader_explicit_arithmetic_types_int16"); + if (options.vulkan_semantics) + require_extension_internal("GL_EXT_shader_16bit_storage"); + } + }); + + auto &execution = get_entry_point(); + switch (execution.model) + { + case ExecutionModelGLCompute: + if (!options.es && options.version < 430) + require_extension_internal("GL_ARB_compute_shader"); + if (options.es && options.version < 310) + SPIRV_CROSS_THROW("At least ESSL 3.10 required for compute shaders."); + break; + + case ExecutionModelGeometry: + if (options.es && options.version < 320) + require_extension_internal("GL_EXT_geometry_shader"); + if (!options.es && options.version < 150) + require_extension_internal("GL_ARB_geometry_shader4"); + + if (execution.flags.get(ExecutionModeInvocations) && execution.invocations != 1) + { + // Instanced GS is part of 400 core or this extension. + if (!options.es && options.version < 400) + require_extension_internal("GL_ARB_gpu_shader5"); + } + break; + + case ExecutionModelTessellationEvaluation: + case ExecutionModelTessellationControl: + if (options.es && options.version < 320) + require_extension_internal("GL_EXT_tessellation_shader"); + if (!options.es && options.version < 400) + require_extension_internal("GL_ARB_tessellation_shader"); + break; + + case ExecutionModelRayGenerationKHR: + case ExecutionModelIntersectionKHR: + case ExecutionModelAnyHitKHR: + case ExecutionModelClosestHitKHR: + case ExecutionModelMissKHR: + case ExecutionModelCallableKHR: + // NV enums are aliases. + if (options.es || options.version < 460) + SPIRV_CROSS_THROW("Ray tracing shaders require non-es profile with version 460 or above."); + if (!options.vulkan_semantics) + SPIRV_CROSS_THROW("Ray tracing requires Vulkan semantics."); + + // Need to figure out if we should target KHR or NV extension based on capabilities. + for (auto &cap : ir.declared_capabilities) + { + if (cap == CapabilityRayTracingKHR || cap == CapabilityRayQueryKHR || + cap == CapabilityRayTraversalPrimitiveCullingKHR) + { + ray_tracing_is_khr = true; + break; + } + } + + if (ray_tracing_is_khr) + { + // In KHR ray tracing we pass payloads by pointer instead of location, + // so make sure we assign locations properly. + ray_tracing_khr_fixup_locations(); + require_extension_internal("GL_EXT_ray_tracing"); + } + else + require_extension_internal("GL_NV_ray_tracing"); + break; + + case ExecutionModelMeshEXT: + case ExecutionModelTaskEXT: + if (options.es || options.version < 450) + SPIRV_CROSS_THROW("Mesh shaders require GLSL 450 or above."); + if (!options.vulkan_semantics) + SPIRV_CROSS_THROW("Mesh shaders require Vulkan semantics."); + require_extension_internal("GL_EXT_mesh_shader"); + break; + + default: + break; + } + + if (!pls_inputs.empty() || !pls_outputs.empty()) + { + if (execution.model != ExecutionModelFragment) + SPIRV_CROSS_THROW("Can only use GL_EXT_shader_pixel_local_storage in fragment shaders."); + require_extension_internal("GL_EXT_shader_pixel_local_storage"); + } + + if (!inout_color_attachments.empty()) + { + if (execution.model != ExecutionModelFragment) + SPIRV_CROSS_THROW("Can only use GL_EXT_shader_framebuffer_fetch in fragment shaders."); + if (options.vulkan_semantics) + SPIRV_CROSS_THROW("Cannot use EXT_shader_framebuffer_fetch in Vulkan GLSL."); + + bool has_coherent = false; + bool has_incoherent = false; + + for (auto &att : inout_color_attachments) + { + if (att.second) + has_coherent = true; + else + has_incoherent = true; + } + + if (has_coherent) + require_extension_internal("GL_EXT_shader_framebuffer_fetch"); + if (has_incoherent) + require_extension_internal("GL_EXT_shader_framebuffer_fetch_non_coherent"); + } + + if (options.separate_shader_objects && !options.es && options.version < 410) + require_extension_internal("GL_ARB_separate_shader_objects"); + + if (ir.addressing_model == AddressingModelPhysicalStorageBuffer64EXT) + { + if (!options.vulkan_semantics) + SPIRV_CROSS_THROW("GL_EXT_buffer_reference is only supported in Vulkan GLSL."); + if (options.es && options.version < 320) + SPIRV_CROSS_THROW("GL_EXT_buffer_reference requires ESSL 320."); + else if (!options.es && options.version < 450) + SPIRV_CROSS_THROW("GL_EXT_buffer_reference requires GLSL 450."); + require_extension_internal("GL_EXT_buffer_reference2"); + } + else if (ir.addressing_model != AddressingModelLogical) + { + SPIRV_CROSS_THROW("Only Logical and PhysicalStorageBuffer64EXT addressing models are supported."); + } + + // Check for nonuniform qualifier and passthrough. + // Instead of looping over all decorations to find this, just look at capabilities. + for (auto &cap : ir.declared_capabilities) + { + switch (cap) + { + case CapabilityShaderNonUniformEXT: + if (!options.vulkan_semantics) + require_extension_internal("GL_NV_gpu_shader5"); + else + require_extension_internal("GL_EXT_nonuniform_qualifier"); + break; + case CapabilityRuntimeDescriptorArrayEXT: + if (!options.vulkan_semantics) + SPIRV_CROSS_THROW("GL_EXT_nonuniform_qualifier is only supported in Vulkan GLSL."); + require_extension_internal("GL_EXT_nonuniform_qualifier"); + break; + + case CapabilityGeometryShaderPassthroughNV: + if (execution.model == ExecutionModelGeometry) + { + require_extension_internal("GL_NV_geometry_shader_passthrough"); + execution.geometry_passthrough = true; + } + break; + + case CapabilityVariablePointers: + case CapabilityVariablePointersStorageBuffer: + SPIRV_CROSS_THROW("VariablePointers capability is not supported in GLSL."); + + case CapabilityMultiView: + if (options.vulkan_semantics) + require_extension_internal("GL_EXT_multiview"); + else + { + require_extension_internal("GL_OVR_multiview2"); + if (options.ovr_multiview_view_count == 0) + SPIRV_CROSS_THROW("ovr_multiview_view_count must be non-zero when using GL_OVR_multiview2."); + if (get_execution_model() != ExecutionModelVertex) + SPIRV_CROSS_THROW("OVR_multiview2 can only be used with Vertex shaders."); + } + break; + + case CapabilityRayQueryKHR: + if (options.es || options.version < 460 || !options.vulkan_semantics) + SPIRV_CROSS_THROW("RayQuery requires Vulkan GLSL 460."); + require_extension_internal("GL_EXT_ray_query"); + ray_tracing_is_khr = true; + break; + + case CapabilityRayTraversalPrimitiveCullingKHR: + if (options.es || options.version < 460 || !options.vulkan_semantics) + SPIRV_CROSS_THROW("RayQuery requires Vulkan GLSL 460."); + require_extension_internal("GL_EXT_ray_flags_primitive_culling"); + ray_tracing_is_khr = true; + break; + + default: + break; + } + } + + if (options.ovr_multiview_view_count) + { + if (options.vulkan_semantics) + SPIRV_CROSS_THROW("OVR_multiview2 cannot be used with Vulkan semantics."); + if (get_execution_model() != ExecutionModelVertex) + SPIRV_CROSS_THROW("OVR_multiview2 can only be used with Vertex shaders."); + require_extension_internal("GL_OVR_multiview2"); + } + + // KHR one is likely to get promoted at some point, so if we don't see an explicit SPIR-V extension, assume KHR. + for (auto &ext : ir.declared_extensions) + if (ext == "SPV_NV_fragment_shader_barycentric") + barycentric_is_nv = true; +} + +void CompilerGLSL::require_polyfill(Polyfill polyfill, bool relaxed) +{ + uint32_t &polyfills = (relaxed && (options.es || options.vulkan_semantics)) ? + required_polyfills_relaxed : required_polyfills; + + if ((polyfills & polyfill) == 0) + { + polyfills |= polyfill; + force_recompile(); + } +} + +void CompilerGLSL::ray_tracing_khr_fixup_locations() +{ + uint32_t location = 0; + ir.for_each_typed_id([&](uint32_t, SPIRVariable &var) { + // Incoming payload storage can also be used for tracing. + if (var.storage != StorageClassRayPayloadKHR && var.storage != StorageClassCallableDataKHR && + var.storage != StorageClassIncomingRayPayloadKHR && var.storage != StorageClassIncomingCallableDataKHR) + return; + if (is_hidden_variable(var)) + return; + set_decoration(var.self, DecorationLocation, location++); + }); +} + +string CompilerGLSL::compile() +{ + ir.fixup_reserved_names(); + + if (!options.vulkan_semantics) + { + // only NV_gpu_shader5 supports divergent indexing on OpenGL, and it does so without extra qualifiers + backend.nonuniform_qualifier = ""; + backend.needs_row_major_load_workaround = options.enable_row_major_load_workaround; + } + backend.allow_precision_qualifiers = options.vulkan_semantics || options.es; + backend.force_gl_in_out_block = true; + backend.supports_extensions = true; + backend.use_array_constructor = true; + backend.workgroup_size_is_hidden = true; + backend.requires_relaxed_precision_analysis = options.es || options.vulkan_semantics; + backend.support_precise_qualifier = + (!options.es && options.version >= 400) || (options.es && options.version >= 320); + + if (is_legacy_es()) + backend.support_case_fallthrough = false; + + // Scan the SPIR-V to find trivial uses of extensions. + fixup_anonymous_struct_names(); + fixup_type_alias(); + reorder_type_alias(); + build_function_control_flow_graphs_and_analyze(); + find_static_extensions(); + fixup_image_load_store_access(); + update_active_builtins(); + analyze_image_and_sampler_usage(); + analyze_interlocked_resource_usage(); + if (!inout_color_attachments.empty()) + emit_inout_fragment_outputs_copy_to_subpass_inputs(); + + // Shaders might cast unrelated data to pointers of non-block types. + // Find all such instances and make sure we can cast the pointers to a synthesized block type. + if (ir.addressing_model == AddressingModelPhysicalStorageBuffer64EXT) + analyze_non_block_pointer_types(); + + uint32_t pass_count = 0; + do + { + reset(pass_count); + + buffer.reset(); + + emit_header(); + emit_resources(); + emit_extension_workarounds(get_execution_model()); + + if (required_polyfills != 0) + emit_polyfills(required_polyfills, false); + if ((options.es || options.vulkan_semantics) && required_polyfills_relaxed != 0) + emit_polyfills(required_polyfills_relaxed, true); + + emit_function(get(ir.default_entry_point), Bitset()); + + pass_count++; + } while (is_forcing_recompilation()); + + // Implement the interlocked wrapper function at the end. + // The body was implemented in lieu of main(). + if (interlocked_is_complex) + { + statement("void main()"); + begin_scope(); + statement("// Interlocks were used in a way not compatible with GLSL, this is very slow."); + statement("SPIRV_Cross_beginInvocationInterlock();"); + statement("spvMainInterlockedBody();"); + statement("SPIRV_Cross_endInvocationInterlock();"); + end_scope(); + } + + // Entry point in GLSL is always main(). + get_entry_point().name = "main"; + + return buffer.str(); +} + +std::string CompilerGLSL::get_partial_source() +{ + return buffer.str(); +} + +void CompilerGLSL::build_workgroup_size(SmallVector &arguments, const SpecializationConstant &wg_x, + const SpecializationConstant &wg_y, const SpecializationConstant &wg_z) +{ + auto &execution = get_entry_point(); + bool builtin_workgroup = execution.workgroup_size.constant != 0; + bool use_local_size_id = !builtin_workgroup && execution.flags.get(ExecutionModeLocalSizeId); + + if (wg_x.id) + { + if (options.vulkan_semantics) + arguments.push_back(join("local_size_x_id = ", wg_x.constant_id)); + else + arguments.push_back(join("local_size_x = ", get(wg_x.id).specialization_constant_macro_name)); + } + else if (use_local_size_id && execution.workgroup_size.id_x) + arguments.push_back(join("local_size_x = ", get(execution.workgroup_size.id_x).scalar())); + else + arguments.push_back(join("local_size_x = ", execution.workgroup_size.x)); + + if (wg_y.id) + { + if (options.vulkan_semantics) + arguments.push_back(join("local_size_y_id = ", wg_y.constant_id)); + else + arguments.push_back(join("local_size_y = ", get(wg_y.id).specialization_constant_macro_name)); + } + else if (use_local_size_id && execution.workgroup_size.id_y) + arguments.push_back(join("local_size_y = ", get(execution.workgroup_size.id_y).scalar())); + else + arguments.push_back(join("local_size_y = ", execution.workgroup_size.y)); + + if (wg_z.id) + { + if (options.vulkan_semantics) + arguments.push_back(join("local_size_z_id = ", wg_z.constant_id)); + else + arguments.push_back(join("local_size_z = ", get(wg_z.id).specialization_constant_macro_name)); + } + else if (use_local_size_id && execution.workgroup_size.id_z) + arguments.push_back(join("local_size_z = ", get(execution.workgroup_size.id_z).scalar())); + else + arguments.push_back(join("local_size_z = ", execution.workgroup_size.z)); +} + +void CompilerGLSL::request_subgroup_feature(ShaderSubgroupSupportHelper::Feature feature) +{ + if (options.vulkan_semantics) + { + auto khr_extension = ShaderSubgroupSupportHelper::get_KHR_extension_for_feature(feature); + require_extension_internal(ShaderSubgroupSupportHelper::get_extension_name(khr_extension)); + } + else + { + if (!shader_subgroup_supporter.is_feature_requested(feature)) + force_recompile(); + shader_subgroup_supporter.request_feature(feature); + } +} + +void CompilerGLSL::emit_header() +{ + auto &execution = get_entry_point(); + statement("#version ", options.version, options.es && options.version > 100 ? " es" : ""); + + if (!options.es && options.version < 420) + { + // Needed for binding = # on UBOs, etc. + if (options.enable_420pack_extension) + { + statement("#ifdef GL_ARB_shading_language_420pack"); + statement("#extension GL_ARB_shading_language_420pack : require"); + statement("#endif"); + } + // Needed for: layout(early_fragment_tests) in; + if (execution.flags.get(ExecutionModeEarlyFragmentTests)) + require_extension_internal("GL_ARB_shader_image_load_store"); + } + + // Needed for: layout(post_depth_coverage) in; + if (execution.flags.get(ExecutionModePostDepthCoverage)) + require_extension_internal("GL_ARB_post_depth_coverage"); + + // Needed for: layout({pixel,sample}_interlock_[un]ordered) in; + bool interlock_used = execution.flags.get(ExecutionModePixelInterlockOrderedEXT) || + execution.flags.get(ExecutionModePixelInterlockUnorderedEXT) || + execution.flags.get(ExecutionModeSampleInterlockOrderedEXT) || + execution.flags.get(ExecutionModeSampleInterlockUnorderedEXT); + + if (interlock_used) + { + if (options.es) + { + if (options.version < 310) + SPIRV_CROSS_THROW("At least ESSL 3.10 required for fragment shader interlock."); + require_extension_internal("GL_NV_fragment_shader_interlock"); + } + else + { + if (options.version < 420) + require_extension_internal("GL_ARB_shader_image_load_store"); + require_extension_internal("GL_ARB_fragment_shader_interlock"); + } + } + + for (auto &ext : forced_extensions) + { + if (ext == "GL_ARB_gpu_shader_int64") + { + statement("#if defined(GL_ARB_gpu_shader_int64)"); + statement("#extension GL_ARB_gpu_shader_int64 : require"); + if (!options.vulkan_semantics || options.es) + { + statement("#elif defined(GL_NV_gpu_shader5)"); + statement("#extension GL_NV_gpu_shader5 : require"); + } + statement("#else"); + statement("#error No extension available for 64-bit integers."); + statement("#endif"); + } + else if (ext == "GL_EXT_shader_explicit_arithmetic_types_float16") + { + // Special case, this extension has a potential fallback to another vendor extension in normal GLSL. + // GL_AMD_gpu_shader_half_float is a superset, so try that first. + statement("#if defined(GL_AMD_gpu_shader_half_float)"); + statement("#extension GL_AMD_gpu_shader_half_float : require"); + if (!options.vulkan_semantics) + { + statement("#elif defined(GL_NV_gpu_shader5)"); + statement("#extension GL_NV_gpu_shader5 : require"); + } + else + { + statement("#elif defined(GL_EXT_shader_explicit_arithmetic_types_float16)"); + statement("#extension GL_EXT_shader_explicit_arithmetic_types_float16 : require"); + } + statement("#else"); + statement("#error No extension available for FP16."); + statement("#endif"); + } + else if (ext == "GL_EXT_shader_explicit_arithmetic_types_int8") + { + if (options.vulkan_semantics) + statement("#extension GL_EXT_shader_explicit_arithmetic_types_int8 : require"); + else + { + statement("#if defined(GL_EXT_shader_explicit_arithmetic_types_int8)"); + statement("#extension GL_EXT_shader_explicit_arithmetic_types_int8 : require"); + statement("#elif defined(GL_NV_gpu_shader5)"); + statement("#extension GL_NV_gpu_shader5 : require"); + statement("#else"); + statement("#error No extension available for Int8."); + statement("#endif"); + } + } + else if (ext == "GL_EXT_shader_explicit_arithmetic_types_int16") + { + if (options.vulkan_semantics) + statement("#extension GL_EXT_shader_explicit_arithmetic_types_int16 : require"); + else + { + statement("#if defined(GL_EXT_shader_explicit_arithmetic_types_int16)"); + statement("#extension GL_EXT_shader_explicit_arithmetic_types_int16 : require"); + statement("#elif defined(GL_AMD_gpu_shader_int16)"); + statement("#extension GL_AMD_gpu_shader_int16 : require"); + statement("#elif defined(GL_NV_gpu_shader5)"); + statement("#extension GL_NV_gpu_shader5 : require"); + statement("#else"); + statement("#error No extension available for Int16."); + statement("#endif"); + } + } + else if (ext == "GL_ARB_post_depth_coverage") + { + if (options.es) + statement("#extension GL_EXT_post_depth_coverage : require"); + else + { + statement("#if defined(GL_ARB_post_depth_coverge)"); + statement("#extension GL_ARB_post_depth_coverage : require"); + statement("#else"); + statement("#extension GL_EXT_post_depth_coverage : require"); + statement("#endif"); + } + } + else if (!options.vulkan_semantics && ext == "GL_ARB_shader_draw_parameters") + { + // Soft-enable this extension on plain GLSL. + statement("#ifdef ", ext); + statement("#extension ", ext, " : enable"); + statement("#endif"); + } + else if (ext == "GL_EXT_control_flow_attributes") + { + // These are just hints so we can conditionally enable and fallback in the shader. + statement("#if defined(GL_EXT_control_flow_attributes)"); + statement("#extension GL_EXT_control_flow_attributes : require"); + statement("#define SPIRV_CROSS_FLATTEN [[flatten]]"); + statement("#define SPIRV_CROSS_BRANCH [[dont_flatten]]"); + statement("#define SPIRV_CROSS_UNROLL [[unroll]]"); + statement("#define SPIRV_CROSS_LOOP [[dont_unroll]]"); + statement("#else"); + statement("#define SPIRV_CROSS_FLATTEN"); + statement("#define SPIRV_CROSS_BRANCH"); + statement("#define SPIRV_CROSS_UNROLL"); + statement("#define SPIRV_CROSS_LOOP"); + statement("#endif"); + } + else if (ext == "GL_NV_fragment_shader_interlock") + { + statement("#extension GL_NV_fragment_shader_interlock : require"); + statement("#define SPIRV_Cross_beginInvocationInterlock() beginInvocationInterlockNV()"); + statement("#define SPIRV_Cross_endInvocationInterlock() endInvocationInterlockNV()"); + } + else if (ext == "GL_ARB_fragment_shader_interlock") + { + statement("#ifdef GL_ARB_fragment_shader_interlock"); + statement("#extension GL_ARB_fragment_shader_interlock : enable"); + statement("#define SPIRV_Cross_beginInvocationInterlock() beginInvocationInterlockARB()"); + statement("#define SPIRV_Cross_endInvocationInterlock() endInvocationInterlockARB()"); + statement("#elif defined(GL_INTEL_fragment_shader_ordering)"); + statement("#extension GL_INTEL_fragment_shader_ordering : enable"); + statement("#define SPIRV_Cross_beginInvocationInterlock() beginFragmentShaderOrderingINTEL()"); + statement("#define SPIRV_Cross_endInvocationInterlock()"); + statement("#endif"); + } + else + statement("#extension ", ext, " : require"); + } + + if (!options.vulkan_semantics) + { + using Supp = ShaderSubgroupSupportHelper; + auto result = shader_subgroup_supporter.resolve(); + + for (uint32_t feature_index = 0; feature_index < Supp::FeatureCount; feature_index++) + { + auto feature = static_cast(feature_index); + if (!shader_subgroup_supporter.is_feature_requested(feature)) + continue; + + auto exts = Supp::get_candidates_for_feature(feature, result); + if (exts.empty()) + continue; + + statement(""); + + for (auto &ext : exts) + { + const char *name = Supp::get_extension_name(ext); + const char *extra_predicate = Supp::get_extra_required_extension_predicate(ext); + auto extra_names = Supp::get_extra_required_extension_names(ext); + statement(&ext != &exts.front() ? "#elif" : "#if", " defined(", name, ")", + (*extra_predicate != '\0' ? " && " : ""), extra_predicate); + for (const auto &e : extra_names) + statement("#extension ", e, " : enable"); + statement("#extension ", name, " : require"); + } + + if (!Supp::can_feature_be_implemented_without_extensions(feature)) + { + statement("#else"); + statement("#error No extensions available to emulate requested subgroup feature."); + } + + statement("#endif"); + } + } + + for (auto &header : header_lines) + statement(header); + + SmallVector inputs; + SmallVector outputs; + + switch (execution.model) + { + case ExecutionModelVertex: + if (options.ovr_multiview_view_count) + inputs.push_back(join("num_views = ", options.ovr_multiview_view_count)); + break; + case ExecutionModelGeometry: + if ((execution.flags.get(ExecutionModeInvocations)) && execution.invocations != 1) + inputs.push_back(join("invocations = ", execution.invocations)); + if (execution.flags.get(ExecutionModeInputPoints)) + inputs.push_back("points"); + if (execution.flags.get(ExecutionModeInputLines)) + inputs.push_back("lines"); + if (execution.flags.get(ExecutionModeInputLinesAdjacency)) + inputs.push_back("lines_adjacency"); + if (execution.flags.get(ExecutionModeTriangles)) + inputs.push_back("triangles"); + if (execution.flags.get(ExecutionModeInputTrianglesAdjacency)) + inputs.push_back("triangles_adjacency"); + + if (!execution.geometry_passthrough) + { + // For passthrough, these are implies and cannot be declared in shader. + outputs.push_back(join("max_vertices = ", execution.output_vertices)); + if (execution.flags.get(ExecutionModeOutputTriangleStrip)) + outputs.push_back("triangle_strip"); + if (execution.flags.get(ExecutionModeOutputPoints)) + outputs.push_back("points"); + if (execution.flags.get(ExecutionModeOutputLineStrip)) + outputs.push_back("line_strip"); + } + break; + + case ExecutionModelTessellationControl: + if (execution.flags.get(ExecutionModeOutputVertices)) + outputs.push_back(join("vertices = ", execution.output_vertices)); + break; + + case ExecutionModelTessellationEvaluation: + if (execution.flags.get(ExecutionModeQuads)) + inputs.push_back("quads"); + if (execution.flags.get(ExecutionModeTriangles)) + inputs.push_back("triangles"); + if (execution.flags.get(ExecutionModeIsolines)) + inputs.push_back("isolines"); + if (execution.flags.get(ExecutionModePointMode)) + inputs.push_back("point_mode"); + + if (!execution.flags.get(ExecutionModeIsolines)) + { + if (execution.flags.get(ExecutionModeVertexOrderCw)) + inputs.push_back("cw"); + if (execution.flags.get(ExecutionModeVertexOrderCcw)) + inputs.push_back("ccw"); + } + + if (execution.flags.get(ExecutionModeSpacingFractionalEven)) + inputs.push_back("fractional_even_spacing"); + if (execution.flags.get(ExecutionModeSpacingFractionalOdd)) + inputs.push_back("fractional_odd_spacing"); + if (execution.flags.get(ExecutionModeSpacingEqual)) + inputs.push_back("equal_spacing"); + break; + + case ExecutionModelGLCompute: + case ExecutionModelTaskEXT: + case ExecutionModelMeshEXT: + { + if (execution.workgroup_size.constant != 0 || execution.flags.get(ExecutionModeLocalSizeId)) + { + SpecializationConstant wg_x, wg_y, wg_z; + get_work_group_size_specialization_constants(wg_x, wg_y, wg_z); + + // If there are any spec constants on legacy GLSL, defer declaration, we need to set up macro + // declarations before we can emit the work group size. + if (options.vulkan_semantics || + ((wg_x.id == ConstantID(0)) && (wg_y.id == ConstantID(0)) && (wg_z.id == ConstantID(0)))) + build_workgroup_size(inputs, wg_x, wg_y, wg_z); + } + else + { + inputs.push_back(join("local_size_x = ", execution.workgroup_size.x)); + inputs.push_back(join("local_size_y = ", execution.workgroup_size.y)); + inputs.push_back(join("local_size_z = ", execution.workgroup_size.z)); + } + + if (execution.model == ExecutionModelMeshEXT) + { + outputs.push_back(join("max_vertices = ", execution.output_vertices)); + outputs.push_back(join("max_primitives = ", execution.output_primitives)); + if (execution.flags.get(ExecutionModeOutputTrianglesEXT)) + outputs.push_back("triangles"); + else if (execution.flags.get(ExecutionModeOutputLinesEXT)) + outputs.push_back("lines"); + else if (execution.flags.get(ExecutionModeOutputPoints)) + outputs.push_back("points"); + } + break; + } + + case ExecutionModelFragment: + if (options.es) + { + switch (options.fragment.default_float_precision) + { + case Options::Lowp: + statement("precision lowp float;"); + break; + + case Options::Mediump: + statement("precision mediump float;"); + break; + + case Options::Highp: + statement("precision highp float;"); + break; + + default: + break; + } + + switch (options.fragment.default_int_precision) + { + case Options::Lowp: + statement("precision lowp int;"); + break; + + case Options::Mediump: + statement("precision mediump int;"); + break; + + case Options::Highp: + statement("precision highp int;"); + break; + + default: + break; + } + } + + if (execution.flags.get(ExecutionModeEarlyFragmentTests)) + inputs.push_back("early_fragment_tests"); + if (execution.flags.get(ExecutionModePostDepthCoverage)) + inputs.push_back("post_depth_coverage"); + + if (interlock_used) + statement("#if defined(GL_ARB_fragment_shader_interlock)"); + + if (execution.flags.get(ExecutionModePixelInterlockOrderedEXT)) + statement("layout(pixel_interlock_ordered) in;"); + else if (execution.flags.get(ExecutionModePixelInterlockUnorderedEXT)) + statement("layout(pixel_interlock_unordered) in;"); + else if (execution.flags.get(ExecutionModeSampleInterlockOrderedEXT)) + statement("layout(sample_interlock_ordered) in;"); + else if (execution.flags.get(ExecutionModeSampleInterlockUnorderedEXT)) + statement("layout(sample_interlock_unordered) in;"); + + if (interlock_used) + { + statement("#elif !defined(GL_INTEL_fragment_shader_ordering)"); + statement("#error Fragment Shader Interlock/Ordering extension missing!"); + statement("#endif"); + } + + if (!options.es && execution.flags.get(ExecutionModeDepthGreater)) + statement("layout(depth_greater) out float gl_FragDepth;"); + else if (!options.es && execution.flags.get(ExecutionModeDepthLess)) + statement("layout(depth_less) out float gl_FragDepth;"); + + break; + + default: + break; + } + + for (auto &cap : ir.declared_capabilities) + if (cap == CapabilityRayTraversalPrimitiveCullingKHR) + statement("layout(primitive_culling);"); + + if (!inputs.empty()) + statement("layout(", merge(inputs), ") in;"); + if (!outputs.empty()) + statement("layout(", merge(outputs), ") out;"); + + statement(""); +} + +bool CompilerGLSL::type_is_empty(const SPIRType &type) +{ + return type.basetype == SPIRType::Struct && type.member_types.empty(); +} + +void CompilerGLSL::emit_struct(SPIRType &type) +{ + // Struct types can be stamped out multiple times + // with just different offsets, matrix layouts, etc ... + // Type-punning with these types is legal, which complicates things + // when we are storing struct and array types in an SSBO for example. + // If the type master is packed however, we can no longer assume that the struct declaration will be redundant. + if (type.type_alias != TypeID(0) && + !has_extended_decoration(type.type_alias, SPIRVCrossDecorationBufferBlockRepacked)) + return; + + add_resource_name(type.self); + auto name = type_to_glsl(type); + + statement(!backend.explicit_struct_type ? "struct " : "", name); + begin_scope(); + + type.member_name_cache.clear(); + + uint32_t i = 0; + bool emitted = false; + for (auto &member : type.member_types) + { + add_member_name(type, i); + emit_struct_member(type, member, i); + i++; + emitted = true; + } + + // Don't declare empty structs in GLSL, this is not allowed. + if (type_is_empty(type) && !backend.supports_empty_struct) + { + statement("int empty_struct_member;"); + emitted = true; + } + + if (has_extended_decoration(type.self, SPIRVCrossDecorationPaddingTarget)) + emit_struct_padding_target(type); + + end_scope_decl(); + + if (emitted) + statement(""); +} + +string CompilerGLSL::to_interpolation_qualifiers(const Bitset &flags) +{ + string res; + //if (flags & (1ull << DecorationSmooth)) + // res += "smooth "; + if (flags.get(DecorationFlat)) + res += "flat "; + if (flags.get(DecorationNoPerspective)) + { + if (options.es) + { + if (options.version < 300) + SPIRV_CROSS_THROW("noperspective requires ESSL 300."); + require_extension_internal("GL_NV_shader_noperspective_interpolation"); + } + else if (is_legacy_desktop()) + require_extension_internal("GL_EXT_gpu_shader4"); + res += "noperspective "; + } + if (flags.get(DecorationCentroid)) + res += "centroid "; + if (flags.get(DecorationPatch)) + res += "patch "; + if (flags.get(DecorationSample)) + { + if (options.es) + { + if (options.version < 300) + SPIRV_CROSS_THROW("sample requires ESSL 300."); + else if (options.version < 320) + require_extension_internal("GL_OES_shader_multisample_interpolation"); + } + res += "sample "; + } + if (flags.get(DecorationInvariant) && (options.es || options.version >= 120)) + res += "invariant "; + if (flags.get(DecorationPerPrimitiveEXT)) + { + res += "perprimitiveEXT "; + require_extension_internal("GL_EXT_mesh_shader"); + } + + if (flags.get(DecorationExplicitInterpAMD)) + { + require_extension_internal("GL_AMD_shader_explicit_vertex_parameter"); + res += "__explicitInterpAMD "; + } + + if (flags.get(DecorationPerVertexKHR)) + { + if (options.es && options.version < 320) + SPIRV_CROSS_THROW("pervertexEXT requires ESSL 320."); + else if (!options.es && options.version < 450) + SPIRV_CROSS_THROW("pervertexEXT requires GLSL 450."); + + if (barycentric_is_nv) + { + require_extension_internal("GL_NV_fragment_shader_barycentric"); + res += "pervertexNV "; + } + else + { + require_extension_internal("GL_EXT_fragment_shader_barycentric"); + res += "pervertexEXT "; + } + } + + return res; +} + +string CompilerGLSL::layout_for_member(const SPIRType &type, uint32_t index) +{ + if (is_legacy()) + return ""; + + bool is_block = has_decoration(type.self, DecorationBlock) || has_decoration(type.self, DecorationBufferBlock); + if (!is_block) + return ""; + + auto &memb = ir.meta[type.self].members; + if (index >= memb.size()) + return ""; + auto &dec = memb[index]; + + SmallVector attr; + + if (has_member_decoration(type.self, index, DecorationPassthroughNV)) + attr.push_back("passthrough"); + + // We can only apply layouts on members in block interfaces. + // This is a bit problematic because in SPIR-V decorations are applied on the struct types directly. + // This is not supported on GLSL, so we have to make the assumption that if a struct within our buffer block struct + // has a decoration, it was originally caused by a top-level layout() qualifier in GLSL. + // + // We would like to go from (SPIR-V style): + // + // struct Foo { layout(row_major) mat4 matrix; }; + // buffer UBO { Foo foo; }; + // + // to + // + // struct Foo { mat4 matrix; }; // GLSL doesn't support any layout shenanigans in raw struct declarations. + // buffer UBO { layout(row_major) Foo foo; }; // Apply the layout on top-level. + auto flags = combined_decoration_for_member(type, index); + + if (flags.get(DecorationRowMajor)) + attr.push_back("row_major"); + // We don't emit any global layouts, so column_major is default. + //if (flags & (1ull << DecorationColMajor)) + // attr.push_back("column_major"); + + if (dec.decoration_flags.get(DecorationLocation) && can_use_io_location(type.storage, true)) + attr.push_back(join("location = ", dec.location)); + + // Can only declare component if we can declare location. + if (dec.decoration_flags.get(DecorationComponent) && can_use_io_location(type.storage, true)) + { + if (!options.es) + { + if (options.version < 440 && options.version >= 140) + require_extension_internal("GL_ARB_enhanced_layouts"); + else if (options.version < 140) + SPIRV_CROSS_THROW("Component decoration is not supported in targets below GLSL 1.40."); + attr.push_back(join("component = ", dec.component)); + } + else + SPIRV_CROSS_THROW("Component decoration is not supported in ES targets."); + } + + // SPIRVCrossDecorationPacked is set by layout_for_variable earlier to mark that we need to emit offset qualifiers. + // This is only done selectively in GLSL as needed. + if (has_extended_decoration(type.self, SPIRVCrossDecorationExplicitOffset) && + dec.decoration_flags.get(DecorationOffset)) + attr.push_back(join("offset = ", dec.offset)); + else if (type.storage == StorageClassOutput && dec.decoration_flags.get(DecorationOffset)) + attr.push_back(join("xfb_offset = ", dec.offset)); + + if (attr.empty()) + return ""; + + string res = "layout("; + res += merge(attr); + res += ") "; + return res; +} + +const char *CompilerGLSL::format_to_glsl(spv::ImageFormat format) +{ + if (options.es && is_desktop_only_format(format)) + SPIRV_CROSS_THROW("Attempting to use image format not supported in ES profile."); + + switch (format) + { + case ImageFormatRgba32f: + return "rgba32f"; + case ImageFormatRgba16f: + return "rgba16f"; + case ImageFormatR32f: + return "r32f"; + case ImageFormatRgba8: + return "rgba8"; + case ImageFormatRgba8Snorm: + return "rgba8_snorm"; + case ImageFormatRg32f: + return "rg32f"; + case ImageFormatRg16f: + return "rg16f"; + case ImageFormatRgba32i: + return "rgba32i"; + case ImageFormatRgba16i: + return "rgba16i"; + case ImageFormatR32i: + return "r32i"; + case ImageFormatRgba8i: + return "rgba8i"; + case ImageFormatRg32i: + return "rg32i"; + case ImageFormatRg16i: + return "rg16i"; + case ImageFormatRgba32ui: + return "rgba32ui"; + case ImageFormatRgba16ui: + return "rgba16ui"; + case ImageFormatR32ui: + return "r32ui"; + case ImageFormatRgba8ui: + return "rgba8ui"; + case ImageFormatRg32ui: + return "rg32ui"; + case ImageFormatRg16ui: + return "rg16ui"; + case ImageFormatR11fG11fB10f: + return "r11f_g11f_b10f"; + case ImageFormatR16f: + return "r16f"; + case ImageFormatRgb10A2: + return "rgb10_a2"; + case ImageFormatR8: + return "r8"; + case ImageFormatRg8: + return "rg8"; + case ImageFormatR16: + return "r16"; + case ImageFormatRg16: + return "rg16"; + case ImageFormatRgba16: + return "rgba16"; + case ImageFormatR16Snorm: + return "r16_snorm"; + case ImageFormatRg16Snorm: + return "rg16_snorm"; + case ImageFormatRgba16Snorm: + return "rgba16_snorm"; + case ImageFormatR8Snorm: + return "r8_snorm"; + case ImageFormatRg8Snorm: + return "rg8_snorm"; + case ImageFormatR8ui: + return "r8ui"; + case ImageFormatRg8ui: + return "rg8ui"; + case ImageFormatR16ui: + return "r16ui"; + case ImageFormatRgb10a2ui: + return "rgb10_a2ui"; + case ImageFormatR8i: + return "r8i"; + case ImageFormatRg8i: + return "rg8i"; + case ImageFormatR16i: + return "r16i"; + case ImageFormatR64i: + return "r64i"; + case ImageFormatR64ui: + return "r64ui"; + default: + case ImageFormatUnknown: + return nullptr; + } +} + +uint32_t CompilerGLSL::type_to_packed_base_size(const SPIRType &type, BufferPackingStandard) +{ + switch (type.basetype) + { + case SPIRType::Double: + case SPIRType::Int64: + case SPIRType::UInt64: + return 8; + case SPIRType::Float: + case SPIRType::Int: + case SPIRType::UInt: + return 4; + case SPIRType::Half: + case SPIRType::Short: + case SPIRType::UShort: + return 2; + case SPIRType::SByte: + case SPIRType::UByte: + return 1; + + default: + SPIRV_CROSS_THROW("Unrecognized type in type_to_packed_base_size."); + } +} + +uint32_t CompilerGLSL::type_to_packed_alignment(const SPIRType &type, const Bitset &flags, + BufferPackingStandard packing) +{ + // If using PhysicalStorageBufferEXT storage class, this is a pointer, + // and is 64-bit. + if (is_physical_pointer(type)) + { + if (!type.pointer) + SPIRV_CROSS_THROW("Types in PhysicalStorageBufferEXT must be pointers."); + + if (ir.addressing_model == AddressingModelPhysicalStorageBuffer64EXT) + { + if (packing_is_vec4_padded(packing) && type_is_array_of_pointers(type)) + return 16; + else + return 8; + } + else + SPIRV_CROSS_THROW("AddressingModelPhysicalStorageBuffer64EXT must be used for PhysicalStorageBufferEXT."); + } + else if (is_array(type)) + { + uint32_t minimum_alignment = 1; + if (packing_is_vec4_padded(packing)) + minimum_alignment = 16; + + auto *tmp = &get(type.parent_type); + while (!tmp->array.empty()) + tmp = &get(tmp->parent_type); + + // Get the alignment of the base type, then maybe round up. + return max(minimum_alignment, type_to_packed_alignment(*tmp, flags, packing)); + } + + if (type.basetype == SPIRType::Struct) + { + // Rule 9. Structs alignments are maximum alignment of its members. + uint32_t alignment = 1; + for (uint32_t i = 0; i < type.member_types.size(); i++) + { + auto member_flags = ir.meta[type.self].members[i].decoration_flags; + alignment = + max(alignment, type_to_packed_alignment(get(type.member_types[i]), member_flags, packing)); + } + + // In std140, struct alignment is rounded up to 16. + if (packing_is_vec4_padded(packing)) + alignment = max(alignment, 16u); + + return alignment; + } + else + { + const uint32_t base_alignment = type_to_packed_base_size(type, packing); + + // Alignment requirement for scalar block layout is always the alignment for the most basic component. + if (packing_is_scalar(packing)) + return base_alignment; + + // Vectors are *not* aligned in HLSL, but there's an extra rule where vectors cannot straddle + // a vec4, this is handled outside since that part knows our current offset. + if (type.columns == 1 && packing_is_hlsl(packing)) + return base_alignment; + + // From 7.6.2.2 in GL 4.5 core spec. + // Rule 1 + if (type.vecsize == 1 && type.columns == 1) + return base_alignment; + + // Rule 2 + if ((type.vecsize == 2 || type.vecsize == 4) && type.columns == 1) + return type.vecsize * base_alignment; + + // Rule 3 + if (type.vecsize == 3 && type.columns == 1) + return 4 * base_alignment; + + // Rule 4 implied. Alignment does not change in std430. + + // Rule 5. Column-major matrices are stored as arrays of + // vectors. + if (flags.get(DecorationColMajor) && type.columns > 1) + { + if (packing_is_vec4_padded(packing)) + return 4 * base_alignment; + else if (type.vecsize == 3) + return 4 * base_alignment; + else + return type.vecsize * base_alignment; + } + + // Rule 6 implied. + + // Rule 7. + if (flags.get(DecorationRowMajor) && type.vecsize > 1) + { + if (packing_is_vec4_padded(packing)) + return 4 * base_alignment; + else if (type.columns == 3) + return 4 * base_alignment; + else + return type.columns * base_alignment; + } + + // Rule 8 implied. + } + + SPIRV_CROSS_THROW("Did not find suitable rule for type. Bogus decorations?"); +} + +uint32_t CompilerGLSL::type_to_packed_array_stride(const SPIRType &type, const Bitset &flags, + BufferPackingStandard packing) +{ + // Array stride is equal to aligned size of the underlying type. + uint32_t parent = type.parent_type; + assert(parent); + + auto &tmp = get(parent); + + uint32_t size = type_to_packed_size(tmp, flags, packing); + uint32_t alignment = type_to_packed_alignment(type, flags, packing); + return (size + alignment - 1) & ~(alignment - 1); +} + +uint32_t CompilerGLSL::type_to_packed_size(const SPIRType &type, const Bitset &flags, BufferPackingStandard packing) +{ + // If using PhysicalStorageBufferEXT storage class, this is a pointer, + // and is 64-bit. + if (is_physical_pointer(type)) + { + if (!type.pointer) + SPIRV_CROSS_THROW("Types in PhysicalStorageBufferEXT must be pointers."); + + if (ir.addressing_model == AddressingModelPhysicalStorageBuffer64EXT) + return 8; + else + SPIRV_CROSS_THROW("AddressingModelPhysicalStorageBuffer64EXT must be used for PhysicalStorageBufferEXT."); + } + else if (is_array(type)) + { + uint32_t packed_size = to_array_size_literal(type) * type_to_packed_array_stride(type, flags, packing); + + // For arrays of vectors and matrices in HLSL, the last element has a size which depends on its vector size, + // so that it is possible to pack other vectors into the last element. + if (packing_is_hlsl(packing) && type.basetype != SPIRType::Struct) + packed_size -= (4 - type.vecsize) * (type.width / 8); + + return packed_size; + } + + uint32_t size = 0; + + if (type.basetype == SPIRType::Struct) + { + uint32_t pad_alignment = 1; + + for (uint32_t i = 0; i < type.member_types.size(); i++) + { + auto member_flags = ir.meta[type.self].members[i].decoration_flags; + auto &member_type = get(type.member_types[i]); + + uint32_t packed_alignment = type_to_packed_alignment(member_type, member_flags, packing); + uint32_t alignment = max(packed_alignment, pad_alignment); + + // The next member following a struct member is aligned to the base alignment of the struct that came before. + // GL 4.5 spec, 7.6.2.2. + if (member_type.basetype == SPIRType::Struct) + pad_alignment = packed_alignment; + else + pad_alignment = 1; + + size = (size + alignment - 1) & ~(alignment - 1); + size += type_to_packed_size(member_type, member_flags, packing); + } + } + else + { + const uint32_t base_alignment = type_to_packed_base_size(type, packing); + + if (packing_is_scalar(packing)) + { + size = type.vecsize * type.columns * base_alignment; + } + else + { + if (type.columns == 1) + size = type.vecsize * base_alignment; + + if (flags.get(DecorationColMajor) && type.columns > 1) + { + if (packing_is_vec4_padded(packing)) + size = type.columns * 4 * base_alignment; + else if (type.vecsize == 3) + size = type.columns * 4 * base_alignment; + else + size = type.columns * type.vecsize * base_alignment; + } + + if (flags.get(DecorationRowMajor) && type.vecsize > 1) + { + if (packing_is_vec4_padded(packing)) + size = type.vecsize * 4 * base_alignment; + else if (type.columns == 3) + size = type.vecsize * 4 * base_alignment; + else + size = type.vecsize * type.columns * base_alignment; + } + + // For matrices in HLSL, the last element has a size which depends on its vector size, + // so that it is possible to pack other vectors into the last element. + if (packing_is_hlsl(packing) && type.columns > 1) + size -= (4 - type.vecsize) * (type.width / 8); + } + } + + return size; +} + +bool CompilerGLSL::buffer_is_packing_standard(const SPIRType &type, BufferPackingStandard packing, + uint32_t *failed_validation_index, uint32_t start_offset, + uint32_t end_offset) +{ + // This is very tricky and error prone, but try to be exhaustive and correct here. + // SPIR-V doesn't directly say if we're using std430 or std140. + // SPIR-V communicates this using Offset and ArrayStride decorations (which is what really matters), + // so we have to try to infer whether or not the original GLSL source was std140 or std430 based on this information. + // We do not have to consider shared or packed since these layouts are not allowed in Vulkan SPIR-V (they are useless anyways, and custom offsets would do the same thing). + // + // It is almost certain that we're using std430, but it gets tricky with arrays in particular. + // We will assume std430, but infer std140 if we can prove the struct is not compliant with std430. + // + // The only two differences between std140 and std430 are related to padding alignment/array stride + // in arrays and structs. In std140 they take minimum vec4 alignment. + // std430 only removes the vec4 requirement. + + uint32_t offset = 0; + uint32_t pad_alignment = 1; + + bool is_top_level_block = + has_decoration(type.self, DecorationBlock) || has_decoration(type.self, DecorationBufferBlock); + + for (uint32_t i = 0; i < type.member_types.size(); i++) + { + auto &memb_type = get(type.member_types[i]); + + auto *type_meta = ir.find_meta(type.self); + auto member_flags = type_meta ? type_meta->members[i].decoration_flags : Bitset{}; + + // Verify alignment rules. + uint32_t packed_alignment = type_to_packed_alignment(memb_type, member_flags, packing); + + // This is a rather dirty workaround to deal with some cases of OpSpecConstantOp used as array size, e.g: + // layout(constant_id = 0) const int s = 10; + // const int S = s + 5; // SpecConstantOp + // buffer Foo { int data[S]; }; // <-- Very hard for us to deduce a fixed value here, + // we would need full implementation of compile-time constant folding. :( + // If we are the last member of a struct, there might be cases where the actual size of that member is irrelevant + // for our analysis (e.g. unsized arrays). + // This lets us simply ignore that there are spec constant op sized arrays in our buffers. + // Querying size of this member will fail, so just don't call it unless we have to. + // + // This is likely "best effort" we can support without going into unacceptably complicated workarounds. + bool member_can_be_unsized = + is_top_level_block && size_t(i + 1) == type.member_types.size() && !memb_type.array.empty(); + + uint32_t packed_size = 0; + if (!member_can_be_unsized || packing_is_hlsl(packing)) + packed_size = type_to_packed_size(memb_type, member_flags, packing); + + // We only need to care about this if we have non-array types which can straddle the vec4 boundary. + uint32_t actual_offset = type_struct_member_offset(type, i); + + if (packing_is_hlsl(packing)) + { + // If a member straddles across a vec4 boundary, alignment is actually vec4. + uint32_t target_offset; + + // If we intend to use explicit packing, we must check for improper straddle with that offset. + // In implicit packing, we must check with implicit offset, since the explicit offset + // might have already accounted for the straddle, and we'd miss the alignment promotion to vec4. + // This is important when packing sub-structs that don't support packoffset(). + if (packing_has_flexible_offset(packing)) + target_offset = actual_offset; + else + target_offset = offset; + + uint32_t begin_word = target_offset / 16; + uint32_t end_word = (target_offset + packed_size - 1) / 16; + + if (begin_word != end_word) + packed_alignment = max(packed_alignment, 16u); + } + + // Field is not in the specified range anymore and we can ignore any further fields. + if (actual_offset >= end_offset) + break; + + uint32_t alignment = max(packed_alignment, pad_alignment); + offset = (offset + alignment - 1) & ~(alignment - 1); + + // The next member following a struct member is aligned to the base alignment of the struct that came before. + // GL 4.5 spec, 7.6.2.2. + if (memb_type.basetype == SPIRType::Struct && !memb_type.pointer) + pad_alignment = packed_alignment; + else + pad_alignment = 1; + + // Only care about packing if we are in the given range + if (actual_offset >= start_offset) + { + // We only care about offsets in std140, std430, etc ... + // For EnhancedLayout variants, we have the flexibility to choose our own offsets. + if (!packing_has_flexible_offset(packing)) + { + if (actual_offset != offset) // This cannot be the packing we're looking for. + { + if (failed_validation_index) + *failed_validation_index = i; + return false; + } + } + else if ((actual_offset & (alignment - 1)) != 0) + { + // We still need to verify that alignment rules are observed, even if we have explicit offset. + if (failed_validation_index) + *failed_validation_index = i; + return false; + } + + // Verify array stride rules. + if (is_array(memb_type) && + type_to_packed_array_stride(memb_type, member_flags, packing) != + type_struct_member_array_stride(type, i)) + { + if (failed_validation_index) + *failed_validation_index = i; + return false; + } + + // Verify that sub-structs also follow packing rules. + // We cannot use enhanced layouts on substructs, so they better be up to spec. + auto substruct_packing = packing_to_substruct_packing(packing); + + if (!memb_type.pointer && !memb_type.member_types.empty() && + !buffer_is_packing_standard(memb_type, substruct_packing)) + { + if (failed_validation_index) + *failed_validation_index = i; + return false; + } + } + + // Bump size. + offset = actual_offset + packed_size; + } + + return true; +} + +bool CompilerGLSL::can_use_io_location(StorageClass storage, bool block) +{ + // Location specifiers are must have in SPIR-V, but they aren't really supported in earlier versions of GLSL. + // Be very explicit here about how to solve the issue. + if ((get_execution_model() != ExecutionModelVertex && storage == StorageClassInput) || + (get_execution_model() != ExecutionModelFragment && storage == StorageClassOutput)) + { + uint32_t minimum_desktop_version = block ? 440 : 410; + // ARB_enhanced_layouts vs ARB_separate_shader_objects ... + + if (!options.es && options.version < minimum_desktop_version && !options.separate_shader_objects) + return false; + else if (options.es && options.version < 310) + return false; + } + + if ((get_execution_model() == ExecutionModelVertex && storage == StorageClassInput) || + (get_execution_model() == ExecutionModelFragment && storage == StorageClassOutput)) + { + if (options.es && options.version < 300) + return false; + else if (!options.es && options.version < 330) + return false; + } + + if (storage == StorageClassUniform || storage == StorageClassUniformConstant || storage == StorageClassPushConstant) + { + if (options.es && options.version < 310) + return false; + else if (!options.es && options.version < 430) + return false; + } + + return true; +} + +string CompilerGLSL::layout_for_variable(const SPIRVariable &var) +{ + // FIXME: Come up with a better solution for when to disable layouts. + // Having layouts depend on extensions as well as which types + // of layouts are used. For now, the simple solution is to just disable + // layouts for legacy versions. + if (is_legacy()) + return ""; + + if (subpass_input_is_framebuffer_fetch(var.self)) + return ""; + + SmallVector attr; + + auto &type = get(var.basetype); + auto &flags = get_decoration_bitset(var.self); + auto &typeflags = get_decoration_bitset(type.self); + + if (flags.get(DecorationPassthroughNV)) + attr.push_back("passthrough"); + + if (options.vulkan_semantics && var.storage == StorageClassPushConstant) + attr.push_back("push_constant"); + else if (var.storage == StorageClassShaderRecordBufferKHR) + attr.push_back(ray_tracing_is_khr ? "shaderRecordEXT" : "shaderRecordNV"); + + if (flags.get(DecorationRowMajor)) + attr.push_back("row_major"); + if (flags.get(DecorationColMajor)) + attr.push_back("column_major"); + + if (options.vulkan_semantics) + { + if (flags.get(DecorationInputAttachmentIndex)) + attr.push_back(join("input_attachment_index = ", get_decoration(var.self, DecorationInputAttachmentIndex))); + } + + bool is_block = has_decoration(type.self, DecorationBlock); + if (flags.get(DecorationLocation) && can_use_io_location(var.storage, is_block)) + { + Bitset combined_decoration; + for (uint32_t i = 0; i < ir.meta[type.self].members.size(); i++) + combined_decoration.merge_or(combined_decoration_for_member(type, i)); + + // If our members have location decorations, we don't need to + // emit location decorations at the top as well (looks weird). + if (!combined_decoration.get(DecorationLocation)) + attr.push_back(join("location = ", get_decoration(var.self, DecorationLocation))); + } + + if (get_execution_model() == ExecutionModelFragment && var.storage == StorageClassOutput && + location_is_non_coherent_framebuffer_fetch(get_decoration(var.self, DecorationLocation))) + { + attr.push_back("noncoherent"); + } + + // Transform feedback + bool uses_enhanced_layouts = false; + if (is_block && var.storage == StorageClassOutput) + { + // For blocks, there is a restriction where xfb_stride/xfb_buffer must only be declared on the block itself, + // since all members must match the same xfb_buffer. The only thing we will declare for members of the block + // is the xfb_offset. + uint32_t member_count = uint32_t(type.member_types.size()); + bool have_xfb_buffer_stride = false; + bool have_any_xfb_offset = false; + bool have_geom_stream = false; + uint32_t xfb_stride = 0, xfb_buffer = 0, geom_stream = 0; + + if (flags.get(DecorationXfbBuffer) && flags.get(DecorationXfbStride)) + { + have_xfb_buffer_stride = true; + xfb_buffer = get_decoration(var.self, DecorationXfbBuffer); + xfb_stride = get_decoration(var.self, DecorationXfbStride); + } + + if (flags.get(DecorationStream)) + { + have_geom_stream = true; + geom_stream = get_decoration(var.self, DecorationStream); + } + + // Verify that none of the members violate our assumption. + for (uint32_t i = 0; i < member_count; i++) + { + if (has_member_decoration(type.self, i, DecorationStream)) + { + uint32_t member_geom_stream = get_member_decoration(type.self, i, DecorationStream); + if (have_geom_stream && member_geom_stream != geom_stream) + SPIRV_CROSS_THROW("IO block member Stream mismatch."); + have_geom_stream = true; + geom_stream = member_geom_stream; + } + + // Only members with an Offset decoration participate in XFB. + if (!has_member_decoration(type.self, i, DecorationOffset)) + continue; + have_any_xfb_offset = true; + + if (has_member_decoration(type.self, i, DecorationXfbBuffer)) + { + uint32_t buffer_index = get_member_decoration(type.self, i, DecorationXfbBuffer); + if (have_xfb_buffer_stride && buffer_index != xfb_buffer) + SPIRV_CROSS_THROW("IO block member XfbBuffer mismatch."); + have_xfb_buffer_stride = true; + xfb_buffer = buffer_index; + } + + if (has_member_decoration(type.self, i, DecorationXfbStride)) + { + uint32_t stride = get_member_decoration(type.self, i, DecorationXfbStride); + if (have_xfb_buffer_stride && stride != xfb_stride) + SPIRV_CROSS_THROW("IO block member XfbStride mismatch."); + have_xfb_buffer_stride = true; + xfb_stride = stride; + } + } + + if (have_xfb_buffer_stride && have_any_xfb_offset) + { + attr.push_back(join("xfb_buffer = ", xfb_buffer)); + attr.push_back(join("xfb_stride = ", xfb_stride)); + uses_enhanced_layouts = true; + } + + if (have_geom_stream) + { + if (get_execution_model() != ExecutionModelGeometry) + SPIRV_CROSS_THROW("Geometry streams can only be used in geometry shaders."); + if (options.es) + SPIRV_CROSS_THROW("Multiple geometry streams not supported in ESSL."); + if (options.version < 400) + require_extension_internal("GL_ARB_transform_feedback3"); + attr.push_back(join("stream = ", get_decoration(var.self, DecorationStream))); + } + } + else if (var.storage == StorageClassOutput) + { + if (flags.get(DecorationXfbBuffer) && flags.get(DecorationXfbStride) && flags.get(DecorationOffset)) + { + // XFB for standalone variables, we can emit all decorations. + attr.push_back(join("xfb_buffer = ", get_decoration(var.self, DecorationXfbBuffer))); + attr.push_back(join("xfb_stride = ", get_decoration(var.self, DecorationXfbStride))); + attr.push_back(join("xfb_offset = ", get_decoration(var.self, DecorationOffset))); + uses_enhanced_layouts = true; + } + + if (flags.get(DecorationStream)) + { + if (get_execution_model() != ExecutionModelGeometry) + SPIRV_CROSS_THROW("Geometry streams can only be used in geometry shaders."); + if (options.es) + SPIRV_CROSS_THROW("Multiple geometry streams not supported in ESSL."); + if (options.version < 400) + require_extension_internal("GL_ARB_transform_feedback3"); + attr.push_back(join("stream = ", get_decoration(var.self, DecorationStream))); + } + } + + // Can only declare Component if we can declare location. + if (flags.get(DecorationComponent) && can_use_io_location(var.storage, is_block)) + { + uses_enhanced_layouts = true; + attr.push_back(join("component = ", get_decoration(var.self, DecorationComponent))); + } + + if (uses_enhanced_layouts) + { + if (!options.es) + { + if (options.version < 440 && options.version >= 140) + require_extension_internal("GL_ARB_enhanced_layouts"); + else if (options.version < 140) + SPIRV_CROSS_THROW("GL_ARB_enhanced_layouts is not supported in targets below GLSL 1.40."); + if (!options.es && options.version < 440) + require_extension_internal("GL_ARB_enhanced_layouts"); + } + else if (options.es) + SPIRV_CROSS_THROW("GL_ARB_enhanced_layouts is not supported in ESSL."); + } + + if (flags.get(DecorationIndex)) + attr.push_back(join("index = ", get_decoration(var.self, DecorationIndex))); + + // Do not emit set = decoration in regular GLSL output, but + // we need to preserve it in Vulkan GLSL mode. + if (var.storage != StorageClassPushConstant && var.storage != StorageClassShaderRecordBufferKHR) + { + if (flags.get(DecorationDescriptorSet) && options.vulkan_semantics) + attr.push_back(join("set = ", get_decoration(var.self, DecorationDescriptorSet))); + } + + bool push_constant_block = options.vulkan_semantics && var.storage == StorageClassPushConstant; + bool ssbo_block = var.storage == StorageClassStorageBuffer || var.storage == StorageClassShaderRecordBufferKHR || + (var.storage == StorageClassUniform && typeflags.get(DecorationBufferBlock)); + bool emulated_ubo = var.storage == StorageClassPushConstant && options.emit_push_constant_as_uniform_buffer; + bool ubo_block = var.storage == StorageClassUniform && typeflags.get(DecorationBlock); + + // GL 3.0/GLSL 1.30 is not considered legacy, but it doesn't have UBOs ... + bool can_use_buffer_blocks = (options.es && options.version >= 300) || (!options.es && options.version >= 140); + + // pretend no UBOs when options say so + if (ubo_block && options.emit_uniform_buffer_as_plain_uniforms) + can_use_buffer_blocks = false; + + bool can_use_binding; + if (options.es) + can_use_binding = options.version >= 310; + else + can_use_binding = options.enable_420pack_extension || (options.version >= 420); + + // Make sure we don't emit binding layout for a classic uniform on GLSL 1.30. + if (!can_use_buffer_blocks && var.storage == StorageClassUniform) + can_use_binding = false; + + if (var.storage == StorageClassShaderRecordBufferKHR) + can_use_binding = false; + + if (can_use_binding && flags.get(DecorationBinding)) + attr.push_back(join("binding = ", get_decoration(var.self, DecorationBinding))); + + if (var.storage != StorageClassOutput && flags.get(DecorationOffset)) + attr.push_back(join("offset = ", get_decoration(var.self, DecorationOffset))); + + // Instead of adding explicit offsets for every element here, just assume we're using std140 or std430. + // If SPIR-V does not comply with either layout, we cannot really work around it. + if (can_use_buffer_blocks && (ubo_block || emulated_ubo)) + { + attr.push_back(buffer_to_packing_standard(type, false, true)); + } + else if (can_use_buffer_blocks && (push_constant_block || ssbo_block)) + { + attr.push_back(buffer_to_packing_standard(type, true, true)); + } + + // For images, the type itself adds a layout qualifer. + // Only emit the format for storage images. + if (type.basetype == SPIRType::Image && type.image.sampled == 2) + { + const char *fmt = format_to_glsl(type.image.format); + if (fmt) + attr.push_back(fmt); + } + + if (attr.empty()) + return ""; + + string res = "layout("; + res += merge(attr); + res += ") "; + return res; +} + +string CompilerGLSL::buffer_to_packing_standard(const SPIRType &type, + bool support_std430_without_scalar_layout, + bool support_enhanced_layouts) +{ + if (support_std430_without_scalar_layout && buffer_is_packing_standard(type, BufferPackingStd430)) + return "std430"; + else if (buffer_is_packing_standard(type, BufferPackingStd140)) + return "std140"; + else if (options.vulkan_semantics && buffer_is_packing_standard(type, BufferPackingScalar)) + { + require_extension_internal("GL_EXT_scalar_block_layout"); + return "scalar"; + } + else if (support_std430_without_scalar_layout && + support_enhanced_layouts && + buffer_is_packing_standard(type, BufferPackingStd430EnhancedLayout)) + { + if (options.es && !options.vulkan_semantics) + SPIRV_CROSS_THROW("Push constant block cannot be expressed as neither std430 nor std140. ES-targets do " + "not support GL_ARB_enhanced_layouts."); + if (!options.es && !options.vulkan_semantics && options.version < 440) + require_extension_internal("GL_ARB_enhanced_layouts"); + + set_extended_decoration(type.self, SPIRVCrossDecorationExplicitOffset); + return "std430"; + } + else if (support_enhanced_layouts && + buffer_is_packing_standard(type, BufferPackingStd140EnhancedLayout)) + { + // Fallback time. We might be able to use the ARB_enhanced_layouts to deal with this difference, + // however, we can only use layout(offset) on the block itself, not any substructs, so the substructs better be the appropriate layout. + // Enhanced layouts seem to always work in Vulkan GLSL, so no need for extensions there. + if (options.es && !options.vulkan_semantics) + SPIRV_CROSS_THROW("Push constant block cannot be expressed as neither std430 nor std140. ES-targets do " + "not support GL_ARB_enhanced_layouts."); + if (!options.es && !options.vulkan_semantics && options.version < 440) + require_extension_internal("GL_ARB_enhanced_layouts"); + + set_extended_decoration(type.self, SPIRVCrossDecorationExplicitOffset); + return "std140"; + } + else if (options.vulkan_semantics && + support_enhanced_layouts && + buffer_is_packing_standard(type, BufferPackingScalarEnhancedLayout)) + { + set_extended_decoration(type.self, SPIRVCrossDecorationExplicitOffset); + require_extension_internal("GL_EXT_scalar_block_layout"); + return "scalar"; + } + else if (!support_std430_without_scalar_layout && options.vulkan_semantics && + buffer_is_packing_standard(type, BufferPackingStd430)) + { + // UBOs can support std430 with GL_EXT_scalar_block_layout. + require_extension_internal("GL_EXT_scalar_block_layout"); + return "std430"; + } + else if (!support_std430_without_scalar_layout && options.vulkan_semantics && + support_enhanced_layouts && + buffer_is_packing_standard(type, BufferPackingStd430EnhancedLayout)) + { + // UBOs can support std430 with GL_EXT_scalar_block_layout. + set_extended_decoration(type.self, SPIRVCrossDecorationExplicitOffset); + require_extension_internal("GL_EXT_scalar_block_layout"); + return "std430"; + } + else + { + SPIRV_CROSS_THROW("Buffer block cannot be expressed as any of std430, std140, scalar, even with enhanced " + "layouts. You can try flattening this block to support a more flexible layout."); + } +} + +void CompilerGLSL::emit_push_constant_block(const SPIRVariable &var) +{ + if (flattened_buffer_blocks.count(var.self)) + emit_buffer_block_flattened(var); + else if (options.vulkan_semantics) + emit_push_constant_block_vulkan(var); + else if (options.emit_push_constant_as_uniform_buffer) + emit_buffer_block_native(var); + else + emit_push_constant_block_glsl(var); +} + +void CompilerGLSL::emit_push_constant_block_vulkan(const SPIRVariable &var) +{ + emit_buffer_block(var); +} + +void CompilerGLSL::emit_push_constant_block_glsl(const SPIRVariable &var) +{ + // OpenGL has no concept of push constant blocks, implement it as a uniform struct. + auto &type = get(var.basetype); + + unset_decoration(var.self, DecorationBinding); + unset_decoration(var.self, DecorationDescriptorSet); + +#if 0 + if (flags & ((1ull << DecorationBinding) | (1ull << DecorationDescriptorSet))) + SPIRV_CROSS_THROW("Push constant blocks cannot be compiled to GLSL with Binding or Set syntax. " + "Remap to location with reflection API first or disable these decorations."); +#endif + + // We're emitting the push constant block as a regular struct, so disable the block qualifier temporarily. + // Otherwise, we will end up emitting layout() qualifiers on naked structs which is not allowed. + bool block_flag = has_decoration(type.self, DecorationBlock); + unset_decoration(type.self, DecorationBlock); + + emit_struct(type); + + if (block_flag) + set_decoration(type.self, DecorationBlock); + + emit_uniform(var); + statement(""); +} + +void CompilerGLSL::emit_buffer_block(const SPIRVariable &var) +{ + auto &type = get(var.basetype); + bool ubo_block = var.storage == StorageClassUniform && has_decoration(type.self, DecorationBlock); + + if (flattened_buffer_blocks.count(var.self)) + emit_buffer_block_flattened(var); + else if (is_legacy() || (!options.es && options.version == 130) || + (ubo_block && options.emit_uniform_buffer_as_plain_uniforms)) + emit_buffer_block_legacy(var); + else + emit_buffer_block_native(var); +} + +void CompilerGLSL::emit_buffer_block_legacy(const SPIRVariable &var) +{ + auto &type = get(var.basetype); + bool ssbo = var.storage == StorageClassStorageBuffer || + ir.meta[type.self].decoration.decoration_flags.get(DecorationBufferBlock); + if (ssbo) + SPIRV_CROSS_THROW("SSBOs not supported in legacy targets."); + + // We're emitting the push constant block as a regular struct, so disable the block qualifier temporarily. + // Otherwise, we will end up emitting layout() qualifiers on naked structs which is not allowed. + auto &block_flags = ir.meta[type.self].decoration.decoration_flags; + bool block_flag = block_flags.get(DecorationBlock); + block_flags.clear(DecorationBlock); + emit_struct(type); + if (block_flag) + block_flags.set(DecorationBlock); + emit_uniform(var); + statement(""); +} + +void CompilerGLSL::emit_buffer_reference_block(uint32_t type_id, bool forward_declaration) +{ + auto &type = get(type_id); + string buffer_name; + + if (forward_declaration && is_physical_pointer_to_buffer_block(type)) + { + // Block names should never alias, but from HLSL input they kind of can because block types are reused for UAVs ... + // Allow aliased name since we might be declaring the block twice. Once with buffer reference (forward declared) and one proper declaration. + // The names must match up. + buffer_name = to_name(type.self, false); + + // Shaders never use the block by interface name, so we don't + // have to track this other than updating name caches. + // If we have a collision for any reason, just fallback immediately. + if (ir.meta[type.self].decoration.alias.empty() || + block_ssbo_names.find(buffer_name) != end(block_ssbo_names) || + resource_names.find(buffer_name) != end(resource_names)) + { + buffer_name = join("_", type.self); + } + + // Make sure we get something unique for both global name scope and block name scope. + // See GLSL 4.5 spec: section 4.3.9 for details. + add_variable(block_ssbo_names, resource_names, buffer_name); + + // If for some reason buffer_name is an illegal name, make a final fallback to a workaround name. + // This cannot conflict with anything else, so we're safe now. + // We cannot reuse this fallback name in neither global scope (blocked by block_names) nor block name scope. + if (buffer_name.empty()) + buffer_name = join("_", type.self); + + block_names.insert(buffer_name); + block_ssbo_names.insert(buffer_name); + + // Ensure we emit the correct name when emitting non-forward pointer type. + ir.meta[type.self].decoration.alias = buffer_name; + } + else + { + buffer_name = type_to_glsl(type); + } + + if (!forward_declaration) + { + auto itr = physical_storage_type_to_alignment.find(type_id); + uint32_t alignment = 0; + if (itr != physical_storage_type_to_alignment.end()) + alignment = itr->second.alignment; + + if (is_physical_pointer_to_buffer_block(type)) + { + SmallVector attributes; + attributes.push_back("buffer_reference"); + if (alignment) + attributes.push_back(join("buffer_reference_align = ", alignment)); + attributes.push_back(buffer_to_packing_standard(type, true, true)); + + auto flags = ir.get_buffer_block_type_flags(type); + string decorations; + if (flags.get(DecorationRestrict)) + decorations += " restrict"; + if (flags.get(DecorationCoherent)) + decorations += " coherent"; + if (flags.get(DecorationNonReadable)) + decorations += " writeonly"; + if (flags.get(DecorationNonWritable)) + decorations += " readonly"; + + statement("layout(", merge(attributes), ")", decorations, " buffer ", buffer_name); + } + else + { + string packing_standard; + if (type.basetype == SPIRType::Struct) + { + // The non-block type is embedded in a block, so we cannot use enhanced layouts :( + packing_standard = buffer_to_packing_standard(type, true, false) + ", "; + } + else if (is_array(get_pointee_type(type))) + { + SPIRType wrap_type{OpTypeStruct}; + wrap_type.self = ir.increase_bound_by(1); + wrap_type.member_types.push_back(get_pointee_type_id(type_id)); + ir.set_member_decoration(wrap_type.self, 0, DecorationOffset, 0); + packing_standard = buffer_to_packing_standard(wrap_type, true, false) + ", "; + } + + if (alignment) + statement("layout(", packing_standard, "buffer_reference, buffer_reference_align = ", alignment, ") buffer ", buffer_name); + else + statement("layout(", packing_standard, "buffer_reference) buffer ", buffer_name); + } + + begin_scope(); + + if (is_physical_pointer_to_buffer_block(type)) + { + type.member_name_cache.clear(); + + uint32_t i = 0; + for (auto &member : type.member_types) + { + add_member_name(type, i); + emit_struct_member(type, member, i); + i++; + } + } + else + { + auto &pointee_type = get_pointee_type(type); + statement(type_to_glsl(pointee_type), " value", type_to_array_glsl(pointee_type, 0), ";"); + } + + end_scope_decl(); + statement(""); + } + else + { + statement("layout(buffer_reference) buffer ", buffer_name, ";"); + } +} + +void CompilerGLSL::emit_buffer_block_native(const SPIRVariable &var) +{ + auto &type = get(var.basetype); + + Bitset flags = ir.get_buffer_block_flags(var); + bool ssbo = var.storage == StorageClassStorageBuffer || var.storage == StorageClassShaderRecordBufferKHR || + ir.meta[type.self].decoration.decoration_flags.get(DecorationBufferBlock); + bool is_restrict = ssbo && flags.get(DecorationRestrict); + bool is_writeonly = ssbo && flags.get(DecorationNonReadable); + bool is_readonly = ssbo && flags.get(DecorationNonWritable); + bool is_coherent = ssbo && flags.get(DecorationCoherent); + + // Block names should never alias, but from HLSL input they kind of can because block types are reused for UAVs ... + auto buffer_name = to_name(type.self, false); + + auto &block_namespace = ssbo ? block_ssbo_names : block_ubo_names; + + // Shaders never use the block by interface name, so we don't + // have to track this other than updating name caches. + // If we have a collision for any reason, just fallback immediately. + if (ir.meta[type.self].decoration.alias.empty() || block_namespace.find(buffer_name) != end(block_namespace) || + resource_names.find(buffer_name) != end(resource_names)) + { + buffer_name = get_block_fallback_name(var.self); + } + + // Make sure we get something unique for both global name scope and block name scope. + // See GLSL 4.5 spec: section 4.3.9 for details. + add_variable(block_namespace, resource_names, buffer_name); + + // If for some reason buffer_name is an illegal name, make a final fallback to a workaround name. + // This cannot conflict with anything else, so we're safe now. + // We cannot reuse this fallback name in neither global scope (blocked by block_names) nor block name scope. + if (buffer_name.empty()) + buffer_name = join("_", get(var.basetype).self, "_", var.self); + + block_names.insert(buffer_name); + block_namespace.insert(buffer_name); + + // Save for post-reflection later. + declared_block_names[var.self] = buffer_name; + + statement(layout_for_variable(var), is_coherent ? "coherent " : "", is_restrict ? "restrict " : "", + is_writeonly ? "writeonly " : "", is_readonly ? "readonly " : "", ssbo ? "buffer " : "uniform ", + buffer_name); + + begin_scope(); + + type.member_name_cache.clear(); + + uint32_t i = 0; + for (auto &member : type.member_types) + { + add_member_name(type, i); + emit_struct_member(type, member, i); + i++; + } + + // Don't declare empty blocks in GLSL, this is not allowed. + if (type_is_empty(type) && !backend.supports_empty_struct) + statement("int empty_struct_member;"); + + // var.self can be used as a backup name for the block name, + // so we need to make sure we don't disturb the name here on a recompile. + // It will need to be reset if we have to recompile. + preserve_alias_on_reset(var.self); + add_resource_name(var.self); + end_scope_decl(to_name(var.self) + type_to_array_glsl(type, var.self)); + statement(""); +} + +void CompilerGLSL::emit_buffer_block_flattened(const SPIRVariable &var) +{ + auto &type = get(var.basetype); + + // Block names should never alias. + auto buffer_name = to_name(type.self, false); + size_t buffer_size = (get_declared_struct_size(type) + 15) / 16; + + SPIRType::BaseType basic_type; + if (get_common_basic_type(type, basic_type)) + { + SPIRType tmp { OpTypeVector }; + tmp.basetype = basic_type; + tmp.vecsize = 4; + if (basic_type != SPIRType::Float && basic_type != SPIRType::Int && basic_type != SPIRType::UInt) + SPIRV_CROSS_THROW("Basic types in a flattened UBO must be float, int or uint."); + + auto flags = ir.get_buffer_block_flags(var); + statement("uniform ", flags_to_qualifiers_glsl(tmp, flags), type_to_glsl(tmp), " ", buffer_name, "[", + buffer_size, "];"); + } + else + SPIRV_CROSS_THROW("All basic types in a flattened block must be the same."); +} + +const char *CompilerGLSL::to_storage_qualifiers_glsl(const SPIRVariable &var) +{ + auto &execution = get_entry_point(); + + if (subpass_input_is_framebuffer_fetch(var.self)) + return ""; + + if (var.storage == StorageClassInput || var.storage == StorageClassOutput) + { + if (is_legacy() && execution.model == ExecutionModelVertex) + return var.storage == StorageClassInput ? "attribute " : "varying "; + else if (is_legacy() && execution.model == ExecutionModelFragment) + return "varying "; // Fragment outputs are renamed so they never hit this case. + else if (execution.model == ExecutionModelFragment && var.storage == StorageClassOutput) + { + uint32_t loc = get_decoration(var.self, DecorationLocation); + bool is_inout = location_is_framebuffer_fetch(loc); + if (is_inout) + return "inout "; + else + return "out "; + } + else + return var.storage == StorageClassInput ? "in " : "out "; + } + else if (var.storage == StorageClassUniformConstant || var.storage == StorageClassUniform || + var.storage == StorageClassPushConstant || var.storage == StorageClassAtomicCounter) + { + return "uniform "; + } + else if (var.storage == StorageClassRayPayloadKHR) + { + return ray_tracing_is_khr ? "rayPayloadEXT " : "rayPayloadNV "; + } + else if (var.storage == StorageClassIncomingRayPayloadKHR) + { + return ray_tracing_is_khr ? "rayPayloadInEXT " : "rayPayloadInNV "; + } + else if (var.storage == StorageClassHitAttributeKHR) + { + return ray_tracing_is_khr ? "hitAttributeEXT " : "hitAttributeNV "; + } + else if (var.storage == StorageClassCallableDataKHR) + { + return ray_tracing_is_khr ? "callableDataEXT " : "callableDataNV "; + } + else if (var.storage == StorageClassIncomingCallableDataKHR) + { + return ray_tracing_is_khr ? "callableDataInEXT " : "callableDataInNV "; + } + + return ""; +} + +void CompilerGLSL::emit_flattened_io_block_member(const std::string &basename, const SPIRType &type, const char *qual, + const SmallVector &indices) +{ + uint32_t member_type_id = type.self; + const SPIRType *member_type = &type; + const SPIRType *parent_type = nullptr; + auto flattened_name = basename; + for (auto &index : indices) + { + flattened_name += "_"; + flattened_name += to_member_name(*member_type, index); + parent_type = member_type; + member_type_id = member_type->member_types[index]; + member_type = &get(member_type_id); + } + + assert(member_type->basetype != SPIRType::Struct); + + // We're overriding struct member names, so ensure we do so on the primary type. + if (parent_type->type_alias) + parent_type = &get(parent_type->type_alias); + + // Sanitize underscores because joining the two identifiers might create more than 1 underscore in a row, + // which is not allowed. + ParsedIR::sanitize_underscores(flattened_name); + + uint32_t last_index = indices.back(); + + // Pass in the varying qualifier here so it will appear in the correct declaration order. + // Replace member name while emitting it so it encodes both struct name and member name. + auto backup_name = get_member_name(parent_type->self, last_index); + auto member_name = to_member_name(*parent_type, last_index); + set_member_name(parent_type->self, last_index, flattened_name); + emit_struct_member(*parent_type, member_type_id, last_index, qual); + // Restore member name. + set_member_name(parent_type->self, last_index, member_name); +} + +void CompilerGLSL::emit_flattened_io_block_struct(const std::string &basename, const SPIRType &type, const char *qual, + const SmallVector &indices) +{ + auto sub_indices = indices; + sub_indices.push_back(0); + + const SPIRType *member_type = &type; + for (auto &index : indices) + member_type = &get(member_type->member_types[index]); + + assert(member_type->basetype == SPIRType::Struct); + + if (!member_type->array.empty()) + SPIRV_CROSS_THROW("Cannot flatten array of structs in I/O blocks."); + + for (uint32_t i = 0; i < uint32_t(member_type->member_types.size()); i++) + { + sub_indices.back() = i; + if (get(member_type->member_types[i]).basetype == SPIRType::Struct) + emit_flattened_io_block_struct(basename, type, qual, sub_indices); + else + emit_flattened_io_block_member(basename, type, qual, sub_indices); + } +} + +void CompilerGLSL::emit_flattened_io_block(const SPIRVariable &var, const char *qual) +{ + auto &var_type = get(var.basetype); + if (!var_type.array.empty()) + SPIRV_CROSS_THROW("Array of varying structs cannot be flattened to legacy-compatible varyings."); + + // Emit flattened types based on the type alias. Normally, we are never supposed to emit + // struct declarations for aliased types. + auto &type = var_type.type_alias ? get(var_type.type_alias) : var_type; + + auto old_flags = ir.meta[type.self].decoration.decoration_flags; + // Emit the members as if they are part of a block to get all qualifiers. + ir.meta[type.self].decoration.decoration_flags.set(DecorationBlock); + + type.member_name_cache.clear(); + + SmallVector member_indices; + member_indices.push_back(0); + auto basename = to_name(var.self); + + uint32_t i = 0; + for (auto &member : type.member_types) + { + add_member_name(type, i); + auto &membertype = get(member); + + member_indices.back() = i; + if (membertype.basetype == SPIRType::Struct) + emit_flattened_io_block_struct(basename, type, qual, member_indices); + else + emit_flattened_io_block_member(basename, type, qual, member_indices); + i++; + } + + ir.meta[type.self].decoration.decoration_flags = old_flags; + + // Treat this variable as fully flattened from now on. + flattened_structs[var.self] = true; +} + +void CompilerGLSL::emit_interface_block(const SPIRVariable &var) +{ + auto &type = get(var.basetype); + + if (var.storage == StorageClassInput && type.basetype == SPIRType::Double && + !options.es && options.version < 410) + { + require_extension_internal("GL_ARB_vertex_attrib_64bit"); + } + + // Either make it plain in/out or in/out blocks depending on what shader is doing ... + bool block = ir.meta[type.self].decoration.decoration_flags.get(DecorationBlock); + const char *qual = to_storage_qualifiers_glsl(var); + + if (block) + { + // ESSL earlier than 310 and GLSL earlier than 150 did not support + // I/O variables which are struct types. + // To support this, flatten the struct into separate varyings instead. + if (options.force_flattened_io_blocks || (options.es && options.version < 310) || + (!options.es && options.version < 150)) + { + // I/O blocks on ES require version 310 with Android Extension Pack extensions, or core version 320. + // On desktop, I/O blocks were introduced with geometry shaders in GL 3.2 (GLSL 150). + emit_flattened_io_block(var, qual); + } + else + { + if (options.es && options.version < 320) + { + // Geometry and tessellation extensions imply this extension. + if (!has_extension("GL_EXT_geometry_shader") && !has_extension("GL_EXT_tessellation_shader")) + require_extension_internal("GL_EXT_shader_io_blocks"); + } + + // Workaround to make sure we can emit "patch in/out" correctly. + fixup_io_block_patch_primitive_qualifiers(var); + + // Block names should never alias. + auto block_name = to_name(type.self, false); + + // The namespace for I/O blocks is separate from other variables in GLSL. + auto &block_namespace = type.storage == StorageClassInput ? block_input_names : block_output_names; + + // Shaders never use the block by interface name, so we don't + // have to track this other than updating name caches. + if (block_name.empty() || block_namespace.find(block_name) != end(block_namespace)) + block_name = get_fallback_name(type.self); + else + block_namespace.insert(block_name); + + // If for some reason buffer_name is an illegal name, make a final fallback to a workaround name. + // This cannot conflict with anything else, so we're safe now. + if (block_name.empty()) + block_name = join("_", get(var.basetype).self, "_", var.self); + + // Instance names cannot alias block names. + resource_names.insert(block_name); + + const char *block_qualifier; + if (has_decoration(var.self, DecorationPatch)) + block_qualifier = "patch "; + else if (has_decoration(var.self, DecorationPerPrimitiveEXT)) + block_qualifier = "perprimitiveEXT "; + else + block_qualifier = ""; + + statement(layout_for_variable(var), block_qualifier, qual, block_name); + begin_scope(); + + type.member_name_cache.clear(); + + uint32_t i = 0; + for (auto &member : type.member_types) + { + add_member_name(type, i); + emit_struct_member(type, member, i); + i++; + } + + add_resource_name(var.self); + end_scope_decl(join(to_name(var.self), type_to_array_glsl(type, var.self))); + statement(""); + } + } + else + { + // ESSL earlier than 310 and GLSL earlier than 150 did not support + // I/O variables which are struct types. + // To support this, flatten the struct into separate varyings instead. + if (type.basetype == SPIRType::Struct && + (options.force_flattened_io_blocks || (options.es && options.version < 310) || + (!options.es && options.version < 150))) + { + emit_flattened_io_block(var, qual); + } + else + { + add_resource_name(var.self); + + // Legacy GLSL did not support int attributes, we automatically + // declare them as float and cast them on load/store + SPIRType newtype = type; + if (is_legacy() && var.storage == StorageClassInput && type.basetype == SPIRType::Int) + newtype.basetype = SPIRType::Float; + + // Tessellation control and evaluation shaders must have either + // gl_MaxPatchVertices or unsized arrays for input arrays. + // Opt for unsized as it's the more "correct" variant to use. + if (type.storage == StorageClassInput && !type.array.empty() && + !has_decoration(var.self, DecorationPatch) && + (get_entry_point().model == ExecutionModelTessellationControl || + get_entry_point().model == ExecutionModelTessellationEvaluation)) + { + newtype.array.back() = 0; + newtype.array_size_literal.back() = true; + } + + statement(layout_for_variable(var), to_qualifiers_glsl(var.self), + variable_decl(newtype, to_name(var.self), var.self), ";"); + } + } +} + +void CompilerGLSL::emit_uniform(const SPIRVariable &var) +{ + auto &type = get(var.basetype); + if (type.basetype == SPIRType::Image && type.image.sampled == 2 && type.image.dim != DimSubpassData) + { + if (!options.es && options.version < 420) + require_extension_internal("GL_ARB_shader_image_load_store"); + else if (options.es && options.version < 310) + SPIRV_CROSS_THROW("At least ESSL 3.10 required for shader image load store."); + } + + add_resource_name(var.self); + statement(layout_for_variable(var), variable_decl(var), ";"); +} + +string CompilerGLSL::constant_value_macro_name(uint32_t id) +{ + return join("SPIRV_CROSS_CONSTANT_ID_", id); +} + +void CompilerGLSL::emit_specialization_constant_op(const SPIRConstantOp &constant) +{ + auto &type = get(constant.basetype); + // This will break. It is bogus and should not be legal. + if (type_is_top_level_block(type)) + return; + add_resource_name(constant.self); + auto name = to_name(constant.self); + statement("const ", variable_decl(type, name), " = ", constant_op_expression(constant), ";"); +} + +int CompilerGLSL::get_constant_mapping_to_workgroup_component(const SPIRConstant &c) const +{ + auto &entry_point = get_entry_point(); + int index = -1; + + // Need to redirect specialization constants which are used as WorkGroupSize to the builtin, + // since the spec constant declarations are never explicitly declared. + if (entry_point.workgroup_size.constant == 0 && entry_point.flags.get(ExecutionModeLocalSizeId)) + { + if (c.self == entry_point.workgroup_size.id_x) + index = 0; + else if (c.self == entry_point.workgroup_size.id_y) + index = 1; + else if (c.self == entry_point.workgroup_size.id_z) + index = 2; + } + + return index; +} + +void CompilerGLSL::emit_constant(const SPIRConstant &constant) +{ + auto &type = get(constant.constant_type); + + // This will break. It is bogus and should not be legal. + if (type_is_top_level_block(type)) + return; + + SpecializationConstant wg_x, wg_y, wg_z; + ID workgroup_size_id = get_work_group_size_specialization_constants(wg_x, wg_y, wg_z); + + // This specialization constant is implicitly declared by emitting layout() in; + if (constant.self == workgroup_size_id) + return; + + // These specialization constants are implicitly declared by emitting layout() in; + // In legacy GLSL, we will still need to emit macros for these, so a layout() in; declaration + // later can use macro overrides for work group size. + bool is_workgroup_size_constant = ConstantID(constant.self) == wg_x.id || ConstantID(constant.self) == wg_y.id || + ConstantID(constant.self) == wg_z.id; + + if (options.vulkan_semantics && is_workgroup_size_constant) + { + // Vulkan GLSL does not need to declare workgroup spec constants explicitly, it is handled in layout(). + return; + } + else if (!options.vulkan_semantics && is_workgroup_size_constant && + !has_decoration(constant.self, DecorationSpecId)) + { + // Only bother declaring a workgroup size if it is actually a specialization constant, because we need macros. + return; + } + + add_resource_name(constant.self); + auto name = to_name(constant.self); + + // Only scalars have constant IDs. + if (has_decoration(constant.self, DecorationSpecId)) + { + if (options.vulkan_semantics) + { + statement("layout(constant_id = ", get_decoration(constant.self, DecorationSpecId), ") const ", + variable_decl(type, name), " = ", constant_expression(constant), ";"); + } + else + { + const string ¯o_name = constant.specialization_constant_macro_name; + statement("#ifndef ", macro_name); + statement("#define ", macro_name, " ", constant_expression(constant)); + statement("#endif"); + + // For workgroup size constants, only emit the macros. + if (!is_workgroup_size_constant) + statement("const ", variable_decl(type, name), " = ", macro_name, ";"); + } + } + else + { + statement("const ", variable_decl(type, name), " = ", constant_expression(constant), ";"); + } +} + +void CompilerGLSL::emit_entry_point_declarations() +{ +} + +void CompilerGLSL::replace_illegal_names(const unordered_set &keywords) +{ + ir.for_each_typed_id([&](uint32_t, const SPIRVariable &var) { + if (is_hidden_variable(var)) + return; + + auto *meta = ir.find_meta(var.self); + if (!meta) + return; + + auto &m = meta->decoration; + if (keywords.find(m.alias) != end(keywords)) + m.alias = join("_", m.alias); + }); + + ir.for_each_typed_id([&](uint32_t, const SPIRFunction &func) { + auto *meta = ir.find_meta(func.self); + if (!meta) + return; + + auto &m = meta->decoration; + if (keywords.find(m.alias) != end(keywords)) + m.alias = join("_", m.alias); + }); + + ir.for_each_typed_id([&](uint32_t, const SPIRType &type) { + auto *meta = ir.find_meta(type.self); + if (!meta) + return; + + auto &m = meta->decoration; + if (keywords.find(m.alias) != end(keywords)) + m.alias = join("_", m.alias); + + for (auto &memb : meta->members) + if (keywords.find(memb.alias) != end(keywords)) + memb.alias = join("_", memb.alias); + }); +} + +void CompilerGLSL::replace_illegal_names() +{ + // clang-format off + static const unordered_set keywords = { + "abs", "acos", "acosh", "all", "any", "asin", "asinh", "atan", "atanh", + "atomicAdd", "atomicCompSwap", "atomicCounter", "atomicCounterDecrement", "atomicCounterIncrement", + "atomicExchange", "atomicMax", "atomicMin", "atomicOr", "atomicXor", + "bitCount", "bitfieldExtract", "bitfieldInsert", "bitfieldReverse", + "ceil", "cos", "cosh", "cross", "degrees", + "dFdx", "dFdxCoarse", "dFdxFine", + "dFdy", "dFdyCoarse", "dFdyFine", + "distance", "dot", "EmitStreamVertex", "EmitVertex", "EndPrimitive", "EndStreamPrimitive", "equal", "exp", "exp2", + "faceforward", "findLSB", "findMSB", "float16BitsToInt16", "float16BitsToUint16", "floatBitsToInt", "floatBitsToUint", "floor", "fma", "fract", + "frexp", "fwidth", "fwidthCoarse", "fwidthFine", + "greaterThan", "greaterThanEqual", "groupMemoryBarrier", + "imageAtomicAdd", "imageAtomicAnd", "imageAtomicCompSwap", "imageAtomicExchange", "imageAtomicMax", "imageAtomicMin", "imageAtomicOr", "imageAtomicXor", + "imageLoad", "imageSamples", "imageSize", "imageStore", "imulExtended", "int16BitsToFloat16", "intBitsToFloat", "interpolateAtOffset", "interpolateAtCentroid", "interpolateAtSample", + "inverse", "inversesqrt", "isinf", "isnan", "ldexp", "length", "lessThan", "lessThanEqual", "log", "log2", + "matrixCompMult", "max", "memoryBarrier", "memoryBarrierAtomicCounter", "memoryBarrierBuffer", "memoryBarrierImage", "memoryBarrierShared", + "min", "mix", "mod", "modf", "noise", "noise1", "noise2", "noise3", "noise4", "normalize", "not", "notEqual", + "outerProduct", "packDouble2x32", "packHalf2x16", "packInt2x16", "packInt4x16", "packSnorm2x16", "packSnorm4x8", + "packUint2x16", "packUint4x16", "packUnorm2x16", "packUnorm4x8", "pow", + "radians", "reflect", "refract", "round", "roundEven", "sign", "sin", "sinh", "smoothstep", "sqrt", "step", + "tan", "tanh", "texelFetch", "texelFetchOffset", "texture", "textureGather", "textureGatherOffset", "textureGatherOffsets", + "textureGrad", "textureGradOffset", "textureLod", "textureLodOffset", "textureOffset", "textureProj", "textureProjGrad", + "textureProjGradOffset", "textureProjLod", "textureProjLodOffset", "textureProjOffset", "textureQueryLevels", "textureQueryLod", "textureSamples", "textureSize", + "transpose", "trunc", "uaddCarry", "uint16BitsToFloat16", "uintBitsToFloat", "umulExtended", "unpackDouble2x32", "unpackHalf2x16", "unpackInt2x16", "unpackInt4x16", + "unpackSnorm2x16", "unpackSnorm4x8", "unpackUint2x16", "unpackUint4x16", "unpackUnorm2x16", "unpackUnorm4x8", "usubBorrow", + + "active", "asm", "atomic_uint", "attribute", "bool", "break", "buffer", + "bvec2", "bvec3", "bvec4", "case", "cast", "centroid", "class", "coherent", "common", "const", "continue", "default", "discard", + "dmat2", "dmat2x2", "dmat2x3", "dmat2x4", "dmat3", "dmat3x2", "dmat3x3", "dmat3x4", "dmat4", "dmat4x2", "dmat4x3", "dmat4x4", + "do", "double", "dvec2", "dvec3", "dvec4", "else", "enum", "extern", "external", "false", "filter", "fixed", "flat", "float", + "for", "fvec2", "fvec3", "fvec4", "goto", "half", "highp", "hvec2", "hvec3", "hvec4", "if", "iimage1D", "iimage1DArray", + "iimage2D", "iimage2DArray", "iimage2DMS", "iimage2DMSArray", "iimage2DRect", "iimage3D", "iimageBuffer", "iimageCube", + "iimageCubeArray", "image1D", "image1DArray", "image2D", "image2DArray", "image2DMS", "image2DMSArray", "image2DRect", + "image3D", "imageBuffer", "imageCube", "imageCubeArray", "in", "inline", "inout", "input", "int", "interface", "invariant", + "isampler1D", "isampler1DArray", "isampler2D", "isampler2DArray", "isampler2DMS", "isampler2DMSArray", "isampler2DRect", + "isampler3D", "isamplerBuffer", "isamplerCube", "isamplerCubeArray", "ivec2", "ivec3", "ivec4", "layout", "long", "lowp", + "mat2", "mat2x2", "mat2x3", "mat2x4", "mat3", "mat3x2", "mat3x3", "mat3x4", "mat4", "mat4x2", "mat4x3", "mat4x4", "mediump", + "namespace", "noinline", "noperspective", "out", "output", "packed", "partition", "patch", "precise", "precision", "public", "readonly", + "resource", "restrict", "return", "sample", "sampler1D", "sampler1DArray", "sampler1DArrayShadow", + "sampler1DShadow", "sampler2D", "sampler2DArray", "sampler2DArrayShadow", "sampler2DMS", "sampler2DMSArray", + "sampler2DRect", "sampler2DRectShadow", "sampler2DShadow", "sampler3D", "sampler3DRect", "samplerBuffer", + "samplerCube", "samplerCubeArray", "samplerCubeArrayShadow", "samplerCubeShadow", "shared", "short", "sizeof", "smooth", "static", + "struct", "subroutine", "superp", "switch", "template", "this", "true", "typedef", "uimage1D", "uimage1DArray", "uimage2D", + "uimage2DArray", "uimage2DMS", "uimage2DMSArray", "uimage2DRect", "uimage3D", "uimageBuffer", "uimageCube", + "uimageCubeArray", "uint", "uniform", "union", "unsigned", "usampler1D", "usampler1DArray", "usampler2D", "usampler2DArray", + "usampler2DMS", "usampler2DMSArray", "usampler2DRect", "usampler3D", "usamplerBuffer", "usamplerCube", + "usamplerCubeArray", "using", "uvec2", "uvec3", "uvec4", "varying", "vec2", "vec3", "vec4", "void", "volatile", + "while", "writeonly", + }; + // clang-format on + + replace_illegal_names(keywords); +} + +void CompilerGLSL::replace_fragment_output(SPIRVariable &var) +{ + auto &m = ir.meta[var.self].decoration; + uint32_t location = 0; + if (m.decoration_flags.get(DecorationLocation)) + location = m.location; + + // If our variable is arrayed, we must not emit the array part of this as the SPIR-V will + // do the access chain part of this for us. + auto &type = get(var.basetype); + + if (type.array.empty()) + { + // Redirect the write to a specific render target in legacy GLSL. + m.alias = join("gl_FragData[", location, "]"); + + if (is_legacy_es() && location != 0) + require_extension_internal("GL_EXT_draw_buffers"); + } + else if (type.array.size() == 1) + { + // If location is non-zero, we probably have to add an offset. + // This gets really tricky since we'd have to inject an offset in the access chain. + // FIXME: This seems like an extremely odd-ball case, so it's probably fine to leave it like this for now. + m.alias = "gl_FragData"; + if (location != 0) + SPIRV_CROSS_THROW("Arrayed output variable used, but location is not 0. " + "This is unimplemented in SPIRV-Cross."); + + if (is_legacy_es()) + require_extension_internal("GL_EXT_draw_buffers"); + } + else + SPIRV_CROSS_THROW("Array-of-array output variable used. This cannot be implemented in legacy GLSL."); + + var.compat_builtin = true; // We don't want to declare this variable, but use the name as-is. +} + +void CompilerGLSL::replace_fragment_outputs() +{ + ir.for_each_typed_id([&](uint32_t, SPIRVariable &var) { + auto &type = this->get(var.basetype); + + if (!is_builtin_variable(var) && !var.remapped_variable && type.pointer && var.storage == StorageClassOutput) + replace_fragment_output(var); + }); +} + +string CompilerGLSL::remap_swizzle(const SPIRType &out_type, uint32_t input_components, const string &expr) +{ + if (out_type.vecsize == input_components) + return expr; + else if (input_components == 1 && !backend.can_swizzle_scalar) + return join(type_to_glsl(out_type), "(", expr, ")"); + else + { + // FIXME: This will not work with packed expressions. + auto e = enclose_expression(expr) + "."; + // Just clamp the swizzle index if we have more outputs than inputs. + for (uint32_t c = 0; c < out_type.vecsize; c++) + e += index_to_swizzle(min(c, input_components - 1)); + if (backend.swizzle_is_function && out_type.vecsize > 1) + e += "()"; + + remove_duplicate_swizzle(e); + return e; + } +} + +void CompilerGLSL::emit_pls() +{ + auto &execution = get_entry_point(); + if (execution.model != ExecutionModelFragment) + SPIRV_CROSS_THROW("Pixel local storage only supported in fragment shaders."); + + if (!options.es) + SPIRV_CROSS_THROW("Pixel local storage only supported in OpenGL ES."); + + if (options.version < 300) + SPIRV_CROSS_THROW("Pixel local storage only supported in ESSL 3.0 and above."); + + if (!pls_inputs.empty()) + { + statement("__pixel_local_inEXT _PLSIn"); + begin_scope(); + for (auto &input : pls_inputs) + statement(pls_decl(input), ";"); + end_scope_decl(); + statement(""); + } + + if (!pls_outputs.empty()) + { + statement("__pixel_local_outEXT _PLSOut"); + begin_scope(); + for (auto &output : pls_outputs) + statement(pls_decl(output), ";"); + end_scope_decl(); + statement(""); + } +} + +void CompilerGLSL::fixup_image_load_store_access() +{ + if (!options.enable_storage_image_qualifier_deduction) + return; + + ir.for_each_typed_id([&](uint32_t var, const SPIRVariable &) { + auto &vartype = expression_type(var); + if (vartype.basetype == SPIRType::Image && vartype.image.sampled == 2) + { + // Very old glslangValidator and HLSL compilers do not emit required qualifiers here. + // Solve this by making the image access as restricted as possible and loosen up if we need to. + // If any no-read/no-write flags are actually set, assume that the compiler knows what it's doing. + + if (!has_decoration(var, DecorationNonWritable) && !has_decoration(var, DecorationNonReadable)) + { + set_decoration(var, DecorationNonWritable); + set_decoration(var, DecorationNonReadable); + } + } + }); +} + +static bool is_block_builtin(BuiltIn builtin) +{ + return builtin == BuiltInPosition || builtin == BuiltInPointSize || builtin == BuiltInClipDistance || + builtin == BuiltInCullDistance; +} + +bool CompilerGLSL::should_force_emit_builtin_block(StorageClass storage) +{ + // If the builtin block uses XFB, we need to force explicit redeclaration of the builtin block. + + if (storage != StorageClassOutput) + return false; + bool should_force = false; + + ir.for_each_typed_id([&](uint32_t, SPIRVariable &var) { + if (should_force) + return; + + auto &type = this->get(var.basetype); + bool block = has_decoration(type.self, DecorationBlock); + if (var.storage == storage && block && is_builtin_variable(var)) + { + uint32_t member_count = uint32_t(type.member_types.size()); + for (uint32_t i = 0; i < member_count; i++) + { + if (has_member_decoration(type.self, i, DecorationBuiltIn) && + is_block_builtin(BuiltIn(get_member_decoration(type.self, i, DecorationBuiltIn))) && + has_member_decoration(type.self, i, DecorationOffset)) + { + should_force = true; + } + } + } + else if (var.storage == storage && !block && is_builtin_variable(var)) + { + if (is_block_builtin(BuiltIn(get_decoration(type.self, DecorationBuiltIn))) && + has_decoration(var.self, DecorationOffset)) + { + should_force = true; + } + } + }); + + // If we're declaring clip/cull planes with control points we need to force block declaration. + if ((get_execution_model() == ExecutionModelTessellationControl || + get_execution_model() == ExecutionModelMeshEXT) && + (clip_distance_count || cull_distance_count)) + { + should_force = true; + } + + // Either glslang bug or oversight, but global invariant position does not work in mesh shaders. + if (get_execution_model() == ExecutionModelMeshEXT && position_invariant) + should_force = true; + + return should_force; +} + +void CompilerGLSL::fixup_implicit_builtin_block_names(ExecutionModel model) +{ + ir.for_each_typed_id([&](uint32_t, SPIRVariable &var) { + auto &type = this->get(var.basetype); + bool block = has_decoration(type.self, DecorationBlock); + if ((var.storage == StorageClassOutput || var.storage == StorageClassInput) && block && + is_builtin_variable(var)) + { + if (model != ExecutionModelMeshEXT) + { + // Make sure the array has a supported name in the code. + if (var.storage == StorageClassOutput) + set_name(var.self, "gl_out"); + else if (var.storage == StorageClassInput) + set_name(var.self, "gl_in"); + } + else + { + auto flags = get_buffer_block_flags(var.self); + if (flags.get(DecorationPerPrimitiveEXT)) + { + set_name(var.self, "gl_MeshPrimitivesEXT"); + set_name(type.self, "gl_MeshPerPrimitiveEXT"); + } + else + { + set_name(var.self, "gl_MeshVerticesEXT"); + set_name(type.self, "gl_MeshPerVertexEXT"); + } + } + } + + if (model == ExecutionModelMeshEXT && var.storage == StorageClassOutput && !block) + { + auto *m = ir.find_meta(var.self); + if (m && m->decoration.builtin) + { + auto builtin_type = m->decoration.builtin_type; + if (builtin_type == BuiltInPrimitivePointIndicesEXT) + set_name(var.self, "gl_PrimitivePointIndicesEXT"); + else if (builtin_type == BuiltInPrimitiveLineIndicesEXT) + set_name(var.self, "gl_PrimitiveLineIndicesEXT"); + else if (builtin_type == BuiltInPrimitiveTriangleIndicesEXT) + set_name(var.self, "gl_PrimitiveTriangleIndicesEXT"); + } + } + }); +} + +void CompilerGLSL::emit_declared_builtin_block(StorageClass storage, ExecutionModel model) +{ + Bitset emitted_builtins; + Bitset global_builtins; + const SPIRVariable *block_var = nullptr; + bool emitted_block = false; + + // Need to use declared size in the type. + // These variables might have been declared, but not statically used, so we haven't deduced their size yet. + uint32_t cull_distance_size = 0; + uint32_t clip_distance_size = 0; + + bool have_xfb_buffer_stride = false; + bool have_geom_stream = false; + bool have_any_xfb_offset = false; + uint32_t xfb_stride = 0, xfb_buffer = 0, geom_stream = 0; + std::unordered_map builtin_xfb_offsets; + + const auto builtin_is_per_vertex_set = [](BuiltIn builtin) -> bool { + return builtin == BuiltInPosition || builtin == BuiltInPointSize || + builtin == BuiltInClipDistance || builtin == BuiltInCullDistance; + }; + + ir.for_each_typed_id([&](uint32_t, SPIRVariable &var) { + auto &type = this->get(var.basetype); + bool block = has_decoration(type.self, DecorationBlock); + Bitset builtins; + + if (var.storage == storage && block && is_builtin_variable(var)) + { + uint32_t index = 0; + for (auto &m : ir.meta[type.self].members) + { + if (m.builtin && builtin_is_per_vertex_set(m.builtin_type)) + { + builtins.set(m.builtin_type); + if (m.builtin_type == BuiltInCullDistance) + cull_distance_size = to_array_size_literal(this->get(type.member_types[index])); + else if (m.builtin_type == BuiltInClipDistance) + clip_distance_size = to_array_size_literal(this->get(type.member_types[index])); + + if (is_block_builtin(m.builtin_type) && m.decoration_flags.get(DecorationOffset)) + { + have_any_xfb_offset = true; + builtin_xfb_offsets[m.builtin_type] = m.offset; + } + + if (is_block_builtin(m.builtin_type) && m.decoration_flags.get(DecorationStream)) + { + uint32_t stream = m.stream; + if (have_geom_stream && geom_stream != stream) + SPIRV_CROSS_THROW("IO block member Stream mismatch."); + have_geom_stream = true; + geom_stream = stream; + } + } + index++; + } + + if (storage == StorageClassOutput && has_decoration(var.self, DecorationXfbBuffer) && + has_decoration(var.self, DecorationXfbStride)) + { + uint32_t buffer_index = get_decoration(var.self, DecorationXfbBuffer); + uint32_t stride = get_decoration(var.self, DecorationXfbStride); + if (have_xfb_buffer_stride && buffer_index != xfb_buffer) + SPIRV_CROSS_THROW("IO block member XfbBuffer mismatch."); + if (have_xfb_buffer_stride && stride != xfb_stride) + SPIRV_CROSS_THROW("IO block member XfbBuffer mismatch."); + have_xfb_buffer_stride = true; + xfb_buffer = buffer_index; + xfb_stride = stride; + } + + if (storage == StorageClassOutput && has_decoration(var.self, DecorationStream)) + { + uint32_t stream = get_decoration(var.self, DecorationStream); + if (have_geom_stream && geom_stream != stream) + SPIRV_CROSS_THROW("IO block member Stream mismatch."); + have_geom_stream = true; + geom_stream = stream; + } + } + else if (var.storage == storage && !block && is_builtin_variable(var)) + { + // While we're at it, collect all declared global builtins (HLSL mostly ...). + auto &m = ir.meta[var.self].decoration; + if (m.builtin && builtin_is_per_vertex_set(m.builtin_type)) + { + // For mesh/tesc output, Clip/Cull is an array-of-array. Look at innermost array type + // for correct result. + global_builtins.set(m.builtin_type); + if (m.builtin_type == BuiltInCullDistance) + cull_distance_size = to_array_size_literal(type, 0); + else if (m.builtin_type == BuiltInClipDistance) + clip_distance_size = to_array_size_literal(type, 0); + + if (is_block_builtin(m.builtin_type) && m.decoration_flags.get(DecorationXfbStride) && + m.decoration_flags.get(DecorationXfbBuffer) && m.decoration_flags.get(DecorationOffset)) + { + have_any_xfb_offset = true; + builtin_xfb_offsets[m.builtin_type] = m.offset; + uint32_t buffer_index = m.xfb_buffer; + uint32_t stride = m.xfb_stride; + if (have_xfb_buffer_stride && buffer_index != xfb_buffer) + SPIRV_CROSS_THROW("IO block member XfbBuffer mismatch."); + if (have_xfb_buffer_stride && stride != xfb_stride) + SPIRV_CROSS_THROW("IO block member XfbBuffer mismatch."); + have_xfb_buffer_stride = true; + xfb_buffer = buffer_index; + xfb_stride = stride; + } + + if (is_block_builtin(m.builtin_type) && m.decoration_flags.get(DecorationStream)) + { + uint32_t stream = get_decoration(var.self, DecorationStream); + if (have_geom_stream && geom_stream != stream) + SPIRV_CROSS_THROW("IO block member Stream mismatch."); + have_geom_stream = true; + geom_stream = stream; + } + } + } + + if (builtins.empty()) + return; + + if (emitted_block) + SPIRV_CROSS_THROW("Cannot use more than one builtin I/O block."); + + emitted_builtins = builtins; + emitted_block = true; + block_var = &var; + }); + + global_builtins = + Bitset(global_builtins.get_lower() & ((1ull << BuiltInPosition) | (1ull << BuiltInPointSize) | + (1ull << BuiltInClipDistance) | (1ull << BuiltInCullDistance))); + + // Try to collect all other declared builtins. + if (!emitted_block) + emitted_builtins = global_builtins; + + // Can't declare an empty interface block. + if (emitted_builtins.empty()) + return; + + if (storage == StorageClassOutput) + { + SmallVector attr; + if (have_xfb_buffer_stride && have_any_xfb_offset) + { + if (!options.es) + { + if (options.version < 440 && options.version >= 140) + require_extension_internal("GL_ARB_enhanced_layouts"); + else if (options.version < 140) + SPIRV_CROSS_THROW("Component decoration is not supported in targets below GLSL 1.40."); + if (!options.es && options.version < 440) + require_extension_internal("GL_ARB_enhanced_layouts"); + } + else if (options.es) + SPIRV_CROSS_THROW("Need GL_ARB_enhanced_layouts for xfb_stride or xfb_buffer."); + attr.push_back(join("xfb_buffer = ", xfb_buffer, ", xfb_stride = ", xfb_stride)); + } + + if (have_geom_stream) + { + if (get_execution_model() != ExecutionModelGeometry) + SPIRV_CROSS_THROW("Geometry streams can only be used in geometry shaders."); + if (options.es) + SPIRV_CROSS_THROW("Multiple geometry streams not supported in ESSL."); + if (options.version < 400) + require_extension_internal("GL_ARB_transform_feedback3"); + attr.push_back(join("stream = ", geom_stream)); + } + + if (model == ExecutionModelMeshEXT) + statement("out gl_MeshPerVertexEXT"); + else if (!attr.empty()) + statement("layout(", merge(attr), ") out gl_PerVertex"); + else + statement("out gl_PerVertex"); + } + else + { + // If we have passthrough, there is no way PerVertex cannot be passthrough. + if (get_entry_point().geometry_passthrough) + statement("layout(passthrough) in gl_PerVertex"); + else + statement("in gl_PerVertex"); + } + + begin_scope(); + if (emitted_builtins.get(BuiltInPosition)) + { + auto itr = builtin_xfb_offsets.find(BuiltInPosition); + if (itr != end(builtin_xfb_offsets)) + statement("layout(xfb_offset = ", itr->second, ") vec4 gl_Position;"); + else if (position_invariant) + statement("invariant vec4 gl_Position;"); + else + statement("vec4 gl_Position;"); + } + + if (emitted_builtins.get(BuiltInPointSize)) + { + auto itr = builtin_xfb_offsets.find(BuiltInPointSize); + if (itr != end(builtin_xfb_offsets)) + statement("layout(xfb_offset = ", itr->second, ") float gl_PointSize;"); + else + statement("float gl_PointSize;"); + } + + if (emitted_builtins.get(BuiltInClipDistance)) + { + auto itr = builtin_xfb_offsets.find(BuiltInClipDistance); + if (itr != end(builtin_xfb_offsets)) + statement("layout(xfb_offset = ", itr->second, ") float gl_ClipDistance[", clip_distance_size, "];"); + else + statement("float gl_ClipDistance[", clip_distance_size, "];"); + } + + if (emitted_builtins.get(BuiltInCullDistance)) + { + auto itr = builtin_xfb_offsets.find(BuiltInCullDistance); + if (itr != end(builtin_xfb_offsets)) + statement("layout(xfb_offset = ", itr->second, ") float gl_CullDistance[", cull_distance_size, "];"); + else + statement("float gl_CullDistance[", cull_distance_size, "];"); + } + + bool builtin_array = model == ExecutionModelTessellationControl || + (model == ExecutionModelMeshEXT && storage == StorageClassOutput) || + (model == ExecutionModelGeometry && storage == StorageClassInput) || + (model == ExecutionModelTessellationEvaluation && storage == StorageClassInput); + + if (builtin_array) + { + const char *instance_name; + if (model == ExecutionModelMeshEXT) + instance_name = "gl_MeshVerticesEXT"; // Per primitive is never synthesized. + else + instance_name = storage == StorageClassInput ? "gl_in" : "gl_out"; + + if (model == ExecutionModelTessellationControl && storage == StorageClassOutput) + end_scope_decl(join(instance_name, "[", get_entry_point().output_vertices, "]")); + else + end_scope_decl(join(instance_name, "[]")); + } + else + end_scope_decl(); + statement(""); +} + +bool CompilerGLSL::variable_is_lut(const SPIRVariable &var) const +{ + bool statically_assigned = var.statically_assigned && var.static_expression != ID(0) && var.remapped_variable; + + if (statically_assigned) + { + auto *constant = maybe_get(var.static_expression); + if (constant && constant->is_used_as_lut) + return true; + } + + return false; +} + +void CompilerGLSL::emit_resources() +{ + auto &execution = get_entry_point(); + + replace_illegal_names(); + + // Legacy GL uses gl_FragData[], redeclare all fragment outputs + // with builtins. + if (execution.model == ExecutionModelFragment && is_legacy()) + replace_fragment_outputs(); + + // Emit PLS blocks if we have such variables. + if (!pls_inputs.empty() || !pls_outputs.empty()) + emit_pls(); + + switch (execution.model) + { + case ExecutionModelGeometry: + case ExecutionModelTessellationControl: + case ExecutionModelTessellationEvaluation: + case ExecutionModelMeshEXT: + fixup_implicit_builtin_block_names(execution.model); + break; + + default: + break; + } + + bool global_invariant_position = position_invariant && (options.es || options.version >= 120); + + // Emit custom gl_PerVertex for SSO compatibility. + if (options.separate_shader_objects && !options.es && execution.model != ExecutionModelFragment) + { + switch (execution.model) + { + case ExecutionModelGeometry: + case ExecutionModelTessellationControl: + case ExecutionModelTessellationEvaluation: + emit_declared_builtin_block(StorageClassInput, execution.model); + emit_declared_builtin_block(StorageClassOutput, execution.model); + global_invariant_position = false; + break; + + case ExecutionModelVertex: + case ExecutionModelMeshEXT: + emit_declared_builtin_block(StorageClassOutput, execution.model); + global_invariant_position = false; + break; + + default: + break; + } + } + else if (should_force_emit_builtin_block(StorageClassOutput)) + { + emit_declared_builtin_block(StorageClassOutput, execution.model); + global_invariant_position = false; + } + else if (execution.geometry_passthrough) + { + // Need to declare gl_in with Passthrough. + // If we're doing passthrough, we cannot emit an output block, so the output block test above will never pass. + emit_declared_builtin_block(StorageClassInput, execution.model); + } + else + { + // Need to redeclare clip/cull distance with explicit size to use them. + // SPIR-V mandates these builtins have a size declared. + const char *storage = execution.model == ExecutionModelFragment ? "in" : "out"; + if (clip_distance_count != 0) + statement(storage, " float gl_ClipDistance[", clip_distance_count, "];"); + if (cull_distance_count != 0) + statement(storage, " float gl_CullDistance[", cull_distance_count, "];"); + if (clip_distance_count != 0 || cull_distance_count != 0) + statement(""); + } + + if (global_invariant_position) + { + statement("invariant gl_Position;"); + statement(""); + } + + bool emitted = false; + + // If emitted Vulkan GLSL, + // emit specialization constants as actual floats, + // spec op expressions will redirect to the constant name. + // + { + auto loop_lock = ir.create_loop_hard_lock(); + for (auto &id_ : ir.ids_for_constant_undef_or_type) + { + auto &id = ir.ids[id_]; + + // Skip declaring any bogus constants or undefs which use block types. + // We don't declare block types directly, so this will never work. + // Should not be legal SPIR-V, so this is considered a workaround. + + if (id.get_type() == TypeConstant) + { + auto &c = id.get(); + + bool needs_declaration = c.specialization || c.is_used_as_lut; + + if (needs_declaration) + { + if (!options.vulkan_semantics && c.specialization) + { + c.specialization_constant_macro_name = + constant_value_macro_name(get_decoration(c.self, DecorationSpecId)); + } + emit_constant(c); + emitted = true; + } + } + else if (id.get_type() == TypeConstantOp) + { + emit_specialization_constant_op(id.get()); + emitted = true; + } + else if (id.get_type() == TypeType) + { + auto *type = &id.get(); + + bool is_natural_struct = type->basetype == SPIRType::Struct && type->array.empty() && !type->pointer && + (!has_decoration(type->self, DecorationBlock) && + !has_decoration(type->self, DecorationBufferBlock)); + + // Special case, ray payload and hit attribute blocks are not really blocks, just regular structs. + if (type->basetype == SPIRType::Struct && type->pointer && + has_decoration(type->self, DecorationBlock) && + (type->storage == StorageClassRayPayloadKHR || type->storage == StorageClassIncomingRayPayloadKHR || + type->storage == StorageClassHitAttributeKHR)) + { + type = &get(type->parent_type); + is_natural_struct = true; + } + + if (is_natural_struct) + { + if (emitted) + statement(""); + emitted = false; + + emit_struct(*type); + } + } + else if (id.get_type() == TypeUndef) + { + auto &undef = id.get(); + auto &type = this->get(undef.basetype); + // OpUndef can be void for some reason ... + if (type.basetype == SPIRType::Void) + return; + + // This will break. It is bogus and should not be legal. + if (type_is_top_level_block(type)) + return; + + string initializer; + if (options.force_zero_initialized_variables && type_can_zero_initialize(type)) + initializer = join(" = ", to_zero_initialized_expression(undef.basetype)); + + // FIXME: If used in a constant, we must declare it as one. + statement(variable_decl(type, to_name(undef.self), undef.self), initializer, ";"); + emitted = true; + } + } + } + + if (emitted) + statement(""); + + // If we needed to declare work group size late, check here. + // If the work group size depends on a specialization constant, we need to declare the layout() block + // after constants (and their macros) have been declared. + if (execution.model == ExecutionModelGLCompute && !options.vulkan_semantics && + (execution.workgroup_size.constant != 0 || execution.flags.get(ExecutionModeLocalSizeId))) + { + SpecializationConstant wg_x, wg_y, wg_z; + get_work_group_size_specialization_constants(wg_x, wg_y, wg_z); + + if ((wg_x.id != ConstantID(0)) || (wg_y.id != ConstantID(0)) || (wg_z.id != ConstantID(0))) + { + SmallVector inputs; + build_workgroup_size(inputs, wg_x, wg_y, wg_z); + statement("layout(", merge(inputs), ") in;"); + statement(""); + } + } + + emitted = false; + + if (ir.addressing_model == AddressingModelPhysicalStorageBuffer64EXT) + { + // Output buffer reference blocks. + // Do this in two stages, one with forward declaration, + // and one without. Buffer reference blocks can reference themselves + // to support things like linked lists. + ir.for_each_typed_id([&](uint32_t id, SPIRType &type) { + if (is_physical_pointer(type)) + { + bool emit_type = true; + if (!is_physical_pointer_to_buffer_block(type)) + { + // Only forward-declare if we intend to emit it in the non_block_pointer types. + // Otherwise, these are just "benign" pointer types that exist as a result of access chains. + emit_type = std::find(physical_storage_non_block_pointer_types.begin(), + physical_storage_non_block_pointer_types.end(), + id) != physical_storage_non_block_pointer_types.end(); + } + + if (emit_type) + emit_buffer_reference_block(id, true); + } + }); + + for (auto type : physical_storage_non_block_pointer_types) + emit_buffer_reference_block(type, false); + + ir.for_each_typed_id([&](uint32_t id, SPIRType &type) { + if (is_physical_pointer_to_buffer_block(type)) + emit_buffer_reference_block(id, false); + }); + } + + // Output UBOs and SSBOs + ir.for_each_typed_id([&](uint32_t, SPIRVariable &var) { + auto &type = this->get(var.basetype); + + bool is_block_storage = type.storage == StorageClassStorageBuffer || type.storage == StorageClassUniform || + type.storage == StorageClassShaderRecordBufferKHR; + bool has_block_flags = ir.meta[type.self].decoration.decoration_flags.get(DecorationBlock) || + ir.meta[type.self].decoration.decoration_flags.get(DecorationBufferBlock); + + if (var.storage != StorageClassFunction && type.pointer && is_block_storage && !is_hidden_variable(var) && + has_block_flags) + { + emit_buffer_block(var); + } + }); + + // Output push constant blocks + ir.for_each_typed_id([&](uint32_t, SPIRVariable &var) { + auto &type = this->get(var.basetype); + if (var.storage != StorageClassFunction && type.pointer && type.storage == StorageClassPushConstant && + !is_hidden_variable(var)) + { + emit_push_constant_block(var); + } + }); + + bool skip_separate_image_sampler = !combined_image_samplers.empty() || !options.vulkan_semantics; + + // Output Uniform Constants (values, samplers, images, etc). + ir.for_each_typed_id([&](uint32_t, SPIRVariable &var) { + auto &type = this->get(var.basetype); + + // If we're remapping separate samplers and images, only emit the combined samplers. + if (skip_separate_image_sampler) + { + // Sampler buffers are always used without a sampler, and they will also work in regular GL. + bool sampler_buffer = type.basetype == SPIRType::Image && type.image.dim == DimBuffer; + bool separate_image = type.basetype == SPIRType::Image && type.image.sampled == 1; + bool separate_sampler = type.basetype == SPIRType::Sampler; + if (!sampler_buffer && (separate_image || separate_sampler)) + return; + } + + if (var.storage != StorageClassFunction && type.pointer && + (type.storage == StorageClassUniformConstant || type.storage == StorageClassAtomicCounter || + type.storage == StorageClassRayPayloadKHR || type.storage == StorageClassIncomingRayPayloadKHR || + type.storage == StorageClassCallableDataKHR || type.storage == StorageClassIncomingCallableDataKHR || + type.storage == StorageClassHitAttributeKHR) && + !is_hidden_variable(var)) + { + emit_uniform(var); + emitted = true; + } + }); + + if (emitted) + statement(""); + emitted = false; + + bool emitted_base_instance = false; + + // Output in/out interfaces. + ir.for_each_typed_id([&](uint32_t, SPIRVariable &var) { + auto &type = this->get(var.basetype); + + bool is_hidden = is_hidden_variable(var); + + // Unused output I/O variables might still be required to implement framebuffer fetch. + if (var.storage == StorageClassOutput && !is_legacy() && + location_is_framebuffer_fetch(get_decoration(var.self, DecorationLocation)) != 0) + { + is_hidden = false; + } + + if (var.storage != StorageClassFunction && type.pointer && + (var.storage == StorageClassInput || var.storage == StorageClassOutput) && + interface_variable_exists_in_entry_point(var.self) && !is_hidden) + { + if (options.es && get_execution_model() == ExecutionModelVertex && var.storage == StorageClassInput && + type.array.size() == 1) + { + SPIRV_CROSS_THROW("OpenGL ES doesn't support array input variables in vertex shader."); + } + emit_interface_block(var); + emitted = true; + } + else if (is_builtin_variable(var)) + { + auto builtin = BuiltIn(get_decoration(var.self, DecorationBuiltIn)); + // For gl_InstanceIndex emulation on GLES, the API user needs to + // supply this uniform. + + // The draw parameter extension is soft-enabled on GL with some fallbacks. + if (!options.vulkan_semantics) + { + if (!emitted_base_instance && + ((options.vertex.support_nonzero_base_instance && builtin == BuiltInInstanceIndex) || + (builtin == BuiltInBaseInstance))) + { + statement("#ifdef GL_ARB_shader_draw_parameters"); + statement("#define SPIRV_Cross_BaseInstance gl_BaseInstanceARB"); + statement("#else"); + // A crude, but simple workaround which should be good enough for non-indirect draws. + statement("uniform int SPIRV_Cross_BaseInstance;"); + statement("#endif"); + emitted = true; + emitted_base_instance = true; + } + else if (builtin == BuiltInBaseVertex) + { + statement("#ifdef GL_ARB_shader_draw_parameters"); + statement("#define SPIRV_Cross_BaseVertex gl_BaseVertexARB"); + statement("#else"); + // A crude, but simple workaround which should be good enough for non-indirect draws. + statement("uniform int SPIRV_Cross_BaseVertex;"); + statement("#endif"); + } + else if (builtin == BuiltInDrawIndex) + { + statement("#ifndef GL_ARB_shader_draw_parameters"); + // Cannot really be worked around. + statement("#error GL_ARB_shader_draw_parameters is not supported."); + statement("#endif"); + } + } + } + }); + + // Global variables. + for (auto global : global_variables) + { + auto &var = get(global); + if (is_hidden_variable(var, true)) + continue; + + if (var.storage != StorageClassOutput) + { + if (!variable_is_lut(var)) + { + add_resource_name(var.self); + + string initializer; + if (options.force_zero_initialized_variables && var.storage == StorageClassPrivate && + !var.initializer && !var.static_expression && type_can_zero_initialize(get_variable_data_type(var))) + { + initializer = join(" = ", to_zero_initialized_expression(get_variable_data_type_id(var))); + } + + statement(variable_decl(var), initializer, ";"); + emitted = true; + } + } + else if (var.initializer && maybe_get(var.initializer) != nullptr) + { + emit_output_variable_initializer(var); + } + } + + if (emitted) + statement(""); +} + +void CompilerGLSL::emit_output_variable_initializer(const SPIRVariable &var) +{ + // If a StorageClassOutput variable has an initializer, we need to initialize it in main(). + auto &entry_func = this->get(ir.default_entry_point); + auto &type = get(var.basetype); + bool is_patch = has_decoration(var.self, DecorationPatch); + bool is_block = has_decoration(type.self, DecorationBlock); + bool is_control_point = get_execution_model() == ExecutionModelTessellationControl && !is_patch; + + if (is_block) + { + uint32_t member_count = uint32_t(type.member_types.size()); + bool type_is_array = type.array.size() == 1; + uint32_t array_size = 1; + if (type_is_array) + array_size = to_array_size_literal(type); + uint32_t iteration_count = is_control_point ? 1 : array_size; + + // If the initializer is a block, we must initialize each block member one at a time. + for (uint32_t i = 0; i < member_count; i++) + { + // These outputs might not have been properly declared, so don't initialize them in that case. + if (has_member_decoration(type.self, i, DecorationBuiltIn)) + { + if (get_member_decoration(type.self, i, DecorationBuiltIn) == BuiltInCullDistance && + !cull_distance_count) + continue; + + if (get_member_decoration(type.self, i, DecorationBuiltIn) == BuiltInClipDistance && + !clip_distance_count) + continue; + } + + // We need to build a per-member array first, essentially transposing from AoS to SoA. + // This code path hits when we have an array of blocks. + string lut_name; + if (type_is_array) + { + lut_name = join("_", var.self, "_", i, "_init"); + uint32_t member_type_id = get(var.basetype).member_types[i]; + auto &member_type = get(member_type_id); + auto array_type = member_type; + array_type.parent_type = member_type_id; + array_type.op = OpTypeArray; + array_type.array.push_back(array_size); + array_type.array_size_literal.push_back(true); + + SmallVector exprs; + exprs.reserve(array_size); + auto &c = get(var.initializer); + for (uint32_t j = 0; j < array_size; j++) + exprs.push_back(to_expression(get(c.subconstants[j]).subconstants[i])); + statement("const ", type_to_glsl(array_type), " ", lut_name, type_to_array_glsl(array_type, 0), " = ", + type_to_glsl_constructor(array_type), "(", merge(exprs, ", "), ");"); + } + + for (uint32_t j = 0; j < iteration_count; j++) + { + entry_func.fixup_hooks_in.push_back([=, &var]() { + AccessChainMeta meta; + auto &c = this->get(var.initializer); + + uint32_t invocation_id = 0; + uint32_t member_index_id = 0; + if (is_control_point) + { + uint32_t ids = ir.increase_bound_by(3); + auto &uint_type = set(ids, OpTypeInt); + uint_type.basetype = SPIRType::UInt; + uint_type.width = 32; + set(ids + 1, builtin_to_glsl(BuiltInInvocationId, StorageClassInput), ids, true); + set(ids + 2, ids, i, false); + invocation_id = ids + 1; + member_index_id = ids + 2; + } + + if (is_patch) + { + statement("if (gl_InvocationID == 0)"); + begin_scope(); + } + + if (type_is_array && !is_control_point) + { + uint32_t indices[2] = { j, i }; + auto chain = access_chain_internal(var.self, indices, 2, ACCESS_CHAIN_INDEX_IS_LITERAL_BIT, &meta); + statement(chain, " = ", lut_name, "[", j, "];"); + } + else if (is_control_point) + { + uint32_t indices[2] = { invocation_id, member_index_id }; + auto chain = access_chain_internal(var.self, indices, 2, 0, &meta); + statement(chain, " = ", lut_name, "[", builtin_to_glsl(BuiltInInvocationId, StorageClassInput), "];"); + } + else + { + auto chain = + access_chain_internal(var.self, &i, 1, ACCESS_CHAIN_INDEX_IS_LITERAL_BIT, &meta); + statement(chain, " = ", to_expression(c.subconstants[i]), ";"); + } + + if (is_patch) + end_scope(); + }); + } + } + } + else if (is_control_point) + { + auto lut_name = join("_", var.self, "_init"); + statement("const ", type_to_glsl(type), " ", lut_name, type_to_array_glsl(type, 0), + " = ", to_expression(var.initializer), ";"); + entry_func.fixup_hooks_in.push_back([&, lut_name]() { + statement(to_expression(var.self), "[gl_InvocationID] = ", lut_name, "[gl_InvocationID];"); + }); + } + else if (has_decoration(var.self, DecorationBuiltIn) && + BuiltIn(get_decoration(var.self, DecorationBuiltIn)) == BuiltInSampleMask) + { + // We cannot copy the array since gl_SampleMask is unsized in GLSL. Unroll time! <_< + entry_func.fixup_hooks_in.push_back([&] { + auto &c = this->get(var.initializer); + uint32_t num_constants = uint32_t(c.subconstants.size()); + for (uint32_t i = 0; i < num_constants; i++) + { + // Don't use to_expression on constant since it might be uint, just fish out the raw int. + statement(to_expression(var.self), "[", i, "] = ", + convert_to_string(this->get(c.subconstants[i]).scalar_i32()), ";"); + } + }); + } + else + { + auto lut_name = join("_", var.self, "_init"); + statement("const ", type_to_glsl(type), " ", lut_name, + type_to_array_glsl(type, var.self), " = ", to_expression(var.initializer), ";"); + entry_func.fixup_hooks_in.push_back([&, lut_name, is_patch]() { + if (is_patch) + { + statement("if (gl_InvocationID == 0)"); + begin_scope(); + } + statement(to_expression(var.self), " = ", lut_name, ";"); + if (is_patch) + end_scope(); + }); + } +} + +void CompilerGLSL::emit_subgroup_arithmetic_workaround(const std::string &func, Op op, GroupOperation group_op) +{ + std::string result; + switch (group_op) + { + case GroupOperationReduce: + result = "reduction"; + break; + + case GroupOperationExclusiveScan: + result = "excl_scan"; + break; + + case GroupOperationInclusiveScan: + result = "incl_scan"; + break; + + default: + SPIRV_CROSS_THROW("Unsupported workaround for arithmetic group operation"); + } + + struct TypeInfo + { + std::string type; + std::string identity; + }; + + std::vector type_infos; + switch (op) + { + case OpGroupNonUniformIAdd: + { + type_infos.emplace_back(TypeInfo{ "uint", "0u" }); + type_infos.emplace_back(TypeInfo{ "uvec2", "uvec2(0u)" }); + type_infos.emplace_back(TypeInfo{ "uvec3", "uvec3(0u)" }); + type_infos.emplace_back(TypeInfo{ "uvec4", "uvec4(0u)" }); + type_infos.emplace_back(TypeInfo{ "int", "0" }); + type_infos.emplace_back(TypeInfo{ "ivec2", "ivec2(0)" }); + type_infos.emplace_back(TypeInfo{ "ivec3", "ivec3(0)" }); + type_infos.emplace_back(TypeInfo{ "ivec4", "ivec4(0)" }); + break; + } + + case OpGroupNonUniformFAdd: + { + type_infos.emplace_back(TypeInfo{ "float", "0.0f" }); + type_infos.emplace_back(TypeInfo{ "vec2", "vec2(0.0f)" }); + type_infos.emplace_back(TypeInfo{ "vec3", "vec3(0.0f)" }); + type_infos.emplace_back(TypeInfo{ "vec4", "vec4(0.0f)" }); + // ARB_gpu_shader_fp64 is required in GL4.0 which in turn is required by NV_thread_shuffle + type_infos.emplace_back(TypeInfo{ "double", "0.0LF" }); + type_infos.emplace_back(TypeInfo{ "dvec2", "dvec2(0.0LF)" }); + type_infos.emplace_back(TypeInfo{ "dvec3", "dvec3(0.0LF)" }); + type_infos.emplace_back(TypeInfo{ "dvec4", "dvec4(0.0LF)" }); + break; + } + + case OpGroupNonUniformIMul: + { + type_infos.emplace_back(TypeInfo{ "uint", "1u" }); + type_infos.emplace_back(TypeInfo{ "uvec2", "uvec2(1u)" }); + type_infos.emplace_back(TypeInfo{ "uvec3", "uvec3(1u)" }); + type_infos.emplace_back(TypeInfo{ "uvec4", "uvec4(1u)" }); + type_infos.emplace_back(TypeInfo{ "int", "1" }); + type_infos.emplace_back(TypeInfo{ "ivec2", "ivec2(1)" }); + type_infos.emplace_back(TypeInfo{ "ivec3", "ivec3(1)" }); + type_infos.emplace_back(TypeInfo{ "ivec4", "ivec4(1)" }); + break; + } + + case OpGroupNonUniformFMul: + { + type_infos.emplace_back(TypeInfo{ "float", "1.0f" }); + type_infos.emplace_back(TypeInfo{ "vec2", "vec2(1.0f)" }); + type_infos.emplace_back(TypeInfo{ "vec3", "vec3(1.0f)" }); + type_infos.emplace_back(TypeInfo{ "vec4", "vec4(1.0f)" }); + type_infos.emplace_back(TypeInfo{ "double", "0.0LF" }); + type_infos.emplace_back(TypeInfo{ "dvec2", "dvec2(1.0LF)" }); + type_infos.emplace_back(TypeInfo{ "dvec3", "dvec3(1.0LF)" }); + type_infos.emplace_back(TypeInfo{ "dvec4", "dvec4(1.0LF)" }); + break; + } + + default: + SPIRV_CROSS_THROW("Unsupported workaround for arithmetic group operation"); + } + + const bool op_is_addition = op == OpGroupNonUniformIAdd || op == OpGroupNonUniformFAdd; + const bool op_is_multiplication = op == OpGroupNonUniformIMul || op == OpGroupNonUniformFMul; + std::string op_symbol; + if (op_is_addition) + { + op_symbol = "+="; + } + else if (op_is_multiplication) + { + op_symbol = "*="; + } + + for (const TypeInfo &t : type_infos) + { + statement(t.type, " ", func, "(", t.type, " v)"); + begin_scope(); + statement(t.type, " ", result, " = ", t.identity, ";"); + statement("uvec4 active_threads = subgroupBallot(true);"); + statement("if (subgroupBallotBitCount(active_threads) == gl_SubgroupSize)"); + begin_scope(); + statement("uint total = gl_SubgroupSize / 2u;"); + statement(result, " = v;"); + statement("for (uint i = 1u; i <= total; i <<= 1u)"); + begin_scope(); + statement("bool valid;"); + if (group_op == GroupOperationReduce) + { + statement(t.type, " s = shuffleXorNV(", result, ", i, gl_SubgroupSize, valid);"); + } + else if (group_op == GroupOperationExclusiveScan || group_op == GroupOperationInclusiveScan) + { + statement(t.type, " s = shuffleUpNV(", result, ", i, gl_SubgroupSize, valid);"); + } + if (op_is_addition || op_is_multiplication) + { + statement(result, " ", op_symbol, " valid ? s : ", t.identity, ";"); + } + end_scope(); + if (group_op == GroupOperationExclusiveScan) + { + statement(result, " = shuffleUpNV(", result, ", 1u, gl_SubgroupSize);"); + statement("if (subgroupElect())"); + begin_scope(); + statement(result, " = ", t.identity, ";"); + end_scope(); + } + end_scope(); + statement("else"); + begin_scope(); + if (group_op == GroupOperationExclusiveScan) + { + statement("uint total = subgroupBallotBitCount(gl_SubgroupLtMask);"); + } + else if (group_op == GroupOperationInclusiveScan) + { + statement("uint total = subgroupBallotBitCount(gl_SubgroupLeMask);"); + } + statement("for (uint i = 0u; i < gl_SubgroupSize; ++i)"); + begin_scope(); + statement("bool valid = subgroupBallotBitExtract(active_threads, i);"); + statement(t.type, " s = shuffleNV(v, i, gl_SubgroupSize);"); + if (group_op == GroupOperationExclusiveScan || group_op == GroupOperationInclusiveScan) + { + statement("valid = valid && (i < total);"); + } + if (op_is_addition || op_is_multiplication) + { + statement(result, " ", op_symbol, " valid ? s : ", t.identity, ";"); + } + end_scope(); + end_scope(); + statement("return ", result, ";"); + end_scope(); + } +} + +void CompilerGLSL::emit_extension_workarounds(spv::ExecutionModel model) +{ + static const char *workaround_types[] = { "int", "ivec2", "ivec3", "ivec4", "uint", "uvec2", "uvec3", "uvec4", + "float", "vec2", "vec3", "vec4", "double", "dvec2", "dvec3", "dvec4" }; + + if (!options.vulkan_semantics) + { + using Supp = ShaderSubgroupSupportHelper; + auto result = shader_subgroup_supporter.resolve(); + + if (shader_subgroup_supporter.is_feature_requested(Supp::SubgroupMask)) + { + auto exts = Supp::get_candidates_for_feature(Supp::SubgroupMask, result); + + for (auto &e : exts) + { + const char *name = Supp::get_extension_name(e); + statement(&e == &exts.front() ? "#if" : "#elif", " defined(", name, ")"); + + switch (e) + { + case Supp::NV_shader_thread_group: + statement("#define gl_SubgroupEqMask uvec4(gl_ThreadEqMaskNV, 0u, 0u, 0u)"); + statement("#define gl_SubgroupGeMask uvec4(gl_ThreadGeMaskNV, 0u, 0u, 0u)"); + statement("#define gl_SubgroupGtMask uvec4(gl_ThreadGtMaskNV, 0u, 0u, 0u)"); + statement("#define gl_SubgroupLeMask uvec4(gl_ThreadLeMaskNV, 0u, 0u, 0u)"); + statement("#define gl_SubgroupLtMask uvec4(gl_ThreadLtMaskNV, 0u, 0u, 0u)"); + break; + case Supp::ARB_shader_ballot: + statement("#define gl_SubgroupEqMask uvec4(unpackUint2x32(gl_SubGroupEqMaskARB), 0u, 0u)"); + statement("#define gl_SubgroupGeMask uvec4(unpackUint2x32(gl_SubGroupGeMaskARB), 0u, 0u)"); + statement("#define gl_SubgroupGtMask uvec4(unpackUint2x32(gl_SubGroupGtMaskARB), 0u, 0u)"); + statement("#define gl_SubgroupLeMask uvec4(unpackUint2x32(gl_SubGroupLeMaskARB), 0u, 0u)"); + statement("#define gl_SubgroupLtMask uvec4(unpackUint2x32(gl_SubGroupLtMaskARB), 0u, 0u)"); + break; + default: + break; + } + } + statement("#endif"); + statement(""); + } + + if (shader_subgroup_supporter.is_feature_requested(Supp::SubgroupSize)) + { + auto exts = Supp::get_candidates_for_feature(Supp::SubgroupSize, result); + + for (auto &e : exts) + { + const char *name = Supp::get_extension_name(e); + statement(&e == &exts.front() ? "#if" : "#elif", " defined(", name, ")"); + + switch (e) + { + case Supp::NV_shader_thread_group: + statement("#define gl_SubgroupSize gl_WarpSizeNV"); + break; + case Supp::ARB_shader_ballot: + statement("#define gl_SubgroupSize gl_SubGroupSizeARB"); + break; + case Supp::AMD_gcn_shader: + statement("#define gl_SubgroupSize uint(gl_SIMDGroupSizeAMD)"); + break; + default: + break; + } + } + statement("#endif"); + statement(""); + } + + if (shader_subgroup_supporter.is_feature_requested(Supp::SubgroupInvocationID)) + { + auto exts = Supp::get_candidates_for_feature(Supp::SubgroupInvocationID, result); + + for (auto &e : exts) + { + const char *name = Supp::get_extension_name(e); + statement(&e == &exts.front() ? "#if" : "#elif", " defined(", name, ")"); + + switch (e) + { + case Supp::NV_shader_thread_group: + statement("#define gl_SubgroupInvocationID gl_ThreadInWarpNV"); + break; + case Supp::ARB_shader_ballot: + statement("#define gl_SubgroupInvocationID gl_SubGroupInvocationARB"); + break; + default: + break; + } + } + statement("#endif"); + statement(""); + } + + if (shader_subgroup_supporter.is_feature_requested(Supp::SubgroupID)) + { + auto exts = Supp::get_candidates_for_feature(Supp::SubgroupID, result); + + for (auto &e : exts) + { + const char *name = Supp::get_extension_name(e); + statement(&e == &exts.front() ? "#if" : "#elif", " defined(", name, ")"); + + switch (e) + { + case Supp::NV_shader_thread_group: + statement("#define gl_SubgroupID gl_WarpIDNV"); + break; + default: + break; + } + } + statement("#endif"); + statement(""); + } + + if (shader_subgroup_supporter.is_feature_requested(Supp::NumSubgroups)) + { + auto exts = Supp::get_candidates_for_feature(Supp::NumSubgroups, result); + + for (auto &e : exts) + { + const char *name = Supp::get_extension_name(e); + statement(&e == &exts.front() ? "#if" : "#elif", " defined(", name, ")"); + + switch (e) + { + case Supp::NV_shader_thread_group: + statement("#define gl_NumSubgroups gl_WarpsPerSMNV"); + break; + default: + break; + } + } + statement("#endif"); + statement(""); + } + + if (shader_subgroup_supporter.is_feature_requested(Supp::SubgroupBroadcast_First)) + { + auto exts = Supp::get_candidates_for_feature(Supp::SubgroupBroadcast_First, result); + + for (auto &e : exts) + { + const char *name = Supp::get_extension_name(e); + statement(&e == &exts.front() ? "#if" : "#elif", " defined(", name, ")"); + + switch (e) + { + case Supp::NV_shader_thread_shuffle: + for (const char *t : workaround_types) + { + statement(t, " subgroupBroadcastFirst(", t, + " value) { return shuffleNV(value, findLSB(ballotThreadNV(true)), gl_WarpSizeNV); }"); + } + for (const char *t : workaround_types) + { + statement(t, " subgroupBroadcast(", t, + " value, uint id) { return shuffleNV(value, id, gl_WarpSizeNV); }"); + } + break; + case Supp::ARB_shader_ballot: + for (const char *t : workaround_types) + { + statement(t, " subgroupBroadcastFirst(", t, + " value) { return readFirstInvocationARB(value); }"); + } + for (const char *t : workaround_types) + { + statement(t, " subgroupBroadcast(", t, + " value, uint id) { return readInvocationARB(value, id); }"); + } + break; + default: + break; + } + } + statement("#endif"); + statement(""); + } + + if (shader_subgroup_supporter.is_feature_requested(Supp::SubgroupBallotFindLSB_MSB)) + { + auto exts = Supp::get_candidates_for_feature(Supp::SubgroupBallotFindLSB_MSB, result); + + for (auto &e : exts) + { + const char *name = Supp::get_extension_name(e); + statement(&e == &exts.front() ? "#if" : "#elif", " defined(", name, ")"); + + switch (e) + { + case Supp::NV_shader_thread_group: + statement("uint subgroupBallotFindLSB(uvec4 value) { return findLSB(value.x); }"); + statement("uint subgroupBallotFindMSB(uvec4 value) { return findMSB(value.x); }"); + break; + default: + break; + } + } + statement("#else"); + statement("uint subgroupBallotFindLSB(uvec4 value)"); + begin_scope(); + statement("int firstLive = findLSB(value.x);"); + statement("return uint(firstLive != -1 ? firstLive : (findLSB(value.y) + 32));"); + end_scope(); + statement("uint subgroupBallotFindMSB(uvec4 value)"); + begin_scope(); + statement("int firstLive = findMSB(value.y);"); + statement("return uint(firstLive != -1 ? (firstLive + 32) : findMSB(value.x));"); + end_scope(); + statement("#endif"); + statement(""); + } + + if (shader_subgroup_supporter.is_feature_requested(Supp::SubgroupAll_Any_AllEqualBool)) + { + auto exts = Supp::get_candidates_for_feature(Supp::SubgroupAll_Any_AllEqualBool, result); + + for (auto &e : exts) + { + const char *name = Supp::get_extension_name(e); + statement(&e == &exts.front() ? "#if" : "#elif", " defined(", name, ")"); + + switch (e) + { + case Supp::NV_gpu_shader_5: + statement("bool subgroupAll(bool value) { return allThreadsNV(value); }"); + statement("bool subgroupAny(bool value) { return anyThreadNV(value); }"); + statement("bool subgroupAllEqual(bool value) { return allThreadsEqualNV(value); }"); + break; + case Supp::ARB_shader_group_vote: + statement("bool subgroupAll(bool v) { return allInvocationsARB(v); }"); + statement("bool subgroupAny(bool v) { return anyInvocationARB(v); }"); + statement("bool subgroupAllEqual(bool v) { return allInvocationsEqualARB(v); }"); + break; + case Supp::AMD_gcn_shader: + statement("bool subgroupAll(bool value) { return ballotAMD(value) == ballotAMD(true); }"); + statement("bool subgroupAny(bool value) { return ballotAMD(value) != 0ull; }"); + statement("bool subgroupAllEqual(bool value) { uint64_t b = ballotAMD(value); return b == 0ull || " + "b == ballotAMD(true); }"); + break; + default: + break; + } + } + statement("#endif"); + statement(""); + } + + if (shader_subgroup_supporter.is_feature_requested(Supp::SubgroupAllEqualT)) + { + statement("#ifndef GL_KHR_shader_subgroup_vote"); + statement( + "#define _SPIRV_CROSS_SUBGROUP_ALL_EQUAL_WORKAROUND(type) bool subgroupAllEqual(type value) { return " + "subgroupAllEqual(subgroupBroadcastFirst(value) == value); }"); + for (const char *t : workaround_types) + statement("_SPIRV_CROSS_SUBGROUP_ALL_EQUAL_WORKAROUND(", t, ")"); + statement("#undef _SPIRV_CROSS_SUBGROUP_ALL_EQUAL_WORKAROUND"); + statement("#endif"); + statement(""); + } + + if (shader_subgroup_supporter.is_feature_requested(Supp::SubgroupBallot)) + { + auto exts = Supp::get_candidates_for_feature(Supp::SubgroupBallot, result); + + for (auto &e : exts) + { + const char *name = Supp::get_extension_name(e); + statement(&e == &exts.front() ? "#if" : "#elif", " defined(", name, ")"); + + switch (e) + { + case Supp::NV_shader_thread_group: + statement("uvec4 subgroupBallot(bool v) { return uvec4(ballotThreadNV(v), 0u, 0u, 0u); }"); + break; + case Supp::ARB_shader_ballot: + statement("uvec4 subgroupBallot(bool v) { return uvec4(unpackUint2x32(ballotARB(v)), 0u, 0u); }"); + break; + default: + break; + } + } + statement("#endif"); + statement(""); + } + + if (shader_subgroup_supporter.is_feature_requested(Supp::SubgroupElect)) + { + statement("#ifndef GL_KHR_shader_subgroup_basic"); + statement("bool subgroupElect()"); + begin_scope(); + statement("uvec4 activeMask = subgroupBallot(true);"); + statement("uint firstLive = subgroupBallotFindLSB(activeMask);"); + statement("return gl_SubgroupInvocationID == firstLive;"); + end_scope(); + statement("#endif"); + statement(""); + } + + if (shader_subgroup_supporter.is_feature_requested(Supp::SubgroupBarrier)) + { + // Extensions we're using in place of GL_KHR_shader_subgroup_basic state + // that subgroup execute in lockstep so this barrier is implicit. + // However the GL 4.6 spec also states that `barrier` implies a shared memory barrier, + // and a specific test of optimizing scans by leveraging lock-step invocation execution, + // has shown that a `memoryBarrierShared` is needed in place of a `subgroupBarrier`. + // https://github.com/buildaworldnet/IrrlichtBAW/commit/d8536857991b89a30a6b65d29441e51b64c2c7ad#diff-9f898d27be1ea6fc79b03d9b361e299334c1a347b6e4dc344ee66110c6aa596aR19 + statement("#ifndef GL_KHR_shader_subgroup_basic"); + statement("void subgroupBarrier() { memoryBarrierShared(); }"); + statement("#endif"); + statement(""); + } + + if (shader_subgroup_supporter.is_feature_requested(Supp::SubgroupMemBarrier)) + { + if (model == spv::ExecutionModelGLCompute) + { + statement("#ifndef GL_KHR_shader_subgroup_basic"); + statement("void subgroupMemoryBarrier() { groupMemoryBarrier(); }"); + statement("void subgroupMemoryBarrierBuffer() { groupMemoryBarrier(); }"); + statement("void subgroupMemoryBarrierShared() { memoryBarrierShared(); }"); + statement("void subgroupMemoryBarrierImage() { groupMemoryBarrier(); }"); + statement("#endif"); + } + else + { + statement("#ifndef GL_KHR_shader_subgroup_basic"); + statement("void subgroupMemoryBarrier() { memoryBarrier(); }"); + statement("void subgroupMemoryBarrierBuffer() { memoryBarrierBuffer(); }"); + statement("void subgroupMemoryBarrierImage() { memoryBarrierImage(); }"); + statement("#endif"); + } + statement(""); + } + + if (shader_subgroup_supporter.is_feature_requested(Supp::SubgroupInverseBallot_InclBitCount_ExclBitCout)) + { + statement("#ifndef GL_KHR_shader_subgroup_ballot"); + statement("bool subgroupInverseBallot(uvec4 value)"); + begin_scope(); + statement("return any(notEqual(value.xy & gl_SubgroupEqMask.xy, uvec2(0u)));"); + end_scope(); + + statement("uint subgroupBallotInclusiveBitCount(uvec4 value)"); + begin_scope(); + statement("uvec2 v = value.xy & gl_SubgroupLeMask.xy;"); + statement("ivec2 c = bitCount(v);"); + statement_no_indent("#ifdef GL_NV_shader_thread_group"); + statement("return uint(c.x);"); + statement_no_indent("#else"); + statement("return uint(c.x + c.y);"); + statement_no_indent("#endif"); + end_scope(); + + statement("uint subgroupBallotExclusiveBitCount(uvec4 value)"); + begin_scope(); + statement("uvec2 v = value.xy & gl_SubgroupLtMask.xy;"); + statement("ivec2 c = bitCount(v);"); + statement_no_indent("#ifdef GL_NV_shader_thread_group"); + statement("return uint(c.x);"); + statement_no_indent("#else"); + statement("return uint(c.x + c.y);"); + statement_no_indent("#endif"); + end_scope(); + statement("#endif"); + statement(""); + } + + if (shader_subgroup_supporter.is_feature_requested(Supp::SubgroupBallotBitCount)) + { + statement("#ifndef GL_KHR_shader_subgroup_ballot"); + statement("uint subgroupBallotBitCount(uvec4 value)"); + begin_scope(); + statement("ivec2 c = bitCount(value.xy);"); + statement_no_indent("#ifdef GL_NV_shader_thread_group"); + statement("return uint(c.x);"); + statement_no_indent("#else"); + statement("return uint(c.x + c.y);"); + statement_no_indent("#endif"); + end_scope(); + statement("#endif"); + statement(""); + } + + if (shader_subgroup_supporter.is_feature_requested(Supp::SubgroupBallotBitExtract)) + { + statement("#ifndef GL_KHR_shader_subgroup_ballot"); + statement("bool subgroupBallotBitExtract(uvec4 value, uint index)"); + begin_scope(); + statement_no_indent("#ifdef GL_NV_shader_thread_group"); + statement("uint shifted = value.x >> index;"); + statement_no_indent("#else"); + statement("uint shifted = value[index >> 5u] >> (index & 0x1fu);"); + statement_no_indent("#endif"); + statement("return (shifted & 1u) != 0u;"); + end_scope(); + statement("#endif"); + statement(""); + } + + auto arithmetic_feature_helper = + [&](Supp::Feature feat, std::string func_name, spv::Op op, spv::GroupOperation group_op) + { + if (shader_subgroup_supporter.is_feature_requested(feat)) + { + auto exts = Supp::get_candidates_for_feature(feat, result); + for (auto &e : exts) + { + const char *name = Supp::get_extension_name(e); + statement(&e == &exts.front() ? "#if" : "#elif", " defined(", name, ")"); + + switch (e) + { + case Supp::NV_shader_thread_shuffle: + emit_subgroup_arithmetic_workaround(func_name, op, group_op); + break; + default: + break; + } + } + statement("#endif"); + statement(""); + } + }; + + arithmetic_feature_helper(Supp::SubgroupArithmeticIAddReduce, "subgroupAdd", OpGroupNonUniformIAdd, + GroupOperationReduce); + arithmetic_feature_helper(Supp::SubgroupArithmeticIAddExclusiveScan, "subgroupExclusiveAdd", + OpGroupNonUniformIAdd, GroupOperationExclusiveScan); + arithmetic_feature_helper(Supp::SubgroupArithmeticIAddInclusiveScan, "subgroupInclusiveAdd", + OpGroupNonUniformIAdd, GroupOperationInclusiveScan); + arithmetic_feature_helper(Supp::SubgroupArithmeticFAddReduce, "subgroupAdd", OpGroupNonUniformFAdd, + GroupOperationReduce); + arithmetic_feature_helper(Supp::SubgroupArithmeticFAddExclusiveScan, "subgroupExclusiveAdd", + OpGroupNonUniformFAdd, GroupOperationExclusiveScan); + arithmetic_feature_helper(Supp::SubgroupArithmeticFAddInclusiveScan, "subgroupInclusiveAdd", + OpGroupNonUniformFAdd, GroupOperationInclusiveScan); + + arithmetic_feature_helper(Supp::SubgroupArithmeticIMulReduce, "subgroupMul", OpGroupNonUniformIMul, + GroupOperationReduce); + arithmetic_feature_helper(Supp::SubgroupArithmeticIMulExclusiveScan, "subgroupExclusiveMul", + OpGroupNonUniformIMul, GroupOperationExclusiveScan); + arithmetic_feature_helper(Supp::SubgroupArithmeticIMulInclusiveScan, "subgroupInclusiveMul", + OpGroupNonUniformIMul, GroupOperationInclusiveScan); + arithmetic_feature_helper(Supp::SubgroupArithmeticFMulReduce, "subgroupMul", OpGroupNonUniformFMul, + GroupOperationReduce); + arithmetic_feature_helper(Supp::SubgroupArithmeticFMulExclusiveScan, "subgroupExclusiveMul", + OpGroupNonUniformFMul, GroupOperationExclusiveScan); + arithmetic_feature_helper(Supp::SubgroupArithmeticFMulInclusiveScan, "subgroupInclusiveMul", + OpGroupNonUniformFMul, GroupOperationInclusiveScan); + } + + if (!workaround_ubo_load_overload_types.empty()) + { + for (auto &type_id : workaround_ubo_load_overload_types) + { + auto &type = get(type_id); + + if (options.es && is_matrix(type)) + { + // Need both variants. + // GLSL cannot overload on precision, so need to dispatch appropriately. + statement("highp ", type_to_glsl(type), " spvWorkaroundRowMajor(highp ", type_to_glsl(type), " wrap) { return wrap; }"); + statement("mediump ", type_to_glsl(type), " spvWorkaroundRowMajorMP(mediump ", type_to_glsl(type), " wrap) { return wrap; }"); + } + else + { + statement(type_to_glsl(type), " spvWorkaroundRowMajor(", type_to_glsl(type), " wrap) { return wrap; }"); + } + } + statement(""); + } +} + +void CompilerGLSL::emit_polyfills(uint32_t polyfills, bool relaxed) +{ + const char *qual = ""; + const char *suffix = (options.es && relaxed) ? "MP" : ""; + if (options.es) + qual = relaxed ? "mediump " : "highp "; + + if (polyfills & PolyfillTranspose2x2) + { + statement(qual, "mat2 spvTranspose", suffix, "(", qual, "mat2 m)"); + begin_scope(); + statement("return mat2(m[0][0], m[1][0], m[0][1], m[1][1]);"); + end_scope(); + statement(""); + } + + if (polyfills & PolyfillTranspose3x3) + { + statement(qual, "mat3 spvTranspose", suffix, "(", qual, "mat3 m)"); + begin_scope(); + statement("return mat3(m[0][0], m[1][0], m[2][0], m[0][1], m[1][1], m[2][1], m[0][2], m[1][2], m[2][2]);"); + end_scope(); + statement(""); + } + + if (polyfills & PolyfillTranspose4x4) + { + statement(qual, "mat4 spvTranspose", suffix, "(", qual, "mat4 m)"); + begin_scope(); + statement("return mat4(m[0][0], m[1][0], m[2][0], m[3][0], m[0][1], m[1][1], m[2][1], m[3][1], m[0][2], " + "m[1][2], m[2][2], m[3][2], m[0][3], m[1][3], m[2][3], m[3][3]);"); + end_scope(); + statement(""); + } + + if (polyfills & PolyfillDeterminant2x2) + { + statement(qual, "float spvDeterminant", suffix, "(", qual, "mat2 m)"); + begin_scope(); + statement("return m[0][0] * m[1][1] - m[0][1] * m[1][0];"); + end_scope(); + statement(""); + } + + if (polyfills & PolyfillDeterminant3x3) + { + statement(qual, "float spvDeterminant", suffix, "(", qual, "mat3 m)"); + begin_scope(); + statement("return dot(m[0], vec3(m[1][1] * m[2][2] - m[1][2] * m[2][1], " + "m[1][2] * m[2][0] - m[1][0] * m[2][2], " + "m[1][0] * m[2][1] - m[1][1] * m[2][0]));"); + end_scope(); + statement(""); + } + + if (polyfills & PolyfillDeterminant4x4) + { + statement(qual, "float spvDeterminant", suffix, "(", qual, "mat4 m)"); + begin_scope(); + statement("return dot(m[0], vec4(" + "m[2][1] * m[3][2] * m[1][3] - m[3][1] * m[2][2] * m[1][3] + m[3][1] * m[1][2] * m[2][3] - m[1][1] * m[3][2] * m[2][3] - m[2][1] * m[1][2] * m[3][3] + m[1][1] * m[2][2] * m[3][3], " + "m[3][0] * m[2][2] * m[1][3] - m[2][0] * m[3][2] * m[1][3] - m[3][0] * m[1][2] * m[2][3] + m[1][0] * m[3][2] * m[2][3] + m[2][0] * m[1][2] * m[3][3] - m[1][0] * m[2][2] * m[3][3], " + "m[2][0] * m[3][1] * m[1][3] - m[3][0] * m[2][1] * m[1][3] + m[3][0] * m[1][1] * m[2][3] - m[1][0] * m[3][1] * m[2][3] - m[2][0] * m[1][1] * m[3][3] + m[1][0] * m[2][1] * m[3][3], " + "m[3][0] * m[2][1] * m[1][2] - m[2][0] * m[3][1] * m[1][2] - m[3][0] * m[1][1] * m[2][2] + m[1][0] * m[3][1] * m[2][2] + m[2][0] * m[1][1] * m[3][2] - m[1][0] * m[2][1] * m[3][2]));"); + end_scope(); + statement(""); + } + + if (polyfills & PolyfillMatrixInverse2x2) + { + statement(qual, "mat2 spvInverse", suffix, "(", qual, "mat2 m)"); + begin_scope(); + statement("return mat2(m[1][1], -m[0][1], -m[1][0], m[0][0]) " + "* (1.0 / (m[0][0] * m[1][1] - m[1][0] * m[0][1]));"); + end_scope(); + statement(""); + } + + if (polyfills & PolyfillMatrixInverse3x3) + { + statement(qual, "mat3 spvInverse", suffix, "(", qual, "mat3 m)"); + begin_scope(); + statement(qual, "vec3 t = vec3(m[1][1] * m[2][2] - m[1][2] * m[2][1], m[1][2] * m[2][0] - m[1][0] * m[2][2], m[1][0] * m[2][1] - m[1][1] * m[2][0]);"); + statement("return mat3(t[0], " + "m[0][2] * m[2][1] - m[0][1] * m[2][2], " + "m[0][1] * m[1][2] - m[0][2] * m[1][1], " + "t[1], " + "m[0][0] * m[2][2] - m[0][2] * m[2][0], " + "m[0][2] * m[1][0] - m[0][0] * m[1][2], " + "t[2], " + "m[0][1] * m[2][0] - m[0][0] * m[2][1], " + "m[0][0] * m[1][1] - m[0][1] * m[1][0]) " + "* (1.0 / dot(m[0], t));"); + end_scope(); + statement(""); + } + + if (polyfills & PolyfillMatrixInverse4x4) + { + statement(qual, "mat4 spvInverse", suffix, "(", qual, "mat4 m)"); + begin_scope(); + statement(qual, "vec4 t = vec4(" + "m[2][1] * m[3][2] * m[1][3] - m[3][1] * m[2][2] * m[1][3] + m[3][1] * m[1][2] * m[2][3] - m[1][1] * m[3][2] * m[2][3] - m[2][1] * m[1][2] * m[3][3] + m[1][1] * m[2][2] * m[3][3], " + "m[3][0] * m[2][2] * m[1][3] - m[2][0] * m[3][2] * m[1][3] - m[3][0] * m[1][2] * m[2][3] + m[1][0] * m[3][2] * m[2][3] + m[2][0] * m[1][2] * m[3][3] - m[1][0] * m[2][2] * m[3][3], " + "m[2][0] * m[3][1] * m[1][3] - m[3][0] * m[2][1] * m[1][3] + m[3][0] * m[1][1] * m[2][3] - m[1][0] * m[3][1] * m[2][3] - m[2][0] * m[1][1] * m[3][3] + m[1][0] * m[2][1] * m[3][3], " + "m[3][0] * m[2][1] * m[1][2] - m[2][0] * m[3][1] * m[1][2] - m[3][0] * m[1][1] * m[2][2] + m[1][0] * m[3][1] * m[2][2] + m[2][0] * m[1][1] * m[3][2] - m[1][0] * m[2][1] * m[3][2]);"); + statement("return mat4(" + "t[0], " + "m[3][1] * m[2][2] * m[0][3] - m[2][1] * m[3][2] * m[0][3] - m[3][1] * m[0][2] * m[2][3] + m[0][1] * m[3][2] * m[2][3] + m[2][1] * m[0][2] * m[3][3] - m[0][1] * m[2][2] * m[3][3], " + "m[1][1] * m[3][2] * m[0][3] - m[3][1] * m[1][2] * m[0][3] + m[3][1] * m[0][2] * m[1][3] - m[0][1] * m[3][2] * m[1][3] - m[1][1] * m[0][2] * m[3][3] + m[0][1] * m[1][2] * m[3][3], " + "m[2][1] * m[1][2] * m[0][3] - m[1][1] * m[2][2] * m[0][3] - m[2][1] * m[0][2] * m[1][3] + m[0][1] * m[2][2] * m[1][3] + m[1][1] * m[0][2] * m[2][3] - m[0][1] * m[1][2] * m[2][3], " + "t[1], " + "m[2][0] * m[3][2] * m[0][3] - m[3][0] * m[2][2] * m[0][3] + m[3][0] * m[0][2] * m[2][3] - m[0][0] * m[3][2] * m[2][3] - m[2][0] * m[0][2] * m[3][3] + m[0][0] * m[2][2] * m[3][3], " + "m[3][0] * m[1][2] * m[0][3] - m[1][0] * m[3][2] * m[0][3] - m[3][0] * m[0][2] * m[1][3] + m[0][0] * m[3][2] * m[1][3] + m[1][0] * m[0][2] * m[3][3] - m[0][0] * m[1][2] * m[3][3], " + "m[1][0] * m[2][2] * m[0][3] - m[2][0] * m[1][2] * m[0][3] + m[2][0] * m[0][2] * m[1][3] - m[0][0] * m[2][2] * m[1][3] - m[1][0] * m[0][2] * m[2][3] + m[0][0] * m[1][2] * m[2][3], " + "t[2], " + "m[3][0] * m[2][1] * m[0][3] - m[2][0] * m[3][1] * m[0][3] - m[3][0] * m[0][1] * m[2][3] + m[0][0] * m[3][1] * m[2][3] + m[2][0] * m[0][1] * m[3][3] - m[0][0] * m[2][1] * m[3][3], " + "m[1][0] * m[3][1] * m[0][3] - m[3][0] * m[1][1] * m[0][3] + m[3][0] * m[0][1] * m[1][3] - m[0][0] * m[3][1] * m[1][3] - m[1][0] * m[0][1] * m[3][3] + m[0][0] * m[1][1] * m[3][3], " + "m[2][0] * m[1][1] * m[0][3] - m[1][0] * m[2][1] * m[0][3] - m[2][0] * m[0][1] * m[1][3] + m[0][0] * m[2][1] * m[1][3] + m[1][0] * m[0][1] * m[2][3] - m[0][0] * m[1][1] * m[2][3], " + "t[3], " + "m[2][0] * m[3][1] * m[0][2] - m[3][0] * m[2][1] * m[0][2] + m[3][0] * m[0][1] * m[2][2] - m[0][0] * m[3][1] * m[2][2] - m[2][0] * m[0][1] * m[3][2] + m[0][0] * m[2][1] * m[3][2], " + "m[3][0] * m[1][1] * m[0][2] - m[1][0] * m[3][1] * m[0][2] - m[3][0] * m[0][1] * m[1][2] + m[0][0] * m[3][1] * m[1][2] + m[1][0] * m[0][1] * m[3][2] - m[0][0] * m[1][1] * m[3][2], " + "m[1][0] * m[2][1] * m[0][2] - m[2][0] * m[1][1] * m[0][2] + m[2][0] * m[0][1] * m[1][2] - m[0][0] * m[2][1] * m[1][2] - m[1][0] * m[0][1] * m[2][2] + m[0][0] * m[1][1] * m[2][2]) " + "* (1.0 / dot(m[0], t));"); + end_scope(); + statement(""); + } + + if (!relaxed) + { + static const Polyfill polys[3][3] = { + { PolyfillNMin16, PolyfillNMin32, PolyfillNMin64 }, + { PolyfillNMax16, PolyfillNMax32, PolyfillNMax64 }, + { PolyfillNClamp16, PolyfillNClamp32, PolyfillNClamp64 }, + }; + + static const GLSLstd450 glsl_ops[] = { GLSLstd450NMin, GLSLstd450NMax, GLSLstd450NClamp }; + static const char *spv_ops[] = { "spvNMin", "spvNMax", "spvNClamp" }; + bool has_poly = false; + + for (uint32_t i = 0; i < 3; i++) + { + for (uint32_t j = 0; j < 3; j++) + { + if ((polyfills & polys[i][j]) == 0) + continue; + + const char *types[3][4] = { + { "float16_t", "f16vec2", "f16vec3", "f16vec4" }, + { "float", "vec2", "vec3", "vec4" }, + { "double", "dvec2", "dvec3", "dvec4" }, + }; + + for (uint32_t k = 0; k < 4; k++) + { + auto *type = types[j][k]; + + if (i < 2) + { + statement("spirv_instruction(set = \"GLSL.std.450\", id = ", glsl_ops[i], ") ", + type, " ", spv_ops[i], "(", type, ", ", type, ");"); + } + else + { + statement("spirv_instruction(set = \"GLSL.std.450\", id = ", glsl_ops[i], ") ", + type, " ", spv_ops[i], "(", type, ", ", type, ", ", type, ");"); + } + + has_poly = true; + } + } + } + + if (has_poly) + statement(""); + } + else + { + // Mediump intrinsics don't work correctly, so wrap the intrinsic in an outer shell that ensures mediump + // propagation. + + static const Polyfill polys[3][3] = { + { PolyfillNMin16, PolyfillNMin32, PolyfillNMin64 }, + { PolyfillNMax16, PolyfillNMax32, PolyfillNMax64 }, + { PolyfillNClamp16, PolyfillNClamp32, PolyfillNClamp64 }, + }; + + static const char *spv_ops[] = { "spvNMin", "spvNMax", "spvNClamp" }; + + for (uint32_t i = 0; i < 3; i++) + { + for (uint32_t j = 0; j < 3; j++) + { + if ((polyfills & polys[i][j]) == 0) + continue; + + const char *types[3][4] = { + { "float16_t", "f16vec2", "f16vec3", "f16vec4" }, + { "float", "vec2", "vec3", "vec4" }, + { "double", "dvec2", "dvec3", "dvec4" }, + }; + + for (uint32_t k = 0; k < 4; k++) + { + auto *type = types[j][k]; + + if (i < 2) + { + statement("mediump ", type, " ", spv_ops[i], "Relaxed(", + "mediump ", type, " a, mediump ", type, " b)"); + begin_scope(); + statement("mediump ", type, " res = ", spv_ops[i], "(a, b);"); + statement("return res;"); + end_scope(); + statement(""); + } + else + { + statement("mediump ", type, " ", spv_ops[i], "Relaxed(", + "mediump ", type, " a, mediump ", type, " b, mediump ", type, " c)"); + begin_scope(); + statement("mediump ", type, " res = ", spv_ops[i], "(a, b, c);"); + statement("return res;"); + end_scope(); + statement(""); + } + } + } + } + } +} + +// Returns a string representation of the ID, usable as a function arg. +// Default is to simply return the expression representation fo the arg ID. +// Subclasses may override to modify the return value. +string CompilerGLSL::to_func_call_arg(const SPIRFunction::Parameter &, uint32_t id) +{ + // Make sure that we use the name of the original variable, and not the parameter alias. + uint32_t name_id = id; + auto *var = maybe_get(id); + if (var && var->basevariable) + name_id = var->basevariable; + return to_expression(name_id); +} + +void CompilerGLSL::force_temporary_and_recompile(uint32_t id) +{ + auto res = forced_temporaries.insert(id); + + // Forcing new temporaries guarantees forward progress. + if (res.second) + force_recompile_guarantee_forward_progress(); + else + force_recompile(); +} + +uint32_t CompilerGLSL::consume_temporary_in_precision_context(uint32_t type_id, uint32_t id, Options::Precision precision) +{ + // Constants do not have innate precision. + auto handle_type = ir.ids[id].get_type(); + if (handle_type == TypeConstant || handle_type == TypeConstantOp || handle_type == TypeUndef) + return id; + + // Ignore anything that isn't 32-bit values. + auto &type = get(type_id); + if (type.pointer) + return id; + if (type.basetype != SPIRType::Float && type.basetype != SPIRType::UInt && type.basetype != SPIRType::Int) + return id; + + if (precision == Options::DontCare) + { + // If precision is consumed as don't care (operations only consisting of constants), + // we need to bind the expression to a temporary, + // otherwise we have no way of controlling the precision later. + auto itr = forced_temporaries.insert(id); + if (itr.second) + force_recompile_guarantee_forward_progress(); + return id; + } + + auto current_precision = has_decoration(id, DecorationRelaxedPrecision) ? Options::Mediump : Options::Highp; + if (current_precision == precision) + return id; + + auto itr = temporary_to_mirror_precision_alias.find(id); + if (itr == temporary_to_mirror_precision_alias.end()) + { + uint32_t alias_id = ir.increase_bound_by(1); + auto &m = ir.meta[alias_id]; + if (auto *input_m = ir.find_meta(id)) + m = *input_m; + + const char *prefix; + if (precision == Options::Mediump) + { + set_decoration(alias_id, DecorationRelaxedPrecision); + prefix = "mp_copy_"; + } + else + { + unset_decoration(alias_id, DecorationRelaxedPrecision); + prefix = "hp_copy_"; + } + + auto alias_name = join(prefix, to_name(id)); + ParsedIR::sanitize_underscores(alias_name); + set_name(alias_id, alias_name); + + emit_op(type_id, alias_id, to_expression(id), true); + temporary_to_mirror_precision_alias[id] = alias_id; + forced_temporaries.insert(id); + forced_temporaries.insert(alias_id); + force_recompile_guarantee_forward_progress(); + id = alias_id; + } + else + { + id = itr->second; + } + + return id; +} + +void CompilerGLSL::handle_invalid_expression(uint32_t id) +{ + // We tried to read an invalidated expression. + // This means we need another pass at compilation, but next time, + // force temporary variables so that they cannot be invalidated. + force_temporary_and_recompile(id); + + // If the invalid expression happened as a result of a CompositeInsert + // overwrite, we must block this from happening next iteration. + if (composite_insert_overwritten.count(id)) + block_composite_insert_overwrite.insert(id); +} + +// Converts the format of the current expression from packed to unpacked, +// by wrapping the expression in a constructor of the appropriate type. +// GLSL does not support packed formats, so simply return the expression. +// Subclasses that do will override. +string CompilerGLSL::unpack_expression_type(string expr_str, const SPIRType &, uint32_t, bool, bool) +{ + return expr_str; +} + +// Sometimes we proactively enclosed an expression where it turns out we might have not needed it after all. +void CompilerGLSL::strip_enclosed_expression(string &expr) +{ + if (expr.size() < 2 || expr.front() != '(' || expr.back() != ')') + return; + + // Have to make sure that our first and last parens actually enclose everything inside it. + uint32_t paren_count = 0; + for (auto &c : expr) + { + if (c == '(') + paren_count++; + else if (c == ')') + { + paren_count--; + + // If we hit 0 and this is not the final char, our first and final parens actually don't + // enclose the expression, and we cannot strip, e.g.: (a + b) * (c + d). + if (paren_count == 0 && &c != &expr.back()) + return; + } + } + expr.erase(expr.size() - 1, 1); + expr.erase(begin(expr)); +} + +bool CompilerGLSL::needs_enclose_expression(const std::string &expr) +{ + bool need_parens = false; + + // If the expression starts with a unary we need to enclose to deal with cases where we have back-to-back + // unary expressions. + if (!expr.empty()) + { + auto c = expr.front(); + if (c == '-' || c == '+' || c == '!' || c == '~' || c == '&' || c == '*') + need_parens = true; + } + + if (!need_parens) + { + uint32_t paren_count = 0; + for (auto c : expr) + { + if (c == '(' || c == '[') + paren_count++; + else if (c == ')' || c == ']') + { + assert(paren_count); + paren_count--; + } + else if (c == ' ' && paren_count == 0) + { + need_parens = true; + break; + } + } + assert(paren_count == 0); + } + + return need_parens; +} + +string CompilerGLSL::enclose_expression(const string &expr) +{ + // If this expression contains any spaces which are not enclosed by parentheses, + // we need to enclose it so we can treat the whole string as an expression. + // This happens when two expressions have been part of a binary op earlier. + if (needs_enclose_expression(expr)) + return join('(', expr, ')'); + else + return expr; +} + +string CompilerGLSL::dereference_expression(const SPIRType &expr_type, const std::string &expr) +{ + // If this expression starts with an address-of operator ('&'), then + // just return the part after the operator. + // TODO: Strip parens if unnecessary? + if (expr.front() == '&') + return expr.substr(1); + else if (backend.native_pointers) + return join('*', expr); + else if (is_physical_pointer(expr_type) && !is_physical_pointer_to_buffer_block(expr_type)) + return join(enclose_expression(expr), ".value"); + else + return expr; +} + +string CompilerGLSL::address_of_expression(const std::string &expr) +{ + if (expr.size() > 3 && expr[0] == '(' && expr[1] == '*' && expr.back() == ')') + { + // If we have an expression which looks like (*foo), taking the address of it is the same as stripping + // the first two and last characters. We might have to enclose the expression. + // This doesn't work for cases like (*foo + 10), + // but this is an r-value expression which we cannot take the address of anyways. + return enclose_expression(expr.substr(2, expr.size() - 3)); + } + else if (expr.front() == '*') + { + // If this expression starts with a dereference operator ('*'), then + // just return the part after the operator. + return expr.substr(1); + } + else + return join('&', enclose_expression(expr)); +} + +// Just like to_expression except that we enclose the expression inside parentheses if needed. +string CompilerGLSL::to_enclosed_expression(uint32_t id, bool register_expression_read) +{ + return enclose_expression(to_expression(id, register_expression_read)); +} + +// Used explicitly when we want to read a row-major expression, but without any transpose shenanigans. +// need_transpose must be forced to false. +string CompilerGLSL::to_unpacked_row_major_matrix_expression(uint32_t id) +{ + return unpack_expression_type(to_expression(id), expression_type(id), + get_extended_decoration(id, SPIRVCrossDecorationPhysicalTypeID), + has_extended_decoration(id, SPIRVCrossDecorationPhysicalTypePacked), true); +} + +string CompilerGLSL::to_unpacked_expression(uint32_t id, bool register_expression_read) +{ + // If we need to transpose, it will also take care of unpacking rules. + auto *e = maybe_get(id); + bool need_transpose = e && e->need_transpose; + bool is_remapped = has_extended_decoration(id, SPIRVCrossDecorationPhysicalTypeID); + bool is_packed = has_extended_decoration(id, SPIRVCrossDecorationPhysicalTypePacked); + + if (!need_transpose && (is_remapped || is_packed)) + { + return unpack_expression_type(to_expression(id, register_expression_read), + get_pointee_type(expression_type_id(id)), + get_extended_decoration(id, SPIRVCrossDecorationPhysicalTypeID), + has_extended_decoration(id, SPIRVCrossDecorationPhysicalTypePacked), false); + } + else + return to_expression(id, register_expression_read); +} + +string CompilerGLSL::to_enclosed_unpacked_expression(uint32_t id, bool register_expression_read) +{ + return enclose_expression(to_unpacked_expression(id, register_expression_read)); +} + +string CompilerGLSL::to_dereferenced_expression(uint32_t id, bool register_expression_read) +{ + auto &type = expression_type(id); + + if (is_pointer(type) && should_dereference(id)) + return dereference_expression(type, to_enclosed_expression(id, register_expression_read)); + else + return to_expression(id, register_expression_read); +} + +string CompilerGLSL::to_pointer_expression(uint32_t id, bool register_expression_read) +{ + auto &type = expression_type(id); + if (is_pointer(type) && expression_is_lvalue(id) && !should_dereference(id)) + return address_of_expression(to_enclosed_expression(id, register_expression_read)); + else + return to_unpacked_expression(id, register_expression_read); +} + +string CompilerGLSL::to_enclosed_pointer_expression(uint32_t id, bool register_expression_read) +{ + auto &type = expression_type(id); + if (is_pointer(type) && expression_is_lvalue(id) && !should_dereference(id)) + return address_of_expression(to_enclosed_expression(id, register_expression_read)); + else + return to_enclosed_unpacked_expression(id, register_expression_read); +} + +string CompilerGLSL::to_extract_component_expression(uint32_t id, uint32_t index) +{ + auto expr = to_enclosed_expression(id); + if (has_extended_decoration(id, SPIRVCrossDecorationPhysicalTypePacked)) + return join(expr, "[", index, "]"); + else + return join(expr, ".", index_to_swizzle(index)); +} + +string CompilerGLSL::to_extract_constant_composite_expression(uint32_t result_type, const SPIRConstant &c, + const uint32_t *chain, uint32_t length) +{ + // It is kinda silly if application actually enter this path since they know the constant up front. + // It is useful here to extract the plain constant directly. + SPIRConstant tmp; + tmp.constant_type = result_type; + auto &composite_type = get(c.constant_type); + assert(composite_type.basetype != SPIRType::Struct && composite_type.array.empty()); + assert(!c.specialization); + + if (is_matrix(composite_type)) + { + if (length == 2) + { + tmp.m.c[0].vecsize = 1; + tmp.m.columns = 1; + tmp.m.c[0].r[0] = c.m.c[chain[0]].r[chain[1]]; + } + else + { + assert(length == 1); + tmp.m.c[0].vecsize = composite_type.vecsize; + tmp.m.columns = 1; + tmp.m.c[0] = c.m.c[chain[0]]; + } + } + else + { + assert(length == 1); + tmp.m.c[0].vecsize = 1; + tmp.m.columns = 1; + tmp.m.c[0].r[0] = c.m.c[0].r[chain[0]]; + } + + return constant_expression(tmp); +} + +string CompilerGLSL::to_rerolled_array_expression(const SPIRType &parent_type, + const string &base_expr, const SPIRType &type) +{ + bool remapped_boolean = parent_type.basetype == SPIRType::Struct && + type.basetype == SPIRType::Boolean && + backend.boolean_in_struct_remapped_type != SPIRType::Boolean; + + SPIRType tmp_type { OpNop }; + if (remapped_boolean) + { + tmp_type = get(type.parent_type); + tmp_type.basetype = backend.boolean_in_struct_remapped_type; + } + else if (type.basetype == SPIRType::Boolean && backend.boolean_in_struct_remapped_type != SPIRType::Boolean) + { + // It's possible that we have an r-value expression that was OpLoaded from a struct. + // We have to reroll this and explicitly cast the input to bool, because the r-value is short. + tmp_type = get(type.parent_type); + remapped_boolean = true; + } + + uint32_t size = to_array_size_literal(type); + auto &parent = get(type.parent_type); + string expr = "{ "; + + for (uint32_t i = 0; i < size; i++) + { + auto subexpr = join(base_expr, "[", convert_to_string(i), "]"); + if (!is_array(parent)) + { + if (remapped_boolean) + subexpr = join(type_to_glsl(tmp_type), "(", subexpr, ")"); + expr += subexpr; + } + else + expr += to_rerolled_array_expression(parent_type, subexpr, parent); + + if (i + 1 < size) + expr += ", "; + } + + expr += " }"; + return expr; +} + +string CompilerGLSL::to_composite_constructor_expression(const SPIRType &parent_type, uint32_t id, bool block_like_type) +{ + auto &type = expression_type(id); + + bool reroll_array = false; + bool remapped_boolean = parent_type.basetype == SPIRType::Struct && + type.basetype == SPIRType::Boolean && + backend.boolean_in_struct_remapped_type != SPIRType::Boolean; + + if (is_array(type)) + { + reroll_array = !backend.array_is_value_type || + (block_like_type && !backend.array_is_value_type_in_buffer_blocks); + + if (remapped_boolean) + { + // Forced to reroll if we have to change bool[] to short[]. + reroll_array = true; + } + } + + if (reroll_array) + { + // For this case, we need to "re-roll" an array initializer from a temporary. + // We cannot simply pass the array directly, since it decays to a pointer and it cannot + // participate in a struct initializer. E.g. + // float arr[2] = { 1.0, 2.0 }; + // Foo foo = { arr }; must be transformed to + // Foo foo = { { arr[0], arr[1] } }; + // The array sizes cannot be deduced from specialization constants since we cannot use any loops. + + // We're only triggering one read of the array expression, but this is fine since arrays have to be declared + // as temporaries anyways. + return to_rerolled_array_expression(parent_type, to_enclosed_expression(id), type); + } + else + { + auto expr = to_unpacked_expression(id); + if (remapped_boolean) + { + auto tmp_type = type; + tmp_type.basetype = backend.boolean_in_struct_remapped_type; + expr = join(type_to_glsl(tmp_type), "(", expr, ")"); + } + + return expr; + } +} + +string CompilerGLSL::to_non_uniform_aware_expression(uint32_t id) +{ + string expr = to_expression(id); + + if (has_decoration(id, DecorationNonUniform)) + convert_non_uniform_expression(expr, id); + + return expr; +} + +string CompilerGLSL::to_expression(uint32_t id, bool register_expression_read) +{ + auto itr = invalid_expressions.find(id); + if (itr != end(invalid_expressions)) + handle_invalid_expression(id); + + if (ir.ids[id].get_type() == TypeExpression) + { + // We might have a more complex chain of dependencies. + // A possible scenario is that we + // + // %1 = OpLoad + // %2 = OpDoSomething %1 %1. here %2 will have a dependency on %1. + // %3 = OpDoSomethingAgain %2 %2. Here %3 will lose the link to %1 since we don't propagate the dependencies like that. + // OpStore %1 %foo // Here we can invalidate %1, and hence all expressions which depend on %1. Only %2 will know since it's part of invalid_expressions. + // %4 = OpDoSomethingAnotherTime %3 %3 // If we forward all expressions we will see %1 expression after store, not before. + // + // However, we can propagate up a list of depended expressions when we used %2, so we can check if %2 is invalid when reading %3 after the store, + // and see that we should not forward reads of the original variable. + auto &expr = get(id); + for (uint32_t dep : expr.expression_dependencies) + if (invalid_expressions.find(dep) != end(invalid_expressions)) + handle_invalid_expression(dep); + } + + if (register_expression_read) + track_expression_read(id); + + switch (ir.ids[id].get_type()) + { + case TypeExpression: + { + auto &e = get(id); + if (e.base_expression) + return to_enclosed_expression(e.base_expression) + e.expression; + else if (e.need_transpose) + { + // This should not be reached for access chains, since we always deal explicitly with transpose state + // when consuming an access chain expression. + uint32_t physical_type_id = get_extended_decoration(id, SPIRVCrossDecorationPhysicalTypeID); + bool is_packed = has_extended_decoration(id, SPIRVCrossDecorationPhysicalTypePacked); + bool relaxed = has_decoration(id, DecorationRelaxedPrecision); + return convert_row_major_matrix(e.expression, get(e.expression_type), physical_type_id, + is_packed, relaxed); + } + else if (flattened_structs.count(id)) + { + return load_flattened_struct(e.expression, get(e.expression_type)); + } + else + { + if (is_forcing_recompilation()) + { + // During first compilation phase, certain expression patterns can trigger exponential growth of memory. + // Avoid this by returning dummy expressions during this phase. + // Do not use empty expressions here, because those are sentinels for other cases. + return "_"; + } + else + return e.expression; + } + } + + case TypeConstant: + { + auto &c = get(id); + auto &type = get(c.constant_type); + + // WorkGroupSize may be a constant. + if (has_decoration(c.self, DecorationBuiltIn)) + return builtin_to_glsl(BuiltIn(get_decoration(c.self, DecorationBuiltIn)), StorageClassGeneric); + else if (c.specialization) + { + if (backend.workgroup_size_is_hidden) + { + int wg_index = get_constant_mapping_to_workgroup_component(c); + if (wg_index >= 0) + { + auto wg_size = join(builtin_to_glsl(BuiltInWorkgroupSize, StorageClassInput), vector_swizzle(1, wg_index)); + if (type.basetype != SPIRType::UInt) + wg_size = bitcast_expression(type, SPIRType::UInt, wg_size); + return wg_size; + } + } + + if (expression_is_forwarded(id)) + return constant_expression(c); + + return to_name(id); + } + else if (c.is_used_as_lut) + return to_name(id); + else if (type.basetype == SPIRType::Struct && !backend.can_declare_struct_inline) + return to_name(id); + else if (!type.array.empty() && !backend.can_declare_arrays_inline) + return to_name(id); + else + return constant_expression(c); + } + + case TypeConstantOp: + return to_name(id); + + case TypeVariable: + { + auto &var = get(id); + // If we try to use a loop variable before the loop header, we have to redirect it to the static expression, + // the variable has not been declared yet. + if (var.statically_assigned || (var.loop_variable && !var.loop_variable_enable)) + { + // We might try to load from a loop variable before it has been initialized. + // Prefer static expression and fallback to initializer. + if (var.static_expression) + return to_expression(var.static_expression); + else if (var.initializer) + return to_expression(var.initializer); + else + { + // We cannot declare the variable yet, so have to fake it. + uint32_t undef_id = ir.increase_bound_by(1); + return emit_uninitialized_temporary_expression(get_variable_data_type_id(var), undef_id).expression; + } + } + else if (var.deferred_declaration) + { + var.deferred_declaration = false; + return variable_decl(var); + } + else if (flattened_structs.count(id)) + { + return load_flattened_struct(to_name(id), get(var.basetype)); + } + else + { + auto &dec = ir.meta[var.self].decoration; + if (dec.builtin) + return builtin_to_glsl(dec.builtin_type, var.storage); + else + return to_name(id); + } + } + + case TypeCombinedImageSampler: + // This type should never be taken the expression of directly. + // The intention is that texture sampling functions will extract the image and samplers + // separately and take their expressions as needed. + // GLSL does not use this type because OpSampledImage immediately creates a combined image sampler + // expression ala sampler2D(texture, sampler). + SPIRV_CROSS_THROW("Combined image samplers have no default expression representation."); + + case TypeAccessChain: + // We cannot express this type. They only have meaning in other OpAccessChains, OpStore or OpLoad. + SPIRV_CROSS_THROW("Access chains have no default expression representation."); + + default: + return to_name(id); + } +} + +SmallVector CompilerGLSL::get_composite_constant_ids(ConstantID const_id) +{ + if (auto *constant = maybe_get(const_id)) + { + const auto &type = get(constant->constant_type); + if (is_array(type) || type.basetype == SPIRType::Struct) + return constant->subconstants; + if (is_matrix(type)) + return SmallVector(constant->m.id); + if (is_vector(type)) + return SmallVector(constant->m.c[0].id); + SPIRV_CROSS_THROW("Unexpected scalar constant!"); + } + if (!const_composite_insert_ids.count(const_id)) + SPIRV_CROSS_THROW("Unimplemented for this OpSpecConstantOp!"); + return const_composite_insert_ids[const_id]; +} + +void CompilerGLSL::fill_composite_constant(SPIRConstant &constant, TypeID type_id, + const SmallVector &initializers) +{ + auto &type = get(type_id); + constant.specialization = true; + if (is_array(type) || type.basetype == SPIRType::Struct) + { + constant.subconstants = initializers; + } + else if (is_matrix(type)) + { + constant.m.columns = type.columns; + for (uint32_t i = 0; i < type.columns; ++i) + { + constant.m.id[i] = initializers[i]; + constant.m.c[i].vecsize = type.vecsize; + } + } + else if (is_vector(type)) + { + constant.m.c[0].vecsize = type.vecsize; + for (uint32_t i = 0; i < type.vecsize; ++i) + constant.m.c[0].id[i] = initializers[i]; + } + else + SPIRV_CROSS_THROW("Unexpected scalar in SpecConstantOp CompositeInsert!"); +} + +void CompilerGLSL::set_composite_constant(ConstantID const_id, TypeID type_id, + const SmallVector &initializers) +{ + if (maybe_get(const_id)) + { + const_composite_insert_ids[const_id] = initializers; + return; + } + + auto &constant = set(const_id, type_id); + fill_composite_constant(constant, type_id, initializers); + forwarded_temporaries.insert(const_id); +} + +TypeID CompilerGLSL::get_composite_member_type(TypeID type_id, uint32_t member_idx) +{ + auto &type = get(type_id); + if (is_array(type)) + return type.parent_type; + if (type.basetype == SPIRType::Struct) + return type.member_types[member_idx]; + if (is_matrix(type)) + return type.parent_type; + if (is_vector(type)) + return type.parent_type; + SPIRV_CROSS_THROW("Shouldn't reach lower than vector handling OpSpecConstantOp CompositeInsert!"); +} + +string CompilerGLSL::constant_op_expression(const SPIRConstantOp &cop) +{ + auto &type = get(cop.basetype); + bool binary = false; + bool unary = false; + string op; + + if (is_legacy() && is_unsigned_opcode(cop.opcode)) + SPIRV_CROSS_THROW("Unsigned integers are not supported on legacy targets."); + + // TODO: Find a clean way to reuse emit_instruction. + switch (cop.opcode) + { + case OpSConvert: + case OpUConvert: + case OpFConvert: + op = type_to_glsl_constructor(type); + break; + +#define GLSL_BOP(opname, x) \ + case Op##opname: \ + binary = true; \ + op = x; \ + break + +#define GLSL_UOP(opname, x) \ + case Op##opname: \ + unary = true; \ + op = x; \ + break + + GLSL_UOP(SNegate, "-"); + GLSL_UOP(Not, "~"); + GLSL_BOP(IAdd, "+"); + GLSL_BOP(ISub, "-"); + GLSL_BOP(IMul, "*"); + GLSL_BOP(SDiv, "/"); + GLSL_BOP(UDiv, "/"); + GLSL_BOP(UMod, "%"); + GLSL_BOP(SMod, "%"); + GLSL_BOP(ShiftRightLogical, ">>"); + GLSL_BOP(ShiftRightArithmetic, ">>"); + GLSL_BOP(ShiftLeftLogical, "<<"); + GLSL_BOP(BitwiseOr, "|"); + GLSL_BOP(BitwiseXor, "^"); + GLSL_BOP(BitwiseAnd, "&"); + GLSL_BOP(LogicalOr, "||"); + GLSL_BOP(LogicalAnd, "&&"); + GLSL_UOP(LogicalNot, "!"); + GLSL_BOP(LogicalEqual, "=="); + GLSL_BOP(LogicalNotEqual, "!="); + GLSL_BOP(IEqual, "=="); + GLSL_BOP(INotEqual, "!="); + GLSL_BOP(ULessThan, "<"); + GLSL_BOP(SLessThan, "<"); + GLSL_BOP(ULessThanEqual, "<="); + GLSL_BOP(SLessThanEqual, "<="); + GLSL_BOP(UGreaterThan, ">"); + GLSL_BOP(SGreaterThan, ">"); + GLSL_BOP(UGreaterThanEqual, ">="); + GLSL_BOP(SGreaterThanEqual, ">="); + + case OpSRem: + { + uint32_t op0 = cop.arguments[0]; + uint32_t op1 = cop.arguments[1]; + return join(to_enclosed_expression(op0), " - ", to_enclosed_expression(op1), " * ", "(", + to_enclosed_expression(op0), " / ", to_enclosed_expression(op1), ")"); + } + + case OpSelect: + { + if (cop.arguments.size() < 3) + SPIRV_CROSS_THROW("Not enough arguments to OpSpecConstantOp."); + + // This one is pretty annoying. It's triggered from + // uint(bool), int(bool) from spec constants. + // In order to preserve its compile-time constness in Vulkan GLSL, + // we need to reduce the OpSelect expression back to this simplified model. + // If we cannot, fail. + if (to_trivial_mix_op(type, op, cop.arguments[2], cop.arguments[1], cop.arguments[0])) + { + // Implement as a simple cast down below. + } + else + { + // Implement a ternary and pray the compiler understands it :) + return to_ternary_expression(type, cop.arguments[0], cop.arguments[1], cop.arguments[2]); + } + break; + } + + case OpVectorShuffle: + { + string expr = type_to_glsl_constructor(type); + expr += "("; + + uint32_t left_components = expression_type(cop.arguments[0]).vecsize; + string left_arg = to_enclosed_expression(cop.arguments[0]); + string right_arg = to_enclosed_expression(cop.arguments[1]); + + for (uint32_t i = 2; i < uint32_t(cop.arguments.size()); i++) + { + uint32_t index = cop.arguments[i]; + if (index == 0xFFFFFFFF) + { + SPIRConstant c; + c.constant_type = type.parent_type; + assert(type.parent_type != ID(0)); + expr += constant_expression(c); + } + else if (index >= left_components) + { + expr += right_arg + "." + "xyzw"[index - left_components]; + } + else + { + expr += left_arg + "." + "xyzw"[index]; + } + + if (i + 1 < uint32_t(cop.arguments.size())) + expr += ", "; + } + + expr += ")"; + return expr; + } + + case OpCompositeExtract: + { + auto expr = access_chain_internal(cop.arguments[0], &cop.arguments[1], uint32_t(cop.arguments.size() - 1), + ACCESS_CHAIN_INDEX_IS_LITERAL_BIT, nullptr); + return expr; + } + + case OpCompositeInsert: + { + SmallVector new_init = get_composite_constant_ids(cop.arguments[1]); + uint32_t idx; + uint32_t target_id = cop.self; + uint32_t target_type_id = cop.basetype; + // We have to drill down to the part we want to modify, and create new + // constants for each containing part. + for (idx = 2; idx < cop.arguments.size() - 1; ++idx) + { + uint32_t new_const = ir.increase_bound_by(1); + uint32_t old_const = new_init[cop.arguments[idx]]; + new_init[cop.arguments[idx]] = new_const; + set_composite_constant(target_id, target_type_id, new_init); + new_init = get_composite_constant_ids(old_const); + target_id = new_const; + target_type_id = get_composite_member_type(target_type_id, cop.arguments[idx]); + } + // Now replace the initializer with the one from this instruction. + new_init[cop.arguments[idx]] = cop.arguments[0]; + set_composite_constant(target_id, target_type_id, new_init); + SPIRConstant tmp_const(cop.basetype); + fill_composite_constant(tmp_const, cop.basetype, const_composite_insert_ids[cop.self]); + return constant_expression(tmp_const); + } + + default: + // Some opcodes are unimplemented here, these are currently not possible to test from glslang. + SPIRV_CROSS_THROW("Unimplemented spec constant op."); + } + + uint32_t bit_width = 0; + if (unary || binary || cop.opcode == OpSConvert || cop.opcode == OpUConvert) + bit_width = expression_type(cop.arguments[0]).width; + + SPIRType::BaseType input_type; + bool skip_cast_if_equal_type = opcode_is_sign_invariant(cop.opcode); + + switch (cop.opcode) + { + case OpIEqual: + case OpINotEqual: + input_type = to_signed_basetype(bit_width); + break; + + case OpSLessThan: + case OpSLessThanEqual: + case OpSGreaterThan: + case OpSGreaterThanEqual: + case OpSMod: + case OpSDiv: + case OpShiftRightArithmetic: + case OpSConvert: + case OpSNegate: + input_type = to_signed_basetype(bit_width); + break; + + case OpULessThan: + case OpULessThanEqual: + case OpUGreaterThan: + case OpUGreaterThanEqual: + case OpUMod: + case OpUDiv: + case OpShiftRightLogical: + case OpUConvert: + input_type = to_unsigned_basetype(bit_width); + break; + + default: + input_type = type.basetype; + break; + } + +#undef GLSL_BOP +#undef GLSL_UOP + if (binary) + { + if (cop.arguments.size() < 2) + SPIRV_CROSS_THROW("Not enough arguments to OpSpecConstantOp."); + + string cast_op0; + string cast_op1; + auto expected_type = binary_op_bitcast_helper(cast_op0, cast_op1, input_type, cop.arguments[0], + cop.arguments[1], skip_cast_if_equal_type); + + if (type.basetype != input_type && type.basetype != SPIRType::Boolean) + { + expected_type.basetype = input_type; + auto expr = bitcast_glsl_op(type, expected_type); + expr += '('; + expr += join(cast_op0, " ", op, " ", cast_op1); + expr += ')'; + return expr; + } + else + return join("(", cast_op0, " ", op, " ", cast_op1, ")"); + } + else if (unary) + { + if (cop.arguments.size() < 1) + SPIRV_CROSS_THROW("Not enough arguments to OpSpecConstantOp."); + + // Auto-bitcast to result type as needed. + // Works around various casting scenarios in glslang as there is no OpBitcast for specialization constants. + return join("(", op, bitcast_glsl(type, cop.arguments[0]), ")"); + } + else if (cop.opcode == OpSConvert || cop.opcode == OpUConvert) + { + if (cop.arguments.size() < 1) + SPIRV_CROSS_THROW("Not enough arguments to OpSpecConstantOp."); + + auto &arg_type = expression_type(cop.arguments[0]); + if (arg_type.width < type.width && input_type != arg_type.basetype) + { + auto expected = arg_type; + expected.basetype = input_type; + return join(op, "(", bitcast_glsl(expected, cop.arguments[0]), ")"); + } + else + return join(op, "(", to_expression(cop.arguments[0]), ")"); + } + else + { + if (cop.arguments.size() < 1) + SPIRV_CROSS_THROW("Not enough arguments to OpSpecConstantOp."); + return join(op, "(", to_expression(cop.arguments[0]), ")"); + } +} + +string CompilerGLSL::constant_expression(const SPIRConstant &c, + bool inside_block_like_struct_scope, + bool inside_struct_scope) +{ + auto &type = get(c.constant_type); + + if (is_pointer(type)) + { + return backend.null_pointer_literal; + } + else if (!c.subconstants.empty()) + { + // Handles Arrays and structures. + string res; + + // Only consider the decay if we are inside a struct scope where we are emitting a member with Offset decoration. + // Outside a block-like struct declaration, we can always bind to a constant array with templated type. + // Should look at ArrayStride here as well, but it's possible to declare a constant struct + // with Offset = 0, using no ArrayStride on the enclosed array type. + // A particular CTS test hits this scenario. + bool array_type_decays = inside_block_like_struct_scope && + is_array(type) && + !backend.array_is_value_type_in_buffer_blocks; + + // Allow Metal to use the array template to make arrays a value type + bool needs_trailing_tracket = false; + if (backend.use_initializer_list && backend.use_typed_initializer_list && type.basetype == SPIRType::Struct && + !is_array(type)) + { + res = type_to_glsl_constructor(type) + "{ "; + } + else if (backend.use_initializer_list && backend.use_typed_initializer_list && backend.array_is_value_type && + is_array(type) && !array_type_decays) + { + const auto *p_type = &type; + SPIRType tmp_type { OpNop }; + + if (inside_struct_scope && + backend.boolean_in_struct_remapped_type != SPIRType::Boolean && + type.basetype == SPIRType::Boolean) + { + tmp_type = type; + tmp_type.basetype = backend.boolean_in_struct_remapped_type; + p_type = &tmp_type; + } + + res = type_to_glsl_constructor(*p_type) + "({ "; + needs_trailing_tracket = true; + } + else if (backend.use_initializer_list) + { + res = "{ "; + } + else + { + res = type_to_glsl_constructor(type) + "("; + } + + uint32_t subconstant_index = 0; + for (auto &elem : c.subconstants) + { + if (auto *op = maybe_get(elem)) + { + res += constant_op_expression(*op); + } + else if (maybe_get(elem) != nullptr) + { + res += to_name(elem); + } + else + { + auto &subc = get(elem); + if (subc.specialization && !expression_is_forwarded(elem)) + res += to_name(elem); + else + { + if (!is_array(type) && type.basetype == SPIRType::Struct) + { + // When we get down to emitting struct members, override the block-like information. + // For constants, we can freely mix and match block-like state. + inside_block_like_struct_scope = + has_member_decoration(type.self, subconstant_index, DecorationOffset); + } + + if (type.basetype == SPIRType::Struct) + inside_struct_scope = true; + + res += constant_expression(subc, inside_block_like_struct_scope, inside_struct_scope); + } + } + + if (&elem != &c.subconstants.back()) + res += ", "; + + subconstant_index++; + } + + res += backend.use_initializer_list ? " }" : ")"; + if (needs_trailing_tracket) + res += ")"; + + return res; + } + else if (type.basetype == SPIRType::Struct && type.member_types.size() == 0) + { + // Metal tessellation likes empty structs which are then constant expressions. + if (backend.supports_empty_struct) + return "{ }"; + else if (backend.use_typed_initializer_list) + return join(type_to_glsl(type), "{ 0 }"); + else if (backend.use_initializer_list) + return "{ 0 }"; + else + return join(type_to_glsl(type), "(0)"); + } + else if (c.columns() == 1) + { + auto res = constant_expression_vector(c, 0); + + if (inside_struct_scope && + backend.boolean_in_struct_remapped_type != SPIRType::Boolean && + type.basetype == SPIRType::Boolean) + { + SPIRType tmp_type = type; + tmp_type.basetype = backend.boolean_in_struct_remapped_type; + res = join(type_to_glsl(tmp_type), "(", res, ")"); + } + + return res; + } + else + { + string res = type_to_glsl(type) + "("; + for (uint32_t col = 0; col < c.columns(); col++) + { + if (c.specialization_constant_id(col) != 0) + res += to_name(c.specialization_constant_id(col)); + else + res += constant_expression_vector(c, col); + + if (col + 1 < c.columns()) + res += ", "; + } + res += ")"; + + if (inside_struct_scope && + backend.boolean_in_struct_remapped_type != SPIRType::Boolean && + type.basetype == SPIRType::Boolean) + { + SPIRType tmp_type = type; + tmp_type.basetype = backend.boolean_in_struct_remapped_type; + res = join(type_to_glsl(tmp_type), "(", res, ")"); + } + + return res; + } +} + +#ifdef _MSC_VER +// snprintf does not exist or is buggy on older MSVC versions, some of them +// being used by MinGW. Use sprintf instead and disable corresponding warning. +#pragma warning(push) +#pragma warning(disable : 4996) +#endif + +string CompilerGLSL::convert_half_to_string(const SPIRConstant &c, uint32_t col, uint32_t row) +{ + string res; + float float_value = c.scalar_f16(col, row); + + // There is no literal "hf" in GL_NV_gpu_shader5, so to avoid lots + // of complicated workarounds, just value-cast to the half type always. + if (std::isnan(float_value) || std::isinf(float_value)) + { + SPIRType type { OpTypeFloat }; + type.basetype = SPIRType::Half; + type.vecsize = 1; + type.columns = 1; + + if (float_value == numeric_limits::infinity()) + res = join(type_to_glsl(type), "(1.0 / 0.0)"); + else if (float_value == -numeric_limits::infinity()) + res = join(type_to_glsl(type), "(-1.0 / 0.0)"); + else if (std::isnan(float_value)) + res = join(type_to_glsl(type), "(0.0 / 0.0)"); + else + SPIRV_CROSS_THROW("Cannot represent non-finite floating point constant."); + } + else + { + SPIRType type { OpTypeFloat }; + type.basetype = SPIRType::Half; + type.vecsize = 1; + type.columns = 1; + res = join(type_to_glsl(type), "(", format_float(float_value), ")"); + } + + return res; +} + +string CompilerGLSL::convert_float_to_string(const SPIRConstant &c, uint32_t col, uint32_t row) +{ + string res; + float float_value = c.scalar_f32(col, row); + + if (std::isnan(float_value) || std::isinf(float_value)) + { + // Use special representation. + if (!is_legacy()) + { + SPIRType out_type { OpTypeFloat }; + SPIRType in_type { OpTypeInt }; + out_type.basetype = SPIRType::Float; + in_type.basetype = SPIRType::UInt; + out_type.vecsize = 1; + in_type.vecsize = 1; + out_type.width = 32; + in_type.width = 32; + + char print_buffer[32]; +#ifdef _WIN32 + sprintf(print_buffer, "0x%xu", c.scalar(col, row)); +#else + snprintf(print_buffer, sizeof(print_buffer), "0x%xu", c.scalar(col, row)); +#endif + + const char *comment = "inf"; + if (float_value == -numeric_limits::infinity()) + comment = "-inf"; + else if (std::isnan(float_value)) + comment = "nan"; + res = join(bitcast_glsl_op(out_type, in_type), "(", print_buffer, " /* ", comment, " */)"); + } + else + { + if (float_value == numeric_limits::infinity()) + { + if (backend.float_literal_suffix) + res = "(1.0f / 0.0f)"; + else + res = "(1.0 / 0.0)"; + } + else if (float_value == -numeric_limits::infinity()) + { + if (backend.float_literal_suffix) + res = "(-1.0f / 0.0f)"; + else + res = "(-1.0 / 0.0)"; + } + else if (std::isnan(float_value)) + { + if (backend.float_literal_suffix) + res = "(0.0f / 0.0f)"; + else + res = "(0.0 / 0.0)"; + } + else + SPIRV_CROSS_THROW("Cannot represent non-finite floating point constant."); + } + } + else + { + res = format_float(float_value); + if (backend.float_literal_suffix) + res += "f"; + } + + return res; +} + +std::string CompilerGLSL::convert_double_to_string(const SPIRConstant &c, uint32_t col, uint32_t row) +{ + string res; + double double_value = c.scalar_f64(col, row); + + if (std::isnan(double_value) || std::isinf(double_value)) + { + // Use special representation. + if (!is_legacy()) + { + SPIRType out_type { OpTypeFloat }; + SPIRType in_type { OpTypeInt }; + out_type.basetype = SPIRType::Double; + in_type.basetype = SPIRType::UInt64; + out_type.vecsize = 1; + in_type.vecsize = 1; + out_type.width = 64; + in_type.width = 64; + + uint64_t u64_value = c.scalar_u64(col, row); + + if (options.es && options.version < 310) // GL_NV_gpu_shader5 fallback requires 310. + SPIRV_CROSS_THROW("64-bit integers not supported in ES profile before version 310."); + require_extension_internal("GL_ARB_gpu_shader_int64"); + + char print_buffer[64]; +#ifdef _WIN32 + sprintf(print_buffer, "0x%llx%s", static_cast(u64_value), + backend.long_long_literal_suffix ? "ull" : "ul"); +#else + snprintf(print_buffer, sizeof(print_buffer), "0x%llx%s", static_cast(u64_value), + backend.long_long_literal_suffix ? "ull" : "ul"); +#endif + + const char *comment = "inf"; + if (double_value == -numeric_limits::infinity()) + comment = "-inf"; + else if (std::isnan(double_value)) + comment = "nan"; + res = join(bitcast_glsl_op(out_type, in_type), "(", print_buffer, " /* ", comment, " */)"); + } + else + { + if (options.es) + SPIRV_CROSS_THROW("FP64 not supported in ES profile."); + if (options.version < 400) + require_extension_internal("GL_ARB_gpu_shader_fp64"); + + if (double_value == numeric_limits::infinity()) + { + if (backend.double_literal_suffix) + res = "(1.0lf / 0.0lf)"; + else + res = "(1.0 / 0.0)"; + } + else if (double_value == -numeric_limits::infinity()) + { + if (backend.double_literal_suffix) + res = "(-1.0lf / 0.0lf)"; + else + res = "(-1.0 / 0.0)"; + } + else if (std::isnan(double_value)) + { + if (backend.double_literal_suffix) + res = "(0.0lf / 0.0lf)"; + else + res = "(0.0 / 0.0)"; + } + else + SPIRV_CROSS_THROW("Cannot represent non-finite floating point constant."); + } + } + else + { + res = format_double(double_value); + if (backend.double_literal_suffix) + res += "lf"; + } + + return res; +} + +#ifdef _MSC_VER +#pragma warning(pop) +#endif + +string CompilerGLSL::constant_expression_vector(const SPIRConstant &c, uint32_t vector) +{ + auto type = get(c.constant_type); + type.columns = 1; + + auto scalar_type = type; + scalar_type.vecsize = 1; + + string res; + bool splat = backend.use_constructor_splatting && c.vector_size() > 1; + bool swizzle_splat = backend.can_swizzle_scalar && c.vector_size() > 1; + + if (!type_is_floating_point(type)) + { + // Cannot swizzle literal integers as a special case. + swizzle_splat = false; + } + + if (splat || swizzle_splat) + { + // Cannot use constant splatting if we have specialization constants somewhere in the vector. + for (uint32_t i = 0; i < c.vector_size(); i++) + { + if (c.specialization_constant_id(vector, i) != 0) + { + splat = false; + swizzle_splat = false; + break; + } + } + } + + if (splat || swizzle_splat) + { + if (type.width == 64) + { + uint64_t ident = c.scalar_u64(vector, 0); + for (uint32_t i = 1; i < c.vector_size(); i++) + { + if (ident != c.scalar_u64(vector, i)) + { + splat = false; + swizzle_splat = false; + break; + } + } + } + else + { + uint32_t ident = c.scalar(vector, 0); + for (uint32_t i = 1; i < c.vector_size(); i++) + { + if (ident != c.scalar(vector, i)) + { + splat = false; + swizzle_splat = false; + } + } + } + } + + if (c.vector_size() > 1 && !swizzle_splat) + res += type_to_glsl(type) + "("; + + switch (type.basetype) + { + case SPIRType::Half: + if (splat || swizzle_splat) + { + res += convert_half_to_string(c, vector, 0); + if (swizzle_splat) + res = remap_swizzle(get(c.constant_type), 1, res); + } + else + { + for (uint32_t i = 0; i < c.vector_size(); i++) + { + if (c.vector_size() > 1 && c.specialization_constant_id(vector, i) != 0) + res += to_expression(c.specialization_constant_id(vector, i)); + else + res += convert_half_to_string(c, vector, i); + + if (i + 1 < c.vector_size()) + res += ", "; + } + } + break; + + case SPIRType::Float: + if (splat || swizzle_splat) + { + res += convert_float_to_string(c, vector, 0); + if (swizzle_splat) + res = remap_swizzle(get(c.constant_type), 1, res); + } + else + { + for (uint32_t i = 0; i < c.vector_size(); i++) + { + if (c.vector_size() > 1 && c.specialization_constant_id(vector, i) != 0) + res += to_expression(c.specialization_constant_id(vector, i)); + else + res += convert_float_to_string(c, vector, i); + + if (i + 1 < c.vector_size()) + res += ", "; + } + } + break; + + case SPIRType::Double: + if (splat || swizzle_splat) + { + res += convert_double_to_string(c, vector, 0); + if (swizzle_splat) + res = remap_swizzle(get(c.constant_type), 1, res); + } + else + { + for (uint32_t i = 0; i < c.vector_size(); i++) + { + if (c.vector_size() > 1 && c.specialization_constant_id(vector, i) != 0) + res += to_expression(c.specialization_constant_id(vector, i)); + else + res += convert_double_to_string(c, vector, i); + + if (i + 1 < c.vector_size()) + res += ", "; + } + } + break; + + case SPIRType::Int64: + { + auto tmp = type; + tmp.vecsize = 1; + tmp.columns = 1; + auto int64_type = type_to_glsl(tmp); + + if (splat) + { + res += convert_to_string(c.scalar_i64(vector, 0), int64_type, backend.long_long_literal_suffix); + } + else + { + for (uint32_t i = 0; i < c.vector_size(); i++) + { + if (c.vector_size() > 1 && c.specialization_constant_id(vector, i) != 0) + res += to_expression(c.specialization_constant_id(vector, i)); + else + res += convert_to_string(c.scalar_i64(vector, i), int64_type, backend.long_long_literal_suffix); + + if (i + 1 < c.vector_size()) + res += ", "; + } + } + break; + } + + case SPIRType::UInt64: + if (splat) + { + res += convert_to_string(c.scalar_u64(vector, 0)); + if (backend.long_long_literal_suffix) + res += "ull"; + else + res += "ul"; + } + else + { + for (uint32_t i = 0; i < c.vector_size(); i++) + { + if (c.vector_size() > 1 && c.specialization_constant_id(vector, i) != 0) + res += to_expression(c.specialization_constant_id(vector, i)); + else + { + res += convert_to_string(c.scalar_u64(vector, i)); + if (backend.long_long_literal_suffix) + res += "ull"; + else + res += "ul"; + } + + if (i + 1 < c.vector_size()) + res += ", "; + } + } + break; + + case SPIRType::UInt: + if (splat) + { + res += convert_to_string(c.scalar(vector, 0)); + if (is_legacy()) + { + // Fake unsigned constant literals with signed ones if possible. + // Things like array sizes, etc, tend to be unsigned even though they could just as easily be signed. + if (c.scalar_i32(vector, 0) < 0) + SPIRV_CROSS_THROW("Tried to convert uint literal into int, but this made the literal negative."); + } + else if (backend.uint32_t_literal_suffix) + res += "u"; + } + else + { + for (uint32_t i = 0; i < c.vector_size(); i++) + { + if (c.vector_size() > 1 && c.specialization_constant_id(vector, i) != 0) + res += to_expression(c.specialization_constant_id(vector, i)); + else + { + res += convert_to_string(c.scalar(vector, i)); + if (is_legacy()) + { + // Fake unsigned constant literals with signed ones if possible. + // Things like array sizes, etc, tend to be unsigned even though they could just as easily be signed. + if (c.scalar_i32(vector, i) < 0) + SPIRV_CROSS_THROW("Tried to convert uint literal into int, but this made " + "the literal negative."); + } + else if (backend.uint32_t_literal_suffix) + res += "u"; + } + + if (i + 1 < c.vector_size()) + res += ", "; + } + } + break; + + case SPIRType::Int: + if (splat) + res += convert_to_string(c.scalar_i32(vector, 0)); + else + { + for (uint32_t i = 0; i < c.vector_size(); i++) + { + if (c.vector_size() > 1 && c.specialization_constant_id(vector, i) != 0) + res += to_expression(c.specialization_constant_id(vector, i)); + else + res += convert_to_string(c.scalar_i32(vector, i)); + if (i + 1 < c.vector_size()) + res += ", "; + } + } + break; + + case SPIRType::UShort: + if (splat) + { + res += convert_to_string(c.scalar(vector, 0)); + } + else + { + for (uint32_t i = 0; i < c.vector_size(); i++) + { + if (c.vector_size() > 1 && c.specialization_constant_id(vector, i) != 0) + res += to_expression(c.specialization_constant_id(vector, i)); + else + { + if (*backend.uint16_t_literal_suffix) + { + res += convert_to_string(c.scalar_u16(vector, i)); + res += backend.uint16_t_literal_suffix; + } + else + { + // If backend doesn't have a literal suffix, we need to value cast. + res += type_to_glsl(scalar_type); + res += "("; + res += convert_to_string(c.scalar_u16(vector, i)); + res += ")"; + } + } + + if (i + 1 < c.vector_size()) + res += ", "; + } + } + break; + + case SPIRType::Short: + if (splat) + { + res += convert_to_string(c.scalar_i16(vector, 0)); + } + else + { + for (uint32_t i = 0; i < c.vector_size(); i++) + { + if (c.vector_size() > 1 && c.specialization_constant_id(vector, i) != 0) + res += to_expression(c.specialization_constant_id(vector, i)); + else + { + if (*backend.int16_t_literal_suffix) + { + res += convert_to_string(c.scalar_i16(vector, i)); + res += backend.int16_t_literal_suffix; + } + else + { + // If backend doesn't have a literal suffix, we need to value cast. + res += type_to_glsl(scalar_type); + res += "("; + res += convert_to_string(c.scalar_i16(vector, i)); + res += ")"; + } + } + + if (i + 1 < c.vector_size()) + res += ", "; + } + } + break; + + case SPIRType::UByte: + if (splat) + { + res += convert_to_string(c.scalar_u8(vector, 0)); + } + else + { + for (uint32_t i = 0; i < c.vector_size(); i++) + { + if (c.vector_size() > 1 && c.specialization_constant_id(vector, i) != 0) + res += to_expression(c.specialization_constant_id(vector, i)); + else + { + res += type_to_glsl(scalar_type); + res += "("; + res += convert_to_string(c.scalar_u8(vector, i)); + res += ")"; + } + + if (i + 1 < c.vector_size()) + res += ", "; + } + } + break; + + case SPIRType::SByte: + if (splat) + { + res += convert_to_string(c.scalar_i8(vector, 0)); + } + else + { + for (uint32_t i = 0; i < c.vector_size(); i++) + { + if (c.vector_size() > 1 && c.specialization_constant_id(vector, i) != 0) + res += to_expression(c.specialization_constant_id(vector, i)); + else + { + res += type_to_glsl(scalar_type); + res += "("; + res += convert_to_string(c.scalar_i8(vector, i)); + res += ")"; + } + + if (i + 1 < c.vector_size()) + res += ", "; + } + } + break; + + case SPIRType::Boolean: + if (splat) + res += c.scalar(vector, 0) ? "true" : "false"; + else + { + for (uint32_t i = 0; i < c.vector_size(); i++) + { + if (c.vector_size() > 1 && c.specialization_constant_id(vector, i) != 0) + res += to_expression(c.specialization_constant_id(vector, i)); + else + res += c.scalar(vector, i) ? "true" : "false"; + + if (i + 1 < c.vector_size()) + res += ", "; + } + } + break; + + default: + SPIRV_CROSS_THROW("Invalid constant expression basetype."); + } + + if (c.vector_size() > 1 && !swizzle_splat) + res += ")"; + + return res; +} + +SPIRExpression &CompilerGLSL::emit_uninitialized_temporary_expression(uint32_t type, uint32_t id) +{ + forced_temporaries.insert(id); + emit_uninitialized_temporary(type, id); + return set(id, to_name(id), type, true); +} + +void CompilerGLSL::emit_uninitialized_temporary(uint32_t result_type, uint32_t result_id) +{ + // If we're declaring temporaries inside continue blocks, + // we must declare the temporary in the loop header so that the continue block can avoid declaring new variables. + if (!block_temporary_hoisting && current_continue_block && !hoisted_temporaries.count(result_id)) + { + auto &header = get(current_continue_block->loop_dominator); + if (find_if(begin(header.declare_temporary), end(header.declare_temporary), + [result_type, result_id](const pair &tmp) { + return tmp.first == result_type && tmp.second == result_id; + }) == end(header.declare_temporary)) + { + header.declare_temporary.emplace_back(result_type, result_id); + hoisted_temporaries.insert(result_id); + force_recompile(); + } + } + else if (hoisted_temporaries.count(result_id) == 0) + { + auto &type = get(result_type); + auto &flags = get_decoration_bitset(result_id); + + // The result_id has not been made into an expression yet, so use flags interface. + add_local_variable_name(result_id); + + string initializer; + if (options.force_zero_initialized_variables && type_can_zero_initialize(type)) + initializer = join(" = ", to_zero_initialized_expression(result_type)); + + statement(flags_to_qualifiers_glsl(type, flags), variable_decl(type, to_name(result_id)), initializer, ";"); + } +} + +string CompilerGLSL::declare_temporary(uint32_t result_type, uint32_t result_id) +{ + auto &type = get(result_type); + + // If we're declaring temporaries inside continue blocks, + // we must declare the temporary in the loop header so that the continue block can avoid declaring new variables. + if (!block_temporary_hoisting && current_continue_block && !hoisted_temporaries.count(result_id)) + { + auto &header = get(current_continue_block->loop_dominator); + if (find_if(begin(header.declare_temporary), end(header.declare_temporary), + [result_type, result_id](const pair &tmp) { + return tmp.first == result_type && tmp.second == result_id; + }) == end(header.declare_temporary)) + { + header.declare_temporary.emplace_back(result_type, result_id); + hoisted_temporaries.insert(result_id); + force_recompile_guarantee_forward_progress(); + } + + return join(to_name(result_id), " = "); + } + else if (hoisted_temporaries.count(result_id)) + { + // The temporary has already been declared earlier, so just "declare" the temporary by writing to it. + return join(to_name(result_id), " = "); + } + else + { + // The result_id has not been made into an expression yet, so use flags interface. + add_local_variable_name(result_id); + auto &flags = get_decoration_bitset(result_id); + return join(flags_to_qualifiers_glsl(type, flags), variable_decl(type, to_name(result_id)), " = "); + } +} + +bool CompilerGLSL::expression_is_forwarded(uint32_t id) const +{ + return forwarded_temporaries.count(id) != 0; +} + +bool CompilerGLSL::expression_suppresses_usage_tracking(uint32_t id) const +{ + return suppressed_usage_tracking.count(id) != 0; +} + +bool CompilerGLSL::expression_read_implies_multiple_reads(uint32_t id) const +{ + auto *expr = maybe_get(id); + if (!expr) + return false; + + // If we're emitting code at a deeper loop level than when we emitted the expression, + // we're probably reading the same expression over and over. + return current_loop_level > expr->emitted_loop_level; +} + +SPIRExpression &CompilerGLSL::emit_op(uint32_t result_type, uint32_t result_id, const string &rhs, bool forwarding, + bool suppress_usage_tracking) +{ + if (forwarding && (forced_temporaries.find(result_id) == end(forced_temporaries))) + { + // Just forward it without temporary. + // If the forward is trivial, we do not force flushing to temporary for this expression. + forwarded_temporaries.insert(result_id); + if (suppress_usage_tracking) + suppressed_usage_tracking.insert(result_id); + + return set(result_id, rhs, result_type, true); + } + else + { + // If expression isn't immutable, bind it to a temporary and make the new temporary immutable (they always are). + statement(declare_temporary(result_type, result_id), rhs, ";"); + return set(result_id, to_name(result_id), result_type, true); + } +} + +void CompilerGLSL::emit_unary_op(uint32_t result_type, uint32_t result_id, uint32_t op0, const char *op) +{ + bool forward = should_forward(op0); + emit_op(result_type, result_id, join(op, to_enclosed_unpacked_expression(op0)), forward); + inherit_expression_dependencies(result_id, op0); +} + +void CompilerGLSL::emit_unary_op_cast(uint32_t result_type, uint32_t result_id, uint32_t op0, const char *op) +{ + auto &type = get(result_type); + bool forward = should_forward(op0); + emit_op(result_type, result_id, join(type_to_glsl(type), "(", op, to_enclosed_unpacked_expression(op0), ")"), forward); + inherit_expression_dependencies(result_id, op0); +} + +void CompilerGLSL::emit_mesh_tasks(SPIRBlock &block) +{ + statement("EmitMeshTasksEXT(", + to_unpacked_expression(block.mesh.groups[0]), ", ", + to_unpacked_expression(block.mesh.groups[1]), ", ", + to_unpacked_expression(block.mesh.groups[2]), ");"); +} + +void CompilerGLSL::emit_binary_op(uint32_t result_type, uint32_t result_id, uint32_t op0, uint32_t op1, const char *op) +{ + // Various FP arithmetic opcodes such as add, sub, mul will hit this. + bool force_temporary_precise = backend.support_precise_qualifier && + has_decoration(result_id, DecorationNoContraction) && + type_is_floating_point(get(result_type)); + bool forward = should_forward(op0) && should_forward(op1) && !force_temporary_precise; + + emit_op(result_type, result_id, + join(to_enclosed_unpacked_expression(op0), " ", op, " ", to_enclosed_unpacked_expression(op1)), forward); + + inherit_expression_dependencies(result_id, op0); + inherit_expression_dependencies(result_id, op1); +} + +void CompilerGLSL::emit_unrolled_unary_op(uint32_t result_type, uint32_t result_id, uint32_t operand, const char *op) +{ + auto &type = get(result_type); + auto expr = type_to_glsl_constructor(type); + expr += '('; + for (uint32_t i = 0; i < type.vecsize; i++) + { + // Make sure to call to_expression multiple times to ensure + // that these expressions are properly flushed to temporaries if needed. + expr += op; + expr += to_extract_component_expression(operand, i); + + if (i + 1 < type.vecsize) + expr += ", "; + } + expr += ')'; + emit_op(result_type, result_id, expr, should_forward(operand)); + + inherit_expression_dependencies(result_id, operand); +} + +void CompilerGLSL::emit_unrolled_binary_op(uint32_t result_type, uint32_t result_id, uint32_t op0, uint32_t op1, + const char *op, bool negate, SPIRType::BaseType expected_type) +{ + auto &type0 = expression_type(op0); + auto &type1 = expression_type(op1); + + SPIRType target_type0 = type0; + SPIRType target_type1 = type1; + target_type0.basetype = expected_type; + target_type1.basetype = expected_type; + target_type0.vecsize = 1; + target_type1.vecsize = 1; + + auto &type = get(result_type); + auto expr = type_to_glsl_constructor(type); + expr += '('; + for (uint32_t i = 0; i < type.vecsize; i++) + { + // Make sure to call to_expression multiple times to ensure + // that these expressions are properly flushed to temporaries if needed. + if (negate) + expr += "!("; + + if (expected_type != SPIRType::Unknown && type0.basetype != expected_type) + expr += bitcast_expression(target_type0, type0.basetype, to_extract_component_expression(op0, i)); + else + expr += to_extract_component_expression(op0, i); + + expr += ' '; + expr += op; + expr += ' '; + + if (expected_type != SPIRType::Unknown && type1.basetype != expected_type) + expr += bitcast_expression(target_type1, type1.basetype, to_extract_component_expression(op1, i)); + else + expr += to_extract_component_expression(op1, i); + + if (negate) + expr += ")"; + + if (i + 1 < type.vecsize) + expr += ", "; + } + expr += ')'; + emit_op(result_type, result_id, expr, should_forward(op0) && should_forward(op1)); + + inherit_expression_dependencies(result_id, op0); + inherit_expression_dependencies(result_id, op1); +} + +SPIRType CompilerGLSL::binary_op_bitcast_helper(string &cast_op0, string &cast_op1, SPIRType::BaseType &input_type, + uint32_t op0, uint32_t op1, bool skip_cast_if_equal_type) +{ + auto &type0 = expression_type(op0); + auto &type1 = expression_type(op1); + + // We have to bitcast if our inputs are of different type, or if our types are not equal to expected inputs. + // For some functions like OpIEqual and INotEqual, we don't care if inputs are of different types than expected + // since equality test is exactly the same. + bool cast = (type0.basetype != type1.basetype) || (!skip_cast_if_equal_type && type0.basetype != input_type); + + // Create a fake type so we can bitcast to it. + // We only deal with regular arithmetic types here like int, uints and so on. + SPIRType expected_type{type0.op}; + expected_type.basetype = input_type; + expected_type.vecsize = type0.vecsize; + expected_type.columns = type0.columns; + expected_type.width = type0.width; + + if (cast) + { + cast_op0 = bitcast_glsl(expected_type, op0); + cast_op1 = bitcast_glsl(expected_type, op1); + } + else + { + // If we don't cast, our actual input type is that of the first (or second) argument. + cast_op0 = to_enclosed_unpacked_expression(op0); + cast_op1 = to_enclosed_unpacked_expression(op1); + input_type = type0.basetype; + } + + return expected_type; +} + +bool CompilerGLSL::emit_complex_bitcast(uint32_t result_type, uint32_t id, uint32_t op0) +{ + // Some bitcasts may require complex casting sequences, and are implemented here. + // Otherwise a simply unary function will do with bitcast_glsl_op. + + auto &output_type = get(result_type); + auto &input_type = expression_type(op0); + string expr; + + if (output_type.basetype == SPIRType::Half && input_type.basetype == SPIRType::Float && input_type.vecsize == 1) + expr = join("unpackFloat2x16(floatBitsToUint(", to_unpacked_expression(op0), "))"); + else if (output_type.basetype == SPIRType::Float && input_type.basetype == SPIRType::Half && + input_type.vecsize == 2) + expr = join("uintBitsToFloat(packFloat2x16(", to_unpacked_expression(op0), "))"); + else + return false; + + emit_op(result_type, id, expr, should_forward(op0)); + return true; +} + +void CompilerGLSL::emit_binary_op_cast(uint32_t result_type, uint32_t result_id, uint32_t op0, uint32_t op1, + const char *op, SPIRType::BaseType input_type, + bool skip_cast_if_equal_type, + bool implicit_integer_promotion) +{ + string cast_op0, cast_op1; + auto expected_type = binary_op_bitcast_helper(cast_op0, cast_op1, input_type, op0, op1, skip_cast_if_equal_type); + auto &out_type = get(result_type); + + // We might have casted away from the result type, so bitcast again. + // For example, arithmetic right shift with uint inputs. + // Special case boolean outputs since relational opcodes output booleans instead of int/uint. + auto bitop = join(cast_op0, " ", op, " ", cast_op1); + string expr; + + if (implicit_integer_promotion) + { + // Simple value cast. + expr = join(type_to_glsl(out_type), '(', bitop, ')'); + } + else if (out_type.basetype != input_type && out_type.basetype != SPIRType::Boolean) + { + expected_type.basetype = input_type; + expr = join(bitcast_glsl_op(out_type, expected_type), '(', bitop, ')'); + } + else + { + expr = std::move(bitop); + } + + emit_op(result_type, result_id, expr, should_forward(op0) && should_forward(op1)); + inherit_expression_dependencies(result_id, op0); + inherit_expression_dependencies(result_id, op1); +} + +void CompilerGLSL::emit_unary_func_op(uint32_t result_type, uint32_t result_id, uint32_t op0, const char *op) +{ + bool forward = should_forward(op0); + emit_op(result_type, result_id, join(op, "(", to_unpacked_expression(op0), ")"), forward); + inherit_expression_dependencies(result_id, op0); +} + +void CompilerGLSL::emit_binary_func_op(uint32_t result_type, uint32_t result_id, uint32_t op0, uint32_t op1, + const char *op) +{ + // Opaque types (e.g. OpTypeSampledImage) must always be forwarded in GLSL + const auto &type = get_type(result_type); + bool must_forward = type_is_opaque_value(type); + bool forward = must_forward || (should_forward(op0) && should_forward(op1)); + emit_op(result_type, result_id, join(op, "(", to_unpacked_expression(op0), ", ", to_unpacked_expression(op1), ")"), + forward); + inherit_expression_dependencies(result_id, op0); + inherit_expression_dependencies(result_id, op1); +} + +void CompilerGLSL::emit_atomic_func_op(uint32_t result_type, uint32_t result_id, uint32_t op0, uint32_t op1, + const char *op) +{ + auto &type = get(result_type); + if (type_is_floating_point(type)) + { + if (!options.vulkan_semantics) + SPIRV_CROSS_THROW("Floating point atomics requires Vulkan semantics."); + if (options.es) + SPIRV_CROSS_THROW("Floating point atomics requires desktop GLSL."); + require_extension_internal("GL_EXT_shader_atomic_float"); + } + + forced_temporaries.insert(result_id); + emit_op(result_type, result_id, + join(op, "(", to_non_uniform_aware_expression(op0), ", ", + to_unpacked_expression(op1), ")"), false); + flush_all_atomic_capable_variables(); +} + +void CompilerGLSL::emit_atomic_func_op(uint32_t result_type, uint32_t result_id, + uint32_t op0, uint32_t op1, uint32_t op2, + const char *op) +{ + forced_temporaries.insert(result_id); + emit_op(result_type, result_id, + join(op, "(", to_non_uniform_aware_expression(op0), ", ", + to_unpacked_expression(op1), ", ", to_unpacked_expression(op2), ")"), false); + flush_all_atomic_capable_variables(); +} + +void CompilerGLSL::emit_unary_func_op_cast(uint32_t result_type, uint32_t result_id, uint32_t op0, const char *op, + SPIRType::BaseType input_type, SPIRType::BaseType expected_result_type) +{ + auto &out_type = get(result_type); + auto &expr_type = expression_type(op0); + auto expected_type = out_type; + + // Bit-widths might be different in unary cases because we use it for SConvert/UConvert and friends. + expected_type.basetype = input_type; + expected_type.width = expr_type.width; + + string cast_op; + if (expr_type.basetype != input_type) + { + if (expr_type.basetype == SPIRType::Boolean) + cast_op = join(type_to_glsl(expected_type), "(", to_unpacked_expression(op0), ")"); + else + cast_op = bitcast_glsl(expected_type, op0); + } + else + cast_op = to_unpacked_expression(op0); + + string expr; + if (out_type.basetype != expected_result_type) + { + expected_type.basetype = expected_result_type; + expected_type.width = out_type.width; + if (out_type.basetype == SPIRType::Boolean) + expr = type_to_glsl(out_type); + else + expr = bitcast_glsl_op(out_type, expected_type); + expr += '('; + expr += join(op, "(", cast_op, ")"); + expr += ')'; + } + else + { + expr += join(op, "(", cast_op, ")"); + } + + emit_op(result_type, result_id, expr, should_forward(op0)); + inherit_expression_dependencies(result_id, op0); +} + +// Very special case. Handling bitfieldExtract requires us to deal with different bitcasts of different signs +// and different vector sizes all at once. Need a special purpose method here. +void CompilerGLSL::emit_trinary_func_op_bitextract(uint32_t result_type, uint32_t result_id, uint32_t op0, uint32_t op1, + uint32_t op2, const char *op, + SPIRType::BaseType expected_result_type, + SPIRType::BaseType input_type0, SPIRType::BaseType input_type1, + SPIRType::BaseType input_type2) +{ + auto &out_type = get(result_type); + auto expected_type = out_type; + expected_type.basetype = input_type0; + + string cast_op0 = + expression_type(op0).basetype != input_type0 ? bitcast_glsl(expected_type, op0) : to_unpacked_expression(op0); + + auto op1_expr = to_unpacked_expression(op1); + auto op2_expr = to_unpacked_expression(op2); + + // Use value casts here instead. Input must be exactly int or uint, but SPIR-V might be 16-bit. + expected_type.basetype = input_type1; + expected_type.vecsize = 1; + string cast_op1 = expression_type(op1).basetype != input_type1 ? + join(type_to_glsl_constructor(expected_type), "(", op1_expr, ")") : + op1_expr; + + expected_type.basetype = input_type2; + expected_type.vecsize = 1; + string cast_op2 = expression_type(op2).basetype != input_type2 ? + join(type_to_glsl_constructor(expected_type), "(", op2_expr, ")") : + op2_expr; + + string expr; + if (out_type.basetype != expected_result_type) + { + expected_type.vecsize = out_type.vecsize; + expected_type.basetype = expected_result_type; + expr = bitcast_glsl_op(out_type, expected_type); + expr += '('; + expr += join(op, "(", cast_op0, ", ", cast_op1, ", ", cast_op2, ")"); + expr += ')'; + } + else + { + expr += join(op, "(", cast_op0, ", ", cast_op1, ", ", cast_op2, ")"); + } + + emit_op(result_type, result_id, expr, should_forward(op0) && should_forward(op1) && should_forward(op2)); + inherit_expression_dependencies(result_id, op0); + inherit_expression_dependencies(result_id, op1); + inherit_expression_dependencies(result_id, op2); +} + +void CompilerGLSL::emit_trinary_func_op_cast(uint32_t result_type, uint32_t result_id, uint32_t op0, uint32_t op1, + uint32_t op2, const char *op, SPIRType::BaseType input_type) +{ + auto &out_type = get(result_type); + auto expected_type = out_type; + expected_type.basetype = input_type; + string cast_op0 = + expression_type(op0).basetype != input_type ? bitcast_glsl(expected_type, op0) : to_unpacked_expression(op0); + string cast_op1 = + expression_type(op1).basetype != input_type ? bitcast_glsl(expected_type, op1) : to_unpacked_expression(op1); + string cast_op2 = + expression_type(op2).basetype != input_type ? bitcast_glsl(expected_type, op2) : to_unpacked_expression(op2); + + string expr; + if (out_type.basetype != input_type) + { + expr = bitcast_glsl_op(out_type, expected_type); + expr += '('; + expr += join(op, "(", cast_op0, ", ", cast_op1, ", ", cast_op2, ")"); + expr += ')'; + } + else + { + expr += join(op, "(", cast_op0, ", ", cast_op1, ", ", cast_op2, ")"); + } + + emit_op(result_type, result_id, expr, should_forward(op0) && should_forward(op1) && should_forward(op2)); + inherit_expression_dependencies(result_id, op0); + inherit_expression_dependencies(result_id, op1); + inherit_expression_dependencies(result_id, op2); +} + +void CompilerGLSL::emit_binary_func_op_cast_clustered(uint32_t result_type, uint32_t result_id, uint32_t op0, + uint32_t op1, const char *op, SPIRType::BaseType input_type) +{ + // Special purpose method for implementing clustered subgroup opcodes. + // Main difference is that op1 does not participate in any casting, it needs to be a literal. + auto &out_type = get(result_type); + auto expected_type = out_type; + expected_type.basetype = input_type; + string cast_op0 = + expression_type(op0).basetype != input_type ? bitcast_glsl(expected_type, op0) : to_unpacked_expression(op0); + + string expr; + if (out_type.basetype != input_type) + { + expr = bitcast_glsl_op(out_type, expected_type); + expr += '('; + expr += join(op, "(", cast_op0, ", ", to_expression(op1), ")"); + expr += ')'; + } + else + { + expr += join(op, "(", cast_op0, ", ", to_expression(op1), ")"); + } + + emit_op(result_type, result_id, expr, should_forward(op0)); + inherit_expression_dependencies(result_id, op0); +} + +void CompilerGLSL::emit_binary_func_op_cast(uint32_t result_type, uint32_t result_id, uint32_t op0, uint32_t op1, + const char *op, SPIRType::BaseType input_type, bool skip_cast_if_equal_type) +{ + string cast_op0, cast_op1; + auto expected_type = binary_op_bitcast_helper(cast_op0, cast_op1, input_type, op0, op1, skip_cast_if_equal_type); + auto &out_type = get(result_type); + + // Special case boolean outputs since relational opcodes output booleans instead of int/uint. + string expr; + if (out_type.basetype != input_type && out_type.basetype != SPIRType::Boolean) + { + expected_type.basetype = input_type; + expr = bitcast_glsl_op(out_type, expected_type); + expr += '('; + expr += join(op, "(", cast_op0, ", ", cast_op1, ")"); + expr += ')'; + } + else + { + expr += join(op, "(", cast_op0, ", ", cast_op1, ")"); + } + + emit_op(result_type, result_id, expr, should_forward(op0) && should_forward(op1)); + inherit_expression_dependencies(result_id, op0); + inherit_expression_dependencies(result_id, op1); +} + +void CompilerGLSL::emit_trinary_func_op(uint32_t result_type, uint32_t result_id, uint32_t op0, uint32_t op1, + uint32_t op2, const char *op) +{ + bool forward = should_forward(op0) && should_forward(op1) && should_forward(op2); + emit_op(result_type, result_id, + join(op, "(", to_unpacked_expression(op0), ", ", to_unpacked_expression(op1), ", ", + to_unpacked_expression(op2), ")"), + forward); + + inherit_expression_dependencies(result_id, op0); + inherit_expression_dependencies(result_id, op1); + inherit_expression_dependencies(result_id, op2); +} + +void CompilerGLSL::emit_quaternary_func_op(uint32_t result_type, uint32_t result_id, uint32_t op0, uint32_t op1, + uint32_t op2, uint32_t op3, const char *op) +{ + bool forward = should_forward(op0) && should_forward(op1) && should_forward(op2) && should_forward(op3); + emit_op(result_type, result_id, + join(op, "(", to_unpacked_expression(op0), ", ", to_unpacked_expression(op1), ", ", + to_unpacked_expression(op2), ", ", to_unpacked_expression(op3), ")"), + forward); + + inherit_expression_dependencies(result_id, op0); + inherit_expression_dependencies(result_id, op1); + inherit_expression_dependencies(result_id, op2); + inherit_expression_dependencies(result_id, op3); +} + +void CompilerGLSL::emit_bitfield_insert_op(uint32_t result_type, uint32_t result_id, uint32_t op0, uint32_t op1, + uint32_t op2, uint32_t op3, const char *op, + SPIRType::BaseType offset_count_type) +{ + // Only need to cast offset/count arguments. Types of base/insert must be same as result type, + // and bitfieldInsert is sign invariant. + bool forward = should_forward(op0) && should_forward(op1) && should_forward(op2) && should_forward(op3); + + auto op0_expr = to_unpacked_expression(op0); + auto op1_expr = to_unpacked_expression(op1); + auto op2_expr = to_unpacked_expression(op2); + auto op3_expr = to_unpacked_expression(op3); + + assert(offset_count_type == SPIRType::UInt || offset_count_type == SPIRType::Int); + SPIRType target_type { OpTypeInt }; + target_type.width = 32; + target_type.vecsize = 1; + target_type.basetype = offset_count_type; + + if (expression_type(op2).basetype != offset_count_type) + { + // Value-cast here. Input might be 16-bit. GLSL requires int. + op2_expr = join(type_to_glsl_constructor(target_type), "(", op2_expr, ")"); + } + + if (expression_type(op3).basetype != offset_count_type) + { + // Value-cast here. Input might be 16-bit. GLSL requires int. + op3_expr = join(type_to_glsl_constructor(target_type), "(", op3_expr, ")"); + } + + emit_op(result_type, result_id, join(op, "(", op0_expr, ", ", op1_expr, ", ", op2_expr, ", ", op3_expr, ")"), + forward); + + inherit_expression_dependencies(result_id, op0); + inherit_expression_dependencies(result_id, op1); + inherit_expression_dependencies(result_id, op2); + inherit_expression_dependencies(result_id, op3); +} + +string CompilerGLSL::legacy_tex_op(const std::string &op, const SPIRType &imgtype, uint32_t tex) +{ + const char *type; + switch (imgtype.image.dim) + { + case spv::Dim1D: + // Force 2D path for ES. + if (options.es) + type = (imgtype.image.arrayed && !options.es) ? "2DArray" : "2D"; + else + type = (imgtype.image.arrayed && !options.es) ? "1DArray" : "1D"; + break; + case spv::Dim2D: + type = (imgtype.image.arrayed && !options.es) ? "2DArray" : "2D"; + break; + case spv::Dim3D: + type = "3D"; + break; + case spv::DimCube: + type = "Cube"; + break; + case spv::DimRect: + type = "2DRect"; + break; + case spv::DimBuffer: + type = "Buffer"; + break; + case spv::DimSubpassData: + type = "2D"; + break; + default: + type = ""; + break; + } + + // In legacy GLSL, an extension is required for textureLod in the fragment + // shader or textureGrad anywhere. + bool legacy_lod_ext = false; + auto &execution = get_entry_point(); + if (op == "textureGrad" || op == "textureProjGrad" || + ((op == "textureLod" || op == "textureProjLod") && execution.model != ExecutionModelVertex)) + { + if (is_legacy_es()) + { + legacy_lod_ext = true; + require_extension_internal("GL_EXT_shader_texture_lod"); + } + else if (is_legacy_desktop()) + require_extension_internal("GL_ARB_shader_texture_lod"); + } + + if (op == "textureLodOffset" || op == "textureProjLodOffset") + { + if (is_legacy_es()) + SPIRV_CROSS_THROW(join(op, " not allowed in legacy ES")); + + require_extension_internal("GL_EXT_gpu_shader4"); + } + + // GLES has very limited support for shadow samplers. + // Basically shadow2D and shadow2DProj work through EXT_shadow_samplers, + // everything else can just throw + bool is_comparison = is_depth_image(imgtype, tex); + if (is_comparison && is_legacy_es()) + { + if (op == "texture" || op == "textureProj") + require_extension_internal("GL_EXT_shadow_samplers"); + else + SPIRV_CROSS_THROW(join(op, " not allowed on depth samplers in legacy ES")); + + if (imgtype.image.dim == spv::DimCube) + return "shadowCubeNV"; + } + + if (op == "textureSize") + { + if (is_legacy_es()) + SPIRV_CROSS_THROW("textureSize not supported in legacy ES"); + if (is_comparison) + SPIRV_CROSS_THROW("textureSize not supported on shadow sampler in legacy GLSL"); + require_extension_internal("GL_EXT_gpu_shader4"); + } + + if (op == "texelFetch" && is_legacy_es()) + SPIRV_CROSS_THROW("texelFetch not supported in legacy ES"); + + bool is_es_and_depth = is_legacy_es() && is_comparison; + std::string type_prefix = is_comparison ? "shadow" : "texture"; + + if (op == "texture") + return is_es_and_depth ? join(type_prefix, type, "EXT") : join(type_prefix, type); + else if (op == "textureLod") + return join(type_prefix, type, legacy_lod_ext ? "LodEXT" : "Lod"); + else if (op == "textureProj") + return join(type_prefix, type, is_es_and_depth ? "ProjEXT" : "Proj"); + else if (op == "textureGrad") + return join(type_prefix, type, is_legacy_es() ? "GradEXT" : is_legacy_desktop() ? "GradARB" : "Grad"); + else if (op == "textureProjLod") + return join(type_prefix, type, legacy_lod_ext ? "ProjLodEXT" : "ProjLod"); + else if (op == "textureLodOffset") + return join(type_prefix, type, "LodOffset"); + else if (op == "textureProjGrad") + return join(type_prefix, type, + is_legacy_es() ? "ProjGradEXT" : is_legacy_desktop() ? "ProjGradARB" : "ProjGrad"); + else if (op == "textureProjLodOffset") + return join(type_prefix, type, "ProjLodOffset"); + else if (op == "textureSize") + return join("textureSize", type); + else if (op == "texelFetch") + return join("texelFetch", type); + else + { + SPIRV_CROSS_THROW(join("Unsupported legacy texture op: ", op)); + } +} + +bool CompilerGLSL::to_trivial_mix_op(const SPIRType &type, string &op, uint32_t left, uint32_t right, uint32_t lerp) +{ + auto *cleft = maybe_get(left); + auto *cright = maybe_get(right); + auto &lerptype = expression_type(lerp); + + // If our targets aren't constants, we cannot use construction. + if (!cleft || !cright) + return false; + + // If our targets are spec constants, we cannot use construction. + if (cleft->specialization || cright->specialization) + return false; + + auto &value_type = get(cleft->constant_type); + + if (lerptype.basetype != SPIRType::Boolean) + return false; + if (value_type.basetype == SPIRType::Struct || is_array(value_type)) + return false; + if (!backend.use_constructor_splatting && value_type.vecsize != lerptype.vecsize) + return false; + + // Only valid way in SPIR-V 1.4 to use matrices in select is a scalar select. + // matrix(scalar) constructor fills in diagnonals, so gets messy very quickly. + // Just avoid this case. + if (value_type.columns > 1) + return false; + + // If our bool selects between 0 and 1, we can cast from bool instead, making our trivial constructor. + bool ret = true; + for (uint32_t row = 0; ret && row < value_type.vecsize; row++) + { + switch (type.basetype) + { + case SPIRType::Short: + case SPIRType::UShort: + ret = cleft->scalar_u16(0, row) == 0 && cright->scalar_u16(0, row) == 1; + break; + + case SPIRType::Int: + case SPIRType::UInt: + ret = cleft->scalar(0, row) == 0 && cright->scalar(0, row) == 1; + break; + + case SPIRType::Half: + ret = cleft->scalar_f16(0, row) == 0.0f && cright->scalar_f16(0, row) == 1.0f; + break; + + case SPIRType::Float: + ret = cleft->scalar_f32(0, row) == 0.0f && cright->scalar_f32(0, row) == 1.0f; + break; + + case SPIRType::Double: + ret = cleft->scalar_f64(0, row) == 0.0 && cright->scalar_f64(0, row) == 1.0; + break; + + case SPIRType::Int64: + case SPIRType::UInt64: + ret = cleft->scalar_u64(0, row) == 0 && cright->scalar_u64(0, row) == 1; + break; + + default: + ret = false; + break; + } + } + + if (ret) + op = type_to_glsl_constructor(type); + return ret; +} + +string CompilerGLSL::to_ternary_expression(const SPIRType &restype, uint32_t select, uint32_t true_value, + uint32_t false_value) +{ + string expr; + auto &lerptype = expression_type(select); + + if (lerptype.vecsize == 1) + expr = join(to_enclosed_expression(select), " ? ", to_enclosed_pointer_expression(true_value), " : ", + to_enclosed_pointer_expression(false_value)); + else + { + auto swiz = [this](uint32_t expression, uint32_t i) { return to_extract_component_expression(expression, i); }; + + expr = type_to_glsl_constructor(restype); + expr += "("; + for (uint32_t i = 0; i < restype.vecsize; i++) + { + expr += swiz(select, i); + expr += " ? "; + expr += swiz(true_value, i); + expr += " : "; + expr += swiz(false_value, i); + if (i + 1 < restype.vecsize) + expr += ", "; + } + expr += ")"; + } + + return expr; +} + +void CompilerGLSL::emit_mix_op(uint32_t result_type, uint32_t id, uint32_t left, uint32_t right, uint32_t lerp) +{ + auto &lerptype = expression_type(lerp); + auto &restype = get(result_type); + + // If this results in a variable pointer, assume it may be written through. + if (restype.pointer) + { + register_write(left); + register_write(right); + } + + string mix_op; + bool has_boolean_mix = *backend.boolean_mix_function && + ((options.es && options.version >= 310) || (!options.es && options.version >= 450)); + bool trivial_mix = to_trivial_mix_op(restype, mix_op, left, right, lerp); + + // Cannot use boolean mix when the lerp argument is just one boolean, + // fall back to regular trinary statements. + if (lerptype.vecsize == 1) + has_boolean_mix = false; + + // If we can reduce the mix to a simple cast, do so. + // This helps for cases like int(bool), uint(bool) which is implemented with + // OpSelect bool 1 0. + if (trivial_mix) + { + emit_unary_func_op(result_type, id, lerp, mix_op.c_str()); + } + else if (!has_boolean_mix && lerptype.basetype == SPIRType::Boolean) + { + // Boolean mix not supported on desktop without extension. + // Was added in OpenGL 4.5 with ES 3.1 compat. + // + // Could use GL_EXT_shader_integer_mix on desktop at least, + // but Apple doesn't support it. :( + // Just implement it as ternary expressions. + auto expr = to_ternary_expression(get(result_type), lerp, right, left); + emit_op(result_type, id, expr, should_forward(left) && should_forward(right) && should_forward(lerp)); + inherit_expression_dependencies(id, left); + inherit_expression_dependencies(id, right); + inherit_expression_dependencies(id, lerp); + } + else if (lerptype.basetype == SPIRType::Boolean) + emit_trinary_func_op(result_type, id, left, right, lerp, backend.boolean_mix_function); + else + emit_trinary_func_op(result_type, id, left, right, lerp, "mix"); +} + +string CompilerGLSL::to_combined_image_sampler(VariableID image_id, VariableID samp_id) +{ + // Keep track of the array indices we have used to load the image. + // We'll need to use the same array index into the combined image sampler array. + auto image_expr = to_non_uniform_aware_expression(image_id); + string array_expr; + auto array_index = image_expr.find_first_of('['); + if (array_index != string::npos) + array_expr = image_expr.substr(array_index, string::npos); + + auto &args = current_function->arguments; + + // For GLSL and ESSL targets, we must enumerate all possible combinations for sampler2D(texture2D, sampler) and redirect + // all possible combinations into new sampler2D uniforms. + auto *image = maybe_get_backing_variable(image_id); + auto *samp = maybe_get_backing_variable(samp_id); + if (image) + image_id = image->self; + if (samp) + samp_id = samp->self; + + auto image_itr = find_if(begin(args), end(args), + [image_id](const SPIRFunction::Parameter ¶m) { return image_id == param.id; }); + + auto sampler_itr = find_if(begin(args), end(args), + [samp_id](const SPIRFunction::Parameter ¶m) { return samp_id == param.id; }); + + if (image_itr != end(args) || sampler_itr != end(args)) + { + // If any parameter originates from a parameter, we will find it in our argument list. + bool global_image = image_itr == end(args); + bool global_sampler = sampler_itr == end(args); + VariableID iid = global_image ? image_id : VariableID(uint32_t(image_itr - begin(args))); + VariableID sid = global_sampler ? samp_id : VariableID(uint32_t(sampler_itr - begin(args))); + + auto &combined = current_function->combined_parameters; + auto itr = find_if(begin(combined), end(combined), [=](const SPIRFunction::CombinedImageSamplerParameter &p) { + return p.global_image == global_image && p.global_sampler == global_sampler && p.image_id == iid && + p.sampler_id == sid; + }); + + if (itr != end(combined)) + return to_expression(itr->id) + array_expr; + else + { + SPIRV_CROSS_THROW("Cannot find mapping for combined sampler parameter, was " + "build_combined_image_samplers() used " + "before compile() was called?"); + } + } + else + { + // For global sampler2D, look directly at the global remapping table. + auto &mapping = combined_image_samplers; + auto itr = find_if(begin(mapping), end(mapping), [image_id, samp_id](const CombinedImageSampler &combined) { + return combined.image_id == image_id && combined.sampler_id == samp_id; + }); + + if (itr != end(combined_image_samplers)) + return to_expression(itr->combined_id) + array_expr; + else + { + SPIRV_CROSS_THROW("Cannot find mapping for combined sampler, was build_combined_image_samplers() used " + "before compile() was called?"); + } + } +} + +bool CompilerGLSL::is_supported_subgroup_op_in_opengl(spv::Op op, const uint32_t *ops) +{ + switch (op) + { + case OpGroupNonUniformElect: + case OpGroupNonUniformBallot: + case OpGroupNonUniformBallotFindLSB: + case OpGroupNonUniformBallotFindMSB: + case OpGroupNonUniformBroadcast: + case OpGroupNonUniformBroadcastFirst: + case OpGroupNonUniformAll: + case OpGroupNonUniformAny: + case OpGroupNonUniformAllEqual: + case OpControlBarrier: + case OpMemoryBarrier: + case OpGroupNonUniformBallotBitCount: + case OpGroupNonUniformBallotBitExtract: + case OpGroupNonUniformInverseBallot: + return true; + case OpGroupNonUniformIAdd: + case OpGroupNonUniformFAdd: + case OpGroupNonUniformIMul: + case OpGroupNonUniformFMul: + { + const GroupOperation operation = static_cast(ops[3]); + if (operation == GroupOperationReduce || operation == GroupOperationInclusiveScan || + operation == GroupOperationExclusiveScan) + { + return true; + } + else + { + return false; + } + } + default: + return false; + } +} + +void CompilerGLSL::emit_sampled_image_op(uint32_t result_type, uint32_t result_id, uint32_t image_id, uint32_t samp_id) +{ + if (options.vulkan_semantics && combined_image_samplers.empty()) + { + emit_binary_func_op(result_type, result_id, image_id, samp_id, + type_to_glsl(get(result_type), result_id).c_str()); + } + else + { + // Make sure to suppress usage tracking. It is illegal to create temporaries of opaque types. + emit_op(result_type, result_id, to_combined_image_sampler(image_id, samp_id), true, true); + } + + // Make sure to suppress usage tracking and any expression invalidation. + // It is illegal to create temporaries of opaque types. + forwarded_temporaries.erase(result_id); +} + +static inline bool image_opcode_is_sample_no_dref(Op op) +{ + switch (op) + { + case OpImageSampleExplicitLod: + case OpImageSampleImplicitLod: + case OpImageSampleProjExplicitLod: + case OpImageSampleProjImplicitLod: + case OpImageFetch: + case OpImageRead: + case OpImageSparseSampleExplicitLod: + case OpImageSparseSampleImplicitLod: + case OpImageSparseSampleProjExplicitLod: + case OpImageSparseSampleProjImplicitLod: + case OpImageSparseFetch: + case OpImageSparseRead: + return true; + + default: + return false; + } +} + +void CompilerGLSL::emit_sparse_feedback_temporaries(uint32_t result_type_id, uint32_t id, uint32_t &feedback_id, + uint32_t &texel_id) +{ + // Need to allocate two temporaries. + if (options.es) + SPIRV_CROSS_THROW("Sparse texture feedback is not supported on ESSL."); + require_extension_internal("GL_ARB_sparse_texture2"); + + auto &temps = extra_sub_expressions[id]; + if (temps == 0) + temps = ir.increase_bound_by(2); + + feedback_id = temps + 0; + texel_id = temps + 1; + + auto &return_type = get(result_type_id); + if (return_type.basetype != SPIRType::Struct || return_type.member_types.size() != 2) + SPIRV_CROSS_THROW("Invalid return type for sparse feedback."); + emit_uninitialized_temporary(return_type.member_types[0], feedback_id); + emit_uninitialized_temporary(return_type.member_types[1], texel_id); +} + +uint32_t CompilerGLSL::get_sparse_feedback_texel_id(uint32_t id) const +{ + auto itr = extra_sub_expressions.find(id); + if (itr == extra_sub_expressions.end()) + return 0; + else + return itr->second + 1; +} + +void CompilerGLSL::emit_texture_op(const Instruction &i, bool sparse) +{ + auto *ops = stream(i); + auto op = static_cast(i.op); + + SmallVector inherited_expressions; + + uint32_t result_type_id = ops[0]; + uint32_t id = ops[1]; + auto &return_type = get(result_type_id); + + uint32_t sparse_code_id = 0; + uint32_t sparse_texel_id = 0; + if (sparse) + emit_sparse_feedback_temporaries(result_type_id, id, sparse_code_id, sparse_texel_id); + + bool forward = false; + string expr = to_texture_op(i, sparse, &forward, inherited_expressions); + + if (sparse) + { + statement(to_expression(sparse_code_id), " = ", expr, ";"); + expr = join(type_to_glsl(return_type), "(", to_expression(sparse_code_id), ", ", to_expression(sparse_texel_id), + ")"); + forward = true; + inherited_expressions.clear(); + } + + emit_op(result_type_id, id, expr, forward); + for (auto &inherit : inherited_expressions) + inherit_expression_dependencies(id, inherit); + + // Do not register sparse ops as control dependent as they are always lowered to a temporary. + switch (op) + { + case OpImageSampleDrefImplicitLod: + case OpImageSampleImplicitLod: + case OpImageSampleProjImplicitLod: + case OpImageSampleProjDrefImplicitLod: + register_control_dependent_expression(id); + break; + + default: + break; + } +} + +std::string CompilerGLSL::to_texture_op(const Instruction &i, bool sparse, bool *forward, + SmallVector &inherited_expressions) +{ + auto *ops = stream(i); + auto op = static_cast(i.op); + uint32_t length = i.length; + + uint32_t result_type_id = ops[0]; + VariableID img = ops[2]; + uint32_t coord = ops[3]; + uint32_t dref = 0; + uint32_t comp = 0; + bool gather = false; + bool proj = false; + bool fetch = false; + bool nonuniform_expression = false; + const uint32_t *opt = nullptr; + + auto &result_type = get(result_type_id); + + inherited_expressions.push_back(coord); + if (has_decoration(img, DecorationNonUniform) && !maybe_get_backing_variable(img)) + nonuniform_expression = true; + + switch (op) + { + case OpImageSampleDrefImplicitLod: + case OpImageSampleDrefExplicitLod: + case OpImageSparseSampleDrefImplicitLod: + case OpImageSparseSampleDrefExplicitLod: + dref = ops[4]; + opt = &ops[5]; + length -= 5; + break; + + case OpImageSampleProjDrefImplicitLod: + case OpImageSampleProjDrefExplicitLod: + case OpImageSparseSampleProjDrefImplicitLod: + case OpImageSparseSampleProjDrefExplicitLod: + dref = ops[4]; + opt = &ops[5]; + length -= 5; + proj = true; + break; + + case OpImageDrefGather: + case OpImageSparseDrefGather: + dref = ops[4]; + opt = &ops[5]; + length -= 5; + gather = true; + if (options.es && options.version < 310) + SPIRV_CROSS_THROW("textureGather requires ESSL 310."); + else if (!options.es && options.version < 400) + SPIRV_CROSS_THROW("textureGather with depth compare requires GLSL 400."); + break; + + case OpImageGather: + case OpImageSparseGather: + comp = ops[4]; + opt = &ops[5]; + length -= 5; + gather = true; + if (options.es && options.version < 310) + SPIRV_CROSS_THROW("textureGather requires ESSL 310."); + else if (!options.es && options.version < 400) + { + if (!expression_is_constant_null(comp)) + SPIRV_CROSS_THROW("textureGather with component requires GLSL 400."); + require_extension_internal("GL_ARB_texture_gather"); + } + break; + + case OpImageFetch: + case OpImageSparseFetch: + case OpImageRead: // Reads == fetches in Metal (other langs will not get here) + opt = &ops[4]; + length -= 4; + fetch = true; + break; + + case OpImageSampleProjImplicitLod: + case OpImageSampleProjExplicitLod: + case OpImageSparseSampleProjImplicitLod: + case OpImageSparseSampleProjExplicitLod: + opt = &ops[4]; + length -= 4; + proj = true; + break; + + default: + opt = &ops[4]; + length -= 4; + break; + } + + // Bypass pointers because we need the real image struct + auto &type = expression_type(img); + auto &imgtype = get(type.self); + + uint32_t coord_components = 0; + switch (imgtype.image.dim) + { + case spv::Dim1D: + coord_components = 1; + break; + case spv::Dim2D: + coord_components = 2; + break; + case spv::Dim3D: + coord_components = 3; + break; + case spv::DimCube: + coord_components = 3; + break; + case spv::DimBuffer: + coord_components = 1; + break; + default: + coord_components = 2; + break; + } + + if (dref) + inherited_expressions.push_back(dref); + + if (proj) + coord_components++; + if (imgtype.image.arrayed) + coord_components++; + + uint32_t bias = 0; + uint32_t lod = 0; + uint32_t grad_x = 0; + uint32_t grad_y = 0; + uint32_t coffset = 0; + uint32_t offset = 0; + uint32_t coffsets = 0; + uint32_t sample = 0; + uint32_t minlod = 0; + uint32_t flags = 0; + + if (length) + { + flags = *opt++; + length--; + } + + auto test = [&](uint32_t &v, uint32_t flag) { + if (length && (flags & flag)) + { + v = *opt++; + inherited_expressions.push_back(v); + length--; + } + }; + + test(bias, ImageOperandsBiasMask); + test(lod, ImageOperandsLodMask); + test(grad_x, ImageOperandsGradMask); + test(grad_y, ImageOperandsGradMask); + test(coffset, ImageOperandsConstOffsetMask); + test(offset, ImageOperandsOffsetMask); + test(coffsets, ImageOperandsConstOffsetsMask); + test(sample, ImageOperandsSampleMask); + test(minlod, ImageOperandsMinLodMask); + + TextureFunctionBaseArguments base_args = {}; + base_args.img = img; + base_args.imgtype = &imgtype; + base_args.is_fetch = fetch != 0; + base_args.is_gather = gather != 0; + base_args.is_proj = proj != 0; + + string expr; + TextureFunctionNameArguments name_args = {}; + + name_args.base = base_args; + name_args.has_array_offsets = coffsets != 0; + name_args.has_offset = coffset != 0 || offset != 0; + name_args.has_grad = grad_x != 0 || grad_y != 0; + name_args.has_dref = dref != 0; + name_args.is_sparse_feedback = sparse; + name_args.has_min_lod = minlod != 0; + name_args.lod = lod; + expr += to_function_name(name_args); + expr += "("; + + uint32_t sparse_texel_id = 0; + if (sparse) + sparse_texel_id = get_sparse_feedback_texel_id(ops[1]); + + TextureFunctionArguments args = {}; + args.base = base_args; + args.coord = coord; + args.coord_components = coord_components; + args.dref = dref; + args.grad_x = grad_x; + args.grad_y = grad_y; + args.lod = lod; + args.has_array_offsets = coffsets != 0; + + if (coffsets) + args.offset = coffsets; + else if (coffset) + args.offset = coffset; + else + args.offset = offset; + + args.bias = bias; + args.component = comp; + args.sample = sample; + args.sparse_texel = sparse_texel_id; + args.min_lod = minlod; + args.nonuniform_expression = nonuniform_expression; + expr += to_function_args(args, forward); + expr += ")"; + + // texture(samplerXShadow) returns float. shadowX() returns vec4, but only in desktop GLSL. Swizzle here. + if (is_legacy() && !options.es && is_depth_image(imgtype, img)) + expr += ".r"; + + // Sampling from a texture which was deduced to be a depth image, might actually return 1 component here. + // Remap back to 4 components as sampling opcodes expect. + if (backend.comparison_image_samples_scalar && image_opcode_is_sample_no_dref(op)) + { + bool image_is_depth = false; + const auto *combined = maybe_get(img); + VariableID image_id = combined ? combined->image : img; + + if (combined && is_depth_image(imgtype, combined->image)) + image_is_depth = true; + else if (is_depth_image(imgtype, img)) + image_is_depth = true; + + // We must also check the backing variable for the image. + // We might have loaded an OpImage, and used that handle for two different purposes. + // Once with comparison, once without. + auto *image_variable = maybe_get_backing_variable(image_id); + if (image_variable && is_depth_image(get(image_variable->basetype), image_variable->self)) + image_is_depth = true; + + if (image_is_depth) + expr = remap_swizzle(result_type, 1, expr); + } + + if (!sparse && !backend.support_small_type_sampling_result && result_type.width < 32) + { + // Just value cast (narrowing) to expected type since we cannot rely on narrowing to work automatically. + // Hopefully compiler picks this up and converts the texturing instruction to the appropriate precision. + expr = join(type_to_glsl_constructor(result_type), "(", expr, ")"); + } + + // Deals with reads from MSL. We might need to downconvert to fewer components. + if (op == OpImageRead) + expr = remap_swizzle(result_type, 4, expr); + + return expr; +} + +bool CompilerGLSL::expression_is_constant_null(uint32_t id) const +{ + auto *c = maybe_get(id); + if (!c) + return false; + return c->constant_is_null(); +} + +bool CompilerGLSL::expression_is_non_value_type_array(uint32_t ptr) +{ + auto &type = expression_type(ptr); + if (!is_array(get_pointee_type(type))) + return false; + + if (!backend.array_is_value_type) + return true; + + auto *var = maybe_get_backing_variable(ptr); + if (!var) + return false; + + auto &backed_type = get(var->basetype); + return !backend.array_is_value_type_in_buffer_blocks && backed_type.basetype == SPIRType::Struct && + has_member_decoration(backed_type.self, 0, DecorationOffset); +} + +// Returns the function name for a texture sampling function for the specified image and sampling characteristics. +// For some subclasses, the function is a method on the specified image. +string CompilerGLSL::to_function_name(const TextureFunctionNameArguments &args) +{ + if (args.has_min_lod) + { + if (options.es) + SPIRV_CROSS_THROW("Sparse residency is not supported in ESSL."); + require_extension_internal("GL_ARB_sparse_texture_clamp"); + } + + string fname; + auto &imgtype = *args.base.imgtype; + VariableID tex = args.base.img; + + // textureLod on sampler2DArrayShadow and samplerCubeShadow does not exist in GLSL for some reason. + // To emulate this, we will have to use textureGrad with a constant gradient of 0. + // The workaround will assert that the LOD is in fact constant 0, or we cannot emit correct code. + // This happens for HLSL SampleCmpLevelZero on Texture2DArray and TextureCube. + bool workaround_lod_array_shadow_as_grad = false; + if (((imgtype.image.arrayed && imgtype.image.dim == Dim2D) || imgtype.image.dim == DimCube) && + is_depth_image(imgtype, tex) && args.lod && !args.base.is_fetch) + { + if (!expression_is_constant_null(args.lod)) + { + SPIRV_CROSS_THROW("textureLod on sampler2DArrayShadow is not constant 0.0. This cannot be " + "expressed in GLSL."); + } + workaround_lod_array_shadow_as_grad = true; + } + + if (args.is_sparse_feedback) + fname += "sparse"; + + if (args.base.is_fetch) + fname += args.is_sparse_feedback ? "TexelFetch" : "texelFetch"; + else + { + fname += args.is_sparse_feedback ? "Texture" : "texture"; + + if (args.base.is_gather) + fname += "Gather"; + if (args.has_array_offsets) + fname += "Offsets"; + if (args.base.is_proj) + fname += "Proj"; + if (args.has_grad || workaround_lod_array_shadow_as_grad) + fname += "Grad"; + if (args.lod != 0 && !workaround_lod_array_shadow_as_grad) + fname += "Lod"; + } + + if (args.has_offset) + fname += "Offset"; + + if (args.has_min_lod) + fname += "Clamp"; + + if (args.is_sparse_feedback || args.has_min_lod) + fname += "ARB"; + + return (is_legacy() && !args.base.is_gather) ? legacy_tex_op(fname, imgtype, tex) : fname; +} + +std::string CompilerGLSL::convert_separate_image_to_expression(uint32_t id) +{ + auto *var = maybe_get_backing_variable(id); + + // If we are fetching from a plain OpTypeImage, we must combine with a dummy sampler in GLSL. + // In Vulkan GLSL, we can make use of the newer GL_EXT_samplerless_texture_functions. + if (var) + { + auto &type = get(var->basetype); + if (type.basetype == SPIRType::Image && type.image.sampled == 1 && type.image.dim != DimBuffer) + { + if (options.vulkan_semantics) + { + if (dummy_sampler_id) + { + // Don't need to consider Shadow state since the dummy sampler is always non-shadow. + auto sampled_type = type; + sampled_type.basetype = SPIRType::SampledImage; + return join(type_to_glsl(sampled_type), "(", to_non_uniform_aware_expression(id), ", ", + to_expression(dummy_sampler_id), ")"); + } + else + { + // Newer glslang supports this extension to deal with texture2D as argument to texture functions. + require_extension_internal("GL_EXT_samplerless_texture_functions"); + } + } + else + { + if (!dummy_sampler_id) + SPIRV_CROSS_THROW("Cannot find dummy sampler ID. Was " + "build_dummy_sampler_for_combined_images() called?"); + + return to_combined_image_sampler(id, dummy_sampler_id); + } + } + } + + return to_non_uniform_aware_expression(id); +} + +// Returns the function args for a texture sampling function for the specified image and sampling characteristics. +string CompilerGLSL::to_function_args(const TextureFunctionArguments &args, bool *p_forward) +{ + VariableID img = args.base.img; + auto &imgtype = *args.base.imgtype; + + string farg_str; + if (args.base.is_fetch) + farg_str = convert_separate_image_to_expression(img); + else + farg_str = to_non_uniform_aware_expression(img); + + if (args.nonuniform_expression && farg_str.find_first_of('[') != string::npos) + { + // Only emit nonuniformEXT() wrapper if the underlying expression is arrayed in some way. + farg_str = join(backend.nonuniform_qualifier, "(", farg_str, ")"); + } + + bool swizz_func = backend.swizzle_is_function; + auto swizzle = [swizz_func](uint32_t comps, uint32_t in_comps) -> const char * { + if (comps == in_comps) + return ""; + + switch (comps) + { + case 1: + return ".x"; + case 2: + return swizz_func ? ".xy()" : ".xy"; + case 3: + return swizz_func ? ".xyz()" : ".xyz"; + default: + return ""; + } + }; + + bool forward = should_forward(args.coord); + + // The IR can give us more components than we need, so chop them off as needed. + auto swizzle_expr = swizzle(args.coord_components, expression_type(args.coord).vecsize); + // Only enclose the UV expression if needed. + auto coord_expr = + (*swizzle_expr == '\0') ? to_expression(args.coord) : (to_enclosed_expression(args.coord) + swizzle_expr); + + // texelFetch only takes int, not uint. + auto &coord_type = expression_type(args.coord); + if (coord_type.basetype == SPIRType::UInt) + { + auto expected_type = coord_type; + expected_type.vecsize = args.coord_components; + expected_type.basetype = SPIRType::Int; + coord_expr = bitcast_expression(expected_type, coord_type.basetype, coord_expr); + } + + // textureLod on sampler2DArrayShadow and samplerCubeShadow does not exist in GLSL for some reason. + // To emulate this, we will have to use textureGrad with a constant gradient of 0. + // The workaround will assert that the LOD is in fact constant 0, or we cannot emit correct code. + // This happens for HLSL SampleCmpLevelZero on Texture2DArray and TextureCube. + bool workaround_lod_array_shadow_as_grad = + ((imgtype.image.arrayed && imgtype.image.dim == Dim2D) || imgtype.image.dim == DimCube) && + is_depth_image(imgtype, img) && args.lod != 0 && !args.base.is_fetch; + + if (args.dref) + { + forward = forward && should_forward(args.dref); + + // SPIR-V splits dref and coordinate. + if (args.base.is_gather || + args.coord_components == 4) // GLSL also splits the arguments in two. Same for textureGather. + { + farg_str += ", "; + farg_str += to_expression(args.coord); + farg_str += ", "; + farg_str += to_expression(args.dref); + } + else if (args.base.is_proj) + { + // Have to reshuffle so we get vec4(coord, dref, proj), special case. + // Other shading languages splits up the arguments for coord and compare value like SPIR-V. + // The coordinate type for textureProj shadow is always vec4 even for sampler1DShadow. + farg_str += ", vec4("; + + if (imgtype.image.dim == Dim1D) + { + // Could reuse coord_expr, but we will mess up the temporary usage checking. + farg_str += to_enclosed_expression(args.coord) + ".x"; + farg_str += ", "; + farg_str += "0.0, "; + farg_str += to_expression(args.dref); + farg_str += ", "; + farg_str += to_enclosed_expression(args.coord) + ".y)"; + } + else if (imgtype.image.dim == Dim2D) + { + // Could reuse coord_expr, but we will mess up the temporary usage checking. + farg_str += to_enclosed_expression(args.coord) + (swizz_func ? ".xy()" : ".xy"); + farg_str += ", "; + farg_str += to_expression(args.dref); + farg_str += ", "; + farg_str += to_enclosed_expression(args.coord) + ".z)"; + } + else + SPIRV_CROSS_THROW("Invalid type for textureProj with shadow."); + } + else + { + // Create a composite which merges coord/dref into a single vector. + auto type = expression_type(args.coord); + type.vecsize = args.coord_components + 1; + if (imgtype.image.dim == Dim1D && options.es) + type.vecsize++; + farg_str += ", "; + farg_str += type_to_glsl_constructor(type); + farg_str += "("; + + if (imgtype.image.dim == Dim1D && options.es) + { + if (imgtype.image.arrayed) + { + farg_str += enclose_expression(coord_expr) + ".x"; + farg_str += ", 0.0, "; + farg_str += enclose_expression(coord_expr) + ".y"; + } + else + { + farg_str += coord_expr; + farg_str += ", 0.0"; + } + } + else + farg_str += coord_expr; + + farg_str += ", "; + farg_str += to_expression(args.dref); + farg_str += ")"; + } + } + else + { + if (imgtype.image.dim == Dim1D && options.es) + { + // Have to fake a second coordinate. + if (type_is_floating_point(coord_type)) + { + // Cannot mix proj and array. + if (imgtype.image.arrayed || args.base.is_proj) + { + coord_expr = join("vec3(", enclose_expression(coord_expr), ".x, 0.0, ", + enclose_expression(coord_expr), ".y)"); + } + else + coord_expr = join("vec2(", coord_expr, ", 0.0)"); + } + else + { + if (imgtype.image.arrayed) + { + coord_expr = join("ivec3(", enclose_expression(coord_expr), + ".x, 0, ", + enclose_expression(coord_expr), ".y)"); + } + else + coord_expr = join("ivec2(", coord_expr, ", 0)"); + } + } + + farg_str += ", "; + farg_str += coord_expr; + } + + if (args.grad_x || args.grad_y) + { + forward = forward && should_forward(args.grad_x); + forward = forward && should_forward(args.grad_y); + farg_str += ", "; + farg_str += to_expression(args.grad_x); + farg_str += ", "; + farg_str += to_expression(args.grad_y); + } + + if (args.lod) + { + if (workaround_lod_array_shadow_as_grad) + { + // Implement textureGrad() instead. LOD == 0.0 is implemented as gradient of 0.0. + // Implementing this as plain texture() is not safe on some implementations. + if (imgtype.image.dim == Dim2D) + farg_str += ", vec2(0.0), vec2(0.0)"; + else if (imgtype.image.dim == DimCube) + farg_str += ", vec3(0.0), vec3(0.0)"; + } + else + { + forward = forward && should_forward(args.lod); + farg_str += ", "; + + // Lod expression for TexelFetch in GLSL must be int, and only int. + if (args.base.is_fetch && imgtype.image.dim != DimBuffer && !imgtype.image.ms) + farg_str += bitcast_expression(SPIRType::Int, args.lod); + else + farg_str += to_expression(args.lod); + } + } + else if (args.base.is_fetch && imgtype.image.dim != DimBuffer && !imgtype.image.ms) + { + // Lod argument is optional in OpImageFetch, but we require a LOD value, pick 0 as the default. + farg_str += ", 0"; + } + + if (args.offset) + { + forward = forward && should_forward(args.offset); + farg_str += ", "; + farg_str += bitcast_expression(SPIRType::Int, args.offset); + } + + if (args.sample) + { + farg_str += ", "; + farg_str += bitcast_expression(SPIRType::Int, args.sample); + } + + if (args.min_lod) + { + farg_str += ", "; + farg_str += to_expression(args.min_lod); + } + + if (args.sparse_texel) + { + // Sparse texel output parameter comes after everything else, except it's before the optional, component/bias arguments. + farg_str += ", "; + farg_str += to_expression(args.sparse_texel); + } + + if (args.bias) + { + forward = forward && should_forward(args.bias); + farg_str += ", "; + farg_str += to_expression(args.bias); + } + + if (args.component && !expression_is_constant_null(args.component)) + { + forward = forward && should_forward(args.component); + farg_str += ", "; + farg_str += bitcast_expression(SPIRType::Int, args.component); + } + + *p_forward = forward; + + return farg_str; +} + +Op CompilerGLSL::get_remapped_spirv_op(Op op) const +{ + if (options.relax_nan_checks) + { + switch (op) + { + case OpFUnordLessThan: + op = OpFOrdLessThan; + break; + case OpFUnordLessThanEqual: + op = OpFOrdLessThanEqual; + break; + case OpFUnordGreaterThan: + op = OpFOrdGreaterThan; + break; + case OpFUnordGreaterThanEqual: + op = OpFOrdGreaterThanEqual; + break; + case OpFUnordEqual: + op = OpFOrdEqual; + break; + case OpFOrdNotEqual: + op = OpFUnordNotEqual; + break; + + default: + break; + } + } + + return op; +} + +GLSLstd450 CompilerGLSL::get_remapped_glsl_op(GLSLstd450 std450_op) const +{ + // Relax to non-NaN aware opcodes. + if (options.relax_nan_checks) + { + switch (std450_op) + { + case GLSLstd450NClamp: + std450_op = GLSLstd450FClamp; + break; + case GLSLstd450NMin: + std450_op = GLSLstd450FMin; + break; + case GLSLstd450NMax: + std450_op = GLSLstd450FMax; + break; + default: + break; + } + } + + return std450_op; +} + +void CompilerGLSL::emit_glsl_op(uint32_t result_type, uint32_t id, uint32_t eop, const uint32_t *args, uint32_t length) +{ + auto op = static_cast(eop); + + if (is_legacy() && is_unsigned_glsl_opcode(op)) + SPIRV_CROSS_THROW("Unsigned integers are not supported on legacy GLSL targets."); + + // If we need to do implicit bitcasts, make sure we do it with the correct type. + uint32_t integer_width = get_integer_width_for_glsl_instruction(op, args, length); + auto int_type = to_signed_basetype(integer_width); + auto uint_type = to_unsigned_basetype(integer_width); + + op = get_remapped_glsl_op(op); + + switch (op) + { + // FP fiddling + case GLSLstd450Round: + if (!is_legacy()) + emit_unary_func_op(result_type, id, args[0], "round"); + else + { + auto op0 = to_enclosed_expression(args[0]); + auto &op0_type = expression_type(args[0]); + auto expr = join("floor(", op0, " + ", type_to_glsl_constructor(op0_type), "(0.5))"); + bool forward = should_forward(args[0]); + emit_op(result_type, id, expr, forward); + inherit_expression_dependencies(id, args[0]); + } + break; + + case GLSLstd450RoundEven: + if (!is_legacy()) + emit_unary_func_op(result_type, id, args[0], "roundEven"); + else if (!options.es) + { + // This extension provides round() with round-to-even semantics. + require_extension_internal("GL_EXT_gpu_shader4"); + emit_unary_func_op(result_type, id, args[0], "round"); + } + else + SPIRV_CROSS_THROW("roundEven supported only in ESSL 300."); + break; + + case GLSLstd450Trunc: + if (!is_legacy()) + emit_unary_func_op(result_type, id, args[0], "trunc"); + else + { + // Implement by value-casting to int and back. + bool forward = should_forward(args[0]); + auto op0 = to_unpacked_expression(args[0]); + auto &op0_type = expression_type(args[0]); + auto via_type = op0_type; + via_type.basetype = SPIRType::Int; + auto expr = join(type_to_glsl(op0_type), "(", type_to_glsl(via_type), "(", op0, "))"); + emit_op(result_type, id, expr, forward); + inherit_expression_dependencies(id, args[0]); + } + break; + + case GLSLstd450SAbs: + emit_unary_func_op_cast(result_type, id, args[0], "abs", int_type, int_type); + break; + case GLSLstd450FAbs: + emit_unary_func_op(result_type, id, args[0], "abs"); + break; + case GLSLstd450SSign: + emit_unary_func_op_cast(result_type, id, args[0], "sign", int_type, int_type); + break; + case GLSLstd450FSign: + emit_unary_func_op(result_type, id, args[0], "sign"); + break; + case GLSLstd450Floor: + emit_unary_func_op(result_type, id, args[0], "floor"); + break; + case GLSLstd450Ceil: + emit_unary_func_op(result_type, id, args[0], "ceil"); + break; + case GLSLstd450Fract: + emit_unary_func_op(result_type, id, args[0], "fract"); + break; + case GLSLstd450Radians: + emit_unary_func_op(result_type, id, args[0], "radians"); + break; + case GLSLstd450Degrees: + emit_unary_func_op(result_type, id, args[0], "degrees"); + break; + case GLSLstd450Fma: + if ((!options.es && options.version < 400) || (options.es && options.version < 320)) + { + auto expr = join(to_enclosed_expression(args[0]), " * ", to_enclosed_expression(args[1]), " + ", + to_enclosed_expression(args[2])); + + emit_op(result_type, id, expr, + should_forward(args[0]) && should_forward(args[1]) && should_forward(args[2])); + for (uint32_t i = 0; i < 3; i++) + inherit_expression_dependencies(id, args[i]); + } + else + emit_trinary_func_op(result_type, id, args[0], args[1], args[2], "fma"); + break; + + case GLSLstd450Modf: + register_call_out_argument(args[1]); + if (!is_legacy()) + { + forced_temporaries.insert(id); + emit_binary_func_op(result_type, id, args[0], args[1], "modf"); + } + else + { + //NB. legacy GLSL doesn't have trunc() either, so we do a value cast + auto &op1_type = expression_type(args[1]); + auto via_type = op1_type; + via_type.basetype = SPIRType::Int; + statement(to_expression(args[1]), " = ", + type_to_glsl(op1_type), "(", type_to_glsl(via_type), + "(", to_expression(args[0]), "));"); + emit_binary_op(result_type, id, args[0], args[1], "-"); + } + break; + + case GLSLstd450ModfStruct: + { + auto &type = get(result_type); + emit_uninitialized_temporary_expression(result_type, id); + if (!is_legacy()) + { + statement(to_expression(id), ".", to_member_name(type, 0), " = ", "modf(", to_expression(args[0]), ", ", + to_expression(id), ".", to_member_name(type, 1), ");"); + } + else + { + //NB. legacy GLSL doesn't have trunc() either, so we do a value cast + auto &op0_type = expression_type(args[0]); + auto via_type = op0_type; + via_type.basetype = SPIRType::Int; + statement(to_expression(id), ".", to_member_name(type, 1), " = ", type_to_glsl(op0_type), + "(", type_to_glsl(via_type), "(", to_expression(args[0]), "));"); + statement(to_expression(id), ".", to_member_name(type, 0), " = ", to_enclosed_expression(args[0]), " - ", + to_expression(id), ".", to_member_name(type, 1), ";"); + } + break; + } + + // Minmax + case GLSLstd450UMin: + emit_binary_func_op_cast(result_type, id, args[0], args[1], "min", uint_type, false); + break; + + case GLSLstd450SMin: + emit_binary_func_op_cast(result_type, id, args[0], args[1], "min", int_type, false); + break; + + case GLSLstd450FMin: + emit_binary_func_op(result_type, id, args[0], args[1], "min"); + break; + + case GLSLstd450FMax: + emit_binary_func_op(result_type, id, args[0], args[1], "max"); + break; + + case GLSLstd450UMax: + emit_binary_func_op_cast(result_type, id, args[0], args[1], "max", uint_type, false); + break; + + case GLSLstd450SMax: + emit_binary_func_op_cast(result_type, id, args[0], args[1], "max", int_type, false); + break; + + case GLSLstd450FClamp: + emit_trinary_func_op(result_type, id, args[0], args[1], args[2], "clamp"); + break; + + case GLSLstd450UClamp: + emit_trinary_func_op_cast(result_type, id, args[0], args[1], args[2], "clamp", uint_type); + break; + + case GLSLstd450SClamp: + emit_trinary_func_op_cast(result_type, id, args[0], args[1], args[2], "clamp", int_type); + break; + + // Trig + case GLSLstd450Sin: + emit_unary_func_op(result_type, id, args[0], "sin"); + break; + case GLSLstd450Cos: + emit_unary_func_op(result_type, id, args[0], "cos"); + break; + case GLSLstd450Tan: + emit_unary_func_op(result_type, id, args[0], "tan"); + break; + case GLSLstd450Asin: + emit_unary_func_op(result_type, id, args[0], "asin"); + break; + case GLSLstd450Acos: + emit_unary_func_op(result_type, id, args[0], "acos"); + break; + case GLSLstd450Atan: + emit_unary_func_op(result_type, id, args[0], "atan"); + break; + case GLSLstd450Sinh: + if (!is_legacy()) + emit_unary_func_op(result_type, id, args[0], "sinh"); + else + { + bool forward = should_forward(args[0]); + auto expr = join("(exp(", to_expression(args[0]), ") - exp(-", to_enclosed_expression(args[0]), ")) * 0.5"); + emit_op(result_type, id, expr, forward); + inherit_expression_dependencies(id, args[0]); + } + break; + case GLSLstd450Cosh: + if (!is_legacy()) + emit_unary_func_op(result_type, id, args[0], "cosh"); + else + { + bool forward = should_forward(args[0]); + auto expr = join("(exp(", to_expression(args[0]), ") + exp(-", to_enclosed_expression(args[0]), ")) * 0.5"); + emit_op(result_type, id, expr, forward); + inherit_expression_dependencies(id, args[0]); + } + break; + case GLSLstd450Tanh: + if (!is_legacy()) + emit_unary_func_op(result_type, id, args[0], "tanh"); + else + { + // Create temporaries to store the result of exp(arg) and exp(-arg). + uint32_t &ids = extra_sub_expressions[id]; + if (!ids) + { + ids = ir.increase_bound_by(2); + + // Inherit precision qualifier (legacy has no NoContraction). + if (has_decoration(id, DecorationRelaxedPrecision)) + { + set_decoration(ids, DecorationRelaxedPrecision); + set_decoration(ids + 1, DecorationRelaxedPrecision); + } + } + uint32_t epos_id = ids; + uint32_t eneg_id = ids + 1; + + emit_op(result_type, epos_id, join("exp(", to_expression(args[0]), ")"), false); + emit_op(result_type, eneg_id, join("exp(-", to_enclosed_expression(args[0]), ")"), false); + inherit_expression_dependencies(epos_id, args[0]); + inherit_expression_dependencies(eneg_id, args[0]); + + auto expr = join("(", to_enclosed_expression(epos_id), " - ", to_enclosed_expression(eneg_id), ") / " + "(", to_enclosed_expression(epos_id), " + ", to_enclosed_expression(eneg_id), ")"); + emit_op(result_type, id, expr, true); + inherit_expression_dependencies(id, epos_id); + inherit_expression_dependencies(id, eneg_id); + } + break; + case GLSLstd450Asinh: + if (!is_legacy()) + emit_unary_func_op(result_type, id, args[0], "asinh"); + else + emit_emulated_ahyper_op(result_type, id, args[0], GLSLstd450Asinh); + break; + case GLSLstd450Acosh: + if (!is_legacy()) + emit_unary_func_op(result_type, id, args[0], "acosh"); + else + emit_emulated_ahyper_op(result_type, id, args[0], GLSLstd450Acosh); + break; + case GLSLstd450Atanh: + if (!is_legacy()) + emit_unary_func_op(result_type, id, args[0], "atanh"); + else + emit_emulated_ahyper_op(result_type, id, args[0], GLSLstd450Atanh); + break; + case GLSLstd450Atan2: + emit_binary_func_op(result_type, id, args[0], args[1], "atan"); + break; + + // Exponentials + case GLSLstd450Pow: + emit_binary_func_op(result_type, id, args[0], args[1], "pow"); + break; + case GLSLstd450Exp: + emit_unary_func_op(result_type, id, args[0], "exp"); + break; + case GLSLstd450Log: + emit_unary_func_op(result_type, id, args[0], "log"); + break; + case GLSLstd450Exp2: + emit_unary_func_op(result_type, id, args[0], "exp2"); + break; + case GLSLstd450Log2: + emit_unary_func_op(result_type, id, args[0], "log2"); + break; + case GLSLstd450Sqrt: + emit_unary_func_op(result_type, id, args[0], "sqrt"); + break; + case GLSLstd450InverseSqrt: + emit_unary_func_op(result_type, id, args[0], "inversesqrt"); + break; + + // Matrix math + case GLSLstd450Determinant: + { + // No need to transpose - it doesn't affect the determinant + auto *e = maybe_get(args[0]); + bool old_transpose = e && e->need_transpose; + if (old_transpose) + e->need_transpose = false; + + if (options.version < 150) // also matches ES 100 + { + auto &type = expression_type(args[0]); + assert(type.vecsize >= 2 && type.vecsize <= 4); + assert(type.vecsize == type.columns); + + // ARB_gpu_shader_fp64 needs GLSL 150, other types are not valid + if (type.basetype != SPIRType::Float) + SPIRV_CROSS_THROW("Unsupported type for matrix determinant"); + + bool relaxed = has_decoration(id, DecorationRelaxedPrecision); + require_polyfill(static_cast(PolyfillDeterminant2x2 << (type.vecsize - 2)), + relaxed); + emit_unary_func_op(result_type, id, args[0], + (options.es && relaxed) ? "spvDeterminantMP" : "spvDeterminant"); + } + else + emit_unary_func_op(result_type, id, args[0], "determinant"); + + if (old_transpose) + e->need_transpose = true; + break; + } + + case GLSLstd450MatrixInverse: + { + // The inverse of the transpose is the same as the transpose of + // the inverse, so we can just flip need_transpose of the result. + auto *a = maybe_get(args[0]); + bool old_transpose = a && a->need_transpose; + if (old_transpose) + a->need_transpose = false; + + const char *func = "inverse"; + if (options.version < 140) // also matches ES 100 + { + auto &type = get(result_type); + assert(type.vecsize >= 2 && type.vecsize <= 4); + assert(type.vecsize == type.columns); + + // ARB_gpu_shader_fp64 needs GLSL 150, other types are invalid + if (type.basetype != SPIRType::Float) + SPIRV_CROSS_THROW("Unsupported type for matrix inverse"); + + bool relaxed = has_decoration(id, DecorationRelaxedPrecision); + require_polyfill(static_cast(PolyfillMatrixInverse2x2 << (type.vecsize - 2)), + relaxed); + func = (options.es && relaxed) ? "spvInverseMP" : "spvInverse"; + } + + bool forward = should_forward(args[0]); + auto &e = emit_op(result_type, id, join(func, "(", to_unpacked_expression(args[0]), ")"), forward); + inherit_expression_dependencies(id, args[0]); + + if (old_transpose) + { + e.need_transpose = true; + a->need_transpose = true; + } + break; + } + + // Lerping + case GLSLstd450FMix: + case GLSLstd450IMix: + { + emit_mix_op(result_type, id, args[0], args[1], args[2]); + break; + } + case GLSLstd450Step: + emit_binary_func_op(result_type, id, args[0], args[1], "step"); + break; + case GLSLstd450SmoothStep: + emit_trinary_func_op(result_type, id, args[0], args[1], args[2], "smoothstep"); + break; + + // Packing + case GLSLstd450Frexp: + register_call_out_argument(args[1]); + forced_temporaries.insert(id); + emit_binary_func_op(result_type, id, args[0], args[1], "frexp"); + break; + + case GLSLstd450FrexpStruct: + { + auto &type = get(result_type); + emit_uninitialized_temporary_expression(result_type, id); + statement(to_expression(id), ".", to_member_name(type, 0), " = ", "frexp(", to_expression(args[0]), ", ", + to_expression(id), ".", to_member_name(type, 1), ");"); + break; + } + + case GLSLstd450Ldexp: + { + bool forward = should_forward(args[0]) && should_forward(args[1]); + + auto op0 = to_unpacked_expression(args[0]); + auto op1 = to_unpacked_expression(args[1]); + auto &op1_type = expression_type(args[1]); + if (op1_type.basetype != SPIRType::Int) + { + // Need a value cast here. + auto target_type = op1_type; + target_type.basetype = SPIRType::Int; + op1 = join(type_to_glsl_constructor(target_type), "(", op1, ")"); + } + + auto expr = join("ldexp(", op0, ", ", op1, ")"); + + emit_op(result_type, id, expr, forward); + inherit_expression_dependencies(id, args[0]); + inherit_expression_dependencies(id, args[1]); + break; + } + + case GLSLstd450PackSnorm4x8: + emit_unary_func_op(result_type, id, args[0], "packSnorm4x8"); + break; + case GLSLstd450PackUnorm4x8: + emit_unary_func_op(result_type, id, args[0], "packUnorm4x8"); + break; + case GLSLstd450PackSnorm2x16: + emit_unary_func_op(result_type, id, args[0], "packSnorm2x16"); + break; + case GLSLstd450PackUnorm2x16: + emit_unary_func_op(result_type, id, args[0], "packUnorm2x16"); + break; + case GLSLstd450PackHalf2x16: + emit_unary_func_op(result_type, id, args[0], "packHalf2x16"); + break; + case GLSLstd450UnpackSnorm4x8: + emit_unary_func_op(result_type, id, args[0], "unpackSnorm4x8"); + break; + case GLSLstd450UnpackUnorm4x8: + emit_unary_func_op(result_type, id, args[0], "unpackUnorm4x8"); + break; + case GLSLstd450UnpackSnorm2x16: + emit_unary_func_op(result_type, id, args[0], "unpackSnorm2x16"); + break; + case GLSLstd450UnpackUnorm2x16: + emit_unary_func_op(result_type, id, args[0], "unpackUnorm2x16"); + break; + case GLSLstd450UnpackHalf2x16: + emit_unary_func_op(result_type, id, args[0], "unpackHalf2x16"); + break; + + case GLSLstd450PackDouble2x32: + emit_unary_func_op(result_type, id, args[0], "packDouble2x32"); + break; + case GLSLstd450UnpackDouble2x32: + emit_unary_func_op(result_type, id, args[0], "unpackDouble2x32"); + break; + + // Vector math + case GLSLstd450Length: + emit_unary_func_op(result_type, id, args[0], "length"); + break; + case GLSLstd450Distance: + emit_binary_func_op(result_type, id, args[0], args[1], "distance"); + break; + case GLSLstd450Cross: + emit_binary_func_op(result_type, id, args[0], args[1], "cross"); + break; + case GLSLstd450Normalize: + emit_unary_func_op(result_type, id, args[0], "normalize"); + break; + case GLSLstd450FaceForward: + emit_trinary_func_op(result_type, id, args[0], args[1], args[2], "faceforward"); + break; + case GLSLstd450Reflect: + emit_binary_func_op(result_type, id, args[0], args[1], "reflect"); + break; + case GLSLstd450Refract: + emit_trinary_func_op(result_type, id, args[0], args[1], args[2], "refract"); + break; + + // Bit-fiddling + case GLSLstd450FindILsb: + // findLSB always returns int. + emit_unary_func_op_cast(result_type, id, args[0], "findLSB", expression_type(args[0]).basetype, int_type); + break; + + case GLSLstd450FindSMsb: + emit_unary_func_op_cast(result_type, id, args[0], "findMSB", int_type, int_type); + break; + + case GLSLstd450FindUMsb: + emit_unary_func_op_cast(result_type, id, args[0], "findMSB", uint_type, + int_type); // findMSB always returns int. + break; + + // Multisampled varying + case GLSLstd450InterpolateAtCentroid: + emit_unary_func_op(result_type, id, args[0], "interpolateAtCentroid"); + break; + case GLSLstd450InterpolateAtSample: + emit_binary_func_op(result_type, id, args[0], args[1], "interpolateAtSample"); + break; + case GLSLstd450InterpolateAtOffset: + emit_binary_func_op(result_type, id, args[0], args[1], "interpolateAtOffset"); + break; + + case GLSLstd450NMin: + case GLSLstd450NMax: + { + if (options.vulkan_semantics) + { + require_extension_internal("GL_EXT_spirv_intrinsics"); + bool relaxed = has_decoration(id, DecorationRelaxedPrecision); + Polyfill poly = {}; + switch (get(result_type).width) + { + case 16: + poly = op == GLSLstd450NMin ? PolyfillNMin16 : PolyfillNMax16; + break; + + case 32: + poly = op == GLSLstd450NMin ? PolyfillNMin32 : PolyfillNMax32; + break; + + case 64: + poly = op == GLSLstd450NMin ? PolyfillNMin64 : PolyfillNMax64; + break; + + default: + SPIRV_CROSS_THROW("Invalid bit width for NMin/NMax."); + } + + require_polyfill(poly, relaxed); + + // Function return decorations are broken, so need to do double polyfill. + if (relaxed) + require_polyfill(poly, false); + + const char *op_str; + if (relaxed) + op_str = op == GLSLstd450NMin ? "spvNMinRelaxed" : "spvNMaxRelaxed"; + else + op_str = op == GLSLstd450NMin ? "spvNMin" : "spvNMax"; + + emit_binary_func_op(result_type, id, args[0], args[1], op_str); + } + else + { + emit_nminmax_op(result_type, id, args[0], args[1], op); + } + break; + } + + case GLSLstd450NClamp: + { + if (options.vulkan_semantics) + { + require_extension_internal("GL_EXT_spirv_intrinsics"); + bool relaxed = has_decoration(id, DecorationRelaxedPrecision); + Polyfill poly = {}; + switch (get(result_type).width) + { + case 16: + poly = PolyfillNClamp16; + break; + + case 32: + poly = PolyfillNClamp32; + break; + + case 64: + poly = PolyfillNClamp64; + break; + + default: + SPIRV_CROSS_THROW("Invalid bit width for NMin/NMax."); + } + + require_polyfill(poly, relaxed); + + // Function return decorations are broken, so need to do double polyfill. + if (relaxed) + require_polyfill(poly, false); + + emit_trinary_func_op(result_type, id, args[0], args[1], args[2], relaxed ? "spvNClampRelaxed" : "spvNClamp"); + } + else + { + // Make sure we have a unique ID here to avoid aliasing the extra sub-expressions between clamp and NMin sub-op. + // IDs cannot exceed 24 bits, so we can make use of the higher bits for some unique flags. + uint32_t &max_id = extra_sub_expressions[id | EXTRA_SUB_EXPRESSION_TYPE_AUX]; + if (!max_id) + max_id = ir.increase_bound_by(1); + + // Inherit precision qualifiers. + ir.meta[max_id] = ir.meta[id]; + + emit_nminmax_op(result_type, max_id, args[0], args[1], GLSLstd450NMax); + emit_nminmax_op(result_type, id, max_id, args[2], GLSLstd450NMin); + } + break; + } + + default: + statement("// unimplemented GLSL op ", eop); + break; + } +} + +void CompilerGLSL::emit_nminmax_op(uint32_t result_type, uint32_t id, uint32_t op0, uint32_t op1, GLSLstd450 op) +{ + // Need to emulate this call. + uint32_t &ids = extra_sub_expressions[id]; + if (!ids) + { + ids = ir.increase_bound_by(5); + auto btype = get(result_type); + btype.basetype = SPIRType::Boolean; + set(ids, btype); + } + + uint32_t btype_id = ids + 0; + uint32_t left_nan_id = ids + 1; + uint32_t right_nan_id = ids + 2; + uint32_t tmp_id = ids + 3; + uint32_t mixed_first_id = ids + 4; + + // Inherit precision qualifiers. + ir.meta[tmp_id] = ir.meta[id]; + ir.meta[mixed_first_id] = ir.meta[id]; + + if (!is_legacy()) + { + emit_unary_func_op(btype_id, left_nan_id, op0, "isnan"); + emit_unary_func_op(btype_id, right_nan_id, op1, "isnan"); + } + else if (expression_type(op0).vecsize > 1) + { + // If the number doesn't equal itself, it must be NaN + emit_binary_func_op(btype_id, left_nan_id, op0, op0, "notEqual"); + emit_binary_func_op(btype_id, right_nan_id, op1, op1, "notEqual"); + } + else + { + emit_binary_op(btype_id, left_nan_id, op0, op0, "!="); + emit_binary_op(btype_id, right_nan_id, op1, op1, "!="); + } + emit_binary_func_op(result_type, tmp_id, op0, op1, op == GLSLstd450NMin ? "min" : "max"); + emit_mix_op(result_type, mixed_first_id, tmp_id, op1, left_nan_id); + emit_mix_op(result_type, id, mixed_first_id, op0, right_nan_id); +} + +void CompilerGLSL::emit_emulated_ahyper_op(uint32_t result_type, uint32_t id, uint32_t op0, GLSLstd450 op) +{ + const char *one = backend.float_literal_suffix ? "1.0f" : "1.0"; + std::string expr; + bool forward = should_forward(op0); + + switch (op) + { + case GLSLstd450Asinh: + expr = join("log(", to_enclosed_expression(op0), " + sqrt(", + to_enclosed_expression(op0), " * ", to_enclosed_expression(op0), " + ", one, "))"); + emit_op(result_type, id, expr, forward); + break; + + case GLSLstd450Acosh: + expr = join("log(", to_enclosed_expression(op0), " + sqrt(", + to_enclosed_expression(op0), " * ", to_enclosed_expression(op0), " - ", one, "))"); + break; + + case GLSLstd450Atanh: + expr = join("log((", one, " + ", to_enclosed_expression(op0), ") / " + "(", one, " - ", to_enclosed_expression(op0), ")) * 0.5", + backend.float_literal_suffix ? "f" : ""); + break; + + default: + SPIRV_CROSS_THROW("Invalid op."); + } + + emit_op(result_type, id, expr, forward); + inherit_expression_dependencies(id, op0); +} + +void CompilerGLSL::emit_spv_amd_shader_ballot_op(uint32_t result_type, uint32_t id, uint32_t eop, const uint32_t *args, + uint32_t) +{ + require_extension_internal("GL_AMD_shader_ballot"); + + enum AMDShaderBallot + { + SwizzleInvocationsAMD = 1, + SwizzleInvocationsMaskedAMD = 2, + WriteInvocationAMD = 3, + MbcntAMD = 4 + }; + + auto op = static_cast(eop); + + switch (op) + { + case SwizzleInvocationsAMD: + emit_binary_func_op(result_type, id, args[0], args[1], "swizzleInvocationsAMD"); + register_control_dependent_expression(id); + break; + + case SwizzleInvocationsMaskedAMD: + emit_binary_func_op(result_type, id, args[0], args[1], "swizzleInvocationsMaskedAMD"); + register_control_dependent_expression(id); + break; + + case WriteInvocationAMD: + emit_trinary_func_op(result_type, id, args[0], args[1], args[2], "writeInvocationAMD"); + register_control_dependent_expression(id); + break; + + case MbcntAMD: + emit_unary_func_op(result_type, id, args[0], "mbcntAMD"); + register_control_dependent_expression(id); + break; + + default: + statement("// unimplemented SPV AMD shader ballot op ", eop); + break; + } +} + +void CompilerGLSL::emit_spv_amd_shader_explicit_vertex_parameter_op(uint32_t result_type, uint32_t id, uint32_t eop, + const uint32_t *args, uint32_t) +{ + require_extension_internal("GL_AMD_shader_explicit_vertex_parameter"); + + enum AMDShaderExplicitVertexParameter + { + InterpolateAtVertexAMD = 1 + }; + + auto op = static_cast(eop); + + switch (op) + { + case InterpolateAtVertexAMD: + emit_binary_func_op(result_type, id, args[0], args[1], "interpolateAtVertexAMD"); + break; + + default: + statement("// unimplemented SPV AMD shader explicit vertex parameter op ", eop); + break; + } +} + +void CompilerGLSL::emit_spv_amd_shader_trinary_minmax_op(uint32_t result_type, uint32_t id, uint32_t eop, + const uint32_t *args, uint32_t) +{ + require_extension_internal("GL_AMD_shader_trinary_minmax"); + + enum AMDShaderTrinaryMinMax + { + FMin3AMD = 1, + UMin3AMD = 2, + SMin3AMD = 3, + FMax3AMD = 4, + UMax3AMD = 5, + SMax3AMD = 6, + FMid3AMD = 7, + UMid3AMD = 8, + SMid3AMD = 9 + }; + + auto op = static_cast(eop); + + switch (op) + { + case FMin3AMD: + case UMin3AMD: + case SMin3AMD: + emit_trinary_func_op(result_type, id, args[0], args[1], args[2], "min3"); + break; + + case FMax3AMD: + case UMax3AMD: + case SMax3AMD: + emit_trinary_func_op(result_type, id, args[0], args[1], args[2], "max3"); + break; + + case FMid3AMD: + case UMid3AMD: + case SMid3AMD: + emit_trinary_func_op(result_type, id, args[0], args[1], args[2], "mid3"); + break; + + default: + statement("// unimplemented SPV AMD shader trinary minmax op ", eop); + break; + } +} + +void CompilerGLSL::emit_spv_amd_gcn_shader_op(uint32_t result_type, uint32_t id, uint32_t eop, const uint32_t *args, + uint32_t) +{ + require_extension_internal("GL_AMD_gcn_shader"); + + enum AMDGCNShader + { + CubeFaceIndexAMD = 1, + CubeFaceCoordAMD = 2, + TimeAMD = 3 + }; + + auto op = static_cast(eop); + + switch (op) + { + case CubeFaceIndexAMD: + emit_unary_func_op(result_type, id, args[0], "cubeFaceIndexAMD"); + break; + case CubeFaceCoordAMD: + emit_unary_func_op(result_type, id, args[0], "cubeFaceCoordAMD"); + break; + case TimeAMD: + { + string expr = "timeAMD()"; + emit_op(result_type, id, expr, true); + register_control_dependent_expression(id); + break; + } + + default: + statement("// unimplemented SPV AMD gcn shader op ", eop); + break; + } +} + +void CompilerGLSL::emit_subgroup_op(const Instruction &i) +{ + const uint32_t *ops = stream(i); + auto op = static_cast(i.op); + + if (!options.vulkan_semantics && !is_supported_subgroup_op_in_opengl(op, ops)) + SPIRV_CROSS_THROW("This subgroup operation is only supported in Vulkan semantics."); + + // If we need to do implicit bitcasts, make sure we do it with the correct type. + uint32_t integer_width = get_integer_width_for_instruction(i); + auto int_type = to_signed_basetype(integer_width); + auto uint_type = to_unsigned_basetype(integer_width); + + switch (op) + { + case OpGroupNonUniformElect: + request_subgroup_feature(ShaderSubgroupSupportHelper::SubgroupElect); + break; + + case OpGroupNonUniformBallotBitCount: + { + const GroupOperation operation = static_cast(ops[3]); + if (operation == GroupOperationReduce) + request_subgroup_feature(ShaderSubgroupSupportHelper::SubgroupBallotBitCount); + else if (operation == GroupOperationInclusiveScan || operation == GroupOperationExclusiveScan) + request_subgroup_feature(ShaderSubgroupSupportHelper::SubgroupInverseBallot_InclBitCount_ExclBitCout); + } + break; + + case OpGroupNonUniformBallotBitExtract: + request_subgroup_feature(ShaderSubgroupSupportHelper::SubgroupBallotBitExtract); + break; + + case OpGroupNonUniformInverseBallot: + request_subgroup_feature(ShaderSubgroupSupportHelper::SubgroupInverseBallot_InclBitCount_ExclBitCout); + break; + + case OpGroupNonUniformBallot: + request_subgroup_feature(ShaderSubgroupSupportHelper::SubgroupBallot); + break; + + case OpGroupNonUniformBallotFindLSB: + case OpGroupNonUniformBallotFindMSB: + request_subgroup_feature(ShaderSubgroupSupportHelper::SubgroupBallotFindLSB_MSB); + break; + + case OpGroupNonUniformBroadcast: + case OpGroupNonUniformBroadcastFirst: + request_subgroup_feature(ShaderSubgroupSupportHelper::SubgroupBroadcast_First); + break; + + case OpGroupNonUniformShuffle: + case OpGroupNonUniformShuffleXor: + require_extension_internal("GL_KHR_shader_subgroup_shuffle"); + break; + + case OpGroupNonUniformShuffleUp: + case OpGroupNonUniformShuffleDown: + require_extension_internal("GL_KHR_shader_subgroup_shuffle_relative"); + break; + + case OpGroupNonUniformAll: + case OpGroupNonUniformAny: + case OpGroupNonUniformAllEqual: + { + const SPIRType &type = expression_type(ops[3]); + if (type.basetype == SPIRType::BaseType::Boolean && type.vecsize == 1u) + request_subgroup_feature(ShaderSubgroupSupportHelper::SubgroupAll_Any_AllEqualBool); + else + request_subgroup_feature(ShaderSubgroupSupportHelper::SubgroupAllEqualT); + } + break; + + // clang-format off +#define GLSL_GROUP_OP(OP)\ + case OpGroupNonUniform##OP:\ + {\ + auto operation = static_cast(ops[3]);\ + if (operation == GroupOperationClusteredReduce)\ + require_extension_internal("GL_KHR_shader_subgroup_clustered");\ + else if (operation == GroupOperationReduce)\ + request_subgroup_feature(ShaderSubgroupSupportHelper::SubgroupArithmetic##OP##Reduce);\ + else if (operation == GroupOperationExclusiveScan)\ + request_subgroup_feature(ShaderSubgroupSupportHelper::SubgroupArithmetic##OP##ExclusiveScan);\ + else if (operation == GroupOperationInclusiveScan)\ + request_subgroup_feature(ShaderSubgroupSupportHelper::SubgroupArithmetic##OP##InclusiveScan);\ + else\ + SPIRV_CROSS_THROW("Invalid group operation.");\ + break;\ + } + + GLSL_GROUP_OP(IAdd) + GLSL_GROUP_OP(FAdd) + GLSL_GROUP_OP(IMul) + GLSL_GROUP_OP(FMul) + +#undef GLSL_GROUP_OP + // clang-format on + + case OpGroupNonUniformFMin: + case OpGroupNonUniformFMax: + case OpGroupNonUniformSMin: + case OpGroupNonUniformSMax: + case OpGroupNonUniformUMin: + case OpGroupNonUniformUMax: + case OpGroupNonUniformBitwiseAnd: + case OpGroupNonUniformBitwiseOr: + case OpGroupNonUniformBitwiseXor: + case OpGroupNonUniformLogicalAnd: + case OpGroupNonUniformLogicalOr: + case OpGroupNonUniformLogicalXor: + { + auto operation = static_cast(ops[3]); + if (operation == GroupOperationClusteredReduce) + { + require_extension_internal("GL_KHR_shader_subgroup_clustered"); + } + else if (operation == GroupOperationExclusiveScan || operation == GroupOperationInclusiveScan || + operation == GroupOperationReduce) + { + require_extension_internal("GL_KHR_shader_subgroup_arithmetic"); + } + else + SPIRV_CROSS_THROW("Invalid group operation."); + break; + } + + case OpGroupNonUniformQuadSwap: + case OpGroupNonUniformQuadBroadcast: + require_extension_internal("GL_KHR_shader_subgroup_quad"); + break; + + default: + SPIRV_CROSS_THROW("Invalid opcode for subgroup."); + } + + uint32_t result_type = ops[0]; + uint32_t id = ops[1]; + + auto scope = static_cast(evaluate_constant_u32(ops[2])); + if (scope != ScopeSubgroup) + SPIRV_CROSS_THROW("Only subgroup scope is supported."); + + switch (op) + { + case OpGroupNonUniformElect: + emit_op(result_type, id, "subgroupElect()", true); + break; + + case OpGroupNonUniformBroadcast: + emit_binary_func_op(result_type, id, ops[3], ops[4], "subgroupBroadcast"); + break; + + case OpGroupNonUniformBroadcastFirst: + emit_unary_func_op(result_type, id, ops[3], "subgroupBroadcastFirst"); + break; + + case OpGroupNonUniformBallot: + emit_unary_func_op(result_type, id, ops[3], "subgroupBallot"); + break; + + case OpGroupNonUniformInverseBallot: + emit_unary_func_op(result_type, id, ops[3], "subgroupInverseBallot"); + break; + + case OpGroupNonUniformBallotBitExtract: + emit_binary_func_op(result_type, id, ops[3], ops[4], "subgroupBallotBitExtract"); + break; + + case OpGroupNonUniformBallotFindLSB: + emit_unary_func_op(result_type, id, ops[3], "subgroupBallotFindLSB"); + break; + + case OpGroupNonUniformBallotFindMSB: + emit_unary_func_op(result_type, id, ops[3], "subgroupBallotFindMSB"); + break; + + case OpGroupNonUniformBallotBitCount: + { + auto operation = static_cast(ops[3]); + if (operation == GroupOperationReduce) + emit_unary_func_op(result_type, id, ops[4], "subgroupBallotBitCount"); + else if (operation == GroupOperationInclusiveScan) + emit_unary_func_op(result_type, id, ops[4], "subgroupBallotInclusiveBitCount"); + else if (operation == GroupOperationExclusiveScan) + emit_unary_func_op(result_type, id, ops[4], "subgroupBallotExclusiveBitCount"); + else + SPIRV_CROSS_THROW("Invalid BitCount operation."); + break; + } + + case OpGroupNonUniformShuffle: + emit_binary_func_op(result_type, id, ops[3], ops[4], "subgroupShuffle"); + break; + + case OpGroupNonUniformShuffleXor: + emit_binary_func_op(result_type, id, ops[3], ops[4], "subgroupShuffleXor"); + break; + + case OpGroupNonUniformShuffleUp: + emit_binary_func_op(result_type, id, ops[3], ops[4], "subgroupShuffleUp"); + break; + + case OpGroupNonUniformShuffleDown: + emit_binary_func_op(result_type, id, ops[3], ops[4], "subgroupShuffleDown"); + break; + + case OpGroupNonUniformAll: + emit_unary_func_op(result_type, id, ops[3], "subgroupAll"); + break; + + case OpGroupNonUniformAny: + emit_unary_func_op(result_type, id, ops[3], "subgroupAny"); + break; + + case OpGroupNonUniformAllEqual: + emit_unary_func_op(result_type, id, ops[3], "subgroupAllEqual"); + break; + + // clang-format off +#define GLSL_GROUP_OP(op, glsl_op) \ +case OpGroupNonUniform##op: \ + { \ + auto operation = static_cast(ops[3]); \ + if (operation == GroupOperationReduce) \ + emit_unary_func_op(result_type, id, ops[4], "subgroup" #glsl_op); \ + else if (operation == GroupOperationInclusiveScan) \ + emit_unary_func_op(result_type, id, ops[4], "subgroupInclusive" #glsl_op); \ + else if (operation == GroupOperationExclusiveScan) \ + emit_unary_func_op(result_type, id, ops[4], "subgroupExclusive" #glsl_op); \ + else if (operation == GroupOperationClusteredReduce) \ + emit_binary_func_op(result_type, id, ops[4], ops[5], "subgroupClustered" #glsl_op); \ + else \ + SPIRV_CROSS_THROW("Invalid group operation."); \ + break; \ + } + +#define GLSL_GROUP_OP_CAST(op, glsl_op, type) \ +case OpGroupNonUniform##op: \ + { \ + auto operation = static_cast(ops[3]); \ + if (operation == GroupOperationReduce) \ + emit_unary_func_op_cast(result_type, id, ops[4], "subgroup" #glsl_op, type, type); \ + else if (operation == GroupOperationInclusiveScan) \ + emit_unary_func_op_cast(result_type, id, ops[4], "subgroupInclusive" #glsl_op, type, type); \ + else if (operation == GroupOperationExclusiveScan) \ + emit_unary_func_op_cast(result_type, id, ops[4], "subgroupExclusive" #glsl_op, type, type); \ + else if (operation == GroupOperationClusteredReduce) \ + emit_binary_func_op_cast_clustered(result_type, id, ops[4], ops[5], "subgroupClustered" #glsl_op, type); \ + else \ + SPIRV_CROSS_THROW("Invalid group operation."); \ + break; \ + } + + GLSL_GROUP_OP(FAdd, Add) + GLSL_GROUP_OP(FMul, Mul) + GLSL_GROUP_OP(FMin, Min) + GLSL_GROUP_OP(FMax, Max) + GLSL_GROUP_OP(IAdd, Add) + GLSL_GROUP_OP(IMul, Mul) + GLSL_GROUP_OP_CAST(SMin, Min, int_type) + GLSL_GROUP_OP_CAST(SMax, Max, int_type) + GLSL_GROUP_OP_CAST(UMin, Min, uint_type) + GLSL_GROUP_OP_CAST(UMax, Max, uint_type) + GLSL_GROUP_OP(BitwiseAnd, And) + GLSL_GROUP_OP(BitwiseOr, Or) + GLSL_GROUP_OP(BitwiseXor, Xor) + GLSL_GROUP_OP(LogicalAnd, And) + GLSL_GROUP_OP(LogicalOr, Or) + GLSL_GROUP_OP(LogicalXor, Xor) +#undef GLSL_GROUP_OP +#undef GLSL_GROUP_OP_CAST + // clang-format on + + case OpGroupNonUniformQuadSwap: + { + uint32_t direction = evaluate_constant_u32(ops[4]); + if (direction == 0) + emit_unary_func_op(result_type, id, ops[3], "subgroupQuadSwapHorizontal"); + else if (direction == 1) + emit_unary_func_op(result_type, id, ops[3], "subgroupQuadSwapVertical"); + else if (direction == 2) + emit_unary_func_op(result_type, id, ops[3], "subgroupQuadSwapDiagonal"); + else + SPIRV_CROSS_THROW("Invalid quad swap direction."); + break; + } + + case OpGroupNonUniformQuadBroadcast: + { + emit_binary_func_op(result_type, id, ops[3], ops[4], "subgroupQuadBroadcast"); + break; + } + + default: + SPIRV_CROSS_THROW("Invalid opcode for subgroup."); + } + + register_control_dependent_expression(id); +} + +string CompilerGLSL::bitcast_glsl_op(const SPIRType &out_type, const SPIRType &in_type) +{ + // OpBitcast can deal with pointers. + if (out_type.pointer || in_type.pointer) + { + if (out_type.vecsize == 2 || in_type.vecsize == 2) + require_extension_internal("GL_EXT_buffer_reference_uvec2"); + return type_to_glsl(out_type); + } + + if (out_type.basetype == in_type.basetype) + return ""; + + assert(out_type.basetype != SPIRType::Boolean); + assert(in_type.basetype != SPIRType::Boolean); + + bool integral_cast = type_is_integral(out_type) && type_is_integral(in_type); + bool same_size_cast = out_type.width == in_type.width; + + // Trivial bitcast case, casts between integers. + if (integral_cast && same_size_cast) + return type_to_glsl(out_type); + + // Catch-all 8-bit arithmetic casts (GL_EXT_shader_explicit_arithmetic_types). + if (out_type.width == 8 && in_type.width >= 16 && integral_cast && in_type.vecsize == 1) + return "unpack8"; + else if (in_type.width == 8 && out_type.width == 16 && integral_cast && out_type.vecsize == 1) + return "pack16"; + else if (in_type.width == 8 && out_type.width == 32 && integral_cast && out_type.vecsize == 1) + return "pack32"; + + // Floating <-> Integer special casts. Just have to enumerate all cases. :( + // 16-bit, 32-bit and 64-bit floats. + if (out_type.basetype == SPIRType::UInt && in_type.basetype == SPIRType::Float) + { + if (is_legacy_es()) + SPIRV_CROSS_THROW("Float -> Uint bitcast not supported on legacy ESSL."); + else if (!options.es && options.version < 330) + require_extension_internal("GL_ARB_shader_bit_encoding"); + return "floatBitsToUint"; + } + else if (out_type.basetype == SPIRType::Int && in_type.basetype == SPIRType::Float) + { + if (is_legacy_es()) + SPIRV_CROSS_THROW("Float -> Int bitcast not supported on legacy ESSL."); + else if (!options.es && options.version < 330) + require_extension_internal("GL_ARB_shader_bit_encoding"); + return "floatBitsToInt"; + } + else if (out_type.basetype == SPIRType::Float && in_type.basetype == SPIRType::UInt) + { + if (is_legacy_es()) + SPIRV_CROSS_THROW("Uint -> Float bitcast not supported on legacy ESSL."); + else if (!options.es && options.version < 330) + require_extension_internal("GL_ARB_shader_bit_encoding"); + return "uintBitsToFloat"; + } + else if (out_type.basetype == SPIRType::Float && in_type.basetype == SPIRType::Int) + { + if (is_legacy_es()) + SPIRV_CROSS_THROW("Int -> Float bitcast not supported on legacy ESSL."); + else if (!options.es && options.version < 330) + require_extension_internal("GL_ARB_shader_bit_encoding"); + return "intBitsToFloat"; + } + + else if (out_type.basetype == SPIRType::Int64 && in_type.basetype == SPIRType::Double) + return "doubleBitsToInt64"; + else if (out_type.basetype == SPIRType::UInt64 && in_type.basetype == SPIRType::Double) + return "doubleBitsToUint64"; + else if (out_type.basetype == SPIRType::Double && in_type.basetype == SPIRType::Int64) + return "int64BitsToDouble"; + else if (out_type.basetype == SPIRType::Double && in_type.basetype == SPIRType::UInt64) + return "uint64BitsToDouble"; + else if (out_type.basetype == SPIRType::Short && in_type.basetype == SPIRType::Half) + return "float16BitsToInt16"; + else if (out_type.basetype == SPIRType::UShort && in_type.basetype == SPIRType::Half) + return "float16BitsToUint16"; + else if (out_type.basetype == SPIRType::Half && in_type.basetype == SPIRType::Short) + return "int16BitsToFloat16"; + else if (out_type.basetype == SPIRType::Half && in_type.basetype == SPIRType::UShort) + return "uint16BitsToFloat16"; + + // And finally, some even more special purpose casts. + if (out_type.basetype == SPIRType::UInt64 && in_type.basetype == SPIRType::UInt && in_type.vecsize == 2) + return "packUint2x32"; + else if (out_type.basetype == SPIRType::UInt && in_type.basetype == SPIRType::UInt64 && out_type.vecsize == 2) + return "unpackUint2x32"; + else if (out_type.basetype == SPIRType::Half && in_type.basetype == SPIRType::UInt && in_type.vecsize == 1) + return "unpackFloat2x16"; + else if (out_type.basetype == SPIRType::UInt && in_type.basetype == SPIRType::Half && in_type.vecsize == 2) + return "packFloat2x16"; + else if (out_type.basetype == SPIRType::Int && in_type.basetype == SPIRType::Short && in_type.vecsize == 2) + return "packInt2x16"; + else if (out_type.basetype == SPIRType::Short && in_type.basetype == SPIRType::Int && in_type.vecsize == 1) + return "unpackInt2x16"; + else if (out_type.basetype == SPIRType::UInt && in_type.basetype == SPIRType::UShort && in_type.vecsize == 2) + return "packUint2x16"; + else if (out_type.basetype == SPIRType::UShort && in_type.basetype == SPIRType::UInt && in_type.vecsize == 1) + return "unpackUint2x16"; + else if (out_type.basetype == SPIRType::Int64 && in_type.basetype == SPIRType::Short && in_type.vecsize == 4) + return "packInt4x16"; + else if (out_type.basetype == SPIRType::Short && in_type.basetype == SPIRType::Int64 && in_type.vecsize == 1) + return "unpackInt4x16"; + else if (out_type.basetype == SPIRType::UInt64 && in_type.basetype == SPIRType::UShort && in_type.vecsize == 4) + return "packUint4x16"; + else if (out_type.basetype == SPIRType::UShort && in_type.basetype == SPIRType::UInt64 && in_type.vecsize == 1) + return "unpackUint4x16"; + + return ""; +} + +string CompilerGLSL::bitcast_glsl(const SPIRType &result_type, uint32_t argument) +{ + auto op = bitcast_glsl_op(result_type, expression_type(argument)); + if (op.empty()) + return to_enclosed_unpacked_expression(argument); + else + return join(op, "(", to_unpacked_expression(argument), ")"); +} + +std::string CompilerGLSL::bitcast_expression(SPIRType::BaseType target_type, uint32_t arg) +{ + auto expr = to_expression(arg); + auto &src_type = expression_type(arg); + if (src_type.basetype != target_type) + { + auto target = src_type; + target.basetype = target_type; + expr = join(bitcast_glsl_op(target, src_type), "(", expr, ")"); + } + + return expr; +} + +std::string CompilerGLSL::bitcast_expression(const SPIRType &target_type, SPIRType::BaseType expr_type, + const std::string &expr) +{ + if (target_type.basetype == expr_type) + return expr; + + auto src_type = target_type; + src_type.basetype = expr_type; + return join(bitcast_glsl_op(target_type, src_type), "(", expr, ")"); +} + +string CompilerGLSL::builtin_to_glsl(BuiltIn builtin, StorageClass storage) +{ + switch (builtin) + { + case BuiltInPosition: + return "gl_Position"; + case BuiltInPointSize: + return "gl_PointSize"; + case BuiltInClipDistance: + { + if (options.es) + require_extension_internal("GL_EXT_clip_cull_distance"); + return "gl_ClipDistance"; + } + case BuiltInCullDistance: + { + if (options.es) + require_extension_internal("GL_EXT_clip_cull_distance"); + return "gl_CullDistance"; + } + case BuiltInVertexId: + if (options.vulkan_semantics) + SPIRV_CROSS_THROW("Cannot implement gl_VertexID in Vulkan GLSL. This shader was created " + "with GL semantics."); + return "gl_VertexID"; + case BuiltInInstanceId: + if (options.vulkan_semantics) + { + auto model = get_entry_point().model; + switch (model) + { + case spv::ExecutionModelIntersectionKHR: + case spv::ExecutionModelAnyHitKHR: + case spv::ExecutionModelClosestHitKHR: + // gl_InstanceID is allowed in these shaders. + break; + + default: + SPIRV_CROSS_THROW("Cannot implement gl_InstanceID in Vulkan GLSL. This shader was " + "created with GL semantics."); + } + } + if (!options.es && options.version < 140) + { + require_extension_internal("GL_ARB_draw_instanced"); + } + return "gl_InstanceID"; + case BuiltInVertexIndex: + if (options.vulkan_semantics) + return "gl_VertexIndex"; + else + return "gl_VertexID"; // gl_VertexID already has the base offset applied. + case BuiltInInstanceIndex: + if (options.vulkan_semantics) + return "gl_InstanceIndex"; + + if (!options.es && options.version < 140) + { + require_extension_internal("GL_ARB_draw_instanced"); + } + + if (options.vertex.support_nonzero_base_instance) + { + if (!options.vulkan_semantics) + { + // This is a soft-enable. We will opt-in to using gl_BaseInstanceARB if supported. + require_extension_internal("GL_ARB_shader_draw_parameters"); + } + return "(gl_InstanceID + SPIRV_Cross_BaseInstance)"; // ... but not gl_InstanceID. + } + else + return "gl_InstanceID"; + case BuiltInPrimitiveId: + if (storage == StorageClassInput && get_entry_point().model == ExecutionModelGeometry) + return "gl_PrimitiveIDIn"; + else + return "gl_PrimitiveID"; + case BuiltInInvocationId: + return "gl_InvocationID"; + case BuiltInLayer: + return "gl_Layer"; + case BuiltInViewportIndex: + return "gl_ViewportIndex"; + case BuiltInTessLevelOuter: + return "gl_TessLevelOuter"; + case BuiltInTessLevelInner: + return "gl_TessLevelInner"; + case BuiltInTessCoord: + return "gl_TessCoord"; + case BuiltInPatchVertices: + return "gl_PatchVerticesIn"; + case BuiltInFragCoord: + return "gl_FragCoord"; + case BuiltInPointCoord: + return "gl_PointCoord"; + case BuiltInFrontFacing: + return "gl_FrontFacing"; + case BuiltInFragDepth: + return "gl_FragDepth"; + case BuiltInNumWorkgroups: + return "gl_NumWorkGroups"; + case BuiltInWorkgroupSize: + return "gl_WorkGroupSize"; + case BuiltInWorkgroupId: + return "gl_WorkGroupID"; + case BuiltInLocalInvocationId: + return "gl_LocalInvocationID"; + case BuiltInGlobalInvocationId: + return "gl_GlobalInvocationID"; + case BuiltInLocalInvocationIndex: + return "gl_LocalInvocationIndex"; + case BuiltInHelperInvocation: + return "gl_HelperInvocation"; + + case BuiltInBaseVertex: + if (options.es) + SPIRV_CROSS_THROW("BaseVertex not supported in ES profile."); + + if (options.vulkan_semantics) + { + if (options.version < 460) + { + require_extension_internal("GL_ARB_shader_draw_parameters"); + return "gl_BaseVertexARB"; + } + return "gl_BaseVertex"; + } + // On regular GL, this is soft-enabled and we emit ifdefs in code. + require_extension_internal("GL_ARB_shader_draw_parameters"); + return "SPIRV_Cross_BaseVertex"; + + case BuiltInBaseInstance: + if (options.es) + SPIRV_CROSS_THROW("BaseInstance not supported in ES profile."); + + if (options.vulkan_semantics) + { + if (options.version < 460) + { + require_extension_internal("GL_ARB_shader_draw_parameters"); + return "gl_BaseInstanceARB"; + } + return "gl_BaseInstance"; + } + // On regular GL, this is soft-enabled and we emit ifdefs in code. + require_extension_internal("GL_ARB_shader_draw_parameters"); + return "SPIRV_Cross_BaseInstance"; + + case BuiltInDrawIndex: + if (options.es) + SPIRV_CROSS_THROW("DrawIndex not supported in ES profile."); + + if (options.vulkan_semantics) + { + if (options.version < 460) + { + require_extension_internal("GL_ARB_shader_draw_parameters"); + return "gl_DrawIDARB"; + } + return "gl_DrawID"; + } + // On regular GL, this is soft-enabled and we emit ifdefs in code. + require_extension_internal("GL_ARB_shader_draw_parameters"); + return "gl_DrawIDARB"; + + case BuiltInSampleId: + if (is_legacy()) + SPIRV_CROSS_THROW("Sample variables not supported in legacy GLSL."); + else if (options.es && options.version < 320) + require_extension_internal("GL_OES_sample_variables"); + else if (!options.es && options.version < 400) + require_extension_internal("GL_ARB_sample_shading"); + return "gl_SampleID"; + + case BuiltInSampleMask: + if (is_legacy()) + SPIRV_CROSS_THROW("Sample variables not supported in legacy GLSL."); + else if (options.es && options.version < 320) + require_extension_internal("GL_OES_sample_variables"); + else if (!options.es && options.version < 400) + require_extension_internal("GL_ARB_sample_shading"); + + if (storage == StorageClassInput) + return "gl_SampleMaskIn"; + else + return "gl_SampleMask"; + + case BuiltInSamplePosition: + if (is_legacy()) + SPIRV_CROSS_THROW("Sample variables not supported in legacy GLSL."); + else if (options.es && options.version < 320) + require_extension_internal("GL_OES_sample_variables"); + else if (!options.es && options.version < 400) + require_extension_internal("GL_ARB_sample_shading"); + return "gl_SamplePosition"; + + case BuiltInViewIndex: + if (options.vulkan_semantics) + return "gl_ViewIndex"; + else + return "gl_ViewID_OVR"; + + case BuiltInNumSubgroups: + request_subgroup_feature(ShaderSubgroupSupportHelper::NumSubgroups); + return "gl_NumSubgroups"; + + case BuiltInSubgroupId: + request_subgroup_feature(ShaderSubgroupSupportHelper::SubgroupID); + return "gl_SubgroupID"; + + case BuiltInSubgroupSize: + request_subgroup_feature(ShaderSubgroupSupportHelper::SubgroupSize); + return "gl_SubgroupSize"; + + case BuiltInSubgroupLocalInvocationId: + request_subgroup_feature(ShaderSubgroupSupportHelper::SubgroupInvocationID); + return "gl_SubgroupInvocationID"; + + case BuiltInSubgroupEqMask: + request_subgroup_feature(ShaderSubgroupSupportHelper::SubgroupMask); + return "gl_SubgroupEqMask"; + + case BuiltInSubgroupGeMask: + request_subgroup_feature(ShaderSubgroupSupportHelper::SubgroupMask); + return "gl_SubgroupGeMask"; + + case BuiltInSubgroupGtMask: + request_subgroup_feature(ShaderSubgroupSupportHelper::SubgroupMask); + return "gl_SubgroupGtMask"; + + case BuiltInSubgroupLeMask: + request_subgroup_feature(ShaderSubgroupSupportHelper::SubgroupMask); + return "gl_SubgroupLeMask"; + + case BuiltInSubgroupLtMask: + request_subgroup_feature(ShaderSubgroupSupportHelper::SubgroupMask); + return "gl_SubgroupLtMask"; + + case BuiltInLaunchIdKHR: + return ray_tracing_is_khr ? "gl_LaunchIDEXT" : "gl_LaunchIDNV"; + case BuiltInLaunchSizeKHR: + return ray_tracing_is_khr ? "gl_LaunchSizeEXT" : "gl_LaunchSizeNV"; + case BuiltInWorldRayOriginKHR: + return ray_tracing_is_khr ? "gl_WorldRayOriginEXT" : "gl_WorldRayOriginNV"; + case BuiltInWorldRayDirectionKHR: + return ray_tracing_is_khr ? "gl_WorldRayDirectionEXT" : "gl_WorldRayDirectionNV"; + case BuiltInObjectRayOriginKHR: + return ray_tracing_is_khr ? "gl_ObjectRayOriginEXT" : "gl_ObjectRayOriginNV"; + case BuiltInObjectRayDirectionKHR: + return ray_tracing_is_khr ? "gl_ObjectRayDirectionEXT" : "gl_ObjectRayDirectionNV"; + case BuiltInRayTminKHR: + return ray_tracing_is_khr ? "gl_RayTminEXT" : "gl_RayTminNV"; + case BuiltInRayTmaxKHR: + return ray_tracing_is_khr ? "gl_RayTmaxEXT" : "gl_RayTmaxNV"; + case BuiltInInstanceCustomIndexKHR: + return ray_tracing_is_khr ? "gl_InstanceCustomIndexEXT" : "gl_InstanceCustomIndexNV"; + case BuiltInObjectToWorldKHR: + return ray_tracing_is_khr ? "gl_ObjectToWorldEXT" : "gl_ObjectToWorldNV"; + case BuiltInWorldToObjectKHR: + return ray_tracing_is_khr ? "gl_WorldToObjectEXT" : "gl_WorldToObjectNV"; + case BuiltInHitTNV: + // gl_HitTEXT is an alias of RayTMax in KHR. + return "gl_HitTNV"; + case BuiltInHitKindKHR: + return ray_tracing_is_khr ? "gl_HitKindEXT" : "gl_HitKindNV"; + case BuiltInIncomingRayFlagsKHR: + return ray_tracing_is_khr ? "gl_IncomingRayFlagsEXT" : "gl_IncomingRayFlagsNV"; + + case BuiltInBaryCoordKHR: + { + if (options.es && options.version < 320) + SPIRV_CROSS_THROW("gl_BaryCoordEXT requires ESSL 320."); + else if (!options.es && options.version < 450) + SPIRV_CROSS_THROW("gl_BaryCoordEXT requires GLSL 450."); + + if (barycentric_is_nv) + { + require_extension_internal("GL_NV_fragment_shader_barycentric"); + return "gl_BaryCoordNV"; + } + else + { + require_extension_internal("GL_EXT_fragment_shader_barycentric"); + return "gl_BaryCoordEXT"; + } + } + + case BuiltInBaryCoordNoPerspNV: + { + if (options.es && options.version < 320) + SPIRV_CROSS_THROW("gl_BaryCoordNoPerspEXT requires ESSL 320."); + else if (!options.es && options.version < 450) + SPIRV_CROSS_THROW("gl_BaryCoordNoPerspEXT requires GLSL 450."); + + if (barycentric_is_nv) + { + require_extension_internal("GL_NV_fragment_shader_barycentric"); + return "gl_BaryCoordNoPerspNV"; + } + else + { + require_extension_internal("GL_EXT_fragment_shader_barycentric"); + return "gl_BaryCoordNoPerspEXT"; + } + } + + case BuiltInFragStencilRefEXT: + { + if (!options.es) + { + require_extension_internal("GL_ARB_shader_stencil_export"); + return "gl_FragStencilRefARB"; + } + else + SPIRV_CROSS_THROW("Stencil export not supported in GLES."); + } + + case BuiltInPrimitiveShadingRateKHR: + { + if (!options.vulkan_semantics) + SPIRV_CROSS_THROW("Can only use PrimitiveShadingRateKHR in Vulkan GLSL."); + require_extension_internal("GL_EXT_fragment_shading_rate"); + return "gl_PrimitiveShadingRateEXT"; + } + + case BuiltInShadingRateKHR: + { + if (!options.vulkan_semantics) + SPIRV_CROSS_THROW("Can only use ShadingRateKHR in Vulkan GLSL."); + require_extension_internal("GL_EXT_fragment_shading_rate"); + return "gl_ShadingRateEXT"; + } + + case BuiltInDeviceIndex: + if (!options.vulkan_semantics) + SPIRV_CROSS_THROW("Need Vulkan semantics for device group support."); + require_extension_internal("GL_EXT_device_group"); + return "gl_DeviceIndex"; + + case BuiltInFullyCoveredEXT: + if (!options.es) + require_extension_internal("GL_NV_conservative_raster_underestimation"); + else + SPIRV_CROSS_THROW("Need desktop GL to use GL_NV_conservative_raster_underestimation."); + return "gl_FragFullyCoveredNV"; + + case BuiltInPrimitiveTriangleIndicesEXT: + return "gl_PrimitiveTriangleIndicesEXT"; + case BuiltInPrimitiveLineIndicesEXT: + return "gl_PrimitiveLineIndicesEXT"; + case BuiltInPrimitivePointIndicesEXT: + return "gl_PrimitivePointIndicesEXT"; + case BuiltInCullPrimitiveEXT: + return "gl_CullPrimitiveEXT"; + + default: + return join("gl_BuiltIn_", convert_to_string(builtin)); + } +} + +const char *CompilerGLSL::index_to_swizzle(uint32_t index) +{ + switch (index) + { + case 0: + return "x"; + case 1: + return "y"; + case 2: + return "z"; + case 3: + return "w"; + default: + return "x"; // Don't crash, but engage the "undefined behavior" described for out-of-bounds logical addressing in spec. + } +} + +void CompilerGLSL::access_chain_internal_append_index(std::string &expr, uint32_t /*base*/, const SPIRType * /*type*/, + AccessChainFlags flags, bool &access_chain_is_arrayed, + uint32_t index) +{ + bool index_is_literal = (flags & ACCESS_CHAIN_INDEX_IS_LITERAL_BIT) != 0; + bool ptr_chain = (flags & ACCESS_CHAIN_PTR_CHAIN_BIT) != 0; + bool register_expression_read = (flags & ACCESS_CHAIN_SKIP_REGISTER_EXPRESSION_READ_BIT) == 0; + + string idx_expr = index_is_literal ? convert_to_string(index) : to_unpacked_expression(index, register_expression_read); + + // For the case where the base of an OpPtrAccessChain already ends in [n], + // we need to use the index as an offset to the existing index, otherwise, + // we can just use the index directly. + if (ptr_chain && access_chain_is_arrayed) + { + size_t split_pos = expr.find_last_of(']'); + size_t enclose_split = expr.find_last_of(')'); + + // If we have already enclosed the expression, don't try to be clever, it will break. + if (split_pos > enclose_split || enclose_split == string::npos) + { + string expr_front = expr.substr(0, split_pos); + string expr_back = expr.substr(split_pos); + expr = expr_front + " + " + enclose_expression(idx_expr) + expr_back; + return; + } + } + + expr += "["; + expr += idx_expr; + expr += "]"; +} + +bool CompilerGLSL::access_chain_needs_stage_io_builtin_translation(uint32_t) +{ + return true; +} + +string CompilerGLSL::access_chain_internal(uint32_t base, const uint32_t *indices, uint32_t count, + AccessChainFlags flags, AccessChainMeta *meta) +{ + string expr; + + bool index_is_literal = (flags & ACCESS_CHAIN_INDEX_IS_LITERAL_BIT) != 0; + bool msb_is_id = (flags & ACCESS_CHAIN_LITERAL_MSB_FORCE_ID) != 0; + bool chain_only = (flags & ACCESS_CHAIN_CHAIN_ONLY_BIT) != 0; + bool ptr_chain = (flags & ACCESS_CHAIN_PTR_CHAIN_BIT) != 0; + bool register_expression_read = (flags & ACCESS_CHAIN_SKIP_REGISTER_EXPRESSION_READ_BIT) == 0; + bool flatten_member_reference = (flags & ACCESS_CHAIN_FLATTEN_ALL_MEMBERS_BIT) != 0; + + if (!chain_only) + { + // We handle transpose explicitly, so don't resolve that here. + auto *e = maybe_get(base); + bool old_transpose = e && e->need_transpose; + if (e) + e->need_transpose = false; + expr = to_enclosed_expression(base, register_expression_read); + if (e) + e->need_transpose = old_transpose; + } + + // Start traversing type hierarchy at the proper non-pointer types, + // but keep type_id referencing the original pointer for use below. + uint32_t type_id = expression_type_id(base); + const auto *type = &get_pointee_type(type_id); + + if (!backend.native_pointers) + { + if (ptr_chain) + SPIRV_CROSS_THROW("Backend does not support native pointers and does not support OpPtrAccessChain."); + + // Wrapped buffer reference pointer types will need to poke into the internal "value" member before + // continuing the access chain. + if (should_dereference(base)) + expr = dereference_expression(get(type_id), expr); + } + else if (should_dereference(base) && type->basetype != SPIRType::Struct && !ptr_chain) + expr = join("(", dereference_expression(*type, expr), ")"); + + bool access_chain_is_arrayed = expr.find_first_of('[') != string::npos; + bool row_major_matrix_needs_conversion = is_non_native_row_major_matrix(base); + bool is_packed = has_extended_decoration(base, SPIRVCrossDecorationPhysicalTypePacked); + uint32_t physical_type = get_extended_decoration(base, SPIRVCrossDecorationPhysicalTypeID); + bool is_invariant = has_decoration(base, DecorationInvariant); + bool relaxed_precision = has_decoration(base, DecorationRelaxedPrecision); + bool pending_array_enclose = false; + bool dimension_flatten = false; + bool access_meshlet_position_y = false; + + if (auto *base_expr = maybe_get(base)) + { + access_meshlet_position_y = base_expr->access_meshlet_position_y; + } + + // If we are translating access to a structured buffer, the first subscript '._m0' must be hidden + bool hide_first_subscript = count > 1 && is_user_type_structured(base); + + const auto append_index = [&](uint32_t index, bool is_literal, bool is_ptr_chain = false) { + AccessChainFlags mod_flags = flags; + if (!is_literal) + mod_flags &= ~ACCESS_CHAIN_INDEX_IS_LITERAL_BIT; + if (!is_ptr_chain) + mod_flags &= ~ACCESS_CHAIN_PTR_CHAIN_BIT; + access_chain_internal_append_index(expr, base, type, mod_flags, access_chain_is_arrayed, index); + check_physical_type_cast(expr, type, physical_type); + }; + + for (uint32_t i = 0; i < count; i++) + { + uint32_t index = indices[i]; + + bool is_literal = index_is_literal; + if (is_literal && msb_is_id && (index >> 31u) != 0u) + { + is_literal = false; + index &= 0x7fffffffu; + } + + bool ptr_chain_array_entry = ptr_chain && i == 0 && is_array(*type); + + if (ptr_chain_array_entry) + { + // This is highly unusual code, since normally we'd use plain AccessChain, but it's still allowed. + // We are considered to have a pointer to array and one element shifts by one array at a time. + // If we use normal array indexing, we'll first decay to pointer, and lose the array-ness, + // so we have to take pointer to array explicitly. + if (!should_dereference(base)) + expr = enclose_expression(address_of_expression(expr)); + } + + if (ptr_chain && i == 0) + { + // Pointer chains + // If we are flattening multidimensional arrays, only create opening bracket on first + // array index. + if (options.flatten_multidimensional_arrays) + { + dimension_flatten = type->array.size() >= 1; + pending_array_enclose = dimension_flatten; + if (pending_array_enclose) + expr += "["; + } + + if (options.flatten_multidimensional_arrays && dimension_flatten) + { + // If we are flattening multidimensional arrays, do manual stride computation. + if (is_literal) + expr += convert_to_string(index); + else + expr += to_enclosed_expression(index, register_expression_read); + + for (auto j = uint32_t(type->array.size()); j; j--) + { + expr += " * "; + expr += enclose_expression(to_array_size(*type, j - 1)); + } + + if (type->array.empty()) + pending_array_enclose = false; + else + expr += " + "; + + if (!pending_array_enclose) + expr += "]"; + } + else + { + if (flags & ACCESS_CHAIN_PTR_CHAIN_POINTER_ARITH_BIT) + { + SPIRType tmp_type(OpTypeInt); + tmp_type.basetype = SPIRType::UInt64; + tmp_type.width = 64; + tmp_type.vecsize = 1; + tmp_type.columns = 1; + + TypeID ptr_type_id = expression_type_id(base); + const SPIRType &ptr_type = get(ptr_type_id); + const SPIRType &pointee_type = get_pointee_type(ptr_type); + + // This only runs in native pointer backends. + // Can replace reinterpret_cast with a backend string if ever needed. + // We expect this to count as a de-reference. + // This leaks some MSL details, but feels slightly overkill to + // add yet another virtual interface just for this. + auto intptr_expr = join("reinterpret_cast<", type_to_glsl(tmp_type), ">(", expr, ")"); + intptr_expr += join(" + ", to_enclosed_unpacked_expression(index), " * ", + get_decoration(ptr_type_id, DecorationArrayStride)); + + if (flags & ACCESS_CHAIN_PTR_CHAIN_CAST_TO_SCALAR_BIT) + { + is_packed = true; + expr = join("*reinterpret_cast(", intptr_expr, ")"); + } + else + { + expr = join("*reinterpret_cast<", type_to_glsl(ptr_type), ">(", intptr_expr, ")"); + } + } + else + append_index(index, is_literal, true); + } + + if (type->basetype == SPIRType::ControlPointArray) + { + type_id = type->parent_type; + type = &get(type_id); + } + + access_chain_is_arrayed = true; + + // Explicitly enclose the expression if this is one of the weird pointer-to-array cases. + // We don't want any future indexing to add to this array dereference. + // Enclosing the expression blocks that and avoids any shenanigans with operand priority. + if (ptr_chain_array_entry) + expr = join("(", expr, ")"); + } + // Arrays + else if (!type->array.empty()) + { + // If we are flattening multidimensional arrays, only create opening bracket on first + // array index. + if (options.flatten_multidimensional_arrays && !pending_array_enclose) + { + dimension_flatten = type->array.size() > 1; + pending_array_enclose = dimension_flatten; + if (pending_array_enclose) + expr += "["; + } + + assert(type->parent_type); + + auto *var = maybe_get(base); + if (backend.force_gl_in_out_block && i == 0 && var && is_builtin_variable(*var) && + !has_decoration(type->self, DecorationBlock)) + { + // This deals with scenarios for tesc/geom where arrays of gl_Position[] are declared. + // Normally, these variables live in blocks when compiled from GLSL, + // but HLSL seems to just emit straight arrays here. + // We must pretend this access goes through gl_in/gl_out arrays + // to be able to access certain builtins as arrays. + // Similar concerns apply for mesh shaders where we have to redirect to gl_MeshVerticesEXT or MeshPrimitivesEXT. + auto builtin = ir.meta[base].decoration.builtin_type; + bool mesh_shader = get_execution_model() == ExecutionModelMeshEXT; + + switch (builtin) + { + case BuiltInCullDistance: + case BuiltInClipDistance: + if (type->array.size() == 1) // Red herring. Only consider block IO for two-dimensional arrays here. + { + append_index(index, is_literal); + break; + } + // fallthrough + case BuiltInPosition: + case BuiltInPointSize: + if (mesh_shader) + expr = join("gl_MeshVerticesEXT[", to_expression(index, register_expression_read), "].", expr); + else if (var->storage == StorageClassInput) + expr = join("gl_in[", to_expression(index, register_expression_read), "].", expr); + else if (var->storage == StorageClassOutput) + expr = join("gl_out[", to_expression(index, register_expression_read), "].", expr); + else + append_index(index, is_literal); + break; + + case BuiltInPrimitiveId: + case BuiltInLayer: + case BuiltInViewportIndex: + case BuiltInCullPrimitiveEXT: + case BuiltInPrimitiveShadingRateKHR: + if (mesh_shader) + expr = join("gl_MeshPrimitivesEXT[", to_expression(index, register_expression_read), "].", expr); + else + append_index(index, is_literal); + break; + + default: + append_index(index, is_literal); + break; + } + } + else if (backend.force_merged_mesh_block && i == 0 && var && + !is_builtin_variable(*var) && var->storage == StorageClassOutput) + { + if (is_per_primitive_variable(*var)) + expr = join("gl_MeshPrimitivesEXT[", to_expression(index, register_expression_read), "].", expr); + else + expr = join("gl_MeshVerticesEXT[", to_expression(index, register_expression_read), "].", expr); + } + else if (options.flatten_multidimensional_arrays && dimension_flatten) + { + // If we are flattening multidimensional arrays, do manual stride computation. + auto &parent_type = get(type->parent_type); + + if (is_literal) + expr += convert_to_string(index); + else + expr += to_enclosed_expression(index, register_expression_read); + + for (auto j = uint32_t(parent_type.array.size()); j; j--) + { + expr += " * "; + expr += enclose_expression(to_array_size(parent_type, j - 1)); + } + + if (parent_type.array.empty()) + pending_array_enclose = false; + else + expr += " + "; + + if (!pending_array_enclose) + expr += "]"; + } + else if (index_is_literal || !builtin_translates_to_nonarray(BuiltIn(get_decoration(base, DecorationBuiltIn)))) + { + // Some builtins are arrays in SPIR-V but not in other languages, e.g. gl_SampleMask[] is an array in SPIR-V but not in Metal. + // By throwing away the index, we imply the index was 0, which it must be for gl_SampleMask. + // For literal indices we are working on composites, so we ignore this since we have already converted to proper array. + append_index(index, is_literal); + } + + if (var && has_decoration(var->self, DecorationBuiltIn) && + get_decoration(var->self, DecorationBuiltIn) == BuiltInPosition && + get_execution_model() == ExecutionModelMeshEXT) + { + access_meshlet_position_y = true; + } + + type_id = type->parent_type; + type = &get(type_id); + + // If the physical type has an unnatural vecsize, + // we must assume it's a faked struct where the .data member + // is used for the real payload. + if (physical_type && (is_vector(*type) || is_scalar(*type))) + { + auto &phys = get(physical_type); + if (phys.vecsize > 4) + expr += ".data"; + } + + access_chain_is_arrayed = true; + } + // For structs, the index refers to a constant, which indexes into the members, possibly through a redirection mapping. + // We also check if this member is a builtin, since we then replace the entire expression with the builtin one. + else if (type->basetype == SPIRType::Struct) + { + if (!is_literal) + index = evaluate_constant_u32(index); + + if (index < uint32_t(type->member_type_index_redirection.size())) + index = type->member_type_index_redirection[index]; + + if (index >= type->member_types.size()) + SPIRV_CROSS_THROW("Member index is out of bounds!"); + + if (hide_first_subscript) + { + // First "._m0" subscript has been hidden, subsequent fields must be emitted even for structured buffers + hide_first_subscript = false; + } + else + { + BuiltIn builtin = BuiltInMax; + if (is_member_builtin(*type, index, &builtin) && access_chain_needs_stage_io_builtin_translation(base)) + { + if (access_chain_is_arrayed) + { + expr += "."; + expr += builtin_to_glsl(builtin, type->storage); + } + else + expr = builtin_to_glsl(builtin, type->storage); + + if (builtin == BuiltInPosition && get_execution_model() == ExecutionModelMeshEXT) + { + access_meshlet_position_y = true; + } + } + else + { + // If the member has a qualified name, use it as the entire chain + string qual_mbr_name = get_member_qualified_name(type_id, index); + if (!qual_mbr_name.empty()) + expr = qual_mbr_name; + else if (flatten_member_reference) + expr += join("_", to_member_name(*type, index)); + else + { + // Any pointer de-refences for values are handled in the first access chain. + // For pointer chains, the pointer-ness is resolved through an array access. + // The only time this is not true is when accessing array of SSBO/UBO. + // This case is explicitly handled. + expr += to_member_reference(base, *type, index, ptr_chain || i != 0); + } + } + } + + if (has_member_decoration(type->self, index, DecorationInvariant)) + is_invariant = true; + if (has_member_decoration(type->self, index, DecorationRelaxedPrecision)) + relaxed_precision = true; + + is_packed = member_is_packed_physical_type(*type, index); + if (member_is_remapped_physical_type(*type, index)) + physical_type = get_extended_member_decoration(type->self, index, SPIRVCrossDecorationPhysicalTypeID); + else + physical_type = 0; + + row_major_matrix_needs_conversion = member_is_non_native_row_major_matrix(*type, index); + type = &get(type->member_types[index]); + } + // Matrix -> Vector + else if (type->columns > 1) + { + // If we have a row-major matrix here, we need to defer any transpose in case this access chain + // is used to store a column. We can resolve it right here and now if we access a scalar directly, + // by flipping indexing order of the matrix. + + expr += "["; + if (is_literal) + expr += convert_to_string(index); + else + expr += to_unpacked_expression(index, register_expression_read); + expr += "]"; + + // If the physical type has an unnatural vecsize, + // we must assume it's a faked struct where the .data member + // is used for the real payload. + if (physical_type) + { + auto &phys = get(physical_type); + if (phys.vecsize > 4 || phys.columns > 4) + expr += ".data"; + } + + type_id = type->parent_type; + type = &get(type_id); + } + // Vector -> Scalar + else if (type->vecsize > 1) + { + string deferred_index; + if (row_major_matrix_needs_conversion) + { + // Flip indexing order. + auto column_index = expr.find_last_of('['); + if (column_index != string::npos) + { + deferred_index = expr.substr(column_index); + + auto end_deferred_index = deferred_index.find_last_of(']'); + if (end_deferred_index != string::npos && end_deferred_index + 1 != deferred_index.size()) + { + // If we have any data member fixups, it must be transposed so that it refers to this index. + // E.g. [0].data followed by [1] would be shuffled to [1][0].data which is wrong, + // and needs to be [1].data[0] instead. + end_deferred_index++; + deferred_index = deferred_index.substr(end_deferred_index) + + deferred_index.substr(0, end_deferred_index); + } + + expr.resize(column_index); + } + } + + // Internally, access chain implementation can also be used on composites, + // ignore scalar access workarounds in this case. + StorageClass effective_storage = StorageClassGeneric; + bool ignore_potential_sliced_writes = false; + if ((flags & ACCESS_CHAIN_FORCE_COMPOSITE_BIT) == 0) + { + if (expression_type(base).pointer) + effective_storage = get_expression_effective_storage_class(base); + + // Special consideration for control points. + // Control points can only be written by InvocationID, so there is no need + // to consider scalar access chains here. + // Cleans up some cases where it's very painful to determine the accurate storage class + // since blocks can be partially masked ... + auto *var = maybe_get_backing_variable(base); + if (var && var->storage == StorageClassOutput && + get_execution_model() == ExecutionModelTessellationControl && + !has_decoration(var->self, DecorationPatch)) + { + ignore_potential_sliced_writes = true; + } + } + else + ignore_potential_sliced_writes = true; + + if (!row_major_matrix_needs_conversion && !ignore_potential_sliced_writes) + { + // On some backends, we might not be able to safely access individual scalars in a vector. + // To work around this, we might have to cast the access chain reference to something which can, + // like a pointer to scalar, which we can then index into. + prepare_access_chain_for_scalar_access(expr, get(type->parent_type), effective_storage, + is_packed); + } + + if (is_literal) + { + bool out_of_bounds = (index >= type->vecsize); + + if (!is_packed && !row_major_matrix_needs_conversion) + { + expr += "."; + expr += index_to_swizzle(out_of_bounds ? 0 : index); + } + else + { + // For packed vectors, we can only access them as an array, not by swizzle. + expr += join("[", out_of_bounds ? 0 : index, "]"); + } + } + else if (ir.ids[index].get_type() == TypeConstant && !is_packed && !row_major_matrix_needs_conversion) + { + auto &c = get(index); + bool out_of_bounds = (c.scalar() >= type->vecsize); + + if (c.specialization) + { + // If the index is a spec constant, we cannot turn extract into a swizzle. + expr += join("[", out_of_bounds ? "0" : to_expression(index), "]"); + } + else + { + expr += "."; + expr += index_to_swizzle(out_of_bounds ? 0 : c.scalar()); + } + } + else + { + expr += "["; + expr += to_unpacked_expression(index, register_expression_read); + expr += "]"; + } + + if (row_major_matrix_needs_conversion && !ignore_potential_sliced_writes) + { + if (prepare_access_chain_for_scalar_access(expr, get(type->parent_type), effective_storage, + is_packed)) + { + // We're in a pointer context now, so just remove any member dereference. + auto first_index = deferred_index.find_first_of('['); + if (first_index != string::npos && first_index != 0) + deferred_index = deferred_index.substr(first_index); + } + } + + if (access_meshlet_position_y) + { + if (is_literal) + { + access_meshlet_position_y = index == 1; + } + else + { + const auto *c = maybe_get(index); + if (c) + access_meshlet_position_y = c->scalar() == 1; + else + { + // We don't know, but we have to assume no. + // Flip Y in mesh shaders is an opt-in horrible hack, so we'll have to assume shaders try to behave. + access_meshlet_position_y = false; + } + } + } + + expr += deferred_index; + row_major_matrix_needs_conversion = false; + + is_packed = false; + physical_type = 0; + type_id = type->parent_type; + type = &get(type_id); + } + else if (!backend.allow_truncated_access_chain) + SPIRV_CROSS_THROW("Cannot subdivide a scalar value!"); + } + + if (pending_array_enclose) + { + SPIRV_CROSS_THROW("Flattening of multidimensional arrays were enabled, " + "but the access chain was terminated in the middle of a multidimensional array. " + "This is not supported."); + } + + if (meta) + { + meta->need_transpose = row_major_matrix_needs_conversion; + meta->storage_is_packed = is_packed; + meta->storage_is_invariant = is_invariant; + meta->storage_physical_type = physical_type; + meta->relaxed_precision = relaxed_precision; + meta->access_meshlet_position_y = access_meshlet_position_y; + } + + return expr; +} + +void CompilerGLSL::check_physical_type_cast(std::string &, const SPIRType *, uint32_t) +{ +} + +bool CompilerGLSL::prepare_access_chain_for_scalar_access(std::string &, const SPIRType &, spv::StorageClass, bool &) +{ + return false; +} + +string CompilerGLSL::to_flattened_struct_member(const string &basename, const SPIRType &type, uint32_t index) +{ + auto ret = join(basename, "_", to_member_name(type, index)); + ParsedIR::sanitize_underscores(ret); + return ret; +} + +uint32_t CompilerGLSL::get_physical_type_stride(const SPIRType &) const +{ + SPIRV_CROSS_THROW("Invalid to call get_physical_type_stride on a backend without native pointer support."); +} + +string CompilerGLSL::access_chain(uint32_t base, const uint32_t *indices, uint32_t count, const SPIRType &target_type, + AccessChainMeta *meta, bool ptr_chain) +{ + if (flattened_buffer_blocks.count(base)) + { + uint32_t matrix_stride = 0; + uint32_t array_stride = 0; + bool need_transpose = false; + flattened_access_chain_offset(expression_type(base), indices, count, 0, 16, &need_transpose, &matrix_stride, + &array_stride, ptr_chain); + + if (meta) + { + meta->need_transpose = target_type.columns > 1 && need_transpose; + meta->storage_is_packed = false; + } + + return flattened_access_chain(base, indices, count, target_type, 0, matrix_stride, array_stride, + need_transpose); + } + else if (flattened_structs.count(base) && count > 0) + { + AccessChainFlags flags = ACCESS_CHAIN_CHAIN_ONLY_BIT | ACCESS_CHAIN_SKIP_REGISTER_EXPRESSION_READ_BIT; + if (ptr_chain) + flags |= ACCESS_CHAIN_PTR_CHAIN_BIT; + + if (flattened_structs[base]) + { + flags |= ACCESS_CHAIN_FLATTEN_ALL_MEMBERS_BIT; + if (meta) + meta->flattened_struct = target_type.basetype == SPIRType::Struct; + } + + auto chain = access_chain_internal(base, indices, count, flags, nullptr).substr(1); + if (meta) + { + meta->need_transpose = false; + meta->storage_is_packed = false; + } + + auto basename = to_flattened_access_chain_expression(base); + auto ret = join(basename, "_", chain); + ParsedIR::sanitize_underscores(ret); + return ret; + } + else + { + AccessChainFlags flags = ACCESS_CHAIN_SKIP_REGISTER_EXPRESSION_READ_BIT; + if (ptr_chain) + { + flags |= ACCESS_CHAIN_PTR_CHAIN_BIT; + // PtrAccessChain could get complicated. + TypeID type_id = expression_type_id(base); + if (backend.native_pointers && has_decoration(type_id, DecorationArrayStride)) + { + // If there is a mismatch we have to go via 64-bit pointer arithmetic :'( + // Using packed hacks only gets us so far, and is not designed to deal with pointer to + // random values. It works for structs though. + auto &pointee_type = get_pointee_type(get(type_id)); + uint32_t physical_stride = get_physical_type_stride(pointee_type); + uint32_t requested_stride = get_decoration(type_id, DecorationArrayStride); + if (physical_stride != requested_stride) + { + flags |= ACCESS_CHAIN_PTR_CHAIN_POINTER_ARITH_BIT; + if (is_vector(pointee_type)) + flags |= ACCESS_CHAIN_PTR_CHAIN_CAST_TO_SCALAR_BIT; + } + } + } + + return access_chain_internal(base, indices, count, flags, meta); + } +} + +string CompilerGLSL::load_flattened_struct(const string &basename, const SPIRType &type) +{ + auto expr = type_to_glsl_constructor(type); + expr += '('; + + for (uint32_t i = 0; i < uint32_t(type.member_types.size()); i++) + { + if (i) + expr += ", "; + + auto &member_type = get(type.member_types[i]); + if (member_type.basetype == SPIRType::Struct) + expr += load_flattened_struct(to_flattened_struct_member(basename, type, i), member_type); + else + expr += to_flattened_struct_member(basename, type, i); + } + expr += ')'; + return expr; +} + +std::string CompilerGLSL::to_flattened_access_chain_expression(uint32_t id) +{ + // Do not use to_expression as that will unflatten access chains. + string basename; + if (const auto *var = maybe_get(id)) + basename = to_name(var->self); + else if (const auto *expr = maybe_get(id)) + basename = expr->expression; + else + basename = to_expression(id); + + return basename; +} + +void CompilerGLSL::store_flattened_struct(const string &basename, uint32_t rhs_id, const SPIRType &type, + const SmallVector &indices) +{ + SmallVector sub_indices = indices; + sub_indices.push_back(0); + + auto *member_type = &type; + for (auto &index : indices) + member_type = &get(member_type->member_types[index]); + + for (uint32_t i = 0; i < uint32_t(member_type->member_types.size()); i++) + { + sub_indices.back() = i; + auto lhs = join(basename, "_", to_member_name(*member_type, i)); + ParsedIR::sanitize_underscores(lhs); + + if (get(member_type->member_types[i]).basetype == SPIRType::Struct) + { + store_flattened_struct(lhs, rhs_id, type, sub_indices); + } + else + { + auto rhs = to_expression(rhs_id) + to_multi_member_reference(type, sub_indices); + statement(lhs, " = ", rhs, ";"); + } + } +} + +void CompilerGLSL::store_flattened_struct(uint32_t lhs_id, uint32_t value) +{ + auto &type = expression_type(lhs_id); + auto basename = to_flattened_access_chain_expression(lhs_id); + store_flattened_struct(basename, value, type, {}); +} + +std::string CompilerGLSL::flattened_access_chain(uint32_t base, const uint32_t *indices, uint32_t count, + const SPIRType &target_type, uint32_t offset, uint32_t matrix_stride, + uint32_t /* array_stride */, bool need_transpose) +{ + if (!target_type.array.empty()) + SPIRV_CROSS_THROW("Access chains that result in an array can not be flattened"); + else if (target_type.basetype == SPIRType::Struct) + return flattened_access_chain_struct(base, indices, count, target_type, offset); + else if (target_type.columns > 1) + return flattened_access_chain_matrix(base, indices, count, target_type, offset, matrix_stride, need_transpose); + else + return flattened_access_chain_vector(base, indices, count, target_type, offset, matrix_stride, need_transpose); +} + +std::string CompilerGLSL::flattened_access_chain_struct(uint32_t base, const uint32_t *indices, uint32_t count, + const SPIRType &target_type, uint32_t offset) +{ + std::string expr; + + if (backend.can_declare_struct_inline) + { + expr += type_to_glsl_constructor(target_type); + expr += "("; + } + else + expr += "{"; + + for (uint32_t i = 0; i < uint32_t(target_type.member_types.size()); ++i) + { + if (i != 0) + expr += ", "; + + const SPIRType &member_type = get(target_type.member_types[i]); + uint32_t member_offset = type_struct_member_offset(target_type, i); + + // The access chain terminates at the struct, so we need to find matrix strides and row-major information + // ahead of time. + bool need_transpose = false; + bool relaxed = false; + uint32_t matrix_stride = 0; + if (member_type.columns > 1) + { + auto decorations = combined_decoration_for_member(target_type, i); + need_transpose = decorations.get(DecorationRowMajor); + relaxed = decorations.get(DecorationRelaxedPrecision); + matrix_stride = type_struct_member_matrix_stride(target_type, i); + } + + auto tmp = flattened_access_chain(base, indices, count, member_type, offset + member_offset, matrix_stride, + 0 /* array_stride */, need_transpose); + + // Cannot forward transpositions, so resolve them here. + if (need_transpose) + expr += convert_row_major_matrix(tmp, member_type, 0, false, relaxed); + else + expr += tmp; + } + + expr += backend.can_declare_struct_inline ? ")" : "}"; + + return expr; +} + +std::string CompilerGLSL::flattened_access_chain_matrix(uint32_t base, const uint32_t *indices, uint32_t count, + const SPIRType &target_type, uint32_t offset, + uint32_t matrix_stride, bool need_transpose) +{ + assert(matrix_stride); + SPIRType tmp_type = target_type; + if (need_transpose) + swap(tmp_type.vecsize, tmp_type.columns); + + std::string expr; + + expr += type_to_glsl_constructor(tmp_type); + expr += "("; + + for (uint32_t i = 0; i < tmp_type.columns; i++) + { + if (i != 0) + expr += ", "; + + expr += flattened_access_chain_vector(base, indices, count, tmp_type, offset + i * matrix_stride, matrix_stride, + /* need_transpose= */ false); + } + + expr += ")"; + + return expr; +} + +std::string CompilerGLSL::flattened_access_chain_vector(uint32_t base, const uint32_t *indices, uint32_t count, + const SPIRType &target_type, uint32_t offset, + uint32_t matrix_stride, bool need_transpose) +{ + auto result = flattened_access_chain_offset(expression_type(base), indices, count, offset, 16); + + auto buffer_name = to_name(expression_type(base).self); + + if (need_transpose) + { + std::string expr; + + if (target_type.vecsize > 1) + { + expr += type_to_glsl_constructor(target_type); + expr += "("; + } + + for (uint32_t i = 0; i < target_type.vecsize; ++i) + { + if (i != 0) + expr += ", "; + + uint32_t component_offset = result.second + i * matrix_stride; + + assert(component_offset % (target_type.width / 8) == 0); + uint32_t index = component_offset / (target_type.width / 8); + + expr += buffer_name; + expr += "["; + expr += result.first; // this is a series of N1 * k1 + N2 * k2 + ... that is either empty or ends with a + + expr += convert_to_string(index / 4); + expr += "]"; + + expr += vector_swizzle(1, index % 4); + } + + if (target_type.vecsize > 1) + { + expr += ")"; + } + + return expr; + } + else + { + assert(result.second % (target_type.width / 8) == 0); + uint32_t index = result.second / (target_type.width / 8); + + std::string expr; + + expr += buffer_name; + expr += "["; + expr += result.first; // this is a series of N1 * k1 + N2 * k2 + ... that is either empty or ends with a + + expr += convert_to_string(index / 4); + expr += "]"; + + expr += vector_swizzle(target_type.vecsize, index % 4); + + return expr; + } +} + +std::pair CompilerGLSL::flattened_access_chain_offset( + const SPIRType &basetype, const uint32_t *indices, uint32_t count, uint32_t offset, uint32_t word_stride, + bool *need_transpose, uint32_t *out_matrix_stride, uint32_t *out_array_stride, bool ptr_chain) +{ + // Start traversing type hierarchy at the proper non-pointer types. + const auto *type = &get_pointee_type(basetype); + + std::string expr; + + // Inherit matrix information in case we are access chaining a vector which might have come from a row major layout. + bool row_major_matrix_needs_conversion = need_transpose ? *need_transpose : false; + uint32_t matrix_stride = out_matrix_stride ? *out_matrix_stride : 0; + uint32_t array_stride = out_array_stride ? *out_array_stride : 0; + + for (uint32_t i = 0; i < count; i++) + { + uint32_t index = indices[i]; + + // Pointers + if (ptr_chain && i == 0) + { + // Here, the pointer type will be decorated with an array stride. + array_stride = get_decoration(basetype.self, DecorationArrayStride); + if (!array_stride) + SPIRV_CROSS_THROW("SPIR-V does not define ArrayStride for buffer block."); + + auto *constant = maybe_get(index); + if (constant) + { + // Constant array access. + offset += constant->scalar() * array_stride; + } + else + { + // Dynamic array access. + if (array_stride % word_stride) + { + SPIRV_CROSS_THROW("Array stride for dynamic indexing must be divisible by the size " + "of a 4-component vector. " + "Likely culprit here is a float or vec2 array inside a push " + "constant block which is std430. " + "This cannot be flattened. Try using std140 layout instead."); + } + + expr += to_enclosed_expression(index); + expr += " * "; + expr += convert_to_string(array_stride / word_stride); + expr += " + "; + } + } + // Arrays + else if (!type->array.empty()) + { + auto *constant = maybe_get(index); + if (constant) + { + // Constant array access. + offset += constant->scalar() * array_stride; + } + else + { + // Dynamic array access. + if (array_stride % word_stride) + { + SPIRV_CROSS_THROW("Array stride for dynamic indexing must be divisible by the size " + "of a 4-component vector. " + "Likely culprit here is a float or vec2 array inside a push " + "constant block which is std430. " + "This cannot be flattened. Try using std140 layout instead."); + } + + expr += to_enclosed_expression(index, false); + expr += " * "; + expr += convert_to_string(array_stride / word_stride); + expr += " + "; + } + + uint32_t parent_type = type->parent_type; + type = &get(parent_type); + + if (!type->array.empty()) + array_stride = get_decoration(parent_type, DecorationArrayStride); + } + // For structs, the index refers to a constant, which indexes into the members. + // We also check if this member is a builtin, since we then replace the entire expression with the builtin one. + else if (type->basetype == SPIRType::Struct) + { + index = evaluate_constant_u32(index); + + if (index >= type->member_types.size()) + SPIRV_CROSS_THROW("Member index is out of bounds!"); + + offset += type_struct_member_offset(*type, index); + + auto &struct_type = *type; + type = &get(type->member_types[index]); + + if (type->columns > 1) + { + matrix_stride = type_struct_member_matrix_stride(struct_type, index); + row_major_matrix_needs_conversion = + combined_decoration_for_member(struct_type, index).get(DecorationRowMajor); + } + else + row_major_matrix_needs_conversion = false; + + if (!type->array.empty()) + array_stride = type_struct_member_array_stride(struct_type, index); + } + // Matrix -> Vector + else if (type->columns > 1) + { + auto *constant = maybe_get(index); + if (constant) + { + index = evaluate_constant_u32(index); + offset += index * (row_major_matrix_needs_conversion ? (type->width / 8) : matrix_stride); + } + else + { + uint32_t indexing_stride = row_major_matrix_needs_conversion ? (type->width / 8) : matrix_stride; + // Dynamic array access. + if (indexing_stride % word_stride) + { + SPIRV_CROSS_THROW("Matrix stride for dynamic indexing must be divisible by the size of a " + "4-component vector. " + "Likely culprit here is a row-major matrix being accessed dynamically. " + "This cannot be flattened. Try using std140 layout instead."); + } + + expr += to_enclosed_expression(index, false); + expr += " * "; + expr += convert_to_string(indexing_stride / word_stride); + expr += " + "; + } + + type = &get(type->parent_type); + } + // Vector -> Scalar + else if (type->vecsize > 1) + { + auto *constant = maybe_get(index); + if (constant) + { + index = evaluate_constant_u32(index); + offset += index * (row_major_matrix_needs_conversion ? matrix_stride : (type->width / 8)); + } + else + { + uint32_t indexing_stride = row_major_matrix_needs_conversion ? matrix_stride : (type->width / 8); + + // Dynamic array access. + if (indexing_stride % word_stride) + { + SPIRV_CROSS_THROW("Stride for dynamic vector indexing must be divisible by the " + "size of a 4-component vector. " + "This cannot be flattened in legacy targets."); + } + + expr += to_enclosed_expression(index, false); + expr += " * "; + expr += convert_to_string(indexing_stride / word_stride); + expr += " + "; + } + + type = &get(type->parent_type); + } + else + SPIRV_CROSS_THROW("Cannot subdivide a scalar value!"); + } + + if (need_transpose) + *need_transpose = row_major_matrix_needs_conversion; + if (out_matrix_stride) + *out_matrix_stride = matrix_stride; + if (out_array_stride) + *out_array_stride = array_stride; + + return std::make_pair(expr, offset); +} + +bool CompilerGLSL::should_dereference(uint32_t id) +{ + const auto &type = expression_type(id); + // Non-pointer expressions don't need to be dereferenced. + if (!type.pointer) + return false; + + // Handles shouldn't be dereferenced either. + if (!expression_is_lvalue(id)) + return false; + + // If id is a variable but not a phi variable, we should not dereference it. + if (auto *var = maybe_get(id)) + return var->phi_variable; + + if (auto *expr = maybe_get(id)) + { + // If id is an access chain, we should not dereference it. + if (expr->access_chain) + return false; + + // If id is a forwarded copy of a variable pointer, we should not dereference it. + SPIRVariable *var = nullptr; + while (expr->loaded_from && expression_is_forwarded(expr->self)) + { + auto &src_type = expression_type(expr->loaded_from); + // To be a copy, the pointer and its source expression must be the + // same type. Can't check type.self, because for some reason that's + // usually the base type with pointers stripped off. This check is + // complex enough that I've hoisted it out of the while condition. + if (src_type.pointer != type.pointer || src_type.pointer_depth != type.pointer_depth || + src_type.parent_type != type.parent_type) + break; + if ((var = maybe_get(expr->loaded_from))) + break; + if (!(expr = maybe_get(expr->loaded_from))) + break; + } + + return !var || var->phi_variable; + } + + // Otherwise, we should dereference this pointer expression. + return true; +} + +bool CompilerGLSL::should_forward(uint32_t id) const +{ + // If id is a variable we will try to forward it regardless of force_temporary check below + // This is important because otherwise we'll get local sampler copies (highp sampler2D foo = bar) that are invalid in OpenGL GLSL + + auto *var = maybe_get(id); + if (var) + { + // Never forward volatile builtin variables, e.g. SPIR-V 1.6 HelperInvocation. + return !(has_decoration(id, DecorationBuiltIn) && has_decoration(id, DecorationVolatile)); + } + + // For debugging emit temporary variables for all expressions + if (options.force_temporary) + return false; + + // If an expression carries enough dependencies we need to stop forwarding at some point, + // or we explode compilers. There are usually limits to how much we can nest expressions. + auto *expr = maybe_get(id); + const uint32_t max_expression_dependencies = 64; + if (expr && expr->expression_dependencies.size() >= max_expression_dependencies) + return false; + + if (expr && expr->loaded_from + && has_decoration(expr->loaded_from, DecorationBuiltIn) + && has_decoration(expr->loaded_from, DecorationVolatile)) + { + // Never forward volatile builtin variables, e.g. SPIR-V 1.6 HelperInvocation. + return false; + } + + // Immutable expression can always be forwarded. + if (is_immutable(id)) + return true; + + return false; +} + +bool CompilerGLSL::should_suppress_usage_tracking(uint32_t id) const +{ + // Used only by opcodes which don't do any real "work", they just swizzle data in some fashion. + return !expression_is_forwarded(id) || expression_suppresses_usage_tracking(id); +} + +void CompilerGLSL::track_expression_read(uint32_t id) +{ + switch (ir.ids[id].get_type()) + { + case TypeExpression: + { + auto &e = get(id); + for (auto implied_read : e.implied_read_expressions) + track_expression_read(implied_read); + break; + } + + case TypeAccessChain: + { + auto &e = get(id); + for (auto implied_read : e.implied_read_expressions) + track_expression_read(implied_read); + break; + } + + default: + break; + } + + // If we try to read a forwarded temporary more than once we will stamp out possibly complex code twice. + // In this case, it's better to just bind the complex expression to the temporary and read that temporary twice. + if (expression_is_forwarded(id) && !expression_suppresses_usage_tracking(id)) + { + auto &v = expression_usage_counts[id]; + v++; + + // If we create an expression outside a loop, + // but access it inside a loop, we're implicitly reading it multiple times. + // If the expression in question is expensive, we should hoist it out to avoid relying on loop-invariant code motion + // working inside the backend compiler. + if (expression_read_implies_multiple_reads(id)) + v++; + + if (v >= 2) + { + //if (v == 2) + // fprintf(stderr, "ID %u was forced to temporary due to more than 1 expression use!\n", id); + + // Force a recompile after this pass to avoid forwarding this variable. + force_temporary_and_recompile(id); + } + } +} + +bool CompilerGLSL::args_will_forward(uint32_t id, const uint32_t *args, uint32_t num_args, bool pure) +{ + if (forced_temporaries.find(id) != end(forced_temporaries)) + return false; + + for (uint32_t i = 0; i < num_args; i++) + if (!should_forward(args[i])) + return false; + + // We need to forward globals as well. + if (!pure) + { + for (auto global : global_variables) + if (!should_forward(global)) + return false; + for (auto aliased : aliased_variables) + if (!should_forward(aliased)) + return false; + } + + return true; +} + +void CompilerGLSL::register_impure_function_call() +{ + // Impure functions can modify globals and aliased variables, so invalidate them as well. + for (auto global : global_variables) + flush_dependees(get(global)); + for (auto aliased : aliased_variables) + flush_dependees(get(aliased)); +} + +void CompilerGLSL::register_call_out_argument(uint32_t id) +{ + register_write(id); + + auto *var = maybe_get(id); + if (var) + flush_variable_declaration(var->self); +} + +string CompilerGLSL::variable_decl_function_local(SPIRVariable &var) +{ + // These variables are always function local, + // so make sure we emit the variable without storage qualifiers. + // Some backends will inject custom variables locally in a function + // with a storage qualifier which is not function-local. + auto old_storage = var.storage; + var.storage = StorageClassFunction; + auto expr = variable_decl(var); + var.storage = old_storage; + return expr; +} + +void CompilerGLSL::emit_variable_temporary_copies(const SPIRVariable &var) +{ + // Ensure that we declare phi-variable copies even if the original declaration isn't deferred + if (var.allocate_temporary_copy && !flushed_phi_variables.count(var.self)) + { + auto &type = get(var.basetype); + auto &flags = get_decoration_bitset(var.self); + statement(flags_to_qualifiers_glsl(type, flags), variable_decl(type, join("_", var.self, "_copy")), ";"); + flushed_phi_variables.insert(var.self); + } +} + +void CompilerGLSL::flush_variable_declaration(uint32_t id) +{ + // Ensure that we declare phi-variable copies even if the original declaration isn't deferred + auto *var = maybe_get(id); + if (var && var->deferred_declaration) + { + string initializer; + if (options.force_zero_initialized_variables && + (var->storage == StorageClassFunction || var->storage == StorageClassGeneric || + var->storage == StorageClassPrivate) && + !var->initializer && type_can_zero_initialize(get_variable_data_type(*var))) + { + initializer = join(" = ", to_zero_initialized_expression(get_variable_data_type_id(*var))); + } + + statement(variable_decl_function_local(*var), initializer, ";"); + var->deferred_declaration = false; + } + if (var) + { + emit_variable_temporary_copies(*var); + } +} + +bool CompilerGLSL::remove_duplicate_swizzle(string &op) +{ + auto pos = op.find_last_of('.'); + if (pos == string::npos || pos == 0) + return false; + + string final_swiz = op.substr(pos + 1, string::npos); + + if (backend.swizzle_is_function) + { + if (final_swiz.size() < 2) + return false; + + if (final_swiz.substr(final_swiz.size() - 2, string::npos) == "()") + final_swiz.erase(final_swiz.size() - 2, string::npos); + else + return false; + } + + // Check if final swizzle is of form .x, .xy, .xyz, .xyzw or similar. + // If so, and previous swizzle is of same length, + // we can drop the final swizzle altogether. + for (uint32_t i = 0; i < final_swiz.size(); i++) + { + static const char expected[] = { 'x', 'y', 'z', 'w' }; + if (i >= 4 || final_swiz[i] != expected[i]) + return false; + } + + auto prevpos = op.find_last_of('.', pos - 1); + if (prevpos == string::npos) + return false; + + prevpos++; + + // Make sure there are only swizzles here ... + for (auto i = prevpos; i < pos; i++) + { + if (op[i] < 'w' || op[i] > 'z') + { + // If swizzles are foo.xyz() like in C++ backend for example, check for that. + if (backend.swizzle_is_function && i + 2 == pos && op[i] == '(' && op[i + 1] == ')') + break; + return false; + } + } + + // If original swizzle is large enough, just carve out the components we need. + // E.g. foobar.wyx.xy will turn into foobar.wy. + if (pos - prevpos >= final_swiz.size()) + { + op.erase(prevpos + final_swiz.size(), string::npos); + + // Add back the function call ... + if (backend.swizzle_is_function) + op += "()"; + } + return true; +} + +// Optimizes away vector swizzles where we have something like +// vec3 foo; +// foo.xyz <-- swizzle expression does nothing. +// This is a very common pattern after OpCompositeCombine. +bool CompilerGLSL::remove_unity_swizzle(uint32_t base, string &op) +{ + auto pos = op.find_last_of('.'); + if (pos == string::npos || pos == 0) + return false; + + string final_swiz = op.substr(pos + 1, string::npos); + + if (backend.swizzle_is_function) + { + if (final_swiz.size() < 2) + return false; + + if (final_swiz.substr(final_swiz.size() - 2, string::npos) == "()") + final_swiz.erase(final_swiz.size() - 2, string::npos); + else + return false; + } + + // Check if final swizzle is of form .x, .xy, .xyz, .xyzw or similar. + // If so, and previous swizzle is of same length, + // we can drop the final swizzle altogether. + for (uint32_t i = 0; i < final_swiz.size(); i++) + { + static const char expected[] = { 'x', 'y', 'z', 'w' }; + if (i >= 4 || final_swiz[i] != expected[i]) + return false; + } + + auto &type = expression_type(base); + + // Sanity checking ... + assert(type.columns == 1 && type.array.empty()); + + if (type.vecsize == final_swiz.size()) + op.erase(pos, string::npos); + return true; +} + +string CompilerGLSL::build_composite_combiner(uint32_t return_type, const uint32_t *elems, uint32_t length) +{ + ID base = 0; + string op; + string subop; + + // Can only merge swizzles for vectors. + auto &type = get(return_type); + bool can_apply_swizzle_opt = type.basetype != SPIRType::Struct && type.array.empty() && type.columns == 1; + bool swizzle_optimization = false; + + for (uint32_t i = 0; i < length; i++) + { + auto *e = maybe_get(elems[i]); + + // If we're merging another scalar which belongs to the same base + // object, just merge the swizzles to avoid triggering more than 1 expression read as much as possible! + if (can_apply_swizzle_opt && e && e->base_expression && e->base_expression == base) + { + // Only supposed to be used for vector swizzle -> scalar. + assert(!e->expression.empty() && e->expression.front() == '.'); + subop += e->expression.substr(1, string::npos); + swizzle_optimization = true; + } + else + { + // We'll likely end up with duplicated swizzles, e.g. + // foobar.xyz.xyz from patterns like + // OpVectorShuffle + // OpCompositeExtract x 3 + // OpCompositeConstruct 3x + other scalar. + // Just modify op in-place. + if (swizzle_optimization) + { + if (backend.swizzle_is_function) + subop += "()"; + + // Don't attempt to remove unity swizzling if we managed to remove duplicate swizzles. + // The base "foo" might be vec4, while foo.xyz is vec3 (OpVectorShuffle) and looks like a vec3 due to the .xyz tacked on. + // We only want to remove the swizzles if we're certain that the resulting base will be the same vecsize. + // Essentially, we can only remove one set of swizzles, since that's what we have control over ... + // Case 1: + // foo.yxz.xyz: Duplicate swizzle kicks in, giving foo.yxz, we are done. + // foo.yxz was the result of OpVectorShuffle and we don't know the type of foo. + // Case 2: + // foo.xyz: Duplicate swizzle won't kick in. + // If foo is vec3, we can remove xyz, giving just foo. + if (!remove_duplicate_swizzle(subop)) + remove_unity_swizzle(base, subop); + + // Strips away redundant parens if we created them during component extraction. + strip_enclosed_expression(subop); + swizzle_optimization = false; + op += subop; + } + else + op += subop; + + if (i) + op += ", "; + + bool uses_buffer_offset = + type.basetype == SPIRType::Struct && has_member_decoration(type.self, i, DecorationOffset); + subop = to_composite_constructor_expression(type, elems[i], uses_buffer_offset); + } + + base = e ? e->base_expression : ID(0); + } + + if (swizzle_optimization) + { + if (backend.swizzle_is_function) + subop += "()"; + + if (!remove_duplicate_swizzle(subop)) + remove_unity_swizzle(base, subop); + // Strips away redundant parens if we created them during component extraction. + strip_enclosed_expression(subop); + } + + op += subop; + return op; +} + +bool CompilerGLSL::skip_argument(uint32_t id) const +{ + if (!combined_image_samplers.empty() || !options.vulkan_semantics) + { + auto &type = expression_type(id); + if (type.basetype == SPIRType::Sampler || (type.basetype == SPIRType::Image && type.image.sampled == 1)) + return true; + } + return false; +} + +bool CompilerGLSL::optimize_read_modify_write(const SPIRType &type, const string &lhs, const string &rhs) +{ + // Do this with strings because we have a very clear pattern we can check for and it avoids + // adding lots of special cases to the code emission. + if (rhs.size() < lhs.size() + 3) + return false; + + // Do not optimize matrices. They are a bit awkward to reason about in general + // (in which order does operation happen?), and it does not work on MSL anyways. + if (type.vecsize > 1 && type.columns > 1) + return false; + + auto index = rhs.find(lhs); + if (index != 0) + return false; + + // TODO: Shift operators, but it's not important for now. + auto op = rhs.find_first_of("+-/*%|&^", lhs.size() + 1); + if (op != lhs.size() + 1) + return false; + + // Check that the op is followed by space. This excludes && and ||. + if (rhs[op + 1] != ' ') + return false; + + char bop = rhs[op]; + auto expr = rhs.substr(lhs.size() + 3); + + // Avoids false positives where we get a = a * b + c. + // Normally, these expressions are always enclosed, but unexpected code paths may end up hitting this. + if (needs_enclose_expression(expr)) + return false; + + // Try to find increments and decrements. Makes it look neater as += 1, -= 1 is fairly rare to see in real code. + // Find some common patterns which are equivalent. + if ((bop == '+' || bop == '-') && (expr == "1" || expr == "uint(1)" || expr == "1u" || expr == "int(1u)")) + statement(lhs, bop, bop, ";"); + else + statement(lhs, " ", bop, "= ", expr, ";"); + return true; +} + +void CompilerGLSL::register_control_dependent_expression(uint32_t expr) +{ + if (forwarded_temporaries.find(expr) == end(forwarded_temporaries)) + return; + + assert(current_emitting_block); + current_emitting_block->invalidate_expressions.push_back(expr); +} + +void CompilerGLSL::emit_block_instructions(SPIRBlock &block) +{ + current_emitting_block = █ + + if (backend.requires_relaxed_precision_analysis) + { + // If PHI variables are consumed in unexpected precision contexts, copy them here. + for (size_t i = 0, n = block.phi_variables.size(); i < n; i++) + { + auto &phi = block.phi_variables[i]; + + // Ensure we only copy once. We know a-priori that this array will lay out + // the same function variables together. + if (i && block.phi_variables[i - 1].function_variable == phi.function_variable) + continue; + + auto itr = temporary_to_mirror_precision_alias.find(phi.function_variable); + if (itr != temporary_to_mirror_precision_alias.end()) + { + // Explicitly, we don't want to inherit RelaxedPrecision state in this CopyObject, + // so it helps to have handle_instruction_precision() on the outside of emit_instruction(). + EmbeddedInstruction inst; + inst.op = OpCopyObject; + inst.length = 3; + inst.ops.push_back(expression_type_id(itr->first)); + inst.ops.push_back(itr->second); + inst.ops.push_back(itr->first); + emit_instruction(inst); + } + } + } + + for (auto &op : block.ops) + { + auto temporary_copy = handle_instruction_precision(op); + emit_instruction(op); + if (temporary_copy.dst_id) + { + // Explicitly, we don't want to inherit RelaxedPrecision state in this CopyObject, + // so it helps to have handle_instruction_precision() on the outside of emit_instruction(). + EmbeddedInstruction inst; + inst.op = OpCopyObject; + inst.length = 3; + inst.ops.push_back(expression_type_id(temporary_copy.src_id)); + inst.ops.push_back(temporary_copy.dst_id); + inst.ops.push_back(temporary_copy.src_id); + + // Never attempt to hoist mirrored temporaries. + // They are hoisted in lock-step with their parents. + block_temporary_hoisting = true; + emit_instruction(inst); + block_temporary_hoisting = false; + } + } + + current_emitting_block = nullptr; +} + +void CompilerGLSL::disallow_forwarding_in_expression_chain(const SPIRExpression &expr) +{ + // Allow trivially forwarded expressions like OpLoad or trivial shuffles, + // these will be marked as having suppressed usage tracking. + // Our only concern is to make sure arithmetic operations are done in similar ways. + if (expression_is_forwarded(expr.self) && !expression_suppresses_usage_tracking(expr.self) && + forced_invariant_temporaries.count(expr.self) == 0) + { + force_temporary_and_recompile(expr.self); + forced_invariant_temporaries.insert(expr.self); + + for (auto &dependent : expr.expression_dependencies) + disallow_forwarding_in_expression_chain(get(dependent)); + } +} + +void CompilerGLSL::handle_store_to_invariant_variable(uint32_t store_id, uint32_t value_id) +{ + // Variables or access chains marked invariant are complicated. We will need to make sure the code-gen leading up to + // this variable is consistent. The failure case for SPIRV-Cross is when an expression is forced to a temporary + // in one translation unit, but not another, e.g. due to multiple use of an expression. + // This causes variance despite the output variable being marked invariant, so the solution here is to force all dependent + // expressions to be temporaries. + // It is uncertain if this is enough to support invariant in all possible cases, but it should be good enough + // for all reasonable uses of invariant. + if (!has_decoration(store_id, DecorationInvariant)) + return; + + auto *expr = maybe_get(value_id); + if (!expr) + return; + + disallow_forwarding_in_expression_chain(*expr); +} + +void CompilerGLSL::emit_store_statement(uint32_t lhs_expression, uint32_t rhs_expression) +{ + auto rhs = to_pointer_expression(rhs_expression); + + // Statements to OpStore may be empty if it is a struct with zero members. Just forward the store to /dev/null. + if (!rhs.empty()) + { + handle_store_to_invariant_variable(lhs_expression, rhs_expression); + + if (!unroll_array_to_complex_store(lhs_expression, rhs_expression)) + { + auto lhs = to_dereferenced_expression(lhs_expression); + if (has_decoration(lhs_expression, DecorationNonUniform)) + convert_non_uniform_expression(lhs, lhs_expression); + + // We might need to cast in order to store to a builtin. + cast_to_variable_store(lhs_expression, rhs, expression_type(rhs_expression)); + + // Tries to optimize assignments like " = op expr". + // While this is purely cosmetic, this is important for legacy ESSL where loop + // variable increments must be in either i++ or i += const-expr. + // Without this, we end up with i = i + 1, which is correct GLSL, but not correct GLES 2.0. + if (!optimize_read_modify_write(expression_type(rhs_expression), lhs, rhs)) + statement(lhs, " = ", rhs, ";"); + } + register_write(lhs_expression); + } +} + +uint32_t CompilerGLSL::get_integer_width_for_instruction(const Instruction &instr) const +{ + if (instr.length < 3) + return 32; + + auto *ops = stream(instr); + + switch (instr.op) + { + case OpSConvert: + case OpConvertSToF: + case OpUConvert: + case OpConvertUToF: + case OpIEqual: + case OpINotEqual: + case OpSLessThan: + case OpSLessThanEqual: + case OpSGreaterThan: + case OpSGreaterThanEqual: + case OpULessThan: + case OpULessThanEqual: + case OpUGreaterThan: + case OpUGreaterThanEqual: + return expression_type(ops[2]).width; + + case OpSMulExtended: + case OpUMulExtended: + return get(get(ops[0]).member_types[0]).width; + + default: + { + // We can look at result type which is more robust. + auto *type = maybe_get(ops[0]); + if (type && type_is_integral(*type)) + return type->width; + else + return 32; + } + } +} + +uint32_t CompilerGLSL::get_integer_width_for_glsl_instruction(GLSLstd450 op, const uint32_t *ops, uint32_t length) const +{ + if (length < 1) + return 32; + + switch (op) + { + case GLSLstd450SAbs: + case GLSLstd450SSign: + case GLSLstd450UMin: + case GLSLstd450SMin: + case GLSLstd450UMax: + case GLSLstd450SMax: + case GLSLstd450UClamp: + case GLSLstd450SClamp: + case GLSLstd450FindSMsb: + case GLSLstd450FindUMsb: + return expression_type(ops[0]).width; + + default: + { + // We don't need to care about other opcodes, just return 32. + return 32; + } + } +} + +void CompilerGLSL::forward_relaxed_precision(uint32_t dst_id, const uint32_t *args, uint32_t length) +{ + // Only GLSL supports RelaxedPrecision directly. + // We cannot implement this in HLSL or MSL because it is tied to the type system. + // In SPIR-V, everything must masquerade as 32-bit. + if (!backend.requires_relaxed_precision_analysis) + return; + + auto input_precision = analyze_expression_precision(args, length); + + // For expressions which are loaded or directly forwarded, we inherit mediump implicitly. + // For dst_id to be analyzed properly, it must inherit any relaxed precision decoration from src_id. + if (input_precision == Options::Mediump) + set_decoration(dst_id, DecorationRelaxedPrecision); +} + +CompilerGLSL::Options::Precision CompilerGLSL::analyze_expression_precision(const uint32_t *args, uint32_t length) const +{ + // Now, analyze the precision at which the arguments would run. + // GLSL rules are such that the precision used to evaluate an expression is equal to the highest precision + // for the inputs. Constants do not have inherent precision and do not contribute to this decision. + // If all inputs are constants, they inherit precision from outer expressions, including an l-value. + // In this case, we'll have to force a temporary for dst_id so that we can bind the constant expression with + // correct precision. + bool expression_has_highp = false; + bool expression_has_mediump = false; + + for (uint32_t i = 0; i < length; i++) + { + uint32_t arg = args[i]; + + auto handle_type = ir.ids[arg].get_type(); + if (handle_type == TypeConstant || handle_type == TypeConstantOp || handle_type == TypeUndef) + continue; + + if (has_decoration(arg, DecorationRelaxedPrecision)) + expression_has_mediump = true; + else + expression_has_highp = true; + } + + if (expression_has_highp) + return Options::Highp; + else if (expression_has_mediump) + return Options::Mediump; + else + return Options::DontCare; +} + +void CompilerGLSL::analyze_precision_requirements(uint32_t type_id, uint32_t dst_id, uint32_t *args, uint32_t length) +{ + if (!backend.requires_relaxed_precision_analysis) + return; + + auto &type = get(type_id); + + // RelaxedPrecision only applies to 32-bit values. + if (type.basetype != SPIRType::Float && type.basetype != SPIRType::Int && type.basetype != SPIRType::UInt) + return; + + bool operation_is_highp = !has_decoration(dst_id, DecorationRelaxedPrecision); + + auto input_precision = analyze_expression_precision(args, length); + if (input_precision == Options::DontCare) + { + consume_temporary_in_precision_context(type_id, dst_id, input_precision); + return; + } + + // In SPIR-V and GLSL, the semantics are flipped for how relaxed precision is determined. + // In SPIR-V, the operation itself marks RelaxedPrecision, meaning that inputs can be truncated to 16-bit. + // However, if the expression is not, inputs must be expanded to 32-bit first, + // since the operation must run at high precision. + // This is the awkward part, because if we have mediump inputs, or expressions which derived from mediump, + // we might have to forcefully bind the source IDs to highp temporaries. This is done by clearing decorations + // and forcing temporaries. Similarly for mediump operations. We bind highp expressions to mediump variables. + if ((operation_is_highp && input_precision == Options::Mediump) || + (!operation_is_highp && input_precision == Options::Highp)) + { + auto precision = operation_is_highp ? Options::Highp : Options::Mediump; + for (uint32_t i = 0; i < length; i++) + { + // Rewrites the opcode so that we consume an ID in correct precision context. + // This is pretty hacky, but it's the most straight forward way of implementing this without adding + // lots of extra passes to rewrite all code blocks. + args[i] = consume_temporary_in_precision_context(expression_type_id(args[i]), args[i], precision); + } + } +} + +// This is probably not exhaustive ... +static bool opcode_is_precision_sensitive_operation(Op op) +{ + switch (op) + { + case OpFAdd: + case OpFSub: + case OpFMul: + case OpFNegate: + case OpIAdd: + case OpISub: + case OpIMul: + case OpSNegate: + case OpFMod: + case OpFDiv: + case OpFRem: + case OpSMod: + case OpSDiv: + case OpSRem: + case OpUMod: + case OpUDiv: + case OpVectorTimesMatrix: + case OpMatrixTimesVector: + case OpMatrixTimesMatrix: + case OpDPdx: + case OpDPdy: + case OpDPdxCoarse: + case OpDPdyCoarse: + case OpDPdxFine: + case OpDPdyFine: + case OpFwidth: + case OpFwidthCoarse: + case OpFwidthFine: + case OpVectorTimesScalar: + case OpMatrixTimesScalar: + case OpOuterProduct: + case OpFConvert: + case OpSConvert: + case OpUConvert: + case OpConvertSToF: + case OpConvertUToF: + case OpConvertFToU: + case OpConvertFToS: + return true; + + default: + return false; + } +} + +// Instructions which just load data but don't do any arithmetic operation should just inherit the decoration. +// SPIR-V doesn't require this, but it's somewhat implied it has to work this way, relaxed precision is only +// relevant when operating on the IDs, not when shuffling things around. +static bool opcode_is_precision_forwarding_instruction(Op op, uint32_t &arg_count) +{ + switch (op) + { + case OpLoad: + case OpAccessChain: + case OpInBoundsAccessChain: + case OpCompositeExtract: + case OpVectorExtractDynamic: + case OpSampledImage: + case OpImage: + case OpCopyObject: + + case OpImageRead: + case OpImageFetch: + case OpImageSampleImplicitLod: + case OpImageSampleProjImplicitLod: + case OpImageSampleDrefImplicitLod: + case OpImageSampleProjDrefImplicitLod: + case OpImageSampleExplicitLod: + case OpImageSampleProjExplicitLod: + case OpImageSampleDrefExplicitLod: + case OpImageSampleProjDrefExplicitLod: + case OpImageGather: + case OpImageDrefGather: + case OpImageSparseRead: + case OpImageSparseFetch: + case OpImageSparseSampleImplicitLod: + case OpImageSparseSampleProjImplicitLod: + case OpImageSparseSampleDrefImplicitLod: + case OpImageSparseSampleProjDrefImplicitLod: + case OpImageSparseSampleExplicitLod: + case OpImageSparseSampleProjExplicitLod: + case OpImageSparseSampleDrefExplicitLod: + case OpImageSparseSampleProjDrefExplicitLod: + case OpImageSparseGather: + case OpImageSparseDrefGather: + arg_count = 1; + return true; + + case OpVectorShuffle: + arg_count = 2; + return true; + + case OpCompositeConstruct: + return true; + + default: + break; + } + + return false; +} + +CompilerGLSL::TemporaryCopy CompilerGLSL::handle_instruction_precision(const Instruction &instruction) +{ + auto ops = stream_mutable(instruction); + auto opcode = static_cast(instruction.op); + uint32_t length = instruction.length; + + if (backend.requires_relaxed_precision_analysis) + { + if (length > 2) + { + uint32_t forwarding_length = length - 2; + + if (opcode_is_precision_sensitive_operation(opcode)) + analyze_precision_requirements(ops[0], ops[1], &ops[2], forwarding_length); + else if (opcode == OpExtInst && length >= 5 && get(ops[2]).ext == SPIRExtension::GLSL) + analyze_precision_requirements(ops[0], ops[1], &ops[4], forwarding_length - 2); + else if (opcode_is_precision_forwarding_instruction(opcode, forwarding_length)) + forward_relaxed_precision(ops[1], &ops[2], forwarding_length); + } + + uint32_t result_type = 0, result_id = 0; + if (instruction_to_result_type(result_type, result_id, opcode, ops, length)) + { + auto itr = temporary_to_mirror_precision_alias.find(ops[1]); + if (itr != temporary_to_mirror_precision_alias.end()) + return { itr->second, itr->first }; + } + } + + return {}; +} + +void CompilerGLSL::emit_instruction(const Instruction &instruction) +{ + auto ops = stream(instruction); + auto opcode = static_cast(instruction.op); + uint32_t length = instruction.length; + +#define GLSL_BOP(op) emit_binary_op(ops[0], ops[1], ops[2], ops[3], #op) +#define GLSL_BOP_CAST(op, type) \ + emit_binary_op_cast(ops[0], ops[1], ops[2], ops[3], #op, type, \ + opcode_is_sign_invariant(opcode), implicit_integer_promotion) +#define GLSL_UOP(op) emit_unary_op(ops[0], ops[1], ops[2], #op) +#define GLSL_UOP_CAST(op) emit_unary_op_cast(ops[0], ops[1], ops[2], #op) +#define GLSL_QFOP(op) emit_quaternary_func_op(ops[0], ops[1], ops[2], ops[3], ops[4], ops[5], #op) +#define GLSL_TFOP(op) emit_trinary_func_op(ops[0], ops[1], ops[2], ops[3], ops[4], #op) +#define GLSL_BFOP(op) emit_binary_func_op(ops[0], ops[1], ops[2], ops[3], #op) +#define GLSL_BFOP_CAST(op, type) \ + emit_binary_func_op_cast(ops[0], ops[1], ops[2], ops[3], #op, type, opcode_is_sign_invariant(opcode)) +#define GLSL_BFOP(op) emit_binary_func_op(ops[0], ops[1], ops[2], ops[3], #op) +#define GLSL_UFOP(op) emit_unary_func_op(ops[0], ops[1], ops[2], #op) + + // If we need to do implicit bitcasts, make sure we do it with the correct type. + uint32_t integer_width = get_integer_width_for_instruction(instruction); + auto int_type = to_signed_basetype(integer_width); + auto uint_type = to_unsigned_basetype(integer_width); + + // Handle C implicit integer promotion rules. + // If we get implicit promotion to int, need to make sure we cast by value to intended return type, + // otherwise, future sign-dependent operations and bitcasts will break. + bool implicit_integer_promotion = integer_width < 32 && backend.implicit_c_integer_promotion_rules && + opcode_can_promote_integer_implicitly(opcode) && + get(ops[0]).vecsize == 1; + + opcode = get_remapped_spirv_op(opcode); + + switch (opcode) + { + // Dealing with memory + case OpLoad: + { + uint32_t result_type = ops[0]; + uint32_t id = ops[1]; + uint32_t ptr = ops[2]; + + flush_variable_declaration(ptr); + + // If we're loading from memory that cannot be changed by the shader, + // just forward the expression directly to avoid needless temporaries. + // If an expression is mutable and forwardable, we speculate that it is immutable. + bool forward = should_forward(ptr) && forced_temporaries.find(id) == end(forced_temporaries); + + // If loading a non-native row-major matrix, mark the expression as need_transpose. + bool need_transpose = false; + bool old_need_transpose = false; + + auto *ptr_expression = maybe_get(ptr); + + if (forward) + { + // If we're forwarding the load, we're also going to forward transpose state, so don't transpose while + // taking the expression. + if (ptr_expression && ptr_expression->need_transpose) + { + old_need_transpose = true; + ptr_expression->need_transpose = false; + need_transpose = true; + } + else if (is_non_native_row_major_matrix(ptr)) + need_transpose = true; + } + + // If we are forwarding this load, + // don't register the read to access chain here, defer that to when we actually use the expression, + // using the add_implied_read_expression mechanism. + string expr; + + bool is_packed = has_extended_decoration(ptr, SPIRVCrossDecorationPhysicalTypePacked); + bool is_remapped = has_extended_decoration(ptr, SPIRVCrossDecorationPhysicalTypeID); + if (forward || (!is_packed && !is_remapped)) + { + // For the simple case, we do not need to deal with repacking. + expr = to_dereferenced_expression(ptr, false); + } + else + { + // If we are not forwarding the expression, we need to unpack and resolve any physical type remapping here before + // storing the expression to a temporary. + expr = to_unpacked_expression(ptr); + } + + auto &type = get(result_type); + auto &expr_type = expression_type(ptr); + + // If the expression has more vector components than the result type, insert + // a swizzle. This shouldn't happen normally on valid SPIR-V, but it might + // happen with e.g. the MSL backend replacing the type of an input variable. + if (expr_type.vecsize > type.vecsize) + expr = enclose_expression(expr + vector_swizzle(type.vecsize, 0)); + + if (forward && ptr_expression) + ptr_expression->need_transpose = old_need_transpose; + + // We might need to cast in order to load from a builtin. + cast_from_variable_load(ptr, expr, type); + + if (forward && ptr_expression) + ptr_expression->need_transpose = false; + + // We might be trying to load a gl_Position[N], where we should be + // doing float4[](gl_in[i].gl_Position, ...) instead. + // Similar workarounds are required for input arrays in tessellation. + // Also, loading from gl_SampleMask array needs special unroll. + unroll_array_from_complex_load(id, ptr, expr); + + if (!type_is_opaque_value(type) && has_decoration(ptr, DecorationNonUniform)) + { + // If we're loading something non-opaque, we need to handle non-uniform descriptor access. + convert_non_uniform_expression(expr, ptr); + } + + if (forward && ptr_expression) + ptr_expression->need_transpose = old_need_transpose; + + bool flattened = ptr_expression && flattened_buffer_blocks.count(ptr_expression->loaded_from) != 0; + + if (backend.needs_row_major_load_workaround && !is_non_native_row_major_matrix(ptr) && !flattened) + rewrite_load_for_wrapped_row_major(expr, result_type, ptr); + + // By default, suppress usage tracking since using same expression multiple times does not imply any extra work. + // However, if we try to load a complex, composite object from a flattened buffer, + // we should avoid emitting the same code over and over and lower the result to a temporary. + bool usage_tracking = flattened && (type.basetype == SPIRType::Struct || (type.columns > 1)); + + SPIRExpression *e = nullptr; + if (!forward && expression_is_non_value_type_array(ptr)) + { + // Complicated load case where we need to make a copy of ptr, but we cannot, because + // it is an array, and our backend does not support arrays as value types. + // Emit the temporary, and copy it explicitly. + e = &emit_uninitialized_temporary_expression(result_type, id); + emit_array_copy(nullptr, id, ptr, StorageClassFunction, get_expression_effective_storage_class(ptr)); + } + else + e = &emit_op(result_type, id, expr, forward, !usage_tracking); + + e->need_transpose = need_transpose; + register_read(id, ptr, forward); + + if (forward) + { + // Pass through whether the result is of a packed type and the physical type ID. + if (has_extended_decoration(ptr, SPIRVCrossDecorationPhysicalTypePacked)) + set_extended_decoration(id, SPIRVCrossDecorationPhysicalTypePacked); + if (has_extended_decoration(ptr, SPIRVCrossDecorationPhysicalTypeID)) + { + set_extended_decoration(id, SPIRVCrossDecorationPhysicalTypeID, + get_extended_decoration(ptr, SPIRVCrossDecorationPhysicalTypeID)); + } + } + else + { + // This might have been set on an earlier compilation iteration, force it to be unset. + unset_extended_decoration(id, SPIRVCrossDecorationPhysicalTypePacked); + unset_extended_decoration(id, SPIRVCrossDecorationPhysicalTypeID); + } + + inherit_expression_dependencies(id, ptr); + if (forward) + add_implied_read_expression(*e, ptr); + break; + } + + case OpInBoundsAccessChain: + case OpAccessChain: + case OpPtrAccessChain: + { + auto *var = maybe_get(ops[2]); + if (var) + flush_variable_declaration(var->self); + + // If the base is immutable, the access chain pointer must also be. + // If an expression is mutable and forwardable, we speculate that it is immutable. + AccessChainMeta meta; + bool ptr_chain = opcode == OpPtrAccessChain; + auto &target_type = get(ops[0]); + auto e = access_chain(ops[2], &ops[3], length - 3, target_type, &meta, ptr_chain); + + // If the base is flattened UBO of struct type, the expression has to be a composite. + // In that case, backends which do not support inline syntax need it to be bound to a temporary. + // Otherwise, invalid expressions like ({UBO[0].xyz, UBO[0].w, UBO[1]}).member are emitted. + bool requires_temporary = false; + if (flattened_buffer_blocks.count(ops[2]) && target_type.basetype == SPIRType::Struct) + requires_temporary = !backend.can_declare_struct_inline; + + auto &expr = requires_temporary ? + emit_op(ops[0], ops[1], std::move(e), false) : + set(ops[1], std::move(e), ops[0], should_forward(ops[2])); + + auto *backing_variable = maybe_get_backing_variable(ops[2]); + expr.loaded_from = backing_variable ? backing_variable->self : ID(ops[2]); + expr.need_transpose = meta.need_transpose; + expr.access_chain = true; + expr.access_meshlet_position_y = meta.access_meshlet_position_y; + + // Mark the result as being packed. Some platforms handled packed vectors differently than non-packed. + if (meta.storage_is_packed) + set_extended_decoration(ops[1], SPIRVCrossDecorationPhysicalTypePacked); + if (meta.storage_physical_type != 0) + set_extended_decoration(ops[1], SPIRVCrossDecorationPhysicalTypeID, meta.storage_physical_type); + if (meta.storage_is_invariant) + set_decoration(ops[1], DecorationInvariant); + if (meta.flattened_struct) + flattened_structs[ops[1]] = true; + if (meta.relaxed_precision && backend.requires_relaxed_precision_analysis) + set_decoration(ops[1], DecorationRelaxedPrecision); + + // If we have some expression dependencies in our access chain, this access chain is technically a forwarded + // temporary which could be subject to invalidation. + // Need to assume we're forwarded while calling inherit_expression_depdendencies. + forwarded_temporaries.insert(ops[1]); + // The access chain itself is never forced to a temporary, but its dependencies might. + suppressed_usage_tracking.insert(ops[1]); + + for (uint32_t i = 2; i < length; i++) + { + inherit_expression_dependencies(ops[1], ops[i]); + add_implied_read_expression(expr, ops[i]); + } + + // If we have no dependencies after all, i.e., all indices in the access chain are immutable temporaries, + // we're not forwarded after all. + if (expr.expression_dependencies.empty()) + forwarded_temporaries.erase(ops[1]); + + break; + } + + case OpStore: + { + auto *var = maybe_get(ops[0]); + + if (var && var->statically_assigned) + var->static_expression = ops[1]; + else if (var && var->loop_variable && !var->loop_variable_enable) + var->static_expression = ops[1]; + else if (var && var->remapped_variable && var->static_expression) + { + // Skip the write. + } + else if (flattened_structs.count(ops[0])) + { + store_flattened_struct(ops[0], ops[1]); + register_write(ops[0]); + } + else + { + emit_store_statement(ops[0], ops[1]); + } + + // Storing a pointer results in a variable pointer, so we must conservatively assume + // we can write through it. + if (expression_type(ops[1]).pointer) + register_write(ops[1]); + break; + } + + case OpArrayLength: + { + uint32_t result_type = ops[0]; + uint32_t id = ops[1]; + auto e = access_chain_internal(ops[2], &ops[3], length - 3, ACCESS_CHAIN_INDEX_IS_LITERAL_BIT, nullptr); + if (has_decoration(ops[2], DecorationNonUniform)) + convert_non_uniform_expression(e, ops[2]); + set(id, join(type_to_glsl(get(result_type)), "(", e, ".length())"), result_type, + true); + break; + } + + // Function calls + case OpFunctionCall: + { + uint32_t result_type = ops[0]; + uint32_t id = ops[1]; + uint32_t func = ops[2]; + const auto *arg = &ops[3]; + length -= 3; + + auto &callee = get(func); + auto &return_type = get(callee.return_type); + bool pure = function_is_pure(callee); + bool control_dependent = function_is_control_dependent(callee); + + bool callee_has_out_variables = false; + bool emit_return_value_as_argument = false; + + // Invalidate out variables passed to functions since they can be OpStore'd to. + for (uint32_t i = 0; i < length; i++) + { + if (callee.arguments[i].write_count) + { + register_call_out_argument(arg[i]); + callee_has_out_variables = true; + } + + flush_variable_declaration(arg[i]); + } + + if (!return_type.array.empty() && !backend.can_return_array) + { + callee_has_out_variables = true; + emit_return_value_as_argument = true; + } + + if (!pure) + register_impure_function_call(); + + string funexpr; + SmallVector arglist; + funexpr += to_name(func) + "("; + + if (emit_return_value_as_argument) + { + statement(type_to_glsl(return_type), " ", to_name(id), type_to_array_glsl(return_type, 0), ";"); + arglist.push_back(to_name(id)); + } + + for (uint32_t i = 0; i < length; i++) + { + // Do not pass in separate images or samplers if we're remapping + // to combined image samplers. + if (skip_argument(arg[i])) + continue; + + arglist.push_back(to_func_call_arg(callee.arguments[i], arg[i])); + } + + for (auto &combined : callee.combined_parameters) + { + auto image_id = combined.global_image ? combined.image_id : VariableID(arg[combined.image_id]); + auto sampler_id = combined.global_sampler ? combined.sampler_id : VariableID(arg[combined.sampler_id]); + arglist.push_back(to_combined_image_sampler(image_id, sampler_id)); + } + + append_global_func_args(callee, length, arglist); + + funexpr += merge(arglist); + funexpr += ")"; + + // Check for function call constraints. + check_function_call_constraints(arg, length); + + if (return_type.basetype != SPIRType::Void) + { + // If the function actually writes to an out variable, + // take the conservative route and do not forward. + // The problem is that we might not read the function + // result (and emit the function) before an out variable + // is read (common case when return value is ignored! + // In order to avoid start tracking invalid variables, + // just avoid the forwarding problem altogether. + bool forward = args_will_forward(id, arg, length, pure) && !callee_has_out_variables && pure && + (forced_temporaries.find(id) == end(forced_temporaries)); + + if (emit_return_value_as_argument) + { + statement(funexpr, ";"); + set(id, to_name(id), result_type, true); + } + else + emit_op(result_type, id, funexpr, forward); + + // Function calls are implicit loads from all variables in question. + // Set dependencies for them. + for (uint32_t i = 0; i < length; i++) + register_read(id, arg[i], forward); + + // If we're going to forward the temporary result, + // put dependencies on every variable that must not change. + if (forward) + register_global_read_dependencies(callee, id); + } + else + statement(funexpr, ";"); + + if (control_dependent) + register_control_dependent_expression(id); + + break; + } + + // Composite munging + case OpCompositeConstruct: + { + uint32_t result_type = ops[0]; + uint32_t id = ops[1]; + const auto *const elems = &ops[2]; + length -= 2; + + bool forward = true; + for (uint32_t i = 0; i < length; i++) + forward = forward && should_forward(elems[i]); + + auto &out_type = get(result_type); + auto *in_type = length > 0 ? &expression_type(elems[0]) : nullptr; + + // Only splat if we have vector constructors. + // Arrays and structs must be initialized properly in full. + bool composite = !out_type.array.empty() || out_type.basetype == SPIRType::Struct; + + bool splat = false; + bool swizzle_splat = false; + + if (in_type) + { + splat = in_type->vecsize == 1 && in_type->columns == 1 && !composite && backend.use_constructor_splatting; + swizzle_splat = in_type->vecsize == 1 && in_type->columns == 1 && backend.can_swizzle_scalar; + + if (ir.ids[elems[0]].get_type() == TypeConstant && !type_is_floating_point(*in_type)) + { + // Cannot swizzle literal integers as a special case. + swizzle_splat = false; + } + } + + if (splat || swizzle_splat) + { + uint32_t input = elems[0]; + for (uint32_t i = 0; i < length; i++) + { + if (input != elems[i]) + { + splat = false; + swizzle_splat = false; + } + } + } + + if (out_type.basetype == SPIRType::Struct && !backend.can_declare_struct_inline) + forward = false; + if (!out_type.array.empty() && !backend.can_declare_arrays_inline) + forward = false; + if (type_is_empty(out_type) && !backend.supports_empty_struct) + forward = false; + + string constructor_op; + if (backend.use_initializer_list && composite) + { + bool needs_trailing_tracket = false; + // Only use this path if we are building composites. + // This path cannot be used for arithmetic. + if (backend.use_typed_initializer_list && out_type.basetype == SPIRType::Struct && out_type.array.empty()) + constructor_op += type_to_glsl_constructor(get(result_type)); + else if (backend.use_typed_initializer_list && backend.array_is_value_type && !out_type.array.empty()) + { + // MSL path. Array constructor is baked into type here, do not use _constructor variant. + constructor_op += type_to_glsl_constructor(get(result_type)) + "("; + needs_trailing_tracket = true; + } + constructor_op += "{ "; + + if (type_is_empty(out_type) && !backend.supports_empty_struct) + constructor_op += "0"; + else if (splat) + constructor_op += to_unpacked_expression(elems[0]); + else + constructor_op += build_composite_combiner(result_type, elems, length); + constructor_op += " }"; + if (needs_trailing_tracket) + constructor_op += ")"; + } + else if (swizzle_splat && !composite) + { + constructor_op = remap_swizzle(get(result_type), 1, to_unpacked_expression(elems[0])); + } + else + { + constructor_op = type_to_glsl_constructor(get(result_type)) + "("; + if (type_is_empty(out_type) && !backend.supports_empty_struct) + constructor_op += "0"; + else if (splat) + constructor_op += to_unpacked_expression(elems[0]); + else + constructor_op += build_composite_combiner(result_type, elems, length); + constructor_op += ")"; + } + + if (!constructor_op.empty()) + { + emit_op(result_type, id, constructor_op, forward); + for (uint32_t i = 0; i < length; i++) + inherit_expression_dependencies(id, elems[i]); + } + break; + } + + case OpVectorInsertDynamic: + { + uint32_t result_type = ops[0]; + uint32_t id = ops[1]; + uint32_t vec = ops[2]; + uint32_t comp = ops[3]; + uint32_t index = ops[4]; + + flush_variable_declaration(vec); + + // Make a copy, then use access chain to store the variable. + statement(declare_temporary(result_type, id), to_expression(vec), ";"); + set(id, to_name(id), result_type, true); + auto chain = access_chain_internal(id, &index, 1, 0, nullptr); + statement(chain, " = ", to_unpacked_expression(comp), ";"); + break; + } + + case OpVectorExtractDynamic: + { + uint32_t result_type = ops[0]; + uint32_t id = ops[1]; + + auto expr = access_chain_internal(ops[2], &ops[3], 1, 0, nullptr); + emit_op(result_type, id, expr, should_forward(ops[2])); + inherit_expression_dependencies(id, ops[2]); + inherit_expression_dependencies(id, ops[3]); + break; + } + + case OpCompositeExtract: + { + uint32_t result_type = ops[0]; + uint32_t id = ops[1]; + length -= 3; + + auto &type = get(result_type); + + // We can only split the expression here if our expression is forwarded as a temporary. + bool allow_base_expression = forced_temporaries.find(id) == end(forced_temporaries); + + // Do not allow base expression for struct members. We risk doing "swizzle" optimizations in this case. + auto &composite_type = expression_type(ops[2]); + bool composite_type_is_complex = composite_type.basetype == SPIRType::Struct || !composite_type.array.empty(); + if (composite_type_is_complex) + allow_base_expression = false; + + // Packed expressions or physical ID mapped expressions cannot be split up. + if (has_extended_decoration(ops[2], SPIRVCrossDecorationPhysicalTypePacked) || + has_extended_decoration(ops[2], SPIRVCrossDecorationPhysicalTypeID)) + allow_base_expression = false; + + // Cannot use base expression for row-major matrix row-extraction since we need to interleave access pattern + // into the base expression. + if (is_non_native_row_major_matrix(ops[2])) + allow_base_expression = false; + + AccessChainMeta meta; + SPIRExpression *e = nullptr; + auto *c = maybe_get(ops[2]); + + if (c && !c->specialization && !composite_type_is_complex) + { + auto expr = to_extract_constant_composite_expression(result_type, *c, ops + 3, length); + e = &emit_op(result_type, id, expr, true, true); + } + else if (allow_base_expression && should_forward(ops[2]) && type.vecsize == 1 && type.columns == 1 && length == 1) + { + // Only apply this optimization if result is scalar. + + // We want to split the access chain from the base. + // This is so we can later combine different CompositeExtract results + // with CompositeConstruct without emitting code like + // + // vec3 temp = texture(...).xyz + // vec4(temp.x, temp.y, temp.z, 1.0). + // + // when we actually wanted to emit this + // vec4(texture(...).xyz, 1.0). + // + // Including the base will prevent this and would trigger multiple reads + // from expression causing it to be forced to an actual temporary in GLSL. + auto expr = access_chain_internal(ops[2], &ops[3], length, + ACCESS_CHAIN_INDEX_IS_LITERAL_BIT | ACCESS_CHAIN_CHAIN_ONLY_BIT | + ACCESS_CHAIN_FORCE_COMPOSITE_BIT, &meta); + e = &emit_op(result_type, id, expr, true, should_suppress_usage_tracking(ops[2])); + inherit_expression_dependencies(id, ops[2]); + e->base_expression = ops[2]; + + if (meta.relaxed_precision && backend.requires_relaxed_precision_analysis) + set_decoration(ops[1], DecorationRelaxedPrecision); + } + else + { + auto expr = access_chain_internal(ops[2], &ops[3], length, + ACCESS_CHAIN_INDEX_IS_LITERAL_BIT | ACCESS_CHAIN_FORCE_COMPOSITE_BIT, &meta); + e = &emit_op(result_type, id, expr, should_forward(ops[2]), should_suppress_usage_tracking(ops[2])); + inherit_expression_dependencies(id, ops[2]); + } + + // Pass through some meta information to the loaded expression. + // We can still end up loading a buffer type to a variable, then CompositeExtract from it + // instead of loading everything through an access chain. + e->need_transpose = meta.need_transpose; + if (meta.storage_is_packed) + set_extended_decoration(id, SPIRVCrossDecorationPhysicalTypePacked); + if (meta.storage_physical_type != 0) + set_extended_decoration(id, SPIRVCrossDecorationPhysicalTypeID, meta.storage_physical_type); + if (meta.storage_is_invariant) + set_decoration(id, DecorationInvariant); + + break; + } + + case OpCompositeInsert: + { + uint32_t result_type = ops[0]; + uint32_t id = ops[1]; + uint32_t obj = ops[2]; + uint32_t composite = ops[3]; + const auto *elems = &ops[4]; + length -= 4; + + flush_variable_declaration(composite); + + // CompositeInsert requires a copy + modification, but this is very awkward code in HLL. + // Speculate that the input composite is no longer used, and we can modify it in-place. + // There are various scenarios where this is not possible to satisfy. + bool can_modify_in_place = true; + forced_temporaries.insert(id); + + // Cannot safely RMW PHI variables since they have no way to be invalidated, + // forcing temporaries is not going to help. + // This is similar for Constant and Undef inputs. + // The only safe thing to RMW is SPIRExpression. + // If the expression has already been used (i.e. used in a continue block), we have to keep using + // that loop variable, since we won't be able to override the expression after the fact. + // If the composite is hoisted, we might never be able to properly invalidate any usage + // of that composite in a subsequent loop iteration. + if (invalid_expressions.count(composite) || + block_composite_insert_overwrite.count(composite) || + hoisted_temporaries.count(id) || hoisted_temporaries.count(composite) || + maybe_get(composite) == nullptr) + { + can_modify_in_place = false; + } + else if (backend.requires_relaxed_precision_analysis && + has_decoration(composite, DecorationRelaxedPrecision) != + has_decoration(id, DecorationRelaxedPrecision) && + get(result_type).basetype != SPIRType::Struct) + { + // Similarly, if precision does not match for input and output, + // we cannot alias them. If we write a composite into a relaxed precision + // ID, we might get a false truncation. + can_modify_in_place = false; + } + + if (can_modify_in_place) + { + // Have to make sure the modified SSA value is bound to a temporary so we can modify it in-place. + if (!forced_temporaries.count(composite)) + force_temporary_and_recompile(composite); + + auto chain = access_chain_internal(composite, elems, length, ACCESS_CHAIN_INDEX_IS_LITERAL_BIT, nullptr); + statement(chain, " = ", to_unpacked_expression(obj), ";"); + set(id, to_expression(composite), result_type, true); + invalid_expressions.insert(composite); + composite_insert_overwritten.insert(composite); + } + else + { + if (maybe_get(composite) != nullptr) + { + emit_uninitialized_temporary_expression(result_type, id); + } + else + { + // Make a copy, then use access chain to store the variable. + statement(declare_temporary(result_type, id), to_expression(composite), ";"); + set(id, to_name(id), result_type, true); + } + + auto chain = access_chain_internal(id, elems, length, ACCESS_CHAIN_INDEX_IS_LITERAL_BIT, nullptr); + statement(chain, " = ", to_unpacked_expression(obj), ";"); + } + + break; + } + + case OpCopyMemory: + { + uint32_t lhs = ops[0]; + uint32_t rhs = ops[1]; + if (lhs != rhs) + { + uint32_t &tmp_id = extra_sub_expressions[instruction.offset | EXTRA_SUB_EXPRESSION_TYPE_STREAM_OFFSET]; + if (!tmp_id) + tmp_id = ir.increase_bound_by(1); + uint32_t tmp_type_id = expression_type(rhs).parent_type; + + EmbeddedInstruction fake_load, fake_store; + fake_load.op = OpLoad; + fake_load.length = 3; + fake_load.ops.push_back(tmp_type_id); + fake_load.ops.push_back(tmp_id); + fake_load.ops.push_back(rhs); + + fake_store.op = OpStore; + fake_store.length = 2; + fake_store.ops.push_back(lhs); + fake_store.ops.push_back(tmp_id); + + // Load and Store do a *lot* of workarounds, and we'd like to reuse them as much as possible. + // Synthesize a fake Load and Store pair for CopyMemory. + emit_instruction(fake_load); + emit_instruction(fake_store); + } + break; + } + + case OpCopyLogical: + { + // This is used for copying object of different types, arrays and structs. + // We need to unroll the copy, element-by-element. + uint32_t result_type = ops[0]; + uint32_t id = ops[1]; + uint32_t rhs = ops[2]; + + emit_uninitialized_temporary_expression(result_type, id); + emit_copy_logical_type(id, result_type, rhs, expression_type_id(rhs), {}); + break; + } + + case OpCopyObject: + { + uint32_t result_type = ops[0]; + uint32_t id = ops[1]; + uint32_t rhs = ops[2]; + bool pointer = get(result_type).pointer; + + auto *chain = maybe_get(rhs); + auto *imgsamp = maybe_get(rhs); + if (chain) + { + // Cannot lower to a SPIRExpression, just copy the object. + auto &e = set(id, *chain); + e.self = id; + } + else if (imgsamp) + { + // Cannot lower to a SPIRExpression, just copy the object. + // GLSL does not currently use this type and will never get here, but MSL does. + // Handled here instead of CompilerMSL for better integration and general handling, + // and in case GLSL or other subclasses require it in the future. + auto &e = set(id, *imgsamp); + e.self = id; + } + else if (expression_is_lvalue(rhs) && !pointer) + { + // Need a copy. + // For pointer types, we copy the pointer itself. + emit_op(result_type, id, to_unpacked_expression(rhs), false); + } + else + { + // RHS expression is immutable, so just forward it. + // Copying these things really make no sense, but + // seems to be allowed anyways. + auto &e = emit_op(result_type, id, to_expression(rhs), true, true); + if (pointer) + { + auto *var = maybe_get_backing_variable(rhs); + e.loaded_from = var ? var->self : ID(0); + } + + // If we're copying an access chain, need to inherit the read expressions. + auto *rhs_expr = maybe_get(rhs); + if (rhs_expr) + { + e.implied_read_expressions = rhs_expr->implied_read_expressions; + e.expression_dependencies = rhs_expr->expression_dependencies; + } + } + break; + } + + case OpVectorShuffle: + { + uint32_t result_type = ops[0]; + uint32_t id = ops[1]; + uint32_t vec0 = ops[2]; + uint32_t vec1 = ops[3]; + const auto *elems = &ops[4]; + length -= 4; + + auto &type0 = expression_type(vec0); + + // If we have the undefined swizzle index -1, we need to swizzle in undefined data, + // or in our case, T(0). + bool shuffle = false; + for (uint32_t i = 0; i < length; i++) + if (elems[i] >= type0.vecsize || elems[i] == 0xffffffffu) + shuffle = true; + + // Cannot use swizzles with packed expressions, force shuffle path. + if (!shuffle && has_extended_decoration(vec0, SPIRVCrossDecorationPhysicalTypePacked)) + shuffle = true; + + string expr; + bool should_fwd, trivial_forward; + + if (shuffle) + { + should_fwd = should_forward(vec0) && should_forward(vec1); + trivial_forward = should_suppress_usage_tracking(vec0) && should_suppress_usage_tracking(vec1); + + // Constructor style and shuffling from two different vectors. + SmallVector args; + for (uint32_t i = 0; i < length; i++) + { + if (elems[i] == 0xffffffffu) + { + // Use a constant 0 here. + // We could use the first component or similar, but then we risk propagating + // a value we might not need, and bog down codegen. + SPIRConstant c; + c.constant_type = type0.parent_type; + assert(type0.parent_type != ID(0)); + args.push_back(constant_expression(c)); + } + else if (elems[i] >= type0.vecsize) + args.push_back(to_extract_component_expression(vec1, elems[i] - type0.vecsize)); + else + args.push_back(to_extract_component_expression(vec0, elems[i])); + } + expr += join(type_to_glsl_constructor(get(result_type)), "(", merge(args), ")"); + } + else + { + should_fwd = should_forward(vec0); + trivial_forward = should_suppress_usage_tracking(vec0); + + // We only source from first vector, so can use swizzle. + // If the vector is packed, unpack it before applying a swizzle (needed for MSL) + expr += to_enclosed_unpacked_expression(vec0); + expr += "."; + for (uint32_t i = 0; i < length; i++) + { + assert(elems[i] != 0xffffffffu); + expr += index_to_swizzle(elems[i]); + } + + if (backend.swizzle_is_function && length > 1) + expr += "()"; + } + + // A shuffle is trivial in that it doesn't actually *do* anything. + // We inherit the forwardedness from our arguments to avoid flushing out to temporaries when it's not really needed. + + emit_op(result_type, id, expr, should_fwd, trivial_forward); + + inherit_expression_dependencies(id, vec0); + if (vec0 != vec1) + inherit_expression_dependencies(id, vec1); + break; + } + + // ALU + case OpIsNan: + if (!is_legacy()) + GLSL_UFOP(isnan); + else + { + // Check if the number doesn't equal itself + auto &type = get(ops[0]); + if (type.vecsize > 1) + emit_binary_func_op(ops[0], ops[1], ops[2], ops[2], "notEqual"); + else + emit_binary_op(ops[0], ops[1], ops[2], ops[2], "!="); + } + break; + + case OpIsInf: + if (!is_legacy()) + GLSL_UFOP(isinf); + else + { + // inf * 2 == inf by IEEE 754 rules, note this also applies to 0.0 + // This is more reliable than checking if product with zero is NaN + uint32_t result_type = ops[0]; + uint32_t result_id = ops[1]; + uint32_t operand = ops[2]; + + auto &type = get(result_type); + std::string expr; + if (type.vecsize > 1) + { + expr = type_to_glsl_constructor(type); + expr += '('; + for (uint32_t i = 0; i < type.vecsize; i++) + { + auto comp = to_extract_component_expression(operand, i); + expr += join(comp, " != 0.0 && 2.0 * ", comp, " == ", comp); + + if (i + 1 < type.vecsize) + expr += ", "; + } + expr += ')'; + } + else + { + // Register an extra read to force writing out a temporary + auto oper = to_enclosed_expression(operand); + track_expression_read(operand); + expr += join(oper, " != 0.0 && 2.0 * ", oper, " == ", oper); + } + emit_op(result_type, result_id, expr, should_forward(operand)); + + inherit_expression_dependencies(result_id, operand); + } + break; + + case OpSNegate: + if (implicit_integer_promotion || expression_type_id(ops[2]) != ops[0]) + GLSL_UOP_CAST(-); + else + GLSL_UOP(-); + break; + + case OpFNegate: + GLSL_UOP(-); + break; + + case OpIAdd: + { + // For simple arith ops, prefer the output type if there's a mismatch to avoid extra bitcasts. + auto type = get(ops[0]).basetype; + GLSL_BOP_CAST(+, type); + break; + } + + case OpFAdd: + GLSL_BOP(+); + break; + + case OpISub: + { + auto type = get(ops[0]).basetype; + GLSL_BOP_CAST(-, type); + break; + } + + case OpFSub: + GLSL_BOP(-); + break; + + case OpIMul: + { + auto type = get(ops[0]).basetype; + GLSL_BOP_CAST(*, type); + break; + } + + case OpVectorTimesMatrix: + case OpMatrixTimesVector: + { + // If the matrix needs transpose, just flip the multiply order. + auto *e = maybe_get(ops[opcode == OpMatrixTimesVector ? 2 : 3]); + if (e && e->need_transpose) + { + e->need_transpose = false; + string expr; + + if (opcode == OpMatrixTimesVector) + expr = join(to_enclosed_unpacked_expression(ops[3]), " * ", + enclose_expression(to_unpacked_row_major_matrix_expression(ops[2]))); + else + expr = join(enclose_expression(to_unpacked_row_major_matrix_expression(ops[3])), " * ", + to_enclosed_unpacked_expression(ops[2])); + + bool forward = should_forward(ops[2]) && should_forward(ops[3]); + emit_op(ops[0], ops[1], expr, forward); + e->need_transpose = true; + inherit_expression_dependencies(ops[1], ops[2]); + inherit_expression_dependencies(ops[1], ops[3]); + } + else + GLSL_BOP(*); + break; + } + + case OpMatrixTimesMatrix: + { + auto *a = maybe_get(ops[2]); + auto *b = maybe_get(ops[3]); + + // If both matrices need transpose, we can multiply in flipped order and tag the expression as transposed. + // a^T * b^T = (b * a)^T. + if (a && b && a->need_transpose && b->need_transpose) + { + a->need_transpose = false; + b->need_transpose = false; + auto expr = join(enclose_expression(to_unpacked_row_major_matrix_expression(ops[3])), " * ", + enclose_expression(to_unpacked_row_major_matrix_expression(ops[2]))); + bool forward = should_forward(ops[2]) && should_forward(ops[3]); + auto &e = emit_op(ops[0], ops[1], expr, forward); + e.need_transpose = true; + a->need_transpose = true; + b->need_transpose = true; + inherit_expression_dependencies(ops[1], ops[2]); + inherit_expression_dependencies(ops[1], ops[3]); + } + else + GLSL_BOP(*); + + break; + } + + case OpMatrixTimesScalar: + { + auto *a = maybe_get(ops[2]); + + // If the matrix need transpose, just mark the result as needing so. + if (a && a->need_transpose) + { + a->need_transpose = false; + auto expr = join(enclose_expression(to_unpacked_row_major_matrix_expression(ops[2])), " * ", + to_enclosed_unpacked_expression(ops[3])); + bool forward = should_forward(ops[2]) && should_forward(ops[3]); + auto &e = emit_op(ops[0], ops[1], expr, forward); + e.need_transpose = true; + a->need_transpose = true; + inherit_expression_dependencies(ops[1], ops[2]); + inherit_expression_dependencies(ops[1], ops[3]); + } + else + GLSL_BOP(*); + break; + } + + case OpFMul: + case OpVectorTimesScalar: + GLSL_BOP(*); + break; + + case OpOuterProduct: + if (options.version < 120) // Matches GLSL 1.10 / ESSL 1.00 + { + uint32_t result_type = ops[0]; + uint32_t id = ops[1]; + uint32_t a = ops[2]; + uint32_t b = ops[3]; + + auto &type = get(result_type); + string expr = type_to_glsl_constructor(type); + expr += "("; + for (uint32_t col = 0; col < type.columns; col++) + { + expr += to_enclosed_expression(a); + expr += " * "; + expr += to_extract_component_expression(b, col); + if (col + 1 < type.columns) + expr += ", "; + } + expr += ")"; + emit_op(result_type, id, expr, should_forward(a) && should_forward(b)); + inherit_expression_dependencies(id, a); + inherit_expression_dependencies(id, b); + } + else + GLSL_BFOP(outerProduct); + break; + + case OpDot: + GLSL_BFOP(dot); + break; + + case OpTranspose: + if (options.version < 120) // Matches GLSL 1.10 / ESSL 1.00 + { + // transpose() is not available, so instead, flip need_transpose, + // which can later be turned into an emulated transpose op by + // convert_row_major_matrix(), if necessary. + uint32_t result_type = ops[0]; + uint32_t result_id = ops[1]; + uint32_t input = ops[2]; + + // Force need_transpose to false temporarily to prevent + // to_expression() from doing the transpose. + bool need_transpose = false; + auto *input_e = maybe_get(input); + if (input_e) + swap(need_transpose, input_e->need_transpose); + + bool forward = should_forward(input); + auto &e = emit_op(result_type, result_id, to_expression(input), forward); + e.need_transpose = !need_transpose; + + // Restore the old need_transpose flag. + if (input_e) + input_e->need_transpose = need_transpose; + } + else + GLSL_UFOP(transpose); + break; + + case OpSRem: + { + uint32_t result_type = ops[0]; + uint32_t result_id = ops[1]; + uint32_t op0 = ops[2]; + uint32_t op1 = ops[3]; + + // Needs special handling. + bool forward = should_forward(op0) && should_forward(op1); + auto expr = join(to_enclosed_expression(op0), " - ", to_enclosed_expression(op1), " * ", "(", + to_enclosed_expression(op0), " / ", to_enclosed_expression(op1), ")"); + + if (implicit_integer_promotion) + expr = join(type_to_glsl(get(result_type)), '(', expr, ')'); + + emit_op(result_type, result_id, expr, forward); + inherit_expression_dependencies(result_id, op0); + inherit_expression_dependencies(result_id, op1); + break; + } + + case OpSDiv: + GLSL_BOP_CAST(/, int_type); + break; + + case OpUDiv: + GLSL_BOP_CAST(/, uint_type); + break; + + case OpIAddCarry: + case OpISubBorrow: + { + if (options.es && options.version < 310) + SPIRV_CROSS_THROW("Extended arithmetic is only available from ESSL 310."); + else if (!options.es && options.version < 400) + SPIRV_CROSS_THROW("Extended arithmetic is only available from GLSL 400."); + + uint32_t result_type = ops[0]; + uint32_t result_id = ops[1]; + uint32_t op0 = ops[2]; + uint32_t op1 = ops[3]; + auto &type = get(result_type); + emit_uninitialized_temporary_expression(result_type, result_id); + const char *op = opcode == OpIAddCarry ? "uaddCarry" : "usubBorrow"; + + statement(to_expression(result_id), ".", to_member_name(type, 0), " = ", op, "(", to_expression(op0), ", ", + to_expression(op1), ", ", to_expression(result_id), ".", to_member_name(type, 1), ");"); + break; + } + + case OpUMulExtended: + case OpSMulExtended: + { + if (options.es && options.version < 310) + SPIRV_CROSS_THROW("Extended arithmetic is only available from ESSL 310."); + else if (!options.es && options.version < 400) + SPIRV_CROSS_THROW("Extended arithmetic is only available from GLSL 4000."); + + uint32_t result_type = ops[0]; + uint32_t result_id = ops[1]; + uint32_t op0 = ops[2]; + uint32_t op1 = ops[3]; + auto &type = get(result_type); + emit_uninitialized_temporary_expression(result_type, result_id); + const char *op = opcode == OpUMulExtended ? "umulExtended" : "imulExtended"; + + statement(op, "(", to_expression(op0), ", ", to_expression(op1), ", ", to_expression(result_id), ".", + to_member_name(type, 1), ", ", to_expression(result_id), ".", to_member_name(type, 0), ");"); + break; + } + + case OpFDiv: + GLSL_BOP(/); + break; + + case OpShiftRightLogical: + GLSL_BOP_CAST(>>, uint_type); + break; + + case OpShiftRightArithmetic: + GLSL_BOP_CAST(>>, int_type); + break; + + case OpShiftLeftLogical: + { + auto type = get(ops[0]).basetype; + GLSL_BOP_CAST(<<, type); + break; + } + + case OpBitwiseOr: + { + auto type = get(ops[0]).basetype; + GLSL_BOP_CAST(|, type); + break; + } + + case OpBitwiseXor: + { + auto type = get(ops[0]).basetype; + GLSL_BOP_CAST(^, type); + break; + } + + case OpBitwiseAnd: + { + auto type = get(ops[0]).basetype; + GLSL_BOP_CAST(&, type); + break; + } + + case OpNot: + if (implicit_integer_promotion || expression_type_id(ops[2]) != ops[0]) + GLSL_UOP_CAST(~); + else + GLSL_UOP(~); + break; + + case OpUMod: + GLSL_BOP_CAST(%, uint_type); + break; + + case OpSMod: + GLSL_BOP_CAST(%, int_type); + break; + + case OpFMod: + GLSL_BFOP(mod); + break; + + case OpFRem: + { + uint32_t result_type = ops[0]; + uint32_t result_id = ops[1]; + uint32_t op0 = ops[2]; + uint32_t op1 = ops[3]; + + // Needs special handling. + bool forward = should_forward(op0) && should_forward(op1); + std::string expr; + if (!is_legacy()) + { + expr = join(to_enclosed_expression(op0), " - ", to_enclosed_expression(op1), " * ", "trunc(", + to_enclosed_expression(op0), " / ", to_enclosed_expression(op1), ")"); + } + else + { + // Legacy GLSL has no trunc, emulate by casting to int and back + auto &op0_type = expression_type(op0); + auto via_type = op0_type; + via_type.basetype = SPIRType::Int; + expr = join(to_enclosed_expression(op0), " - ", to_enclosed_expression(op1), " * ", + type_to_glsl(op0_type), "(", type_to_glsl(via_type), "(", + to_enclosed_expression(op0), " / ", to_enclosed_expression(op1), "))"); + } + + emit_op(result_type, result_id, expr, forward); + inherit_expression_dependencies(result_id, op0); + inherit_expression_dependencies(result_id, op1); + break; + } + + // Relational + case OpAny: + GLSL_UFOP(any); + break; + + case OpAll: + GLSL_UFOP(all); + break; + + case OpSelect: + emit_mix_op(ops[0], ops[1], ops[4], ops[3], ops[2]); + break; + + case OpLogicalOr: + { + // No vector variant in GLSL for logical OR. + auto result_type = ops[0]; + auto id = ops[1]; + auto &type = get(result_type); + + if (type.vecsize > 1) + emit_unrolled_binary_op(result_type, id, ops[2], ops[3], "||", false, SPIRType::Unknown); + else + GLSL_BOP(||); + break; + } + + case OpLogicalAnd: + { + // No vector variant in GLSL for logical AND. + auto result_type = ops[0]; + auto id = ops[1]; + auto &type = get(result_type); + + if (type.vecsize > 1) + emit_unrolled_binary_op(result_type, id, ops[2], ops[3], "&&", false, SPIRType::Unknown); + else + GLSL_BOP(&&); + break; + } + + case OpLogicalNot: + { + auto &type = get(ops[0]); + if (type.vecsize > 1) + GLSL_UFOP(not ); + else + GLSL_UOP(!); + break; + } + + case OpIEqual: + { + if (expression_type(ops[2]).vecsize > 1) + GLSL_BFOP_CAST(equal, int_type); + else + GLSL_BOP_CAST(==, int_type); + break; + } + + case OpLogicalEqual: + case OpFOrdEqual: + { + if (expression_type(ops[2]).vecsize > 1) + GLSL_BFOP(equal); + else + GLSL_BOP(==); + break; + } + + case OpINotEqual: + { + if (expression_type(ops[2]).vecsize > 1) + GLSL_BFOP_CAST(notEqual, int_type); + else + GLSL_BOP_CAST(!=, int_type); + break; + } + + case OpLogicalNotEqual: + case OpFOrdNotEqual: + case OpFUnordNotEqual: + { + // GLSL is fuzzy on what to do with ordered vs unordered not equal. + // glslang started emitting UnorderedNotEqual some time ago to harmonize with IEEE, + // but this means we have no easy way of implementing ordered not equal. + if (expression_type(ops[2]).vecsize > 1) + GLSL_BFOP(notEqual); + else + GLSL_BOP(!=); + break; + } + + case OpUGreaterThan: + case OpSGreaterThan: + { + auto type = opcode == OpUGreaterThan ? uint_type : int_type; + if (expression_type(ops[2]).vecsize > 1) + GLSL_BFOP_CAST(greaterThan, type); + else + GLSL_BOP_CAST(>, type); + break; + } + + case OpFOrdGreaterThan: + { + if (expression_type(ops[2]).vecsize > 1) + GLSL_BFOP(greaterThan); + else + GLSL_BOP(>); + break; + } + + case OpUGreaterThanEqual: + case OpSGreaterThanEqual: + { + auto type = opcode == OpUGreaterThanEqual ? uint_type : int_type; + if (expression_type(ops[2]).vecsize > 1) + GLSL_BFOP_CAST(greaterThanEqual, type); + else + GLSL_BOP_CAST(>=, type); + break; + } + + case OpFOrdGreaterThanEqual: + { + if (expression_type(ops[2]).vecsize > 1) + GLSL_BFOP(greaterThanEqual); + else + GLSL_BOP(>=); + break; + } + + case OpULessThan: + case OpSLessThan: + { + auto type = opcode == OpULessThan ? uint_type : int_type; + if (expression_type(ops[2]).vecsize > 1) + GLSL_BFOP_CAST(lessThan, type); + else + GLSL_BOP_CAST(<, type); + break; + } + + case OpFOrdLessThan: + { + if (expression_type(ops[2]).vecsize > 1) + GLSL_BFOP(lessThan); + else + GLSL_BOP(<); + break; + } + + case OpULessThanEqual: + case OpSLessThanEqual: + { + auto type = opcode == OpULessThanEqual ? uint_type : int_type; + if (expression_type(ops[2]).vecsize > 1) + GLSL_BFOP_CAST(lessThanEqual, type); + else + GLSL_BOP_CAST(<=, type); + break; + } + + case OpFOrdLessThanEqual: + { + if (expression_type(ops[2]).vecsize > 1) + GLSL_BFOP(lessThanEqual); + else + GLSL_BOP(<=); + break; + } + + // Conversion + case OpSConvert: + case OpConvertSToF: + case OpUConvert: + case OpConvertUToF: + { + auto input_type = opcode == OpSConvert || opcode == OpConvertSToF ? int_type : uint_type; + uint32_t result_type = ops[0]; + uint32_t id = ops[1]; + + auto &type = get(result_type); + auto &arg_type = expression_type(ops[2]); + auto func = type_to_glsl_constructor(type); + + if (arg_type.width < type.width || type_is_floating_point(type)) + emit_unary_func_op_cast(result_type, id, ops[2], func.c_str(), input_type, type.basetype); + else + emit_unary_func_op(result_type, id, ops[2], func.c_str()); + break; + } + + case OpConvertFToU: + case OpConvertFToS: + { + // Cast to expected arithmetic type, then potentially bitcast away to desired signedness. + uint32_t result_type = ops[0]; + uint32_t id = ops[1]; + auto &type = get(result_type); + auto expected_type = type; + auto &float_type = expression_type(ops[2]); + expected_type.basetype = + opcode == OpConvertFToS ? to_signed_basetype(type.width) : to_unsigned_basetype(type.width); + + auto func = type_to_glsl_constructor(expected_type); + emit_unary_func_op_cast(result_type, id, ops[2], func.c_str(), float_type.basetype, expected_type.basetype); + break; + } + + case OpFConvert: + { + uint32_t result_type = ops[0]; + uint32_t id = ops[1]; + + auto func = type_to_glsl_constructor(get(result_type)); + emit_unary_func_op(result_type, id, ops[2], func.c_str()); + break; + } + + case OpBitcast: + { + uint32_t result_type = ops[0]; + uint32_t id = ops[1]; + uint32_t arg = ops[2]; + + if (!emit_complex_bitcast(result_type, id, arg)) + { + auto op = bitcast_glsl_op(get(result_type), expression_type(arg)); + emit_unary_func_op(result_type, id, arg, op.c_str()); + } + break; + } + + case OpQuantizeToF16: + { + uint32_t result_type = ops[0]; + uint32_t id = ops[1]; + uint32_t arg = ops[2]; + + string op; + auto &type = get(result_type); + + switch (type.vecsize) + { + case 1: + op = join("unpackHalf2x16(packHalf2x16(vec2(", to_expression(arg), "))).x"); + break; + case 2: + op = join("unpackHalf2x16(packHalf2x16(", to_expression(arg), "))"); + break; + case 3: + { + auto op0 = join("unpackHalf2x16(packHalf2x16(", to_expression(arg), ".xy))"); + auto op1 = join("unpackHalf2x16(packHalf2x16(", to_expression(arg), ".zz)).x"); + op = join("vec3(", op0, ", ", op1, ")"); + break; + } + case 4: + { + auto op0 = join("unpackHalf2x16(packHalf2x16(", to_expression(arg), ".xy))"); + auto op1 = join("unpackHalf2x16(packHalf2x16(", to_expression(arg), ".zw))"); + op = join("vec4(", op0, ", ", op1, ")"); + break; + } + default: + SPIRV_CROSS_THROW("Illegal argument to OpQuantizeToF16."); + } + + emit_op(result_type, id, op, should_forward(arg)); + inherit_expression_dependencies(id, arg); + break; + } + + // Derivatives + case OpDPdx: + GLSL_UFOP(dFdx); + if (is_legacy_es()) + require_extension_internal("GL_OES_standard_derivatives"); + register_control_dependent_expression(ops[1]); + break; + + case OpDPdy: + GLSL_UFOP(dFdy); + if (is_legacy_es()) + require_extension_internal("GL_OES_standard_derivatives"); + register_control_dependent_expression(ops[1]); + break; + + case OpDPdxFine: + GLSL_UFOP(dFdxFine); + if (options.es) + { + SPIRV_CROSS_THROW("GL_ARB_derivative_control is unavailable in OpenGL ES."); + } + if (options.version < 450) + require_extension_internal("GL_ARB_derivative_control"); + register_control_dependent_expression(ops[1]); + break; + + case OpDPdyFine: + GLSL_UFOP(dFdyFine); + if (options.es) + { + SPIRV_CROSS_THROW("GL_ARB_derivative_control is unavailable in OpenGL ES."); + } + if (options.version < 450) + require_extension_internal("GL_ARB_derivative_control"); + register_control_dependent_expression(ops[1]); + break; + + case OpDPdxCoarse: + if (options.es) + { + SPIRV_CROSS_THROW("GL_ARB_derivative_control is unavailable in OpenGL ES."); + } + GLSL_UFOP(dFdxCoarse); + if (options.version < 450) + require_extension_internal("GL_ARB_derivative_control"); + register_control_dependent_expression(ops[1]); + break; + + case OpDPdyCoarse: + GLSL_UFOP(dFdyCoarse); + if (options.es) + { + SPIRV_CROSS_THROW("GL_ARB_derivative_control is unavailable in OpenGL ES."); + } + if (options.version < 450) + require_extension_internal("GL_ARB_derivative_control"); + register_control_dependent_expression(ops[1]); + break; + + case OpFwidth: + GLSL_UFOP(fwidth); + if (is_legacy_es()) + require_extension_internal("GL_OES_standard_derivatives"); + register_control_dependent_expression(ops[1]); + break; + + case OpFwidthCoarse: + GLSL_UFOP(fwidthCoarse); + if (options.es) + { + SPIRV_CROSS_THROW("GL_ARB_derivative_control is unavailable in OpenGL ES."); + } + if (options.version < 450) + require_extension_internal("GL_ARB_derivative_control"); + register_control_dependent_expression(ops[1]); + break; + + case OpFwidthFine: + GLSL_UFOP(fwidthFine); + if (options.es) + { + SPIRV_CROSS_THROW("GL_ARB_derivative_control is unavailable in OpenGL ES."); + } + if (options.version < 450) + require_extension_internal("GL_ARB_derivative_control"); + register_control_dependent_expression(ops[1]); + break; + + // Bitfield + case OpBitFieldInsert: + { + emit_bitfield_insert_op(ops[0], ops[1], ops[2], ops[3], ops[4], ops[5], "bitfieldInsert", SPIRType::Int); + break; + } + + case OpBitFieldSExtract: + { + emit_trinary_func_op_bitextract(ops[0], ops[1], ops[2], ops[3], ops[4], "bitfieldExtract", int_type, int_type, + SPIRType::Int, SPIRType::Int); + break; + } + + case OpBitFieldUExtract: + { + emit_trinary_func_op_bitextract(ops[0], ops[1], ops[2], ops[3], ops[4], "bitfieldExtract", uint_type, uint_type, + SPIRType::Int, SPIRType::Int); + break; + } + + case OpBitReverse: + // BitReverse does not have issues with sign since result type must match input type. + GLSL_UFOP(bitfieldReverse); + break; + + case OpBitCount: + { + auto basetype = expression_type(ops[2]).basetype; + emit_unary_func_op_cast(ops[0], ops[1], ops[2], "bitCount", basetype, int_type); + break; + } + + // Atomics + case OpAtomicExchange: + { + uint32_t result_type = ops[0]; + uint32_t id = ops[1]; + uint32_t ptr = ops[2]; + // Ignore semantics for now, probably only relevant to CL. + uint32_t val = ops[5]; + const char *op = check_atomic_image(ptr) ? "imageAtomicExchange" : "atomicExchange"; + + emit_atomic_func_op(result_type, id, ptr, val, op); + break; + } + + case OpAtomicCompareExchange: + { + uint32_t result_type = ops[0]; + uint32_t id = ops[1]; + uint32_t ptr = ops[2]; + uint32_t val = ops[6]; + uint32_t comp = ops[7]; + const char *op = check_atomic_image(ptr) ? "imageAtomicCompSwap" : "atomicCompSwap"; + + emit_atomic_func_op(result_type, id, ptr, comp, val, op); + break; + } + + case OpAtomicLoad: + { + // In plain GLSL, we have no atomic loads, so emulate this by fetch adding by 0 and hope compiler figures it out. + // Alternatively, we could rely on KHR_memory_model, but that's not very helpful for GL. + auto &type = expression_type(ops[2]); + forced_temporaries.insert(ops[1]); + bool atomic_image = check_atomic_image(ops[2]); + bool unsigned_type = (type.basetype == SPIRType::UInt) || + (atomic_image && get(type.image.type).basetype == SPIRType::UInt); + const char *op = atomic_image ? "imageAtomicAdd" : "atomicAdd"; + const char *increment = unsigned_type ? "0u" : "0"; + emit_op(ops[0], ops[1], + join(op, "(", + to_non_uniform_aware_expression(ops[2]), ", ", increment, ")"), false); + flush_all_atomic_capable_variables(); + break; + } + + case OpAtomicStore: + { + // In plain GLSL, we have no atomic stores, so emulate this with an atomic exchange where we don't consume the result. + // Alternatively, we could rely on KHR_memory_model, but that's not very helpful for GL. + uint32_t ptr = ops[0]; + // Ignore semantics for now, probably only relevant to CL. + uint32_t val = ops[3]; + const char *op = check_atomic_image(ptr) ? "imageAtomicExchange" : "atomicExchange"; + statement(op, "(", to_non_uniform_aware_expression(ptr), ", ", to_expression(val), ");"); + flush_all_atomic_capable_variables(); + break; + } + + case OpAtomicIIncrement: + case OpAtomicIDecrement: + { + forced_temporaries.insert(ops[1]); + auto &type = expression_type(ops[2]); + if (type.storage == StorageClassAtomicCounter) + { + // Legacy GLSL stuff, not sure if this is relevant to support. + if (opcode == OpAtomicIIncrement) + GLSL_UFOP(atomicCounterIncrement); + else + GLSL_UFOP(atomicCounterDecrement); + } + else + { + bool atomic_image = check_atomic_image(ops[2]); + bool unsigned_type = (type.basetype == SPIRType::UInt) || + (atomic_image && get(type.image.type).basetype == SPIRType::UInt); + const char *op = atomic_image ? "imageAtomicAdd" : "atomicAdd"; + + const char *increment = nullptr; + if (opcode == OpAtomicIIncrement && unsigned_type) + increment = "1u"; + else if (opcode == OpAtomicIIncrement) + increment = "1"; + else if (unsigned_type) + increment = "uint(-1)"; + else + increment = "-1"; + + emit_op(ops[0], ops[1], + join(op, "(", to_non_uniform_aware_expression(ops[2]), ", ", increment, ")"), false); + } + + flush_all_atomic_capable_variables(); + break; + } + + case OpAtomicIAdd: + case OpAtomicFAddEXT: + { + const char *op = check_atomic_image(ops[2]) ? "imageAtomicAdd" : "atomicAdd"; + emit_atomic_func_op(ops[0], ops[1], ops[2], ops[5], op); + break; + } + + case OpAtomicISub: + { + const char *op = check_atomic_image(ops[2]) ? "imageAtomicAdd" : "atomicAdd"; + forced_temporaries.insert(ops[1]); + auto expr = join(op, "(", to_non_uniform_aware_expression(ops[2]), ", -", to_enclosed_expression(ops[5]), ")"); + emit_op(ops[0], ops[1], expr, should_forward(ops[2]) && should_forward(ops[5])); + flush_all_atomic_capable_variables(); + break; + } + + case OpAtomicSMin: + case OpAtomicUMin: + { + const char *op = check_atomic_image(ops[2]) ? "imageAtomicMin" : "atomicMin"; + emit_atomic_func_op(ops[0], ops[1], ops[2], ops[5], op); + break; + } + + case OpAtomicSMax: + case OpAtomicUMax: + { + const char *op = check_atomic_image(ops[2]) ? "imageAtomicMax" : "atomicMax"; + emit_atomic_func_op(ops[0], ops[1], ops[2], ops[5], op); + break; + } + + case OpAtomicAnd: + { + const char *op = check_atomic_image(ops[2]) ? "imageAtomicAnd" : "atomicAnd"; + emit_atomic_func_op(ops[0], ops[1], ops[2], ops[5], op); + break; + } + + case OpAtomicOr: + { + const char *op = check_atomic_image(ops[2]) ? "imageAtomicOr" : "atomicOr"; + emit_atomic_func_op(ops[0], ops[1], ops[2], ops[5], op); + break; + } + + case OpAtomicXor: + { + const char *op = check_atomic_image(ops[2]) ? "imageAtomicXor" : "atomicXor"; + emit_atomic_func_op(ops[0], ops[1], ops[2], ops[5], op); + break; + } + + // Geometry shaders + case OpEmitVertex: + statement("EmitVertex();"); + break; + + case OpEndPrimitive: + statement("EndPrimitive();"); + break; + + case OpEmitStreamVertex: + { + if (options.es) + SPIRV_CROSS_THROW("Multi-stream geometry shaders not supported in ES."); + else if (!options.es && options.version < 400) + SPIRV_CROSS_THROW("Multi-stream geometry shaders only supported in GLSL 400."); + + auto stream_expr = to_expression(ops[0]); + if (expression_type(ops[0]).basetype != SPIRType::Int) + stream_expr = join("int(", stream_expr, ")"); + statement("EmitStreamVertex(", stream_expr, ");"); + break; + } + + case OpEndStreamPrimitive: + { + if (options.es) + SPIRV_CROSS_THROW("Multi-stream geometry shaders not supported in ES."); + else if (!options.es && options.version < 400) + SPIRV_CROSS_THROW("Multi-stream geometry shaders only supported in GLSL 400."); + + auto stream_expr = to_expression(ops[0]); + if (expression_type(ops[0]).basetype != SPIRType::Int) + stream_expr = join("int(", stream_expr, ")"); + statement("EndStreamPrimitive(", stream_expr, ");"); + break; + } + + // Textures + case OpImageSampleExplicitLod: + case OpImageSampleProjExplicitLod: + case OpImageSampleDrefExplicitLod: + case OpImageSampleProjDrefExplicitLod: + case OpImageSampleImplicitLod: + case OpImageSampleProjImplicitLod: + case OpImageSampleDrefImplicitLod: + case OpImageSampleProjDrefImplicitLod: + case OpImageFetch: + case OpImageGather: + case OpImageDrefGather: + // Gets a bit hairy, so move this to a separate instruction. + emit_texture_op(instruction, false); + break; + + case OpImageSparseSampleExplicitLod: + case OpImageSparseSampleProjExplicitLod: + case OpImageSparseSampleDrefExplicitLod: + case OpImageSparseSampleProjDrefExplicitLod: + case OpImageSparseSampleImplicitLod: + case OpImageSparseSampleProjImplicitLod: + case OpImageSparseSampleDrefImplicitLod: + case OpImageSparseSampleProjDrefImplicitLod: + case OpImageSparseFetch: + case OpImageSparseGather: + case OpImageSparseDrefGather: + // Gets a bit hairy, so move this to a separate instruction. + emit_texture_op(instruction, true); + break; + + case OpImageSparseTexelsResident: + if (options.es) + SPIRV_CROSS_THROW("Sparse feedback is not supported in GLSL."); + require_extension_internal("GL_ARB_sparse_texture2"); + emit_unary_func_op_cast(ops[0], ops[1], ops[2], "sparseTexelsResidentARB", int_type, SPIRType::Boolean); + break; + + case OpImage: + { + uint32_t result_type = ops[0]; + uint32_t id = ops[1]; + + // Suppress usage tracking. + auto &e = emit_op(result_type, id, to_expression(ops[2]), true, true); + + // When using the image, we need to know which variable it is actually loaded from. + auto *var = maybe_get_backing_variable(ops[2]); + e.loaded_from = var ? var->self : ID(0); + break; + } + + case OpImageQueryLod: + { + const char *op = nullptr; + if (!options.es && options.version < 400) + { + require_extension_internal("GL_ARB_texture_query_lod"); + // For some reason, the ARB spec is all-caps. + op = "textureQueryLOD"; + } + else if (options.es) + { + if (options.version < 300) + SPIRV_CROSS_THROW("textureQueryLod not supported in legacy ES"); + require_extension_internal("GL_EXT_texture_query_lod"); + op = "textureQueryLOD"; + } + else + op = "textureQueryLod"; + + auto sampler_expr = to_expression(ops[2]); + if (has_decoration(ops[2], DecorationNonUniform)) + { + if (maybe_get_backing_variable(ops[2])) + convert_non_uniform_expression(sampler_expr, ops[2]); + else if (*backend.nonuniform_qualifier != '\0') + sampler_expr = join(backend.nonuniform_qualifier, "(", sampler_expr, ")"); + } + + bool forward = should_forward(ops[3]); + emit_op(ops[0], ops[1], + join(op, "(", sampler_expr, ", ", to_unpacked_expression(ops[3]), ")"), + forward); + inherit_expression_dependencies(ops[1], ops[2]); + inherit_expression_dependencies(ops[1], ops[3]); + register_control_dependent_expression(ops[1]); + break; + } + + case OpImageQueryLevels: + { + uint32_t result_type = ops[0]; + uint32_t id = ops[1]; + + if (!options.es && options.version < 430) + require_extension_internal("GL_ARB_texture_query_levels"); + if (options.es) + SPIRV_CROSS_THROW("textureQueryLevels not supported in ES profile."); + + auto expr = join("textureQueryLevels(", convert_separate_image_to_expression(ops[2]), ")"); + auto &restype = get(ops[0]); + expr = bitcast_expression(restype, SPIRType::Int, expr); + emit_op(result_type, id, expr, true); + break; + } + + case OpImageQuerySamples: + { + auto &type = expression_type(ops[2]); + uint32_t result_type = ops[0]; + uint32_t id = ops[1]; + + if (options.es) + SPIRV_CROSS_THROW("textureSamples and imageSamples not supported in ES profile."); + else if (options.version < 450) + require_extension_internal("GL_ARB_texture_query_samples"); + + string expr; + if (type.image.sampled == 2) + expr = join("imageSamples(", to_non_uniform_aware_expression(ops[2]), ")"); + else + expr = join("textureSamples(", convert_separate_image_to_expression(ops[2]), ")"); + + auto &restype = get(ops[0]); + expr = bitcast_expression(restype, SPIRType::Int, expr); + emit_op(result_type, id, expr, true); + break; + } + + case OpSampledImage: + { + uint32_t result_type = ops[0]; + uint32_t id = ops[1]; + emit_sampled_image_op(result_type, id, ops[2], ops[3]); + inherit_expression_dependencies(id, ops[2]); + inherit_expression_dependencies(id, ops[3]); + break; + } + + case OpImageQuerySizeLod: + { + uint32_t result_type = ops[0]; + uint32_t id = ops[1]; + uint32_t img = ops[2]; + auto &type = expression_type(img); + auto &imgtype = get(type.self); + + std::string fname = "textureSize"; + if (is_legacy_desktop()) + { + fname = legacy_tex_op(fname, imgtype, img); + } + else if (is_legacy_es()) + SPIRV_CROSS_THROW("textureSize is not supported in ESSL 100."); + + auto expr = join(fname, "(", convert_separate_image_to_expression(img), ", ", + bitcast_expression(SPIRType::Int, ops[3]), ")"); + + // ES needs to emulate 1D images as 2D. + if (type.image.dim == Dim1D && options.es) + expr = join(expr, ".x"); + + auto &restype = get(ops[0]); + expr = bitcast_expression(restype, SPIRType::Int, expr); + emit_op(result_type, id, expr, true); + break; + } + + // Image load/store + case OpImageRead: + case OpImageSparseRead: + { + // We added Nonreadable speculatively to the OpImage variable due to glslangValidator + // not adding the proper qualifiers. + // If it turns out we need to read the image after all, remove the qualifier and recompile. + auto *var = maybe_get_backing_variable(ops[2]); + if (var) + { + auto &flags = get_decoration_bitset(var->self); + if (flags.get(DecorationNonReadable)) + { + unset_decoration(var->self, DecorationNonReadable); + force_recompile(); + } + } + + uint32_t result_type = ops[0]; + uint32_t id = ops[1]; + + bool pure; + string imgexpr; + auto &type = expression_type(ops[2]); + + if (var && var->remapped_variable) // Remapped input, just read as-is without any op-code + { + if (type.image.ms) + SPIRV_CROSS_THROW("Trying to remap multisampled image to variable, this is not possible."); + + auto itr = + find_if(begin(pls_inputs), end(pls_inputs), [var](const PlsRemap &pls) { return pls.id == var->self; }); + + if (itr == end(pls_inputs)) + { + // For non-PLS inputs, we rely on subpass type remapping information to get it right + // since ImageRead always returns 4-component vectors and the backing type is opaque. + if (!var->remapped_components) + SPIRV_CROSS_THROW("subpassInput was remapped, but remap_components is not set correctly."); + imgexpr = remap_swizzle(get(result_type), var->remapped_components, to_expression(ops[2])); + } + else + { + // PLS input could have different number of components than what the SPIR expects, swizzle to + // the appropriate vector size. + uint32_t components = pls_format_to_components(itr->format); + imgexpr = remap_swizzle(get(result_type), components, to_expression(ops[2])); + } + pure = true; + } + else if (type.image.dim == DimSubpassData) + { + if (var && subpass_input_is_framebuffer_fetch(var->self)) + { + imgexpr = to_expression(var->self); + } + else if (options.vulkan_semantics) + { + // With Vulkan semantics, use the proper Vulkan GLSL construct. + if (type.image.ms) + { + uint32_t operands = ops[4]; + if (operands != ImageOperandsSampleMask || length != 6) + SPIRV_CROSS_THROW("Multisampled image used in OpImageRead, but unexpected " + "operand mask was used."); + + uint32_t samples = ops[5]; + imgexpr = join("subpassLoad(", to_non_uniform_aware_expression(ops[2]), ", ", to_expression(samples), ")"); + } + else + imgexpr = join("subpassLoad(", to_non_uniform_aware_expression(ops[2]), ")"); + } + else + { + if (type.image.ms) + { + uint32_t operands = ops[4]; + if (operands != ImageOperandsSampleMask || length != 6) + SPIRV_CROSS_THROW("Multisampled image used in OpImageRead, but unexpected " + "operand mask was used."); + + uint32_t samples = ops[5]; + imgexpr = join("texelFetch(", to_non_uniform_aware_expression(ops[2]), ", ivec2(gl_FragCoord.xy), ", + to_expression(samples), ")"); + } + else + { + // Implement subpass loads via texture barrier style sampling. + imgexpr = join("texelFetch(", to_non_uniform_aware_expression(ops[2]), ", ivec2(gl_FragCoord.xy), 0)"); + } + } + imgexpr = remap_swizzle(get(result_type), 4, imgexpr); + pure = true; + } + else + { + bool sparse = opcode == OpImageSparseRead; + uint32_t sparse_code_id = 0; + uint32_t sparse_texel_id = 0; + if (sparse) + emit_sparse_feedback_temporaries(ops[0], ops[1], sparse_code_id, sparse_texel_id); + + // imageLoad only accepts int coords, not uint. + auto coord_expr = to_expression(ops[3]); + auto target_coord_type = expression_type(ops[3]); + target_coord_type.basetype = SPIRType::Int; + coord_expr = bitcast_expression(target_coord_type, expression_type(ops[3]).basetype, coord_expr); + + // ES needs to emulate 1D images as 2D. + if (type.image.dim == Dim1D && options.es) + coord_expr = join("ivec2(", coord_expr, ", 0)"); + + // Plain image load/store. + if (sparse) + { + if (type.image.ms) + { + uint32_t operands = ops[4]; + if (operands != ImageOperandsSampleMask || length != 6) + SPIRV_CROSS_THROW("Multisampled image used in OpImageRead, but unexpected " + "operand mask was used."); + + uint32_t samples = ops[5]; + statement(to_expression(sparse_code_id), " = sparseImageLoadARB(", to_non_uniform_aware_expression(ops[2]), ", ", + coord_expr, ", ", to_expression(samples), ", ", to_expression(sparse_texel_id), ");"); + } + else + { + statement(to_expression(sparse_code_id), " = sparseImageLoadARB(", to_non_uniform_aware_expression(ops[2]), ", ", + coord_expr, ", ", to_expression(sparse_texel_id), ");"); + } + imgexpr = join(type_to_glsl(get(result_type)), "(", to_expression(sparse_code_id), ", ", + to_expression(sparse_texel_id), ")"); + } + else + { + if (type.image.ms) + { + uint32_t operands = ops[4]; + if (operands != ImageOperandsSampleMask || length != 6) + SPIRV_CROSS_THROW("Multisampled image used in OpImageRead, but unexpected " + "operand mask was used."); + + uint32_t samples = ops[5]; + imgexpr = + join("imageLoad(", to_non_uniform_aware_expression(ops[2]), ", ", coord_expr, ", ", to_expression(samples), ")"); + } + else + imgexpr = join("imageLoad(", to_non_uniform_aware_expression(ops[2]), ", ", coord_expr, ")"); + } + + if (!sparse) + imgexpr = remap_swizzle(get(result_type), 4, imgexpr); + pure = false; + } + + if (var) + { + bool forward = forced_temporaries.find(id) == end(forced_temporaries); + auto &e = emit_op(result_type, id, imgexpr, forward); + + // We only need to track dependencies if we're reading from image load/store. + if (!pure) + { + e.loaded_from = var->self; + if (forward) + var->dependees.push_back(id); + } + } + else + emit_op(result_type, id, imgexpr, false); + + inherit_expression_dependencies(id, ops[2]); + if (type.image.ms) + inherit_expression_dependencies(id, ops[5]); + break; + } + + case OpImageTexelPointer: + { + uint32_t result_type = ops[0]; + uint32_t id = ops[1]; + + auto coord_expr = to_expression(ops[3]); + auto target_coord_type = expression_type(ops[3]); + target_coord_type.basetype = SPIRType::Int; + coord_expr = bitcast_expression(target_coord_type, expression_type(ops[3]).basetype, coord_expr); + + auto expr = join(to_expression(ops[2]), ", ", coord_expr); + auto &e = set(id, expr, result_type, true); + + // When using the pointer, we need to know which variable it is actually loaded from. + auto *var = maybe_get_backing_variable(ops[2]); + e.loaded_from = var ? var->self : ID(0); + inherit_expression_dependencies(id, ops[3]); + break; + } + + case OpImageWrite: + { + // We added Nonwritable speculatively to the OpImage variable due to glslangValidator + // not adding the proper qualifiers. + // If it turns out we need to write to the image after all, remove the qualifier and recompile. + auto *var = maybe_get_backing_variable(ops[0]); + if (var) + { + if (has_decoration(var->self, DecorationNonWritable)) + { + unset_decoration(var->self, DecorationNonWritable); + force_recompile(); + } + } + + auto &type = expression_type(ops[0]); + auto &value_type = expression_type(ops[2]); + auto store_type = value_type; + store_type.vecsize = 4; + + // imageStore only accepts int coords, not uint. + auto coord_expr = to_expression(ops[1]); + auto target_coord_type = expression_type(ops[1]); + target_coord_type.basetype = SPIRType::Int; + coord_expr = bitcast_expression(target_coord_type, expression_type(ops[1]).basetype, coord_expr); + + // ES needs to emulate 1D images as 2D. + if (type.image.dim == Dim1D && options.es) + coord_expr = join("ivec2(", coord_expr, ", 0)"); + + if (type.image.ms) + { + uint32_t operands = ops[3]; + if (operands != ImageOperandsSampleMask || length != 5) + SPIRV_CROSS_THROW("Multisampled image used in OpImageWrite, but unexpected operand mask was used."); + uint32_t samples = ops[4]; + statement("imageStore(", to_non_uniform_aware_expression(ops[0]), ", ", coord_expr, ", ", to_expression(samples), ", ", + remap_swizzle(store_type, value_type.vecsize, to_expression(ops[2])), ");"); + } + else + statement("imageStore(", to_non_uniform_aware_expression(ops[0]), ", ", coord_expr, ", ", + remap_swizzle(store_type, value_type.vecsize, to_expression(ops[2])), ");"); + + if (var && variable_storage_is_aliased(*var)) + flush_all_aliased_variables(); + break; + } + + case OpImageQuerySize: + { + auto &type = expression_type(ops[2]); + uint32_t result_type = ops[0]; + uint32_t id = ops[1]; + + if (type.basetype == SPIRType::Image) + { + string expr; + if (type.image.sampled == 2) + { + if (!options.es && options.version < 430) + require_extension_internal("GL_ARB_shader_image_size"); + else if (options.es && options.version < 310) + SPIRV_CROSS_THROW("At least ESSL 3.10 required for imageSize."); + + // The size of an image is always constant. + expr = join("imageSize(", to_non_uniform_aware_expression(ops[2]), ")"); + } + else + { + // This path is hit for samplerBuffers and multisampled images which do not have LOD. + std::string fname = "textureSize"; + if (is_legacy()) + { + auto &imgtype = get(type.self); + fname = legacy_tex_op(fname, imgtype, ops[2]); + } + expr = join(fname, "(", convert_separate_image_to_expression(ops[2]), ")"); + } + + auto &restype = get(ops[0]); + expr = bitcast_expression(restype, SPIRType::Int, expr); + emit_op(result_type, id, expr, true); + } + else + SPIRV_CROSS_THROW("Invalid type for OpImageQuerySize."); + break; + } + + case OpImageSampleWeightedQCOM: + case OpImageBoxFilterQCOM: + case OpImageBlockMatchSSDQCOM: + case OpImageBlockMatchSADQCOM: + { + require_extension_internal("GL_QCOM_image_processing"); + uint32_t result_type_id = ops[0]; + uint32_t id = ops[1]; + string expr; + switch (opcode) + { + case OpImageSampleWeightedQCOM: + expr = "textureWeightedQCOM"; + break; + case OpImageBoxFilterQCOM: + expr = "textureBoxFilterQCOM"; + break; + case OpImageBlockMatchSSDQCOM: + expr = "textureBlockMatchSSDQCOM"; + break; + case OpImageBlockMatchSADQCOM: + expr = "textureBlockMatchSADQCOM"; + break; + default: + SPIRV_CROSS_THROW("Invalid opcode for QCOM_image_processing."); + } + expr += "("; + + bool forward = false; + expr += to_expression(ops[2]); + expr += ", " + to_expression(ops[3]); + + switch (opcode) + { + case OpImageSampleWeightedQCOM: + expr += ", " + to_non_uniform_aware_expression(ops[4]); + break; + case OpImageBoxFilterQCOM: + expr += ", " + to_expression(ops[4]); + break; + case OpImageBlockMatchSSDQCOM: + case OpImageBlockMatchSADQCOM: + expr += ", " + to_non_uniform_aware_expression(ops[4]); + expr += ", " + to_expression(ops[5]); + expr += ", " + to_expression(ops[6]); + break; + default: + SPIRV_CROSS_THROW("Invalid opcode for QCOM_image_processing."); + } + + expr += ")"; + emit_op(result_type_id, id, expr, forward); + + inherit_expression_dependencies(id, ops[3]); + if (opcode == OpImageBlockMatchSSDQCOM || opcode == OpImageBlockMatchSADQCOM) + inherit_expression_dependencies(id, ops[5]); + + break; + } + + // Compute + case OpControlBarrier: + case OpMemoryBarrier: + { + uint32_t execution_scope = 0; + uint32_t memory; + uint32_t semantics; + + if (opcode == OpMemoryBarrier) + { + memory = evaluate_constant_u32(ops[0]); + semantics = evaluate_constant_u32(ops[1]); + } + else + { + execution_scope = evaluate_constant_u32(ops[0]); + memory = evaluate_constant_u32(ops[1]); + semantics = evaluate_constant_u32(ops[2]); + } + + if (execution_scope == ScopeSubgroup || memory == ScopeSubgroup) + { + // OpControlBarrier with ScopeSubgroup is subgroupBarrier() + if (opcode != OpControlBarrier) + { + request_subgroup_feature(ShaderSubgroupSupportHelper::SubgroupMemBarrier); + } + else + { + request_subgroup_feature(ShaderSubgroupSupportHelper::SubgroupBarrier); + } + } + + if (execution_scope != ScopeSubgroup && get_entry_point().model == ExecutionModelTessellationControl) + { + // Control shaders only have barriers, and it implies memory barriers. + if (opcode == OpControlBarrier) + statement("barrier();"); + break; + } + + // We only care about these flags, acquire/release and friends are not relevant to GLSL. + semantics = mask_relevant_memory_semantics(semantics); + + if (opcode == OpMemoryBarrier) + { + // If we are a memory barrier, and the next instruction is a control barrier, check if that memory barrier + // does what we need, so we avoid redundant barriers. + const Instruction *next = get_next_instruction_in_block(instruction); + if (next && next->op == OpControlBarrier) + { + auto *next_ops = stream(*next); + uint32_t next_memory = evaluate_constant_u32(next_ops[1]); + uint32_t next_semantics = evaluate_constant_u32(next_ops[2]); + next_semantics = mask_relevant_memory_semantics(next_semantics); + + bool memory_scope_covered = false; + if (next_memory == memory) + memory_scope_covered = true; + else if (next_semantics == MemorySemanticsWorkgroupMemoryMask) + { + // If we only care about workgroup memory, either Device or Workgroup scope is fine, + // scope does not have to match. + if ((next_memory == ScopeDevice || next_memory == ScopeWorkgroup) && + (memory == ScopeDevice || memory == ScopeWorkgroup)) + { + memory_scope_covered = true; + } + } + else if (memory == ScopeWorkgroup && next_memory == ScopeDevice) + { + // The control barrier has device scope, but the memory barrier just has workgroup scope. + memory_scope_covered = true; + } + + // If we have the same memory scope, and all memory types are covered, we're good. + if (memory_scope_covered && (semantics & next_semantics) == semantics) + break; + } + } + + // We are synchronizing some memory or syncing execution, + // so we cannot forward any loads beyond the memory barrier. + if (semantics || opcode == OpControlBarrier) + { + assert(current_emitting_block); + flush_control_dependent_expressions(current_emitting_block->self); + flush_all_active_variables(); + } + + if (memory == ScopeWorkgroup) // Only need to consider memory within a group + { + if (semantics == MemorySemanticsWorkgroupMemoryMask) + { + // OpControlBarrier implies a memory barrier for shared memory as well. + bool implies_shared_barrier = opcode == OpControlBarrier && execution_scope == ScopeWorkgroup; + if (!implies_shared_barrier) + statement("memoryBarrierShared();"); + } + else if (semantics != 0) + statement("groupMemoryBarrier();"); + } + else if (memory == ScopeSubgroup) + { + const uint32_t all_barriers = + MemorySemanticsWorkgroupMemoryMask | MemorySemanticsUniformMemoryMask | MemorySemanticsImageMemoryMask; + + if (semantics & (MemorySemanticsCrossWorkgroupMemoryMask | MemorySemanticsSubgroupMemoryMask)) + { + // These are not relevant for GLSL, but assume it means memoryBarrier(). + // memoryBarrier() does everything, so no need to test anything else. + statement("subgroupMemoryBarrier();"); + } + else if ((semantics & all_barriers) == all_barriers) + { + // Short-hand instead of emitting 3 barriers. + statement("subgroupMemoryBarrier();"); + } + else + { + // Pick out individual barriers. + if (semantics & MemorySemanticsWorkgroupMemoryMask) + statement("subgroupMemoryBarrierShared();"); + if (semantics & MemorySemanticsUniformMemoryMask) + statement("subgroupMemoryBarrierBuffer();"); + if (semantics & MemorySemanticsImageMemoryMask) + statement("subgroupMemoryBarrierImage();"); + } + } + else + { + const uint32_t all_barriers = + MemorySemanticsWorkgroupMemoryMask | MemorySemanticsUniformMemoryMask | MemorySemanticsImageMemoryMask; + + if (semantics & (MemorySemanticsCrossWorkgroupMemoryMask | MemorySemanticsSubgroupMemoryMask)) + { + // These are not relevant for GLSL, but assume it means memoryBarrier(). + // memoryBarrier() does everything, so no need to test anything else. + statement("memoryBarrier();"); + } + else if ((semantics & all_barriers) == all_barriers) + { + // Short-hand instead of emitting 4 barriers. + statement("memoryBarrier();"); + } + else + { + // Pick out individual barriers. + if (semantics & MemorySemanticsWorkgroupMemoryMask) + statement("memoryBarrierShared();"); + if (semantics & MemorySemanticsUniformMemoryMask) + statement("memoryBarrierBuffer();"); + if (semantics & MemorySemanticsImageMemoryMask) + statement("memoryBarrierImage();"); + } + } + + if (opcode == OpControlBarrier) + { + if (execution_scope == ScopeSubgroup) + statement("subgroupBarrier();"); + else + statement("barrier();"); + } + break; + } + + case OpExtInst: + { + uint32_t extension_set = ops[2]; + auto ext = get(extension_set).ext; + + if (ext == SPIRExtension::GLSL) + { + emit_glsl_op(ops[0], ops[1], ops[3], &ops[4], length - 4); + } + else if (ext == SPIRExtension::SPV_AMD_shader_ballot) + { + emit_spv_amd_shader_ballot_op(ops[0], ops[1], ops[3], &ops[4], length - 4); + } + else if (ext == SPIRExtension::SPV_AMD_shader_explicit_vertex_parameter) + { + emit_spv_amd_shader_explicit_vertex_parameter_op(ops[0], ops[1], ops[3], &ops[4], length - 4); + } + else if (ext == SPIRExtension::SPV_AMD_shader_trinary_minmax) + { + emit_spv_amd_shader_trinary_minmax_op(ops[0], ops[1], ops[3], &ops[4], length - 4); + } + else if (ext == SPIRExtension::SPV_AMD_gcn_shader) + { + emit_spv_amd_gcn_shader_op(ops[0], ops[1], ops[3], &ops[4], length - 4); + } + else if (ext == SPIRExtension::SPV_debug_info || + ext == SPIRExtension::NonSemanticShaderDebugInfo || + ext == SPIRExtension::NonSemanticGeneric) + { + break; // Ignore SPIR-V debug information extended instructions. + } + else if (ext == SPIRExtension::NonSemanticDebugPrintf) + { + // Operation 1 is printf. + if (ops[3] == 1) + { + if (!options.vulkan_semantics) + SPIRV_CROSS_THROW("Debug printf is only supported in Vulkan GLSL.\n"); + require_extension_internal("GL_EXT_debug_printf"); + auto &format_string = get(ops[4]).str; + string expr = join("debugPrintfEXT(\"", format_string, "\""); + for (uint32_t i = 5; i < length; i++) + { + expr += ", "; + expr += to_expression(ops[i]); + } + statement(expr, ");"); + } + } + else + { + statement("// unimplemented ext op ", instruction.op); + break; + } + + break; + } + + // Legacy sub-group stuff ... + case OpSubgroupBallotKHR: + { + uint32_t result_type = ops[0]; + uint32_t id = ops[1]; + string expr; + expr = join("uvec4(unpackUint2x32(ballotARB(" + to_expression(ops[2]) + ")), 0u, 0u)"); + emit_op(result_type, id, expr, should_forward(ops[2])); + + require_extension_internal("GL_ARB_shader_ballot"); + inherit_expression_dependencies(id, ops[2]); + register_control_dependent_expression(ops[1]); + break; + } + + case OpSubgroupFirstInvocationKHR: + { + uint32_t result_type = ops[0]; + uint32_t id = ops[1]; + emit_unary_func_op(result_type, id, ops[2], "readFirstInvocationARB"); + + require_extension_internal("GL_ARB_shader_ballot"); + register_control_dependent_expression(ops[1]); + break; + } + + case OpSubgroupReadInvocationKHR: + { + uint32_t result_type = ops[0]; + uint32_t id = ops[1]; + emit_binary_func_op(result_type, id, ops[2], ops[3], "readInvocationARB"); + + require_extension_internal("GL_ARB_shader_ballot"); + register_control_dependent_expression(ops[1]); + break; + } + + case OpSubgroupAllKHR: + { + uint32_t result_type = ops[0]; + uint32_t id = ops[1]; + emit_unary_func_op(result_type, id, ops[2], "allInvocationsARB"); + + require_extension_internal("GL_ARB_shader_group_vote"); + register_control_dependent_expression(ops[1]); + break; + } + + case OpSubgroupAnyKHR: + { + uint32_t result_type = ops[0]; + uint32_t id = ops[1]; + emit_unary_func_op(result_type, id, ops[2], "anyInvocationARB"); + + require_extension_internal("GL_ARB_shader_group_vote"); + register_control_dependent_expression(ops[1]); + break; + } + + case OpSubgroupAllEqualKHR: + { + uint32_t result_type = ops[0]; + uint32_t id = ops[1]; + emit_unary_func_op(result_type, id, ops[2], "allInvocationsEqualARB"); + + require_extension_internal("GL_ARB_shader_group_vote"); + register_control_dependent_expression(ops[1]); + break; + } + + case OpGroupIAddNonUniformAMD: + case OpGroupFAddNonUniformAMD: + { + uint32_t result_type = ops[0]; + uint32_t id = ops[1]; + emit_unary_func_op(result_type, id, ops[4], "addInvocationsNonUniformAMD"); + + require_extension_internal("GL_AMD_shader_ballot"); + register_control_dependent_expression(ops[1]); + break; + } + + case OpGroupFMinNonUniformAMD: + case OpGroupUMinNonUniformAMD: + case OpGroupSMinNonUniformAMD: + { + uint32_t result_type = ops[0]; + uint32_t id = ops[1]; + emit_unary_func_op(result_type, id, ops[4], "minInvocationsNonUniformAMD"); + + require_extension_internal("GL_AMD_shader_ballot"); + register_control_dependent_expression(ops[1]); + break; + } + + case OpGroupFMaxNonUniformAMD: + case OpGroupUMaxNonUniformAMD: + case OpGroupSMaxNonUniformAMD: + { + uint32_t result_type = ops[0]; + uint32_t id = ops[1]; + emit_unary_func_op(result_type, id, ops[4], "maxInvocationsNonUniformAMD"); + + require_extension_internal("GL_AMD_shader_ballot"); + register_control_dependent_expression(ops[1]); + break; + } + + case OpFragmentMaskFetchAMD: + { + auto &type = expression_type(ops[2]); + uint32_t result_type = ops[0]; + uint32_t id = ops[1]; + + if (type.image.dim == spv::DimSubpassData) + { + emit_unary_func_op(result_type, id, ops[2], "fragmentMaskFetchAMD"); + } + else + { + emit_binary_func_op(result_type, id, ops[2], ops[3], "fragmentMaskFetchAMD"); + } + + require_extension_internal("GL_AMD_shader_fragment_mask"); + break; + } + + case OpFragmentFetchAMD: + { + auto &type = expression_type(ops[2]); + uint32_t result_type = ops[0]; + uint32_t id = ops[1]; + + if (type.image.dim == spv::DimSubpassData) + { + emit_binary_func_op(result_type, id, ops[2], ops[4], "fragmentFetchAMD"); + } + else + { + emit_trinary_func_op(result_type, id, ops[2], ops[3], ops[4], "fragmentFetchAMD"); + } + + require_extension_internal("GL_AMD_shader_fragment_mask"); + break; + } + + // Vulkan 1.1 sub-group stuff ... + case OpGroupNonUniformElect: + case OpGroupNonUniformBroadcast: + case OpGroupNonUniformBroadcastFirst: + case OpGroupNonUniformBallot: + case OpGroupNonUniformInverseBallot: + case OpGroupNonUniformBallotBitExtract: + case OpGroupNonUniformBallotBitCount: + case OpGroupNonUniformBallotFindLSB: + case OpGroupNonUniformBallotFindMSB: + case OpGroupNonUniformShuffle: + case OpGroupNonUniformShuffleXor: + case OpGroupNonUniformShuffleUp: + case OpGroupNonUniformShuffleDown: + case OpGroupNonUniformAll: + case OpGroupNonUniformAny: + case OpGroupNonUniformAllEqual: + case OpGroupNonUniformFAdd: + case OpGroupNonUniformIAdd: + case OpGroupNonUniformFMul: + case OpGroupNonUniformIMul: + case OpGroupNonUniformFMin: + case OpGroupNonUniformFMax: + case OpGroupNonUniformSMin: + case OpGroupNonUniformSMax: + case OpGroupNonUniformUMin: + case OpGroupNonUniformUMax: + case OpGroupNonUniformBitwiseAnd: + case OpGroupNonUniformBitwiseOr: + case OpGroupNonUniformBitwiseXor: + case OpGroupNonUniformLogicalAnd: + case OpGroupNonUniformLogicalOr: + case OpGroupNonUniformLogicalXor: + case OpGroupNonUniformQuadSwap: + case OpGroupNonUniformQuadBroadcast: + emit_subgroup_op(instruction); + break; + + case OpFUnordEqual: + case OpFUnordLessThan: + case OpFUnordGreaterThan: + case OpFUnordLessThanEqual: + case OpFUnordGreaterThanEqual: + { + // GLSL doesn't specify if floating point comparisons are ordered or unordered, + // but glslang always emits ordered floating point compares for GLSL. + // To get unordered compares, we can test the opposite thing and invert the result. + // This way, we force true when there is any NaN present. + uint32_t op0 = ops[2]; + uint32_t op1 = ops[3]; + + string expr; + if (expression_type(op0).vecsize > 1) + { + const char *comp_op = nullptr; + switch (opcode) + { + case OpFUnordEqual: + comp_op = "notEqual"; + break; + + case OpFUnordLessThan: + comp_op = "greaterThanEqual"; + break; + + case OpFUnordLessThanEqual: + comp_op = "greaterThan"; + break; + + case OpFUnordGreaterThan: + comp_op = "lessThanEqual"; + break; + + case OpFUnordGreaterThanEqual: + comp_op = "lessThan"; + break; + + default: + assert(0); + break; + } + + expr = join("not(", comp_op, "(", to_unpacked_expression(op0), ", ", to_unpacked_expression(op1), "))"); + } + else + { + const char *comp_op = nullptr; + switch (opcode) + { + case OpFUnordEqual: + comp_op = " != "; + break; + + case OpFUnordLessThan: + comp_op = " >= "; + break; + + case OpFUnordLessThanEqual: + comp_op = " > "; + break; + + case OpFUnordGreaterThan: + comp_op = " <= "; + break; + + case OpFUnordGreaterThanEqual: + comp_op = " < "; + break; + + default: + assert(0); + break; + } + + expr = join("!(", to_enclosed_unpacked_expression(op0), comp_op, to_enclosed_unpacked_expression(op1), ")"); + } + + emit_op(ops[0], ops[1], expr, should_forward(op0) && should_forward(op1)); + inherit_expression_dependencies(ops[1], op0); + inherit_expression_dependencies(ops[1], op1); + break; + } + + case OpReportIntersectionKHR: + // NV is same opcode. + forced_temporaries.insert(ops[1]); + if (ray_tracing_is_khr) + GLSL_BFOP(reportIntersectionEXT); + else + GLSL_BFOP(reportIntersectionNV); + flush_control_dependent_expressions(current_emitting_block->self); + break; + case OpIgnoreIntersectionNV: + // KHR variant is a terminator. + statement("ignoreIntersectionNV();"); + flush_control_dependent_expressions(current_emitting_block->self); + break; + case OpTerminateRayNV: + // KHR variant is a terminator. + statement("terminateRayNV();"); + flush_control_dependent_expressions(current_emitting_block->self); + break; + case OpTraceNV: + statement("traceNV(", to_non_uniform_aware_expression(ops[0]), ", ", to_expression(ops[1]), ", ", to_expression(ops[2]), ", ", + to_expression(ops[3]), ", ", to_expression(ops[4]), ", ", to_expression(ops[5]), ", ", + to_expression(ops[6]), ", ", to_expression(ops[7]), ", ", to_expression(ops[8]), ", ", + to_expression(ops[9]), ", ", to_expression(ops[10]), ");"); + flush_control_dependent_expressions(current_emitting_block->self); + break; + case OpTraceRayKHR: + if (!has_decoration(ops[10], DecorationLocation)) + SPIRV_CROSS_THROW("A memory declaration object must be used in TraceRayKHR."); + statement("traceRayEXT(", to_non_uniform_aware_expression(ops[0]), ", ", to_expression(ops[1]), ", ", to_expression(ops[2]), ", ", + to_expression(ops[3]), ", ", to_expression(ops[4]), ", ", to_expression(ops[5]), ", ", + to_expression(ops[6]), ", ", to_expression(ops[7]), ", ", to_expression(ops[8]), ", ", + to_expression(ops[9]), ", ", get_decoration(ops[10], DecorationLocation), ");"); + flush_control_dependent_expressions(current_emitting_block->self); + break; + case OpExecuteCallableNV: + statement("executeCallableNV(", to_expression(ops[0]), ", ", to_expression(ops[1]), ");"); + flush_control_dependent_expressions(current_emitting_block->self); + break; + case OpExecuteCallableKHR: + if (!has_decoration(ops[1], DecorationLocation)) + SPIRV_CROSS_THROW("A memory declaration object must be used in ExecuteCallableKHR."); + statement("executeCallableEXT(", to_expression(ops[0]), ", ", get_decoration(ops[1], DecorationLocation), ");"); + flush_control_dependent_expressions(current_emitting_block->self); + break; + + // Don't bother forwarding temporaries. Avoids having to test expression invalidation with ray query objects. + case OpRayQueryInitializeKHR: + flush_variable_declaration(ops[0]); + statement("rayQueryInitializeEXT(", + to_expression(ops[0]), ", ", to_expression(ops[1]), ", ", + to_expression(ops[2]), ", ", to_expression(ops[3]), ", ", + to_expression(ops[4]), ", ", to_expression(ops[5]), ", ", + to_expression(ops[6]), ", ", to_expression(ops[7]), ");"); + break; + case OpRayQueryProceedKHR: + flush_variable_declaration(ops[0]); + emit_op(ops[0], ops[1], join("rayQueryProceedEXT(", to_expression(ops[2]), ")"), false); + break; + case OpRayQueryTerminateKHR: + flush_variable_declaration(ops[0]); + statement("rayQueryTerminateEXT(", to_expression(ops[0]), ");"); + break; + case OpRayQueryGenerateIntersectionKHR: + flush_variable_declaration(ops[0]); + statement("rayQueryGenerateIntersectionEXT(", to_expression(ops[0]), ", ", to_expression(ops[1]), ");"); + break; + case OpRayQueryConfirmIntersectionKHR: + flush_variable_declaration(ops[0]); + statement("rayQueryConfirmIntersectionEXT(", to_expression(ops[0]), ");"); + break; +#define GLSL_RAY_QUERY_GET_OP(op) \ + case OpRayQueryGet##op##KHR: \ + flush_variable_declaration(ops[2]); \ + emit_op(ops[0], ops[1], join("rayQueryGet" #op "EXT(", to_expression(ops[2]), ")"), false); \ + break +#define GLSL_RAY_QUERY_GET_OP2(op) \ + case OpRayQueryGet##op##KHR: \ + flush_variable_declaration(ops[2]); \ + emit_op(ops[0], ops[1], join("rayQueryGet" #op "EXT(", to_expression(ops[2]), ", ", "bool(", to_expression(ops[3]), "))"), false); \ + break + GLSL_RAY_QUERY_GET_OP(RayTMin); + GLSL_RAY_QUERY_GET_OP(RayFlags); + GLSL_RAY_QUERY_GET_OP(WorldRayOrigin); + GLSL_RAY_QUERY_GET_OP(WorldRayDirection); + GLSL_RAY_QUERY_GET_OP(IntersectionCandidateAABBOpaque); + GLSL_RAY_QUERY_GET_OP2(IntersectionType); + GLSL_RAY_QUERY_GET_OP2(IntersectionT); + GLSL_RAY_QUERY_GET_OP2(IntersectionInstanceCustomIndex); + GLSL_RAY_QUERY_GET_OP2(IntersectionInstanceId); + GLSL_RAY_QUERY_GET_OP2(IntersectionInstanceShaderBindingTableRecordOffset); + GLSL_RAY_QUERY_GET_OP2(IntersectionGeometryIndex); + GLSL_RAY_QUERY_GET_OP2(IntersectionPrimitiveIndex); + GLSL_RAY_QUERY_GET_OP2(IntersectionBarycentrics); + GLSL_RAY_QUERY_GET_OP2(IntersectionFrontFace); + GLSL_RAY_QUERY_GET_OP2(IntersectionObjectRayDirection); + GLSL_RAY_QUERY_GET_OP2(IntersectionObjectRayOrigin); + GLSL_RAY_QUERY_GET_OP2(IntersectionObjectToWorld); + GLSL_RAY_QUERY_GET_OP2(IntersectionWorldToObject); +#undef GLSL_RAY_QUERY_GET_OP +#undef GLSL_RAY_QUERY_GET_OP2 + + case OpConvertUToAccelerationStructureKHR: + { + require_extension_internal("GL_EXT_ray_tracing"); + + bool elide_temporary = should_forward(ops[2]) && forced_temporaries.count(ops[1]) == 0 && + !hoisted_temporaries.count(ops[1]); + + if (elide_temporary) + { + GLSL_UFOP(accelerationStructureEXT); + } + else + { + // Force this path in subsequent iterations. + forced_temporaries.insert(ops[1]); + + // We cannot declare a temporary acceleration structure in GLSL. + // If we get to this point, we'll have to emit a temporary uvec2, + // and cast to RTAS on demand. + statement(declare_temporary(expression_type_id(ops[2]), ops[1]), to_unpacked_expression(ops[2]), ";"); + // Use raw SPIRExpression interface to block all usage tracking. + set(ops[1], join("accelerationStructureEXT(", to_name(ops[1]), ")"), ops[0], true); + } + break; + } + + case OpConvertUToPtr: + { + auto &type = get(ops[0]); + if (type.storage != StorageClassPhysicalStorageBufferEXT) + SPIRV_CROSS_THROW("Only StorageClassPhysicalStorageBufferEXT is supported by OpConvertUToPtr."); + + auto &in_type = expression_type(ops[2]); + if (in_type.vecsize == 2) + require_extension_internal("GL_EXT_buffer_reference_uvec2"); + + auto op = type_to_glsl(type); + emit_unary_func_op(ops[0], ops[1], ops[2], op.c_str()); + break; + } + + case OpConvertPtrToU: + { + auto &type = get(ops[0]); + auto &ptr_type = expression_type(ops[2]); + if (ptr_type.storage != StorageClassPhysicalStorageBufferEXT) + SPIRV_CROSS_THROW("Only StorageClassPhysicalStorageBufferEXT is supported by OpConvertPtrToU."); + + if (type.vecsize == 2) + require_extension_internal("GL_EXT_buffer_reference_uvec2"); + + auto op = type_to_glsl(type); + emit_unary_func_op(ops[0], ops[1], ops[2], op.c_str()); + break; + } + + case OpUndef: + // Undefined value has been declared. + break; + + case OpLine: + { + emit_line_directive(ops[0], ops[1]); + break; + } + + case OpNoLine: + break; + + case OpDemoteToHelperInvocationEXT: + if (!options.vulkan_semantics) + SPIRV_CROSS_THROW("GL_EXT_demote_to_helper_invocation is only supported in Vulkan GLSL."); + require_extension_internal("GL_EXT_demote_to_helper_invocation"); + statement(backend.demote_literal, ";"); + break; + + case OpIsHelperInvocationEXT: + if (!options.vulkan_semantics) + SPIRV_CROSS_THROW("GL_EXT_demote_to_helper_invocation is only supported in Vulkan GLSL."); + require_extension_internal("GL_EXT_demote_to_helper_invocation"); + // Helper lane state with demote is volatile by nature. + // Do not forward this. + emit_op(ops[0], ops[1], "helperInvocationEXT()", false); + break; + + case OpBeginInvocationInterlockEXT: + // If the interlock is complex, we emit this elsewhere. + if (!interlocked_is_complex) + { + statement("SPIRV_Cross_beginInvocationInterlock();"); + flush_all_active_variables(); + // Make sure forwarding doesn't propagate outside interlock region. + } + break; + + case OpEndInvocationInterlockEXT: + // If the interlock is complex, we emit this elsewhere. + if (!interlocked_is_complex) + { + statement("SPIRV_Cross_endInvocationInterlock();"); + flush_all_active_variables(); + // Make sure forwarding doesn't propagate outside interlock region. + } + break; + + case OpSetMeshOutputsEXT: + statement("SetMeshOutputsEXT(", to_unpacked_expression(ops[0]), ", ", to_unpacked_expression(ops[1]), ");"); + break; + + case OpReadClockKHR: + { + auto &type = get(ops[0]); + auto scope = static_cast(evaluate_constant_u32(ops[2])); + const char *op = nullptr; + // Forwarding clock statements leads to a scenario where an SSA value can take on different + // values every time it's evaluated. Block any forwarding attempt. + // We also might want to invalidate all expressions to function as a sort of optimization + // barrier, but might be overkill for now. + if (scope == ScopeDevice) + { + require_extension_internal("GL_EXT_shader_realtime_clock"); + if (type.basetype == SPIRType::BaseType::UInt64) + op = "clockRealtimeEXT()"; + else if (type.basetype == SPIRType::BaseType::UInt && type.vecsize == 2) + op = "clockRealtime2x32EXT()"; + else + SPIRV_CROSS_THROW("Unsupported result type for OpReadClockKHR opcode."); + } + else if (scope == ScopeSubgroup) + { + require_extension_internal("GL_ARB_shader_clock"); + if (type.basetype == SPIRType::BaseType::UInt64) + op = "clockARB()"; + else if (type.basetype == SPIRType::BaseType::UInt && type.vecsize == 2) + op = "clock2x32ARB()"; + else + SPIRV_CROSS_THROW("Unsupported result type for OpReadClockKHR opcode."); + } + else + SPIRV_CROSS_THROW("Unsupported scope for OpReadClockKHR opcode."); + + emit_op(ops[0], ops[1], op, false); + break; + } + + default: + statement("// unimplemented op ", instruction.op); + break; + } +} + +// Appends function arguments, mapped from global variables, beyond the specified arg index. +// This is used when a function call uses fewer arguments than the function defines. +// This situation may occur if the function signature has been dynamically modified to +// extract global variables referenced from within the function, and convert them to +// function arguments. This is necessary for shader languages that do not support global +// access to shader input content from within a function (eg. Metal). Each additional +// function args uses the name of the global variable. Function nesting will modify the +// functions and function calls all the way up the nesting chain. +void CompilerGLSL::append_global_func_args(const SPIRFunction &func, uint32_t index, SmallVector &arglist) +{ + auto &args = func.arguments; + uint32_t arg_cnt = uint32_t(args.size()); + for (uint32_t arg_idx = index; arg_idx < arg_cnt; arg_idx++) + { + auto &arg = args[arg_idx]; + assert(arg.alias_global_variable); + + // If the underlying variable needs to be declared + // (ie. a local variable with deferred declaration), do so now. + uint32_t var_id = get(arg.id).basevariable; + if (var_id) + flush_variable_declaration(var_id); + + arglist.push_back(to_func_call_arg(arg, arg.id)); + } +} + +string CompilerGLSL::to_member_name(const SPIRType &type, uint32_t index) +{ + if (type.type_alias != TypeID(0) && + !has_extended_decoration(type.type_alias, SPIRVCrossDecorationBufferBlockRepacked)) + { + return to_member_name(get(type.type_alias), index); + } + + auto &memb = ir.meta[type.self].members; + if (index < memb.size() && !memb[index].alias.empty()) + return memb[index].alias; + else + return join("_m", index); +} + +string CompilerGLSL::to_member_reference(uint32_t, const SPIRType &type, uint32_t index, bool) +{ + return join(".", to_member_name(type, index)); +} + +string CompilerGLSL::to_multi_member_reference(const SPIRType &type, const SmallVector &indices) +{ + string ret; + auto *member_type = &type; + for (auto &index : indices) + { + ret += join(".", to_member_name(*member_type, index)); + member_type = &get(member_type->member_types[index]); + } + return ret; +} + +void CompilerGLSL::add_member_name(SPIRType &type, uint32_t index) +{ + auto &memb = ir.meta[type.self].members; + if (index < memb.size() && !memb[index].alias.empty()) + { + auto &name = memb[index].alias; + if (name.empty()) + return; + + ParsedIR::sanitize_identifier(name, true, true); + update_name_cache(type.member_name_cache, name); + } +} + +// Checks whether the ID is a row_major matrix that requires conversion before use +bool CompilerGLSL::is_non_native_row_major_matrix(uint32_t id) +{ + // Natively supported row-major matrices do not need to be converted. + // Legacy targets do not support row major. + if (backend.native_row_major_matrix && !is_legacy()) + return false; + + auto *e = maybe_get(id); + if (e) + return e->need_transpose; + else + return has_decoration(id, DecorationRowMajor); +} + +// Checks whether the member is a row_major matrix that requires conversion before use +bool CompilerGLSL::member_is_non_native_row_major_matrix(const SPIRType &type, uint32_t index) +{ + // Natively supported row-major matrices do not need to be converted. + if (backend.native_row_major_matrix && !is_legacy()) + return false; + + // Non-matrix or column-major matrix types do not need to be converted. + if (!has_member_decoration(type.self, index, DecorationRowMajor)) + return false; + + // Only square row-major matrices can be converted at this time. + // Converting non-square matrices will require defining custom GLSL function that + // swaps matrix elements while retaining the original dimensional form of the matrix. + const auto mbr_type = get(type.member_types[index]); + if (mbr_type.columns != mbr_type.vecsize) + SPIRV_CROSS_THROW("Row-major matrices must be square on this platform."); + + return true; +} + +// Checks if we need to remap physical type IDs when declaring the type in a buffer. +bool CompilerGLSL::member_is_remapped_physical_type(const SPIRType &type, uint32_t index) const +{ + return has_extended_member_decoration(type.self, index, SPIRVCrossDecorationPhysicalTypeID); +} + +// Checks whether the member is in packed data type, that might need to be unpacked. +bool CompilerGLSL::member_is_packed_physical_type(const SPIRType &type, uint32_t index) const +{ + return has_extended_member_decoration(type.self, index, SPIRVCrossDecorationPhysicalTypePacked); +} + +// Wraps the expression string in a function call that converts the +// row_major matrix result of the expression to a column_major matrix. +// Base implementation uses the standard library transpose() function. +// Subclasses may override to use a different function. +string CompilerGLSL::convert_row_major_matrix(string exp_str, const SPIRType &exp_type, uint32_t /* physical_type_id */, + bool /*is_packed*/, bool relaxed) +{ + strip_enclosed_expression(exp_str); + if (!is_matrix(exp_type)) + { + auto column_index = exp_str.find_last_of('['); + if (column_index == string::npos) + return exp_str; + + auto column_expr = exp_str.substr(column_index); + exp_str.resize(column_index); + + auto end_deferred_index = column_expr.find_last_of(']'); + if (end_deferred_index != string::npos && end_deferred_index + 1 != column_expr.size()) + { + // If we have any data member fixups, it must be transposed so that it refers to this index. + // E.g. [0].data followed by [1] would be shuffled to [1][0].data which is wrong, + // and needs to be [1].data[0] instead. + end_deferred_index++; + column_expr = column_expr.substr(end_deferred_index) + + column_expr.substr(0, end_deferred_index); + } + + auto transposed_expr = type_to_glsl_constructor(exp_type) + "("; + + // Loading a column from a row-major matrix. Unroll the load. + for (uint32_t c = 0; c < exp_type.vecsize; c++) + { + transposed_expr += join(exp_str, '[', c, ']', column_expr); + if (c + 1 < exp_type.vecsize) + transposed_expr += ", "; + } + + transposed_expr += ")"; + return transposed_expr; + } + else if (options.version < 120) + { + // GLSL 110, ES 100 do not have transpose(), so emulate it. Note that + // these GLSL versions do not support non-square matrices. + if (exp_type.vecsize == 2 && exp_type.columns == 2) + require_polyfill(PolyfillTranspose2x2, relaxed); + else if (exp_type.vecsize == 3 && exp_type.columns == 3) + require_polyfill(PolyfillTranspose3x3, relaxed); + else if (exp_type.vecsize == 4 && exp_type.columns == 4) + require_polyfill(PolyfillTranspose4x4, relaxed); + else + SPIRV_CROSS_THROW("Non-square matrices are not supported in legacy GLSL, cannot transpose."); + return join("spvTranspose", (options.es && relaxed) ? "MP" : "", "(", exp_str, ")"); + } + else + return join("transpose(", exp_str, ")"); +} + +string CompilerGLSL::variable_decl(const SPIRType &type, const string &name, uint32_t id) +{ + string type_name = type_to_glsl(type, id); + remap_variable_type_name(type, name, type_name); + return join(type_name, " ", name, type_to_array_glsl(type, id)); +} + +bool CompilerGLSL::variable_decl_is_remapped_storage(const SPIRVariable &var, StorageClass storage) const +{ + return var.storage == storage; +} + +// Emit a structure member. Subclasses may override to modify output, +// or to dynamically add a padding member if needed. +void CompilerGLSL::emit_struct_member(const SPIRType &type, uint32_t member_type_id, uint32_t index, + const string &qualifier, uint32_t) +{ + auto &membertype = get(member_type_id); + + Bitset memberflags; + auto &memb = ir.meta[type.self].members; + if (index < memb.size()) + memberflags = memb[index].decoration_flags; + + string qualifiers; + bool is_block = ir.meta[type.self].decoration.decoration_flags.get(DecorationBlock) || + ir.meta[type.self].decoration.decoration_flags.get(DecorationBufferBlock); + + if (is_block) + qualifiers = to_interpolation_qualifiers(memberflags); + + statement(layout_for_member(type, index), qualifiers, qualifier, flags_to_qualifiers_glsl(membertype, memberflags), + variable_decl(membertype, to_member_name(type, index)), ";"); +} + +void CompilerGLSL::emit_struct_padding_target(const SPIRType &) +{ +} + +string CompilerGLSL::flags_to_qualifiers_glsl(const SPIRType &type, const Bitset &flags) +{ + // GL_EXT_buffer_reference variables can be marked as restrict. + if (flags.get(DecorationRestrictPointerEXT)) + return "restrict "; + + string qual; + + if (type_is_floating_point(type) && flags.get(DecorationNoContraction) && backend.support_precise_qualifier) + qual = "precise "; + + // Structs do not have precision qualifiers, neither do doubles (desktop only anyways, so no mediump/highp). + bool type_supports_precision = + type.basetype == SPIRType::Float || type.basetype == SPIRType::Int || type.basetype == SPIRType::UInt || + type.basetype == SPIRType::Image || type.basetype == SPIRType::SampledImage || + type.basetype == SPIRType::Sampler; + + if (!type_supports_precision) + return qual; + + if (options.es) + { + auto &execution = get_entry_point(); + + if (type.basetype == SPIRType::UInt && is_legacy_es()) + { + // HACK: This is a bool. See comment in type_to_glsl(). + qual += "lowp "; + } + else if (flags.get(DecorationRelaxedPrecision)) + { + bool implied_fmediump = type.basetype == SPIRType::Float && + options.fragment.default_float_precision == Options::Mediump && + execution.model == ExecutionModelFragment; + + bool implied_imediump = (type.basetype == SPIRType::Int || type.basetype == SPIRType::UInt) && + options.fragment.default_int_precision == Options::Mediump && + execution.model == ExecutionModelFragment; + + qual += (implied_fmediump || implied_imediump) ? "" : "mediump "; + } + else + { + bool implied_fhighp = + type.basetype == SPIRType::Float && ((options.fragment.default_float_precision == Options::Highp && + execution.model == ExecutionModelFragment) || + (execution.model != ExecutionModelFragment)); + + bool implied_ihighp = (type.basetype == SPIRType::Int || type.basetype == SPIRType::UInt) && + ((options.fragment.default_int_precision == Options::Highp && + execution.model == ExecutionModelFragment) || + (execution.model != ExecutionModelFragment)); + + qual += (implied_fhighp || implied_ihighp) ? "" : "highp "; + } + } + else if (backend.allow_precision_qualifiers) + { + // Vulkan GLSL supports precision qualifiers, even in desktop profiles, which is convenient. + // The default is highp however, so only emit mediump in the rare case that a shader has these. + if (flags.get(DecorationRelaxedPrecision)) + qual += "mediump "; + } + + return qual; +} + +string CompilerGLSL::to_precision_qualifiers_glsl(uint32_t id) +{ + auto &type = expression_type(id); + bool use_precision_qualifiers = backend.allow_precision_qualifiers; + if (use_precision_qualifiers && (type.basetype == SPIRType::Image || type.basetype == SPIRType::SampledImage)) + { + // Force mediump for the sampler type. We cannot declare 16-bit or smaller image types. + auto &result_type = get(type.image.type); + if (result_type.width < 32) + return "mediump "; + } + return flags_to_qualifiers_glsl(type, ir.meta[id].decoration.decoration_flags); +} + +void CompilerGLSL::fixup_io_block_patch_primitive_qualifiers(const SPIRVariable &var) +{ + // Works around weird behavior in glslangValidator where + // a patch out block is translated to just block members getting the decoration. + // To make glslang not complain when we compile again, we have to transform this back to a case where + // the variable itself has Patch decoration, and not members. + // Same for perprimitiveEXT. + auto &type = get(var.basetype); + if (has_decoration(type.self, DecorationBlock)) + { + uint32_t member_count = uint32_t(type.member_types.size()); + Decoration promoted_decoration = {}; + bool do_promote_decoration = false; + for (uint32_t i = 0; i < member_count; i++) + { + if (has_member_decoration(type.self, i, DecorationPatch)) + { + promoted_decoration = DecorationPatch; + do_promote_decoration = true; + break; + } + else if (has_member_decoration(type.self, i, DecorationPerPrimitiveEXT)) + { + promoted_decoration = DecorationPerPrimitiveEXT; + do_promote_decoration = true; + break; + } + } + + if (do_promote_decoration) + { + set_decoration(var.self, promoted_decoration); + for (uint32_t i = 0; i < member_count; i++) + unset_member_decoration(type.self, i, promoted_decoration); + } + } +} + +string CompilerGLSL::to_qualifiers_glsl(uint32_t id) +{ + auto &flags = get_decoration_bitset(id); + string res; + + auto *var = maybe_get(id); + + if (var && var->storage == StorageClassWorkgroup && !backend.shared_is_implied) + res += "shared "; + else if (var && var->storage == StorageClassTaskPayloadWorkgroupEXT && !backend.shared_is_implied) + res += "taskPayloadSharedEXT "; + + res += to_interpolation_qualifiers(flags); + if (var) + res += to_storage_qualifiers_glsl(*var); + + auto &type = expression_type(id); + if (type.image.dim != DimSubpassData && type.image.sampled == 2) + { + if (flags.get(DecorationCoherent)) + res += "coherent "; + if (flags.get(DecorationRestrict)) + res += "restrict "; + + if (flags.get(DecorationNonWritable)) + res += "readonly "; + + bool formatted_load = type.image.format == ImageFormatUnknown; + if (flags.get(DecorationNonReadable)) + { + res += "writeonly "; + formatted_load = false; + } + + if (formatted_load) + { + if (!options.es) + require_extension_internal("GL_EXT_shader_image_load_formatted"); + else + SPIRV_CROSS_THROW("Cannot use GL_EXT_shader_image_load_formatted in ESSL."); + } + } + + res += to_precision_qualifiers_glsl(id); + + return res; +} + +string CompilerGLSL::argument_decl(const SPIRFunction::Parameter &arg) +{ + // glslangValidator seems to make all arguments pointer no matter what which is rather bizarre ... + auto &type = expression_type(arg.id); + const char *direction = ""; + + if (type.pointer) + { + if (arg.write_count && arg.read_count) + direction = "inout "; + else if (arg.write_count) + direction = "out "; + } + + return join(direction, to_qualifiers_glsl(arg.id), variable_decl(type, to_name(arg.id), arg.id)); +} + +string CompilerGLSL::to_initializer_expression(const SPIRVariable &var) +{ + return to_unpacked_expression(var.initializer); +} + +string CompilerGLSL::to_zero_initialized_expression(uint32_t type_id) +{ +#ifndef NDEBUG + auto &type = get(type_id); + assert(type.storage == StorageClassPrivate || type.storage == StorageClassFunction || + type.storage == StorageClassGeneric); +#endif + uint32_t id = ir.increase_bound_by(1); + ir.make_constant_null(id, type_id, false); + return constant_expression(get(id)); +} + +bool CompilerGLSL::type_can_zero_initialize(const SPIRType &type) const +{ + if (type.pointer) + return false; + + if (!type.array.empty() && options.flatten_multidimensional_arrays) + return false; + + for (auto &literal : type.array_size_literal) + if (!literal) + return false; + + for (auto &memb : type.member_types) + if (!type_can_zero_initialize(get(memb))) + return false; + + return true; +} + +string CompilerGLSL::variable_decl(const SPIRVariable &variable) +{ + // Ignore the pointer type since GLSL doesn't have pointers. + auto &type = get_variable_data_type(variable); + + if (type.pointer_depth > 1 && !backend.support_pointer_to_pointer) + SPIRV_CROSS_THROW("Cannot declare pointer-to-pointer types."); + + auto res = join(to_qualifiers_glsl(variable.self), variable_decl(type, to_name(variable.self), variable.self)); + + if (variable.loop_variable && variable.static_expression) + { + uint32_t expr = variable.static_expression; + if (ir.ids[expr].get_type() != TypeUndef) + res += join(" = ", to_unpacked_expression(variable.static_expression)); + else if (options.force_zero_initialized_variables && type_can_zero_initialize(type)) + res += join(" = ", to_zero_initialized_expression(get_variable_data_type_id(variable))); + } + else if (variable.initializer && !variable_decl_is_remapped_storage(variable, StorageClassWorkgroup)) + { + uint32_t expr = variable.initializer; + if (ir.ids[expr].get_type() != TypeUndef) + res += join(" = ", to_initializer_expression(variable)); + else if (options.force_zero_initialized_variables && type_can_zero_initialize(type)) + res += join(" = ", to_zero_initialized_expression(get_variable_data_type_id(variable))); + } + + return res; +} + +const char *CompilerGLSL::to_pls_qualifiers_glsl(const SPIRVariable &variable) +{ + auto &flags = get_decoration_bitset(variable.self); + if (flags.get(DecorationRelaxedPrecision)) + return "mediump "; + else + return "highp "; +} + +string CompilerGLSL::pls_decl(const PlsRemap &var) +{ + auto &variable = get(var.id); + + auto op_and_basetype = pls_format_to_basetype(var.format); + + SPIRType type { op_and_basetype.first }; + type.basetype = op_and_basetype.second; + auto vecsize = pls_format_to_components(var.format); + if (vecsize > 1) + { + type.op = OpTypeVector; + type.vecsize = vecsize; + } + + return join(to_pls_layout(var.format), to_pls_qualifiers_glsl(variable), type_to_glsl(type), " ", + to_name(variable.self)); +} + +uint32_t CompilerGLSL::to_array_size_literal(const SPIRType &type) const +{ + return to_array_size_literal(type, uint32_t(type.array.size() - 1)); +} + +uint32_t CompilerGLSL::to_array_size_literal(const SPIRType &type, uint32_t index) const +{ + assert(type.array.size() == type.array_size_literal.size()); + + if (type.array_size_literal[index]) + { + return type.array[index]; + } + else + { + // Use the default spec constant value. + // This is the best we can do. + return evaluate_constant_u32(type.array[index]); + } +} + +string CompilerGLSL::to_array_size(const SPIRType &type, uint32_t index) +{ + assert(type.array.size() == type.array_size_literal.size()); + + auto &size = type.array[index]; + if (!type.array_size_literal[index]) + return to_expression(size); + else if (size) + return convert_to_string(size); + else if (!backend.unsized_array_supported) + { + // For runtime-sized arrays, we can work around + // lack of standard support for this by simply having + // a single element array. + // + // Runtime length arrays must always be the last element + // in an interface block. + return "1"; + } + else + return ""; +} + +string CompilerGLSL::type_to_array_glsl(const SPIRType &type, uint32_t) +{ + if (type.pointer && type.storage == StorageClassPhysicalStorageBufferEXT && type.basetype != SPIRType::Struct) + { + // We are using a wrapped pointer type, and we should not emit any array declarations here. + return ""; + } + + if (type.array.empty()) + return ""; + + if (options.flatten_multidimensional_arrays) + { + string res; + res += "["; + for (auto i = uint32_t(type.array.size()); i; i--) + { + res += enclose_expression(to_array_size(type, i - 1)); + if (i > 1) + res += " * "; + } + res += "]"; + return res; + } + else + { + if (type.array.size() > 1) + { + if (!options.es && options.version < 430) + require_extension_internal("GL_ARB_arrays_of_arrays"); + else if (options.es && options.version < 310) + SPIRV_CROSS_THROW("Arrays of arrays not supported before ESSL version 310. " + "Try using --flatten-multidimensional-arrays or set " + "options.flatten_multidimensional_arrays to true."); + } + + string res; + for (auto i = uint32_t(type.array.size()); i; i--) + { + res += "["; + res += to_array_size(type, i - 1); + res += "]"; + } + return res; + } +} + +string CompilerGLSL::image_type_glsl(const SPIRType &type, uint32_t id, bool /*member*/) +{ + auto &imagetype = get(type.image.type); + string res; + + switch (imagetype.basetype) + { + case SPIRType::Int64: + res = "i64"; + require_extension_internal("GL_EXT_shader_image_int64"); + break; + case SPIRType::UInt64: + res = "u64"; + require_extension_internal("GL_EXT_shader_image_int64"); + break; + case SPIRType::Int: + case SPIRType::Short: + case SPIRType::SByte: + res = "i"; + break; + case SPIRType::UInt: + case SPIRType::UShort: + case SPIRType::UByte: + res = "u"; + break; + default: + break; + } + + // For half image types, we will force mediump for the sampler, and cast to f16 after any sampling operation. + // We cannot express a true half texture type in GLSL. Neither for short integer formats for that matter. + + if (type.basetype == SPIRType::Image && type.image.dim == DimSubpassData && options.vulkan_semantics) + return res + "subpassInput" + (type.image.ms ? "MS" : ""); + else if (type.basetype == SPIRType::Image && type.image.dim == DimSubpassData && + subpass_input_is_framebuffer_fetch(id)) + { + SPIRType sampled_type = get(type.image.type); + sampled_type.vecsize = 4; + return type_to_glsl(sampled_type); + } + + // If we're emulating subpassInput with samplers, force sampler2D + // so we don't have to specify format. + if (type.basetype == SPIRType::Image && type.image.dim != DimSubpassData) + { + // Sampler buffers are always declared as samplerBuffer even though they might be separate images in the SPIR-V. + if (type.image.dim == DimBuffer && type.image.sampled == 1) + res += "sampler"; + else + res += type.image.sampled == 2 ? "image" : "texture"; + } + else + res += "sampler"; + + switch (type.image.dim) + { + case Dim1D: + // ES doesn't support 1D. Fake it with 2D. + res += options.es ? "2D" : "1D"; + break; + case Dim2D: + res += "2D"; + break; + case Dim3D: + res += "3D"; + break; + case DimCube: + res += "Cube"; + break; + case DimRect: + if (options.es) + SPIRV_CROSS_THROW("Rectangle textures are not supported on OpenGL ES."); + + if (is_legacy_desktop()) + require_extension_internal("GL_ARB_texture_rectangle"); + + res += "2DRect"; + break; + + case DimBuffer: + if (options.es && options.version < 320) + require_extension_internal("GL_EXT_texture_buffer"); + else if (!options.es && options.version < 300) + require_extension_internal("GL_EXT_texture_buffer_object"); + res += "Buffer"; + break; + + case DimSubpassData: + res += "2D"; + break; + default: + SPIRV_CROSS_THROW("Only 1D, 2D, 2DRect, 3D, Buffer, InputTarget and Cube textures supported."); + } + + if (type.image.ms) + res += "MS"; + if (type.image.arrayed) + { + if (is_legacy_desktop()) + require_extension_internal("GL_EXT_texture_array"); + res += "Array"; + } + + // "Shadow" state in GLSL only exists for samplers and combined image samplers. + if (((type.basetype == SPIRType::SampledImage) || (type.basetype == SPIRType::Sampler)) && + is_depth_image(type, id)) + { + res += "Shadow"; + + if (type.image.dim == DimCube && is_legacy()) + { + if (!options.es) + require_extension_internal("GL_EXT_gpu_shader4"); + else + { + require_extension_internal("GL_NV_shadow_samplers_cube"); + res += "NV"; + } + } + } + + return res; +} + +string CompilerGLSL::type_to_glsl_constructor(const SPIRType &type) +{ + if (backend.use_array_constructor && type.array.size() > 1) + { + if (options.flatten_multidimensional_arrays) + SPIRV_CROSS_THROW("Cannot flatten constructors of multidimensional array constructors, " + "e.g. float[][]()."); + else if (!options.es && options.version < 430) + require_extension_internal("GL_ARB_arrays_of_arrays"); + else if (options.es && options.version < 310) + SPIRV_CROSS_THROW("Arrays of arrays not supported before ESSL version 310."); + } + + auto e = type_to_glsl(type); + if (backend.use_array_constructor) + { + for (uint32_t i = 0; i < type.array.size(); i++) + e += "[]"; + } + return e; +} + +// The optional id parameter indicates the object whose type we are trying +// to find the description for. It is optional. Most type descriptions do not +// depend on a specific object's use of that type. +string CompilerGLSL::type_to_glsl(const SPIRType &type, uint32_t id) +{ + if (is_physical_pointer(type) && !is_physical_pointer_to_buffer_block(type)) + { + // Need to create a magic type name which compacts the entire type information. + auto *parent = &get_pointee_type(type); + string name = type_to_glsl(*parent); + + uint32_t array_stride = get_decoration(type.parent_type, DecorationArrayStride); + + // Resolve all array dimensions in one go since once we lose the pointer type, + // array information is left to to_array_type_glsl. The base type loses array information. + while (is_array(*parent)) + { + if (parent->array_size_literal.back()) + name += join(type.array.back(), "_"); + else + name += join("id", type.array.back(), "_"); + + name += "stride_" + std::to_string(array_stride); + + array_stride = get_decoration(parent->parent_type, DecorationArrayStride); + parent = &get(parent->parent_type); + } + + name += "Pointer"; + return name; + } + + switch (type.basetype) + { + case SPIRType::Struct: + // Need OpName lookup here to get a "sensible" name for a struct. + if (backend.explicit_struct_type) + return join("struct ", to_name(type.self)); + else + return to_name(type.self); + + case SPIRType::Image: + case SPIRType::SampledImage: + return image_type_glsl(type, id); + + case SPIRType::Sampler: + // The depth field is set by calling code based on the variable ID of the sampler, effectively reintroducing + // this distinction into the type system. + return comparison_ids.count(id) ? "samplerShadow" : "sampler"; + + case SPIRType::AccelerationStructure: + return ray_tracing_is_khr ? "accelerationStructureEXT" : "accelerationStructureNV"; + + case SPIRType::RayQuery: + return "rayQueryEXT"; + + case SPIRType::Void: + return "void"; + + default: + break; + } + + if (type.basetype == SPIRType::UInt && is_legacy()) + { + if (options.es) + // HACK: spirv-cross changes bools into uints and generates code which compares them to + // zero. Input code will have already been validated as not to have contained any uints, + // so any remaining uints must in fact be bools. However, simply returning "bool" here + // will result in invalid code. Instead, return an int. + return backend.basic_int_type; + else + require_extension_internal("GL_EXT_gpu_shader4"); + } + + if (type.basetype == SPIRType::AtomicCounter) + { + if (options.es && options.version < 310) + SPIRV_CROSS_THROW("At least ESSL 3.10 required for atomic counters."); + else if (!options.es && options.version < 420) + require_extension_internal("GL_ARB_shader_atomic_counters"); + } + + if (type.vecsize == 1 && type.columns == 1) // Scalar builtin + { + switch (type.basetype) + { + case SPIRType::Boolean: + return "bool"; + case SPIRType::SByte: + return backend.basic_int8_type; + case SPIRType::UByte: + return backend.basic_uint8_type; + case SPIRType::Short: + return backend.basic_int16_type; + case SPIRType::UShort: + return backend.basic_uint16_type; + case SPIRType::Int: + return backend.basic_int_type; + case SPIRType::UInt: + return backend.basic_uint_type; + case SPIRType::AtomicCounter: + return "atomic_uint"; + case SPIRType::Half: + return "float16_t"; + case SPIRType::Float: + return "float"; + case SPIRType::Double: + return "double"; + case SPIRType::Int64: + return "int64_t"; + case SPIRType::UInt64: + return "uint64_t"; + default: + return "???"; + } + } + else if (type.vecsize > 1 && type.columns == 1) // Vector builtin + { + switch (type.basetype) + { + case SPIRType::Boolean: + return join("bvec", type.vecsize); + case SPIRType::SByte: + return join("i8vec", type.vecsize); + case SPIRType::UByte: + return join("u8vec", type.vecsize); + case SPIRType::Short: + return join("i16vec", type.vecsize); + case SPIRType::UShort: + return join("u16vec", type.vecsize); + case SPIRType::Int: + return join("ivec", type.vecsize); + case SPIRType::UInt: + return join("uvec", type.vecsize); + case SPIRType::Half: + return join("f16vec", type.vecsize); + case SPIRType::Float: + return join("vec", type.vecsize); + case SPIRType::Double: + return join("dvec", type.vecsize); + case SPIRType::Int64: + return join("i64vec", type.vecsize); + case SPIRType::UInt64: + return join("u64vec", type.vecsize); + default: + return "???"; + } + } + else if (type.vecsize == type.columns) // Simple Matrix builtin + { + switch (type.basetype) + { + case SPIRType::Boolean: + return join("bmat", type.vecsize); + case SPIRType::Int: + return join("imat", type.vecsize); + case SPIRType::UInt: + return join("umat", type.vecsize); + case SPIRType::Half: + return join("f16mat", type.vecsize); + case SPIRType::Float: + return join("mat", type.vecsize); + case SPIRType::Double: + return join("dmat", type.vecsize); + // Matrix types not supported for int64/uint64. + default: + return "???"; + } + } + else + { + switch (type.basetype) + { + case SPIRType::Boolean: + return join("bmat", type.columns, "x", type.vecsize); + case SPIRType::Int: + return join("imat", type.columns, "x", type.vecsize); + case SPIRType::UInt: + return join("umat", type.columns, "x", type.vecsize); + case SPIRType::Half: + return join("f16mat", type.columns, "x", type.vecsize); + case SPIRType::Float: + return join("mat", type.columns, "x", type.vecsize); + case SPIRType::Double: + return join("dmat", type.columns, "x", type.vecsize); + // Matrix types not supported for int64/uint64. + default: + return "???"; + } + } +} + +void CompilerGLSL::add_variable(unordered_set &variables_primary, + const unordered_set &variables_secondary, string &name) +{ + if (name.empty()) + return; + + ParsedIR::sanitize_underscores(name); + if (ParsedIR::is_globally_reserved_identifier(name, true)) + { + name.clear(); + return; + } + + update_name_cache(variables_primary, variables_secondary, name); +} + +void CompilerGLSL::add_local_variable_name(uint32_t id) +{ + add_variable(local_variable_names, block_names, ir.meta[id].decoration.alias); +} + +void CompilerGLSL::add_resource_name(uint32_t id) +{ + add_variable(resource_names, block_names, ir.meta[id].decoration.alias); +} + +void CompilerGLSL::add_header_line(const std::string &line) +{ + header_lines.push_back(line); +} + +bool CompilerGLSL::has_extension(const std::string &ext) const +{ + auto itr = find(begin(forced_extensions), end(forced_extensions), ext); + return itr != end(forced_extensions); +} + +void CompilerGLSL::require_extension(const std::string &ext) +{ + if (!has_extension(ext)) + forced_extensions.push_back(ext); +} + +const SmallVector &CompilerGLSL::get_required_extensions() const +{ + return forced_extensions; +} + +void CompilerGLSL::require_extension_internal(const string &ext) +{ + if (backend.supports_extensions && !has_extension(ext)) + { + forced_extensions.push_back(ext); + force_recompile(); + } +} + +void CompilerGLSL::flatten_buffer_block(VariableID id) +{ + auto &var = get(id); + auto &type = get(var.basetype); + auto name = to_name(type.self, false); + auto &flags = get_decoration_bitset(type.self); + + if (!type.array.empty()) + SPIRV_CROSS_THROW(name + " is an array of UBOs."); + if (type.basetype != SPIRType::Struct) + SPIRV_CROSS_THROW(name + " is not a struct."); + if (!flags.get(DecorationBlock)) + SPIRV_CROSS_THROW(name + " is not a block."); + if (type.member_types.empty()) + SPIRV_CROSS_THROW(name + " is an empty struct."); + + flattened_buffer_blocks.insert(id); +} + +bool CompilerGLSL::builtin_translates_to_nonarray(spv::BuiltIn /*builtin*/) const +{ + return false; // GLSL itself does not need to translate array builtin types to non-array builtin types +} + +bool CompilerGLSL::is_user_type_structured(uint32_t /*id*/) const +{ + return false; // GLSL itself does not have structured user type, but HLSL does with StructuredBuffer and RWStructuredBuffer resources. +} + +bool CompilerGLSL::check_atomic_image(uint32_t id) +{ + auto &type = expression_type(id); + if (type.storage == StorageClassImage) + { + if (options.es && options.version < 320) + require_extension_internal("GL_OES_shader_image_atomic"); + + auto *var = maybe_get_backing_variable(id); + if (var) + { + if (has_decoration(var->self, DecorationNonWritable) || has_decoration(var->self, DecorationNonReadable)) + { + unset_decoration(var->self, DecorationNonWritable); + unset_decoration(var->self, DecorationNonReadable); + force_recompile(); + } + } + return true; + } + else + return false; +} + +void CompilerGLSL::add_function_overload(const SPIRFunction &func) +{ + Hasher hasher; + for (auto &arg : func.arguments) + { + // Parameters can vary with pointer type or not, + // but that will not change the signature in GLSL/HLSL, + // so strip the pointer type before hashing. + uint32_t type_id = get_pointee_type_id(arg.type); + auto &type = get(type_id); + + if (!combined_image_samplers.empty()) + { + // If we have combined image samplers, we cannot really trust the image and sampler arguments + // we pass down to callees, because they may be shuffled around. + // Ignore these arguments, to make sure that functions need to differ in some other way + // to be considered different overloads. + if (type.basetype == SPIRType::SampledImage || + (type.basetype == SPIRType::Image && type.image.sampled == 1) || type.basetype == SPIRType::Sampler) + { + continue; + } + } + + hasher.u32(type_id); + } + uint64_t types_hash = hasher.get(); + + auto function_name = to_name(func.self); + auto itr = function_overloads.find(function_name); + if (itr != end(function_overloads)) + { + // There exists a function with this name already. + auto &overloads = itr->second; + if (overloads.count(types_hash) != 0) + { + // Overload conflict, assign a new name. + add_resource_name(func.self); + function_overloads[to_name(func.self)].insert(types_hash); + } + else + { + // Can reuse the name. + overloads.insert(types_hash); + } + } + else + { + // First time we see this function name. + add_resource_name(func.self); + function_overloads[to_name(func.self)].insert(types_hash); + } +} + +void CompilerGLSL::emit_function_prototype(SPIRFunction &func, const Bitset &return_flags) +{ + if (func.self != ir.default_entry_point) + add_function_overload(func); + + // Avoid shadow declarations. + local_variable_names = resource_names; + + string decl; + + auto &type = get(func.return_type); + decl += flags_to_qualifiers_glsl(type, return_flags); + decl += type_to_glsl(type); + decl += type_to_array_glsl(type, 0); + decl += " "; + + if (func.self == ir.default_entry_point) + { + // If we need complex fallback in GLSL, we just wrap main() in a function + // and interlock the entire shader ... + if (interlocked_is_complex) + decl += "spvMainInterlockedBody"; + else + decl += "main"; + + processing_entry_point = true; + } + else + decl += to_name(func.self); + + decl += "("; + SmallVector arglist; + for (auto &arg : func.arguments) + { + // Do not pass in separate images or samplers if we're remapping + // to combined image samplers. + if (skip_argument(arg.id)) + continue; + + // Might change the variable name if it already exists in this function. + // SPIRV OpName doesn't have any semantic effect, so it's valid for an implementation + // to use same name for variables. + // Since we want to make the GLSL debuggable and somewhat sane, use fallback names for variables which are duplicates. + add_local_variable_name(arg.id); + + arglist.push_back(argument_decl(arg)); + + // Hold a pointer to the parameter so we can invalidate the readonly field if needed. + auto *var = maybe_get(arg.id); + if (var) + var->parameter = &arg; + } + + for (auto &arg : func.shadow_arguments) + { + // Might change the variable name if it already exists in this function. + // SPIRV OpName doesn't have any semantic effect, so it's valid for an implementation + // to use same name for variables. + // Since we want to make the GLSL debuggable and somewhat sane, use fallback names for variables which are duplicates. + add_local_variable_name(arg.id); + + arglist.push_back(argument_decl(arg)); + + // Hold a pointer to the parameter so we can invalidate the readonly field if needed. + auto *var = maybe_get(arg.id); + if (var) + var->parameter = &arg; + } + + decl += merge(arglist); + decl += ")"; + statement(decl); +} + +void CompilerGLSL::emit_function(SPIRFunction &func, const Bitset &return_flags) +{ + // Avoid potential cycles. + if (func.active) + return; + func.active = true; + + // If we depend on a function, emit that function before we emit our own function. + for (auto block : func.blocks) + { + auto &b = get(block); + for (auto &i : b.ops) + { + auto ops = stream(i); + auto op = static_cast(i.op); + + if (op == OpFunctionCall) + { + // Recursively emit functions which are called. + uint32_t id = ops[2]; + emit_function(get(id), ir.meta[ops[1]].decoration.decoration_flags); + } + } + } + + if (func.entry_line.file_id != 0) + emit_line_directive(func.entry_line.file_id, func.entry_line.line_literal); + emit_function_prototype(func, return_flags); + begin_scope(); + + if (func.self == ir.default_entry_point) + emit_entry_point_declarations(); + + current_function = &func; + auto &entry_block = get(func.entry_block); + + sort(begin(func.constant_arrays_needed_on_stack), end(func.constant_arrays_needed_on_stack)); + for (auto &array : func.constant_arrays_needed_on_stack) + { + auto &c = get(array); + auto &type = get(c.constant_type); + statement(variable_decl(type, join("_", array, "_array_copy")), " = ", constant_expression(c), ";"); + } + + for (auto &v : func.local_variables) + { + auto &var = get(v); + var.deferred_declaration = false; + + if (variable_decl_is_remapped_storage(var, StorageClassWorkgroup)) + { + // Special variable type which cannot have initializer, + // need to be declared as standalone variables. + // Comes from MSL which can push global variables as local variables in main function. + add_local_variable_name(var.self); + statement(variable_decl(var), ";"); + var.deferred_declaration = false; + } + else if (var.storage == StorageClassPrivate) + { + // These variables will not have had their CFG usage analyzed, so move it to the entry block. + // Comes from MSL which can push global variables as local variables in main function. + // We could just declare them right now, but we would miss out on an important initialization case which is + // LUT declaration in MSL. + // If we don't declare the variable when it is assigned we're forced to go through a helper function + // which copies elements one by one. + add_local_variable_name(var.self); + + if (var.initializer) + { + statement(variable_decl(var), ";"); + var.deferred_declaration = false; + } + else + { + auto &dominated = entry_block.dominated_variables; + if (find(begin(dominated), end(dominated), var.self) == end(dominated)) + entry_block.dominated_variables.push_back(var.self); + var.deferred_declaration = true; + } + } + else if (var.storage == StorageClassFunction && var.remapped_variable && var.static_expression) + { + // No need to declare this variable, it has a static expression. + var.deferred_declaration = false; + } + else if (expression_is_lvalue(v)) + { + add_local_variable_name(var.self); + + // Loop variables should never be declared early, they are explicitly emitted in a loop. + if (var.initializer && !var.loop_variable) + statement(variable_decl_function_local(var), ";"); + else + { + // Don't declare variable until first use to declutter the GLSL output quite a lot. + // If we don't touch the variable before first branch, + // declare it then since we need variable declaration to be in top scope. + var.deferred_declaration = true; + } + } + else + { + // HACK: SPIR-V in older glslang output likes to use samplers and images as local variables, but GLSL does not allow this. + // For these types (non-lvalue), we enforce forwarding through a shadowed variable. + // This means that when we OpStore to these variables, we just write in the expression ID directly. + // This breaks any kind of branching, since the variable must be statically assigned. + // Branching on samplers and images would be pretty much impossible to fake in GLSL. + var.statically_assigned = true; + } + + var.loop_variable_enable = false; + + // Loop variables are never declared outside their for-loop, so block any implicit declaration. + if (var.loop_variable) + { + var.deferred_declaration = false; + // Need to reset the static expression so we can fallback to initializer if need be. + var.static_expression = 0; + } + } + + // Enforce declaration order for regression testing purposes. + for (auto &block_id : func.blocks) + { + auto &block = get(block_id); + sort(begin(block.dominated_variables), end(block.dominated_variables)); + } + + for (auto &line : current_function->fixup_hooks_in) + line(); + + emit_block_chain(entry_block); + + end_scope(); + processing_entry_point = false; + statement(""); + + // Make sure deferred declaration state for local variables is cleared when we are done with function. + // We risk declaring Private/Workgroup variables in places we are not supposed to otherwise. + for (auto &v : func.local_variables) + { + auto &var = get(v); + var.deferred_declaration = false; + } +} + +void CompilerGLSL::emit_fixup() +{ + if (is_vertex_like_shader()) + { + if (options.vertex.fixup_clipspace) + { + const char *suffix = backend.float_literal_suffix ? "f" : ""; + statement("gl_Position.z = 2.0", suffix, " * gl_Position.z - gl_Position.w;"); + } + + if (options.vertex.flip_vert_y) + statement("gl_Position.y = -gl_Position.y;"); + } +} + +void CompilerGLSL::flush_phi(BlockID from, BlockID to) +{ + auto &child = get(to); + if (child.ignore_phi_from_block == from) + return; + + unordered_set temporary_phi_variables; + + for (auto itr = begin(child.phi_variables); itr != end(child.phi_variables); ++itr) + { + auto &phi = *itr; + + if (phi.parent == from) + { + auto &var = get(phi.function_variable); + + // A Phi variable might be a loop variable, so flush to static expression. + if (var.loop_variable && !var.loop_variable_enable) + var.static_expression = phi.local_variable; + else + { + flush_variable_declaration(phi.function_variable); + + // Check if we are going to write to a Phi variable that another statement will read from + // as part of another Phi node in our target block. + // For this case, we will need to copy phi.function_variable to a temporary, and use that for future reads. + // This is judged to be extremely rare, so deal with it here using a simple, but suboptimal algorithm. + bool need_saved_temporary = + find_if(itr + 1, end(child.phi_variables), [&](const SPIRBlock::Phi &future_phi) -> bool { + return future_phi.local_variable == ID(phi.function_variable) && future_phi.parent == from; + }) != end(child.phi_variables); + + if (need_saved_temporary) + { + // Need to make sure we declare the phi variable with a copy at the right scope. + // We cannot safely declare a temporary here since we might be inside a continue block. + if (!var.allocate_temporary_copy) + { + var.allocate_temporary_copy = true; + force_recompile(); + } + statement("_", phi.function_variable, "_copy", " = ", to_name(phi.function_variable), ";"); + temporary_phi_variables.insert(phi.function_variable); + } + + // This might be called in continue block, so make sure we + // use this to emit ESSL 1.0 compliant increments/decrements. + auto lhs = to_expression(phi.function_variable); + + string rhs; + if (temporary_phi_variables.count(phi.local_variable)) + rhs = join("_", phi.local_variable, "_copy"); + else + rhs = to_pointer_expression(phi.local_variable); + + if (!optimize_read_modify_write(get(var.basetype), lhs, rhs)) + statement(lhs, " = ", rhs, ";"); + } + + register_write(phi.function_variable); + } + } +} + +void CompilerGLSL::branch_to_continue(BlockID from, BlockID to) +{ + auto &to_block = get(to); + if (from == to) + return; + + assert(is_continue(to)); + if (to_block.complex_continue) + { + // Just emit the whole block chain as is. + auto usage_counts = expression_usage_counts; + + emit_block_chain(to_block); + + // Expression usage counts are moot after returning from the continue block. + expression_usage_counts = usage_counts; + } + else + { + auto &from_block = get(from); + bool outside_control_flow = false; + uint32_t loop_dominator = 0; + + // FIXME: Refactor this to not use the old loop_dominator tracking. + if (from_block.merge_block) + { + // If we are a loop header, we don't set the loop dominator, + // so just use "self" here. + loop_dominator = from; + } + else if (from_block.loop_dominator != BlockID(SPIRBlock::NoDominator)) + { + loop_dominator = from_block.loop_dominator; + } + + if (loop_dominator != 0) + { + auto &cfg = get_cfg_for_current_function(); + + // For non-complex continue blocks, we implicitly branch to the continue block + // by having the continue block be part of the loop header in for (; ; continue-block). + outside_control_flow = cfg.node_terminates_control_flow_in_sub_graph(loop_dominator, from); + } + + // Some simplification for for-loops. We always end up with a useless continue; + // statement since we branch to a loop block. + // Walk the CFG, if we unconditionally execute the block calling continue assuming we're in the loop block, + // we can avoid writing out an explicit continue statement. + // Similar optimization to return statements if we know we're outside flow control. + if (!outside_control_flow) + statement("continue;"); + } +} + +void CompilerGLSL::branch(BlockID from, BlockID to) +{ + flush_phi(from, to); + flush_control_dependent_expressions(from); + + bool to_is_continue = is_continue(to); + + // This is only a continue if we branch to our loop dominator. + if ((ir.block_meta[to] & ParsedIR::BLOCK_META_LOOP_HEADER_BIT) != 0 && get(from).loop_dominator == to) + { + // This can happen if we had a complex continue block which was emitted. + // Once the continue block tries to branch to the loop header, just emit continue; + // and end the chain here. + statement("continue;"); + } + else if (from != to && is_break(to)) + { + // We cannot break to ourselves, so check explicitly for from != to. + // This case can trigger if a loop header is all three of these things: + // - Continue block + // - Loop header + // - Break merge target all at once ... + + // Very dirty workaround. + // Switch constructs are able to break, but they cannot break out of a loop at the same time, + // yet SPIR-V allows it. + // Only sensible solution is to make a ladder variable, which we declare at the top of the switch block, + // write to the ladder here, and defer the break. + // The loop we're breaking out of must dominate the switch block, or there is no ladder breaking case. + if (is_loop_break(to)) + { + for (size_t n = current_emitting_switch_stack.size(); n; n--) + { + auto *current_emitting_switch = current_emitting_switch_stack[n - 1]; + + if (current_emitting_switch && + current_emitting_switch->loop_dominator != BlockID(SPIRBlock::NoDominator) && + get(current_emitting_switch->loop_dominator).merge_block == to) + { + if (!current_emitting_switch->need_ladder_break) + { + force_recompile(); + current_emitting_switch->need_ladder_break = true; + } + + statement("_", current_emitting_switch->self, "_ladder_break = true;"); + } + else + break; + } + } + statement("break;"); + } + else if (to_is_continue || from == to) + { + // For from == to case can happen for a do-while loop which branches into itself. + // We don't mark these cases as continue blocks, but the only possible way to branch into + // ourselves is through means of continue blocks. + + // If we are merging to a continue block, there is no need to emit the block chain for continue here. + // We can branch to the continue block after we merge execution. + + // Here we make use of structured control flow rules from spec: + // 2.11: - the merge block declared by a header block cannot be a merge block declared by any other header block + // - each header block must strictly dominate its merge block, unless the merge block is unreachable in the CFG + // If we are branching to a merge block, we must be inside a construct which dominates the merge block. + auto &block_meta = ir.block_meta[to]; + bool branching_to_merge = + (block_meta & (ParsedIR::BLOCK_META_SELECTION_MERGE_BIT | ParsedIR::BLOCK_META_MULTISELECT_MERGE_BIT | + ParsedIR::BLOCK_META_LOOP_MERGE_BIT)) != 0; + if (!to_is_continue || !branching_to_merge) + branch_to_continue(from, to); + } + else if (!is_conditional(to)) + emit_block_chain(get(to)); + + // It is important that we check for break before continue. + // A block might serve two purposes, a break block for the inner scope, and + // a continue block in the outer scope. + // Inner scope always takes precedence. +} + +void CompilerGLSL::branch(BlockID from, uint32_t cond, BlockID true_block, BlockID false_block) +{ + auto &from_block = get(from); + BlockID merge_block = from_block.merge == SPIRBlock::MergeSelection ? from_block.next_block : BlockID(0); + + // If we branch directly to our selection merge target, we don't need a code path. + bool true_block_needs_code = true_block != merge_block || flush_phi_required(from, true_block); + bool false_block_needs_code = false_block != merge_block || flush_phi_required(from, false_block); + + if (!true_block_needs_code && !false_block_needs_code) + return; + + // We might have a loop merge here. Only consider selection flattening constructs. + // Loop hints are handled explicitly elsewhere. + if (from_block.hint == SPIRBlock::HintFlatten || from_block.hint == SPIRBlock::HintDontFlatten) + emit_block_hints(from_block); + + if (true_block_needs_code) + { + statement("if (", to_expression(cond), ")"); + begin_scope(); + branch(from, true_block); + end_scope(); + + if (false_block_needs_code) + { + statement("else"); + begin_scope(); + branch(from, false_block); + end_scope(); + } + } + else if (false_block_needs_code) + { + // Only need false path, use negative conditional. + statement("if (!", to_enclosed_expression(cond), ")"); + begin_scope(); + branch(from, false_block); + end_scope(); + } +} + +// FIXME: This currently cannot handle complex continue blocks +// as in do-while. +// This should be seen as a "trivial" continue block. +string CompilerGLSL::emit_continue_block(uint32_t continue_block, bool follow_true_block, bool follow_false_block) +{ + auto *block = &get(continue_block); + + // While emitting the continue block, declare_temporary will check this + // if we have to emit temporaries. + current_continue_block = block; + + SmallVector statements; + + // Capture all statements into our list. + auto *old = redirect_statement; + redirect_statement = &statements; + + // Stamp out all blocks one after each other. + while ((ir.block_meta[block->self] & ParsedIR::BLOCK_META_LOOP_HEADER_BIT) == 0) + { + // Write out all instructions we have in this block. + emit_block_instructions(*block); + + // For plain branchless for/while continue blocks. + if (block->next_block) + { + flush_phi(continue_block, block->next_block); + block = &get(block->next_block); + } + // For do while blocks. The last block will be a select block. + else if (block->true_block && follow_true_block) + { + flush_phi(continue_block, block->true_block); + block = &get(block->true_block); + } + else if (block->false_block && follow_false_block) + { + flush_phi(continue_block, block->false_block); + block = &get(block->false_block); + } + else + { + SPIRV_CROSS_THROW("Invalid continue block detected!"); + } + } + + // Restore old pointer. + redirect_statement = old; + + // Somewhat ugly, strip off the last ';' since we use ',' instead. + // Ideally, we should select this behavior in statement(). + for (auto &s : statements) + { + if (!s.empty() && s.back() == ';') + s.erase(s.size() - 1, 1); + } + + current_continue_block = nullptr; + return merge(statements); +} + +void CompilerGLSL::emit_while_loop_initializers(const SPIRBlock &block) +{ + // While loops do not take initializers, so declare all of them outside. + for (auto &loop_var : block.loop_variables) + { + auto &var = get(loop_var); + statement(variable_decl(var), ";"); + } +} + +string CompilerGLSL::emit_for_loop_initializers(const SPIRBlock &block) +{ + if (block.loop_variables.empty()) + return ""; + + bool same_types = for_loop_initializers_are_same_type(block); + // We can only declare for loop initializers if all variables are of same type. + // If we cannot do this, declare individual variables before the loop header. + + // We might have a loop variable candidate which was not assigned to for some reason. + uint32_t missing_initializers = 0; + for (auto &variable : block.loop_variables) + { + uint32_t expr = get(variable).static_expression; + + // Sometimes loop variables are initialized with OpUndef, but we can just declare + // a plain variable without initializer in this case. + if (expr == 0 || ir.ids[expr].get_type() == TypeUndef) + missing_initializers++; + } + + if (block.loop_variables.size() == 1 && missing_initializers == 0) + { + return variable_decl(get(block.loop_variables.front())); + } + else if (!same_types || missing_initializers == uint32_t(block.loop_variables.size())) + { + for (auto &loop_var : block.loop_variables) + statement(variable_decl(get(loop_var)), ";"); + return ""; + } + else + { + // We have a mix of loop variables, either ones with a clear initializer, or ones without. + // Separate the two streams. + string expr; + + for (auto &loop_var : block.loop_variables) + { + uint32_t static_expr = get(loop_var).static_expression; + if (static_expr == 0 || ir.ids[static_expr].get_type() == TypeUndef) + { + statement(variable_decl(get(loop_var)), ";"); + } + else + { + auto &var = get(loop_var); + auto &type = get_variable_data_type(var); + if (expr.empty()) + { + // For loop initializers are of the form (block.true_block), get(block.merge_block))) + condition = join("!", enclose_expression(condition)); + + statement("while (", condition, ")"); + break; + } + + default: + block.disable_block_optimization = true; + force_recompile(); + begin_scope(); // We'll see an end_scope() later. + return false; + } + + begin_scope(); + return true; + } + else + { + block.disable_block_optimization = true; + force_recompile(); + begin_scope(); // We'll see an end_scope() later. + return false; + } + } + else if (method == SPIRBlock::MergeToDirectForLoop) + { + auto &child = get(block.next_block); + + // This block may be a dominating block, so make sure we flush undeclared variables before building the for loop header. + flush_undeclared_variables(child); + + uint32_t current_count = statement_count; + + // If we're trying to create a true for loop, + // we need to make sure that all opcodes before branch statement do not actually emit any code. + // We can then take the condition expression and create a for (; cond ; ) { body; } structure instead. + emit_block_instructions_with_masked_debug(child); + + bool condition_is_temporary = forced_temporaries.find(child.condition) == end(forced_temporaries); + + bool flushes_phi = flush_phi_required(child.self, child.true_block) || + flush_phi_required(child.self, child.false_block); + + if (!flushes_phi && current_count == statement_count && condition_is_temporary) + { + uint32_t target_block = child.true_block; + + switch (continue_type) + { + case SPIRBlock::ForLoop: + { + // Important that we do this in this order because + // emitting the continue block can invalidate the condition expression. + auto initializer = emit_for_loop_initializers(block); + auto condition = to_expression(child.condition); + + // Condition might have to be inverted. + if (execution_is_noop(get(child.true_block), get(block.merge_block))) + { + condition = join("!", enclose_expression(condition)); + target_block = child.false_block; + } + + auto continue_block = emit_continue_block(block.continue_block, false, false); + emit_block_hints(block); + statement("for (", initializer, "; ", condition, "; ", continue_block, ")"); + break; + } + + case SPIRBlock::WhileLoop: + { + emit_while_loop_initializers(block); + emit_block_hints(block); + + auto condition = to_expression(child.condition); + // Condition might have to be inverted. + if (execution_is_noop(get(child.true_block), get(block.merge_block))) + { + condition = join("!", enclose_expression(condition)); + target_block = child.false_block; + } + + statement("while (", condition, ")"); + break; + } + + default: + block.disable_block_optimization = true; + force_recompile(); + begin_scope(); // We'll see an end_scope() later. + return false; + } + + begin_scope(); + branch(child.self, target_block); + return true; + } + else + { + block.disable_block_optimization = true; + force_recompile(); + begin_scope(); // We'll see an end_scope() later. + return false; + } + } + else + return false; +} + +void CompilerGLSL::flush_undeclared_variables(SPIRBlock &block) +{ + for (auto &v : block.dominated_variables) + flush_variable_declaration(v); +} + +void CompilerGLSL::emit_hoisted_temporaries(SmallVector> &temporaries) +{ + // If we need to force temporaries for certain IDs due to continue blocks, do it before starting loop header. + // Need to sort these to ensure that reference output is stable. + sort(begin(temporaries), end(temporaries), + [](const pair &a, const pair &b) { return a.second < b.second; }); + + for (auto &tmp : temporaries) + { + auto &type = get(tmp.first); + + // There are some rare scenarios where we are asked to declare pointer types as hoisted temporaries. + // This should be ignored unless we're doing actual variable pointers and backend supports it. + // Access chains cannot normally be lowered to temporaries in GLSL and HLSL. + if (type.pointer && !backend.native_pointers) + continue; + + add_local_variable_name(tmp.second); + auto &flags = get_decoration_bitset(tmp.second); + + // Not all targets support pointer literals, so don't bother with that case. + string initializer; + if (options.force_zero_initialized_variables && type_can_zero_initialize(type)) + initializer = join(" = ", to_zero_initialized_expression(tmp.first)); + + statement(flags_to_qualifiers_glsl(type, flags), variable_decl(type, to_name(tmp.second)), initializer, ";"); + + hoisted_temporaries.insert(tmp.second); + forced_temporaries.insert(tmp.second); + + // The temporary might be read from before it's assigned, set up the expression now. + set(tmp.second, to_name(tmp.second), tmp.first, true); + + // If we have hoisted temporaries in multi-precision contexts, emit that here too ... + // We will not be able to analyze hoisted-ness for dependent temporaries that we hallucinate here. + auto mirrored_precision_itr = temporary_to_mirror_precision_alias.find(tmp.second); + if (mirrored_precision_itr != temporary_to_mirror_precision_alias.end()) + { + uint32_t mirror_id = mirrored_precision_itr->second; + auto &mirror_flags = get_decoration_bitset(mirror_id); + statement(flags_to_qualifiers_glsl(type, mirror_flags), + variable_decl(type, to_name(mirror_id)), + initializer, ";"); + // The temporary might be read from before it's assigned, set up the expression now. + set(mirror_id, to_name(mirror_id), tmp.first, true); + hoisted_temporaries.insert(mirror_id); + } + } +} + +void CompilerGLSL::emit_block_chain(SPIRBlock &block) +{ + bool select_branch_to_true_block = false; + bool select_branch_to_false_block = false; + bool skip_direct_branch = false; + bool emitted_loop_header_variables = false; + bool force_complex_continue_block = false; + ValueSaver loop_level_saver(current_loop_level); + + if (block.merge == SPIRBlock::MergeLoop) + add_loop_level(); + + // If we're emitting PHI variables with precision aliases, we have to emit them as hoisted temporaries. + for (auto var_id : block.dominated_variables) + { + auto &var = get(var_id); + if (var.phi_variable) + { + auto mirrored_precision_itr = temporary_to_mirror_precision_alias.find(var_id); + if (mirrored_precision_itr != temporary_to_mirror_precision_alias.end() && + find_if(block.declare_temporary.begin(), block.declare_temporary.end(), + [mirrored_precision_itr](const std::pair &p) { + return p.second == mirrored_precision_itr->second; + }) == block.declare_temporary.end()) + { + block.declare_temporary.push_back({ var.basetype, mirrored_precision_itr->second }); + } + } + } + + emit_hoisted_temporaries(block.declare_temporary); + + SPIRBlock::ContinueBlockType continue_type = SPIRBlock::ContinueNone; + if (block.continue_block) + { + continue_type = continue_block_type(get(block.continue_block)); + // If we know we cannot emit a loop, mark the block early as a complex loop so we don't force unnecessary recompiles. + if (continue_type == SPIRBlock::ComplexLoop) + block.complex_continue = true; + } + + // If we have loop variables, stop masking out access to the variable now. + for (auto var_id : block.loop_variables) + { + auto &var = get(var_id); + var.loop_variable_enable = true; + // We're not going to declare the variable directly, so emit a copy here. + emit_variable_temporary_copies(var); + } + + // Remember deferred declaration state. We will restore it before returning. + SmallVector rearm_dominated_variables(block.dominated_variables.size()); + for (size_t i = 0; i < block.dominated_variables.size(); i++) + { + uint32_t var_id = block.dominated_variables[i]; + auto &var = get(var_id); + rearm_dominated_variables[i] = var.deferred_declaration; + } + + // This is the method often used by spirv-opt to implement loops. + // The loop header goes straight into the continue block. + // However, don't attempt this on ESSL 1.0, because if a loop variable is used in a continue block, + // it *MUST* be used in the continue block. This loop method will not work. + if (!is_legacy_es() && block_is_loop_candidate(block, SPIRBlock::MergeToSelectContinueForLoop)) + { + flush_undeclared_variables(block); + if (attempt_emit_loop_header(block, SPIRBlock::MergeToSelectContinueForLoop)) + { + if (execution_is_noop(get(block.true_block), get(block.merge_block))) + select_branch_to_false_block = true; + else + select_branch_to_true_block = true; + + emitted_loop_header_variables = true; + force_complex_continue_block = true; + } + } + // This is the older loop behavior in glslang which branches to loop body directly from the loop header. + else if (block_is_loop_candidate(block, SPIRBlock::MergeToSelectForLoop)) + { + flush_undeclared_variables(block); + if (attempt_emit_loop_header(block, SPIRBlock::MergeToSelectForLoop)) + { + // The body of while, is actually just the true (or false) block, so always branch there unconditionally. + if (execution_is_noop(get(block.true_block), get(block.merge_block))) + select_branch_to_false_block = true; + else + select_branch_to_true_block = true; + + emitted_loop_header_variables = true; + } + } + // This is the newer loop behavior in glslang which branches from Loop header directly to + // a new block, which in turn has a OpBranchSelection without a selection merge. + else if (block_is_loop_candidate(block, SPIRBlock::MergeToDirectForLoop)) + { + flush_undeclared_variables(block); + if (attempt_emit_loop_header(block, SPIRBlock::MergeToDirectForLoop)) + { + skip_direct_branch = true; + emitted_loop_header_variables = true; + } + } + else if (continue_type == SPIRBlock::DoWhileLoop) + { + flush_undeclared_variables(block); + emit_while_loop_initializers(block); + emitted_loop_header_variables = true; + // We have some temporaries where the loop header is the dominator. + // We risk a case where we have code like: + // for (;;) { create-temporary; break; } consume-temporary; + // so force-declare temporaries here. + emit_hoisted_temporaries(block.potential_declare_temporary); + statement("do"); + begin_scope(); + + emit_block_instructions(block); + } + else if (block.merge == SPIRBlock::MergeLoop) + { + flush_undeclared_variables(block); + emit_while_loop_initializers(block); + emitted_loop_header_variables = true; + + // We have a generic loop without any distinguishable pattern like for, while or do while. + get(block.continue_block).complex_continue = true; + continue_type = SPIRBlock::ComplexLoop; + + // We have some temporaries where the loop header is the dominator. + // We risk a case where we have code like: + // for (;;) { create-temporary; break; } consume-temporary; + // so force-declare temporaries here. + emit_hoisted_temporaries(block.potential_declare_temporary); + emit_block_hints(block); + statement("for (;;)"); + begin_scope(); + + emit_block_instructions(block); + } + else + { + emit_block_instructions(block); + } + + // If we didn't successfully emit a loop header and we had loop variable candidates, we have a problem + // as writes to said loop variables might have been masked out, we need a recompile. + if (!emitted_loop_header_variables && !block.loop_variables.empty()) + { + force_recompile_guarantee_forward_progress(); + for (auto var : block.loop_variables) + get(var).loop_variable = false; + block.loop_variables.clear(); + } + + flush_undeclared_variables(block); + bool emit_next_block = true; + + // Handle end of block. + switch (block.terminator) + { + case SPIRBlock::Direct: + // True when emitting complex continue block. + if (block.loop_dominator == block.next_block) + { + branch(block.self, block.next_block); + emit_next_block = false; + } + // True if MergeToDirectForLoop succeeded. + else if (skip_direct_branch) + emit_next_block = false; + else if (is_continue(block.next_block) || is_break(block.next_block) || is_conditional(block.next_block)) + { + branch(block.self, block.next_block); + emit_next_block = false; + } + break; + + case SPIRBlock::Select: + // True if MergeToSelectForLoop or MergeToSelectContinueForLoop succeeded. + if (select_branch_to_true_block) + { + if (force_complex_continue_block) + { + assert(block.true_block == block.continue_block); + + // We're going to emit a continue block directly here, so make sure it's marked as complex. + auto &complex_continue = get(block.continue_block).complex_continue; + bool old_complex = complex_continue; + complex_continue = true; + branch(block.self, block.true_block); + complex_continue = old_complex; + } + else + branch(block.self, block.true_block); + } + else if (select_branch_to_false_block) + { + if (force_complex_continue_block) + { + assert(block.false_block == block.continue_block); + + // We're going to emit a continue block directly here, so make sure it's marked as complex. + auto &complex_continue = get(block.continue_block).complex_continue; + bool old_complex = complex_continue; + complex_continue = true; + branch(block.self, block.false_block); + complex_continue = old_complex; + } + else + branch(block.self, block.false_block); + } + else + branch(block.self, block.condition, block.true_block, block.false_block); + break; + + case SPIRBlock::MultiSelect: + { + auto &type = expression_type(block.condition); + bool unsigned_case = type.basetype == SPIRType::UInt || type.basetype == SPIRType::UShort || + type.basetype == SPIRType::UByte || type.basetype == SPIRType::UInt64; + + if (block.merge == SPIRBlock::MergeNone) + SPIRV_CROSS_THROW("Switch statement is not structured"); + + if (!backend.support_64bit_switch && (type.basetype == SPIRType::UInt64 || type.basetype == SPIRType::Int64)) + { + // SPIR-V spec suggests this is allowed, but we cannot support it in higher level languages. + SPIRV_CROSS_THROW("Cannot use 64-bit switch selectors."); + } + + const char *label_suffix = ""; + if (type.basetype == SPIRType::UInt && backend.uint32_t_literal_suffix) + label_suffix = "u"; + else if (type.basetype == SPIRType::Int64 && backend.support_64bit_switch) + label_suffix = "l"; + else if (type.basetype == SPIRType::UInt64 && backend.support_64bit_switch) + label_suffix = "ul"; + else if (type.basetype == SPIRType::UShort) + label_suffix = backend.uint16_t_literal_suffix; + else if (type.basetype == SPIRType::Short) + label_suffix = backend.int16_t_literal_suffix; + + current_emitting_switch_stack.push_back(&block); + + if (block.need_ladder_break) + statement("bool _", block.self, "_ladder_break = false;"); + + // Find all unique case constructs. + unordered_map> case_constructs; + SmallVector block_declaration_order; + SmallVector literals_to_merge; + + // If a switch case branches to the default block for some reason, we can just remove that literal from consideration + // and let the default: block handle it. + // 2.11 in SPIR-V spec states that for fall-through cases, there is a very strict declaration order which we can take advantage of here. + // We only need to consider possible fallthrough if order[i] branches to order[i + 1]. + auto &cases = get_case_list(block); + for (auto &c : cases) + { + if (c.block != block.next_block && c.block != block.default_block) + { + if (!case_constructs.count(c.block)) + block_declaration_order.push_back(c.block); + case_constructs[c.block].push_back(c.value); + } + else if (c.block == block.next_block && block.default_block != block.next_block) + { + // We might have to flush phi inside specific case labels. + // If we can piggyback on default:, do so instead. + literals_to_merge.push_back(c.value); + } + } + + // Empty literal array -> default. + if (block.default_block != block.next_block) + { + auto &default_block = get(block.default_block); + + // We need to slide in the default block somewhere in this chain + // if there are fall-through scenarios since the default is declared separately in OpSwitch. + // Only consider trivial fall-through cases here. + size_t num_blocks = block_declaration_order.size(); + bool injected_block = false; + + for (size_t i = 0; i < num_blocks; i++) + { + auto &case_block = get(block_declaration_order[i]); + if (execution_is_direct_branch(case_block, default_block)) + { + // Fallthrough to default block, we must inject the default block here. + block_declaration_order.insert(begin(block_declaration_order) + i + 1, block.default_block); + injected_block = true; + break; + } + else if (execution_is_direct_branch(default_block, case_block)) + { + // Default case is falling through to another case label, we must inject the default block here. + block_declaration_order.insert(begin(block_declaration_order) + i, block.default_block); + injected_block = true; + break; + } + } + + // Order does not matter. + if (!injected_block) + block_declaration_order.push_back(block.default_block); + else if (is_legacy_es()) + SPIRV_CROSS_THROW("Default case label fallthrough to other case label is not supported in ESSL 1.0."); + + case_constructs[block.default_block] = {}; + } + + size_t num_blocks = block_declaration_order.size(); + + const auto to_case_label = [](uint64_t literal, uint32_t width, bool is_unsigned_case) -> string + { + if (is_unsigned_case) + return convert_to_string(literal); + + // For smaller cases, the literals are compiled as 32 bit wide + // literals so we don't need to care for all sizes specifically. + if (width <= 32) + { + return convert_to_string(int64_t(int32_t(literal))); + } + + return convert_to_string(int64_t(literal)); + }; + + const auto to_legacy_case_label = [&](uint32_t condition, const SmallVector &labels, + const char *suffix) -> string { + string ret; + size_t count = labels.size(); + for (size_t i = 0; i < count; i++) + { + if (i) + ret += " || "; + ret += join(count > 1 ? "(" : "", to_enclosed_expression(condition), " == ", labels[i], suffix, + count > 1 ? ")" : ""); + } + return ret; + }; + + // We need to deal with a complex scenario for OpPhi. If we have case-fallthrough and Phi in the picture, + // we need to flush phi nodes outside the switch block in a branch, + // and skip any Phi handling inside the case label to make fall-through work as expected. + // This kind of code-gen is super awkward and it's a last resort. Normally we would want to handle this + // inside the case label if at all possible. + for (size_t i = 1; backend.support_case_fallthrough && i < num_blocks; i++) + { + if (flush_phi_required(block.self, block_declaration_order[i]) && + flush_phi_required(block_declaration_order[i - 1], block_declaration_order[i])) + { + uint32_t target_block = block_declaration_order[i]; + + // Make sure we flush Phi, it might have been marked to be ignored earlier. + get(target_block).ignore_phi_from_block = 0; + + auto &literals = case_constructs[target_block]; + + if (literals.empty()) + { + // Oh boy, gotta make a complete negative test instead! o.o + // Find all possible literals that would *not* make us enter the default block. + // If none of those literals match, we flush Phi ... + SmallVector conditions; + for (size_t j = 0; j < num_blocks; j++) + { + auto &negative_literals = case_constructs[block_declaration_order[j]]; + for (auto &case_label : negative_literals) + conditions.push_back(join(to_enclosed_expression(block.condition), + " != ", to_case_label(case_label, type.width, unsigned_case))); + } + + statement("if (", merge(conditions, " && "), ")"); + begin_scope(); + flush_phi(block.self, target_block); + end_scope(); + } + else + { + SmallVector conditions; + conditions.reserve(literals.size()); + for (auto &case_label : literals) + conditions.push_back(join(to_enclosed_expression(block.condition), + " == ", to_case_label(case_label, type.width, unsigned_case))); + statement("if (", merge(conditions, " || "), ")"); + begin_scope(); + flush_phi(block.self, target_block); + end_scope(); + } + + // Mark the block so that we don't flush Phi from header to case label. + get(target_block).ignore_phi_from_block = block.self; + } + } + + // If there is only one default block, and no cases, this is a case where SPIRV-opt decided to emulate + // non-structured exits with the help of a switch block. + // This is buggy on FXC, so just emit the logical equivalent of a do { } while(false), which is more idiomatic. + bool block_like_switch = cases.empty(); + + // If this is true, the switch is completely meaningless, and we should just avoid it. + bool collapsed_switch = block_like_switch && block.default_block == block.next_block; + + if (!collapsed_switch) + { + if (block_like_switch || is_legacy_es()) + { + // ESSL 1.0 is not guaranteed to support do/while. + if (is_legacy_es()) + { + uint32_t counter = statement_count; + statement("for (int spvDummy", counter, " = 0; spvDummy", counter, " < 1; spvDummy", counter, + "++)"); + } + else + statement("do"); + } + else + { + emit_block_hints(block); + statement("switch (", to_unpacked_expression(block.condition), ")"); + } + begin_scope(); + } + + for (size_t i = 0; i < num_blocks; i++) + { + uint32_t target_block = block_declaration_order[i]; + auto &literals = case_constructs[target_block]; + + if (literals.empty()) + { + // Default case. + if (!block_like_switch) + { + if (is_legacy_es()) + statement("else"); + else + statement("default:"); + } + } + else + { + if (is_legacy_es()) + { + statement((i ? "else " : ""), "if (", to_legacy_case_label(block.condition, literals, label_suffix), + ")"); + } + else + { + for (auto &case_literal : literals) + { + // The case label value must be sign-extended properly in SPIR-V, so we can assume 32-bit values here. + statement("case ", to_case_label(case_literal, type.width, unsigned_case), label_suffix, ":"); + } + } + } + + auto &case_block = get(target_block); + if (backend.support_case_fallthrough && i + 1 < num_blocks && + execution_is_direct_branch(case_block, get(block_declaration_order[i + 1]))) + { + // We will fall through here, so just terminate the block chain early. + // We still need to deal with Phi potentially. + // No need for a stack-like thing here since we only do fall-through when there is a + // single trivial branch to fall-through target.. + current_emitting_switch_fallthrough = true; + } + else + current_emitting_switch_fallthrough = false; + + if (!block_like_switch) + begin_scope(); + branch(block.self, target_block); + if (!block_like_switch) + end_scope(); + + current_emitting_switch_fallthrough = false; + } + + // Might still have to flush phi variables if we branch from loop header directly to merge target. + // This is supposed to emit all cases where we branch from header to merge block directly. + // There are two main scenarios where cannot rely on default fallthrough. + // - There is an explicit default: label already. + // In this case, literals_to_merge need to form their own "default" case, so that we avoid executing that block. + // - Header -> Merge requires flushing PHI. In this case, we need to collect all cases and flush PHI there. + bool header_merge_requires_phi = flush_phi_required(block.self, block.next_block); + bool need_fallthrough_block = block.default_block == block.next_block || !literals_to_merge.empty(); + if (!collapsed_switch && ((header_merge_requires_phi && need_fallthrough_block) || !literals_to_merge.empty())) + { + for (auto &case_literal : literals_to_merge) + statement("case ", to_case_label(case_literal, type.width, unsigned_case), label_suffix, ":"); + + if (block.default_block == block.next_block) + { + if (is_legacy_es()) + statement("else"); + else + statement("default:"); + } + + begin_scope(); + flush_phi(block.self, block.next_block); + statement("break;"); + end_scope(); + } + + if (!collapsed_switch) + { + if (block_like_switch && !is_legacy_es()) + end_scope_decl("while(false)"); + else + end_scope(); + } + else + flush_phi(block.self, block.next_block); + + if (block.need_ladder_break) + { + statement("if (_", block.self, "_ladder_break)"); + begin_scope(); + statement("break;"); + end_scope(); + } + + current_emitting_switch_stack.pop_back(); + break; + } + + case SPIRBlock::Return: + { + for (auto &line : current_function->fixup_hooks_out) + line(); + + if (processing_entry_point) + emit_fixup(); + + auto &cfg = get_cfg_for_current_function(); + + if (block.return_value) + { + auto &type = expression_type(block.return_value); + if (!type.array.empty() && !backend.can_return_array) + { + // If we cannot return arrays, we will have a special out argument we can write to instead. + // The backend is responsible for setting this up, and redirection the return values as appropriate. + if (ir.ids[block.return_value].get_type() != TypeUndef) + { + emit_array_copy("spvReturnValue", 0, block.return_value, StorageClassFunction, + get_expression_effective_storage_class(block.return_value)); + } + + if (!cfg.node_terminates_control_flow_in_sub_graph(current_function->entry_block, block.self) || + block.loop_dominator != BlockID(SPIRBlock::NoDominator)) + { + statement("return;"); + } + } + else + { + // OpReturnValue can return Undef, so don't emit anything for this case. + if (ir.ids[block.return_value].get_type() != TypeUndef) + statement("return ", to_unpacked_expression(block.return_value), ";"); + } + } + else if (!cfg.node_terminates_control_flow_in_sub_graph(current_function->entry_block, block.self) || + block.loop_dominator != BlockID(SPIRBlock::NoDominator)) + { + // If this block is the very final block and not called from control flow, + // we do not need an explicit return which looks out of place. Just end the function here. + // In the very weird case of for(;;) { return; } executing return is unconditional, + // but we actually need a return here ... + statement("return;"); + } + break; + } + + // If the Kill is terminating a block with a (probably synthetic) return value, emit a return value statement. + case SPIRBlock::Kill: + statement(backend.discard_literal, ";"); + if (block.return_value) + statement("return ", to_unpacked_expression(block.return_value), ";"); + break; + + case SPIRBlock::Unreachable: + { + // Avoid emitting false fallthrough, which can happen for + // if (cond) break; else discard; inside a case label. + // Discard is not always implementable as a terminator. + + auto &cfg = get_cfg_for_current_function(); + bool inner_dominator_is_switch = false; + ID id = block.self; + + while (id) + { + auto &iter_block = get(id); + if (iter_block.terminator == SPIRBlock::MultiSelect || + iter_block.merge == SPIRBlock::MergeLoop) + { + ID next_block = iter_block.merge == SPIRBlock::MergeLoop ? + iter_block.merge_block : iter_block.next_block; + bool outside_construct = next_block && cfg.find_common_dominator(next_block, block.self) == next_block; + if (!outside_construct) + { + inner_dominator_is_switch = iter_block.terminator == SPIRBlock::MultiSelect; + break; + } + } + + if (cfg.get_preceding_edges(id).empty()) + break; + + id = cfg.get_immediate_dominator(id); + } + + if (inner_dominator_is_switch) + statement("break; // unreachable workaround"); + + emit_next_block = false; + break; + } + + case SPIRBlock::IgnoreIntersection: + statement("ignoreIntersectionEXT;"); + break; + + case SPIRBlock::TerminateRay: + statement("terminateRayEXT;"); + break; + + case SPIRBlock::EmitMeshTasks: + emit_mesh_tasks(block); + break; + + default: + SPIRV_CROSS_THROW("Unimplemented block terminator."); + } + + if (block.next_block && emit_next_block) + { + // If we hit this case, we're dealing with an unconditional branch, which means we will output + // that block after this. If we had selection merge, we already flushed phi variables. + if (block.merge != SPIRBlock::MergeSelection) + { + flush_phi(block.self, block.next_block); + // For a direct branch, need to remember to invalidate expressions in the next linear block instead. + get(block.next_block).invalidate_expressions = block.invalidate_expressions; + } + + // For switch fallthrough cases, we terminate the chain here, but we still need to handle Phi. + if (!current_emitting_switch_fallthrough) + { + // For merge selects we might have ignored the fact that a merge target + // could have been a break; or continue; + // We will need to deal with it here. + if (is_loop_break(block.next_block)) + { + // Cannot check for just break, because switch statements will also use break. + assert(block.merge == SPIRBlock::MergeSelection); + statement("break;"); + } + else if (is_continue(block.next_block)) + { + assert(block.merge == SPIRBlock::MergeSelection); + branch_to_continue(block.self, block.next_block); + } + else if (BlockID(block.self) != block.next_block) + emit_block_chain(get(block.next_block)); + } + } + + if (block.merge == SPIRBlock::MergeLoop) + { + if (continue_type == SPIRBlock::DoWhileLoop) + { + // Make sure that we run the continue block to get the expressions set, but this + // should become an empty string. + // We have no fallbacks if we cannot forward everything to temporaries ... + const auto &continue_block = get(block.continue_block); + bool positive_test = execution_is_noop(get(continue_block.true_block), + get(continue_block.loop_dominator)); + + uint32_t current_count = statement_count; + auto statements = emit_continue_block(block.continue_block, positive_test, !positive_test); + if (statement_count != current_count) + { + // The DoWhile block has side effects, force ComplexLoop pattern next pass. + get(block.continue_block).complex_continue = true; + force_recompile(); + } + + // Might have to invert the do-while test here. + auto condition = to_expression(continue_block.condition); + if (!positive_test) + condition = join("!", enclose_expression(condition)); + + end_scope_decl(join("while (", condition, ")")); + } + else + end_scope(); + + loop_level_saver.release(); + + // We cannot break out of two loops at once, so don't check for break; here. + // Using block.self as the "from" block isn't quite right, but it has the same scope + // and dominance structure, so it's fine. + if (is_continue(block.merge_block)) + branch_to_continue(block.self, block.merge_block); + else + emit_block_chain(get(block.merge_block)); + } + + // Forget about control dependent expressions now. + block.invalidate_expressions.clear(); + + // After we return, we must be out of scope, so if we somehow have to re-emit this function, + // re-declare variables if necessary. + assert(rearm_dominated_variables.size() == block.dominated_variables.size()); + for (size_t i = 0; i < block.dominated_variables.size(); i++) + { + uint32_t var = block.dominated_variables[i]; + get(var).deferred_declaration = rearm_dominated_variables[i]; + } + + // Just like for deferred declaration, we need to forget about loop variable enable + // if our block chain is reinstantiated later. + for (auto &var_id : block.loop_variables) + get(var_id).loop_variable_enable = false; +} + +void CompilerGLSL::begin_scope() +{ + statement("{"); + indent++; +} + +void CompilerGLSL::end_scope() +{ + if (!indent) + SPIRV_CROSS_THROW("Popping empty indent stack."); + indent--; + statement("}"); +} + +void CompilerGLSL::end_scope(const string &trailer) +{ + if (!indent) + SPIRV_CROSS_THROW("Popping empty indent stack."); + indent--; + statement("}", trailer); +} + +void CompilerGLSL::end_scope_decl() +{ + if (!indent) + SPIRV_CROSS_THROW("Popping empty indent stack."); + indent--; + statement("};"); +} + +void CompilerGLSL::end_scope_decl(const string &decl) +{ + if (!indent) + SPIRV_CROSS_THROW("Popping empty indent stack."); + indent--; + statement("} ", decl, ";"); +} + +void CompilerGLSL::check_function_call_constraints(const uint32_t *args, uint32_t length) +{ + // If our variable is remapped, and we rely on type-remapping information as + // well, then we cannot pass the variable as a function parameter. + // Fixing this is non-trivial without stamping out variants of the same function, + // so for now warn about this and suggest workarounds instead. + for (uint32_t i = 0; i < length; i++) + { + auto *var = maybe_get(args[i]); + if (!var || !var->remapped_variable) + continue; + + auto &type = get(var->basetype); + if (type.basetype == SPIRType::Image && type.image.dim == DimSubpassData) + { + SPIRV_CROSS_THROW("Tried passing a remapped subpassInput variable to a function. " + "This will not work correctly because type-remapping information is lost. " + "To workaround, please consider not passing the subpass input as a function parameter, " + "or use in/out variables instead which do not need type remapping information."); + } + } +} + +const Instruction *CompilerGLSL::get_next_instruction_in_block(const Instruction &instr) +{ + // FIXME: This is kind of hacky. There should be a cleaner way. + auto offset = uint32_t(&instr - current_emitting_block->ops.data()); + if ((offset + 1) < current_emitting_block->ops.size()) + return ¤t_emitting_block->ops[offset + 1]; + else + return nullptr; +} + +uint32_t CompilerGLSL::mask_relevant_memory_semantics(uint32_t semantics) +{ + return semantics & (MemorySemanticsAtomicCounterMemoryMask | MemorySemanticsImageMemoryMask | + MemorySemanticsWorkgroupMemoryMask | MemorySemanticsUniformMemoryMask | + MemorySemanticsCrossWorkgroupMemoryMask | MemorySemanticsSubgroupMemoryMask); +} + +bool CompilerGLSL::emit_array_copy(const char *expr, uint32_t lhs_id, uint32_t rhs_id, StorageClass, StorageClass) +{ + string lhs; + if (expr) + lhs = expr; + else + lhs = to_expression(lhs_id); + + statement(lhs, " = ", to_expression(rhs_id), ";"); + return true; +} + +bool CompilerGLSL::unroll_array_to_complex_store(uint32_t target_id, uint32_t source_id) +{ + if (!backend.force_gl_in_out_block) + return false; + // This path is only relevant for GL backends. + + auto *var = maybe_get(target_id); + if (!var || var->storage != StorageClassOutput) + return false; + + if (!is_builtin_variable(*var) || BuiltIn(get_decoration(var->self, DecorationBuiltIn)) != BuiltInSampleMask) + return false; + + auto &type = expression_type(source_id); + string array_expr; + if (type.array_size_literal.back()) + { + array_expr = convert_to_string(type.array.back()); + if (type.array.back() == 0) + SPIRV_CROSS_THROW("Cannot unroll an array copy from unsized array."); + } + else + array_expr = to_expression(type.array.back()); + + SPIRType target_type { OpTypeInt }; + target_type.basetype = SPIRType::Int; + + statement("for (int i = 0; i < int(", array_expr, "); i++)"); + begin_scope(); + statement(to_expression(target_id), "[i] = ", + bitcast_expression(target_type, type.basetype, join(to_expression(source_id), "[i]")), + ";"); + end_scope(); + + return true; +} + +void CompilerGLSL::unroll_array_from_complex_load(uint32_t target_id, uint32_t source_id, std::string &expr) +{ + if (!backend.force_gl_in_out_block) + return; + // This path is only relevant for GL backends. + + auto *var = maybe_get(source_id); + if (!var) + return; + + if (var->storage != StorageClassInput && var->storage != StorageClassOutput) + return; + + auto &type = get_variable_data_type(*var); + if (type.array.empty()) + return; + + auto builtin = BuiltIn(get_decoration(var->self, DecorationBuiltIn)); + bool is_builtin = is_builtin_variable(*var) && + (builtin == BuiltInPointSize || + builtin == BuiltInPosition || + builtin == BuiltInSampleMask); + bool is_tess = is_tessellation_shader(); + bool is_patch = has_decoration(var->self, DecorationPatch); + bool is_sample_mask = is_builtin && builtin == BuiltInSampleMask; + + // Tessellation input arrays are special in that they are unsized, so we cannot directly copy from it. + // We must unroll the array load. + // For builtins, we couldn't catch this case normally, + // because this is resolved in the OpAccessChain in most cases. + // If we load the entire array, we have no choice but to unroll here. + if (!is_patch && (is_builtin || is_tess)) + { + auto new_expr = join("_", target_id, "_unrolled"); + statement(variable_decl(type, new_expr, target_id), ";"); + string array_expr; + if (type.array_size_literal.back()) + { + array_expr = convert_to_string(type.array.back()); + if (type.array.back() == 0) + SPIRV_CROSS_THROW("Cannot unroll an array copy from unsized array."); + } + else + array_expr = to_expression(type.array.back()); + + // The array size might be a specialization constant, so use a for-loop instead. + statement("for (int i = 0; i < int(", array_expr, "); i++)"); + begin_scope(); + if (is_builtin && !is_sample_mask) + statement(new_expr, "[i] = gl_in[i].", expr, ";"); + else if (is_sample_mask) + { + SPIRType target_type { OpTypeInt }; + target_type.basetype = SPIRType::Int; + statement(new_expr, "[i] = ", bitcast_expression(target_type, type.basetype, join(expr, "[i]")), ";"); + } + else + statement(new_expr, "[i] = ", expr, "[i];"); + end_scope(); + + expr = std::move(new_expr); + } +} + +void CompilerGLSL::cast_from_variable_load(uint32_t source_id, std::string &expr, const SPIRType &expr_type) +{ + // We will handle array cases elsewhere. + if (!expr_type.array.empty()) + return; + + auto *var = maybe_get_backing_variable(source_id); + if (var) + source_id = var->self; + + // Only interested in standalone builtin variables. + if (!has_decoration(source_id, DecorationBuiltIn)) + { + // Except for int attributes in legacy GLSL, which are cast from float. + if (is_legacy() && expr_type.basetype == SPIRType::Int && var && var->storage == StorageClassInput) + expr = join(type_to_glsl(expr_type), "(", expr, ")"); + return; + } + + auto builtin = static_cast(get_decoration(source_id, DecorationBuiltIn)); + auto expected_type = expr_type.basetype; + + // TODO: Fill in for more builtins. + switch (builtin) + { + case BuiltInLayer: + case BuiltInPrimitiveId: + case BuiltInViewportIndex: + case BuiltInInstanceId: + case BuiltInInstanceIndex: + case BuiltInVertexId: + case BuiltInVertexIndex: + case BuiltInSampleId: + case BuiltInBaseVertex: + case BuiltInBaseInstance: + case BuiltInDrawIndex: + case BuiltInFragStencilRefEXT: + case BuiltInInstanceCustomIndexNV: + case BuiltInSampleMask: + case BuiltInPrimitiveShadingRateKHR: + case BuiltInShadingRateKHR: + expected_type = SPIRType::Int; + break; + + case BuiltInGlobalInvocationId: + case BuiltInLocalInvocationId: + case BuiltInWorkgroupId: + case BuiltInLocalInvocationIndex: + case BuiltInWorkgroupSize: + case BuiltInNumWorkgroups: + case BuiltInIncomingRayFlagsNV: + case BuiltInLaunchIdNV: + case BuiltInLaunchSizeNV: + case BuiltInPrimitiveTriangleIndicesEXT: + case BuiltInPrimitiveLineIndicesEXT: + case BuiltInPrimitivePointIndicesEXT: + expected_type = SPIRType::UInt; + break; + + default: + break; + } + + if (expected_type != expr_type.basetype) + expr = bitcast_expression(expr_type, expected_type, expr); +} + +SPIRType::BaseType CompilerGLSL::get_builtin_basetype(BuiltIn builtin, SPIRType::BaseType default_type) +{ + // TODO: Fill in for more builtins. + switch (builtin) + { + case BuiltInLayer: + case BuiltInPrimitiveId: + case BuiltInViewportIndex: + case BuiltInFragStencilRefEXT: + case BuiltInSampleMask: + case BuiltInPrimitiveShadingRateKHR: + case BuiltInShadingRateKHR: + return SPIRType::Int; + + default: + return default_type; + } +} + +void CompilerGLSL::cast_to_variable_store(uint32_t target_id, std::string &expr, const SPIRType &expr_type) +{ + auto *var = maybe_get_backing_variable(target_id); + if (var) + target_id = var->self; + + // Only interested in standalone builtin variables. + if (!has_decoration(target_id, DecorationBuiltIn)) + return; + + auto builtin = static_cast(get_decoration(target_id, DecorationBuiltIn)); + auto expected_type = get_builtin_basetype(builtin, expr_type.basetype); + + if (expected_type != expr_type.basetype) + { + auto type = expr_type; + type.basetype = expected_type; + expr = bitcast_expression(type, expr_type.basetype, expr); + } +} + +void CompilerGLSL::convert_non_uniform_expression(string &expr, uint32_t ptr_id) +{ + if (*backend.nonuniform_qualifier == '\0') + return; + + auto *var = maybe_get_backing_variable(ptr_id); + if (!var) + return; + + if (var->storage != StorageClassUniformConstant && + var->storage != StorageClassStorageBuffer && + var->storage != StorageClassUniform) + return; + + auto &backing_type = get(var->basetype); + if (backing_type.array.empty()) + return; + + // If we get here, we know we're accessing an arrayed resource which + // might require nonuniform qualifier. + + auto start_array_index = expr.find_first_of('['); + + if (start_array_index == string::npos) + return; + + // We've opened a bracket, track expressions until we can close the bracket. + // This must be our resource index. + size_t end_array_index = string::npos; + unsigned bracket_count = 1; + for (size_t index = start_array_index + 1; index < expr.size(); index++) + { + if (expr[index] == ']') + { + if (--bracket_count == 0) + { + end_array_index = index; + break; + } + } + else if (expr[index] == '[') + bracket_count++; + } + + assert(bracket_count == 0); + + // Doesn't really make sense to declare a non-arrayed image with nonuniformEXT, but there's + // nothing we can do here to express that. + if (start_array_index == string::npos || end_array_index == string::npos || end_array_index < start_array_index) + return; + + start_array_index++; + + expr = join(expr.substr(0, start_array_index), backend.nonuniform_qualifier, "(", + expr.substr(start_array_index, end_array_index - start_array_index), ")", + expr.substr(end_array_index, string::npos)); +} + +void CompilerGLSL::emit_block_hints(const SPIRBlock &block) +{ + if ((options.es && options.version < 310) || (!options.es && options.version < 140)) + return; + + switch (block.hint) + { + case SPIRBlock::HintFlatten: + require_extension_internal("GL_EXT_control_flow_attributes"); + statement("SPIRV_CROSS_FLATTEN"); + break; + case SPIRBlock::HintDontFlatten: + require_extension_internal("GL_EXT_control_flow_attributes"); + statement("SPIRV_CROSS_BRANCH"); + break; + case SPIRBlock::HintUnroll: + require_extension_internal("GL_EXT_control_flow_attributes"); + statement("SPIRV_CROSS_UNROLL"); + break; + case SPIRBlock::HintDontUnroll: + require_extension_internal("GL_EXT_control_flow_attributes"); + statement("SPIRV_CROSS_LOOP"); + break; + default: + break; + } +} + +void CompilerGLSL::preserve_alias_on_reset(uint32_t id) +{ + preserved_aliases[id] = get_name(id); +} + +void CompilerGLSL::reset_name_caches() +{ + for (auto &preserved : preserved_aliases) + set_name(preserved.first, preserved.second); + + preserved_aliases.clear(); + resource_names.clear(); + block_input_names.clear(); + block_output_names.clear(); + block_ubo_names.clear(); + block_ssbo_names.clear(); + block_names.clear(); + function_overloads.clear(); +} + +void CompilerGLSL::fixup_anonymous_struct_names(std::unordered_set &visited, const SPIRType &type) +{ + if (visited.count(type.self)) + return; + visited.insert(type.self); + + for (uint32_t i = 0; i < uint32_t(type.member_types.size()); i++) + { + auto &mbr_type = get(type.member_types[i]); + + if (mbr_type.basetype == SPIRType::Struct) + { + // If there are multiple aliases, the output might be somewhat unpredictable, + // but the only real alternative in that case is to do nothing, which isn't any better. + // This check should be fine in practice. + if (get_name(mbr_type.self).empty() && !get_member_name(type.self, i).empty()) + { + auto anon_name = join("anon_", get_member_name(type.self, i)); + ParsedIR::sanitize_underscores(anon_name); + set_name(mbr_type.self, anon_name); + } + + fixup_anonymous_struct_names(visited, mbr_type); + } + } +} + +void CompilerGLSL::fixup_anonymous_struct_names() +{ + // HLSL codegen can often end up emitting anonymous structs inside blocks, which + // breaks GL linking since all names must match ... + // Try to emit sensible code, so attempt to find such structs and emit anon_$member. + + // Breaks exponential explosion with weird type trees. + std::unordered_set visited; + + ir.for_each_typed_id([&](uint32_t, SPIRType &type) { + if (type.basetype == SPIRType::Struct && + (has_decoration(type.self, DecorationBlock) || + has_decoration(type.self, DecorationBufferBlock))) + { + fixup_anonymous_struct_names(visited, type); + } + }); +} + +void CompilerGLSL::fixup_type_alias() +{ + // Due to how some backends work, the "master" type of type_alias must be a block-like type if it exists. + ir.for_each_typed_id([&](uint32_t self, SPIRType &type) { + if (!type.type_alias) + return; + + if (has_decoration(type.self, DecorationBlock) || has_decoration(type.self, DecorationBufferBlock)) + { + // Top-level block types should never alias anything else. + type.type_alias = 0; + } + else if (type_is_block_like(type) && type.self == ID(self)) + { + // A block-like type is any type which contains Offset decoration, but not top-level blocks, + // i.e. blocks which are placed inside buffers. + // Become the master. + ir.for_each_typed_id([&](uint32_t other_id, SPIRType &other_type) { + if (other_id == self) + return; + + if (other_type.type_alias == type.type_alias) + other_type.type_alias = self; + }); + + this->get(type.type_alias).type_alias = self; + type.type_alias = 0; + } + }); +} + +void CompilerGLSL::reorder_type_alias() +{ + // Reorder declaration of types so that the master of the type alias is always emitted first. + // We need this in case a type B depends on type A (A must come before in the vector), but A is an alias of a type Abuffer, which + // means declaration of A doesn't happen (yet), and order would be B, ABuffer and not ABuffer, B. Fix this up here. + auto loop_lock = ir.create_loop_hard_lock(); + + auto &type_ids = ir.ids_for_type[TypeType]; + for (auto alias_itr = begin(type_ids); alias_itr != end(type_ids); ++alias_itr) + { + auto &type = get(*alias_itr); + if (type.type_alias != TypeID(0) && + !has_extended_decoration(type.type_alias, SPIRVCrossDecorationBufferBlockRepacked)) + { + // We will skip declaring this type, so make sure the type_alias type comes before. + auto master_itr = find(begin(type_ids), end(type_ids), ID(type.type_alias)); + assert(master_itr != end(type_ids)); + + if (alias_itr < master_itr) + { + // Must also swap the type order for the constant-type joined array. + auto &joined_types = ir.ids_for_constant_undef_or_type; + auto alt_alias_itr = find(begin(joined_types), end(joined_types), *alias_itr); + auto alt_master_itr = find(begin(joined_types), end(joined_types), *master_itr); + assert(alt_alias_itr != end(joined_types)); + assert(alt_master_itr != end(joined_types)); + + swap(*alias_itr, *master_itr); + swap(*alt_alias_itr, *alt_master_itr); + } + } + } +} + +void CompilerGLSL::emit_line_directive(uint32_t file_id, uint32_t line_literal) +{ + // If we are redirecting statements, ignore the line directive. + // Common case here is continue blocks. + if (redirect_statement) + return; + + // If we're emitting code in a sensitive context such as condition blocks in for loops, don't emit + // any line directives, because it's not possible. + if (block_debug_directives) + return; + + if (options.emit_line_directives) + { + require_extension_internal("GL_GOOGLE_cpp_style_line_directive"); + statement_no_indent("#line ", line_literal, " \"", get(file_id).str, "\""); + } +} + +void CompilerGLSL::emit_copy_logical_type(uint32_t lhs_id, uint32_t lhs_type_id, uint32_t rhs_id, uint32_t rhs_type_id, + SmallVector chain) +{ + // Fully unroll all member/array indices one by one. + + auto &lhs_type = get(lhs_type_id); + auto &rhs_type = get(rhs_type_id); + + if (!lhs_type.array.empty()) + { + // Could use a loop here to support specialization constants, but it gets rather complicated with nested array types, + // and this is a rather obscure opcode anyways, keep it simple unless we are forced to. + uint32_t array_size = to_array_size_literal(lhs_type); + chain.push_back(0); + + for (uint32_t i = 0; i < array_size; i++) + { + chain.back() = i; + emit_copy_logical_type(lhs_id, lhs_type.parent_type, rhs_id, rhs_type.parent_type, chain); + } + } + else if (lhs_type.basetype == SPIRType::Struct) + { + chain.push_back(0); + uint32_t member_count = uint32_t(lhs_type.member_types.size()); + for (uint32_t i = 0; i < member_count; i++) + { + chain.back() = i; + emit_copy_logical_type(lhs_id, lhs_type.member_types[i], rhs_id, rhs_type.member_types[i], chain); + } + } + else + { + // Need to handle unpack/packing fixups since this can differ wildly between the logical types, + // particularly in MSL. + // To deal with this, we emit access chains and go through emit_store_statement + // to deal with all the special cases we can encounter. + + AccessChainMeta lhs_meta, rhs_meta; + auto lhs = access_chain_internal(lhs_id, chain.data(), uint32_t(chain.size()), + ACCESS_CHAIN_INDEX_IS_LITERAL_BIT, &lhs_meta); + auto rhs = access_chain_internal(rhs_id, chain.data(), uint32_t(chain.size()), + ACCESS_CHAIN_INDEX_IS_LITERAL_BIT, &rhs_meta); + + uint32_t id = ir.increase_bound_by(2); + lhs_id = id; + rhs_id = id + 1; + + { + auto &lhs_expr = set(lhs_id, std::move(lhs), lhs_type_id, true); + lhs_expr.need_transpose = lhs_meta.need_transpose; + + if (lhs_meta.storage_is_packed) + set_extended_decoration(lhs_id, SPIRVCrossDecorationPhysicalTypePacked); + if (lhs_meta.storage_physical_type != 0) + set_extended_decoration(lhs_id, SPIRVCrossDecorationPhysicalTypeID, lhs_meta.storage_physical_type); + + forwarded_temporaries.insert(lhs_id); + suppressed_usage_tracking.insert(lhs_id); + } + + { + auto &rhs_expr = set(rhs_id, std::move(rhs), rhs_type_id, true); + rhs_expr.need_transpose = rhs_meta.need_transpose; + + if (rhs_meta.storage_is_packed) + set_extended_decoration(rhs_id, SPIRVCrossDecorationPhysicalTypePacked); + if (rhs_meta.storage_physical_type != 0) + set_extended_decoration(rhs_id, SPIRVCrossDecorationPhysicalTypeID, rhs_meta.storage_physical_type); + + forwarded_temporaries.insert(rhs_id); + suppressed_usage_tracking.insert(rhs_id); + } + + emit_store_statement(lhs_id, rhs_id); + } +} + +bool CompilerGLSL::subpass_input_is_framebuffer_fetch(uint32_t id) const +{ + if (!has_decoration(id, DecorationInputAttachmentIndex)) + return false; + + uint32_t input_attachment_index = get_decoration(id, DecorationInputAttachmentIndex); + for (auto &remap : subpass_to_framebuffer_fetch_attachment) + if (remap.first == input_attachment_index) + return true; + + return false; +} + +const SPIRVariable *CompilerGLSL::find_subpass_input_by_attachment_index(uint32_t index) const +{ + const SPIRVariable *ret = nullptr; + ir.for_each_typed_id([&](uint32_t, const SPIRVariable &var) { + if (has_decoration(var.self, DecorationInputAttachmentIndex) && + get_decoration(var.self, DecorationInputAttachmentIndex) == index) + { + ret = &var; + } + }); + return ret; +} + +const SPIRVariable *CompilerGLSL::find_color_output_by_location(uint32_t location) const +{ + const SPIRVariable *ret = nullptr; + ir.for_each_typed_id([&](uint32_t, const SPIRVariable &var) { + if (var.storage == StorageClassOutput && get_decoration(var.self, DecorationLocation) == location) + ret = &var; + }); + return ret; +} + +void CompilerGLSL::emit_inout_fragment_outputs_copy_to_subpass_inputs() +{ + for (auto &remap : subpass_to_framebuffer_fetch_attachment) + { + auto *subpass_var = find_subpass_input_by_attachment_index(remap.first); + auto *output_var = find_color_output_by_location(remap.second); + if (!subpass_var) + continue; + if (!output_var) + SPIRV_CROSS_THROW("Need to declare the corresponding fragment output variable to be able " + "to read from it."); + if (is_array(get(output_var->basetype))) + SPIRV_CROSS_THROW("Cannot use GL_EXT_shader_framebuffer_fetch with arrays of color outputs."); + + auto &func = get(get_entry_point().self); + func.fixup_hooks_in.push_back([=]() { + if (is_legacy()) + { + statement(to_expression(subpass_var->self), " = ", "gl_LastFragData[", + get_decoration(output_var->self, DecorationLocation), "];"); + } + else + { + uint32_t num_rt_components = this->get(output_var->basetype).vecsize; + statement(to_expression(subpass_var->self), vector_swizzle(num_rt_components, 0), " = ", + to_expression(output_var->self), ";"); + } + }); + } +} + +bool CompilerGLSL::variable_is_depth_or_compare(VariableID id) const +{ + return is_depth_image(get(get(id).basetype), id); +} + +const char *CompilerGLSL::ShaderSubgroupSupportHelper::get_extension_name(Candidate c) +{ + static const char *const retval[CandidateCount] = { "GL_KHR_shader_subgroup_ballot", + "GL_KHR_shader_subgroup_basic", + "GL_KHR_shader_subgroup_vote", + "GL_KHR_shader_subgroup_arithmetic", + "GL_NV_gpu_shader_5", + "GL_NV_shader_thread_group", + "GL_NV_shader_thread_shuffle", + "GL_ARB_shader_ballot", + "GL_ARB_shader_group_vote", + "GL_AMD_gcn_shader" }; + return retval[c]; +} + +SmallVector CompilerGLSL::ShaderSubgroupSupportHelper::get_extra_required_extension_names(Candidate c) +{ + switch (c) + { + case ARB_shader_ballot: + return { "GL_ARB_shader_int64" }; + case AMD_gcn_shader: + return { "GL_AMD_gpu_shader_int64", "GL_NV_gpu_shader5" }; + default: + return {}; + } +} + +const char *CompilerGLSL::ShaderSubgroupSupportHelper::get_extra_required_extension_predicate(Candidate c) +{ + switch (c) + { + case ARB_shader_ballot: + return "defined(GL_ARB_shader_int64)"; + case AMD_gcn_shader: + return "(defined(GL_AMD_gpu_shader_int64) || defined(GL_NV_gpu_shader5))"; + default: + return ""; + } +} + +CompilerGLSL::ShaderSubgroupSupportHelper::FeatureVector CompilerGLSL::ShaderSubgroupSupportHelper:: + get_feature_dependencies(Feature feature) +{ + switch (feature) + { + case SubgroupAllEqualT: + return { SubgroupBroadcast_First, SubgroupAll_Any_AllEqualBool }; + case SubgroupElect: + return { SubgroupBallotFindLSB_MSB, SubgroupBallot, SubgroupInvocationID }; + case SubgroupInverseBallot_InclBitCount_ExclBitCout: + return { SubgroupMask }; + case SubgroupBallotBitCount: + return { SubgroupBallot }; + case SubgroupArithmeticIAddReduce: + case SubgroupArithmeticIAddInclusiveScan: + case SubgroupArithmeticFAddReduce: + case SubgroupArithmeticFAddInclusiveScan: + case SubgroupArithmeticIMulReduce: + case SubgroupArithmeticIMulInclusiveScan: + case SubgroupArithmeticFMulReduce: + case SubgroupArithmeticFMulInclusiveScan: + return { SubgroupSize, SubgroupBallot, SubgroupBallotBitCount, SubgroupMask, SubgroupBallotBitExtract }; + case SubgroupArithmeticIAddExclusiveScan: + case SubgroupArithmeticFAddExclusiveScan: + case SubgroupArithmeticIMulExclusiveScan: + case SubgroupArithmeticFMulExclusiveScan: + return { SubgroupSize, SubgroupBallot, SubgroupBallotBitCount, + SubgroupMask, SubgroupElect, SubgroupBallotBitExtract }; + default: + return {}; + } +} + +CompilerGLSL::ShaderSubgroupSupportHelper::FeatureMask CompilerGLSL::ShaderSubgroupSupportHelper:: + get_feature_dependency_mask(Feature feature) +{ + return build_mask(get_feature_dependencies(feature)); +} + +bool CompilerGLSL::ShaderSubgroupSupportHelper::can_feature_be_implemented_without_extensions(Feature feature) +{ + static const bool retval[FeatureCount] = { + false, false, false, false, false, false, + true, // SubgroupBalloFindLSB_MSB + false, false, false, false, + true, // SubgroupMemBarrier - replaced with workgroup memory barriers + false, false, true, false, + false, false, false, false, false, false, // iadd, fadd + false, false, false, false, false, false, // imul , fmul + }; + + return retval[feature]; +} + +CompilerGLSL::ShaderSubgroupSupportHelper::Candidate CompilerGLSL::ShaderSubgroupSupportHelper:: + get_KHR_extension_for_feature(Feature feature) +{ + static const Candidate extensions[FeatureCount] = { + KHR_shader_subgroup_ballot, KHR_shader_subgroup_basic, KHR_shader_subgroup_basic, KHR_shader_subgroup_basic, + KHR_shader_subgroup_basic, KHR_shader_subgroup_ballot, KHR_shader_subgroup_ballot, KHR_shader_subgroup_vote, + KHR_shader_subgroup_vote, KHR_shader_subgroup_basic, KHR_shader_subgroup_basic, KHR_shader_subgroup_basic, + KHR_shader_subgroup_ballot, KHR_shader_subgroup_ballot, KHR_shader_subgroup_ballot, KHR_shader_subgroup_ballot, + KHR_shader_subgroup_arithmetic, KHR_shader_subgroup_arithmetic, KHR_shader_subgroup_arithmetic, + KHR_shader_subgroup_arithmetic, KHR_shader_subgroup_arithmetic, KHR_shader_subgroup_arithmetic, + KHR_shader_subgroup_arithmetic, KHR_shader_subgroup_arithmetic, KHR_shader_subgroup_arithmetic, + KHR_shader_subgroup_arithmetic, KHR_shader_subgroup_arithmetic, KHR_shader_subgroup_arithmetic, + }; + + return extensions[feature]; +} + +void CompilerGLSL::ShaderSubgroupSupportHelper::request_feature(Feature feature) +{ + feature_mask |= (FeatureMask(1) << feature) | get_feature_dependency_mask(feature); +} + +bool CompilerGLSL::ShaderSubgroupSupportHelper::is_feature_requested(Feature feature) const +{ + return (feature_mask & (1u << feature)) != 0; +} + +CompilerGLSL::ShaderSubgroupSupportHelper::Result CompilerGLSL::ShaderSubgroupSupportHelper::resolve() const +{ + Result res; + + for (uint32_t i = 0u; i < FeatureCount; ++i) + { + if (feature_mask & (1u << i)) + { + auto feature = static_cast(i); + std::unordered_set unique_candidates; + + auto candidates = get_candidates_for_feature(feature); + unique_candidates.insert(candidates.begin(), candidates.end()); + + auto deps = get_feature_dependencies(feature); + for (Feature d : deps) + { + candidates = get_candidates_for_feature(d); + if (!candidates.empty()) + unique_candidates.insert(candidates.begin(), candidates.end()); + } + + for (uint32_t c : unique_candidates) + ++res.weights[static_cast(c)]; + } + } + + return res; +} + +CompilerGLSL::ShaderSubgroupSupportHelper::CandidateVector CompilerGLSL::ShaderSubgroupSupportHelper:: + get_candidates_for_feature(Feature ft, const Result &r) +{ + auto c = get_candidates_for_feature(ft); + auto cmp = [&r](Candidate a, Candidate b) { + if (r.weights[a] == r.weights[b]) + return a < b; // Prefer candidates with lower enum value + return r.weights[a] > r.weights[b]; + }; + std::sort(c.begin(), c.end(), cmp); + return c; +} + +CompilerGLSL::ShaderSubgroupSupportHelper::CandidateVector CompilerGLSL::ShaderSubgroupSupportHelper:: + get_candidates_for_feature(Feature feature) +{ + switch (feature) + { + case SubgroupMask: + return { KHR_shader_subgroup_ballot, NV_shader_thread_group, ARB_shader_ballot }; + case SubgroupSize: + return { KHR_shader_subgroup_basic, NV_shader_thread_group, AMD_gcn_shader, ARB_shader_ballot }; + case SubgroupInvocationID: + return { KHR_shader_subgroup_basic, NV_shader_thread_group, ARB_shader_ballot }; + case SubgroupID: + return { KHR_shader_subgroup_basic, NV_shader_thread_group }; + case NumSubgroups: + return { KHR_shader_subgroup_basic, NV_shader_thread_group }; + case SubgroupBroadcast_First: + return { KHR_shader_subgroup_ballot, NV_shader_thread_shuffle, ARB_shader_ballot }; + case SubgroupBallotFindLSB_MSB: + return { KHR_shader_subgroup_ballot, NV_shader_thread_group }; + case SubgroupAll_Any_AllEqualBool: + return { KHR_shader_subgroup_vote, NV_gpu_shader_5, ARB_shader_group_vote, AMD_gcn_shader }; + case SubgroupAllEqualT: + return {}; // depends on other features only + case SubgroupElect: + return {}; // depends on other features only + case SubgroupBallot: + return { KHR_shader_subgroup_ballot, NV_shader_thread_group, ARB_shader_ballot }; + case SubgroupBarrier: + return { KHR_shader_subgroup_basic, NV_shader_thread_group, ARB_shader_ballot, AMD_gcn_shader }; + case SubgroupMemBarrier: + return { KHR_shader_subgroup_basic }; + case SubgroupInverseBallot_InclBitCount_ExclBitCout: + return {}; + case SubgroupBallotBitExtract: + return { NV_shader_thread_group }; + case SubgroupBallotBitCount: + return {}; + case SubgroupArithmeticIAddReduce: + case SubgroupArithmeticIAddExclusiveScan: + case SubgroupArithmeticIAddInclusiveScan: + case SubgroupArithmeticFAddReduce: + case SubgroupArithmeticFAddExclusiveScan: + case SubgroupArithmeticFAddInclusiveScan: + case SubgroupArithmeticIMulReduce: + case SubgroupArithmeticIMulExclusiveScan: + case SubgroupArithmeticIMulInclusiveScan: + case SubgroupArithmeticFMulReduce: + case SubgroupArithmeticFMulExclusiveScan: + case SubgroupArithmeticFMulInclusiveScan: + return { KHR_shader_subgroup_arithmetic, NV_shader_thread_shuffle }; + default: + return {}; + } +} + +CompilerGLSL::ShaderSubgroupSupportHelper::FeatureMask CompilerGLSL::ShaderSubgroupSupportHelper::build_mask( + const SmallVector &features) +{ + FeatureMask mask = 0; + for (Feature f : features) + mask |= FeatureMask(1) << f; + return mask; +} + +CompilerGLSL::ShaderSubgroupSupportHelper::Result::Result() +{ + for (auto &weight : weights) + weight = 0; + + // Make sure KHR_shader_subgroup extensions are always prefered. + const uint32_t big_num = FeatureCount; + weights[KHR_shader_subgroup_ballot] = big_num; + weights[KHR_shader_subgroup_basic] = big_num; + weights[KHR_shader_subgroup_vote] = big_num; + weights[KHR_shader_subgroup_arithmetic] = big_num; +} + +void CompilerGLSL::request_workaround_wrapper_overload(TypeID id) +{ + // Must be ordered to maintain deterministic output, so vector is appropriate. + if (find(begin(workaround_ubo_load_overload_types), end(workaround_ubo_load_overload_types), id) == + end(workaround_ubo_load_overload_types)) + { + force_recompile(); + workaround_ubo_load_overload_types.push_back(id); + } +} + +void CompilerGLSL::rewrite_load_for_wrapped_row_major(std::string &expr, TypeID loaded_type, ID ptr) +{ + // Loading row-major matrices from UBOs on older AMD Windows OpenGL drivers is problematic. + // To load these types correctly, we must first wrap them in a dummy function which only purpose is to + // ensure row_major decoration is actually respected. + auto *var = maybe_get_backing_variable(ptr); + if (!var) + return; + + auto &backing_type = get(var->basetype); + bool is_ubo = backing_type.basetype == SPIRType::Struct && backing_type.storage == StorageClassUniform && + has_decoration(backing_type.self, DecorationBlock); + if (!is_ubo) + return; + + auto *type = &get(loaded_type); + bool rewrite = false; + bool relaxed = options.es; + + if (is_matrix(*type)) + { + // To avoid adding a lot of unnecessary meta tracking to forward the row_major state, + // we will simply look at the base struct itself. It is exceptionally rare to mix and match row-major/col-major state. + // If there is any row-major action going on, we apply the workaround. + // It is harmless to apply the workaround to column-major matrices, so this is still a valid solution. + // If an access chain occurred, the workaround is not required, so loading vectors or scalars don't need workaround. + type = &backing_type; + } + else + { + // If we're loading a composite, we don't have overloads like these. + relaxed = false; + } + + if (type->basetype == SPIRType::Struct) + { + // If we're loading a struct where any member is a row-major matrix, apply the workaround. + for (uint32_t i = 0; i < uint32_t(type->member_types.size()); i++) + { + auto decorations = combined_decoration_for_member(*type, i); + if (decorations.get(DecorationRowMajor)) + rewrite = true; + + // Since we decide on a per-struct basis, only use mediump wrapper if all candidates are mediump. + if (!decorations.get(DecorationRelaxedPrecision)) + relaxed = false; + } + } + + if (rewrite) + { + request_workaround_wrapper_overload(loaded_type); + expr = join("spvWorkaroundRowMajor", (relaxed ? "MP" : ""), "(", expr, ")"); + } +} + +void CompilerGLSL::mask_stage_output_by_location(uint32_t location, uint32_t component) +{ + masked_output_locations.insert({ location, component }); +} + +void CompilerGLSL::mask_stage_output_by_builtin(BuiltIn builtin) +{ + masked_output_builtins.insert(builtin); +} + +bool CompilerGLSL::is_stage_output_variable_masked(const SPIRVariable &var) const +{ + auto &type = get(var.basetype); + bool is_block = has_decoration(type.self, DecorationBlock); + // Blocks by themselves are never masked. Must be masked per-member. + if (is_block) + return false; + + bool is_builtin = has_decoration(var.self, DecorationBuiltIn); + + if (is_builtin) + { + return is_stage_output_builtin_masked(BuiltIn(get_decoration(var.self, DecorationBuiltIn))); + } + else + { + if (!has_decoration(var.self, DecorationLocation)) + return false; + + return is_stage_output_location_masked( + get_decoration(var.self, DecorationLocation), + get_decoration(var.self, DecorationComponent)); + } +} + +bool CompilerGLSL::is_stage_output_block_member_masked(const SPIRVariable &var, uint32_t index, bool strip_array) const +{ + auto &type = get(var.basetype); + bool is_block = has_decoration(type.self, DecorationBlock); + if (!is_block) + return false; + + BuiltIn builtin = BuiltInMax; + if (is_member_builtin(type, index, &builtin)) + { + return is_stage_output_builtin_masked(builtin); + } + else + { + uint32_t location = get_declared_member_location(var, index, strip_array); + uint32_t component = get_member_decoration(type.self, index, DecorationComponent); + return is_stage_output_location_masked(location, component); + } +} + +bool CompilerGLSL::is_per_primitive_variable(const SPIRVariable &var) const +{ + if (has_decoration(var.self, DecorationPerPrimitiveEXT)) + return true; + + auto &type = get(var.basetype); + if (!has_decoration(type.self, DecorationBlock)) + return false; + + for (uint32_t i = 0, n = uint32_t(type.member_types.size()); i < n; i++) + if (!has_member_decoration(type.self, i, DecorationPerPrimitiveEXT)) + return false; + + return true; +} + +bool CompilerGLSL::is_stage_output_location_masked(uint32_t location, uint32_t component) const +{ + return masked_output_locations.count({ location, component }) != 0; +} + +bool CompilerGLSL::is_stage_output_builtin_masked(spv::BuiltIn builtin) const +{ + return masked_output_builtins.count(builtin) != 0; +} + +uint32_t CompilerGLSL::get_declared_member_location(const SPIRVariable &var, uint32_t mbr_idx, bool strip_array) const +{ + auto &block_type = get(var.basetype); + if (has_member_decoration(block_type.self, mbr_idx, DecorationLocation)) + return get_member_decoration(block_type.self, mbr_idx, DecorationLocation); + else + return get_accumulated_member_location(var, mbr_idx, strip_array); +} + +uint32_t CompilerGLSL::get_accumulated_member_location(const SPIRVariable &var, uint32_t mbr_idx, bool strip_array) const +{ + auto &type = strip_array ? get_variable_element_type(var) : get_variable_data_type(var); + uint32_t location = get_decoration(var.self, DecorationLocation); + + for (uint32_t i = 0; i < mbr_idx; i++) + { + auto &mbr_type = get(type.member_types[i]); + + // Start counting from any place we have a new location decoration. + if (has_member_decoration(type.self, mbr_idx, DecorationLocation)) + location = get_member_decoration(type.self, mbr_idx, DecorationLocation); + + uint32_t location_count = type_to_location_count(mbr_type); + location += location_count; + } + + return location; +} + +StorageClass CompilerGLSL::get_expression_effective_storage_class(uint32_t ptr) +{ + auto *var = maybe_get_backing_variable(ptr); + + // If the expression has been lowered to a temporary, we need to use the Generic storage class. + // We're looking for the effective storage class of a given expression. + // An access chain or forwarded OpLoads from such access chains + // will generally have the storage class of the underlying variable, but if the load was not forwarded + // we have lost any address space qualifiers. + bool forced_temporary = ir.ids[ptr].get_type() == TypeExpression && !get(ptr).access_chain && + (forced_temporaries.count(ptr) != 0 || forwarded_temporaries.count(ptr) == 0); + + if (var && !forced_temporary) + { + if (variable_decl_is_remapped_storage(*var, StorageClassWorkgroup)) + return StorageClassWorkgroup; + if (variable_decl_is_remapped_storage(*var, StorageClassStorageBuffer)) + return StorageClassStorageBuffer; + + // Normalize SSBOs to StorageBuffer here. + if (var->storage == StorageClassUniform && + has_decoration(get(var->basetype).self, DecorationBufferBlock)) + return StorageClassStorageBuffer; + else + return var->storage; + } + else + return expression_type(ptr).storage; +} + +uint32_t CompilerGLSL::type_to_location_count(const SPIRType &type) const +{ + uint32_t count; + if (type.basetype == SPIRType::Struct) + { + uint32_t mbr_count = uint32_t(type.member_types.size()); + count = 0; + for (uint32_t i = 0; i < mbr_count; i++) + count += type_to_location_count(get(type.member_types[i])); + } + else + { + count = type.columns > 1 ? type.columns : 1; + } + + uint32_t dim_count = uint32_t(type.array.size()); + for (uint32_t i = 0; i < dim_count; i++) + count *= to_array_size_literal(type, i); + + return count; +} + +std::string CompilerGLSL::format_float(float value) const +{ + if (float_formatter) + return float_formatter->format_float(value); + + // default behavior + return convert_to_string(value, current_locale_radix_character); +} + +std::string CompilerGLSL::format_double(double value) const +{ + if (float_formatter) + return float_formatter->format_double(value); + + // default behavior + return convert_to_string(value, current_locale_radix_character); +} + diff --git a/thirdparty/spirv-cross/spirv_glsl.hpp b/thirdparty/spirv-cross/spirv_glsl.hpp new file mode 100644 index 00000000000..8a00263234d --- /dev/null +++ b/thirdparty/spirv-cross/spirv_glsl.hpp @@ -0,0 +1,1074 @@ +/* + * Copyright 2015-2021 Arm Limited + * SPDX-License-Identifier: Apache-2.0 OR MIT + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +/* + * At your option, you may choose to accept this material under either: + * 1. The Apache License, Version 2.0, found at , or + * 2. The MIT License, found at . + */ + +#ifndef SPIRV_CROSS_GLSL_HPP +#define SPIRV_CROSS_GLSL_HPP + +#include "GLSL.std.450.h" +#include "spirv_cross.hpp" +#include +#include +#include + +namespace SPIRV_CROSS_NAMESPACE +{ +enum PlsFormat +{ + PlsNone = 0, + + PlsR11FG11FB10F, + PlsR32F, + PlsRG16F, + PlsRGB10A2, + PlsRGBA8, + PlsRG16, + + PlsRGBA8I, + PlsRG16I, + + PlsRGB10A2UI, + PlsRGBA8UI, + PlsRG16UI, + PlsR32UI +}; + +struct PlsRemap +{ + uint32_t id; + PlsFormat format; +}; + +enum AccessChainFlagBits +{ + ACCESS_CHAIN_INDEX_IS_LITERAL_BIT = 1 << 0, + ACCESS_CHAIN_CHAIN_ONLY_BIT = 1 << 1, + ACCESS_CHAIN_PTR_CHAIN_BIT = 1 << 2, + ACCESS_CHAIN_SKIP_REGISTER_EXPRESSION_READ_BIT = 1 << 3, + ACCESS_CHAIN_LITERAL_MSB_FORCE_ID = 1 << 4, + ACCESS_CHAIN_FLATTEN_ALL_MEMBERS_BIT = 1 << 5, + ACCESS_CHAIN_FORCE_COMPOSITE_BIT = 1 << 6, + ACCESS_CHAIN_PTR_CHAIN_POINTER_ARITH_BIT = 1 << 7, + ACCESS_CHAIN_PTR_CHAIN_CAST_TO_SCALAR_BIT = 1 << 8 +}; +typedef uint32_t AccessChainFlags; + +class CompilerGLSL : public Compiler +{ +public: + struct Options + { + // The shading language version. Corresponds to #version $VALUE. + uint32_t version = 450; + + // Emit the OpenGL ES shading language instead of desktop OpenGL. + bool es = false; + + // Debug option to always emit temporary variables for all expressions. + bool force_temporary = false; + // Debug option, can be increased in an attempt to workaround SPIRV-Cross bugs temporarily. + // If this limit has to be increased, it points to an implementation bug. + // In certain scenarios, the maximum number of debug iterations may increase beyond this limit + // as long as we can prove we're making certain kinds of forward progress. + uint32_t force_recompile_max_debug_iterations = 3; + + // If true, Vulkan GLSL features are used instead of GL-compatible features. + // Mostly useful for debugging SPIR-V files. + bool vulkan_semantics = false; + + // If true, gl_PerVertex is explicitly redeclared in vertex, geometry and tessellation shaders. + // The members of gl_PerVertex is determined by which built-ins are declared by the shader. + // This option is ignored in ES versions, as redeclaration in ES is not required, and it depends on a different extension + // (EXT_shader_io_blocks) which makes things a bit more fuzzy. + bool separate_shader_objects = false; + + // Flattens multidimensional arrays, e.g. float foo[a][b][c] into single-dimensional arrays, + // e.g. float foo[a * b * c]. + // This function does not change the actual SPIRType of any object. + // Only the generated code, including declarations of interface variables are changed to be single array dimension. + bool flatten_multidimensional_arrays = false; + + // For older desktop GLSL targets than version 420, the + // GL_ARB_shading_language_420pack extensions is used to be able to support + // layout(binding) on UBOs and samplers. + // If disabled on older targets, binding decorations will be stripped. + bool enable_420pack_extension = true; + + // In non-Vulkan GLSL, emit push constant blocks as UBOs rather than plain uniforms. + bool emit_push_constant_as_uniform_buffer = false; + + // Always emit uniform blocks as plain uniforms, regardless of the GLSL version, even when UBOs are supported. + // Does not apply to shader storage or push constant blocks. + bool emit_uniform_buffer_as_plain_uniforms = false; + + // Emit OpLine directives if present in the module. + // May not correspond exactly to original source, but should be a good approximation. + bool emit_line_directives = false; + + // In cases where readonly/writeonly decoration are not used at all, + // we try to deduce which qualifier(s) we should actually used, since actually emitting + // read-write decoration is very rare, and older glslang/HLSL compilers tend to just emit readwrite as a matter of fact. + // The default (true) is to enable automatic deduction for these cases, but if you trust the decorations set + // by the SPIR-V, it's recommended to set this to false. + bool enable_storage_image_qualifier_deduction = true; + + // On some targets (WebGPU), uninitialized variables are banned. + // If this is enabled, all variables (temporaries, Private, Function) + // which would otherwise be uninitialized will now be initialized to 0 instead. + bool force_zero_initialized_variables = false; + + // In GLSL, force use of I/O block flattening, similar to + // what happens on legacy GLSL targets for blocks and structs. + bool force_flattened_io_blocks = false; + + // For opcodes where we have to perform explicit additional nan checks, very ugly code is generated. + // If we opt-in, ignore these requirements. + // In opcodes like NClamp/NMin/NMax and FP compare, ignore NaN behavior. + // Use FClamp/FMin/FMax semantics for clamps and lets implementation choose ordered or unordered + // compares. + bool relax_nan_checks = false; + + // Loading row-major matrices from UBOs on older AMD Windows OpenGL drivers is problematic. + // To load these types correctly, we must generate a wrapper. them in a dummy function which only purpose is to + // ensure row_major decoration is actually respected. + // This workaround may cause significant performance degeneration on some Android devices. + bool enable_row_major_load_workaround = true; + + // If non-zero, controls layout(num_views = N) in; in GL_OVR_multiview2. + uint32_t ovr_multiview_view_count = 0; + + enum Precision + { + DontCare, + Lowp, + Mediump, + Highp + }; + + struct VertexOptions + { + // "Vertex-like shader" here is any shader stage that can write BuiltInPosition. + + // GLSL: In vertex-like shaders, rewrite [0, w] depth (Vulkan/D3D style) to [-w, w] depth (GL style). + // MSL: In vertex-like shaders, rewrite [-w, w] depth (GL style) to [0, w] depth. + // HLSL: In vertex-like shaders, rewrite [-w, w] depth (GL style) to [0, w] depth. + bool fixup_clipspace = false; + + // In vertex-like shaders, inverts gl_Position.y or equivalent. + bool flip_vert_y = false; + + // GLSL only, for HLSL version of this option, see CompilerHLSL. + // If true, the backend will assume that InstanceIndex will need to apply + // a base instance offset. Set to false if you know you will never use base instance + // functionality as it might remove some internal uniforms. + bool support_nonzero_base_instance = true; + } vertex; + + struct FragmentOptions + { + // Add precision mediump float in ES targets when emitting GLES source. + // Add precision highp int in ES targets when emitting GLES source. + Precision default_float_precision = Mediump; + Precision default_int_precision = Highp; + } fragment; + }; + + void remap_pixel_local_storage(std::vector inputs, std::vector outputs) + { + pls_inputs = std::move(inputs); + pls_outputs = std::move(outputs); + remap_pls_variables(); + } + + // Redirect a subpassInput reading from input_attachment_index to instead load its value from + // the color attachment at location = color_location. Requires ESSL. + // If coherent, uses GL_EXT_shader_framebuffer_fetch, if not, uses noncoherent variant. + void remap_ext_framebuffer_fetch(uint32_t input_attachment_index, uint32_t color_location, bool coherent); + + explicit CompilerGLSL(std::vector spirv_) + : Compiler(std::move(spirv_)) + { + init(); + } + + CompilerGLSL(const uint32_t *ir_, size_t word_count) + : Compiler(ir_, word_count) + { + init(); + } + + explicit CompilerGLSL(const ParsedIR &ir_) + : Compiler(ir_) + { + init(); + } + + explicit CompilerGLSL(ParsedIR &&ir_) + : Compiler(std::move(ir_)) + { + init(); + } + + const Options &get_common_options() const + { + return options; + } + + void set_common_options(const Options &opts) + { + options = opts; + } + + std::string compile() override; + + // Returns the current string held in the conversion buffer. Useful for + // capturing what has been converted so far when compile() throws an error. + std::string get_partial_source(); + + // Adds a line to be added right after #version in GLSL backend. + // This is useful for enabling custom extensions which are outside the scope of SPIRV-Cross. + // This can be combined with variable remapping. + // A new-line will be added. + // + // While add_header_line() is a more generic way of adding arbitrary text to the header + // of a GLSL file, require_extension() should be used when adding extensions since it will + // avoid creating collisions with SPIRV-Cross generated extensions. + // + // Code added via add_header_line() is typically backend-specific. + void add_header_line(const std::string &str); + + // Adds an extension which is required to run this shader, e.g. + // require_extension("GL_KHR_my_extension"); + void require_extension(const std::string &ext); + + // Returns the list of required extensions. After compilation this will contains any other + // extensions that the compiler used automatically, in addition to the user specified ones. + const SmallVector &get_required_extensions() const; + + // Legacy GLSL compatibility method. + // Takes a uniform or push constant variable and flattens it into a (i|u)vec4 array[N]; array instead. + // For this to work, all types in the block must be the same basic type, e.g. mixing vec2 and vec4 is fine, but + // mixing int and float is not. + // The name of the uniform array will be the same as the interface block name. + void flatten_buffer_block(VariableID id); + + // After compilation, query if a variable ID was used as a depth resource. + // This is meaningful for MSL since descriptor types depend on this knowledge. + // Cases which return true: + // - Images which are declared with depth = 1 image type. + // - Samplers which are statically used at least once with Dref opcodes. + // - Images which are statically used at least once with Dref opcodes. + bool variable_is_depth_or_compare(VariableID id) const; + + // If a shader output is active in this stage, but inactive in a subsequent stage, + // this can be signalled here. This can be used to work around certain cross-stage matching problems + // which plagues MSL and HLSL in certain scenarios. + // An output which matches one of these will not be emitted in stage output interfaces, but rather treated as a private + // variable. + // This option is only meaningful for MSL and HLSL, since GLSL matches by location directly. + // Masking builtins only takes effect if the builtin in question is part of the stage output interface. + void mask_stage_output_by_location(uint32_t location, uint32_t component); + void mask_stage_output_by_builtin(spv::BuiltIn builtin); + + // Allow to control how to format float literals in the output. + // Set to "nullptr" to use the default "convert_to_string" function. + // This handle is not owned by SPIRV-Cross and must remain valid until compile() has been called. + void set_float_formatter(FloatFormatter *formatter) + { + float_formatter = formatter; + } + +protected: + struct ShaderSubgroupSupportHelper + { + // lower enum value = greater priority + enum Candidate + { + KHR_shader_subgroup_ballot, + KHR_shader_subgroup_basic, + KHR_shader_subgroup_vote, + KHR_shader_subgroup_arithmetic, + NV_gpu_shader_5, + NV_shader_thread_group, + NV_shader_thread_shuffle, + ARB_shader_ballot, + ARB_shader_group_vote, + AMD_gcn_shader, + + CandidateCount + }; + + static const char *get_extension_name(Candidate c); + static SmallVector get_extra_required_extension_names(Candidate c); + static const char *get_extra_required_extension_predicate(Candidate c); + + enum Feature + { + SubgroupMask = 0, + SubgroupSize = 1, + SubgroupInvocationID = 2, + SubgroupID = 3, + NumSubgroups = 4, + SubgroupBroadcast_First = 5, + SubgroupBallotFindLSB_MSB = 6, + SubgroupAll_Any_AllEqualBool = 7, + SubgroupAllEqualT = 8, + SubgroupElect = 9, + SubgroupBarrier = 10, + SubgroupMemBarrier = 11, + SubgroupBallot = 12, + SubgroupInverseBallot_InclBitCount_ExclBitCout = 13, + SubgroupBallotBitExtract = 14, + SubgroupBallotBitCount = 15, + SubgroupArithmeticIAddReduce = 16, + SubgroupArithmeticIAddExclusiveScan = 17, + SubgroupArithmeticIAddInclusiveScan = 18, + SubgroupArithmeticFAddReduce = 19, + SubgroupArithmeticFAddExclusiveScan = 20, + SubgroupArithmeticFAddInclusiveScan = 21, + SubgroupArithmeticIMulReduce = 22, + SubgroupArithmeticIMulExclusiveScan = 23, + SubgroupArithmeticIMulInclusiveScan = 24, + SubgroupArithmeticFMulReduce = 25, + SubgroupArithmeticFMulExclusiveScan = 26, + SubgroupArithmeticFMulInclusiveScan = 27, + FeatureCount + }; + + using FeatureMask = uint32_t; + static_assert(sizeof(FeatureMask) * 8u >= FeatureCount, "Mask type needs more bits."); + + using CandidateVector = SmallVector; + using FeatureVector = SmallVector; + + static FeatureVector get_feature_dependencies(Feature feature); + static FeatureMask get_feature_dependency_mask(Feature feature); + static bool can_feature_be_implemented_without_extensions(Feature feature); + static Candidate get_KHR_extension_for_feature(Feature feature); + + struct Result + { + Result(); + uint32_t weights[CandidateCount]; + }; + + void request_feature(Feature feature); + bool is_feature_requested(Feature feature) const; + Result resolve() const; + + static CandidateVector get_candidates_for_feature(Feature ft, const Result &r); + + private: + static CandidateVector get_candidates_for_feature(Feature ft); + static FeatureMask build_mask(const SmallVector &features); + FeatureMask feature_mask = 0; + }; + + // TODO remove this function when all subgroup ops are supported (or make it always return true) + static bool is_supported_subgroup_op_in_opengl(spv::Op op, const uint32_t *ops); + + void reset(uint32_t iteration_count); + void emit_function(SPIRFunction &func, const Bitset &return_flags); + + bool has_extension(const std::string &ext) const; + void require_extension_internal(const std::string &ext); + + // Virtualize methods which need to be overridden by subclass targets like C++ and such. + virtual void emit_function_prototype(SPIRFunction &func, const Bitset &return_flags); + + SPIRBlock *current_emitting_block = nullptr; + SmallVector current_emitting_switch_stack; + bool current_emitting_switch_fallthrough = false; + + virtual void emit_instruction(const Instruction &instr); + struct TemporaryCopy + { + uint32_t dst_id; + uint32_t src_id; + }; + TemporaryCopy handle_instruction_precision(const Instruction &instr); + void emit_block_instructions(SPIRBlock &block); + void emit_block_instructions_with_masked_debug(SPIRBlock &block); + + // For relax_nan_checks. + GLSLstd450 get_remapped_glsl_op(GLSLstd450 std450_op) const; + spv::Op get_remapped_spirv_op(spv::Op op) const; + + virtual void emit_glsl_op(uint32_t result_type, uint32_t result_id, uint32_t op, const uint32_t *args, + uint32_t count); + virtual void emit_spv_amd_shader_ballot_op(uint32_t result_type, uint32_t result_id, uint32_t op, + const uint32_t *args, uint32_t count); + virtual void emit_spv_amd_shader_explicit_vertex_parameter_op(uint32_t result_type, uint32_t result_id, uint32_t op, + const uint32_t *args, uint32_t count); + virtual void emit_spv_amd_shader_trinary_minmax_op(uint32_t result_type, uint32_t result_id, uint32_t op, + const uint32_t *args, uint32_t count); + virtual void emit_spv_amd_gcn_shader_op(uint32_t result_type, uint32_t result_id, uint32_t op, const uint32_t *args, + uint32_t count); + virtual void emit_header(); + void emit_line_directive(uint32_t file_id, uint32_t line_literal); + void build_workgroup_size(SmallVector &arguments, const SpecializationConstant &x, + const SpecializationConstant &y, const SpecializationConstant &z); + + void request_subgroup_feature(ShaderSubgroupSupportHelper::Feature feature); + + virtual void emit_sampled_image_op(uint32_t result_type, uint32_t result_id, uint32_t image_id, uint32_t samp_id); + virtual void emit_texture_op(const Instruction &i, bool sparse); + virtual std::string to_texture_op(const Instruction &i, bool sparse, bool *forward, + SmallVector &inherited_expressions); + virtual void emit_subgroup_op(const Instruction &i); + virtual std::string type_to_glsl(const SPIRType &type, uint32_t id = 0); + virtual std::string builtin_to_glsl(spv::BuiltIn builtin, spv::StorageClass storage); + virtual void emit_struct_member(const SPIRType &type, uint32_t member_type_id, uint32_t index, + const std::string &qualifier = "", uint32_t base_offset = 0); + virtual void emit_struct_padding_target(const SPIRType &type); + virtual std::string image_type_glsl(const SPIRType &type, uint32_t id = 0, bool member = false); + std::string constant_expression(const SPIRConstant &c, + bool inside_block_like_struct_scope = false, + bool inside_struct_scope = false); + virtual std::string constant_op_expression(const SPIRConstantOp &cop); + virtual std::string constant_expression_vector(const SPIRConstant &c, uint32_t vector); + virtual void emit_fixup(); + virtual std::string variable_decl(const SPIRType &type, const std::string &name, uint32_t id = 0); + virtual bool variable_decl_is_remapped_storage(const SPIRVariable &var, spv::StorageClass storage) const; + virtual std::string to_func_call_arg(const SPIRFunction::Parameter &arg, uint32_t id); + + struct TextureFunctionBaseArguments + { + // GCC 4.8 workarounds, it doesn't understand '{}' constructor here, use explicit default constructor. + TextureFunctionBaseArguments() = default; + VariableID img = 0; + const SPIRType *imgtype = nullptr; + bool is_fetch = false, is_gather = false, is_proj = false; + }; + + struct TextureFunctionNameArguments + { + // GCC 4.8 workarounds, it doesn't understand '{}' constructor here, use explicit default constructor. + TextureFunctionNameArguments() = default; + TextureFunctionBaseArguments base; + bool has_array_offsets = false, has_offset = false, has_grad = false; + bool has_dref = false, is_sparse_feedback = false, has_min_lod = false; + uint32_t lod = 0; + }; + virtual std::string to_function_name(const TextureFunctionNameArguments &args); + + struct TextureFunctionArguments + { + // GCC 4.8 workarounds, it doesn't understand '{}' constructor here, use explicit default constructor. + TextureFunctionArguments() = default; + TextureFunctionBaseArguments base; + uint32_t coord = 0, coord_components = 0, dref = 0; + uint32_t grad_x = 0, grad_y = 0, lod = 0, offset = 0; + uint32_t bias = 0, component = 0, sample = 0, sparse_texel = 0, min_lod = 0; + bool nonuniform_expression = false, has_array_offsets = false; + }; + virtual std::string to_function_args(const TextureFunctionArguments &args, bool *p_forward); + + void emit_sparse_feedback_temporaries(uint32_t result_type_id, uint32_t id, uint32_t &feedback_id, + uint32_t &texel_id); + uint32_t get_sparse_feedback_texel_id(uint32_t id) const; + virtual void emit_buffer_block(const SPIRVariable &type); + virtual void emit_push_constant_block(const SPIRVariable &var); + virtual void emit_uniform(const SPIRVariable &var); + virtual std::string unpack_expression_type(std::string expr_str, const SPIRType &type, uint32_t physical_type_id, + bool packed_type, bool row_major); + + virtual bool builtin_translates_to_nonarray(spv::BuiltIn builtin) const; + + virtual bool is_user_type_structured(uint32_t id) const; + + void emit_copy_logical_type(uint32_t lhs_id, uint32_t lhs_type_id, uint32_t rhs_id, uint32_t rhs_type_id, + SmallVector chain); + + StringStream<> buffer; + + template + inline void statement_inner(T &&t) + { + buffer << std::forward(t); + statement_count++; + } + + template + inline void statement_inner(T &&t, Ts &&... ts) + { + buffer << std::forward(t); + statement_count++; + statement_inner(std::forward(ts)...); + } + + template + inline void statement(Ts &&... ts) + { + if (is_forcing_recompilation()) + { + // Do not bother emitting code while force_recompile is active. + // We will compile again. + statement_count++; + return; + } + + if (redirect_statement) + { + redirect_statement->push_back(join(std::forward(ts)...)); + statement_count++; + } + else + { + for (uint32_t i = 0; i < indent; i++) + buffer << " "; + statement_inner(std::forward(ts)...); + buffer << '\n'; + } + } + + template + inline void statement_no_indent(Ts &&... ts) + { + auto old_indent = indent; + indent = 0; + statement(std::forward(ts)...); + indent = old_indent; + } + + // Used for implementing continue blocks where + // we want to obtain a list of statements we can merge + // on a single line separated by comma. + SmallVector *redirect_statement = nullptr; + const SPIRBlock *current_continue_block = nullptr; + bool block_temporary_hoisting = false; + bool block_debug_directives = false; + + void begin_scope(); + void end_scope(); + void end_scope(const std::string &trailer); + void end_scope_decl(); + void end_scope_decl(const std::string &decl); + + Options options; + + // Allow Metal to use the array template to make arrays a value type + virtual std::string type_to_array_glsl(const SPIRType &type, uint32_t variable_id); + std::string to_array_size(const SPIRType &type, uint32_t index); + uint32_t to_array_size_literal(const SPIRType &type, uint32_t index) const; + uint32_t to_array_size_literal(const SPIRType &type) const; + virtual std::string variable_decl(const SPIRVariable &variable); // Threadgroup arrays can't have a wrapper type + std::string variable_decl_function_local(SPIRVariable &variable); + + void add_local_variable_name(uint32_t id); + void add_resource_name(uint32_t id); + void add_member_name(SPIRType &type, uint32_t name); + void add_function_overload(const SPIRFunction &func); + + virtual bool is_non_native_row_major_matrix(uint32_t id); + virtual bool member_is_non_native_row_major_matrix(const SPIRType &type, uint32_t index); + bool member_is_remapped_physical_type(const SPIRType &type, uint32_t index) const; + bool member_is_packed_physical_type(const SPIRType &type, uint32_t index) const; + virtual std::string convert_row_major_matrix(std::string exp_str, const SPIRType &exp_type, + uint32_t physical_type_id, bool is_packed, + bool relaxed = false); + + std::unordered_set local_variable_names; + std::unordered_set resource_names; + std::unordered_set block_input_names; + std::unordered_set block_output_names; + std::unordered_set block_ubo_names; + std::unordered_set block_ssbo_names; + std::unordered_set block_names; // A union of all block_*_names. + std::unordered_map> function_overloads; + std::unordered_map preserved_aliases; + void preserve_alias_on_reset(uint32_t id); + void reset_name_caches(); + + bool processing_entry_point = false; + + // Can be overriden by subclass backends for trivial things which + // shouldn't need polymorphism. + struct BackendVariations + { + std::string discard_literal = "discard"; + std::string demote_literal = "demote"; + std::string null_pointer_literal = ""; + bool float_literal_suffix = false; + bool double_literal_suffix = true; + bool uint32_t_literal_suffix = true; + bool long_long_literal_suffix = false; + const char *basic_int_type = "int"; + const char *basic_uint_type = "uint"; + const char *basic_int8_type = "int8_t"; + const char *basic_uint8_type = "uint8_t"; + const char *basic_int16_type = "int16_t"; + const char *basic_uint16_type = "uint16_t"; + const char *int16_t_literal_suffix = "s"; + const char *uint16_t_literal_suffix = "us"; + const char *nonuniform_qualifier = "nonuniformEXT"; + const char *boolean_mix_function = "mix"; + SPIRType::BaseType boolean_in_struct_remapped_type = SPIRType::Boolean; + bool swizzle_is_function = false; + bool shared_is_implied = false; + bool unsized_array_supported = true; + bool explicit_struct_type = false; + bool use_initializer_list = false; + bool use_typed_initializer_list = false; + bool can_declare_struct_inline = true; + bool can_declare_arrays_inline = true; + bool native_row_major_matrix = true; + bool use_constructor_splatting = true; + bool allow_precision_qualifiers = false; + bool can_swizzle_scalar = false; + bool force_gl_in_out_block = false; + bool force_merged_mesh_block = false; + bool can_return_array = true; + bool allow_truncated_access_chain = false; + bool supports_extensions = false; + bool supports_empty_struct = false; + bool array_is_value_type = true; + bool array_is_value_type_in_buffer_blocks = true; + bool comparison_image_samples_scalar = false; + bool native_pointers = false; + bool support_small_type_sampling_result = false; + bool support_case_fallthrough = true; + bool use_array_constructor = false; + bool needs_row_major_load_workaround = false; + bool support_pointer_to_pointer = false; + bool support_precise_qualifier = false; + bool support_64bit_switch = false; + bool workgroup_size_is_hidden = false; + bool requires_relaxed_precision_analysis = false; + bool implicit_c_integer_promotion_rules = false; + } backend; + + void emit_struct(SPIRType &type); + void emit_resources(); + void emit_extension_workarounds(spv::ExecutionModel model); + void emit_subgroup_arithmetic_workaround(const std::string &func, spv::Op op, spv::GroupOperation group_op); + void emit_polyfills(uint32_t polyfills, bool relaxed); + void emit_buffer_block_native(const SPIRVariable &var); + void emit_buffer_reference_block(uint32_t type_id, bool forward_declaration); + void emit_buffer_block_legacy(const SPIRVariable &var); + void emit_buffer_block_flattened(const SPIRVariable &type); + void fixup_implicit_builtin_block_names(spv::ExecutionModel model); + void emit_declared_builtin_block(spv::StorageClass storage, spv::ExecutionModel model); + bool should_force_emit_builtin_block(spv::StorageClass storage); + void emit_push_constant_block_vulkan(const SPIRVariable &var); + void emit_push_constant_block_glsl(const SPIRVariable &var); + void emit_interface_block(const SPIRVariable &type); + void emit_flattened_io_block(const SPIRVariable &var, const char *qual); + void emit_flattened_io_block_struct(const std::string &basename, const SPIRType &type, const char *qual, + const SmallVector &indices); + void emit_flattened_io_block_member(const std::string &basename, const SPIRType &type, const char *qual, + const SmallVector &indices); + void emit_block_chain(SPIRBlock &block); + void emit_hoisted_temporaries(SmallVector> &temporaries); + std::string constant_value_macro_name(uint32_t id); + int get_constant_mapping_to_workgroup_component(const SPIRConstant &constant) const; + void emit_constant(const SPIRConstant &constant); + void emit_specialization_constant_op(const SPIRConstantOp &constant); + std::string emit_continue_block(uint32_t continue_block, bool follow_true_block, bool follow_false_block); + bool attempt_emit_loop_header(SPIRBlock &block, SPIRBlock::Method method); + + void branch(BlockID from, BlockID to); + void branch_to_continue(BlockID from, BlockID to); + void branch(BlockID from, uint32_t cond, BlockID true_block, BlockID false_block); + void flush_phi(BlockID from, BlockID to); + void flush_variable_declaration(uint32_t id); + void flush_undeclared_variables(SPIRBlock &block); + void emit_variable_temporary_copies(const SPIRVariable &var); + + bool should_dereference(uint32_t id); + bool should_forward(uint32_t id) const; + bool should_suppress_usage_tracking(uint32_t id) const; + void emit_mix_op(uint32_t result_type, uint32_t id, uint32_t left, uint32_t right, uint32_t lerp); + void emit_nminmax_op(uint32_t result_type, uint32_t id, uint32_t op0, uint32_t op1, GLSLstd450 op); + void emit_emulated_ahyper_op(uint32_t result_type, uint32_t result_id, uint32_t op0, GLSLstd450 op); + bool to_trivial_mix_op(const SPIRType &type, std::string &op, uint32_t left, uint32_t right, uint32_t lerp); + void emit_quaternary_func_op(uint32_t result_type, uint32_t result_id, uint32_t op0, uint32_t op1, uint32_t op2, + uint32_t op3, const char *op); + void emit_trinary_func_op(uint32_t result_type, uint32_t result_id, uint32_t op0, uint32_t op1, uint32_t op2, + const char *op); + void emit_binary_func_op(uint32_t result_type, uint32_t result_id, uint32_t op0, uint32_t op1, const char *op); + void emit_atomic_func_op(uint32_t result_type, uint32_t result_id, uint32_t op0, uint32_t op1, const char *op); + void emit_atomic_func_op(uint32_t result_type, uint32_t result_id, uint32_t op0, uint32_t op1, uint32_t op2, const char *op); + + void emit_unary_func_op_cast(uint32_t result_type, uint32_t result_id, uint32_t op0, const char *op, + SPIRType::BaseType input_type, SPIRType::BaseType expected_result_type); + void emit_binary_func_op_cast(uint32_t result_type, uint32_t result_id, uint32_t op0, uint32_t op1, const char *op, + SPIRType::BaseType input_type, bool skip_cast_if_equal_type); + void emit_binary_func_op_cast_clustered(uint32_t result_type, uint32_t result_id, uint32_t op0, uint32_t op1, + const char *op, SPIRType::BaseType input_type); + void emit_trinary_func_op_cast(uint32_t result_type, uint32_t result_id, uint32_t op0, uint32_t op1, uint32_t op2, + const char *op, SPIRType::BaseType input_type); + void emit_trinary_func_op_bitextract(uint32_t result_type, uint32_t result_id, uint32_t op0, uint32_t op1, + uint32_t op2, const char *op, SPIRType::BaseType expected_result_type, + SPIRType::BaseType input_type0, SPIRType::BaseType input_type1, + SPIRType::BaseType input_type2); + void emit_bitfield_insert_op(uint32_t result_type, uint32_t result_id, uint32_t op0, uint32_t op1, uint32_t op2, + uint32_t op3, const char *op, SPIRType::BaseType offset_count_type); + + void emit_unary_func_op(uint32_t result_type, uint32_t result_id, uint32_t op0, const char *op); + void emit_unrolled_unary_op(uint32_t result_type, uint32_t result_id, uint32_t operand, const char *op); + void emit_binary_op(uint32_t result_type, uint32_t result_id, uint32_t op0, uint32_t op1, const char *op); + void emit_unrolled_binary_op(uint32_t result_type, uint32_t result_id, uint32_t op0, uint32_t op1, const char *op, + bool negate, SPIRType::BaseType expected_type); + void emit_binary_op_cast(uint32_t result_type, uint32_t result_id, uint32_t op0, uint32_t op1, const char *op, + SPIRType::BaseType input_type, bool skip_cast_if_equal_type, bool implicit_integer_promotion); + + SPIRType binary_op_bitcast_helper(std::string &cast_op0, std::string &cast_op1, SPIRType::BaseType &input_type, + uint32_t op0, uint32_t op1, bool skip_cast_if_equal_type); + + virtual bool emit_complex_bitcast(uint32_t result_type, uint32_t id, uint32_t op0); + + std::string to_ternary_expression(const SPIRType &result_type, uint32_t select, uint32_t true_value, + uint32_t false_value); + + void emit_unary_op(uint32_t result_type, uint32_t result_id, uint32_t op0, const char *op); + void emit_unary_op_cast(uint32_t result_type, uint32_t result_id, uint32_t op0, const char *op); + virtual void emit_mesh_tasks(SPIRBlock &block); + bool expression_is_forwarded(uint32_t id) const; + bool expression_suppresses_usage_tracking(uint32_t id) const; + bool expression_read_implies_multiple_reads(uint32_t id) const; + SPIRExpression &emit_op(uint32_t result_type, uint32_t result_id, const std::string &rhs, bool forward_rhs, + bool suppress_usage_tracking = false); + + void access_chain_internal_append_index(std::string &expr, uint32_t base, const SPIRType *type, + AccessChainFlags flags, bool &access_chain_is_arrayed, uint32_t index); + + std::string access_chain_internal(uint32_t base, const uint32_t *indices, uint32_t count, AccessChainFlags flags, + AccessChainMeta *meta); + + // Only meaningful on backends with physical pointer support ala MSL. + // Relevant for PtrAccessChain / BDA. + virtual uint32_t get_physical_type_stride(const SPIRType &type) const; + + spv::StorageClass get_expression_effective_storage_class(uint32_t ptr); + virtual bool access_chain_needs_stage_io_builtin_translation(uint32_t base); + + virtual void check_physical_type_cast(std::string &expr, const SPIRType *type, uint32_t physical_type); + virtual bool prepare_access_chain_for_scalar_access(std::string &expr, const SPIRType &type, + spv::StorageClass storage, bool &is_packed); + + std::string access_chain(uint32_t base, const uint32_t *indices, uint32_t count, const SPIRType &target_type, + AccessChainMeta *meta = nullptr, bool ptr_chain = false); + + std::string flattened_access_chain(uint32_t base, const uint32_t *indices, uint32_t count, + const SPIRType &target_type, uint32_t offset, uint32_t matrix_stride, + uint32_t array_stride, bool need_transpose); + std::string flattened_access_chain_struct(uint32_t base, const uint32_t *indices, uint32_t count, + const SPIRType &target_type, uint32_t offset); + std::string flattened_access_chain_matrix(uint32_t base, const uint32_t *indices, uint32_t count, + const SPIRType &target_type, uint32_t offset, uint32_t matrix_stride, + bool need_transpose); + std::string flattened_access_chain_vector(uint32_t base, const uint32_t *indices, uint32_t count, + const SPIRType &target_type, uint32_t offset, uint32_t matrix_stride, + bool need_transpose); + std::pair flattened_access_chain_offset(const SPIRType &basetype, const uint32_t *indices, + uint32_t count, uint32_t offset, + uint32_t word_stride, bool *need_transpose = nullptr, + uint32_t *matrix_stride = nullptr, + uint32_t *array_stride = nullptr, + bool ptr_chain = false); + + const char *index_to_swizzle(uint32_t index); + std::string remap_swizzle(const SPIRType &result_type, uint32_t input_components, const std::string &expr); + std::string declare_temporary(uint32_t type, uint32_t id); + void emit_uninitialized_temporary(uint32_t type, uint32_t id); + SPIRExpression &emit_uninitialized_temporary_expression(uint32_t type, uint32_t id); + void append_global_func_args(const SPIRFunction &func, uint32_t index, SmallVector &arglist); + std::string to_non_uniform_aware_expression(uint32_t id); + std::string to_expression(uint32_t id, bool register_expression_read = true); + std::string to_composite_constructor_expression(const SPIRType &parent_type, uint32_t id, bool block_like_type); + std::string to_rerolled_array_expression(const SPIRType &parent_type, const std::string &expr, const SPIRType &type); + std::string to_enclosed_expression(uint32_t id, bool register_expression_read = true); + std::string to_unpacked_expression(uint32_t id, bool register_expression_read = true); + std::string to_unpacked_row_major_matrix_expression(uint32_t id); + std::string to_enclosed_unpacked_expression(uint32_t id, bool register_expression_read = true); + std::string to_dereferenced_expression(uint32_t id, bool register_expression_read = true); + std::string to_pointer_expression(uint32_t id, bool register_expression_read = true); + std::string to_enclosed_pointer_expression(uint32_t id, bool register_expression_read = true); + std::string to_extract_component_expression(uint32_t id, uint32_t index); + std::string to_extract_constant_composite_expression(uint32_t result_type, const SPIRConstant &c, + const uint32_t *chain, uint32_t length); + static bool needs_enclose_expression(const std::string &expr); + std::string enclose_expression(const std::string &expr); + std::string dereference_expression(const SPIRType &expression_type, const std::string &expr); + std::string address_of_expression(const std::string &expr); + void strip_enclosed_expression(std::string &expr); + std::string to_member_name(const SPIRType &type, uint32_t index); + virtual std::string to_member_reference(uint32_t base, const SPIRType &type, uint32_t index, bool ptr_chain_is_resolved); + std::string to_multi_member_reference(const SPIRType &type, const SmallVector &indices); + std::string type_to_glsl_constructor(const SPIRType &type); + std::string argument_decl(const SPIRFunction::Parameter &arg); + virtual std::string to_qualifiers_glsl(uint32_t id); + void fixup_io_block_patch_primitive_qualifiers(const SPIRVariable &var); + void emit_output_variable_initializer(const SPIRVariable &var); + std::string to_precision_qualifiers_glsl(uint32_t id); + virtual const char *to_storage_qualifiers_glsl(const SPIRVariable &var); + std::string flags_to_qualifiers_glsl(const SPIRType &type, const Bitset &flags); + const char *format_to_glsl(spv::ImageFormat format); + virtual std::string layout_for_member(const SPIRType &type, uint32_t index); + virtual std::string to_interpolation_qualifiers(const Bitset &flags); + std::string layout_for_variable(const SPIRVariable &variable); + std::string to_combined_image_sampler(VariableID image_id, VariableID samp_id); + virtual bool skip_argument(uint32_t id) const; + virtual bool emit_array_copy(const char *expr, uint32_t lhs_id, uint32_t rhs_id, + spv::StorageClass lhs_storage, spv::StorageClass rhs_storage); + virtual void emit_block_hints(const SPIRBlock &block); + virtual std::string to_initializer_expression(const SPIRVariable &var); + virtual std::string to_zero_initialized_expression(uint32_t type_id); + bool type_can_zero_initialize(const SPIRType &type) const; + + bool buffer_is_packing_standard(const SPIRType &type, BufferPackingStandard packing, + uint32_t *failed_index = nullptr, uint32_t start_offset = 0, + uint32_t end_offset = ~(0u)); + std::string buffer_to_packing_standard(const SPIRType &type, + bool support_std430_without_scalar_layout, + bool support_enhanced_layouts); + + uint32_t type_to_packed_base_size(const SPIRType &type, BufferPackingStandard packing); + uint32_t type_to_packed_alignment(const SPIRType &type, const Bitset &flags, BufferPackingStandard packing); + uint32_t type_to_packed_array_stride(const SPIRType &type, const Bitset &flags, BufferPackingStandard packing); + uint32_t type_to_packed_size(const SPIRType &type, const Bitset &flags, BufferPackingStandard packing); + uint32_t type_to_location_count(const SPIRType &type) const; + + std::string bitcast_glsl(const SPIRType &result_type, uint32_t arg); + virtual std::string bitcast_glsl_op(const SPIRType &result_type, const SPIRType &argument_type); + + std::string bitcast_expression(SPIRType::BaseType target_type, uint32_t arg); + std::string bitcast_expression(const SPIRType &target_type, SPIRType::BaseType expr_type, const std::string &expr); + + std::string build_composite_combiner(uint32_t result_type, const uint32_t *elems, uint32_t length); + bool remove_duplicate_swizzle(std::string &op); + bool remove_unity_swizzle(uint32_t base, std::string &op); + + // Can modify flags to remote readonly/writeonly if image type + // and force recompile. + bool check_atomic_image(uint32_t id); + + virtual void replace_illegal_names(); + void replace_illegal_names(const std::unordered_set &keywords); + virtual void emit_entry_point_declarations(); + + void replace_fragment_output(SPIRVariable &var); + void replace_fragment_outputs(); + std::string legacy_tex_op(const std::string &op, const SPIRType &imgtype, uint32_t id); + + void forward_relaxed_precision(uint32_t dst_id, const uint32_t *args, uint32_t length); + void analyze_precision_requirements(uint32_t type_id, uint32_t dst_id, uint32_t *args, uint32_t length); + Options::Precision analyze_expression_precision(const uint32_t *args, uint32_t length) const; + + uint32_t indent = 0; + + std::unordered_set emitted_functions; + + // Ensure that we declare phi-variable copies even if the original declaration isn't deferred + std::unordered_set flushed_phi_variables; + + std::unordered_set flattened_buffer_blocks; + std::unordered_map flattened_structs; + + ShaderSubgroupSupportHelper shader_subgroup_supporter; + + std::string load_flattened_struct(const std::string &basename, const SPIRType &type); + std::string to_flattened_struct_member(const std::string &basename, const SPIRType &type, uint32_t index); + void store_flattened_struct(uint32_t lhs_id, uint32_t value); + void store_flattened_struct(const std::string &basename, uint32_t rhs, const SPIRType &type, + const SmallVector &indices); + std::string to_flattened_access_chain_expression(uint32_t id); + + // Usage tracking. If a temporary is used more than once, use the temporary instead to + // avoid AST explosion when SPIRV is generated with pure SSA and doesn't write stuff to variables. + std::unordered_map expression_usage_counts; + void track_expression_read(uint32_t id); + + SmallVector forced_extensions; + SmallVector header_lines; + + // Used when expressions emit extra opcodes with their own unique IDs, + // and we need to reuse the IDs across recompilation loops. + // Currently used by NMin/Max/Clamp implementations. + std::unordered_map extra_sub_expressions; + + SmallVector workaround_ubo_load_overload_types; + void request_workaround_wrapper_overload(TypeID id); + void rewrite_load_for_wrapped_row_major(std::string &expr, TypeID loaded_type, ID ptr); + + uint32_t statement_count = 0; + + inline bool is_legacy() const + { + return (options.es && options.version < 300) || (!options.es && options.version < 130); + } + + inline bool is_legacy_es() const + { + return options.es && options.version < 300; + } + + inline bool is_legacy_desktop() const + { + return !options.es && options.version < 130; + } + + enum Polyfill : uint32_t + { + PolyfillTranspose2x2 = 1 << 0, + PolyfillTranspose3x3 = 1 << 1, + PolyfillTranspose4x4 = 1 << 2, + PolyfillDeterminant2x2 = 1 << 3, + PolyfillDeterminant3x3 = 1 << 4, + PolyfillDeterminant4x4 = 1 << 5, + PolyfillMatrixInverse2x2 = 1 << 6, + PolyfillMatrixInverse3x3 = 1 << 7, + PolyfillMatrixInverse4x4 = 1 << 8, + PolyfillNMin16 = 1 << 9, + PolyfillNMin32 = 1 << 10, + PolyfillNMin64 = 1 << 11, + PolyfillNMax16 = 1 << 12, + PolyfillNMax32 = 1 << 13, + PolyfillNMax64 = 1 << 14, + PolyfillNClamp16 = 1 << 15, + PolyfillNClamp32 = 1 << 16, + PolyfillNClamp64 = 1 << 17, + }; + + uint32_t required_polyfills = 0; + uint32_t required_polyfills_relaxed = 0; + void require_polyfill(Polyfill polyfill, bool relaxed); + + bool ray_tracing_is_khr = false; + bool barycentric_is_nv = false; + void ray_tracing_khr_fixup_locations(); + + bool args_will_forward(uint32_t id, const uint32_t *args, uint32_t num_args, bool pure); + void register_call_out_argument(uint32_t id); + void register_impure_function_call(); + void register_control_dependent_expression(uint32_t expr); + + // GL_EXT_shader_pixel_local_storage support. + std::vector pls_inputs; + std::vector pls_outputs; + std::string pls_decl(const PlsRemap &variable); + const char *to_pls_qualifiers_glsl(const SPIRVariable &variable); + void emit_pls(); + void remap_pls_variables(); + + // GL_EXT_shader_framebuffer_fetch support. + std::vector> subpass_to_framebuffer_fetch_attachment; + std::vector> inout_color_attachments; + bool location_is_framebuffer_fetch(uint32_t location) const; + bool location_is_non_coherent_framebuffer_fetch(uint32_t location) const; + bool subpass_input_is_framebuffer_fetch(uint32_t id) const; + void emit_inout_fragment_outputs_copy_to_subpass_inputs(); + const SPIRVariable *find_subpass_input_by_attachment_index(uint32_t index) const; + const SPIRVariable *find_color_output_by_location(uint32_t location) const; + + // A variant which takes two sets of name. The secondary is only used to verify there are no collisions, + // but the set is not updated when we have found a new name. + // Used primarily when adding block interface names. + void add_variable(std::unordered_set &variables_primary, + const std::unordered_set &variables_secondary, std::string &name); + + void check_function_call_constraints(const uint32_t *args, uint32_t length); + void handle_invalid_expression(uint32_t id); + void force_temporary_and_recompile(uint32_t id); + void find_static_extensions(); + + uint32_t consume_temporary_in_precision_context(uint32_t type_id, uint32_t id, Options::Precision precision); + std::unordered_map temporary_to_mirror_precision_alias; + std::unordered_set composite_insert_overwritten; + std::unordered_set block_composite_insert_overwrite; + + std::string emit_for_loop_initializers(const SPIRBlock &block); + void emit_while_loop_initializers(const SPIRBlock &block); + bool for_loop_initializers_are_same_type(const SPIRBlock &block); + bool optimize_read_modify_write(const SPIRType &type, const std::string &lhs, const std::string &rhs); + void fixup_image_load_store_access(); + + bool type_is_empty(const SPIRType &type); + + bool can_use_io_location(spv::StorageClass storage, bool block); + const Instruction *get_next_instruction_in_block(const Instruction &instr); + static uint32_t mask_relevant_memory_semantics(uint32_t semantics); + + std::string convert_half_to_string(const SPIRConstant &value, uint32_t col, uint32_t row); + std::string convert_float_to_string(const SPIRConstant &value, uint32_t col, uint32_t row); + std::string convert_double_to_string(const SPIRConstant &value, uint32_t col, uint32_t row); + + std::string convert_separate_image_to_expression(uint32_t id); + + // Builtins in GLSL are always specific signedness, but the SPIR-V can declare them + // as either unsigned or signed. + // Sometimes we will need to automatically perform casts on load and store to make this work. + virtual SPIRType::BaseType get_builtin_basetype(spv::BuiltIn builtin, SPIRType::BaseType default_type); + virtual void cast_to_variable_store(uint32_t target_id, std::string &expr, const SPIRType &expr_type); + virtual void cast_from_variable_load(uint32_t source_id, std::string &expr, const SPIRType &expr_type); + void unroll_array_from_complex_load(uint32_t target_id, uint32_t source_id, std::string &expr); + bool unroll_array_to_complex_store(uint32_t target_id, uint32_t source_id); + void convert_non_uniform_expression(std::string &expr, uint32_t ptr_id); + + void handle_store_to_invariant_variable(uint32_t store_id, uint32_t value_id); + void disallow_forwarding_in_expression_chain(const SPIRExpression &expr); + + bool expression_is_constant_null(uint32_t id) const; + bool expression_is_non_value_type_array(uint32_t ptr); + virtual void emit_store_statement(uint32_t lhs_expression, uint32_t rhs_expression); + + uint32_t get_integer_width_for_instruction(const Instruction &instr) const; + uint32_t get_integer_width_for_glsl_instruction(GLSLstd450 op, const uint32_t *arguments, uint32_t length) const; + + bool variable_is_lut(const SPIRVariable &var) const; + + char current_locale_radix_character = '.'; + + void fixup_type_alias(); + void reorder_type_alias(); + void fixup_anonymous_struct_names(); + void fixup_anonymous_struct_names(std::unordered_set &visited, const SPIRType &type); + + static const char *vector_swizzle(int vecsize, int index); + + bool is_stage_output_location_masked(uint32_t location, uint32_t component) const; + bool is_stage_output_builtin_masked(spv::BuiltIn builtin) const; + bool is_stage_output_variable_masked(const SPIRVariable &var) const; + bool is_stage_output_block_member_masked(const SPIRVariable &var, uint32_t index, bool strip_array) const; + bool is_per_primitive_variable(const SPIRVariable &var) const; + uint32_t get_accumulated_member_location(const SPIRVariable &var, uint32_t mbr_idx, bool strip_array) const; + uint32_t get_declared_member_location(const SPIRVariable &var, uint32_t mbr_idx, bool strip_array) const; + std::unordered_set masked_output_locations; + std::unordered_set masked_output_builtins; + + FloatFormatter *float_formatter = nullptr; + std::string format_float(float value) const; + std::string format_double(double value) const; + +private: + void init(); + + SmallVector get_composite_constant_ids(ConstantID const_id); + void fill_composite_constant(SPIRConstant &constant, TypeID type_id, const SmallVector &initializers); + void set_composite_constant(ConstantID const_id, TypeID type_id, const SmallVector &initializers); + TypeID get_composite_member_type(TypeID type_id, uint32_t member_idx); + std::unordered_map> const_composite_insert_ids; +}; +} // namespace SPIRV_CROSS_NAMESPACE + +#endif diff --git a/thirdparty/spirv-cross/spirv_msl.cpp b/thirdparty/spirv-cross/spirv_msl.cpp new file mode 100644 index 00000000000..383ce688e98 --- /dev/null +++ b/thirdparty/spirv-cross/spirv_msl.cpp @@ -0,0 +1,18810 @@ +/* + * Copyright 2016-2021 The Brenwill Workshop Ltd. + * SPDX-License-Identifier: Apache-2.0 OR MIT + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +/* + * At your option, you may choose to accept this material under either: + * 1. The Apache License, Version 2.0, found at , or + * 2. The MIT License, found at . + */ + +#include "spirv_msl.hpp" +#include "GLSL.std.450.h" + +#include +#include +#include + +using namespace spv; +using namespace SPIRV_CROSS_NAMESPACE; +using namespace std; + +static const uint32_t k_unknown_location = ~0u; +static const uint32_t k_unknown_component = ~0u; +static const char *force_inline = "static inline __attribute__((always_inline))"; + +CompilerMSL::CompilerMSL(std::vector spirv_) + : CompilerGLSL(std::move(spirv_)) +{ +} + +CompilerMSL::CompilerMSL(const uint32_t *ir_, size_t word_count) + : CompilerGLSL(ir_, word_count) +{ +} + +CompilerMSL::CompilerMSL(const ParsedIR &ir_) + : CompilerGLSL(ir_) +{ +} + +CompilerMSL::CompilerMSL(ParsedIR &&ir_) + : CompilerGLSL(std::move(ir_)) +{ +} + +void CompilerMSL::add_msl_shader_input(const MSLShaderInterfaceVariable &si) +{ + inputs_by_location[{si.location, si.component}] = si; + if (si.builtin != BuiltInMax && !inputs_by_builtin.count(si.builtin)) + inputs_by_builtin[si.builtin] = si; +} + +void CompilerMSL::add_msl_shader_output(const MSLShaderInterfaceVariable &so) +{ + outputs_by_location[{so.location, so.component}] = so; + if (so.builtin != BuiltInMax && !outputs_by_builtin.count(so.builtin)) + outputs_by_builtin[so.builtin] = so; +} + +void CompilerMSL::add_msl_resource_binding(const MSLResourceBinding &binding) +{ + StageSetBinding tuple = { binding.stage, binding.desc_set, binding.binding }; + resource_bindings[tuple] = { binding, false }; + + // If we might need to pad argument buffer members to positionally align + // arg buffer indexes, also maintain a lookup by argument buffer index. + if (msl_options.pad_argument_buffer_resources) + { + StageSetBinding arg_idx_tuple = { binding.stage, binding.desc_set, k_unknown_component }; + +#define ADD_ARG_IDX_TO_BINDING_NUM_LOOKUP(rez) \ + arg_idx_tuple.binding = binding.msl_##rez; \ + resource_arg_buff_idx_to_binding_number[arg_idx_tuple] = binding.binding + + switch (binding.basetype) + { + case SPIRType::Void: + case SPIRType::Boolean: + case SPIRType::SByte: + case SPIRType::UByte: + case SPIRType::Short: + case SPIRType::UShort: + case SPIRType::Int: + case SPIRType::UInt: + case SPIRType::Int64: + case SPIRType::UInt64: + case SPIRType::AtomicCounter: + case SPIRType::Half: + case SPIRType::Float: + case SPIRType::Double: + ADD_ARG_IDX_TO_BINDING_NUM_LOOKUP(buffer); + break; + case SPIRType::Image: + ADD_ARG_IDX_TO_BINDING_NUM_LOOKUP(texture); + break; + case SPIRType::Sampler: + ADD_ARG_IDX_TO_BINDING_NUM_LOOKUP(sampler); + break; + case SPIRType::SampledImage: + ADD_ARG_IDX_TO_BINDING_NUM_LOOKUP(texture); + ADD_ARG_IDX_TO_BINDING_NUM_LOOKUP(sampler); + break; + default: + SPIRV_CROSS_THROW("Unexpected argument buffer resource base type. When padding argument buffer elements, " + "all descriptor set resources must be supplied with a base type by the app."); + } +#undef ADD_ARG_IDX_TO_BINDING_NUM_LOOKUP + } +} + +void CompilerMSL::add_dynamic_buffer(uint32_t desc_set, uint32_t binding, uint32_t index) +{ + SetBindingPair pair = { desc_set, binding }; + buffers_requiring_dynamic_offset[pair] = { index, 0 }; +} + +void CompilerMSL::add_inline_uniform_block(uint32_t desc_set, uint32_t binding) +{ + SetBindingPair pair = { desc_set, binding }; + inline_uniform_blocks.insert(pair); +} + +void CompilerMSL::add_discrete_descriptor_set(uint32_t desc_set) +{ + if (desc_set < kMaxArgumentBuffers) + argument_buffer_discrete_mask |= 1u << desc_set; +} + +void CompilerMSL::set_argument_buffer_device_address_space(uint32_t desc_set, bool device_storage) +{ + if (desc_set < kMaxArgumentBuffers) + { + if (device_storage) + argument_buffer_device_storage_mask |= 1u << desc_set; + else + argument_buffer_device_storage_mask &= ~(1u << desc_set); + } +} + +bool CompilerMSL::is_msl_shader_input_used(uint32_t location) +{ + // Don't report internal location allocations to app. + return location_inputs_in_use.count(location) != 0 && + location_inputs_in_use_fallback.count(location) == 0; +} + +bool CompilerMSL::is_msl_shader_output_used(uint32_t location) +{ + // Don't report internal location allocations to app. + return location_outputs_in_use.count(location) != 0 && + location_outputs_in_use_fallback.count(location) == 0; +} + +uint32_t CompilerMSL::get_automatic_builtin_input_location(spv::BuiltIn builtin) const +{ + auto itr = builtin_to_automatic_input_location.find(builtin); + if (itr == builtin_to_automatic_input_location.end()) + return k_unknown_location; + else + return itr->second; +} + +uint32_t CompilerMSL::get_automatic_builtin_output_location(spv::BuiltIn builtin) const +{ + auto itr = builtin_to_automatic_output_location.find(builtin); + if (itr == builtin_to_automatic_output_location.end()) + return k_unknown_location; + else + return itr->second; +} + +bool CompilerMSL::is_msl_resource_binding_used(ExecutionModel model, uint32_t desc_set, uint32_t binding) const +{ + StageSetBinding tuple = { model, desc_set, binding }; + auto itr = resource_bindings.find(tuple); + return itr != end(resource_bindings) && itr->second.second; +} + +bool CompilerMSL::is_var_runtime_size_array(const SPIRVariable &var) const +{ + auto& type = get_variable_data_type(var); + return is_runtime_size_array(type) && get_resource_array_size(type, var.self) == 0; +} + +// Returns the size of the array of resources used by the variable with the specified type and id. +// The size is first retrieved from the type, but in the case of runtime array sizing, +// the size is retrieved from the resource binding added using add_msl_resource_binding(). +uint32_t CompilerMSL::get_resource_array_size(const SPIRType &type, uint32_t id) const +{ + uint32_t array_size = to_array_size_literal(type); + + // If we have argument buffers, we need to honor the ABI by using the correct array size + // from the layout. Only use shader declared size if we're not using argument buffers. + uint32_t desc_set = get_decoration(id, DecorationDescriptorSet); + if (!descriptor_set_is_argument_buffer(desc_set) && array_size) + return array_size; + + StageSetBinding tuple = { get_entry_point().model, desc_set, + get_decoration(id, DecorationBinding) }; + auto itr = resource_bindings.find(tuple); + return itr != end(resource_bindings) ? itr->second.first.count : array_size; +} + +uint32_t CompilerMSL::get_automatic_msl_resource_binding(uint32_t id) const +{ + return get_extended_decoration(id, SPIRVCrossDecorationResourceIndexPrimary); +} + +uint32_t CompilerMSL::get_automatic_msl_resource_binding_secondary(uint32_t id) const +{ + return get_extended_decoration(id, SPIRVCrossDecorationResourceIndexSecondary); +} + +uint32_t CompilerMSL::get_automatic_msl_resource_binding_tertiary(uint32_t id) const +{ + return get_extended_decoration(id, SPIRVCrossDecorationResourceIndexTertiary); +} + +uint32_t CompilerMSL::get_automatic_msl_resource_binding_quaternary(uint32_t id) const +{ + return get_extended_decoration(id, SPIRVCrossDecorationResourceIndexQuaternary); +} + +void CompilerMSL::set_fragment_output_components(uint32_t location, uint32_t components) +{ + fragment_output_components[location] = components; +} + +bool CompilerMSL::builtin_translates_to_nonarray(spv::BuiltIn builtin) const +{ + return (builtin == BuiltInSampleMask); +} + +void CompilerMSL::build_implicit_builtins() +{ + bool need_sample_pos = active_input_builtins.get(BuiltInSamplePosition); + bool need_vertex_params = capture_output_to_buffer && get_execution_model() == ExecutionModelVertex && + !msl_options.vertex_for_tessellation; + bool need_tesc_params = is_tesc_shader(); + bool need_tese_params = is_tese_shader() && msl_options.raw_buffer_tese_input; + bool need_subgroup_mask = + active_input_builtins.get(BuiltInSubgroupEqMask) || active_input_builtins.get(BuiltInSubgroupGeMask) || + active_input_builtins.get(BuiltInSubgroupGtMask) || active_input_builtins.get(BuiltInSubgroupLeMask) || + active_input_builtins.get(BuiltInSubgroupLtMask); + bool need_subgroup_ge_mask = !msl_options.is_ios() && (active_input_builtins.get(BuiltInSubgroupGeMask) || + active_input_builtins.get(BuiltInSubgroupGtMask)); + bool need_multiview = get_execution_model() == ExecutionModelVertex && !msl_options.view_index_from_device_index && + msl_options.multiview_layered_rendering && + (msl_options.multiview || active_input_builtins.get(BuiltInViewIndex)); + bool need_dispatch_base = + msl_options.dispatch_base && get_execution_model() == ExecutionModelGLCompute && + (active_input_builtins.get(BuiltInWorkgroupId) || active_input_builtins.get(BuiltInGlobalInvocationId)); + bool need_grid_params = get_execution_model() == ExecutionModelVertex && msl_options.vertex_for_tessellation; + bool need_vertex_base_params = + need_grid_params && + (active_input_builtins.get(BuiltInVertexId) || active_input_builtins.get(BuiltInVertexIndex) || + active_input_builtins.get(BuiltInBaseVertex) || active_input_builtins.get(BuiltInInstanceId) || + active_input_builtins.get(BuiltInInstanceIndex) || active_input_builtins.get(BuiltInBaseInstance)); + bool need_local_invocation_index = msl_options.emulate_subgroups && active_input_builtins.get(BuiltInSubgroupId); + bool need_workgroup_size = msl_options.emulate_subgroups && active_input_builtins.get(BuiltInNumSubgroups); + bool force_frag_depth_passthrough = + get_execution_model() == ExecutionModelFragment && !uses_explicit_early_fragment_test() && need_subpass_input && + msl_options.enable_frag_depth_builtin && msl_options.input_attachment_is_ds_attachment; + + if (need_subpass_input || need_sample_pos || need_subgroup_mask || need_vertex_params || need_tesc_params || + need_tese_params || need_multiview || need_dispatch_base || need_vertex_base_params || need_grid_params || + needs_sample_id || needs_subgroup_invocation_id || needs_subgroup_size || needs_helper_invocation || + has_additional_fixed_sample_mask() || need_local_invocation_index || need_workgroup_size || force_frag_depth_passthrough) + { + bool has_frag_coord = false; + bool has_sample_id = false; + bool has_vertex_idx = false; + bool has_base_vertex = false; + bool has_instance_idx = false; + bool has_base_instance = false; + bool has_invocation_id = false; + bool has_primitive_id = false; + bool has_subgroup_invocation_id = false; + bool has_subgroup_size = false; + bool has_view_idx = false; + bool has_layer = false; + bool has_helper_invocation = false; + bool has_local_invocation_index = false; + bool has_workgroup_size = false; + bool has_frag_depth = false; + uint32_t workgroup_id_type = 0; + + ir.for_each_typed_id([&](uint32_t, SPIRVariable &var) { + if (var.storage != StorageClassInput && var.storage != StorageClassOutput) + return; + if (!interface_variable_exists_in_entry_point(var.self)) + return; + if (!has_decoration(var.self, DecorationBuiltIn)) + return; + + BuiltIn builtin = ir.meta[var.self].decoration.builtin_type; + + if (var.storage == StorageClassOutput) + { + if (has_additional_fixed_sample_mask() && builtin == BuiltInSampleMask) + { + builtin_sample_mask_id = var.self; + mark_implicit_builtin(StorageClassOutput, BuiltInSampleMask, var.self); + does_shader_write_sample_mask = true; + } + + if (force_frag_depth_passthrough && builtin == BuiltInFragDepth) + { + builtin_frag_depth_id = var.self; + mark_implicit_builtin(StorageClassOutput, BuiltInFragDepth, var.self); + has_frag_depth = true; + } + } + + if (var.storage != StorageClassInput) + return; + + // Use Metal's native frame-buffer fetch API for subpass inputs. + if (need_subpass_input && (!msl_options.use_framebuffer_fetch_subpasses)) + { + switch (builtin) + { + case BuiltInFragCoord: + mark_implicit_builtin(StorageClassInput, BuiltInFragCoord, var.self); + builtin_frag_coord_id = var.self; + has_frag_coord = true; + break; + case BuiltInLayer: + if (!msl_options.arrayed_subpass_input || msl_options.multiview) + break; + mark_implicit_builtin(StorageClassInput, BuiltInLayer, var.self); + builtin_layer_id = var.self; + has_layer = true; + break; + case BuiltInViewIndex: + if (!msl_options.multiview) + break; + mark_implicit_builtin(StorageClassInput, BuiltInViewIndex, var.self); + builtin_view_idx_id = var.self; + has_view_idx = true; + break; + default: + break; + } + } + + if ((need_sample_pos || needs_sample_id) && builtin == BuiltInSampleId) + { + builtin_sample_id_id = var.self; + mark_implicit_builtin(StorageClassInput, BuiltInSampleId, var.self); + has_sample_id = true; + } + + if (need_vertex_params) + { + switch (builtin) + { + case BuiltInVertexIndex: + builtin_vertex_idx_id = var.self; + mark_implicit_builtin(StorageClassInput, BuiltInVertexIndex, var.self); + has_vertex_idx = true; + break; + case BuiltInBaseVertex: + builtin_base_vertex_id = var.self; + mark_implicit_builtin(StorageClassInput, BuiltInBaseVertex, var.self); + has_base_vertex = true; + break; + case BuiltInInstanceIndex: + builtin_instance_idx_id = var.self; + mark_implicit_builtin(StorageClassInput, BuiltInInstanceIndex, var.self); + has_instance_idx = true; + break; + case BuiltInBaseInstance: + builtin_base_instance_id = var.self; + mark_implicit_builtin(StorageClassInput, BuiltInBaseInstance, var.self); + has_base_instance = true; + break; + default: + break; + } + } + + if (need_tesc_params && builtin == BuiltInInvocationId) + { + builtin_invocation_id_id = var.self; + mark_implicit_builtin(StorageClassInput, BuiltInInvocationId, var.self); + has_invocation_id = true; + } + + if ((need_tesc_params || need_tese_params) && builtin == BuiltInPrimitiveId) + { + builtin_primitive_id_id = var.self; + mark_implicit_builtin(StorageClassInput, BuiltInPrimitiveId, var.self); + has_primitive_id = true; + } + + if (need_tese_params && builtin == BuiltInTessLevelOuter) + { + tess_level_outer_var_id = var.self; + } + + if (need_tese_params && builtin == BuiltInTessLevelInner) + { + tess_level_inner_var_id = var.self; + } + + if ((need_subgroup_mask || needs_subgroup_invocation_id) && builtin == BuiltInSubgroupLocalInvocationId) + { + builtin_subgroup_invocation_id_id = var.self; + mark_implicit_builtin(StorageClassInput, BuiltInSubgroupLocalInvocationId, var.self); + has_subgroup_invocation_id = true; + } + + if ((need_subgroup_ge_mask || needs_subgroup_size) && builtin == BuiltInSubgroupSize) + { + builtin_subgroup_size_id = var.self; + mark_implicit_builtin(StorageClassInput, BuiltInSubgroupSize, var.self); + has_subgroup_size = true; + } + + if (need_multiview) + { + switch (builtin) + { + case BuiltInInstanceIndex: + // The view index here is derived from the instance index. + builtin_instance_idx_id = var.self; + mark_implicit_builtin(StorageClassInput, BuiltInInstanceIndex, var.self); + has_instance_idx = true; + break; + case BuiltInBaseInstance: + // If a non-zero base instance is used, we need to adjust for it when calculating the view index. + builtin_base_instance_id = var.self; + mark_implicit_builtin(StorageClassInput, BuiltInBaseInstance, var.self); + has_base_instance = true; + break; + case BuiltInViewIndex: + builtin_view_idx_id = var.self; + mark_implicit_builtin(StorageClassInput, BuiltInViewIndex, var.self); + has_view_idx = true; + break; + default: + break; + } + } + + if (needs_helper_invocation && builtin == BuiltInHelperInvocation) + { + builtin_helper_invocation_id = var.self; + mark_implicit_builtin(StorageClassInput, BuiltInHelperInvocation, var.self); + has_helper_invocation = true; + } + + if (need_local_invocation_index && builtin == BuiltInLocalInvocationIndex) + { + builtin_local_invocation_index_id = var.self; + mark_implicit_builtin(StorageClassInput, BuiltInLocalInvocationIndex, var.self); + has_local_invocation_index = true; + } + + if (need_workgroup_size && builtin == BuiltInLocalInvocationId) + { + builtin_workgroup_size_id = var.self; + mark_implicit_builtin(StorageClassInput, BuiltInWorkgroupSize, var.self); + has_workgroup_size = true; + } + + // The base workgroup needs to have the same type and vector size + // as the workgroup or invocation ID, so keep track of the type that + // was used. + if (need_dispatch_base && workgroup_id_type == 0 && + (builtin == BuiltInWorkgroupId || builtin == BuiltInGlobalInvocationId)) + workgroup_id_type = var.basetype; + }); + + // Use Metal's native frame-buffer fetch API for subpass inputs. + if ((!has_frag_coord || (msl_options.multiview && !has_view_idx) || + (msl_options.arrayed_subpass_input && !msl_options.multiview && !has_layer)) && + (!msl_options.use_framebuffer_fetch_subpasses) && need_subpass_input) + { + if (!has_frag_coord) + { + uint32_t offset = ir.increase_bound_by(3); + uint32_t type_id = offset; + uint32_t type_ptr_id = offset + 1; + uint32_t var_id = offset + 2; + + // Create gl_FragCoord. + SPIRType vec4_type { OpTypeVector }; + vec4_type.basetype = SPIRType::Float; + vec4_type.width = 32; + vec4_type.vecsize = 4; + set(type_id, vec4_type); + + SPIRType vec4_type_ptr = vec4_type; + vec4_type_ptr.op = OpTypePointer; + vec4_type_ptr.pointer = true; + vec4_type_ptr.pointer_depth++; + vec4_type_ptr.parent_type = type_id; + vec4_type_ptr.storage = StorageClassInput; + auto &ptr_type = set(type_ptr_id, vec4_type_ptr); + ptr_type.self = type_id; + + set(var_id, type_ptr_id, StorageClassInput); + set_decoration(var_id, DecorationBuiltIn, BuiltInFragCoord); + builtin_frag_coord_id = var_id; + mark_implicit_builtin(StorageClassInput, BuiltInFragCoord, var_id); + } + + if (!has_layer && msl_options.arrayed_subpass_input && !msl_options.multiview) + { + uint32_t offset = ir.increase_bound_by(2); + uint32_t type_ptr_id = offset; + uint32_t var_id = offset + 1; + + // Create gl_Layer. + SPIRType uint_type_ptr = get_uint_type(); + uint_type_ptr.op = OpTypePointer; + uint_type_ptr.pointer = true; + uint_type_ptr.pointer_depth++; + uint_type_ptr.parent_type = get_uint_type_id(); + uint_type_ptr.storage = StorageClassInput; + auto &ptr_type = set(type_ptr_id, uint_type_ptr); + ptr_type.self = get_uint_type_id(); + + set(var_id, type_ptr_id, StorageClassInput); + set_decoration(var_id, DecorationBuiltIn, BuiltInLayer); + builtin_layer_id = var_id; + mark_implicit_builtin(StorageClassInput, BuiltInLayer, var_id); + } + + if (!has_view_idx && msl_options.multiview) + { + uint32_t offset = ir.increase_bound_by(2); + uint32_t type_ptr_id = offset; + uint32_t var_id = offset + 1; + + // Create gl_ViewIndex. + SPIRType uint_type_ptr = get_uint_type(); + uint_type_ptr.op = OpTypePointer; + uint_type_ptr.pointer = true; + uint_type_ptr.pointer_depth++; + uint_type_ptr.parent_type = get_uint_type_id(); + uint_type_ptr.storage = StorageClassInput; + auto &ptr_type = set(type_ptr_id, uint_type_ptr); + ptr_type.self = get_uint_type_id(); + + set(var_id, type_ptr_id, StorageClassInput); + set_decoration(var_id, DecorationBuiltIn, BuiltInViewIndex); + builtin_view_idx_id = var_id; + mark_implicit_builtin(StorageClassInput, BuiltInViewIndex, var_id); + } + } + + if (!has_sample_id && (need_sample_pos || needs_sample_id)) + { + uint32_t offset = ir.increase_bound_by(2); + uint32_t type_ptr_id = offset; + uint32_t var_id = offset + 1; + + // Create gl_SampleID. + SPIRType uint_type_ptr = get_uint_type(); + uint_type_ptr.op = OpTypePointer; + uint_type_ptr.pointer = true; + uint_type_ptr.pointer_depth++; + uint_type_ptr.parent_type = get_uint_type_id(); + uint_type_ptr.storage = StorageClassInput; + auto &ptr_type = set(type_ptr_id, uint_type_ptr); + ptr_type.self = get_uint_type_id(); + + set(var_id, type_ptr_id, StorageClassInput); + set_decoration(var_id, DecorationBuiltIn, BuiltInSampleId); + builtin_sample_id_id = var_id; + mark_implicit_builtin(StorageClassInput, BuiltInSampleId, var_id); + } + + if ((need_vertex_params && (!has_vertex_idx || !has_base_vertex || !has_instance_idx || !has_base_instance)) || + (need_multiview && (!has_instance_idx || !has_base_instance || !has_view_idx))) + { + uint32_t type_ptr_id = ir.increase_bound_by(1); + + SPIRType uint_type_ptr = get_uint_type(); + uint_type_ptr.op = OpTypePointer; + uint_type_ptr.pointer = true; + uint_type_ptr.pointer_depth++; + uint_type_ptr.parent_type = get_uint_type_id(); + uint_type_ptr.storage = StorageClassInput; + auto &ptr_type = set(type_ptr_id, uint_type_ptr); + ptr_type.self = get_uint_type_id(); + + if (need_vertex_params && !has_vertex_idx) + { + uint32_t var_id = ir.increase_bound_by(1); + + // Create gl_VertexIndex. + set(var_id, type_ptr_id, StorageClassInput); + set_decoration(var_id, DecorationBuiltIn, BuiltInVertexIndex); + builtin_vertex_idx_id = var_id; + mark_implicit_builtin(StorageClassInput, BuiltInVertexIndex, var_id); + } + + if (need_vertex_params && !has_base_vertex) + { + uint32_t var_id = ir.increase_bound_by(1); + + // Create gl_BaseVertex. + set(var_id, type_ptr_id, StorageClassInput); + set_decoration(var_id, DecorationBuiltIn, BuiltInBaseVertex); + builtin_base_vertex_id = var_id; + mark_implicit_builtin(StorageClassInput, BuiltInBaseVertex, var_id); + } + + if (!has_instance_idx) // Needed by both multiview and tessellation + { + uint32_t var_id = ir.increase_bound_by(1); + + // Create gl_InstanceIndex. + set(var_id, type_ptr_id, StorageClassInput); + set_decoration(var_id, DecorationBuiltIn, BuiltInInstanceIndex); + builtin_instance_idx_id = var_id; + mark_implicit_builtin(StorageClassInput, BuiltInInstanceIndex, var_id); + } + + if (!has_base_instance) // Needed by both multiview and tessellation + { + uint32_t var_id = ir.increase_bound_by(1); + + // Create gl_BaseInstance. + set(var_id, type_ptr_id, StorageClassInput); + set_decoration(var_id, DecorationBuiltIn, BuiltInBaseInstance); + builtin_base_instance_id = var_id; + mark_implicit_builtin(StorageClassInput, BuiltInBaseInstance, var_id); + } + + if (need_multiview) + { + // Multiview shaders are not allowed to write to gl_Layer, ostensibly because + // it is implicitly written from gl_ViewIndex, but we have to do that explicitly. + // Note that we can't just abuse gl_ViewIndex for this purpose: it's an input, but + // gl_Layer is an output in vertex-pipeline shaders. + uint32_t type_ptr_out_id = ir.increase_bound_by(2); + SPIRType uint_type_ptr_out = get_uint_type(); + uint_type_ptr.op = OpTypePointer; + uint_type_ptr_out.pointer = true; + uint_type_ptr_out.pointer_depth++; + uint_type_ptr_out.parent_type = get_uint_type_id(); + uint_type_ptr_out.storage = StorageClassOutput; + auto &ptr_out_type = set(type_ptr_out_id, uint_type_ptr_out); + ptr_out_type.self = get_uint_type_id(); + uint32_t var_id = type_ptr_out_id + 1; + set(var_id, type_ptr_out_id, StorageClassOutput); + set_decoration(var_id, DecorationBuiltIn, BuiltInLayer); + builtin_layer_id = var_id; + mark_implicit_builtin(StorageClassOutput, BuiltInLayer, var_id); + } + + if (need_multiview && !has_view_idx) + { + uint32_t var_id = ir.increase_bound_by(1); + + // Create gl_ViewIndex. + set(var_id, type_ptr_id, StorageClassInput); + set_decoration(var_id, DecorationBuiltIn, BuiltInViewIndex); + builtin_view_idx_id = var_id; + mark_implicit_builtin(StorageClassInput, BuiltInViewIndex, var_id); + } + } + + if ((need_tesc_params && (msl_options.multi_patch_workgroup || !has_invocation_id || !has_primitive_id)) || + (need_tese_params && !has_primitive_id) || need_grid_params) + { + uint32_t type_ptr_id = ir.increase_bound_by(1); + + SPIRType uint_type_ptr = get_uint_type(); + uint_type_ptr.op = OpTypePointer; + uint_type_ptr.pointer = true; + uint_type_ptr.pointer_depth++; + uint_type_ptr.parent_type = get_uint_type_id(); + uint_type_ptr.storage = StorageClassInput; + auto &ptr_type = set(type_ptr_id, uint_type_ptr); + ptr_type.self = get_uint_type_id(); + + if ((need_tesc_params && msl_options.multi_patch_workgroup) || need_grid_params) + { + uint32_t var_id = ir.increase_bound_by(1); + + // Create gl_GlobalInvocationID. + set(var_id, type_ptr_id, StorageClassInput); + set_decoration(var_id, DecorationBuiltIn, BuiltInGlobalInvocationId); + builtin_invocation_id_id = var_id; + mark_implicit_builtin(StorageClassInput, BuiltInGlobalInvocationId, var_id); + } + else if (need_tesc_params && !has_invocation_id) + { + uint32_t var_id = ir.increase_bound_by(1); + + // Create gl_InvocationID. + set(var_id, type_ptr_id, StorageClassInput); + set_decoration(var_id, DecorationBuiltIn, BuiltInInvocationId); + builtin_invocation_id_id = var_id; + mark_implicit_builtin(StorageClassInput, BuiltInInvocationId, var_id); + } + + if ((need_tesc_params || need_tese_params) && !has_primitive_id) + { + uint32_t var_id = ir.increase_bound_by(1); + + // Create gl_PrimitiveID. + set(var_id, type_ptr_id, StorageClassInput); + set_decoration(var_id, DecorationBuiltIn, BuiltInPrimitiveId); + builtin_primitive_id_id = var_id; + mark_implicit_builtin(StorageClassInput, BuiltInPrimitiveId, var_id); + } + + if (need_grid_params) + { + uint32_t var_id = ir.increase_bound_by(1); + + set(var_id, build_extended_vector_type(get_uint_type_id(), 3), StorageClassInput); + set_extended_decoration(var_id, SPIRVCrossDecorationBuiltInStageInputSize); + get_entry_point().interface_variables.push_back(var_id); + set_name(var_id, "spvStageInputSize"); + builtin_stage_input_size_id = var_id; + } + } + + if (!has_subgroup_invocation_id && (need_subgroup_mask || needs_subgroup_invocation_id)) + { + uint32_t offset = ir.increase_bound_by(2); + uint32_t type_ptr_id = offset; + uint32_t var_id = offset + 1; + + // Create gl_SubgroupInvocationID. + SPIRType uint_type_ptr = get_uint_type(); + uint_type_ptr.op = OpTypePointer; + uint_type_ptr.pointer = true; + uint_type_ptr.pointer_depth++; + uint_type_ptr.parent_type = get_uint_type_id(); + uint_type_ptr.storage = StorageClassInput; + auto &ptr_type = set(type_ptr_id, uint_type_ptr); + ptr_type.self = get_uint_type_id(); + + set(var_id, type_ptr_id, StorageClassInput); + set_decoration(var_id, DecorationBuiltIn, BuiltInSubgroupLocalInvocationId); + builtin_subgroup_invocation_id_id = var_id; + mark_implicit_builtin(StorageClassInput, BuiltInSubgroupLocalInvocationId, var_id); + } + + if (!has_subgroup_size && (need_subgroup_ge_mask || needs_subgroup_size)) + { + uint32_t offset = ir.increase_bound_by(2); + uint32_t type_ptr_id = offset; + uint32_t var_id = offset + 1; + + // Create gl_SubgroupSize. + SPIRType uint_type_ptr = get_uint_type(); + uint_type_ptr.op = OpTypePointer; + uint_type_ptr.pointer = true; + uint_type_ptr.pointer_depth++; + uint_type_ptr.parent_type = get_uint_type_id(); + uint_type_ptr.storage = StorageClassInput; + auto &ptr_type = set(type_ptr_id, uint_type_ptr); + ptr_type.self = get_uint_type_id(); + + set(var_id, type_ptr_id, StorageClassInput); + set_decoration(var_id, DecorationBuiltIn, BuiltInSubgroupSize); + builtin_subgroup_size_id = var_id; + mark_implicit_builtin(StorageClassInput, BuiltInSubgroupSize, var_id); + } + + if (need_dispatch_base || need_vertex_base_params) + { + if (workgroup_id_type == 0) + workgroup_id_type = build_extended_vector_type(get_uint_type_id(), 3); + uint32_t var_id; + if (msl_options.supports_msl_version(1, 2)) + { + // If we have MSL 1.2, we can (ab)use the [[grid_origin]] builtin + // to convey this information and save a buffer slot. + uint32_t offset = ir.increase_bound_by(1); + var_id = offset; + + set(var_id, workgroup_id_type, StorageClassInput); + set_extended_decoration(var_id, SPIRVCrossDecorationBuiltInDispatchBase); + get_entry_point().interface_variables.push_back(var_id); + } + else + { + // Otherwise, we need to fall back to a good ol' fashioned buffer. + uint32_t offset = ir.increase_bound_by(2); + var_id = offset; + uint32_t type_id = offset + 1; + + SPIRType var_type = get(workgroup_id_type); + var_type.storage = StorageClassUniform; + set(type_id, var_type); + + set(var_id, type_id, StorageClassUniform); + // This should never match anything. + set_decoration(var_id, DecorationDescriptorSet, ~(5u)); + set_decoration(var_id, DecorationBinding, msl_options.indirect_params_buffer_index); + set_extended_decoration(var_id, SPIRVCrossDecorationResourceIndexPrimary, + msl_options.indirect_params_buffer_index); + } + set_name(var_id, "spvDispatchBase"); + builtin_dispatch_base_id = var_id; + } + + if (has_additional_fixed_sample_mask() && !does_shader_write_sample_mask) + { + uint32_t offset = ir.increase_bound_by(2); + uint32_t var_id = offset + 1; + + // Create gl_SampleMask. + SPIRType uint_type_ptr_out = get_uint_type(); + uint_type_ptr_out.op = OpTypePointer; + uint_type_ptr_out.pointer = true; + uint_type_ptr_out.pointer_depth++; + uint_type_ptr_out.parent_type = get_uint_type_id(); + uint_type_ptr_out.storage = StorageClassOutput; + + auto &ptr_out_type = set(offset, uint_type_ptr_out); + ptr_out_type.self = get_uint_type_id(); + set(var_id, offset, StorageClassOutput); + set_decoration(var_id, DecorationBuiltIn, BuiltInSampleMask); + builtin_sample_mask_id = var_id; + mark_implicit_builtin(StorageClassOutput, BuiltInSampleMask, var_id); + } + + if (!has_helper_invocation && needs_helper_invocation) + { + uint32_t offset = ir.increase_bound_by(3); + uint32_t type_id = offset; + uint32_t type_ptr_id = offset + 1; + uint32_t var_id = offset + 2; + + // Create gl_HelperInvocation. + SPIRType bool_type { OpTypeBool }; + bool_type.basetype = SPIRType::Boolean; + bool_type.width = 8; + bool_type.vecsize = 1; + set(type_id, bool_type); + + SPIRType bool_type_ptr_in = bool_type; + bool_type_ptr_in.op = spv::OpTypePointer; + bool_type_ptr_in.pointer = true; + bool_type_ptr_in.pointer_depth++; + bool_type_ptr_in.parent_type = type_id; + bool_type_ptr_in.storage = StorageClassInput; + + auto &ptr_in_type = set(type_ptr_id, bool_type_ptr_in); + ptr_in_type.self = type_id; + set(var_id, type_ptr_id, StorageClassInput); + set_decoration(var_id, DecorationBuiltIn, BuiltInHelperInvocation); + builtin_helper_invocation_id = var_id; + mark_implicit_builtin(StorageClassInput, BuiltInHelperInvocation, var_id); + } + + if (need_local_invocation_index && !has_local_invocation_index) + { + uint32_t offset = ir.increase_bound_by(2); + uint32_t type_ptr_id = offset; + uint32_t var_id = offset + 1; + + // Create gl_LocalInvocationIndex. + SPIRType uint_type_ptr = get_uint_type(); + uint_type_ptr.op = OpTypePointer; + uint_type_ptr.pointer = true; + uint_type_ptr.pointer_depth++; + uint_type_ptr.parent_type = get_uint_type_id(); + uint_type_ptr.storage = StorageClassInput; + + auto &ptr_type = set(type_ptr_id, uint_type_ptr); + ptr_type.self = get_uint_type_id(); + set(var_id, type_ptr_id, StorageClassInput); + set_decoration(var_id, DecorationBuiltIn, BuiltInLocalInvocationIndex); + builtin_local_invocation_index_id = var_id; + mark_implicit_builtin(StorageClassInput, BuiltInLocalInvocationIndex, var_id); + } + + if (need_workgroup_size && !has_workgroup_size) + { + uint32_t offset = ir.increase_bound_by(2); + uint32_t type_ptr_id = offset; + uint32_t var_id = offset + 1; + + // Create gl_WorkgroupSize. + uint32_t type_id = build_extended_vector_type(get_uint_type_id(), 3); + SPIRType uint_type_ptr = get(type_id); + uint_type_ptr.op = OpTypePointer; + uint_type_ptr.pointer = true; + uint_type_ptr.pointer_depth++; + uint_type_ptr.parent_type = type_id; + uint_type_ptr.storage = StorageClassInput; + + auto &ptr_type = set(type_ptr_id, uint_type_ptr); + ptr_type.self = type_id; + set(var_id, type_ptr_id, StorageClassInput); + set_decoration(var_id, DecorationBuiltIn, BuiltInWorkgroupSize); + builtin_workgroup_size_id = var_id; + mark_implicit_builtin(StorageClassInput, BuiltInWorkgroupSize, var_id); + } + + if (!has_frag_depth && force_frag_depth_passthrough) + { + uint32_t offset = ir.increase_bound_by(3); + uint32_t type_id = offset; + uint32_t type_ptr_id = offset + 1; + uint32_t var_id = offset + 2; + + // Create gl_FragDepth + SPIRType float_type { OpTypeFloat }; + float_type.basetype = SPIRType::Float; + float_type.width = 32; + float_type.vecsize = 1; + set(type_id, float_type); + + SPIRType float_type_ptr_in = float_type; + float_type_ptr_in.op = spv::OpTypePointer; + float_type_ptr_in.pointer = true; + float_type_ptr_in.pointer_depth++; + float_type_ptr_in.parent_type = type_id; + float_type_ptr_in.storage = StorageClassOutput; + + auto &ptr_in_type = set(type_ptr_id, float_type_ptr_in); + ptr_in_type.self = type_id; + set(var_id, type_ptr_id, StorageClassOutput); + set_decoration(var_id, DecorationBuiltIn, BuiltInFragDepth); + builtin_frag_depth_id = var_id; + mark_implicit_builtin(StorageClassOutput, BuiltInFragDepth, var_id); + active_output_builtins.set(BuiltInFragDepth); + } + } + + if (needs_swizzle_buffer_def) + { + uint32_t var_id = build_constant_uint_array_pointer(); + set_name(var_id, "spvSwizzleConstants"); + // This should never match anything. + set_decoration(var_id, DecorationDescriptorSet, kSwizzleBufferBinding); + set_decoration(var_id, DecorationBinding, msl_options.swizzle_buffer_index); + set_extended_decoration(var_id, SPIRVCrossDecorationResourceIndexPrimary, msl_options.swizzle_buffer_index); + swizzle_buffer_id = var_id; + } + + if (needs_buffer_size_buffer()) + { + uint32_t var_id = build_constant_uint_array_pointer(); + set_name(var_id, "spvBufferSizeConstants"); + // This should never match anything. + set_decoration(var_id, DecorationDescriptorSet, kBufferSizeBufferBinding); + set_decoration(var_id, DecorationBinding, msl_options.buffer_size_buffer_index); + set_extended_decoration(var_id, SPIRVCrossDecorationResourceIndexPrimary, msl_options.buffer_size_buffer_index); + buffer_size_buffer_id = var_id; + } + + if (needs_view_mask_buffer()) + { + uint32_t var_id = build_constant_uint_array_pointer(); + set_name(var_id, "spvViewMask"); + // This should never match anything. + set_decoration(var_id, DecorationDescriptorSet, ~(4u)); + set_decoration(var_id, DecorationBinding, msl_options.view_mask_buffer_index); + set_extended_decoration(var_id, SPIRVCrossDecorationResourceIndexPrimary, msl_options.view_mask_buffer_index); + view_mask_buffer_id = var_id; + } + + if (!buffers_requiring_dynamic_offset.empty()) + { + uint32_t var_id = build_constant_uint_array_pointer(); + set_name(var_id, "spvDynamicOffsets"); + // This should never match anything. + set_decoration(var_id, DecorationDescriptorSet, ~(5u)); + set_decoration(var_id, DecorationBinding, msl_options.dynamic_offsets_buffer_index); + set_extended_decoration(var_id, SPIRVCrossDecorationResourceIndexPrimary, + msl_options.dynamic_offsets_buffer_index); + dynamic_offsets_buffer_id = var_id; + } + + // If we're returning a struct from a vertex-like entry point, we must return a position attribute. + bool need_position = (get_execution_model() == ExecutionModelVertex || is_tese_shader()) && + !capture_output_to_buffer && !get_is_rasterization_disabled() && + !active_output_builtins.get(BuiltInPosition); + + if (need_position) + { + // If we can get away with returning void from entry point, we don't need to care. + // If there is at least one other stage output, we need to return [[position]], + // so we need to create one if it doesn't appear in the SPIR-V. Before adding the + // implicit variable, check if it actually exists already, but just has not been used + // or initialized, and if so, mark it as active, and do not create the implicit variable. + bool has_output = false; + ir.for_each_typed_id([&](uint32_t, SPIRVariable &var) { + if (var.storage == StorageClassOutput && interface_variable_exists_in_entry_point(var.self)) + { + has_output = true; + + // Check if the var is the Position builtin + if (has_decoration(var.self, DecorationBuiltIn) && get_decoration(var.self, DecorationBuiltIn) == BuiltInPosition) + active_output_builtins.set(BuiltInPosition); + + // If the var is a struct, check if any members is the Position builtin + auto &var_type = get_variable_element_type(var); + if (var_type.basetype == SPIRType::Struct) + { + auto mbr_cnt = var_type.member_types.size(); + for (uint32_t mbr_idx = 0; mbr_idx < mbr_cnt; mbr_idx++) + { + auto builtin = BuiltInMax; + bool is_builtin = is_member_builtin(var_type, mbr_idx, &builtin); + if (is_builtin && builtin == BuiltInPosition) + active_output_builtins.set(BuiltInPosition); + } + } + } + }); + need_position = has_output && !active_output_builtins.get(BuiltInPosition); + } + + if (need_position) + { + uint32_t offset = ir.increase_bound_by(3); + uint32_t type_id = offset; + uint32_t type_ptr_id = offset + 1; + uint32_t var_id = offset + 2; + + // Create gl_Position. + SPIRType vec4_type { OpTypeVector }; + vec4_type.basetype = SPIRType::Float; + vec4_type.width = 32; + vec4_type.vecsize = 4; + set(type_id, vec4_type); + + SPIRType vec4_type_ptr = vec4_type; + vec4_type_ptr.op = OpTypePointer; + vec4_type_ptr.pointer = true; + vec4_type_ptr.pointer_depth++; + vec4_type_ptr.parent_type = type_id; + vec4_type_ptr.storage = StorageClassOutput; + auto &ptr_type = set(type_ptr_id, vec4_type_ptr); + ptr_type.self = type_id; + + set(var_id, type_ptr_id, StorageClassOutput); + set_decoration(var_id, DecorationBuiltIn, BuiltInPosition); + mark_implicit_builtin(StorageClassOutput, BuiltInPosition, var_id); + } +} + +// Checks if the specified builtin variable (e.g. gl_InstanceIndex) is marked as active. +// If not, it marks it as active and forces a recompilation. +// This might be used when the optimization of inactive builtins was too optimistic (e.g. when "spvOut" is emitted). +void CompilerMSL::ensure_builtin(spv::StorageClass storage, spv::BuiltIn builtin) +{ + Bitset *active_builtins = nullptr; + switch (storage) + { + case StorageClassInput: + active_builtins = &active_input_builtins; + break; + + case StorageClassOutput: + active_builtins = &active_output_builtins; + break; + + default: + break; + } + + // At this point, the specified builtin variable must have already been declared in the entry point. + // If not, mark as active and force recompile. + if (active_builtins != nullptr && !active_builtins->get(builtin)) + { + active_builtins->set(builtin); + force_recompile(); + } +} + +void CompilerMSL::mark_implicit_builtin(StorageClass storage, BuiltIn builtin, uint32_t id) +{ + Bitset *active_builtins = nullptr; + switch (storage) + { + case StorageClassInput: + active_builtins = &active_input_builtins; + break; + + case StorageClassOutput: + active_builtins = &active_output_builtins; + break; + + default: + break; + } + + assert(active_builtins != nullptr); + active_builtins->set(builtin); + + auto &var = get_entry_point().interface_variables; + if (find(begin(var), end(var), VariableID(id)) == end(var)) + var.push_back(id); +} + +uint32_t CompilerMSL::build_constant_uint_array_pointer() +{ + uint32_t offset = ir.increase_bound_by(3); + uint32_t type_ptr_id = offset; + uint32_t type_ptr_ptr_id = offset + 1; + uint32_t var_id = offset + 2; + + // Create a buffer to hold extra data, including the swizzle constants. + SPIRType uint_type_pointer = get_uint_type(); + uint_type_pointer.op = OpTypePointer; + uint_type_pointer.pointer = true; + uint_type_pointer.pointer_depth++; + uint_type_pointer.parent_type = get_uint_type_id(); + uint_type_pointer.storage = StorageClassUniform; + set(type_ptr_id, uint_type_pointer); + set_decoration(type_ptr_id, DecorationArrayStride, 4); + + SPIRType uint_type_pointer2 = uint_type_pointer; + uint_type_pointer2.pointer_depth++; + uint_type_pointer2.parent_type = type_ptr_id; + set(type_ptr_ptr_id, uint_type_pointer2); + + set(var_id, type_ptr_ptr_id, StorageClassUniformConstant); + return var_id; +} + +static string create_sampler_address(const char *prefix, MSLSamplerAddress addr) +{ + switch (addr) + { + case MSL_SAMPLER_ADDRESS_CLAMP_TO_EDGE: + return join(prefix, "address::clamp_to_edge"); + case MSL_SAMPLER_ADDRESS_CLAMP_TO_ZERO: + return join(prefix, "address::clamp_to_zero"); + case MSL_SAMPLER_ADDRESS_CLAMP_TO_BORDER: + return join(prefix, "address::clamp_to_border"); + case MSL_SAMPLER_ADDRESS_REPEAT: + return join(prefix, "address::repeat"); + case MSL_SAMPLER_ADDRESS_MIRRORED_REPEAT: + return join(prefix, "address::mirrored_repeat"); + default: + SPIRV_CROSS_THROW("Invalid sampler addressing mode."); + } +} + +SPIRType &CompilerMSL::get_stage_in_struct_type() +{ + auto &si_var = get(stage_in_var_id); + return get_variable_data_type(si_var); +} + +SPIRType &CompilerMSL::get_stage_out_struct_type() +{ + auto &so_var = get(stage_out_var_id); + return get_variable_data_type(so_var); +} + +SPIRType &CompilerMSL::get_patch_stage_in_struct_type() +{ + auto &si_var = get(patch_stage_in_var_id); + return get_variable_data_type(si_var); +} + +SPIRType &CompilerMSL::get_patch_stage_out_struct_type() +{ + auto &so_var = get(patch_stage_out_var_id); + return get_variable_data_type(so_var); +} + +std::string CompilerMSL::get_tess_factor_struct_name() +{ + if (is_tessellating_triangles()) + return "MTLTriangleTessellationFactorsHalf"; + return "MTLQuadTessellationFactorsHalf"; +} + +SPIRType &CompilerMSL::get_uint_type() +{ + return get(get_uint_type_id()); +} + +uint32_t CompilerMSL::get_uint_type_id() +{ + if (uint_type_id != 0) + return uint_type_id; + + uint_type_id = ir.increase_bound_by(1); + + SPIRType type { OpTypeInt }; + type.basetype = SPIRType::UInt; + type.width = 32; + set(uint_type_id, type); + return uint_type_id; +} + +void CompilerMSL::emit_entry_point_declarations() +{ + // FIXME: Get test coverage here ... + // Constant arrays of non-primitive types (i.e. matrices) won't link properly into Metal libraries + declare_complex_constant_arrays(); + + // Emit constexpr samplers here. + for (auto &samp : constexpr_samplers_by_id) + { + auto &var = get(samp.first); + auto &type = get(var.basetype); + if (type.basetype == SPIRType::Sampler) + add_resource_name(samp.first); + + SmallVector args; + auto &s = samp.second; + + if (s.coord != MSL_SAMPLER_COORD_NORMALIZED) + args.push_back("coord::pixel"); + + if (s.min_filter == s.mag_filter) + { + if (s.min_filter != MSL_SAMPLER_FILTER_NEAREST) + args.push_back("filter::linear"); + } + else + { + if (s.min_filter != MSL_SAMPLER_FILTER_NEAREST) + args.push_back("min_filter::linear"); + if (s.mag_filter != MSL_SAMPLER_FILTER_NEAREST) + args.push_back("mag_filter::linear"); + } + + switch (s.mip_filter) + { + case MSL_SAMPLER_MIP_FILTER_NONE: + // Default + break; + case MSL_SAMPLER_MIP_FILTER_NEAREST: + args.push_back("mip_filter::nearest"); + break; + case MSL_SAMPLER_MIP_FILTER_LINEAR: + args.push_back("mip_filter::linear"); + break; + default: + SPIRV_CROSS_THROW("Invalid mip filter."); + } + + if (s.s_address == s.t_address && s.s_address == s.r_address) + { + if (s.s_address != MSL_SAMPLER_ADDRESS_CLAMP_TO_EDGE) + args.push_back(create_sampler_address("", s.s_address)); + } + else + { + if (s.s_address != MSL_SAMPLER_ADDRESS_CLAMP_TO_EDGE) + args.push_back(create_sampler_address("s_", s.s_address)); + if (s.t_address != MSL_SAMPLER_ADDRESS_CLAMP_TO_EDGE) + args.push_back(create_sampler_address("t_", s.t_address)); + if (s.r_address != MSL_SAMPLER_ADDRESS_CLAMP_TO_EDGE) + args.push_back(create_sampler_address("r_", s.r_address)); + } + + if (s.compare_enable) + { + switch (s.compare_func) + { + case MSL_SAMPLER_COMPARE_FUNC_ALWAYS: + args.push_back("compare_func::always"); + break; + case MSL_SAMPLER_COMPARE_FUNC_NEVER: + args.push_back("compare_func::never"); + break; + case MSL_SAMPLER_COMPARE_FUNC_EQUAL: + args.push_back("compare_func::equal"); + break; + case MSL_SAMPLER_COMPARE_FUNC_NOT_EQUAL: + args.push_back("compare_func::not_equal"); + break; + case MSL_SAMPLER_COMPARE_FUNC_LESS: + args.push_back("compare_func::less"); + break; + case MSL_SAMPLER_COMPARE_FUNC_LESS_EQUAL: + args.push_back("compare_func::less_equal"); + break; + case MSL_SAMPLER_COMPARE_FUNC_GREATER: + args.push_back("compare_func::greater"); + break; + case MSL_SAMPLER_COMPARE_FUNC_GREATER_EQUAL: + args.push_back("compare_func::greater_equal"); + break; + default: + SPIRV_CROSS_THROW("Invalid sampler compare function."); + } + } + + if (s.s_address == MSL_SAMPLER_ADDRESS_CLAMP_TO_BORDER || s.t_address == MSL_SAMPLER_ADDRESS_CLAMP_TO_BORDER || + s.r_address == MSL_SAMPLER_ADDRESS_CLAMP_TO_BORDER) + { + switch (s.border_color) + { + case MSL_SAMPLER_BORDER_COLOR_OPAQUE_BLACK: + args.push_back("border_color::opaque_black"); + break; + case MSL_SAMPLER_BORDER_COLOR_OPAQUE_WHITE: + args.push_back("border_color::opaque_white"); + break; + case MSL_SAMPLER_BORDER_COLOR_TRANSPARENT_BLACK: + args.push_back("border_color::transparent_black"); + break; + default: + SPIRV_CROSS_THROW("Invalid sampler border color."); + } + } + + if (s.anisotropy_enable) + args.push_back(join("max_anisotropy(", s.max_anisotropy, ")")); + if (s.lod_clamp_enable) + { + args.push_back(join("lod_clamp(", format_float(s.lod_clamp_min), ", ", format_float(s.lod_clamp_max), ")")); + } + + // If we would emit no arguments, then omit the parentheses entirely. Otherwise, + // we'll wind up with a "most vexing parse" situation. + if (args.empty()) + statement("constexpr sampler ", + type.basetype == SPIRType::SampledImage ? to_sampler_expression(samp.first) : to_name(samp.first), + ";"); + else + statement("constexpr sampler ", + type.basetype == SPIRType::SampledImage ? to_sampler_expression(samp.first) : to_name(samp.first), + "(", merge(args), ");"); + } + + // Emit dynamic buffers here. + for (auto &dynamic_buffer : buffers_requiring_dynamic_offset) + { + if (!dynamic_buffer.second.second) + { + // Could happen if no buffer was used at requested binding point. + continue; + } + + const auto &var = get(dynamic_buffer.second.second); + uint32_t var_id = var.self; + const auto &type = get_variable_data_type(var); + string name = to_name(var.self); + uint32_t desc_set = get_decoration(var.self, DecorationDescriptorSet); + uint32_t arg_id = argument_buffer_ids[desc_set]; + uint32_t base_index = dynamic_buffer.second.first; + + if (is_array(type)) + { + is_using_builtin_array = true; + statement(get_argument_address_space(var), " ", type_to_glsl(type), "* ", to_restrict(var_id, true), name, + type_to_array_glsl(type, var_id), " ="); + + uint32_t array_size = get_resource_array_size(type, var_id); + if (array_size == 0) + SPIRV_CROSS_THROW("Size of runtime array with dynamic offset could not be determined from resource bindings."); + + begin_scope(); + + for (uint32_t i = 0; i < array_size; i++) + { + statement("(", get_argument_address_space(var), " ", type_to_glsl(type), "* ", + to_restrict(var_id, false), ")((", get_argument_address_space(var), " char* ", + to_restrict(var_id, false), ")", to_name(arg_id), ".", ensure_valid_name(name, "m"), + "[", i, "]", " + ", to_name(dynamic_offsets_buffer_id), "[", base_index + i, "]),"); + } + + end_scope_decl(); + statement_no_indent(""); + is_using_builtin_array = false; + } + else + { + statement(get_argument_address_space(var), " auto& ", to_restrict(var_id, true), name, " = *(", + get_argument_address_space(var), " ", type_to_glsl(type), "* ", to_restrict(var_id, false), ")((", + get_argument_address_space(var), " char* ", to_restrict(var_id, false), ")", to_name(arg_id), ".", + ensure_valid_name(name, "m"), " + ", to_name(dynamic_offsets_buffer_id), "[", base_index, "]);"); + } + } + + bool has_runtime_array_declaration = false; + for (SPIRVariable *arg : entry_point_bindings) + { + const auto &var = *arg; + const auto &type = get_variable_data_type(var); + const auto &buffer_type = get_variable_element_type(var); + const string name = to_name(var.self); + + if (is_var_runtime_size_array(var)) + { + if (msl_options.argument_buffers_tier < Options::ArgumentBuffersTier::Tier2) + { + SPIRV_CROSS_THROW("Unsized array of descriptors requires argument buffer tier 2"); + } + + string resource_name; + if (descriptor_set_is_argument_buffer(get_decoration(var.self, DecorationDescriptorSet))) + resource_name = ir.meta[var.self].decoration.qualified_alias; + else + resource_name = name + "_"; + + switch (type.basetype) + { + case SPIRType::Image: + case SPIRType::Sampler: + case SPIRType::AccelerationStructure: + statement("spvDescriptorArray<", type_to_glsl(buffer_type, var.self), "> ", name, " {", resource_name, "};"); + break; + case SPIRType::SampledImage: + statement("spvDescriptorArray<", type_to_glsl(buffer_type, var.self), "> ", name, " {", resource_name, "};"); + // Unsupported with argument buffer for now. + statement("spvDescriptorArray ", name, "Smplr {", name, "Smplr_};"); + break; + case SPIRType::Struct: + statement("spvDescriptorArray<", get_argument_address_space(var), " ", type_to_glsl(buffer_type), "*> ", + name, " {", resource_name, "};"); + break; + default: + break; + } + has_runtime_array_declaration = true; + } + else if (!type.array.empty() && type.basetype == SPIRType::Struct) + { + // Emit only buffer arrays here. + statement(get_argument_address_space(var), " ", type_to_glsl(buffer_type), "* ", + to_restrict(var.self, true), name, "[] ="); + begin_scope(); + uint32_t array_size = get_resource_array_size(type, var.self); + for (uint32_t i = 0; i < array_size; ++i) + statement(name, "_", i, ","); + end_scope_decl(); + statement_no_indent(""); + } + } + + if (has_runtime_array_declaration) + statement_no_indent(""); + + // Emit buffer aliases here. + for (auto &var_id : buffer_aliases_discrete) + { + const auto &var = get(var_id); + const auto &type = get_variable_data_type(var); + auto addr_space = get_argument_address_space(var); + auto name = to_name(var_id); + + uint32_t desc_set = get_decoration(var_id, DecorationDescriptorSet); + uint32_t desc_binding = get_decoration(var_id, DecorationBinding); + auto alias_name = join("spvBufferAliasSet", desc_set, "Binding", desc_binding); + + statement(addr_space, " auto& ", to_restrict(var_id, true), + name, + " = *(", addr_space, " ", type_to_glsl(type), "*)", alias_name, ";"); + } + // Discrete descriptors are processed in entry point emission every compiler iteration. + buffer_aliases_discrete.clear(); + + for (auto &var_pair : buffer_aliases_argument) + { + uint32_t var_id = var_pair.first; + uint32_t alias_id = var_pair.second; + + const auto &var = get(var_id); + const auto &type = get_variable_data_type(var); + auto addr_space = get_argument_address_space(var); + + if (type.array.empty()) + { + statement(addr_space, " auto& ", to_restrict(var_id, true), to_name(var_id), " = (", addr_space, " ", + type_to_glsl(type), "&)", ir.meta[alias_id].decoration.qualified_alias, ";"); + } + else + { + const char *desc_addr_space = descriptor_address_space(var_id, var.storage, "thread"); + + // Esoteric type cast. Reference to array of pointers. + // Auto here defers to UBO or SSBO. The address space of the reference needs to refer to the + // address space of the argument buffer itself, which is usually constant, but can be const device for + // large argument buffers. + is_using_builtin_array = true; + statement(desc_addr_space, " auto& ", to_restrict(var_id, true), to_name(var_id), " = (", addr_space, " ", + type_to_glsl(type), "* ", desc_addr_space, " (&)", + type_to_array_glsl(type, var_id), ")", ir.meta[alias_id].decoration.qualified_alias, ";"); + is_using_builtin_array = false; + } + } + + // Emit disabled fragment outputs. + std::sort(disabled_frag_outputs.begin(), disabled_frag_outputs.end()); + for (uint32_t var_id : disabled_frag_outputs) + { + auto &var = get(var_id); + add_local_variable_name(var_id); + statement(CompilerGLSL::variable_decl(var), ";"); + var.deferred_declaration = false; + } +} + +string CompilerMSL::compile() +{ + replace_illegal_entry_point_names(); + ir.fixup_reserved_names(); + + // Do not deal with GLES-isms like precision, older extensions and such. + options.vulkan_semantics = true; + options.es = false; + options.version = 450; + backend.null_pointer_literal = "nullptr"; + backend.float_literal_suffix = false; + backend.uint32_t_literal_suffix = true; + backend.int16_t_literal_suffix = ""; + backend.uint16_t_literal_suffix = ""; + backend.basic_int_type = "int"; + backend.basic_uint_type = "uint"; + backend.basic_int8_type = "char"; + backend.basic_uint8_type = "uchar"; + backend.basic_int16_type = "short"; + backend.basic_uint16_type = "ushort"; + backend.boolean_mix_function = "select"; + backend.swizzle_is_function = false; + backend.shared_is_implied = false; + backend.use_initializer_list = true; + backend.use_typed_initializer_list = true; + backend.native_row_major_matrix = false; + backend.unsized_array_supported = false; + backend.can_declare_arrays_inline = false; + backend.allow_truncated_access_chain = true; + backend.comparison_image_samples_scalar = true; + backend.native_pointers = true; + backend.nonuniform_qualifier = ""; + backend.support_small_type_sampling_result = true; + backend.supports_empty_struct = true; + backend.support_64bit_switch = true; + backend.boolean_in_struct_remapped_type = SPIRType::Short; + + // Allow Metal to use the array template unless we force it off. + backend.can_return_array = !msl_options.force_native_arrays; + backend.array_is_value_type = !msl_options.force_native_arrays; + // Arrays which are part of buffer objects are never considered to be value types (just plain C-style). + backend.array_is_value_type_in_buffer_blocks = false; + backend.support_pointer_to_pointer = true; + backend.implicit_c_integer_promotion_rules = true; + + capture_output_to_buffer = msl_options.capture_output_to_buffer; + is_rasterization_disabled = msl_options.disable_rasterization || capture_output_to_buffer; + + // Initialize array here rather than constructor, MSVC 2013 workaround. + for (auto &id : next_metal_resource_ids) + id = 0; + + fixup_anonymous_struct_names(); + fixup_type_alias(); + replace_illegal_names(); + sync_entry_point_aliases_and_names(); + + build_function_control_flow_graphs_and_analyze(); + update_active_builtins(); + analyze_image_and_sampler_usage(); + analyze_sampled_image_usage(); + analyze_interlocked_resource_usage(); + preprocess_op_codes(); + build_implicit_builtins(); + + if (needs_manual_helper_invocation_updates() && + (active_input_builtins.get(BuiltInHelperInvocation) || needs_helper_invocation)) + { + string builtin_helper_invocation = builtin_to_glsl(BuiltInHelperInvocation, StorageClassInput); + string discard_expr = join(builtin_helper_invocation, " = true, discard_fragment()"); + if (msl_options.force_fragment_with_side_effects_execution) + discard_expr = join("!", builtin_helper_invocation, " ? (", discard_expr, ") : (void)0"); + backend.discard_literal = discard_expr; + backend.demote_literal = discard_expr; + } + else + { + backend.discard_literal = "discard_fragment()"; + backend.demote_literal = "discard_fragment()"; + } + + fixup_image_load_store_access(); + + set_enabled_interface_variables(get_active_interface_variables()); + if (msl_options.force_active_argument_buffer_resources) + activate_argument_buffer_resources(); + + if (swizzle_buffer_id) + add_active_interface_variable(swizzle_buffer_id); + if (buffer_size_buffer_id) + add_active_interface_variable(buffer_size_buffer_id); + if (view_mask_buffer_id) + add_active_interface_variable(view_mask_buffer_id); + if (dynamic_offsets_buffer_id) + add_active_interface_variable(dynamic_offsets_buffer_id); + if (builtin_layer_id) + add_active_interface_variable(builtin_layer_id); + if (builtin_dispatch_base_id && !msl_options.supports_msl_version(1, 2)) + add_active_interface_variable(builtin_dispatch_base_id); + if (builtin_sample_mask_id) + add_active_interface_variable(builtin_sample_mask_id); + if (builtin_frag_depth_id) + add_active_interface_variable(builtin_frag_depth_id); + + // Create structs to hold input, output and uniform variables. + // Do output first to ensure out. is declared at top of entry function. + qual_pos_var_name = ""; + stage_out_var_id = add_interface_block(StorageClassOutput); + patch_stage_out_var_id = add_interface_block(StorageClassOutput, true); + stage_in_var_id = add_interface_block(StorageClassInput); + if (is_tese_shader()) + patch_stage_in_var_id = add_interface_block(StorageClassInput, true); + + if (is_tesc_shader()) + stage_out_ptr_var_id = add_interface_block_pointer(stage_out_var_id, StorageClassOutput); + if (is_tessellation_shader()) + stage_in_ptr_var_id = add_interface_block_pointer(stage_in_var_id, StorageClassInput); + + // Metal vertex functions that define no output must disable rasterization and return void. + if (!stage_out_var_id) + is_rasterization_disabled = true; + + // Convert the use of global variables to recursively-passed function parameters + localize_global_variables(); + extract_global_variables_from_functions(); + + // Mark any non-stage-in structs to be tightly packed. + mark_packable_structs(); + reorder_type_alias(); + + // Add fixup hooks required by shader inputs and outputs. This needs to happen before + // the loop, so the hooks aren't added multiple times. + fix_up_shader_inputs_outputs(); + + // If we are using argument buffers, we create argument buffer structures for them here. + // These buffers will be used in the entry point, not the individual resources. + if (msl_options.argument_buffers) + { + if (!msl_options.supports_msl_version(2, 0)) + SPIRV_CROSS_THROW("Argument buffers can only be used with MSL 2.0 and up."); + analyze_argument_buffers(); + } + + uint32_t pass_count = 0; + do + { + reset(pass_count); + + // Start bindings at zero. + next_metal_resource_index_buffer = 0; + next_metal_resource_index_texture = 0; + next_metal_resource_index_sampler = 0; + for (auto &id : next_metal_resource_ids) + id = 0; + + // Move constructor for this type is broken on GCC 4.9 ... + buffer.reset(); + + emit_header(); + emit_custom_templates(); + emit_custom_functions(); + emit_specialization_constants_and_structs(); + emit_resources(); + emit_function(get(ir.default_entry_point), Bitset()); + + pass_count++; + } while (is_forcing_recompilation()); + + return buffer.str(); +} + +// Register the need to output any custom functions. +void CompilerMSL::preprocess_op_codes() +{ + OpCodePreprocessor preproc(*this); + traverse_all_reachable_opcodes(get(ir.default_entry_point), preproc); + + suppress_missing_prototypes = preproc.suppress_missing_prototypes; + + if (preproc.uses_atomics) + { + add_header_line("#include "); + add_pragma_line("#pragma clang diagnostic ignored \"-Wunused-variable\""); + } + + // Before MSL 2.1 (2.2 for textures), Metal vertex functions that write to + // resources must disable rasterization and return void. + if ((preproc.uses_buffer_write && !msl_options.supports_msl_version(2, 1)) || + (preproc.uses_image_write && !msl_options.supports_msl_version(2, 2))) + is_rasterization_disabled = true; + + // Tessellation control shaders are run as compute functions in Metal, and so + // must capture their output to a buffer. + if (is_tesc_shader() || (get_execution_model() == ExecutionModelVertex && msl_options.vertex_for_tessellation)) + { + is_rasterization_disabled = true; + capture_output_to_buffer = true; + } + + if (preproc.needs_subgroup_invocation_id) + needs_subgroup_invocation_id = true; + if (preproc.needs_subgroup_size) + needs_subgroup_size = true; + // build_implicit_builtins() hasn't run yet, and in fact, this needs to execute + // before then so that gl_SampleID will get added; so we also need to check if + // that function would add gl_FragCoord. + if (preproc.needs_sample_id || msl_options.force_sample_rate_shading || + (is_sample_rate() && (active_input_builtins.get(BuiltInFragCoord) || + (need_subpass_input_ms && !msl_options.use_framebuffer_fetch_subpasses)))) + needs_sample_id = true; + if (preproc.needs_helper_invocation) + needs_helper_invocation = true; + + // OpKill is removed by the parser, so we need to identify those by inspecting + // blocks. + ir.for_each_typed_id([&preproc](uint32_t, SPIRBlock &block) { + if (block.terminator == SPIRBlock::Kill) + preproc.uses_discard = true; + }); + + // Fragment shaders that both write to storage resources and discard fragments + // need checks on the writes, to work around Metal allowing these writes despite + // the fragment being dead. We also require to force Metal to execute fragment + // shaders instead of being prematurely discarded. + if (preproc.uses_discard && (preproc.uses_buffer_write || preproc.uses_image_write)) + { + bool should_enable = (msl_options.check_discarded_frag_stores || msl_options.force_fragment_with_side_effects_execution); + frag_shader_needs_discard_checks |= msl_options.check_discarded_frag_stores; + needs_helper_invocation |= should_enable; + // Fragment discard store checks imply manual HelperInvocation updates. + msl_options.manual_helper_invocation_updates |= should_enable; + } + + if (is_intersection_query()) + { + add_header_line("#if __METAL_VERSION__ >= 230"); + add_header_line("#include "); + add_header_line("using namespace metal::raytracing;"); + add_header_line("#endif"); + } +} + +// Move the Private and Workgroup global variables to the entry function. +// Non-constant variables cannot have global scope in Metal. +void CompilerMSL::localize_global_variables() +{ + auto &entry_func = get(ir.default_entry_point); + auto iter = global_variables.begin(); + while (iter != global_variables.end()) + { + uint32_t v_id = *iter; + auto &var = get(v_id); + if (var.storage == StorageClassPrivate || var.storage == StorageClassWorkgroup) + { + if (!variable_is_lut(var)) + entry_func.add_local_variable(v_id); + iter = global_variables.erase(iter); + } + else + iter++; + } +} + +// For any global variable accessed directly by a function, +// extract that variable and add it as an argument to that function. +void CompilerMSL::extract_global_variables_from_functions() +{ + // Uniforms + unordered_set global_var_ids; + ir.for_each_typed_id([&](uint32_t, SPIRVariable &var) { + // Some builtins resolve directly to a function call which does not need any declared variables. + // Skip these. + if (var.storage == StorageClassInput && has_decoration(var.self, DecorationBuiltIn)) + { + auto bi_type = BuiltIn(get_decoration(var.self, DecorationBuiltIn)); + if (bi_type == BuiltInHelperInvocation && !needs_manual_helper_invocation_updates()) + return; + if (bi_type == BuiltInHelperInvocation && needs_manual_helper_invocation_updates()) + { + if (msl_options.is_ios() && !msl_options.supports_msl_version(2, 3)) + SPIRV_CROSS_THROW("simd_is_helper_thread() requires version 2.3 on iOS."); + else if (msl_options.is_macos() && !msl_options.supports_msl_version(2, 1)) + SPIRV_CROSS_THROW("simd_is_helper_thread() requires version 2.1 on macOS."); + // Make sure this is declared and initialized. + // Force this to have the proper name. + set_name(var.self, builtin_to_glsl(BuiltInHelperInvocation, StorageClassInput)); + auto &entry_func = this->get(ir.default_entry_point); + entry_func.add_local_variable(var.self); + vars_needing_early_declaration.push_back(var.self); + entry_func.fixup_hooks_in.push_back([this, &var]() + { statement(to_name(var.self), " = simd_is_helper_thread();"); }); + } + } + + if (var.storage == StorageClassInput || var.storage == StorageClassOutput || + var.storage == StorageClassUniform || var.storage == StorageClassUniformConstant || + var.storage == StorageClassPushConstant || var.storage == StorageClassStorageBuffer) + { + global_var_ids.insert(var.self); + } + }); + + // Local vars that are declared in the main function and accessed directly by a function + auto &entry_func = get(ir.default_entry_point); + for (auto &var : entry_func.local_variables) + if (get(var).storage != StorageClassFunction) + global_var_ids.insert(var); + + std::set added_arg_ids; + unordered_set processed_func_ids; + extract_global_variables_from_function(ir.default_entry_point, added_arg_ids, global_var_ids, processed_func_ids); +} + +// MSL does not support the use of global variables for shader input content. +// For any global variable accessed directly by the specified function, extract that variable, +// add it as an argument to that function, and the arg to the added_arg_ids collection. +void CompilerMSL::extract_global_variables_from_function(uint32_t func_id, std::set &added_arg_ids, + unordered_set &global_var_ids, + unordered_set &processed_func_ids) +{ + // Avoid processing a function more than once + if (processed_func_ids.find(func_id) != processed_func_ids.end()) + { + // Return function global variables + added_arg_ids = function_global_vars[func_id]; + return; + } + + processed_func_ids.insert(func_id); + + auto &func = get(func_id); + + // Recursively establish global args added to functions on which we depend. + for (auto block : func.blocks) + { + auto &b = get(block); + for (auto &i : b.ops) + { + auto ops = stream(i); + auto op = static_cast(i.op); + + switch (op) + { + case OpLoad: + case OpInBoundsAccessChain: + case OpAccessChain: + case OpPtrAccessChain: + case OpArrayLength: + { + uint32_t base_id = ops[2]; + if (global_var_ids.find(base_id) != global_var_ids.end()) + added_arg_ids.insert(base_id); + + // Use Metal's native frame-buffer fetch API for subpass inputs. + auto &type = get(ops[0]); + if (type.basetype == SPIRType::Image && type.image.dim == DimSubpassData && + (!msl_options.use_framebuffer_fetch_subpasses)) + { + // Implicitly reads gl_FragCoord. + assert(builtin_frag_coord_id != 0); + added_arg_ids.insert(builtin_frag_coord_id); + if (msl_options.multiview) + { + // Implicitly reads gl_ViewIndex. + assert(builtin_view_idx_id != 0); + added_arg_ids.insert(builtin_view_idx_id); + } + else if (msl_options.arrayed_subpass_input) + { + // Implicitly reads gl_Layer. + assert(builtin_layer_id != 0); + added_arg_ids.insert(builtin_layer_id); + } + } + + break; + } + + case OpFunctionCall: + { + // First see if any of the function call args are globals + for (uint32_t arg_idx = 3; arg_idx < i.length; arg_idx++) + { + uint32_t arg_id = ops[arg_idx]; + if (global_var_ids.find(arg_id) != global_var_ids.end()) + added_arg_ids.insert(arg_id); + } + + // Then recurse into the function itself to extract globals used internally in the function + uint32_t inner_func_id = ops[2]; + std::set inner_func_args; + extract_global_variables_from_function(inner_func_id, inner_func_args, global_var_ids, + processed_func_ids); + added_arg_ids.insert(inner_func_args.begin(), inner_func_args.end()); + break; + } + + case OpStore: + { + uint32_t base_id = ops[0]; + if (global_var_ids.find(base_id) != global_var_ids.end()) + { + added_arg_ids.insert(base_id); + + if (msl_options.input_attachment_is_ds_attachment && base_id == builtin_frag_depth_id) + writes_to_depth = true; + } + + uint32_t rvalue_id = ops[1]; + if (global_var_ids.find(rvalue_id) != global_var_ids.end()) + added_arg_ids.insert(rvalue_id); + + if (needs_frag_discard_checks()) + added_arg_ids.insert(builtin_helper_invocation_id); + + break; + } + + case OpSelect: + { + uint32_t base_id = ops[3]; + if (global_var_ids.find(base_id) != global_var_ids.end()) + added_arg_ids.insert(base_id); + base_id = ops[4]; + if (global_var_ids.find(base_id) != global_var_ids.end()) + added_arg_ids.insert(base_id); + break; + } + + case OpAtomicExchange: + case OpAtomicCompareExchange: + case OpAtomicStore: + case OpAtomicIIncrement: + case OpAtomicIDecrement: + case OpAtomicIAdd: + case OpAtomicFAddEXT: + case OpAtomicISub: + case OpAtomicSMin: + case OpAtomicUMin: + case OpAtomicSMax: + case OpAtomicUMax: + case OpAtomicAnd: + case OpAtomicOr: + case OpAtomicXor: + case OpImageWrite: + { + if (needs_frag_discard_checks()) + added_arg_ids.insert(builtin_helper_invocation_id); + uint32_t ptr = 0; + if (op == OpAtomicStore || op == OpImageWrite) + ptr = ops[0]; + else + ptr = ops[2]; + if (global_var_ids.find(ptr) != global_var_ids.end()) + added_arg_ids.insert(ptr); + break; + } + + // Emulate texture2D atomic operations + case OpImageTexelPointer: + { + // When using the pointer, we need to know which variable it is actually loaded from. + uint32_t base_id = ops[2]; + auto *var = maybe_get_backing_variable(base_id); + if (var) + { + if (atomic_image_vars_emulated.count(var->self) && + !get(var->basetype).array.empty()) + { + SPIRV_CROSS_THROW( + "Cannot emulate array of storage images with atomics. Use MSL 3.1 for native support."); + } + + if (global_var_ids.find(base_id) != global_var_ids.end()) + added_arg_ids.insert(base_id); + } + break; + } + + case OpExtInst: + { + uint32_t extension_set = ops[2]; + if (get(extension_set).ext == SPIRExtension::GLSL) + { + auto op_450 = static_cast(ops[3]); + switch (op_450) + { + case GLSLstd450InterpolateAtCentroid: + case GLSLstd450InterpolateAtSample: + case GLSLstd450InterpolateAtOffset: + { + // For these, we really need the stage-in block. It is theoretically possible to pass the + // interpolant object, but a) doing so would require us to create an entirely new variable + // with Interpolant type, and b) if we have a struct or array, handling all the members and + // elements could get unwieldy fast. + added_arg_ids.insert(stage_in_var_id); + break; + } + + case GLSLstd450Modf: + case GLSLstd450Frexp: + { + uint32_t base_id = ops[5]; + if (global_var_ids.find(base_id) != global_var_ids.end()) + added_arg_ids.insert(base_id); + break; + } + + default: + break; + } + } + break; + } + + case OpGroupNonUniformInverseBallot: + { + added_arg_ids.insert(builtin_subgroup_invocation_id_id); + break; + } + + case OpGroupNonUniformBallotFindLSB: + case OpGroupNonUniformBallotFindMSB: + { + added_arg_ids.insert(builtin_subgroup_size_id); + break; + } + + case OpGroupNonUniformBallotBitCount: + { + auto operation = static_cast(ops[3]); + switch (operation) + { + case GroupOperationReduce: + added_arg_ids.insert(builtin_subgroup_size_id); + break; + case GroupOperationInclusiveScan: + case GroupOperationExclusiveScan: + added_arg_ids.insert(builtin_subgroup_invocation_id_id); + break; + default: + break; + } + break; + } + + case OpDemoteToHelperInvocation: + if (needs_manual_helper_invocation_updates() && + (active_input_builtins.get(BuiltInHelperInvocation) || needs_helper_invocation)) + added_arg_ids.insert(builtin_helper_invocation_id); + break; + + case OpIsHelperInvocationEXT: + if (needs_manual_helper_invocation_updates()) + added_arg_ids.insert(builtin_helper_invocation_id); + break; + + case OpRayQueryInitializeKHR: + case OpRayQueryProceedKHR: + case OpRayQueryTerminateKHR: + case OpRayQueryGenerateIntersectionKHR: + case OpRayQueryConfirmIntersectionKHR: + { + // Ray query accesses memory directly, need check pass down object if using Private storage class. + uint32_t base_id = ops[0]; + if (global_var_ids.find(base_id) != global_var_ids.end()) + added_arg_ids.insert(base_id); + break; + } + + case OpRayQueryGetRayTMinKHR: + case OpRayQueryGetRayFlagsKHR: + case OpRayQueryGetWorldRayOriginKHR: + case OpRayQueryGetWorldRayDirectionKHR: + case OpRayQueryGetIntersectionCandidateAABBOpaqueKHR: + case OpRayQueryGetIntersectionTypeKHR: + case OpRayQueryGetIntersectionTKHR: + case OpRayQueryGetIntersectionInstanceCustomIndexKHR: + case OpRayQueryGetIntersectionInstanceIdKHR: + case OpRayQueryGetIntersectionInstanceShaderBindingTableRecordOffsetKHR: + case OpRayQueryGetIntersectionGeometryIndexKHR: + case OpRayQueryGetIntersectionPrimitiveIndexKHR: + case OpRayQueryGetIntersectionBarycentricsKHR: + case OpRayQueryGetIntersectionFrontFaceKHR: + case OpRayQueryGetIntersectionObjectRayDirectionKHR: + case OpRayQueryGetIntersectionObjectRayOriginKHR: + case OpRayQueryGetIntersectionObjectToWorldKHR: + case OpRayQueryGetIntersectionWorldToObjectKHR: + { + // Ray query accesses memory directly, need check pass down object if using Private storage class. + uint32_t base_id = ops[2]; + if (global_var_ids.find(base_id) != global_var_ids.end()) + added_arg_ids.insert(base_id); + break; + } + + default: + break; + } + + if (needs_manual_helper_invocation_updates() && b.terminator == SPIRBlock::Kill && + (active_input_builtins.get(BuiltInHelperInvocation) || needs_helper_invocation)) + added_arg_ids.insert(builtin_helper_invocation_id); + + // TODO: Add all other operations which can affect memory. + // We should consider a more unified system here to reduce boiler-plate. + // This kind of analysis is done in several places ... + } + } + + function_global_vars[func_id] = added_arg_ids; + + // Add the global variables as arguments to the function + if (func_id != ir.default_entry_point) + { + bool control_point_added_in = false; + bool control_point_added_out = false; + bool patch_added_in = false; + bool patch_added_out = false; + + for (uint32_t arg_id : added_arg_ids) + { + auto &var = get(arg_id); + uint32_t type_id = var.basetype; + auto *p_type = &get(type_id); + BuiltIn bi_type = BuiltIn(get_decoration(arg_id, DecorationBuiltIn)); + + bool is_patch = has_decoration(arg_id, DecorationPatch) || is_patch_block(*p_type); + bool is_block = has_decoration(p_type->self, DecorationBlock); + bool is_control_point_storage = + !is_patch && ((is_tessellation_shader() && var.storage == StorageClassInput) || + (is_tesc_shader() && var.storage == StorageClassOutput)); + bool is_patch_block_storage = is_patch && is_block && var.storage == StorageClassOutput; + bool is_builtin = is_builtin_variable(var); + bool variable_is_stage_io = + !is_builtin || bi_type == BuiltInPosition || bi_type == BuiltInPointSize || + bi_type == BuiltInClipDistance || bi_type == BuiltInCullDistance || + p_type->basetype == SPIRType::Struct; + bool is_redirected_to_global_stage_io = (is_control_point_storage || is_patch_block_storage) && + variable_is_stage_io; + + // If output is masked it is not considered part of the global stage IO interface. + if (is_redirected_to_global_stage_io && var.storage == StorageClassOutput) + is_redirected_to_global_stage_io = !is_stage_output_variable_masked(var); + + if (is_redirected_to_global_stage_io) + { + // Tessellation control shaders see inputs and per-point outputs as arrays. + // Similarly, tessellation evaluation shaders see per-point inputs as arrays. + // We collected them into a structure; we must pass the array of this + // structure to the function. + std::string name; + if (is_patch) + name = var.storage == StorageClassInput ? patch_stage_in_var_name : patch_stage_out_var_name; + else + name = var.storage == StorageClassInput ? "gl_in" : "gl_out"; + + if (var.storage == StorageClassOutput && has_decoration(p_type->self, DecorationBlock)) + { + // If we're redirecting a block, we might still need to access the original block + // variable if we're masking some members. + for (uint32_t mbr_idx = 0; mbr_idx < uint32_t(p_type->member_types.size()); mbr_idx++) + { + if (is_stage_output_block_member_masked(var, mbr_idx, true)) + { + func.add_parameter(var.basetype, var.self, true); + break; + } + } + } + + if (var.storage == StorageClassInput) + { + auto &added_in = is_patch ? patch_added_in : control_point_added_in; + if (added_in) + continue; + arg_id = is_patch ? patch_stage_in_var_id : stage_in_ptr_var_id; + added_in = true; + } + else if (var.storage == StorageClassOutput) + { + auto &added_out = is_patch ? patch_added_out : control_point_added_out; + if (added_out) + continue; + arg_id = is_patch ? patch_stage_out_var_id : stage_out_ptr_var_id; + added_out = true; + } + + type_id = get(arg_id).basetype; + uint32_t next_id = ir.increase_bound_by(1); + func.add_parameter(type_id, next_id, true); + set(next_id, type_id, StorageClassFunction, 0, arg_id); + + set_name(next_id, name); + if (is_tese_shader() && msl_options.raw_buffer_tese_input && var.storage == StorageClassInput) + set_decoration(next_id, DecorationNonWritable); + } + else if (is_builtin && has_decoration(p_type->self, DecorationBlock)) + { + // Get the pointee type + type_id = get_pointee_type_id(type_id); + p_type = &get(type_id); + + uint32_t mbr_idx = 0; + for (auto &mbr_type_id : p_type->member_types) + { + BuiltIn builtin = BuiltInMax; + is_builtin = is_member_builtin(*p_type, mbr_idx, &builtin); + if (is_builtin && has_active_builtin(builtin, var.storage)) + { + // Add a arg variable with the same type and decorations as the member + uint32_t next_ids = ir.increase_bound_by(2); + uint32_t ptr_type_id = next_ids + 0; + uint32_t var_id = next_ids + 1; + + // Make sure we have an actual pointer type, + // so that we will get the appropriate address space when declaring these builtins. + auto &ptr = set(ptr_type_id, get(mbr_type_id)); + ptr.self = mbr_type_id; + ptr.storage = var.storage; + ptr.pointer = true; + ptr.pointer_depth++; + ptr.parent_type = mbr_type_id; + + func.add_parameter(mbr_type_id, var_id, true); + set(var_id, ptr_type_id, StorageClassFunction); + ir.meta[var_id].decoration = ir.meta[type_id].members[mbr_idx]; + } + mbr_idx++; + } + } + else + { + uint32_t next_id = ir.increase_bound_by(1); + func.add_parameter(type_id, next_id, true); + set(next_id, type_id, StorageClassFunction, 0, arg_id); + + // Ensure the new variable has all the same meta info + ir.meta[next_id] = ir.meta[arg_id]; + } + } + } +} + +// For all variables that are some form of non-input-output interface block, mark that all the structs +// that are recursively contained within the type referenced by that variable should be packed tightly. +void CompilerMSL::mark_packable_structs() +{ + ir.for_each_typed_id([&](uint32_t, SPIRVariable &var) { + if (var.storage != StorageClassFunction && !is_hidden_variable(var)) + { + auto &type = this->get(var.basetype); + if (type.pointer && + (type.storage == StorageClassUniform || type.storage == StorageClassUniformConstant || + type.storage == StorageClassPushConstant || type.storage == StorageClassStorageBuffer) && + (has_decoration(type.self, DecorationBlock) || has_decoration(type.self, DecorationBufferBlock))) + mark_as_packable(type); + } + + if (var.storage == StorageClassWorkgroup) + { + auto *type = &this->get(var.basetype); + if (type->basetype == SPIRType::Struct) + mark_as_workgroup_struct(*type); + } + }); + + // Physical storage buffer pointers can appear outside of the context of a variable, if the address + // is calculated from a ulong or uvec2 and cast to a pointer, so check if they need to be packed too. + ir.for_each_typed_id([&](uint32_t, SPIRType &type) { + if (type.basetype == SPIRType::Struct && type.pointer && type.storage == StorageClassPhysicalStorageBuffer) + mark_as_packable(type); + }); +} + +// If the specified type is a struct, it and any nested structs +// are marked as packable with the SPIRVCrossDecorationBufferBlockRepacked decoration, +void CompilerMSL::mark_as_packable(SPIRType &type) +{ + // If this is not the base type (eg. it's a pointer or array), tunnel down + if (type.parent_type) + { + mark_as_packable(get(type.parent_type)); + return; + } + + // Handle possible recursion when a struct contains a pointer to its own type nested somewhere. + if (type.basetype == SPIRType::Struct && !has_extended_decoration(type.self, SPIRVCrossDecorationBufferBlockRepacked)) + { + set_extended_decoration(type.self, SPIRVCrossDecorationBufferBlockRepacked); + + // Recurse + uint32_t mbr_cnt = uint32_t(type.member_types.size()); + for (uint32_t mbr_idx = 0; mbr_idx < mbr_cnt; mbr_idx++) + { + uint32_t mbr_type_id = type.member_types[mbr_idx]; + auto &mbr_type = get(mbr_type_id); + mark_as_packable(mbr_type); + if (mbr_type.type_alias) + { + auto &mbr_type_alias = get(mbr_type.type_alias); + mark_as_packable(mbr_type_alias); + } + } + } +} + +// If the specified type is a struct, it and any nested structs +// are marked as used with workgroup storage using the SPIRVCrossDecorationWorkgroupStruct decoration. +void CompilerMSL::mark_as_workgroup_struct(SPIRType &type) +{ + // If this is not the base type (eg. it's a pointer or array), tunnel down + if (type.parent_type) + { + mark_as_workgroup_struct(get(type.parent_type)); + return; + } + + // Handle possible recursion when a struct contains a pointer to its own type nested somewhere. + if (type.basetype == SPIRType::Struct && !has_extended_decoration(type.self, SPIRVCrossDecorationWorkgroupStruct)) + { + set_extended_decoration(type.self, SPIRVCrossDecorationWorkgroupStruct); + + // Recurse + uint32_t mbr_cnt = uint32_t(type.member_types.size()); + for (uint32_t mbr_idx = 0; mbr_idx < mbr_cnt; mbr_idx++) + { + uint32_t mbr_type_id = type.member_types[mbr_idx]; + auto &mbr_type = get(mbr_type_id); + mark_as_workgroup_struct(mbr_type); + if (mbr_type.type_alias) + { + auto &mbr_type_alias = get(mbr_type.type_alias); + mark_as_workgroup_struct(mbr_type_alias); + } + } + } +} + +// If a shader input exists at the location, it is marked as being used by this shader +void CompilerMSL::mark_location_as_used_by_shader(uint32_t location, const SPIRType &type, + StorageClass storage, bool fallback) +{ + uint32_t count = type_to_location_count(type); + switch (storage) + { + case StorageClassInput: + for (uint32_t i = 0; i < count; i++) + { + location_inputs_in_use.insert(location + i); + if (fallback) + location_inputs_in_use_fallback.insert(location + i); + } + break; + case StorageClassOutput: + for (uint32_t i = 0; i < count; i++) + { + location_outputs_in_use.insert(location + i); + if (fallback) + location_outputs_in_use_fallback.insert(location + i); + } + break; + default: + return; + } +} + +uint32_t CompilerMSL::get_target_components_for_fragment_location(uint32_t location) const +{ + auto itr = fragment_output_components.find(location); + if (itr == end(fragment_output_components)) + return 4; + else + return itr->second; +} + +uint32_t CompilerMSL::build_extended_vector_type(uint32_t type_id, uint32_t components, SPIRType::BaseType basetype) +{ + assert(components > 1); + uint32_t new_type_id = ir.increase_bound_by(1); + const auto *p_old_type = &get(type_id); + const SPIRType *old_ptr_t = nullptr; + const SPIRType *old_array_t = nullptr; + + if (is_pointer(*p_old_type)) + { + old_ptr_t = p_old_type; + p_old_type = &get_pointee_type(*old_ptr_t); + } + + if (is_array(*p_old_type)) + { + old_array_t = p_old_type; + p_old_type = &get_type(old_array_t->parent_type); + } + + auto *type = &set(new_type_id, *p_old_type); + assert(is_scalar(*type) || is_vector(*type)); + type->op = OpTypeVector; + type->vecsize = components; + if (basetype != SPIRType::Unknown) + type->basetype = basetype; + type->self = new_type_id; + // We want parent type to point to the scalar type. + type->parent_type = is_scalar(*p_old_type) ? TypeID(p_old_type->self) : p_old_type->parent_type; + assert(is_scalar(get(type->parent_type))); + type->array.clear(); + type->array_size_literal.clear(); + type->pointer = false; + + if (old_array_t) + { + uint32_t array_type_id = ir.increase_bound_by(1); + type = &set(array_type_id, *type); + type->op = OpTypeArray; + type->parent_type = new_type_id; + type->array = old_array_t->array; + type->array_size_literal = old_array_t->array_size_literal; + new_type_id = array_type_id; + } + + if (old_ptr_t) + { + uint32_t ptr_type_id = ir.increase_bound_by(1); + type = &set(ptr_type_id, *type); + type->op = OpTypePointer; + type->parent_type = new_type_id; + type->storage = old_ptr_t->storage; + type->pointer = true; + type->pointer_depth++; + new_type_id = ptr_type_id; + } + + return new_type_id; +} + +uint32_t CompilerMSL::build_msl_interpolant_type(uint32_t type_id, bool is_noperspective) +{ + uint32_t new_type_id = ir.increase_bound_by(1); + SPIRType &type = set(new_type_id, get(type_id)); + type.basetype = SPIRType::Interpolant; + type.parent_type = type_id; + // In Metal, the pull-model interpolant type encodes perspective-vs-no-perspective in the type itself. + // Add this decoration so we know which argument to pass to the template. + if (is_noperspective) + set_decoration(new_type_id, DecorationNoPerspective); + return new_type_id; +} + +bool CompilerMSL::add_component_variable_to_interface_block(spv::StorageClass storage, const std::string &ib_var_ref, + SPIRVariable &var, + const SPIRType &type, + InterfaceBlockMeta &meta) +{ + // Deal with Component decorations. + const InterfaceBlockMeta::LocationMeta *location_meta = nullptr; + uint32_t location = ~0u; + if (has_decoration(var.self, DecorationLocation)) + { + location = get_decoration(var.self, DecorationLocation); + auto location_meta_itr = meta.location_meta.find(location); + if (location_meta_itr != end(meta.location_meta)) + location_meta = &location_meta_itr->second; + } + + // Check if we need to pad fragment output to match a certain number of components. + if (location_meta) + { + bool pad_fragment_output = has_decoration(var.self, DecorationLocation) && + msl_options.pad_fragment_output_components && + get_entry_point().model == ExecutionModelFragment && storage == StorageClassOutput; + + auto &entry_func = get(ir.default_entry_point); + uint32_t start_component = get_decoration(var.self, DecorationComponent); + uint32_t type_components = type.vecsize; + uint32_t num_components = location_meta->num_components; + + if (pad_fragment_output) + { + uint32_t locn = get_decoration(var.self, DecorationLocation); + num_components = max(num_components, get_target_components_for_fragment_location(locn)); + } + + // We have already declared an IO block member as m_location_N. + // Just emit an early-declared variable and fixup as needed. + // Arrays need to be unrolled here since each location might need a different number of components. + entry_func.add_local_variable(var.self); + vars_needing_early_declaration.push_back(var.self); + + if (var.storage == StorageClassInput) + { + entry_func.fixup_hooks_in.push_back([=, &type, &var]() { + if (!type.array.empty()) + { + uint32_t array_size = to_array_size_literal(type); + for (uint32_t loc_off = 0; loc_off < array_size; loc_off++) + { + statement(to_name(var.self), "[", loc_off, "]", " = ", ib_var_ref, + ".m_location_", location + loc_off, + vector_swizzle(type_components, start_component), ";"); + } + } + else + { + statement(to_name(var.self), " = ", ib_var_ref, ".m_location_", location, + vector_swizzle(type_components, start_component), ";"); + } + }); + } + else + { + entry_func.fixup_hooks_out.push_back([=, &type, &var]() { + if (!type.array.empty()) + { + uint32_t array_size = to_array_size_literal(type); + for (uint32_t loc_off = 0; loc_off < array_size; loc_off++) + { + statement(ib_var_ref, ".m_location_", location + loc_off, + vector_swizzle(type_components, start_component), " = ", + to_name(var.self), "[", loc_off, "];"); + } + } + else + { + statement(ib_var_ref, ".m_location_", location, + vector_swizzle(type_components, start_component), " = ", to_name(var.self), ";"); + } + }); + } + return true; + } + else + return false; +} + +void CompilerMSL::add_plain_variable_to_interface_block(StorageClass storage, const string &ib_var_ref, + SPIRType &ib_type, SPIRVariable &var, InterfaceBlockMeta &meta) +{ + bool is_builtin = is_builtin_variable(var); + BuiltIn builtin = BuiltIn(get_decoration(var.self, DecorationBuiltIn)); + bool is_flat = has_decoration(var.self, DecorationFlat); + bool is_noperspective = has_decoration(var.self, DecorationNoPerspective); + bool is_centroid = has_decoration(var.self, DecorationCentroid); + bool is_sample = has_decoration(var.self, DecorationSample); + + // Add a reference to the variable type to the interface struct. + uint32_t ib_mbr_idx = uint32_t(ib_type.member_types.size()); + uint32_t type_id = ensure_correct_builtin_type(var.basetype, builtin); + var.basetype = type_id; + + type_id = get_pointee_type_id(var.basetype); + if (meta.strip_array && is_array(get(type_id))) + type_id = get(type_id).parent_type; + auto &type = get(type_id); + uint32_t target_components = 0; + uint32_t type_components = type.vecsize; + + bool padded_output = false; + bool padded_input = false; + uint32_t start_component = 0; + + auto &entry_func = get(ir.default_entry_point); + + if (add_component_variable_to_interface_block(storage, ib_var_ref, var, type, meta)) + return; + + bool pad_fragment_output = has_decoration(var.self, DecorationLocation) && + msl_options.pad_fragment_output_components && + get_entry_point().model == ExecutionModelFragment && storage == StorageClassOutput; + + if (pad_fragment_output) + { + uint32_t locn = get_decoration(var.self, DecorationLocation); + target_components = get_target_components_for_fragment_location(locn); + if (type_components < target_components) + { + // Make a new type here. + type_id = build_extended_vector_type(type_id, target_components); + padded_output = true; + } + } + + if (storage == StorageClassInput && pull_model_inputs.count(var.self)) + ib_type.member_types.push_back(build_msl_interpolant_type(type_id, is_noperspective)); + else + ib_type.member_types.push_back(type_id); + + // Give the member a name + string mbr_name = ensure_valid_name(to_expression(var.self), "m"); + set_member_name(ib_type.self, ib_mbr_idx, mbr_name); + + // Update the original variable reference to include the structure reference + string qual_var_name = ib_var_ref + "." + mbr_name; + // If using pull-model interpolation, need to add a call to the correct interpolation method. + if (storage == StorageClassInput && pull_model_inputs.count(var.self)) + { + if (is_centroid) + qual_var_name += ".interpolate_at_centroid()"; + else if (is_sample) + qual_var_name += join(".interpolate_at_sample(", to_expression(builtin_sample_id_id), ")"); + else + qual_var_name += ".interpolate_at_center()"; + } + + if (padded_output || padded_input) + { + entry_func.add_local_variable(var.self); + vars_needing_early_declaration.push_back(var.self); + + if (padded_output) + { + entry_func.fixup_hooks_out.push_back([=, &var]() { + statement(qual_var_name, vector_swizzle(type_components, start_component), " = ", to_name(var.self), + ";"); + }); + } + else + { + entry_func.fixup_hooks_in.push_back([=, &var]() { + statement(to_name(var.self), " = ", qual_var_name, vector_swizzle(type_components, start_component), + ";"); + }); + } + } + else if (!meta.strip_array) + ir.meta[var.self].decoration.qualified_alias = qual_var_name; + + if (var.storage == StorageClassOutput && var.initializer != ID(0)) + { + if (padded_output || padded_input) + { + entry_func.fixup_hooks_in.push_back( + [=, &var]() { statement(to_name(var.self), " = ", to_expression(var.initializer), ";"); }); + } + else + { + if (meta.strip_array) + { + entry_func.fixup_hooks_in.push_back([=, &var]() { + uint32_t index = get_extended_decoration(var.self, SPIRVCrossDecorationInterfaceMemberIndex); + auto invocation = to_tesc_invocation_id(); + statement(to_expression(stage_out_ptr_var_id), "[", + invocation, "].", + to_member_name(ib_type, index), " = ", to_expression(var.initializer), "[", + invocation, "];"); + }); + } + else + { + entry_func.fixup_hooks_in.push_back([=, &var]() { + statement(qual_var_name, " = ", to_expression(var.initializer), ";"); + }); + } + } + } + + // Copy the variable location from the original variable to the member + if (get_decoration_bitset(var.self).get(DecorationLocation)) + { + uint32_t locn = get_decoration(var.self, DecorationLocation); + uint32_t comp = get_decoration(var.self, DecorationComponent); + if (storage == StorageClassInput) + { + type_id = ensure_correct_input_type(var.basetype, locn, comp, 0, meta.strip_array); + var.basetype = type_id; + + type_id = get_pointee_type_id(type_id); + if (meta.strip_array && is_array(get(type_id))) + type_id = get(type_id).parent_type; + if (pull_model_inputs.count(var.self)) + ib_type.member_types[ib_mbr_idx] = build_msl_interpolant_type(type_id, is_noperspective); + else + ib_type.member_types[ib_mbr_idx] = type_id; + } + set_member_decoration(ib_type.self, ib_mbr_idx, DecorationLocation, locn); + if (comp) + set_member_decoration(ib_type.self, ib_mbr_idx, DecorationComponent, comp); + mark_location_as_used_by_shader(locn, get(type_id), storage); + } + else if (is_builtin && is_tessellation_shader() && storage == StorageClassInput && inputs_by_builtin.count(builtin)) + { + uint32_t locn = inputs_by_builtin[builtin].location; + set_member_decoration(ib_type.self, ib_mbr_idx, DecorationLocation, locn); + mark_location_as_used_by_shader(locn, type, storage); + } + else if (is_builtin && capture_output_to_buffer && storage == StorageClassOutput && outputs_by_builtin.count(builtin)) + { + uint32_t locn = outputs_by_builtin[builtin].location; + set_member_decoration(ib_type.self, ib_mbr_idx, DecorationLocation, locn); + mark_location_as_used_by_shader(locn, type, storage); + } + + if (get_decoration_bitset(var.self).get(DecorationComponent)) + { + uint32_t component = get_decoration(var.self, DecorationComponent); + set_member_decoration(ib_type.self, ib_mbr_idx, DecorationComponent, component); + } + + if (get_decoration_bitset(var.self).get(DecorationIndex)) + { + uint32_t index = get_decoration(var.self, DecorationIndex); + set_member_decoration(ib_type.self, ib_mbr_idx, DecorationIndex, index); + } + + // Mark the member as builtin if needed + if (is_builtin) + { + set_member_decoration(ib_type.self, ib_mbr_idx, DecorationBuiltIn, builtin); + if (builtin == BuiltInPosition && storage == StorageClassOutput) + qual_pos_var_name = qual_var_name; + } + + // Copy interpolation decorations if needed + if (storage != StorageClassInput || !pull_model_inputs.count(var.self)) + { + if (is_flat) + set_member_decoration(ib_type.self, ib_mbr_idx, DecorationFlat); + if (is_noperspective) + set_member_decoration(ib_type.self, ib_mbr_idx, DecorationNoPerspective); + if (is_centroid) + set_member_decoration(ib_type.self, ib_mbr_idx, DecorationCentroid); + if (is_sample) + set_member_decoration(ib_type.self, ib_mbr_idx, DecorationSample); + } + + set_extended_member_decoration(ib_type.self, ib_mbr_idx, SPIRVCrossDecorationInterfaceOrigID, var.self); +} + +void CompilerMSL::add_composite_variable_to_interface_block(StorageClass storage, const string &ib_var_ref, + SPIRType &ib_type, SPIRVariable &var, + InterfaceBlockMeta &meta) +{ + auto &entry_func = get(ir.default_entry_point); + auto &var_type = meta.strip_array ? get_variable_element_type(var) : get_variable_data_type(var); + uint32_t elem_cnt = 0; + + if (add_component_variable_to_interface_block(storage, ib_var_ref, var, var_type, meta)) + return; + + if (is_matrix(var_type)) + { + if (is_array(var_type)) + SPIRV_CROSS_THROW("MSL cannot emit arrays-of-matrices in input and output variables."); + + elem_cnt = var_type.columns; + } + else if (is_array(var_type)) + { + if (var_type.array.size() != 1) + SPIRV_CROSS_THROW("MSL cannot emit arrays-of-arrays in input and output variables."); + + elem_cnt = to_array_size_literal(var_type); + } + + bool is_builtin = is_builtin_variable(var); + BuiltIn builtin = BuiltIn(get_decoration(var.self, DecorationBuiltIn)); + bool is_flat = has_decoration(var.self, DecorationFlat); + bool is_noperspective = has_decoration(var.self, DecorationNoPerspective); + bool is_centroid = has_decoration(var.self, DecorationCentroid); + bool is_sample = has_decoration(var.self, DecorationSample); + + auto *usable_type = &var_type; + if (usable_type->pointer) + usable_type = &get(usable_type->parent_type); + while (is_array(*usable_type) || is_matrix(*usable_type)) + usable_type = &get(usable_type->parent_type); + + // If a builtin, force it to have the proper name. + if (is_builtin) + set_name(var.self, builtin_to_glsl(builtin, StorageClassFunction)); + + bool flatten_from_ib_var = false; + string flatten_from_ib_mbr_name; + + if (storage == StorageClassOutput && is_builtin && builtin == BuiltInClipDistance) + { + // Also declare [[clip_distance]] attribute here. + uint32_t clip_array_mbr_idx = uint32_t(ib_type.member_types.size()); + ib_type.member_types.push_back(get_variable_data_type_id(var)); + set_member_decoration(ib_type.self, clip_array_mbr_idx, DecorationBuiltIn, BuiltInClipDistance); + + flatten_from_ib_mbr_name = builtin_to_glsl(BuiltInClipDistance, StorageClassOutput); + set_member_name(ib_type.self, clip_array_mbr_idx, flatten_from_ib_mbr_name); + + // When we flatten, we flatten directly from the "out" struct, + // not from a function variable. + flatten_from_ib_var = true; + + if (!msl_options.enable_clip_distance_user_varying) + return; + } + else if (!meta.strip_array) + { + // Only flatten/unflatten IO composites for non-tessellation cases where arrays are not stripped. + entry_func.add_local_variable(var.self); + // We need to declare the variable early and at entry-point scope. + vars_needing_early_declaration.push_back(var.self); + } + + for (uint32_t i = 0; i < elem_cnt; i++) + { + // Add a reference to the variable type to the interface struct. + uint32_t ib_mbr_idx = uint32_t(ib_type.member_types.size()); + + uint32_t target_components = 0; + bool padded_output = false; + uint32_t type_id = usable_type->self; + + // Check if we need to pad fragment output to match a certain number of components. + if (get_decoration_bitset(var.self).get(DecorationLocation) && msl_options.pad_fragment_output_components && + get_entry_point().model == ExecutionModelFragment && storage == StorageClassOutput) + { + uint32_t locn = get_decoration(var.self, DecorationLocation) + i; + target_components = get_target_components_for_fragment_location(locn); + if (usable_type->vecsize < target_components) + { + // Make a new type here. + type_id = build_extended_vector_type(usable_type->self, target_components); + padded_output = true; + } + } + + if (storage == StorageClassInput && pull_model_inputs.count(var.self)) + ib_type.member_types.push_back(build_msl_interpolant_type(get_pointee_type_id(type_id), is_noperspective)); + else + ib_type.member_types.push_back(get_pointee_type_id(type_id)); + + // Give the member a name + string mbr_name = ensure_valid_name(join(to_expression(var.self), "_", i), "m"); + set_member_name(ib_type.self, ib_mbr_idx, mbr_name); + + // There is no qualified alias since we need to flatten the internal array on return. + if (get_decoration_bitset(var.self).get(DecorationLocation)) + { + uint32_t locn = get_decoration(var.self, DecorationLocation) + i; + uint32_t comp = get_decoration(var.self, DecorationComponent); + if (storage == StorageClassInput) + { + var.basetype = ensure_correct_input_type(var.basetype, locn, comp, 0, meta.strip_array); + uint32_t mbr_type_id = ensure_correct_input_type(usable_type->self, locn, comp, 0, meta.strip_array); + if (storage == StorageClassInput && pull_model_inputs.count(var.self)) + ib_type.member_types[ib_mbr_idx] = build_msl_interpolant_type(mbr_type_id, is_noperspective); + else + ib_type.member_types[ib_mbr_idx] = mbr_type_id; + } + set_member_decoration(ib_type.self, ib_mbr_idx, DecorationLocation, locn); + if (comp) + set_member_decoration(ib_type.self, ib_mbr_idx, DecorationComponent, comp); + mark_location_as_used_by_shader(locn, *usable_type, storage); + } + else if (is_builtin && is_tessellation_shader() && storage == StorageClassInput && inputs_by_builtin.count(builtin)) + { + uint32_t locn = inputs_by_builtin[builtin].location + i; + set_member_decoration(ib_type.self, ib_mbr_idx, DecorationLocation, locn); + mark_location_as_used_by_shader(locn, *usable_type, storage); + } + else if (is_builtin && capture_output_to_buffer && storage == StorageClassOutput && outputs_by_builtin.count(builtin)) + { + uint32_t locn = outputs_by_builtin[builtin].location + i; + set_member_decoration(ib_type.self, ib_mbr_idx, DecorationLocation, locn); + mark_location_as_used_by_shader(locn, *usable_type, storage); + } + else if (is_builtin && (builtin == BuiltInClipDistance || builtin == BuiltInCullDistance)) + { + // Declare the Clip/CullDistance as [[user(clip/cullN)]]. + set_member_decoration(ib_type.self, ib_mbr_idx, DecorationBuiltIn, builtin); + set_member_decoration(ib_type.self, ib_mbr_idx, DecorationIndex, i); + } + + if (get_decoration_bitset(var.self).get(DecorationIndex)) + { + uint32_t index = get_decoration(var.self, DecorationIndex); + set_member_decoration(ib_type.self, ib_mbr_idx, DecorationIndex, index); + } + + if (storage != StorageClassInput || !pull_model_inputs.count(var.self)) + { + // Copy interpolation decorations if needed + if (is_flat) + set_member_decoration(ib_type.self, ib_mbr_idx, DecorationFlat); + if (is_noperspective) + set_member_decoration(ib_type.self, ib_mbr_idx, DecorationNoPerspective); + if (is_centroid) + set_member_decoration(ib_type.self, ib_mbr_idx, DecorationCentroid); + if (is_sample) + set_member_decoration(ib_type.self, ib_mbr_idx, DecorationSample); + } + + set_extended_member_decoration(ib_type.self, ib_mbr_idx, SPIRVCrossDecorationInterfaceOrigID, var.self); + + // Only flatten/unflatten IO composites for non-tessellation cases where arrays are not stripped. + if (!meta.strip_array) + { + switch (storage) + { + case StorageClassInput: + entry_func.fixup_hooks_in.push_back([=, &var]() { + if (pull_model_inputs.count(var.self)) + { + string lerp_call; + if (is_centroid) + lerp_call = ".interpolate_at_centroid()"; + else if (is_sample) + lerp_call = join(".interpolate_at_sample(", to_expression(builtin_sample_id_id), ")"); + else + lerp_call = ".interpolate_at_center()"; + statement(to_name(var.self), "[", i, "] = ", ib_var_ref, ".", mbr_name, lerp_call, ";"); + } + else + { + statement(to_name(var.self), "[", i, "] = ", ib_var_ref, ".", mbr_name, ";"); + } + }); + break; + + case StorageClassOutput: + entry_func.fixup_hooks_out.push_back([=, &var]() { + if (padded_output) + { + auto &padded_type = this->get(type_id); + statement( + ib_var_ref, ".", mbr_name, " = ", + remap_swizzle(padded_type, usable_type->vecsize, join(to_name(var.self), "[", i, "]")), + ";"); + } + else if (flatten_from_ib_var) + statement(ib_var_ref, ".", mbr_name, " = ", ib_var_ref, ".", flatten_from_ib_mbr_name, "[", i, + "];"); + else + statement(ib_var_ref, ".", mbr_name, " = ", to_name(var.self), "[", i, "];"); + }); + break; + + default: + break; + } + } + } +} + +void CompilerMSL::add_composite_member_variable_to_interface_block(StorageClass storage, + const string &ib_var_ref, SPIRType &ib_type, + SPIRVariable &var, SPIRType &var_type, + uint32_t mbr_idx, InterfaceBlockMeta &meta, + const string &mbr_name_qual, + const string &var_chain_qual, + uint32_t &location, uint32_t &var_mbr_idx, + const Bitset &interpolation_qual) +{ + auto &entry_func = get(ir.default_entry_point); + + BuiltIn builtin = BuiltInMax; + bool is_builtin = is_member_builtin(var_type, mbr_idx, &builtin); + bool is_flat = interpolation_qual.get(DecorationFlat) || + has_member_decoration(var_type.self, mbr_idx, DecorationFlat) || + has_decoration(var.self, DecorationFlat); + bool is_noperspective = interpolation_qual.get(DecorationNoPerspective) || + has_member_decoration(var_type.self, mbr_idx, DecorationNoPerspective) || + has_decoration(var.self, DecorationNoPerspective); + bool is_centroid = interpolation_qual.get(DecorationCentroid) || + has_member_decoration(var_type.self, mbr_idx, DecorationCentroid) || + has_decoration(var.self, DecorationCentroid); + bool is_sample = interpolation_qual.get(DecorationSample) || + has_member_decoration(var_type.self, mbr_idx, DecorationSample) || + has_decoration(var.self, DecorationSample); + + Bitset inherited_qual; + if (is_flat) + inherited_qual.set(DecorationFlat); + if (is_noperspective) + inherited_qual.set(DecorationNoPerspective); + if (is_centroid) + inherited_qual.set(DecorationCentroid); + if (is_sample) + inherited_qual.set(DecorationSample); + + uint32_t mbr_type_id = var_type.member_types[mbr_idx]; + auto &mbr_type = get(mbr_type_id); + + bool mbr_is_indexable = false; + uint32_t elem_cnt = 1; + if (is_matrix(mbr_type)) + { + if (is_array(mbr_type)) + SPIRV_CROSS_THROW("MSL cannot emit arrays-of-matrices in input and output variables."); + + mbr_is_indexable = true; + elem_cnt = mbr_type.columns; + } + else if (is_array(mbr_type)) + { + if (mbr_type.array.size() != 1) + SPIRV_CROSS_THROW("MSL cannot emit arrays-of-arrays in input and output variables."); + + mbr_is_indexable = true; + elem_cnt = to_array_size_literal(mbr_type); + } + + auto *usable_type = &mbr_type; + if (usable_type->pointer) + usable_type = &get(usable_type->parent_type); + while (is_array(*usable_type) || is_matrix(*usable_type)) + usable_type = &get(usable_type->parent_type); + + bool flatten_from_ib_var = false; + string flatten_from_ib_mbr_name; + + if (storage == StorageClassOutput && is_builtin && builtin == BuiltInClipDistance) + { + // Also declare [[clip_distance]] attribute here. + uint32_t clip_array_mbr_idx = uint32_t(ib_type.member_types.size()); + ib_type.member_types.push_back(mbr_type_id); + set_member_decoration(ib_type.self, clip_array_mbr_idx, DecorationBuiltIn, BuiltInClipDistance); + + flatten_from_ib_mbr_name = builtin_to_glsl(BuiltInClipDistance, StorageClassOutput); + set_member_name(ib_type.self, clip_array_mbr_idx, flatten_from_ib_mbr_name); + + // When we flatten, we flatten directly from the "out" struct, + // not from a function variable. + flatten_from_ib_var = true; + + if (!msl_options.enable_clip_distance_user_varying) + return; + } + + // Recursively handle nested structures. + if (mbr_type.basetype == SPIRType::Struct) + { + for (uint32_t i = 0; i < elem_cnt; i++) + { + string mbr_name = append_member_name(mbr_name_qual, var_type, mbr_idx) + (mbr_is_indexable ? join("_", i) : ""); + string var_chain = join(var_chain_qual, ".", to_member_name(var_type, mbr_idx), (mbr_is_indexable ? join("[", i, "]") : "")); + uint32_t sub_mbr_cnt = uint32_t(mbr_type.member_types.size()); + for (uint32_t sub_mbr_idx = 0; sub_mbr_idx < sub_mbr_cnt; sub_mbr_idx++) + { + add_composite_member_variable_to_interface_block(storage, ib_var_ref, ib_type, + var, mbr_type, sub_mbr_idx, + meta, mbr_name, var_chain, + location, var_mbr_idx, inherited_qual); + // FIXME: Recursive structs and tessellation breaks here. + var_mbr_idx++; + } + } + return; + } + + for (uint32_t i = 0; i < elem_cnt; i++) + { + // Add a reference to the variable type to the interface struct. + uint32_t ib_mbr_idx = uint32_t(ib_type.member_types.size()); + if (storage == StorageClassInput && pull_model_inputs.count(var.self)) + ib_type.member_types.push_back(build_msl_interpolant_type(usable_type->self, is_noperspective)); + else + ib_type.member_types.push_back(usable_type->self); + + // Give the member a name + string mbr_name = ensure_valid_name(append_member_name(mbr_name_qual, var_type, mbr_idx) + (mbr_is_indexable ? join("_", i) : ""), "m"); + set_member_name(ib_type.self, ib_mbr_idx, mbr_name); + + // Once we determine the location of the first member within nested structures, + // from a var of the topmost structure, the remaining flattened members of + // the nested structures will have consecutive location values. At this point, + // we've recursively tunnelled into structs, arrays, and matrices, and are + // down to a single location for each member now. + if (!is_builtin && location != UINT32_MAX) + { + set_member_decoration(ib_type.self, ib_mbr_idx, DecorationLocation, location); + mark_location_as_used_by_shader(location, *usable_type, storage); + location++; + } + else if (has_member_decoration(var_type.self, mbr_idx, DecorationLocation)) + { + location = get_member_decoration(var_type.self, mbr_idx, DecorationLocation) + i; + set_member_decoration(ib_type.self, ib_mbr_idx, DecorationLocation, location); + mark_location_as_used_by_shader(location, *usable_type, storage); + location++; + } + else if (has_decoration(var.self, DecorationLocation)) + { + location = get_accumulated_member_location(var, mbr_idx, meta.strip_array) + i; + set_member_decoration(ib_type.self, ib_mbr_idx, DecorationLocation, location); + mark_location_as_used_by_shader(location, *usable_type, storage); + location++; + } + else if (is_builtin && is_tessellation_shader() && storage == StorageClassInput && inputs_by_builtin.count(builtin)) + { + location = inputs_by_builtin[builtin].location + i; + set_member_decoration(ib_type.self, ib_mbr_idx, DecorationLocation, location); + mark_location_as_used_by_shader(location, *usable_type, storage); + location++; + } + else if (is_builtin && capture_output_to_buffer && storage == StorageClassOutput && outputs_by_builtin.count(builtin)) + { + location = outputs_by_builtin[builtin].location + i; + set_member_decoration(ib_type.self, ib_mbr_idx, DecorationLocation, location); + mark_location_as_used_by_shader(location, *usable_type, storage); + location++; + } + else if (is_builtin && (builtin == BuiltInClipDistance || builtin == BuiltInCullDistance)) + { + // Declare the Clip/CullDistance as [[user(clip/cullN)]]. + set_member_decoration(ib_type.self, ib_mbr_idx, DecorationBuiltIn, builtin); + set_member_decoration(ib_type.self, ib_mbr_idx, DecorationIndex, i); + } + + if (has_member_decoration(var_type.self, mbr_idx, DecorationComponent)) + SPIRV_CROSS_THROW("DecorationComponent on matrices and arrays is not supported."); + + if (storage != StorageClassInput || !pull_model_inputs.count(var.self)) + { + // Copy interpolation decorations if needed + if (is_flat) + set_member_decoration(ib_type.self, ib_mbr_idx, DecorationFlat); + if (is_noperspective) + set_member_decoration(ib_type.self, ib_mbr_idx, DecorationNoPerspective); + if (is_centroid) + set_member_decoration(ib_type.self, ib_mbr_idx, DecorationCentroid); + if (is_sample) + set_member_decoration(ib_type.self, ib_mbr_idx, DecorationSample); + } + + set_extended_member_decoration(ib_type.self, ib_mbr_idx, SPIRVCrossDecorationInterfaceOrigID, var.self); + set_extended_member_decoration(ib_type.self, ib_mbr_idx, SPIRVCrossDecorationInterfaceMemberIndex, var_mbr_idx); + + // Unflatten or flatten from [[stage_in]] or [[stage_out]] as appropriate. + if (!meta.strip_array && meta.allow_local_declaration) + { + string var_chain = join(var_chain_qual, ".", to_member_name(var_type, mbr_idx), (mbr_is_indexable ? join("[", i, "]") : "")); + switch (storage) + { + case StorageClassInput: + entry_func.fixup_hooks_in.push_back([=, &var]() { + string lerp_call; + if (pull_model_inputs.count(var.self)) + { + if (is_centroid) + lerp_call = ".interpolate_at_centroid()"; + else if (is_sample) + lerp_call = join(".interpolate_at_sample(", to_expression(builtin_sample_id_id), ")"); + else + lerp_call = ".interpolate_at_center()"; + } + statement(var_chain, " = ", ib_var_ref, ".", mbr_name, lerp_call, ";"); + }); + break; + + case StorageClassOutput: + entry_func.fixup_hooks_out.push_back([=]() { + if (flatten_from_ib_var) + statement(ib_var_ref, ".", mbr_name, " = ", ib_var_ref, ".", flatten_from_ib_mbr_name, "[", i, "];"); + else + statement(ib_var_ref, ".", mbr_name, " = ", var_chain, ";"); + }); + break; + + default: + break; + } + } + } +} + +void CompilerMSL::add_plain_member_variable_to_interface_block(StorageClass storage, + const string &ib_var_ref, SPIRType &ib_type, + SPIRVariable &var, SPIRType &var_type, + uint32_t mbr_idx, InterfaceBlockMeta &meta, + const string &mbr_name_qual, + const string &var_chain_qual, + uint32_t &location, uint32_t &var_mbr_idx) +{ + auto &entry_func = get(ir.default_entry_point); + + BuiltIn builtin = BuiltInMax; + bool is_builtin = is_member_builtin(var_type, mbr_idx, &builtin); + bool is_flat = + has_member_decoration(var_type.self, mbr_idx, DecorationFlat) || has_decoration(var.self, DecorationFlat); + bool is_noperspective = has_member_decoration(var_type.self, mbr_idx, DecorationNoPerspective) || + has_decoration(var.self, DecorationNoPerspective); + bool is_centroid = has_member_decoration(var_type.self, mbr_idx, DecorationCentroid) || + has_decoration(var.self, DecorationCentroid); + bool is_sample = + has_member_decoration(var_type.self, mbr_idx, DecorationSample) || has_decoration(var.self, DecorationSample); + + // Add a reference to the member to the interface struct. + uint32_t mbr_type_id = var_type.member_types[mbr_idx]; + uint32_t ib_mbr_idx = uint32_t(ib_type.member_types.size()); + mbr_type_id = ensure_correct_builtin_type(mbr_type_id, builtin); + var_type.member_types[mbr_idx] = mbr_type_id; + if (storage == StorageClassInput && pull_model_inputs.count(var.self)) + ib_type.member_types.push_back(build_msl_interpolant_type(mbr_type_id, is_noperspective)); + else + ib_type.member_types.push_back(mbr_type_id); + + // Give the member a name + string mbr_name = ensure_valid_name(append_member_name(mbr_name_qual, var_type, mbr_idx), "m"); + set_member_name(ib_type.self, ib_mbr_idx, mbr_name); + + // Update the original variable reference to include the structure reference + string qual_var_name = ib_var_ref + "." + mbr_name; + // If using pull-model interpolation, need to add a call to the correct interpolation method. + if (storage == StorageClassInput && pull_model_inputs.count(var.self)) + { + if (is_centroid) + qual_var_name += ".interpolate_at_centroid()"; + else if (is_sample) + qual_var_name += join(".interpolate_at_sample(", to_expression(builtin_sample_id_id), ")"); + else + qual_var_name += ".interpolate_at_center()"; + } + + bool flatten_stage_out = false; + string var_chain = var_chain_qual + "." + to_member_name(var_type, mbr_idx); + if (is_builtin && !meta.strip_array) + { + // For the builtin gl_PerVertex, we cannot treat it as a block anyways, + // so redirect to qualified name. + set_member_qualified_name(var_type.self, mbr_idx, qual_var_name); + } + else if (!meta.strip_array && meta.allow_local_declaration) + { + // Unflatten or flatten from [[stage_in]] or [[stage_out]] as appropriate. + switch (storage) + { + case StorageClassInput: + entry_func.fixup_hooks_in.push_back([=]() { + statement(var_chain, " = ", qual_var_name, ";"); + }); + break; + + case StorageClassOutput: + flatten_stage_out = true; + entry_func.fixup_hooks_out.push_back([=]() { + statement(qual_var_name, " = ", var_chain, ";"); + }); + break; + + default: + break; + } + } + + // Once we determine the location of the first member within nested structures, + // from a var of the topmost structure, the remaining flattened members of + // the nested structures will have consecutive location values. At this point, + // we've recursively tunnelled into structs, arrays, and matrices, and are + // down to a single location for each member now. + if (!is_builtin && location != UINT32_MAX) + { + set_member_decoration(ib_type.self, ib_mbr_idx, DecorationLocation, location); + mark_location_as_used_by_shader(location, get(mbr_type_id), storage); + location += type_to_location_count(get(mbr_type_id)); + } + else if (has_member_decoration(var_type.self, mbr_idx, DecorationLocation)) + { + location = get_member_decoration(var_type.self, mbr_idx, DecorationLocation); + uint32_t comp = get_member_decoration(var_type.self, mbr_idx, DecorationComponent); + if (storage == StorageClassInput) + { + mbr_type_id = ensure_correct_input_type(mbr_type_id, location, comp, 0, meta.strip_array); + var_type.member_types[mbr_idx] = mbr_type_id; + if (storage == StorageClassInput && pull_model_inputs.count(var.self)) + ib_type.member_types[ib_mbr_idx] = build_msl_interpolant_type(mbr_type_id, is_noperspective); + else + ib_type.member_types[ib_mbr_idx] = mbr_type_id; + } + set_member_decoration(ib_type.self, ib_mbr_idx, DecorationLocation, location); + mark_location_as_used_by_shader(location, get(mbr_type_id), storage); + location += type_to_location_count(get(mbr_type_id)); + } + else if (has_decoration(var.self, DecorationLocation)) + { + location = get_accumulated_member_location(var, mbr_idx, meta.strip_array); + if (storage == StorageClassInput) + { + mbr_type_id = ensure_correct_input_type(mbr_type_id, location, 0, 0, meta.strip_array); + var_type.member_types[mbr_idx] = mbr_type_id; + if (storage == StorageClassInput && pull_model_inputs.count(var.self)) + ib_type.member_types[ib_mbr_idx] = build_msl_interpolant_type(mbr_type_id, is_noperspective); + else + ib_type.member_types[ib_mbr_idx] = mbr_type_id; + } + set_member_decoration(ib_type.self, ib_mbr_idx, DecorationLocation, location); + mark_location_as_used_by_shader(location, get(mbr_type_id), storage); + location += type_to_location_count(get(mbr_type_id)); + } + else if (is_builtin && is_tessellation_shader() && storage == StorageClassInput && inputs_by_builtin.count(builtin)) + { + location = inputs_by_builtin[builtin].location; + set_member_decoration(ib_type.self, ib_mbr_idx, DecorationLocation, location); + mark_location_as_used_by_shader(location, get(mbr_type_id), storage); + location += type_to_location_count(get(mbr_type_id)); + } + else if (is_builtin && capture_output_to_buffer && storage == StorageClassOutput && outputs_by_builtin.count(builtin)) + { + location = outputs_by_builtin[builtin].location; + set_member_decoration(ib_type.self, ib_mbr_idx, DecorationLocation, location); + mark_location_as_used_by_shader(location, get(mbr_type_id), storage); + location += type_to_location_count(get(mbr_type_id)); + } + + // Copy the component location, if present. + if (has_member_decoration(var_type.self, mbr_idx, DecorationComponent)) + { + uint32_t comp = get_member_decoration(var_type.self, mbr_idx, DecorationComponent); + set_member_decoration(ib_type.self, ib_mbr_idx, DecorationComponent, comp); + } + + // Mark the member as builtin if needed + if (is_builtin) + { + set_member_decoration(ib_type.self, ib_mbr_idx, DecorationBuiltIn, builtin); + if (builtin == BuiltInPosition && storage == StorageClassOutput) + qual_pos_var_name = qual_var_name; + } + + const SPIRConstant *c = nullptr; + if (!flatten_stage_out && var.storage == StorageClassOutput && + var.initializer != ID(0) && (c = maybe_get(var.initializer))) + { + if (meta.strip_array) + { + entry_func.fixup_hooks_in.push_back([=, &var]() { + auto &type = this->get(var.basetype); + uint32_t index = get_extended_member_decoration(var.self, mbr_idx, SPIRVCrossDecorationInterfaceMemberIndex); + + auto invocation = to_tesc_invocation_id(); + auto constant_chain = join(to_expression(var.initializer), "[", invocation, "]"); + statement(to_expression(stage_out_ptr_var_id), "[", + invocation, "].", + to_member_name(ib_type, index), " = ", + constant_chain, ".", to_member_name(type, mbr_idx), ";"); + }); + } + else + { + entry_func.fixup_hooks_in.push_back([=]() { + statement(qual_var_name, " = ", constant_expression( + this->get(c->subconstants[mbr_idx])), ";"); + }); + } + } + + if (storage != StorageClassInput || !pull_model_inputs.count(var.self)) + { + // Copy interpolation decorations if needed + if (is_flat) + set_member_decoration(ib_type.self, ib_mbr_idx, DecorationFlat); + if (is_noperspective) + set_member_decoration(ib_type.self, ib_mbr_idx, DecorationNoPerspective); + if (is_centroid) + set_member_decoration(ib_type.self, ib_mbr_idx, DecorationCentroid); + if (is_sample) + set_member_decoration(ib_type.self, ib_mbr_idx, DecorationSample); + } + + set_extended_member_decoration(ib_type.self, ib_mbr_idx, SPIRVCrossDecorationInterfaceOrigID, var.self); + set_extended_member_decoration(ib_type.self, ib_mbr_idx, SPIRVCrossDecorationInterfaceMemberIndex, var_mbr_idx); +} + +// In Metal, the tessellation levels are stored as tightly packed half-precision floating point values. +// But, stage-in attribute offsets and strides must be multiples of four, so we can't pass the levels +// individually. Therefore, we must pass them as vectors. Triangles get a single float4, with the outer +// levels in 'xyz' and the inner level in 'w'. Quads get a float4 containing the outer levels and a +// float2 containing the inner levels. +void CompilerMSL::add_tess_level_input_to_interface_block(const std::string &ib_var_ref, SPIRType &ib_type, + SPIRVariable &var) +{ + auto &var_type = get_variable_element_type(var); + + BuiltIn builtin = BuiltIn(get_decoration(var.self, DecorationBuiltIn)); + bool triangles = is_tessellating_triangles(); + string mbr_name; + + // Add a reference to the variable type to the interface struct. + uint32_t ib_mbr_idx = uint32_t(ib_type.member_types.size()); + + const auto mark_locations = [&](const SPIRType &new_var_type) { + if (get_decoration_bitset(var.self).get(DecorationLocation)) + { + uint32_t locn = get_decoration(var.self, DecorationLocation); + set_member_decoration(ib_type.self, ib_mbr_idx, DecorationLocation, locn); + mark_location_as_used_by_shader(locn, new_var_type, StorageClassInput); + } + else if (inputs_by_builtin.count(builtin)) + { + uint32_t locn = inputs_by_builtin[builtin].location; + set_member_decoration(ib_type.self, ib_mbr_idx, DecorationLocation, locn); + mark_location_as_used_by_shader(locn, new_var_type, StorageClassInput); + } + }; + + if (triangles) + { + // Triangles are tricky, because we want only one member in the struct. + mbr_name = "gl_TessLevel"; + + // If we already added the other one, we can skip this step. + if (!added_builtin_tess_level) + { + uint32_t type_id = build_extended_vector_type(var_type.self, 4); + + ib_type.member_types.push_back(type_id); + + // Give the member a name + set_member_name(ib_type.self, ib_mbr_idx, mbr_name); + + // We cannot decorate both, but the important part is that + // it's marked as builtin so we can get automatic attribute assignment if needed. + set_member_decoration(ib_type.self, ib_mbr_idx, DecorationBuiltIn, builtin); + + mark_locations(var_type); + added_builtin_tess_level = true; + } + } + else + { + mbr_name = builtin_to_glsl(builtin, StorageClassFunction); + + uint32_t type_id = build_extended_vector_type(var_type.self, builtin == BuiltInTessLevelOuter ? 4 : 2); + + uint32_t ptr_type_id = ir.increase_bound_by(1); + auto &new_var_type = set(ptr_type_id, get(type_id)); + new_var_type.pointer = true; + new_var_type.pointer_depth++; + new_var_type.storage = StorageClassInput; + new_var_type.parent_type = type_id; + + ib_type.member_types.push_back(type_id); + + // Give the member a name + set_member_name(ib_type.self, ib_mbr_idx, mbr_name); + set_member_decoration(ib_type.self, ib_mbr_idx, DecorationBuiltIn, builtin); + + mark_locations(new_var_type); + } + + add_tess_level_input(ib_var_ref, mbr_name, var); +} + +void CompilerMSL::add_tess_level_input(const std::string &base_ref, const std::string &mbr_name, SPIRVariable &var) +{ + auto &entry_func = get(ir.default_entry_point); + BuiltIn builtin = BuiltIn(get_decoration(var.self, DecorationBuiltIn)); + + // Force the variable to have the proper name. + string var_name = builtin_to_glsl(builtin, StorageClassFunction); + set_name(var.self, var_name); + + // We need to declare the variable early and at entry-point scope. + entry_func.add_local_variable(var.self); + vars_needing_early_declaration.push_back(var.self); + bool triangles = is_tessellating_triangles(); + + if (builtin == BuiltInTessLevelOuter) + { + entry_func.fixup_hooks_in.push_back( + [=]() + { + statement(var_name, "[0] = ", base_ref, ".", mbr_name, "[0];"); + statement(var_name, "[1] = ", base_ref, ".", mbr_name, "[1];"); + statement(var_name, "[2] = ", base_ref, ".", mbr_name, "[2];"); + if (!triangles) + statement(var_name, "[3] = ", base_ref, ".", mbr_name, "[3];"); + }); + } + else + { + entry_func.fixup_hooks_in.push_back([=]() { + if (triangles) + { + if (msl_options.raw_buffer_tese_input) + statement(var_name, "[0] = ", base_ref, ".", mbr_name, ";"); + else + statement(var_name, "[0] = ", base_ref, ".", mbr_name, "[3];"); + } + else + { + statement(var_name, "[0] = ", base_ref, ".", mbr_name, "[0];"); + statement(var_name, "[1] = ", base_ref, ".", mbr_name, "[1];"); + } + }); + } +} + +bool CompilerMSL::variable_storage_requires_stage_io(spv::StorageClass storage) const +{ + if (storage == StorageClassOutput) + return !capture_output_to_buffer; + else if (storage == StorageClassInput) + return !(is_tesc_shader() && msl_options.multi_patch_workgroup) && + !(is_tese_shader() && msl_options.raw_buffer_tese_input); + else + return false; +} + +string CompilerMSL::to_tesc_invocation_id() +{ + if (msl_options.multi_patch_workgroup) + { + // n.b. builtin_invocation_id_id here is the dispatch global invocation ID, + // not the TC invocation ID. + return join(to_expression(builtin_invocation_id_id), ".x % ", get_entry_point().output_vertices); + } + else + return builtin_to_glsl(BuiltInInvocationId, StorageClassInput); +} + +void CompilerMSL::emit_local_masked_variable(const SPIRVariable &masked_var, bool strip_array) +{ + auto &entry_func = get(ir.default_entry_point); + bool threadgroup_storage = variable_decl_is_remapped_storage(masked_var, StorageClassWorkgroup); + + if (threadgroup_storage && msl_options.multi_patch_workgroup) + { + // We need one threadgroup block per patch, so fake this. + entry_func.fixup_hooks_in.push_back([this, &masked_var]() { + auto &type = get_variable_data_type(masked_var); + add_local_variable_name(masked_var.self); + + const uint32_t max_control_points_per_patch = 32u; + uint32_t max_num_instances = + (max_control_points_per_patch + get_entry_point().output_vertices - 1u) / + get_entry_point().output_vertices; + statement("threadgroup ", type_to_glsl(type), " ", + "spvStorage", to_name(masked_var.self), "[", max_num_instances, "]", + type_to_array_glsl(type, 0), ";"); + + // Assign a threadgroup slice to each PrimitiveID. + // We assume here that workgroup size is rounded to 32, + // since that's the maximum number of control points per patch. + // We cannot size the array based on fixed dispatch parameters, + // since Metal does not allow that. :( + // FIXME: We will likely need an option to support passing down target workgroup size, + // so we can emit appropriate size here. + statement("threadgroup auto ", + "&", to_name(masked_var.self), + " = spvStorage", to_name(masked_var.self), "[", + "(", to_expression(builtin_invocation_id_id), ".x / ", + get_entry_point().output_vertices, ") % ", + max_num_instances, "];"); + }); + } + else + { + entry_func.add_local_variable(masked_var.self); + } + + if (!threadgroup_storage) + { + vars_needing_early_declaration.push_back(masked_var.self); + } + else if (masked_var.initializer) + { + // Cannot directly initialize threadgroup variables. Need fixup hooks. + ID initializer = masked_var.initializer; + if (strip_array) + { + entry_func.fixup_hooks_in.push_back([this, &masked_var, initializer]() { + auto invocation = to_tesc_invocation_id(); + statement(to_expression(masked_var.self), "[", + invocation, "] = ", + to_expression(initializer), "[", + invocation, "];"); + }); + } + else + { + entry_func.fixup_hooks_in.push_back([this, &masked_var, initializer]() { + statement(to_expression(masked_var.self), " = ", to_expression(initializer), ";"); + }); + } + } +} + +void CompilerMSL::add_variable_to_interface_block(StorageClass storage, const string &ib_var_ref, SPIRType &ib_type, + SPIRVariable &var, InterfaceBlockMeta &meta) +{ + auto &entry_func = get(ir.default_entry_point); + // Tessellation control I/O variables and tessellation evaluation per-point inputs are + // usually declared as arrays. In these cases, we want to add the element type to the + // interface block, since in Metal it's the interface block itself which is arrayed. + auto &var_type = meta.strip_array ? get_variable_element_type(var) : get_variable_data_type(var); + bool is_builtin = is_builtin_variable(var); + auto builtin = BuiltIn(get_decoration(var.self, DecorationBuiltIn)); + bool is_block = has_decoration(var_type.self, DecorationBlock); + + // If stage variables are masked out, emit them as plain variables instead. + // For builtins, we query them one by one later. + // IO blocks are not masked here, we need to mask them per-member instead. + if (storage == StorageClassOutput && is_stage_output_variable_masked(var)) + { + // If we ignore an output, we must still emit it, since it might be used by app. + // Instead, just emit it as early declaration. + emit_local_masked_variable(var, meta.strip_array); + return; + } + + if (storage == StorageClassInput && has_decoration(var.self, DecorationPerVertexKHR)) + SPIRV_CROSS_THROW("PerVertexKHR decoration is not supported in MSL."); + + // If variable names alias, they will end up with wrong names in the interface struct, because + // there might be aliases in the member name cache and there would be a mismatch in fixup_in code. + // Make sure to register the variables as unique resource names ahead of time. + // This would normally conflict with the name cache when emitting local variables, + // but this happens in the setup stage, before we hit compilation loops. + // The name cache is cleared before we actually emit code, so this is safe. + add_resource_name(var.self); + + if (var_type.basetype == SPIRType::Struct) + { + bool block_requires_flattening = + variable_storage_requires_stage_io(storage) || (is_block && var_type.array.empty()); + bool needs_local_declaration = !is_builtin && block_requires_flattening && meta.allow_local_declaration; + + if (needs_local_declaration) + { + // For I/O blocks or structs, we will need to pass the block itself around + // to functions if they are used globally in leaf functions. + // Rather than passing down member by member, + // we unflatten I/O blocks while running the shader, + // and pass the actual struct type down to leaf functions. + // We then unflatten inputs, and flatten outputs in the "fixup" stages. + emit_local_masked_variable(var, meta.strip_array); + } + + if (!block_requires_flattening) + { + // In Metal tessellation shaders, the interface block itself is arrayed. This makes things + // very complicated, since stage-in structures in MSL don't support nested structures. + // Luckily, for stage-out when capturing output, we can avoid this and just add + // composite members directly, because the stage-out structure is stored to a buffer, + // not returned. + add_plain_variable_to_interface_block(storage, ib_var_ref, ib_type, var, meta); + } + else + { + bool masked_block = false; + uint32_t location = UINT32_MAX; + uint32_t var_mbr_idx = 0; + uint32_t elem_cnt = 1; + if (is_matrix(var_type)) + { + if (is_array(var_type)) + SPIRV_CROSS_THROW("MSL cannot emit arrays-of-matrices in input and output variables."); + + elem_cnt = var_type.columns; + } + else if (is_array(var_type)) + { + if (var_type.array.size() != 1) + SPIRV_CROSS_THROW("MSL cannot emit arrays-of-arrays in input and output variables."); + + elem_cnt = to_array_size_literal(var_type); + } + + for (uint32_t elem_idx = 0; elem_idx < elem_cnt; elem_idx++) + { + // Flatten the struct members into the interface struct + for (uint32_t mbr_idx = 0; mbr_idx < uint32_t(var_type.member_types.size()); mbr_idx++) + { + builtin = BuiltInMax; + is_builtin = is_member_builtin(var_type, mbr_idx, &builtin); + auto &mbr_type = get(var_type.member_types[mbr_idx]); + + if (storage == StorageClassOutput && is_stage_output_block_member_masked(var, mbr_idx, meta.strip_array)) + { + location = UINT32_MAX; // Skip this member and resolve location again on next var member + + if (is_block) + masked_block = true; + + // Non-builtin block output variables are just ignored, since they will still access + // the block variable as-is. They're just not flattened. + if (is_builtin && !meta.strip_array) + { + // Emit a fake variable instead. + uint32_t ids = ir.increase_bound_by(2); + uint32_t ptr_type_id = ids + 0; + uint32_t var_id = ids + 1; + + auto ptr_type = mbr_type; + ptr_type.pointer = true; + ptr_type.pointer_depth++; + ptr_type.parent_type = var_type.member_types[mbr_idx]; + ptr_type.storage = StorageClassOutput; + + uint32_t initializer = 0; + if (var.initializer) + if (auto *c = maybe_get(var.initializer)) + initializer = c->subconstants[mbr_idx]; + + set(ptr_type_id, ptr_type); + set(var_id, ptr_type_id, StorageClassOutput, initializer); + entry_func.add_local_variable(var_id); + vars_needing_early_declaration.push_back(var_id); + set_name(var_id, builtin_to_glsl(builtin, StorageClassOutput)); + set_decoration(var_id, DecorationBuiltIn, builtin); + } + } + else if (!is_builtin || has_active_builtin(builtin, storage)) + { + bool is_composite_type = is_matrix(mbr_type) || is_array(mbr_type) || mbr_type.basetype == SPIRType::Struct; + bool attribute_load_store = + storage == StorageClassInput && get_execution_model() != ExecutionModelFragment; + bool storage_is_stage_io = variable_storage_requires_stage_io(storage); + + // Clip/CullDistance always need to be declared as user attributes. + if (builtin == BuiltInClipDistance || builtin == BuiltInCullDistance) + is_builtin = false; + + const string var_name = to_name(var.self); + string mbr_name_qual = var_name; + string var_chain_qual = var_name; + if (elem_cnt > 1) + { + mbr_name_qual += join("_", elem_idx); + var_chain_qual += join("[", elem_idx, "]"); + } + + if ((!is_builtin || attribute_load_store) && storage_is_stage_io && is_composite_type) + { + add_composite_member_variable_to_interface_block(storage, ib_var_ref, ib_type, + var, var_type, mbr_idx, meta, + mbr_name_qual, var_chain_qual, + location, var_mbr_idx, {}); + } + else + { + add_plain_member_variable_to_interface_block(storage, ib_var_ref, ib_type, + var, var_type, mbr_idx, meta, + mbr_name_qual, var_chain_qual, + location, var_mbr_idx); + } + } + var_mbr_idx++; + } + } + + // If we're redirecting a block, we might still need to access the original block + // variable if we're masking some members. + if (masked_block && !needs_local_declaration && (!is_builtin_variable(var) || is_tesc_shader())) + { + if (is_builtin_variable(var)) + { + // Ensure correct names for the block members if we're actually going to + // declare gl_PerVertex. + for (uint32_t mbr_idx = 0; mbr_idx < uint32_t(var_type.member_types.size()); mbr_idx++) + { + set_member_name(var_type.self, mbr_idx, builtin_to_glsl( + BuiltIn(get_member_decoration(var_type.self, mbr_idx, DecorationBuiltIn)), + StorageClassOutput)); + } + + set_name(var_type.self, "gl_PerVertex"); + set_name(var.self, "gl_out_masked"); + stage_out_masked_builtin_type_id = var_type.self; + } + emit_local_masked_variable(var, meta.strip_array); + } + } + } + else if (is_tese_shader() && storage == StorageClassInput && !meta.strip_array && is_builtin && + (builtin == BuiltInTessLevelOuter || builtin == BuiltInTessLevelInner)) + { + add_tess_level_input_to_interface_block(ib_var_ref, ib_type, var); + } + else if (var_type.basetype == SPIRType::Boolean || var_type.basetype == SPIRType::Char || + type_is_integral(var_type) || type_is_floating_point(var_type)) + { + if (!is_builtin || has_active_builtin(builtin, storage)) + { + bool is_composite_type = is_matrix(var_type) || is_array(var_type); + bool storage_is_stage_io = variable_storage_requires_stage_io(storage); + bool attribute_load_store = storage == StorageClassInput && get_execution_model() != ExecutionModelFragment; + + // Clip/CullDistance always needs to be declared as user attributes. + if (builtin == BuiltInClipDistance || builtin == BuiltInCullDistance) + is_builtin = false; + + // MSL does not allow matrices or arrays in input or output variables, so need to handle it specially. + if ((!is_builtin || attribute_load_store) && storage_is_stage_io && is_composite_type) + { + add_composite_variable_to_interface_block(storage, ib_var_ref, ib_type, var, meta); + } + else + { + add_plain_variable_to_interface_block(storage, ib_var_ref, ib_type, var, meta); + } + } + } +} + +// Fix up the mapping of variables to interface member indices, which is used to compile access chains +// for per-vertex variables in a tessellation control shader. +void CompilerMSL::fix_up_interface_member_indices(StorageClass storage, uint32_t ib_type_id) +{ + // Only needed for tessellation shaders and pull-model interpolants. + // Need to redirect interface indices back to variables themselves. + // For structs, each member of the struct need a separate instance. + if (!is_tesc_shader() && !(is_tese_shader() && storage == StorageClassInput) && + !(get_execution_model() == ExecutionModelFragment && storage == StorageClassInput && + !pull_model_inputs.empty())) + return; + + auto mbr_cnt = uint32_t(ir.meta[ib_type_id].members.size()); + for (uint32_t i = 0; i < mbr_cnt; i++) + { + uint32_t var_id = get_extended_member_decoration(ib_type_id, i, SPIRVCrossDecorationInterfaceOrigID); + if (!var_id) + continue; + auto &var = get(var_id); + + auto &type = get_variable_element_type(var); + + bool flatten_composites = variable_storage_requires_stage_io(var.storage); + bool is_block = has_decoration(type.self, DecorationBlock); + + uint32_t mbr_idx = uint32_t(-1); + if (type.basetype == SPIRType::Struct && (flatten_composites || is_block)) + mbr_idx = get_extended_member_decoration(ib_type_id, i, SPIRVCrossDecorationInterfaceMemberIndex); + + if (mbr_idx != uint32_t(-1)) + { + // Only set the lowest InterfaceMemberIndex for each variable member. + // IB struct members will be emitted in-order w.r.t. interface member index. + if (!has_extended_member_decoration(var_id, mbr_idx, SPIRVCrossDecorationInterfaceMemberIndex)) + set_extended_member_decoration(var_id, mbr_idx, SPIRVCrossDecorationInterfaceMemberIndex, i); + } + else + { + // Only set the lowest InterfaceMemberIndex for each variable. + // IB struct members will be emitted in-order w.r.t. interface member index. + if (!has_extended_decoration(var_id, SPIRVCrossDecorationInterfaceMemberIndex)) + set_extended_decoration(var_id, SPIRVCrossDecorationInterfaceMemberIndex, i); + } + } +} + +// Add an interface structure for the type of storage, which is either StorageClassInput or StorageClassOutput. +// Returns the ID of the newly added variable, or zero if no variable was added. +uint32_t CompilerMSL::add_interface_block(StorageClass storage, bool patch) +{ + // Accumulate the variables that should appear in the interface struct. + SmallVector vars; + bool incl_builtins = storage == StorageClassOutput || is_tessellation_shader(); + bool has_seen_barycentric = false; + + InterfaceBlockMeta meta; + + // Varying interfaces between stages which use "user()" attribute can be dealt with + // without explicit packing and unpacking of components. For any variables which link against the runtime + // in some way (vertex attributes, fragment output, etc), we'll need to deal with it somehow. + bool pack_components = + (storage == StorageClassInput && get_execution_model() == ExecutionModelVertex) || + (storage == StorageClassOutput && get_execution_model() == ExecutionModelFragment) || + (storage == StorageClassOutput && get_execution_model() == ExecutionModelVertex && capture_output_to_buffer); + + ir.for_each_typed_id([&](uint32_t var_id, SPIRVariable &var) { + if (var.storage != storage) + return; + + auto &type = this->get(var.basetype); + + bool is_builtin = is_builtin_variable(var); + bool is_block = has_decoration(type.self, DecorationBlock); + + auto bi_type = BuiltInMax; + bool builtin_is_gl_in_out = false; + if (is_builtin && !is_block) + { + bi_type = BuiltIn(get_decoration(var_id, DecorationBuiltIn)); + builtin_is_gl_in_out = bi_type == BuiltInPosition || bi_type == BuiltInPointSize || + bi_type == BuiltInClipDistance || bi_type == BuiltInCullDistance; + } + + if (is_builtin && is_block) + builtin_is_gl_in_out = true; + + uint32_t location = get_decoration(var_id, DecorationLocation); + + bool builtin_is_stage_in_out = builtin_is_gl_in_out || + bi_type == BuiltInLayer || bi_type == BuiltInViewportIndex || + bi_type == BuiltInBaryCoordKHR || bi_type == BuiltInBaryCoordNoPerspKHR || + bi_type == BuiltInFragDepth || + bi_type == BuiltInFragStencilRefEXT || bi_type == BuiltInSampleMask; + + // These builtins are part of the stage in/out structs. + bool is_interface_block_builtin = + builtin_is_stage_in_out || (is_tese_shader() && !msl_options.raw_buffer_tese_input && + (bi_type == BuiltInTessLevelOuter || bi_type == BuiltInTessLevelInner)); + + bool is_active = interface_variable_exists_in_entry_point(var.self); + if (is_builtin && is_active) + { + // Only emit the builtin if it's active in this entry point. Interface variable list might lie. + if (is_block) + { + // If any builtin is active, the block is active. + uint32_t mbr_cnt = uint32_t(type.member_types.size()); + for (uint32_t i = 0; !is_active && i < mbr_cnt; i++) + is_active = has_active_builtin(BuiltIn(get_member_decoration(type.self, i, DecorationBuiltIn)), storage); + } + else + { + is_active = has_active_builtin(bi_type, storage); + } + } + + bool filter_patch_decoration = (has_decoration(var_id, DecorationPatch) || is_patch_block(type)) == patch; + + bool hidden = is_hidden_variable(var, incl_builtins); + + // ClipDistance is never hidden, we need to emulate it when used as an input. + if (bi_type == BuiltInClipDistance || bi_type == BuiltInCullDistance) + hidden = false; + + // It's not enough to simply avoid marking fragment outputs if the pipeline won't + // accept them. We can't put them in the struct at all, or otherwise the compiler + // complains that the outputs weren't explicitly marked. + // Frag depth and stencil outputs are incompatible with explicit early fragment tests. + // In GLSL, depth and stencil outputs are just ignored when explicit early fragment tests are required. + // In Metal, it's a compilation error, so we need to exclude them from the output struct. + if (get_execution_model() == ExecutionModelFragment && storage == StorageClassOutput && !patch && + ((is_builtin && ((bi_type == BuiltInFragDepth && (!msl_options.enable_frag_depth_builtin || uses_explicit_early_fragment_test())) || + (bi_type == BuiltInFragStencilRefEXT && (!msl_options.enable_frag_stencil_ref_builtin || uses_explicit_early_fragment_test())))) || + (!is_builtin && !(msl_options.enable_frag_output_mask & (1 << location))))) + { + hidden = true; + disabled_frag_outputs.push_back(var_id); + // If a builtin, force it to have the proper name, and mark it as not part of the output struct. + if (is_builtin) + { + set_name(var_id, builtin_to_glsl(bi_type, StorageClassFunction)); + mask_stage_output_by_builtin(bi_type); + } + } + + // Barycentric inputs must be emitted in stage-in, because they can have interpolation arguments. + if (is_active && (bi_type == BuiltInBaryCoordKHR || bi_type == BuiltInBaryCoordNoPerspKHR)) + { + if (has_seen_barycentric) + SPIRV_CROSS_THROW("Cannot declare both BaryCoordNV and BaryCoordNoPerspNV in same shader in MSL."); + has_seen_barycentric = true; + hidden = false; + } + + if (is_active && !hidden && type.pointer && filter_patch_decoration && + (!is_builtin || is_interface_block_builtin)) + { + vars.push_back(&var); + + if (!is_builtin) + { + // Need to deal specially with DecorationComponent. + // Multiple variables can alias the same Location, and try to make sure each location is declared only once. + // We will swizzle data in and out to make this work. + // This is only relevant for vertex inputs and fragment outputs. + // Technically tessellation as well, but it is too complicated to support. + uint32_t component = get_decoration(var_id, DecorationComponent); + if (component != 0) + { + if (is_tessellation_shader()) + SPIRV_CROSS_THROW("Component decoration is not supported in tessellation shaders."); + else if (pack_components) + { + uint32_t array_size = 1; + if (!type.array.empty()) + array_size = to_array_size_literal(type); + + for (uint32_t location_offset = 0; location_offset < array_size; location_offset++) + { + auto &location_meta = meta.location_meta[location + location_offset]; + location_meta.num_components = max(location_meta.num_components, component + type.vecsize); + + // For variables sharing location, decorations and base type must match. + location_meta.base_type_id = type.self; + location_meta.flat = has_decoration(var.self, DecorationFlat); + location_meta.noperspective = has_decoration(var.self, DecorationNoPerspective); + location_meta.centroid = has_decoration(var.self, DecorationCentroid); + location_meta.sample = has_decoration(var.self, DecorationSample); + } + } + } + } + } + + if (is_tese_shader() && msl_options.raw_buffer_tese_input && patch && storage == StorageClassInput && + (bi_type == BuiltInTessLevelOuter || bi_type == BuiltInTessLevelInner)) + { + // In this case, we won't add the builtin to the interface struct, + // but we still need the hook to run to populate the arrays. + string base_ref = join(tess_factor_buffer_var_name, "[", to_expression(builtin_primitive_id_id), "]"); + const char *mbr_name = + bi_type == BuiltInTessLevelOuter ? "edgeTessellationFactor" : "insideTessellationFactor"; + add_tess_level_input(base_ref, mbr_name, var); + if (inputs_by_builtin.count(bi_type)) + { + uint32_t locn = inputs_by_builtin[bi_type].location; + mark_location_as_used_by_shader(locn, type, StorageClassInput); + } + } + }); + + // If no variables qualify, leave. + // For patch input in a tessellation evaluation shader, the per-vertex stage inputs + // are included in a special patch control point array. + if (vars.empty() && + !(!msl_options.raw_buffer_tese_input && storage == StorageClassInput && patch && stage_in_var_id)) + return 0; + + // Add a new typed variable for this interface structure. + // The initializer expression is allocated here, but populated when the function + // declaraion is emitted, because it is cleared after each compilation pass. + uint32_t next_id = ir.increase_bound_by(3); + uint32_t ib_type_id = next_id++; + auto &ib_type = set(ib_type_id, OpTypeStruct); + ib_type.basetype = SPIRType::Struct; + ib_type.storage = storage; + set_decoration(ib_type_id, DecorationBlock); + + uint32_t ib_var_id = next_id++; + auto &var = set(ib_var_id, ib_type_id, storage, 0); + var.initializer = next_id++; + + string ib_var_ref; + auto &entry_func = get(ir.default_entry_point); + switch (storage) + { + case StorageClassInput: + ib_var_ref = patch ? patch_stage_in_var_name : stage_in_var_name; + switch (get_execution_model()) + { + case ExecutionModelTessellationControl: + // Add a hook to populate the shared workgroup memory containing the gl_in array. + entry_func.fixup_hooks_in.push_back([=]() { + // Can't use PatchVertices, PrimitiveId, or InvocationId yet; the hooks for those may not have run yet. + if (msl_options.multi_patch_workgroup) + { + // n.b. builtin_invocation_id_id here is the dispatch global invocation ID, + // not the TC invocation ID. + statement("device ", to_name(ir.default_entry_point), "_", ib_var_ref, "* gl_in = &", + input_buffer_var_name, "[min(", to_expression(builtin_invocation_id_id), ".x / ", + get_entry_point().output_vertices, + ", spvIndirectParams[1] - 1) * spvIndirectParams[0]];"); + } + else + { + // It's safe to use InvocationId here because it's directly mapped to a + // Metal builtin, and therefore doesn't need a hook. + statement("if (", to_expression(builtin_invocation_id_id), " < spvIndirectParams[0])"); + statement(" ", input_wg_var_name, "[", to_expression(builtin_invocation_id_id), + "] = ", ib_var_ref, ";"); + statement("threadgroup_barrier(mem_flags::mem_threadgroup);"); + statement("if (", to_expression(builtin_invocation_id_id), + " >= ", get_entry_point().output_vertices, ")"); + statement(" return;"); + } + }); + break; + case ExecutionModelTessellationEvaluation: + if (!msl_options.raw_buffer_tese_input) + break; + if (patch) + { + entry_func.fixup_hooks_in.push_back( + [=]() + { + statement("const device ", to_name(ir.default_entry_point), "_", ib_var_ref, "& ", ib_var_ref, + " = ", patch_input_buffer_var_name, "[", to_expression(builtin_primitive_id_id), + "];"); + }); + } + else + { + entry_func.fixup_hooks_in.push_back( + [=]() + { + statement("const device ", to_name(ir.default_entry_point), "_", ib_var_ref, "* gl_in = &", + input_buffer_var_name, "[", to_expression(builtin_primitive_id_id), " * ", + get_entry_point().output_vertices, "];"); + }); + } + break; + default: + break; + } + break; + + case StorageClassOutput: + { + ib_var_ref = patch ? patch_stage_out_var_name : stage_out_var_name; + + // Add the output interface struct as a local variable to the entry function. + // If the entry point should return the output struct, set the entry function + // to return the output interface struct, otherwise to return nothing. + // Watch out for the rare case where the terminator of the last entry point block is a + // Kill, instead of a Return. Based on SPIR-V's block-domination rules, we assume that + // any block that has a Kill will also have a terminating Return, except the last block. + // Indicate the output var requires early initialization. + bool ep_should_return_output = !get_is_rasterization_disabled(); + uint32_t rtn_id = ep_should_return_output ? ib_var_id : 0; + if (!capture_output_to_buffer) + { + entry_func.add_local_variable(ib_var_id); + for (auto &blk_id : entry_func.blocks) + { + auto &blk = get(blk_id); + if (blk.terminator == SPIRBlock::Return || (blk.terminator == SPIRBlock::Kill && blk_id == entry_func.blocks.back())) + blk.return_value = rtn_id; + } + vars_needing_early_declaration.push_back(ib_var_id); + } + else + { + switch (get_execution_model()) + { + case ExecutionModelVertex: + case ExecutionModelTessellationEvaluation: + // Instead of declaring a struct variable to hold the output and then + // copying that to the output buffer, we'll declare the output variable + // as a reference to the final output element in the buffer. Then we can + // avoid the extra copy. + entry_func.fixup_hooks_in.push_back([=]() { + if (stage_out_var_id) + { + // The first member of the indirect buffer is always the number of vertices + // to draw. + // We zero-base the InstanceID & VertexID variables for HLSL emulation elsewhere, so don't do it twice + if (get_execution_model() == ExecutionModelVertex && msl_options.vertex_for_tessellation) + { + statement("device ", to_name(ir.default_entry_point), "_", ib_var_ref, "& ", ib_var_ref, + " = ", output_buffer_var_name, "[", to_expression(builtin_invocation_id_id), + ".y * ", to_expression(builtin_stage_input_size_id), ".x + ", + to_expression(builtin_invocation_id_id), ".x];"); + } + else if (msl_options.enable_base_index_zero) + { + statement("device ", to_name(ir.default_entry_point), "_", ib_var_ref, "& ", ib_var_ref, + " = ", output_buffer_var_name, "[", to_expression(builtin_instance_idx_id), + " * spvIndirectParams[0] + ", to_expression(builtin_vertex_idx_id), "];"); + } + else + { + statement("device ", to_name(ir.default_entry_point), "_", ib_var_ref, "& ", ib_var_ref, + " = ", output_buffer_var_name, "[(", to_expression(builtin_instance_idx_id), + " - ", to_expression(builtin_base_instance_id), ") * spvIndirectParams[0] + ", + to_expression(builtin_vertex_idx_id), " - ", + to_expression(builtin_base_vertex_id), "];"); + } + } + }); + break; + case ExecutionModelTessellationControl: + if (msl_options.multi_patch_workgroup) + { + // We cannot use PrimitiveId here, because the hook may not have run yet. + if (patch) + { + entry_func.fixup_hooks_in.push_back([=]() { + statement("device ", to_name(ir.default_entry_point), "_", ib_var_ref, "& ", ib_var_ref, + " = ", patch_output_buffer_var_name, "[", to_expression(builtin_invocation_id_id), + ".x / ", get_entry_point().output_vertices, "];"); + }); + } + else + { + entry_func.fixup_hooks_in.push_back([=]() { + statement("device ", to_name(ir.default_entry_point), "_", ib_var_ref, "* gl_out = &", + output_buffer_var_name, "[", to_expression(builtin_invocation_id_id), ".x - ", + to_expression(builtin_invocation_id_id), ".x % ", + get_entry_point().output_vertices, "];"); + }); + } + } + else + { + if (patch) + { + entry_func.fixup_hooks_in.push_back([=]() { + statement("device ", to_name(ir.default_entry_point), "_", ib_var_ref, "& ", ib_var_ref, + " = ", patch_output_buffer_var_name, "[", to_expression(builtin_primitive_id_id), + "];"); + }); + } + else + { + entry_func.fixup_hooks_in.push_back([=]() { + statement("device ", to_name(ir.default_entry_point), "_", ib_var_ref, "* gl_out = &", + output_buffer_var_name, "[", to_expression(builtin_primitive_id_id), " * ", + get_entry_point().output_vertices, "];"); + }); + } + } + break; + default: + break; + } + } + break; + } + + default: + break; + } + + set_name(ib_type_id, to_name(ir.default_entry_point) + "_" + ib_var_ref); + set_name(ib_var_id, ib_var_ref); + + for (auto *p_var : vars) + { + bool strip_array = (is_tesc_shader() || (is_tese_shader() && storage == StorageClassInput)) && !patch; + + // Fixing up flattened stores in TESC is impossible since the memory is group shared either via + // device (not masked) or threadgroup (masked) storage classes and it's race condition city. + meta.strip_array = strip_array; + meta.allow_local_declaration = !strip_array && !(is_tesc_shader() && storage == StorageClassOutput); + add_variable_to_interface_block(storage, ib_var_ref, ib_type, *p_var, meta); + } + + if (((is_tesc_shader() && msl_options.multi_patch_workgroup) || + (is_tese_shader() && msl_options.raw_buffer_tese_input)) && + storage == StorageClassInput) + { + // For tessellation inputs, add all outputs from the previous stage to ensure + // the struct containing them is the correct size and layout. + for (auto &input : inputs_by_location) + { + if (location_inputs_in_use.count(input.first.location) != 0) + continue; + + if (patch != (input.second.rate == MSL_SHADER_VARIABLE_RATE_PER_PATCH)) + continue; + + // Tessellation levels have their own struct, so there's no need to add them here. + if (input.second.builtin == BuiltInTessLevelOuter || input.second.builtin == BuiltInTessLevelInner) + continue; + + // Create a fake variable to put at the location. + uint32_t offset = ir.increase_bound_by(5); + uint32_t type_id = offset; + uint32_t vec_type_id = offset + 1; + uint32_t array_type_id = offset + 2; + uint32_t ptr_type_id = offset + 3; + uint32_t var_id = offset + 4; + + SPIRType type { OpTypeInt }; + switch (input.second.format) + { + case MSL_SHADER_VARIABLE_FORMAT_UINT16: + case MSL_SHADER_VARIABLE_FORMAT_ANY16: + type.basetype = SPIRType::UShort; + type.width = 16; + break; + case MSL_SHADER_VARIABLE_FORMAT_ANY32: + default: + type.basetype = SPIRType::UInt; + type.width = 32; + break; + } + set(type_id, type); + if (input.second.vecsize > 1) + { + type.op = OpTypeVector; + type.vecsize = input.second.vecsize; + set(vec_type_id, type); + type_id = vec_type_id; + } + + type.op = OpTypeArray; + type.array.push_back(0); + type.array_size_literal.push_back(true); + type.parent_type = type_id; + set(array_type_id, type); + type.self = type_id; + + type.op = OpTypePointer; + type.pointer = true; + type.pointer_depth++; + type.parent_type = array_type_id; + type.storage = storage; + auto &ptr_type = set(ptr_type_id, type); + ptr_type.self = array_type_id; + + auto &fake_var = set(var_id, ptr_type_id, storage); + set_decoration(var_id, DecorationLocation, input.first.location); + if (input.first.component) + set_decoration(var_id, DecorationComponent, input.first.component); + + meta.strip_array = true; + meta.allow_local_declaration = false; + add_variable_to_interface_block(storage, ib_var_ref, ib_type, fake_var, meta); + } + } + + if (capture_output_to_buffer && storage == StorageClassOutput) + { + // For captured output, add all inputs from the next stage to ensure + // the struct containing them is the correct size and layout. This is + // necessary for certain implicit builtins that may nonetheless be read, + // even when they aren't written. + for (auto &output : outputs_by_location) + { + if (location_outputs_in_use.count(output.first.location) != 0) + continue; + + // Create a fake variable to put at the location. + uint32_t offset = ir.increase_bound_by(5); + uint32_t type_id = offset; + uint32_t vec_type_id = offset + 1; + uint32_t array_type_id = offset + 2; + uint32_t ptr_type_id = offset + 3; + uint32_t var_id = offset + 4; + + SPIRType type { OpTypeInt }; + switch (output.second.format) + { + case MSL_SHADER_VARIABLE_FORMAT_UINT16: + case MSL_SHADER_VARIABLE_FORMAT_ANY16: + type.basetype = SPIRType::UShort; + type.width = 16; + break; + case MSL_SHADER_VARIABLE_FORMAT_ANY32: + default: + type.basetype = SPIRType::UInt; + type.width = 32; + break; + } + set(type_id, type); + if (output.second.vecsize > 1) + { + type.op = OpTypeVector; + type.vecsize = output.second.vecsize; + set(vec_type_id, type); + type_id = vec_type_id; + } + + if (is_tesc_shader()) + { + type.op = OpTypeArray; + type.array.push_back(0); + type.array_size_literal.push_back(true); + type.parent_type = type_id; + set(array_type_id, type); + } + + type.op = OpTypePointer; + type.pointer = true; + type.pointer_depth++; + type.parent_type = is_tesc_shader() ? array_type_id : type_id; + type.storage = storage; + auto &ptr_type = set(ptr_type_id, type); + ptr_type.self = type.parent_type; + + auto &fake_var = set(var_id, ptr_type_id, storage); + set_decoration(var_id, DecorationLocation, output.first.location); + if (output.first.component) + set_decoration(var_id, DecorationComponent, output.first.component); + + meta.strip_array = true; + meta.allow_local_declaration = false; + add_variable_to_interface_block(storage, ib_var_ref, ib_type, fake_var, meta); + } + } + + // When multiple variables need to access same location, + // unroll locations one by one and we will flatten output or input as necessary. + for (auto &loc : meta.location_meta) + { + uint32_t location = loc.first; + auto &location_meta = loc.second; + + uint32_t ib_mbr_idx = uint32_t(ib_type.member_types.size()); + uint32_t type_id = build_extended_vector_type(location_meta.base_type_id, location_meta.num_components); + ib_type.member_types.push_back(type_id); + + set_member_name(ib_type.self, ib_mbr_idx, join("m_location_", location)); + set_member_decoration(ib_type.self, ib_mbr_idx, DecorationLocation, location); + mark_location_as_used_by_shader(location, get(type_id), storage); + + if (location_meta.flat) + set_member_decoration(ib_type.self, ib_mbr_idx, DecorationFlat); + if (location_meta.noperspective) + set_member_decoration(ib_type.self, ib_mbr_idx, DecorationNoPerspective); + if (location_meta.centroid) + set_member_decoration(ib_type.self, ib_mbr_idx, DecorationCentroid); + if (location_meta.sample) + set_member_decoration(ib_type.self, ib_mbr_idx, DecorationSample); + } + + // Sort the members of the structure by their locations. + MemberSorter member_sorter(ib_type, ir.meta[ib_type_id], MemberSorter::LocationThenBuiltInType); + member_sorter.sort(); + + // The member indices were saved to the original variables, but after the members + // were sorted, those indices are now likely incorrect. Fix those up now. + fix_up_interface_member_indices(storage, ib_type_id); + + // For patch inputs, add one more member, holding the array of control point data. + if (is_tese_shader() && !msl_options.raw_buffer_tese_input && storage == StorageClassInput && patch && + stage_in_var_id) + { + uint32_t pcp_type_id = ir.increase_bound_by(1); + auto &pcp_type = set(pcp_type_id, ib_type); + pcp_type.basetype = SPIRType::ControlPointArray; + pcp_type.parent_type = pcp_type.type_alias = get_stage_in_struct_type().self; + pcp_type.storage = storage; + ir.meta[pcp_type_id] = ir.meta[ib_type.self]; + uint32_t mbr_idx = uint32_t(ib_type.member_types.size()); + ib_type.member_types.push_back(pcp_type_id); + set_member_name(ib_type.self, mbr_idx, "gl_in"); + } + + if (storage == StorageClassInput) + set_decoration(ib_var_id, DecorationNonWritable); + + return ib_var_id; +} + +uint32_t CompilerMSL::add_interface_block_pointer(uint32_t ib_var_id, StorageClass storage) +{ + if (!ib_var_id) + return 0; + + uint32_t ib_ptr_var_id; + uint32_t next_id = ir.increase_bound_by(3); + auto &ib_type = expression_type(ib_var_id); + if (is_tesc_shader() || (is_tese_shader() && msl_options.raw_buffer_tese_input)) + { + // Tessellation control per-vertex I/O is presented as an array, so we must + // do the same with our struct here. + uint32_t ib_ptr_type_id = next_id++; + auto &ib_ptr_type = set(ib_ptr_type_id, ib_type); + ib_ptr_type.op = OpTypePointer; + ib_ptr_type.parent_type = ib_ptr_type.type_alias = ib_type.self; + ib_ptr_type.pointer = true; + ib_ptr_type.pointer_depth++; + ib_ptr_type.storage = storage == StorageClassInput ? + ((is_tesc_shader() && msl_options.multi_patch_workgroup) || + (is_tese_shader() && msl_options.raw_buffer_tese_input) ? + StorageClassStorageBuffer : + StorageClassWorkgroup) : + StorageClassStorageBuffer; + ir.meta[ib_ptr_type_id] = ir.meta[ib_type.self]; + // To ensure that get_variable_data_type() doesn't strip off the pointer, + // which we need, use another pointer. + uint32_t ib_ptr_ptr_type_id = next_id++; + auto &ib_ptr_ptr_type = set(ib_ptr_ptr_type_id, ib_ptr_type); + ib_ptr_ptr_type.parent_type = ib_ptr_type_id; + ib_ptr_ptr_type.type_alias = ib_type.self; + ib_ptr_ptr_type.storage = StorageClassFunction; + ir.meta[ib_ptr_ptr_type_id] = ir.meta[ib_type.self]; + + ib_ptr_var_id = next_id; + set(ib_ptr_var_id, ib_ptr_ptr_type_id, StorageClassFunction, 0); + set_name(ib_ptr_var_id, storage == StorageClassInput ? "gl_in" : "gl_out"); + if (storage == StorageClassInput) + set_decoration(ib_ptr_var_id, DecorationNonWritable); + } + else + { + // Tessellation evaluation per-vertex inputs are also presented as arrays. + // But, in Metal, this array uses a very special type, 'patch_control_point', + // which is a container that can be used to access the control point data. + // To represent this, a special 'ControlPointArray' type has been added to the + // SPIRV-Cross type system. It should only be generated by and seen in the MSL + // backend (i.e. this one). + uint32_t pcp_type_id = next_id++; + auto &pcp_type = set(pcp_type_id, ib_type); + pcp_type.basetype = SPIRType::ControlPointArray; + pcp_type.parent_type = pcp_type.type_alias = ib_type.self; + pcp_type.storage = storage; + ir.meta[pcp_type_id] = ir.meta[ib_type.self]; + + ib_ptr_var_id = next_id; + set(ib_ptr_var_id, pcp_type_id, storage, 0); + set_name(ib_ptr_var_id, "gl_in"); + ir.meta[ib_ptr_var_id].decoration.qualified_alias = join(patch_stage_in_var_name, ".gl_in"); + } + return ib_ptr_var_id; +} + +// Ensure that the type is compatible with the builtin. +// If it is, simply return the given type ID. +// Otherwise, create a new type, and return it's ID. +uint32_t CompilerMSL::ensure_correct_builtin_type(uint32_t type_id, BuiltIn builtin) +{ + auto &type = get(type_id); + auto &pointee_type = get_pointee_type(type); + + if ((builtin == BuiltInSampleMask && is_array(pointee_type)) || + ((builtin == BuiltInLayer || builtin == BuiltInViewportIndex || builtin == BuiltInFragStencilRefEXT) && + pointee_type.basetype != SPIRType::UInt)) + { + uint32_t next_id = ir.increase_bound_by(is_pointer(type) ? 2 : 1); + uint32_t base_type_id = next_id++; + auto &base_type = set(base_type_id, OpTypeInt); + base_type.basetype = SPIRType::UInt; + base_type.width = 32; + + if (!is_pointer(type)) + return base_type_id; + + uint32_t ptr_type_id = next_id++; + auto &ptr_type = set(ptr_type_id, base_type); + ptr_type.op = spv::OpTypePointer; + ptr_type.pointer = true; + ptr_type.pointer_depth++; + ptr_type.storage = type.storage; + ptr_type.parent_type = base_type_id; + return ptr_type_id; + } + + return type_id; +} + +// Ensure that the type is compatible with the shader input. +// If it is, simply return the given type ID. +// Otherwise, create a new type, and return its ID. +uint32_t CompilerMSL::ensure_correct_input_type(uint32_t type_id, uint32_t location, uint32_t component, uint32_t num_components, bool strip_array) +{ + auto &type = get(type_id); + + uint32_t max_array_dimensions = strip_array ? 1 : 0; + + // Struct and array types must match exactly. + if (type.basetype == SPIRType::Struct || type.array.size() > max_array_dimensions) + return type_id; + + auto p_va = inputs_by_location.find({location, component}); + if (p_va == end(inputs_by_location)) + { + if (num_components > type.vecsize) + return build_extended_vector_type(type_id, num_components); + else + return type_id; + } + + if (num_components == 0) + num_components = p_va->second.vecsize; + + switch (p_va->second.format) + { + case MSL_SHADER_VARIABLE_FORMAT_UINT8: + { + switch (type.basetype) + { + case SPIRType::UByte: + case SPIRType::UShort: + case SPIRType::UInt: + if (num_components > type.vecsize) + return build_extended_vector_type(type_id, num_components); + else + return type_id; + + case SPIRType::Short: + return build_extended_vector_type(type_id, num_components > type.vecsize ? num_components : type.vecsize, + SPIRType::UShort); + case SPIRType::Int: + return build_extended_vector_type(type_id, num_components > type.vecsize ? num_components : type.vecsize, + SPIRType::UInt); + + default: + SPIRV_CROSS_THROW("Vertex attribute type mismatch between host and shader"); + } + } + + case MSL_SHADER_VARIABLE_FORMAT_UINT16: + { + switch (type.basetype) + { + case SPIRType::UShort: + case SPIRType::UInt: + if (num_components > type.vecsize) + return build_extended_vector_type(type_id, num_components); + else + return type_id; + + case SPIRType::Int: + return build_extended_vector_type(type_id, num_components > type.vecsize ? num_components : type.vecsize, + SPIRType::UInt); + + default: + SPIRV_CROSS_THROW("Vertex attribute type mismatch between host and shader"); + } + } + + default: + if (num_components > type.vecsize) + type_id = build_extended_vector_type(type_id, num_components); + break; + } + + return type_id; +} + +void CompilerMSL::mark_struct_members_packed(const SPIRType &type) +{ + // Handle possible recursion when a struct contains a pointer to its own type nested somewhere. + if (has_extended_decoration(type.self, SPIRVCrossDecorationPhysicalTypePacked)) + return; + + set_extended_decoration(type.self, SPIRVCrossDecorationPhysicalTypePacked); + + // Problem case! Struct needs to be placed at an awkward alignment. + // Mark every member of the child struct as packed. + uint32_t mbr_cnt = uint32_t(type.member_types.size()); + for (uint32_t i = 0; i < mbr_cnt; i++) + { + auto &mbr_type = get(type.member_types[i]); + if (mbr_type.basetype == SPIRType::Struct) + { + // Recursively mark structs as packed. + auto *struct_type = &mbr_type; + while (!struct_type->array.empty()) + struct_type = &get(struct_type->parent_type); + mark_struct_members_packed(*struct_type); + } + else if (!is_scalar(mbr_type)) + set_extended_member_decoration(type.self, i, SPIRVCrossDecorationPhysicalTypePacked); + } +} + +void CompilerMSL::mark_scalar_layout_structs(const SPIRType &type) +{ + uint32_t mbr_cnt = uint32_t(type.member_types.size()); + for (uint32_t i = 0; i < mbr_cnt; i++) + { + // Handle possible recursion when a struct contains a pointer to its own type nested somewhere. + auto &mbr_type = get(type.member_types[i]); + if (mbr_type.basetype == SPIRType::Struct && !(mbr_type.pointer && mbr_type.storage == StorageClassPhysicalStorageBuffer)) + { + auto *struct_type = &mbr_type; + while (!struct_type->array.empty()) + struct_type = &get(struct_type->parent_type); + + if (has_extended_decoration(struct_type->self, SPIRVCrossDecorationPhysicalTypePacked)) + continue; + + uint32_t msl_alignment = get_declared_struct_member_alignment_msl(type, i); + uint32_t msl_size = get_declared_struct_member_size_msl(type, i); + uint32_t spirv_offset = type_struct_member_offset(type, i); + uint32_t spirv_offset_next; + if (i + 1 < mbr_cnt) + spirv_offset_next = type_struct_member_offset(type, i + 1); + else + spirv_offset_next = spirv_offset + msl_size; + + // Both are complicated cases. In scalar layout, a struct of float3 might just consume 12 bytes, + // and the next member will be placed at offset 12. + bool struct_is_misaligned = (spirv_offset % msl_alignment) != 0; + bool struct_is_too_large = spirv_offset + msl_size > spirv_offset_next; + uint32_t array_stride = 0; + bool struct_needs_explicit_padding = false; + + // Verify that if a struct is used as an array that ArrayStride matches the effective size of the struct. + if (!mbr_type.array.empty()) + { + array_stride = type_struct_member_array_stride(type, i); + uint32_t dimensions = uint32_t(mbr_type.array.size() - 1); + for (uint32_t dim = 0; dim < dimensions; dim++) + { + uint32_t array_size = to_array_size_literal(mbr_type, dim); + array_stride /= max(array_size, 1u); + } + + // Set expected struct size based on ArrayStride. + struct_needs_explicit_padding = true; + + // If struct size is larger than array stride, we might be able to fit, if we tightly pack. + if (get_declared_struct_size_msl(*struct_type) > array_stride) + struct_is_too_large = true; + } + + if (struct_is_misaligned || struct_is_too_large) + mark_struct_members_packed(*struct_type); + mark_scalar_layout_structs(*struct_type); + + if (struct_needs_explicit_padding) + { + msl_size = get_declared_struct_size_msl(*struct_type, true, true); + if (array_stride < msl_size) + { + SPIRV_CROSS_THROW("Cannot express an array stride smaller than size of struct type."); + } + else + { + if (has_extended_decoration(struct_type->self, SPIRVCrossDecorationPaddingTarget)) + { + if (array_stride != + get_extended_decoration(struct_type->self, SPIRVCrossDecorationPaddingTarget)) + SPIRV_CROSS_THROW( + "A struct is used with different array strides. Cannot express this in MSL."); + } + else + set_extended_decoration(struct_type->self, SPIRVCrossDecorationPaddingTarget, array_stride); + } + } + } + } +} + +// Sort the members of the struct type by offset, and pack and then pad members where needed +// to align MSL members with SPIR-V offsets. The struct members are iterated twice. Packing +// occurs first, followed by padding, because packing a member reduces both its size and its +// natural alignment, possibly requiring a padding member to be added ahead of it. +void CompilerMSL::align_struct(SPIRType &ib_type, unordered_set &aligned_structs) +{ + // We align structs recursively, so stop any redundant work. + ID &ib_type_id = ib_type.self; + if (aligned_structs.count(ib_type_id)) + return; + aligned_structs.insert(ib_type_id); + + // Sort the members of the interface structure by their offset. + // They should already be sorted per SPIR-V spec anyway. + MemberSorter member_sorter(ib_type, ir.meta[ib_type_id], MemberSorter::Offset); + member_sorter.sort(); + + auto mbr_cnt = uint32_t(ib_type.member_types.size()); + + for (uint32_t mbr_idx = 0; mbr_idx < mbr_cnt; mbr_idx++) + { + // Pack any dependent struct types before we pack a parent struct. + auto &mbr_type = get(ib_type.member_types[mbr_idx]); + if (mbr_type.basetype == SPIRType::Struct) + align_struct(mbr_type, aligned_structs); + } + + // Test the alignment of each member, and if a member should be closer to the previous + // member than the default spacing expects, it is likely that the previous member is in + // a packed format. If so, and the previous member is packable, pack it. + // For example ... this applies to any 3-element vector that is followed by a scalar. + uint32_t msl_offset = 0; + for (uint32_t mbr_idx = 0; mbr_idx < mbr_cnt; mbr_idx++) + { + // This checks the member in isolation, if the member needs some kind of type remapping to conform to SPIR-V + // offsets, array strides and matrix strides. + ensure_member_packing_rules_msl(ib_type, mbr_idx); + + // Align current offset to the current member's default alignment. If the member was packed, it will observe + // the updated alignment here. + uint32_t msl_align_mask = get_declared_struct_member_alignment_msl(ib_type, mbr_idx) - 1; + uint32_t aligned_msl_offset = (msl_offset + msl_align_mask) & ~msl_align_mask; + + // Fetch the member offset as declared in the SPIRV. + uint32_t spirv_mbr_offset = get_member_decoration(ib_type_id, mbr_idx, DecorationOffset); + if (spirv_mbr_offset > aligned_msl_offset) + { + // Since MSL and SPIR-V have slightly different struct member alignment and + // size rules, we'll pad to standard C-packing rules with a char[] array. If the member is farther + // away than C-packing, expects, add an inert padding member before the the member. + uint32_t padding_bytes = spirv_mbr_offset - aligned_msl_offset; + set_extended_member_decoration(ib_type_id, mbr_idx, SPIRVCrossDecorationPaddingTarget, padding_bytes); + + // Re-align as a sanity check that aligning post-padding matches up. + msl_offset += padding_bytes; + aligned_msl_offset = (msl_offset + msl_align_mask) & ~msl_align_mask; + } + else if (spirv_mbr_offset < aligned_msl_offset) + { + // This should not happen, but deal with unexpected scenarios. + // It *might* happen if a sub-struct has a larger alignment requirement in MSL than SPIR-V. + SPIRV_CROSS_THROW("Cannot represent buffer block correctly in MSL."); + } + + assert(aligned_msl_offset == spirv_mbr_offset); + + // Increment the current offset to be positioned immediately after the current member. + // Don't do this for the last member since it can be unsized, and it is not relevant for padding purposes here. + if (mbr_idx + 1 < mbr_cnt) + msl_offset = aligned_msl_offset + get_declared_struct_member_size_msl(ib_type, mbr_idx); + } +} + +bool CompilerMSL::validate_member_packing_rules_msl(const SPIRType &type, uint32_t index) const +{ + auto &mbr_type = get(type.member_types[index]); + uint32_t spirv_offset = get_member_decoration(type.self, index, DecorationOffset); + + if (index + 1 < type.member_types.size()) + { + // First, we will check offsets. If SPIR-V offset + MSL size > SPIR-V offset of next member, + // we *must* perform some kind of remapping, no way getting around it. + // We can always pad after this member if necessary, so that case is fine. + uint32_t spirv_offset_next = get_member_decoration(type.self, index + 1, DecorationOffset); + assert(spirv_offset_next >= spirv_offset); + uint32_t maximum_size = spirv_offset_next - spirv_offset; + uint32_t msl_mbr_size = get_declared_struct_member_size_msl(type, index); + if (msl_mbr_size > maximum_size) + return false; + } + + if (is_array(mbr_type)) + { + // If we have an array type, array stride must match exactly with SPIR-V. + + // An exception to this requirement is if we have one array element. + // This comes from DX scalar layout workaround. + // If app tries to be cheeky and access the member out of bounds, this will not work, but this is the best we can do. + // In OpAccessChain with logical memory models, access chains must be in-bounds in SPIR-V specification. + bool relax_array_stride = mbr_type.array.back() == 1 && mbr_type.array_size_literal.back(); + + if (!relax_array_stride) + { + uint32_t spirv_array_stride = type_struct_member_array_stride(type, index); + uint32_t msl_array_stride = get_declared_struct_member_array_stride_msl(type, index); + if (spirv_array_stride != msl_array_stride) + return false; + } + } + + if (is_matrix(mbr_type)) + { + // Need to check MatrixStride as well. + uint32_t spirv_matrix_stride = type_struct_member_matrix_stride(type, index); + uint32_t msl_matrix_stride = get_declared_struct_member_matrix_stride_msl(type, index); + if (spirv_matrix_stride != msl_matrix_stride) + return false; + } + + // Now, we check alignment. + uint32_t msl_alignment = get_declared_struct_member_alignment_msl(type, index); + if ((spirv_offset % msl_alignment) != 0) + return false; + + // We're in the clear. + return true; +} + +// Here we need to verify that the member type we declare conforms to Offset, ArrayStride or MatrixStride restrictions. +// If there is a mismatch, we need to emit remapped types, either normal types, or "packed_X" types. +// In odd cases we need to emit packed and remapped types, for e.g. weird matrices or arrays with weird array strides. +void CompilerMSL::ensure_member_packing_rules_msl(SPIRType &ib_type, uint32_t index) +{ + if (validate_member_packing_rules_msl(ib_type, index)) + return; + + // We failed validation. + // This case will be nightmare-ish to deal with. This could possibly happen if struct alignment does not quite + // match up with what we want. Scalar block layout comes to mind here where we might have to work around the rule + // that struct alignment == max alignment of all members and struct size depends on this alignment. + // Can't repack structs, but can repack pointers to structs. + auto &mbr_type = get(ib_type.member_types[index]); + bool is_buff_ptr = mbr_type.pointer && mbr_type.storage == StorageClassPhysicalStorageBuffer; + if (mbr_type.basetype == SPIRType::Struct && !is_buff_ptr) + SPIRV_CROSS_THROW("Cannot perform any repacking for structs when it is used as a member of another struct."); + + // Perform remapping here. + // There is nothing to be gained by using packed scalars, so don't attempt it. + if (!is_scalar(ib_type)) + set_extended_member_decoration(ib_type.self, index, SPIRVCrossDecorationPhysicalTypePacked); + + // Try validating again, now with packed. + if (validate_member_packing_rules_msl(ib_type, index)) + return; + + // We're in deep trouble, and we need to create a new PhysicalType which matches up with what we expect. + // A lot of work goes here ... + // We will need remapping on Load and Store to translate the types between Logical and Physical. + + // First, we check if we have small vector std140 array. + // We detect this if we have an array of vectors, and array stride is greater than number of elements. + if (!mbr_type.array.empty() && !is_matrix(mbr_type)) + { + uint32_t array_stride = type_struct_member_array_stride(ib_type, index); + + // Hack off array-of-arrays until we find the array stride per element we must have to make it work. + uint32_t dimensions = uint32_t(mbr_type.array.size() - 1); + for (uint32_t dim = 0; dim < dimensions; dim++) + array_stride /= max(to_array_size_literal(mbr_type, dim), 1u); + + // Pointers are 8 bytes + uint32_t mbr_width_in_bytes = is_buff_ptr ? 8 : (mbr_type.width / 8); + uint32_t elems_per_stride = array_stride / mbr_width_in_bytes; + + if (elems_per_stride == 3) + SPIRV_CROSS_THROW("Cannot use ArrayStride of 3 elements in remapping scenarios."); + else if (elems_per_stride > 4 && elems_per_stride != 8) + SPIRV_CROSS_THROW("Cannot represent vectors with more than 4 elements in MSL."); + + if (elems_per_stride == 8) + { + if (mbr_type.width == 16) + add_spv_func_and_recompile(SPVFuncImplPaddedStd140); + else + SPIRV_CROSS_THROW("Unexpected type in std140 wide array resolve."); + } + + auto physical_type = mbr_type; + physical_type.vecsize = elems_per_stride; + physical_type.parent_type = 0; + + // If this is a physical buffer pointer, replace type with a ulongn vector. + if (is_buff_ptr) + { + physical_type.width = 64; + physical_type.basetype = to_unsigned_basetype(physical_type.width); + physical_type.pointer = false; + physical_type.pointer_depth = false; + physical_type.forward_pointer = false; + } + + uint32_t type_id = ir.increase_bound_by(1); + set(type_id, physical_type); + set_extended_member_decoration(ib_type.self, index, SPIRVCrossDecorationPhysicalTypeID, type_id); + set_decoration(type_id, DecorationArrayStride, array_stride); + + // Remove packed_ for vectors of size 1, 2 and 4. + unset_extended_member_decoration(ib_type.self, index, SPIRVCrossDecorationPhysicalTypePacked); + } + else if (is_matrix(mbr_type)) + { + // MatrixStride might be std140-esque. + uint32_t matrix_stride = type_struct_member_matrix_stride(ib_type, index); + + uint32_t elems_per_stride = matrix_stride / (mbr_type.width / 8); + + if (elems_per_stride == 3) + SPIRV_CROSS_THROW("Cannot use ArrayStride of 3 elements in remapping scenarios."); + else if (elems_per_stride > 4 && elems_per_stride != 8) + SPIRV_CROSS_THROW("Cannot represent vectors with more than 4 elements in MSL."); + + if (elems_per_stride == 8) + { + if (mbr_type.basetype != SPIRType::Half) + SPIRV_CROSS_THROW("Unexpected type in std140 wide matrix stride resolve."); + add_spv_func_and_recompile(SPVFuncImplPaddedStd140); + } + + bool row_major = has_member_decoration(ib_type.self, index, DecorationRowMajor); + auto physical_type = mbr_type; + physical_type.parent_type = 0; + + if (row_major) + physical_type.columns = elems_per_stride; + else + physical_type.vecsize = elems_per_stride; + uint32_t type_id = ir.increase_bound_by(1); + set(type_id, physical_type); + set_extended_member_decoration(ib_type.self, index, SPIRVCrossDecorationPhysicalTypeID, type_id); + + // Remove packed_ for vectors of size 1, 2 and 4. + unset_extended_member_decoration(ib_type.self, index, SPIRVCrossDecorationPhysicalTypePacked); + } + else + SPIRV_CROSS_THROW("Found a buffer packing case which we cannot represent in MSL."); + + // Try validating again, now with physical type remapping. + if (validate_member_packing_rules_msl(ib_type, index)) + return; + + // We might have a particular odd scalar layout case where the last element of an array + // does not take up as much space as the ArrayStride or MatrixStride. This can happen with DX cbuffers. + // The "proper" workaround for this is extremely painful and essentially impossible in the edge case of float3[], + // so we hack around it by declaring the offending array or matrix with one less array size/col/row, + // and rely on padding to get the correct value. We will technically access arrays out of bounds into the padding region, + // but it should spill over gracefully without too much trouble. We rely on behavior like this for unsized arrays anyways. + + // E.g. we might observe a physical layout of: + // { float2 a[2]; float b; } in cbuffer layout where ArrayStride of a is 16, but offset of b is 24, packed right after a[1] ... + uint32_t type_id = get_extended_member_decoration(ib_type.self, index, SPIRVCrossDecorationPhysicalTypeID); + auto &type = get(type_id); + + // Modify the physical type in-place. This is safe since each physical type workaround is a copy. + if (is_array(type)) + { + if (type.array.back() > 1) + { + if (!type.array_size_literal.back()) + SPIRV_CROSS_THROW("Cannot apply scalar layout workaround with spec constant array size."); + type.array.back() -= 1; + } + else + { + // We have an array of size 1, so we cannot decrement that. Our only option now is to + // force a packed layout instead, and drop the physical type remap since ArrayStride is meaningless now. + unset_extended_member_decoration(ib_type.self, index, SPIRVCrossDecorationPhysicalTypeID); + set_extended_member_decoration(ib_type.self, index, SPIRVCrossDecorationPhysicalTypePacked); + } + } + else if (is_matrix(type)) + { + bool row_major = has_member_decoration(ib_type.self, index, DecorationRowMajor); + if (!row_major) + { + // Slice off one column. If we only have 2 columns, this might turn the matrix into a vector with one array element instead. + if (type.columns > 2) + { + type.columns--; + } + else if (type.columns == 2) + { + type.columns = 1; + assert(type.array.empty()); + type.op = OpTypeArray; + type.array.push_back(1); + type.array_size_literal.push_back(true); + } + } + else + { + // Slice off one row. If we only have 2 rows, this might turn the matrix into a vector with one array element instead. + if (type.vecsize > 2) + { + type.vecsize--; + } + else if (type.vecsize == 2) + { + type.vecsize = type.columns; + type.columns = 1; + assert(type.array.empty()); + type.op = OpTypeArray; + type.array.push_back(1); + type.array_size_literal.push_back(true); + } + } + } + + // This better validate now, or we must fail gracefully. + if (!validate_member_packing_rules_msl(ib_type, index)) + SPIRV_CROSS_THROW("Found a buffer packing case which we cannot represent in MSL."); +} + +void CompilerMSL::emit_store_statement(uint32_t lhs_expression, uint32_t rhs_expression) +{ + auto &type = expression_type(rhs_expression); + + bool lhs_remapped_type = has_extended_decoration(lhs_expression, SPIRVCrossDecorationPhysicalTypeID); + bool lhs_packed_type = has_extended_decoration(lhs_expression, SPIRVCrossDecorationPhysicalTypePacked); + auto *lhs_e = maybe_get(lhs_expression); + auto *rhs_e = maybe_get(rhs_expression); + + bool transpose = lhs_e && lhs_e->need_transpose; + + if (has_decoration(lhs_expression, DecorationBuiltIn) && + BuiltIn(get_decoration(lhs_expression, DecorationBuiltIn)) == BuiltInSampleMask && + is_array(type)) + { + // Storing an array to SampleMask, have to remove the array-ness before storing. + statement(to_expression(lhs_expression), " = ", to_enclosed_unpacked_expression(rhs_expression), "[0];"); + register_write(lhs_expression); + } + else if (!lhs_remapped_type && !lhs_packed_type) + { + // No physical type remapping, and no packed type, so can just emit a store directly. + + // We might not be dealing with remapped physical types or packed types, + // but we might be doing a clean store to a row-major matrix. + // In this case, we just flip transpose states, and emit the store, a transpose must be in the RHS expression, if any. + if (is_matrix(type) && lhs_e && lhs_e->need_transpose) + { + lhs_e->need_transpose = false; + + if (rhs_e && rhs_e->need_transpose) + { + // Direct copy, but might need to unpack RHS. + // Skip the transpose, as we will transpose when writing to LHS and transpose(transpose(T)) == T. + rhs_e->need_transpose = false; + statement(to_expression(lhs_expression), " = ", to_unpacked_row_major_matrix_expression(rhs_expression), + ";"); + rhs_e->need_transpose = true; + } + else + statement(to_expression(lhs_expression), " = transpose(", to_unpacked_expression(rhs_expression), ");"); + + lhs_e->need_transpose = true; + register_write(lhs_expression); + } + else if (lhs_e && lhs_e->need_transpose) + { + lhs_e->need_transpose = false; + + // Storing a column to a row-major matrix. Unroll the write. + for (uint32_t c = 0; c < type.vecsize; c++) + { + auto lhs_expr = to_dereferenced_expression(lhs_expression); + auto column_index = lhs_expr.find_last_of('['); + if (column_index != string::npos) + { + statement(lhs_expr.insert(column_index, join('[', c, ']')), " = ", + to_extract_component_expression(rhs_expression, c), ";"); + } + } + lhs_e->need_transpose = true; + register_write(lhs_expression); + } + else + CompilerGLSL::emit_store_statement(lhs_expression, rhs_expression); + } + else if (!lhs_remapped_type && !is_matrix(type) && !transpose) + { + // Even if the target type is packed, we can directly store to it. We cannot store to packed matrices directly, + // since they are declared as array of vectors instead, and we need the fallback path below. + CompilerGLSL::emit_store_statement(lhs_expression, rhs_expression); + } + else + { + // Special handling when storing to a remapped physical type. + // This is mostly to deal with std140 padded matrices or vectors. + + TypeID physical_type_id = lhs_remapped_type ? + ID(get_extended_decoration(lhs_expression, SPIRVCrossDecorationPhysicalTypeID)) : + type.self; + + auto &physical_type = get(physical_type_id); + + string cast_addr_space = "thread"; + auto *p_var_lhs = maybe_get_backing_variable(lhs_expression); + if (p_var_lhs) + cast_addr_space = get_type_address_space(get(p_var_lhs->basetype), lhs_expression); + + if (is_matrix(type)) + { + const char *packed_pfx = lhs_packed_type ? "packed_" : ""; + + // Packed matrices are stored as arrays of packed vectors, so we need + // to assign the vectors one at a time. + // For row-major matrices, we need to transpose the *right-hand* side, + // not the left-hand side. + + // Lots of cases to cover here ... + + bool rhs_transpose = rhs_e && rhs_e->need_transpose; + SPIRType write_type = type; + string cast_expr; + + // We're dealing with transpose manually. + if (rhs_transpose) + rhs_e->need_transpose = false; + + if (transpose) + { + // We're dealing with transpose manually. + lhs_e->need_transpose = false; + write_type.vecsize = type.columns; + write_type.columns = 1; + + if (physical_type.columns != type.columns) + cast_expr = join("(", cast_addr_space, " ", packed_pfx, type_to_glsl(write_type), "&)"); + + if (rhs_transpose) + { + // If RHS is also transposed, we can just copy row by row. + for (uint32_t i = 0; i < type.vecsize; i++) + { + statement(cast_expr, to_enclosed_expression(lhs_expression), "[", i, "]", " = ", + to_unpacked_row_major_matrix_expression(rhs_expression), "[", i, "];"); + } + } + else + { + auto vector_type = expression_type(rhs_expression); + vector_type.vecsize = vector_type.columns; + vector_type.columns = 1; + + // Transpose on the fly. Emitting a lot of full transpose() ops and extracting lanes seems very bad, + // so pick out individual components instead. + for (uint32_t i = 0; i < type.vecsize; i++) + { + string rhs_row = type_to_glsl_constructor(vector_type) + "("; + for (uint32_t j = 0; j < vector_type.vecsize; j++) + { + rhs_row += join(to_enclosed_unpacked_expression(rhs_expression), "[", j, "][", i, "]"); + if (j + 1 < vector_type.vecsize) + rhs_row += ", "; + } + rhs_row += ")"; + + statement(cast_expr, to_enclosed_expression(lhs_expression), "[", i, "]", " = ", rhs_row, ";"); + } + } + + // We're dealing with transpose manually. + lhs_e->need_transpose = true; + } + else + { + write_type.columns = 1; + + if (physical_type.vecsize != type.vecsize) + cast_expr = join("(", cast_addr_space, " ", packed_pfx, type_to_glsl(write_type), "&)"); + + if (rhs_transpose) + { + auto vector_type = expression_type(rhs_expression); + vector_type.columns = 1; + + // Transpose on the fly. Emitting a lot of full transpose() ops and extracting lanes seems very bad, + // so pick out individual components instead. + for (uint32_t i = 0; i < type.columns; i++) + { + string rhs_row = type_to_glsl_constructor(vector_type) + "("; + for (uint32_t j = 0; j < vector_type.vecsize; j++) + { + // Need to explicitly unpack expression since we've mucked with transpose state. + auto unpacked_expr = to_unpacked_row_major_matrix_expression(rhs_expression); + rhs_row += join(unpacked_expr, "[", j, "][", i, "]"); + if (j + 1 < vector_type.vecsize) + rhs_row += ", "; + } + rhs_row += ")"; + + statement(cast_expr, to_enclosed_expression(lhs_expression), "[", i, "]", " = ", rhs_row, ";"); + } + } + else + { + // Copy column-by-column. + for (uint32_t i = 0; i < type.columns; i++) + { + statement(cast_expr, to_enclosed_expression(lhs_expression), "[", i, "]", " = ", + to_enclosed_unpacked_expression(rhs_expression), "[", i, "];"); + } + } + } + + // We're dealing with transpose manually. + if (rhs_transpose) + rhs_e->need_transpose = true; + } + else if (transpose) + { + lhs_e->need_transpose = false; + + SPIRType write_type = type; + write_type.vecsize = 1; + write_type.columns = 1; + + // Storing a column to a row-major matrix. Unroll the write. + for (uint32_t c = 0; c < type.vecsize; c++) + { + auto lhs_expr = to_enclosed_expression(lhs_expression); + auto column_index = lhs_expr.find_last_of('['); + + // Get rid of any ".data" half8 handling here, we're casting to scalar anyway. + auto end_column_index = lhs_expr.find_last_of(']'); + auto end_dot_index = lhs_expr.find_last_of('.'); + if (end_dot_index != string::npos && end_dot_index > end_column_index) + lhs_expr.resize(end_dot_index); + + if (column_index != string::npos) + { + statement("((", cast_addr_space, " ", type_to_glsl(write_type), "*)&", + lhs_expr.insert(column_index, join('[', c, ']', ")")), " = ", + to_extract_component_expression(rhs_expression, c), ";"); + } + } + + lhs_e->need_transpose = true; + } + else if ((is_matrix(physical_type) || is_array(physical_type)) && + physical_type.vecsize <= 4 && + physical_type.vecsize > type.vecsize) + { + assert(type.vecsize >= 1 && type.vecsize <= 3); + + // If we have packed types, we cannot use swizzled stores. + // We could technically unroll the store for each element if needed. + // When remapping to a std140 physical type, we always get float4, + // and the packed decoration should always be removed. + assert(!lhs_packed_type); + + string lhs = to_dereferenced_expression(lhs_expression); + string rhs = to_pointer_expression(rhs_expression); + + // Unpack the expression so we can store to it with a float or float2. + // It's still an l-value, so it's fine. Most other unpacking of expressions turn them into r-values instead. + lhs = join("(", cast_addr_space, " ", type_to_glsl(type), "&)", enclose_expression(lhs)); + if (!optimize_read_modify_write(expression_type(rhs_expression), lhs, rhs)) + statement(lhs, " = ", rhs, ";"); + } + else if (!is_matrix(type)) + { + string lhs = to_dereferenced_expression(lhs_expression); + string rhs = to_pointer_expression(rhs_expression); + if (!optimize_read_modify_write(expression_type(rhs_expression), lhs, rhs)) + statement(lhs, " = ", rhs, ";"); + } + + register_write(lhs_expression); + } +} + +static bool expression_ends_with(const string &expr_str, const std::string &ending) +{ + if (expr_str.length() >= ending.length()) + return (expr_str.compare(expr_str.length() - ending.length(), ending.length(), ending) == 0); + else + return false; +} + +// Converts the format of the current expression from packed to unpacked, +// by wrapping the expression in a constructor of the appropriate type. +// Also, handle special physical ID remapping scenarios, similar to emit_store_statement(). +string CompilerMSL::unpack_expression_type(string expr_str, const SPIRType &type, uint32_t physical_type_id, + bool packed, bool row_major) +{ + // Trivial case, nothing to do. + if (physical_type_id == 0 && !packed) + return expr_str; + + const SPIRType *physical_type = nullptr; + if (physical_type_id) + physical_type = &get(physical_type_id); + + static const char *swizzle_lut[] = { + ".x", + ".xy", + ".xyz", + "", + }; + + // TODO: Move everything to the template wrapper? + bool uses_std140_wrapper = physical_type && physical_type->vecsize > 4; + + if (physical_type && is_vector(*physical_type) && is_array(*physical_type) && + !uses_std140_wrapper && + physical_type->vecsize > type.vecsize && !expression_ends_with(expr_str, swizzle_lut[type.vecsize - 1])) + { + // std140 array cases for vectors. + assert(type.vecsize >= 1 && type.vecsize <= 3); + return enclose_expression(expr_str) + swizzle_lut[type.vecsize - 1]; + } + else if (physical_type && is_matrix(*physical_type) && is_vector(type) && + !uses_std140_wrapper && + physical_type->vecsize > type.vecsize) + { + // Extract column from padded matrix. + assert(type.vecsize >= 1 && type.vecsize <= 4); + return enclose_expression(expr_str) + swizzle_lut[type.vecsize - 1]; + } + else if (is_matrix(type)) + { + // Packed matrices are stored as arrays of packed vectors. Unfortunately, + // we can't just pass the array straight to the matrix constructor. We have to + // pass each vector individually, so that they can be unpacked to normal vectors. + if (!physical_type) + physical_type = &type; + + uint32_t vecsize = type.vecsize; + uint32_t columns = type.columns; + if (row_major) + swap(vecsize, columns); + + uint32_t physical_vecsize = row_major ? physical_type->columns : physical_type->vecsize; + + const char *base_type = type.width == 16 ? "half" : "float"; + string unpack_expr = join(base_type, columns, "x", vecsize, "("); + + const char *load_swiz = ""; + const char *data_swiz = physical_vecsize > 4 ? ".data" : ""; + + if (physical_vecsize != vecsize) + load_swiz = swizzle_lut[vecsize - 1]; + + for (uint32_t i = 0; i < columns; i++) + { + if (i > 0) + unpack_expr += ", "; + + if (packed) + unpack_expr += join(base_type, physical_vecsize, "(", expr_str, "[", i, "]", ")", load_swiz); + else + unpack_expr += join(expr_str, "[", i, "]", data_swiz, load_swiz); + } + + unpack_expr += ")"; + return unpack_expr; + } + else + { + return join(type_to_glsl(type), "(", expr_str, ")"); + } +} + +// Emits the file header info +void CompilerMSL::emit_header() +{ + // This particular line can be overridden during compilation, so make it a flag and not a pragma line. + if (suppress_missing_prototypes) + statement("#pragma clang diagnostic ignored \"-Wmissing-prototypes\""); + if (suppress_incompatible_pointer_types_discard_qualifiers) + statement("#pragma clang diagnostic ignored \"-Wincompatible-pointer-types-discards-qualifiers\""); + + // Disable warning about missing braces for array template to make arrays a value type + if (spv_function_implementations.count(SPVFuncImplUnsafeArray) != 0) + statement("#pragma clang diagnostic ignored \"-Wmissing-braces\""); + + for (auto &pragma : pragma_lines) + statement(pragma); + + if (!pragma_lines.empty() || suppress_missing_prototypes) + statement(""); + + statement("#include "); + statement("#include "); + + for (auto &header : header_lines) + statement(header); + + statement(""); + statement("using namespace metal;"); + statement(""); + + for (auto &td : typedef_lines) + statement(td); + + if (!typedef_lines.empty()) + statement(""); +} + +void CompilerMSL::add_pragma_line(const string &line) +{ + auto rslt = pragma_lines.insert(line); + if (rslt.second) + force_recompile(); +} + +void CompilerMSL::add_typedef_line(const string &line) +{ + auto rslt = typedef_lines.insert(line); + if (rslt.second) + force_recompile(); +} + +// Template struct like spvUnsafeArray<> need to be declared *before* any resources are declared +void CompilerMSL::emit_custom_templates() +{ + static const char * const address_spaces[] = { + "thread", "constant", "device", "threadgroup", "threadgroup_imageblock", "ray_data", "object_data" + }; + + for (const auto &spv_func : spv_function_implementations) + { + switch (spv_func) + { + case SPVFuncImplUnsafeArray: + statement("template"); + statement("struct spvUnsafeArray"); + begin_scope(); + statement("T elements[Num ? Num : 1];"); + statement(""); + statement("thread T& operator [] (size_t pos) thread"); + begin_scope(); + statement("return elements[pos];"); + end_scope(); + statement("constexpr const thread T& operator [] (size_t pos) const thread"); + begin_scope(); + statement("return elements[pos];"); + end_scope(); + statement(""); + statement("device T& operator [] (size_t pos) device"); + begin_scope(); + statement("return elements[pos];"); + end_scope(); + statement("constexpr const device T& operator [] (size_t pos) const device"); + begin_scope(); + statement("return elements[pos];"); + end_scope(); + statement(""); + statement("constexpr const constant T& operator [] (size_t pos) const constant"); + begin_scope(); + statement("return elements[pos];"); + end_scope(); + statement(""); + statement("threadgroup T& operator [] (size_t pos) threadgroup"); + begin_scope(); + statement("return elements[pos];"); + end_scope(); + statement("constexpr const threadgroup T& operator [] (size_t pos) const threadgroup"); + begin_scope(); + statement("return elements[pos];"); + end_scope(); + end_scope_decl(); + statement(""); + break; + + case SPVFuncImplStorageMatrix: + statement("template"); + statement("struct spvStorageMatrix"); + begin_scope(); + statement("vec columns[Cols];"); + statement(""); + for (size_t method_idx = 0; method_idx < sizeof(address_spaces) / sizeof(address_spaces[0]); ++method_idx) + { + // Some address spaces require particular features. + if (method_idx == 4) // threadgroup_imageblock + statement("#ifdef __HAVE_IMAGEBLOCKS__"); + else if (method_idx == 5) // ray_data + statement("#ifdef __HAVE_RAYTRACING__"); + else if (method_idx == 6) // object_data + statement("#ifdef __HAVE_MESH__"); + const string &method_as = address_spaces[method_idx]; + statement("spvStorageMatrix() ", method_as, " = default;"); + if (method_idx != 1) // constant + { + statement(method_as, " spvStorageMatrix& operator=(initializer_list> cols) ", + method_as); + begin_scope(); + statement("size_t i;"); + statement("thread vec* col;"); + statement("for (i = 0, col = cols.begin(); i < Cols; ++i, ++col)"); + statement(" columns[i] = *col;"); + statement("return *this;"); + end_scope(); + } + statement(""); + for (size_t param_idx = 0; param_idx < sizeof(address_spaces) / sizeof(address_spaces[0]); ++param_idx) + { + if (param_idx != method_idx) + { + if (param_idx == 4) // threadgroup_imageblock + statement("#ifdef __HAVE_IMAGEBLOCKS__"); + else if (param_idx == 5) // ray_data + statement("#ifdef __HAVE_RAYTRACING__"); + else if (param_idx == 6) // object_data + statement("#ifdef __HAVE_MESH__"); + } + const string ¶m_as = address_spaces[param_idx]; + statement("spvStorageMatrix(const ", param_as, " matrix& m) ", method_as); + begin_scope(); + statement("for (size_t i = 0; i < Cols; ++i)"); + statement(" columns[i] = m.columns[i];"); + end_scope(); + statement("spvStorageMatrix(const ", param_as, " spvStorageMatrix& m) ", method_as, " = default;"); + if (method_idx != 1) // constant + { + statement(method_as, " spvStorageMatrix& operator=(const ", param_as, + " matrix& m) ", method_as); + begin_scope(); + statement("for (size_t i = 0; i < Cols; ++i)"); + statement(" columns[i] = m.columns[i];"); + statement("return *this;"); + end_scope(); + statement(method_as, " spvStorageMatrix& operator=(const ", param_as, " spvStorageMatrix& m) ", + method_as, " = default;"); + } + if (param_idx != method_idx && param_idx >= 4) + statement("#endif"); + statement(""); + } + statement("operator matrix() const ", method_as); + begin_scope(); + statement("matrix m;"); + statement("for (int i = 0; i < Cols; ++i)"); + statement(" m.columns[i] = columns[i];"); + statement("return m;"); + end_scope(); + statement(""); + statement("vec operator[](size_t idx) const ", method_as); + begin_scope(); + statement("return columns[idx];"); + end_scope(); + if (method_idx != 1) // constant + { + statement(method_as, " vec& operator[](size_t idx) ", method_as); + begin_scope(); + statement("return columns[idx];"); + end_scope(); + } + if (method_idx >= 4) + statement("#endif"); + statement(""); + } + end_scope_decl(); + statement(""); + statement("template"); + statement("matrix transpose(spvStorageMatrix m)"); + begin_scope(); + statement("return transpose(matrix(m));"); + end_scope(); + statement(""); + statement("typedef spvStorageMatrix spvStorage_half2x2;"); + statement("typedef spvStorageMatrix spvStorage_half2x3;"); + statement("typedef spvStorageMatrix spvStorage_half2x4;"); + statement("typedef spvStorageMatrix spvStorage_half3x2;"); + statement("typedef spvStorageMatrix spvStorage_half3x3;"); + statement("typedef spvStorageMatrix spvStorage_half3x4;"); + statement("typedef spvStorageMatrix spvStorage_half4x2;"); + statement("typedef spvStorageMatrix spvStorage_half4x3;"); + statement("typedef spvStorageMatrix spvStorage_half4x4;"); + statement("typedef spvStorageMatrix spvStorage_float2x2;"); + statement("typedef spvStorageMatrix spvStorage_float2x3;"); + statement("typedef spvStorageMatrix spvStorage_float2x4;"); + statement("typedef spvStorageMatrix spvStorage_float3x2;"); + statement("typedef spvStorageMatrix spvStorage_float3x3;"); + statement("typedef spvStorageMatrix spvStorage_float3x4;"); + statement("typedef spvStorageMatrix spvStorage_float4x2;"); + statement("typedef spvStorageMatrix spvStorage_float4x3;"); + statement("typedef spvStorageMatrix spvStorage_float4x4;"); + statement(""); + break; + + default: + break; + } + } +} + +// Emits any needed custom function bodies. +// Metal helper functions must be static force-inline, i.e. static inline __attribute__((always_inline)) +// otherwise they will cause problems when linked together in a single Metallib. +void CompilerMSL::emit_custom_functions() +{ + // Use when outputting overloaded functions to cover different address spaces. + static const char *texture_addr_spaces[] = { "device", "constant", "thread" }; + static uint32_t texture_addr_space_count = sizeof(texture_addr_spaces) / sizeof(char*); + + if (spv_function_implementations.count(SPVFuncImplArrayCopyMultidim)) + spv_function_implementations.insert(SPVFuncImplArrayCopy); + + if (spv_function_implementations.count(SPVFuncImplDynamicImageSampler)) + { + // Unfortunately, this one needs a lot of the other functions to compile OK. + if (!msl_options.supports_msl_version(2)) + SPIRV_CROSS_THROW( + "spvDynamicImageSampler requires default-constructible texture objects, which require MSL 2.0."); + spv_function_implementations.insert(SPVFuncImplForwardArgs); + spv_function_implementations.insert(SPVFuncImplTextureSwizzle); + if (msl_options.swizzle_texture_samples) + spv_function_implementations.insert(SPVFuncImplGatherSwizzle); + for (uint32_t i = SPVFuncImplChromaReconstructNearest2Plane; + i <= SPVFuncImplChromaReconstructLinear420XMidpointYMidpoint3Plane; i++) + spv_function_implementations.insert(static_cast(i)); + spv_function_implementations.insert(SPVFuncImplExpandITUFullRange); + spv_function_implementations.insert(SPVFuncImplExpandITUNarrowRange); + spv_function_implementations.insert(SPVFuncImplConvertYCbCrBT709); + spv_function_implementations.insert(SPVFuncImplConvertYCbCrBT601); + spv_function_implementations.insert(SPVFuncImplConvertYCbCrBT2020); + } + + for (uint32_t i = SPVFuncImplChromaReconstructNearest2Plane; + i <= SPVFuncImplChromaReconstructLinear420XMidpointYMidpoint3Plane; i++) + if (spv_function_implementations.count(static_cast(i))) + spv_function_implementations.insert(SPVFuncImplForwardArgs); + + if (spv_function_implementations.count(SPVFuncImplTextureSwizzle) || + spv_function_implementations.count(SPVFuncImplGatherSwizzle) || + spv_function_implementations.count(SPVFuncImplGatherCompareSwizzle)) + { + spv_function_implementations.insert(SPVFuncImplForwardArgs); + spv_function_implementations.insert(SPVFuncImplGetSwizzle); + } + + for (const auto &spv_func : spv_function_implementations) + { + switch (spv_func) + { + case SPVFuncImplMod: + statement("// Implementation of the GLSL mod() function, which is slightly different than Metal fmod()"); + statement("template"); + statement("inline Tx mod(Tx x, Ty y)"); + begin_scope(); + statement("return x - y * floor(x / y);"); + end_scope(); + statement(""); + break; + + case SPVFuncImplRadians: + statement("// Implementation of the GLSL radians() function"); + statement("template"); + statement("inline T radians(T d)"); + begin_scope(); + statement("return d * T(0.01745329251);"); + end_scope(); + statement(""); + break; + + case SPVFuncImplDegrees: + statement("// Implementation of the GLSL degrees() function"); + statement("template"); + statement("inline T degrees(T r)"); + begin_scope(); + statement("return r * T(57.2957795131);"); + end_scope(); + statement(""); + break; + + case SPVFuncImplFindILsb: + statement("// Implementation of the GLSL findLSB() function"); + statement("template"); + statement("inline T spvFindLSB(T x)"); + begin_scope(); + statement("return select(ctz(x), T(-1), x == T(0));"); + end_scope(); + statement(""); + break; + + case SPVFuncImplFindUMsb: + statement("// Implementation of the unsigned GLSL findMSB() function"); + statement("template"); + statement("inline T spvFindUMSB(T x)"); + begin_scope(); + statement("return select(clz(T(0)) - (clz(x) + T(1)), T(-1), x == T(0));"); + end_scope(); + statement(""); + break; + + case SPVFuncImplFindSMsb: + statement("// Implementation of the signed GLSL findMSB() function"); + statement("template"); + statement("inline T spvFindSMSB(T x)"); + begin_scope(); + statement("T v = select(x, T(-1) - x, x < T(0));"); + statement("return select(clz(T(0)) - (clz(v) + T(1)), T(-1), v == T(0));"); + end_scope(); + statement(""); + break; + + case SPVFuncImplSSign: + statement("// Implementation of the GLSL sign() function for integer types"); + statement("template::value>::type>"); + statement("inline T sign(T x)"); + begin_scope(); + statement("return select(select(select(x, T(0), x == T(0)), T(1), x > T(0)), T(-1), x < T(0));"); + end_scope(); + statement(""); + break; + + case SPVFuncImplArrayCopy: + case SPVFuncImplArrayCopyMultidim: + { + // Unfortunately we cannot template on the address space, so combinatorial explosion it is. + static const char *function_name_tags[] = { + "FromConstantToStack", "FromConstantToThreadGroup", "FromStackToStack", + "FromStackToThreadGroup", "FromThreadGroupToStack", "FromThreadGroupToThreadGroup", + "FromDeviceToDevice", "FromConstantToDevice", "FromStackToDevice", + "FromThreadGroupToDevice", "FromDeviceToStack", "FromDeviceToThreadGroup", + }; + + static const char *src_address_space[] = { + "constant", "constant", "thread const", "thread const", + "threadgroup const", "threadgroup const", "device const", "constant", + "thread const", "threadgroup const", "device const", "device const", + }; + + static const char *dst_address_space[] = { + "thread", "threadgroup", "thread", "threadgroup", "thread", "threadgroup", + "device", "device", "device", "device", "thread", "threadgroup", + }; + + for (uint32_t variant = 0; variant < 12; variant++) + { + bool is_multidim = spv_func == SPVFuncImplArrayCopyMultidim; + const char* dim = is_multidim ? "[N][M]" : "[N]"; + statement("template" : ">"); + statement("inline void spvArrayCopy", function_name_tags[variant], "(", + dst_address_space[variant], " T (&dst)", dim, ", ", + src_address_space[variant], " T (&src)", dim, ")"); + begin_scope(); + statement("for (uint i = 0; i < N; i++)"); + begin_scope(); + if (is_multidim) + statement("spvArrayCopy", function_name_tags[variant], "(dst[i], src[i]);"); + else + statement("dst[i] = src[i];"); + end_scope(); + end_scope(); + statement(""); + } + break; + } + + // Support for Metal 2.1's new texture_buffer type. + case SPVFuncImplTexelBufferCoords: + { + if (msl_options.texel_buffer_texture_width > 0) + { + string tex_width_str = convert_to_string(msl_options.texel_buffer_texture_width); + statement("// Returns 2D texture coords corresponding to 1D texel buffer coords"); + statement(force_inline); + statement("uint2 spvTexelBufferCoord(uint tc)"); + begin_scope(); + statement(join("return uint2(tc % ", tex_width_str, ", tc / ", tex_width_str, ");")); + end_scope(); + statement(""); + } + else + { + statement("// Returns 2D texture coords corresponding to 1D texel buffer coords"); + statement( + "#define spvTexelBufferCoord(tc, tex) uint2((tc) % (tex).get_width(), (tc) / (tex).get_width())"); + statement(""); + } + break; + } + + // Emulate texture2D atomic operations + case SPVFuncImplImage2DAtomicCoords: + { + if (msl_options.supports_msl_version(1, 2)) + { + statement("// The required alignment of a linear texture of R32Uint format."); + statement("constant uint spvLinearTextureAlignmentOverride [[function_constant(", + msl_options.r32ui_alignment_constant_id, ")]];"); + statement("constant uint spvLinearTextureAlignment = ", + "is_function_constant_defined(spvLinearTextureAlignmentOverride) ? ", + "spvLinearTextureAlignmentOverride : ", msl_options.r32ui_linear_texture_alignment, ";"); + } + else + { + statement("// The required alignment of a linear texture of R32Uint format."); + statement("constant uint spvLinearTextureAlignment = ", msl_options.r32ui_linear_texture_alignment, + ";"); + } + statement("// Returns buffer coords corresponding to 2D texture coords for emulating 2D texture atomics"); + statement("#define spvImage2DAtomicCoord(tc, tex) (((((tex).get_width() + ", + " spvLinearTextureAlignment / 4 - 1) & ~(", + " spvLinearTextureAlignment / 4 - 1)) * (tc).y) + (tc).x)"); + statement(""); + break; + } + + // Fix up gradient vectors when sampling a cube texture for Apple Silicon. + // h/t Alexey Knyazev (https://github.com/KhronosGroup/MoltenVK/issues/2068#issuecomment-1817799067) for the code. + case SPVFuncImplGradientCube: + statement("static inline gradientcube spvGradientCube(float3 P, float3 dPdx, float3 dPdy)"); + begin_scope(); + statement("// Major axis selection"); + statement("float3 absP = abs(P);"); + statement("bool xMajor = absP.x >= max(absP.y, absP.z);"); + statement("bool yMajor = absP.y >= absP.z;"); + statement("float3 Q = xMajor ? P.yzx : (yMajor ? P.xzy : P);"); + statement("float3 dQdx = xMajor ? dPdx.yzx : (yMajor ? dPdx.xzy : dPdx);"); + statement("float3 dQdy = xMajor ? dPdy.yzx : (yMajor ? dPdy.xzy : dPdy);"); + statement_no_indent(""); + statement("// Skip a couple of operations compared to usual projection"); + statement("float4 d = float4(dQdx.xy, dQdy.xy) - (Q.xy / Q.z).xyxy * float4(dQdx.zz, dQdy.zz);"); + statement_no_indent(""); + statement("// Final swizzle to put the intermediate values into non-ignored components"); + statement("// X major: X and Z"); + statement("// Y major: X and Y"); + statement("// Z major: Y and Z"); + statement("return gradientcube(xMajor ? d.xxy : d.xyx, xMajor ? d.zzw : d.zwz);"); + end_scope(); + statement(""); + break; + + // "fadd" intrinsic support + case SPVFuncImplFAdd: + statement("template"); + statement("[[clang::optnone]] T spvFAdd(T l, T r)"); + begin_scope(); + statement("return fma(T(1), l, r);"); + end_scope(); + statement(""); + break; + + // "fsub" intrinsic support + case SPVFuncImplFSub: + statement("template"); + statement("[[clang::optnone]] T spvFSub(T l, T r)"); + begin_scope(); + statement("return fma(T(-1), r, l);"); + end_scope(); + statement(""); + break; + + // "fmul' intrinsic support + case SPVFuncImplFMul: + statement("template"); + statement("[[clang::optnone]] T spvFMul(T l, T r)"); + begin_scope(); + statement("return fma(l, r, T(0));"); + end_scope(); + statement(""); + + statement("template"); + statement("[[clang::optnone]] vec spvFMulVectorMatrix(vec v, matrix m)"); + begin_scope(); + statement("vec res = vec(0);"); + statement("for (uint i = Rows; i > 0; --i)"); + begin_scope(); + statement("vec tmp(0);"); + statement("for (uint j = 0; j < Cols; ++j)"); + begin_scope(); + statement("tmp[j] = m[j][i - 1];"); + end_scope(); + statement("res = fma(tmp, vec(v[i - 1]), res);"); + end_scope(); + statement("return res;"); + end_scope(); + statement(""); + + statement("template"); + statement("[[clang::optnone]] vec spvFMulMatrixVector(matrix m, vec v)"); + begin_scope(); + statement("vec res = vec(0);"); + statement("for (uint i = Cols; i > 0; --i)"); + begin_scope(); + statement("res = fma(m[i - 1], vec(v[i - 1]), res);"); + end_scope(); + statement("return res;"); + end_scope(); + statement(""); + + statement("template"); + statement("[[clang::optnone]] matrix spvFMulMatrixMatrix(matrix l, matrix r)"); + begin_scope(); + statement("matrix res;"); + statement("for (uint i = 0; i < RCols; i++)"); + begin_scope(); + statement("vec tmp(0);"); + statement("for (uint j = 0; j < LCols; j++)"); + begin_scope(); + statement("tmp = fma(vec(r[i][j]), l[j], tmp);"); + end_scope(); + statement("res[i] = tmp;"); + end_scope(); + statement("return res;"); + end_scope(); + statement(""); + break; + + case SPVFuncImplQuantizeToF16: + // Ensure fast-math is disabled to match Vulkan results. + // SpvHalfTypeSelector is used to match the half* template type to the float* template type. + // Depending on GPU, MSL does not always flush converted subnormal halfs to zero, + // as required by OpQuantizeToF16, so check for subnormals and flush them to zero. + statement("template struct SpvHalfTypeSelector;"); + statement("template <> struct SpvHalfTypeSelector { public: using H = half; };"); + statement("template struct SpvHalfTypeSelector> { using H = vec; };"); + statement("template::H>"); + statement("[[clang::optnone]] F spvQuantizeToF16(F fval)"); + begin_scope(); + statement("H hval = H(fval);"); + statement("hval = select(copysign(H(0), hval), hval, isnormal(hval) || isinf(hval) || isnan(hval));"); + statement("return F(hval);"); + end_scope(); + statement(""); + break; + + // Emulate texturecube_array with texture2d_array for iOS where this type is not available + case SPVFuncImplCubemapTo2DArrayFace: + statement(force_inline); + statement("float3 spvCubemapTo2DArrayFace(float3 P)"); + begin_scope(); + statement("float3 Coords = abs(P.xyz);"); + statement("float CubeFace = 0;"); + statement("float ProjectionAxis = 0;"); + statement("float u = 0;"); + statement("float v = 0;"); + statement("if (Coords.x >= Coords.y && Coords.x >= Coords.z)"); + begin_scope(); + statement("CubeFace = P.x >= 0 ? 0 : 1;"); + statement("ProjectionAxis = Coords.x;"); + statement("u = P.x >= 0 ? -P.z : P.z;"); + statement("v = -P.y;"); + end_scope(); + statement("else if (Coords.y >= Coords.x && Coords.y >= Coords.z)"); + begin_scope(); + statement("CubeFace = P.y >= 0 ? 2 : 3;"); + statement("ProjectionAxis = Coords.y;"); + statement("u = P.x;"); + statement("v = P.y >= 0 ? P.z : -P.z;"); + end_scope(); + statement("else"); + begin_scope(); + statement("CubeFace = P.z >= 0 ? 4 : 5;"); + statement("ProjectionAxis = Coords.z;"); + statement("u = P.z >= 0 ? P.x : -P.x;"); + statement("v = -P.y;"); + end_scope(); + statement("u = 0.5 * (u/ProjectionAxis + 1);"); + statement("v = 0.5 * (v/ProjectionAxis + 1);"); + statement("return float3(u, v, CubeFace);"); + end_scope(); + statement(""); + break; + + case SPVFuncImplInverse4x4: + statement("// Returns the determinant of a 2x2 matrix."); + statement(force_inline); + statement("float spvDet2x2(float a1, float a2, float b1, float b2)"); + begin_scope(); + statement("return a1 * b2 - b1 * a2;"); + end_scope(); + statement(""); + + statement("// Returns the determinant of a 3x3 matrix."); + statement(force_inline); + statement("float spvDet3x3(float a1, float a2, float a3, float b1, float b2, float b3, float c1, " + "float c2, float c3)"); + begin_scope(); + statement("return a1 * spvDet2x2(b2, b3, c2, c3) - b1 * spvDet2x2(a2, a3, c2, c3) + c1 * spvDet2x2(a2, a3, " + "b2, b3);"); + end_scope(); + statement(""); + statement("// Returns the inverse of a matrix, by using the algorithm of calculating the classical"); + statement("// adjoint and dividing by the determinant. The contents of the matrix are changed."); + statement(force_inline); + statement("float4x4 spvInverse4x4(float4x4 m)"); + begin_scope(); + statement("float4x4 adj; // The adjoint matrix (inverse after dividing by determinant)"); + statement_no_indent(""); + statement("// Create the transpose of the cofactors, as the classical adjoint of the matrix."); + statement("adj[0][0] = spvDet3x3(m[1][1], m[1][2], m[1][3], m[2][1], m[2][2], m[2][3], m[3][1], m[3][2], " + "m[3][3]);"); + statement("adj[0][1] = -spvDet3x3(m[0][1], m[0][2], m[0][3], m[2][1], m[2][2], m[2][3], m[3][1], m[3][2], " + "m[3][3]);"); + statement("adj[0][2] = spvDet3x3(m[0][1], m[0][2], m[0][3], m[1][1], m[1][2], m[1][3], m[3][1], m[3][2], " + "m[3][3]);"); + statement("adj[0][3] = -spvDet3x3(m[0][1], m[0][2], m[0][3], m[1][1], m[1][2], m[1][3], m[2][1], m[2][2], " + "m[2][3]);"); + statement_no_indent(""); + statement("adj[1][0] = -spvDet3x3(m[1][0], m[1][2], m[1][3], m[2][0], m[2][2], m[2][3], m[3][0], m[3][2], " + "m[3][3]);"); + statement("adj[1][1] = spvDet3x3(m[0][0], m[0][2], m[0][3], m[2][0], m[2][2], m[2][3], m[3][0], m[3][2], " + "m[3][3]);"); + statement("adj[1][2] = -spvDet3x3(m[0][0], m[0][2], m[0][3], m[1][0], m[1][2], m[1][3], m[3][0], m[3][2], " + "m[3][3]);"); + statement("adj[1][3] = spvDet3x3(m[0][0], m[0][2], m[0][3], m[1][0], m[1][2], m[1][3], m[2][0], m[2][2], " + "m[2][3]);"); + statement_no_indent(""); + statement("adj[2][0] = spvDet3x3(m[1][0], m[1][1], m[1][3], m[2][0], m[2][1], m[2][3], m[3][0], m[3][1], " + "m[3][3]);"); + statement("adj[2][1] = -spvDet3x3(m[0][0], m[0][1], m[0][3], m[2][0], m[2][1], m[2][3], m[3][0], m[3][1], " + "m[3][3]);"); + statement("adj[2][2] = spvDet3x3(m[0][0], m[0][1], m[0][3], m[1][0], m[1][1], m[1][3], m[3][0], m[3][1], " + "m[3][3]);"); + statement("adj[2][3] = -spvDet3x3(m[0][0], m[0][1], m[0][3], m[1][0], m[1][1], m[1][3], m[2][0], m[2][1], " + "m[2][3]);"); + statement_no_indent(""); + statement("adj[3][0] = -spvDet3x3(m[1][0], m[1][1], m[1][2], m[2][0], m[2][1], m[2][2], m[3][0], m[3][1], " + "m[3][2]);"); + statement("adj[3][1] = spvDet3x3(m[0][0], m[0][1], m[0][2], m[2][0], m[2][1], m[2][2], m[3][0], m[3][1], " + "m[3][2]);"); + statement("adj[3][2] = -spvDet3x3(m[0][0], m[0][1], m[0][2], m[1][0], m[1][1], m[1][2], m[3][0], m[3][1], " + "m[3][2]);"); + statement("adj[3][3] = spvDet3x3(m[0][0], m[0][1], m[0][2], m[1][0], m[1][1], m[1][2], m[2][0], m[2][1], " + "m[2][2]);"); + statement_no_indent(""); + statement("// Calculate the determinant as a combination of the cofactors of the first row."); + statement("float det = (adj[0][0] * m[0][0]) + (adj[0][1] * m[1][0]) + (adj[0][2] * m[2][0]) + (adj[0][3] " + "* m[3][0]);"); + statement_no_indent(""); + statement("// Divide the classical adjoint matrix by the determinant."); + statement("// If determinant is zero, matrix is not invertable, so leave it unchanged."); + statement("return (det != 0.0f) ? (adj * (1.0f / det)) : m;"); + end_scope(); + statement(""); + break; + + case SPVFuncImplInverse3x3: + if (spv_function_implementations.count(SPVFuncImplInverse4x4) == 0) + { + statement("// Returns the determinant of a 2x2 matrix."); + statement(force_inline); + statement("float spvDet2x2(float a1, float a2, float b1, float b2)"); + begin_scope(); + statement("return a1 * b2 - b1 * a2;"); + end_scope(); + statement(""); + } + + statement("// Returns the inverse of a matrix, by using the algorithm of calculating the classical"); + statement("// adjoint and dividing by the determinant. The contents of the matrix are changed."); + statement(force_inline); + statement("float3x3 spvInverse3x3(float3x3 m)"); + begin_scope(); + statement("float3x3 adj; // The adjoint matrix (inverse after dividing by determinant)"); + statement_no_indent(""); + statement("// Create the transpose of the cofactors, as the classical adjoint of the matrix."); + statement("adj[0][0] = spvDet2x2(m[1][1], m[1][2], m[2][1], m[2][2]);"); + statement("adj[0][1] = -spvDet2x2(m[0][1], m[0][2], m[2][1], m[2][2]);"); + statement("adj[0][2] = spvDet2x2(m[0][1], m[0][2], m[1][1], m[1][2]);"); + statement_no_indent(""); + statement("adj[1][0] = -spvDet2x2(m[1][0], m[1][2], m[2][0], m[2][2]);"); + statement("adj[1][1] = spvDet2x2(m[0][0], m[0][2], m[2][0], m[2][2]);"); + statement("adj[1][2] = -spvDet2x2(m[0][0], m[0][2], m[1][0], m[1][2]);"); + statement_no_indent(""); + statement("adj[2][0] = spvDet2x2(m[1][0], m[1][1], m[2][0], m[2][1]);"); + statement("adj[2][1] = -spvDet2x2(m[0][0], m[0][1], m[2][0], m[2][1]);"); + statement("adj[2][2] = spvDet2x2(m[0][0], m[0][1], m[1][0], m[1][1]);"); + statement_no_indent(""); + statement("// Calculate the determinant as a combination of the cofactors of the first row."); + statement("float det = (adj[0][0] * m[0][0]) + (adj[0][1] * m[1][0]) + (adj[0][2] * m[2][0]);"); + statement_no_indent(""); + statement("// Divide the classical adjoint matrix by the determinant."); + statement("// If determinant is zero, matrix is not invertable, so leave it unchanged."); + statement("return (det != 0.0f) ? (adj * (1.0f / det)) : m;"); + end_scope(); + statement(""); + break; + + case SPVFuncImplInverse2x2: + statement("// Returns the inverse of a matrix, by using the algorithm of calculating the classical"); + statement("// adjoint and dividing by the determinant. The contents of the matrix are changed."); + statement(force_inline); + statement("float2x2 spvInverse2x2(float2x2 m)"); + begin_scope(); + statement("float2x2 adj; // The adjoint matrix (inverse after dividing by determinant)"); + statement_no_indent(""); + statement("// Create the transpose of the cofactors, as the classical adjoint of the matrix."); + statement("adj[0][0] = m[1][1];"); + statement("adj[0][1] = -m[0][1];"); + statement_no_indent(""); + statement("adj[1][0] = -m[1][0];"); + statement("adj[1][1] = m[0][0];"); + statement_no_indent(""); + statement("// Calculate the determinant as a combination of the cofactors of the first row."); + statement("float det = (adj[0][0] * m[0][0]) + (adj[0][1] * m[1][0]);"); + statement_no_indent(""); + statement("// Divide the classical adjoint matrix by the determinant."); + statement("// If determinant is zero, matrix is not invertable, so leave it unchanged."); + statement("return (det != 0.0f) ? (adj * (1.0f / det)) : m;"); + end_scope(); + statement(""); + break; + + case SPVFuncImplForwardArgs: + statement("template struct spvRemoveReference { typedef T type; };"); + statement("template struct spvRemoveReference { typedef T type; };"); + statement("template struct spvRemoveReference { typedef T type; };"); + statement("template inline constexpr thread T&& spvForward(thread typename " + "spvRemoveReference::type& x)"); + begin_scope(); + statement("return static_cast(x);"); + end_scope(); + statement("template inline constexpr thread T&& spvForward(thread typename " + "spvRemoveReference::type&& x)"); + begin_scope(); + statement("return static_cast(x);"); + end_scope(); + statement(""); + break; + + case SPVFuncImplGetSwizzle: + statement("enum class spvSwizzle : uint"); + begin_scope(); + statement("none = 0,"); + statement("zero,"); + statement("one,"); + statement("red,"); + statement("green,"); + statement("blue,"); + statement("alpha"); + end_scope_decl(); + statement(""); + statement("template"); + statement("inline T spvGetSwizzle(vec x, T c, spvSwizzle s)"); + begin_scope(); + statement("switch (s)"); + begin_scope(); + statement("case spvSwizzle::none:"); + statement(" return c;"); + statement("case spvSwizzle::zero:"); + statement(" return 0;"); + statement("case spvSwizzle::one:"); + statement(" return 1;"); + statement("case spvSwizzle::red:"); + statement(" return x.r;"); + statement("case spvSwizzle::green:"); + statement(" return x.g;"); + statement("case spvSwizzle::blue:"); + statement(" return x.b;"); + statement("case spvSwizzle::alpha:"); + statement(" return x.a;"); + end_scope(); + end_scope(); + statement(""); + break; + + case SPVFuncImplTextureSwizzle: + statement("// Wrapper function that swizzles texture samples and fetches."); + statement("template"); + statement("inline vec spvTextureSwizzle(vec x, uint s)"); + begin_scope(); + statement("if (!s)"); + statement(" return x;"); + statement("return vec(spvGetSwizzle(x, x.r, spvSwizzle((s >> 0) & 0xFF)), " + "spvGetSwizzle(x, x.g, spvSwizzle((s >> 8) & 0xFF)), spvGetSwizzle(x, x.b, spvSwizzle((s >> 16) " + "& 0xFF)), " + "spvGetSwizzle(x, x.a, spvSwizzle((s >> 24) & 0xFF)));"); + end_scope(); + statement(""); + statement("template"); + statement("inline T spvTextureSwizzle(T x, uint s)"); + begin_scope(); + statement("return spvTextureSwizzle(vec(x, 0, 0, 1), s).x;"); + end_scope(); + statement(""); + break; + + case SPVFuncImplGatherSwizzle: + statement("// Wrapper function that swizzles texture gathers."); + statement("template class Tex, " + "typename... Ts>"); + statement("inline vec spvGatherSwizzle(const thread Tex& t, sampler s, " + "uint sw, component c, Ts... params) METAL_CONST_ARG(c)"); + begin_scope(); + statement("if (sw)"); + begin_scope(); + statement("switch (spvSwizzle((sw >> (uint(c) * 8)) & 0xFF))"); + begin_scope(); + statement("case spvSwizzle::none:"); + statement(" break;"); + statement("case spvSwizzle::zero:"); + statement(" return vec(0, 0, 0, 0);"); + statement("case spvSwizzle::one:"); + statement(" return vec(1, 1, 1, 1);"); + statement("case spvSwizzle::red:"); + statement(" return t.gather(s, spvForward(params)..., component::x);"); + statement("case spvSwizzle::green:"); + statement(" return t.gather(s, spvForward(params)..., component::y);"); + statement("case spvSwizzle::blue:"); + statement(" return t.gather(s, spvForward(params)..., component::z);"); + statement("case spvSwizzle::alpha:"); + statement(" return t.gather(s, spvForward(params)..., component::w);"); + end_scope(); + end_scope(); + // texture::gather insists on its component parameter being a constant + // expression, so we need this silly workaround just to compile the shader. + statement("switch (c)"); + begin_scope(); + statement("case component::x:"); + statement(" return t.gather(s, spvForward(params)..., component::x);"); + statement("case component::y:"); + statement(" return t.gather(s, spvForward(params)..., component::y);"); + statement("case component::z:"); + statement(" return t.gather(s, spvForward(params)..., component::z);"); + statement("case component::w:"); + statement(" return t.gather(s, spvForward(params)..., component::w);"); + end_scope(); + end_scope(); + statement(""); + break; + + case SPVFuncImplGatherCompareSwizzle: + statement("// Wrapper function that swizzles depth texture gathers."); + statement("template class Tex, " + "typename... Ts>"); + statement("inline vec spvGatherCompareSwizzle(const thread Tex& t, sampler " + "s, uint sw, Ts... params) "); + begin_scope(); + statement("if (sw)"); + begin_scope(); + statement("switch (spvSwizzle(sw & 0xFF))"); + begin_scope(); + statement("case spvSwizzle::none:"); + statement("case spvSwizzle::red:"); + statement(" break;"); + statement("case spvSwizzle::zero:"); + statement("case spvSwizzle::green:"); + statement("case spvSwizzle::blue:"); + statement("case spvSwizzle::alpha:"); + statement(" return vec(0, 0, 0, 0);"); + statement("case spvSwizzle::one:"); + statement(" return vec(1, 1, 1, 1);"); + end_scope(); + end_scope(); + statement("return t.gather_compare(s, spvForward(params)...);"); + end_scope(); + statement(""); + break; + + case SPVFuncImplGatherConstOffsets: + // Because we are passing a texture reference, we have to output an overloaded version of this function for each address space. + for (uint32_t i = 0; i < texture_addr_space_count; i++) + { + statement("// Wrapper function that processes a ", texture_addr_spaces[i], " texture gather with a constant offset array."); + statement("template class Tex, " + "typename Toff, typename... Tp>"); + statement("inline vec spvGatherConstOffsets(const ", texture_addr_spaces[i], " Tex& t, sampler s, " + "Toff coffsets, component c, Tp... params) METAL_CONST_ARG(c)"); + begin_scope(); + statement("vec rslts[4];"); + statement("for (uint i = 0; i < 4; i++)"); + begin_scope(); + statement("switch (c)"); + begin_scope(); + // Work around texture::gather() requiring its component parameter to be a constant expression + statement("case component::x:"); + statement(" rslts[i] = t.gather(s, spvForward(params)..., coffsets[i], component::x);"); + statement(" break;"); + statement("case component::y:"); + statement(" rslts[i] = t.gather(s, spvForward(params)..., coffsets[i], component::y);"); + statement(" break;"); + statement("case component::z:"); + statement(" rslts[i] = t.gather(s, spvForward(params)..., coffsets[i], component::z);"); + statement(" break;"); + statement("case component::w:"); + statement(" rslts[i] = t.gather(s, spvForward(params)..., coffsets[i], component::w);"); + statement(" break;"); + end_scope(); + end_scope(); + // Pull all values from the i0j0 component of each gather footprint + statement("return vec(rslts[0].w, rslts[1].w, rslts[2].w, rslts[3].w);"); + end_scope(); + statement(""); + } + break; + + case SPVFuncImplGatherCompareConstOffsets: + // Because we are passing a texture reference, we have to output an overloaded version of this function for each address space. + for (uint32_t i = 0; i < texture_addr_space_count; i++) + { + statement("// Wrapper function that processes a ", texture_addr_spaces[i], " texture gather with a constant offset array."); + statement("template class Tex, " + "typename Toff, typename... Tp>"); + statement("inline vec spvGatherCompareConstOffsets(const ", texture_addr_spaces[i], " Tex& t, sampler s, " + "Toff coffsets, Tp... params)"); + begin_scope(); + statement("vec rslts[4];"); + statement("for (uint i = 0; i < 4; i++)"); + begin_scope(); + statement(" rslts[i] = t.gather_compare(s, spvForward(params)..., coffsets[i]);"); + end_scope(); + // Pull all values from the i0j0 component of each gather footprint + statement("return vec(rslts[0].w, rslts[1].w, rslts[2].w, rslts[3].w);"); + end_scope(); + statement(""); + } + break; + + case SPVFuncImplSubgroupBroadcast: + // Metal doesn't allow broadcasting boolean values directly, but we can work around that by broadcasting + // them as integers. + statement("template"); + statement("inline T spvSubgroupBroadcast(T value, ushort lane)"); + begin_scope(); + if (msl_options.use_quadgroup_operation()) + statement("return quad_broadcast(value, lane);"); + else + statement("return simd_broadcast(value, lane);"); + end_scope(); + statement(""); + statement("template<>"); + statement("inline bool spvSubgroupBroadcast(bool value, ushort lane)"); + begin_scope(); + if (msl_options.use_quadgroup_operation()) + statement("return !!quad_broadcast((ushort)value, lane);"); + else + statement("return !!simd_broadcast((ushort)value, lane);"); + end_scope(); + statement(""); + statement("template"); + statement("inline vec spvSubgroupBroadcast(vec value, ushort lane)"); + begin_scope(); + if (msl_options.use_quadgroup_operation()) + statement("return (vec)quad_broadcast((vec)value, lane);"); + else + statement("return (vec)simd_broadcast((vec)value, lane);"); + end_scope(); + statement(""); + break; + + case SPVFuncImplSubgroupBroadcastFirst: + statement("template"); + statement("inline T spvSubgroupBroadcastFirst(T value)"); + begin_scope(); + if (msl_options.use_quadgroup_operation()) + statement("return quad_broadcast_first(value);"); + else + statement("return simd_broadcast_first(value);"); + end_scope(); + statement(""); + statement("template<>"); + statement("inline bool spvSubgroupBroadcastFirst(bool value)"); + begin_scope(); + if (msl_options.use_quadgroup_operation()) + statement("return !!quad_broadcast_first((ushort)value);"); + else + statement("return !!simd_broadcast_first((ushort)value);"); + end_scope(); + statement(""); + statement("template"); + statement("inline vec spvSubgroupBroadcastFirst(vec value)"); + begin_scope(); + if (msl_options.use_quadgroup_operation()) + statement("return (vec)quad_broadcast_first((vec)value);"); + else + statement("return (vec)simd_broadcast_first((vec)value);"); + end_scope(); + statement(""); + break; + + case SPVFuncImplSubgroupBallot: + statement("inline uint4 spvSubgroupBallot(bool value)"); + begin_scope(); + if (msl_options.use_quadgroup_operation()) + { + statement("return uint4((quad_vote::vote_t)quad_ballot(value), 0, 0, 0);"); + } + else if (msl_options.is_ios()) + { + // The current simd_vote on iOS uses a 32-bit integer-like object. + statement("return uint4((simd_vote::vote_t)simd_ballot(value), 0, 0, 0);"); + } + else + { + statement("simd_vote vote = simd_ballot(value);"); + statement("// simd_ballot() returns a 64-bit integer-like object, but"); + statement("// SPIR-V callers expect a uint4. We must convert."); + statement("// FIXME: This won't include higher bits if Apple ever supports"); + statement("// 128 lanes in an SIMD-group."); + statement("return uint4(as_type((simd_vote::vote_t)vote), 0, 0);"); + } + end_scope(); + statement(""); + break; + + case SPVFuncImplSubgroupBallotBitExtract: + statement("inline bool spvSubgroupBallotBitExtract(uint4 ballot, uint bit)"); + begin_scope(); + statement("return !!extract_bits(ballot[bit / 32], bit % 32, 1);"); + end_scope(); + statement(""); + break; + + case SPVFuncImplSubgroupBallotFindLSB: + statement("inline uint spvSubgroupBallotFindLSB(uint4 ballot, uint gl_SubgroupSize)"); + begin_scope(); + if (msl_options.is_ios()) + { + statement("uint4 mask = uint4(extract_bits(0xFFFFFFFF, 0, gl_SubgroupSize), uint3(0));"); + } + else + { + statement("uint4 mask = uint4(extract_bits(0xFFFFFFFF, 0, min(gl_SubgroupSize, 32u)), " + "extract_bits(0xFFFFFFFF, 0, (uint)max((int)gl_SubgroupSize - 32, 0)), uint2(0));"); + } + statement("ballot &= mask;"); + statement("return select(ctz(ballot.x), select(32 + ctz(ballot.y), select(64 + ctz(ballot.z), select(96 + " + "ctz(ballot.w), uint(-1), ballot.w == 0), ballot.z == 0), ballot.y == 0), ballot.x == 0);"); + end_scope(); + statement(""); + break; + + case SPVFuncImplSubgroupBallotFindMSB: + statement("inline uint spvSubgroupBallotFindMSB(uint4 ballot, uint gl_SubgroupSize)"); + begin_scope(); + if (msl_options.is_ios()) + { + statement("uint4 mask = uint4(extract_bits(0xFFFFFFFF, 0, gl_SubgroupSize), uint3(0));"); + } + else + { + statement("uint4 mask = uint4(extract_bits(0xFFFFFFFF, 0, min(gl_SubgroupSize, 32u)), " + "extract_bits(0xFFFFFFFF, 0, (uint)max((int)gl_SubgroupSize - 32, 0)), uint2(0));"); + } + statement("ballot &= mask;"); + statement("return select(128 - (clz(ballot.w) + 1), select(96 - (clz(ballot.z) + 1), select(64 - " + "(clz(ballot.y) + 1), select(32 - (clz(ballot.x) + 1), uint(-1), ballot.x == 0), ballot.y == 0), " + "ballot.z == 0), ballot.w == 0);"); + end_scope(); + statement(""); + break; + + case SPVFuncImplSubgroupBallotBitCount: + statement("inline uint spvPopCount4(uint4 ballot)"); + begin_scope(); + statement("return popcount(ballot.x) + popcount(ballot.y) + popcount(ballot.z) + popcount(ballot.w);"); + end_scope(); + statement(""); + statement("inline uint spvSubgroupBallotBitCount(uint4 ballot, uint gl_SubgroupSize)"); + begin_scope(); + if (msl_options.is_ios()) + { + statement("uint4 mask = uint4(extract_bits(0xFFFFFFFF, 0, gl_SubgroupSize), uint3(0));"); + } + else + { + statement("uint4 mask = uint4(extract_bits(0xFFFFFFFF, 0, min(gl_SubgroupSize, 32u)), " + "extract_bits(0xFFFFFFFF, 0, (uint)max((int)gl_SubgroupSize - 32, 0)), uint2(0));"); + } + statement("return spvPopCount4(ballot & mask);"); + end_scope(); + statement(""); + statement("inline uint spvSubgroupBallotInclusiveBitCount(uint4 ballot, uint gl_SubgroupInvocationID)"); + begin_scope(); + if (msl_options.is_ios()) + { + statement("uint4 mask = uint4(extract_bits(0xFFFFFFFF, 0, gl_SubgroupInvocationID + 1), uint3(0));"); + } + else + { + statement("uint4 mask = uint4(extract_bits(0xFFFFFFFF, 0, min(gl_SubgroupInvocationID + 1, 32u)), " + "extract_bits(0xFFFFFFFF, 0, (uint)max((int)gl_SubgroupInvocationID + 1 - 32, 0)), " + "uint2(0));"); + } + statement("return spvPopCount4(ballot & mask);"); + end_scope(); + statement(""); + statement("inline uint spvSubgroupBallotExclusiveBitCount(uint4 ballot, uint gl_SubgroupInvocationID)"); + begin_scope(); + if (msl_options.is_ios()) + { + statement("uint4 mask = uint4(extract_bits(0xFFFFFFFF, 0, gl_SubgroupInvocationID), uint2(0));"); + } + else + { + statement("uint4 mask = uint4(extract_bits(0xFFFFFFFF, 0, min(gl_SubgroupInvocationID, 32u)), " + "extract_bits(0xFFFFFFFF, 0, (uint)max((int)gl_SubgroupInvocationID - 32, 0)), uint2(0));"); + } + statement("return spvPopCount4(ballot & mask);"); + end_scope(); + statement(""); + break; + + case SPVFuncImplSubgroupAllEqual: + // Metal doesn't provide a function to evaluate this directly. But, we can + // implement this by comparing every thread's value to one thread's value + // (in this case, the value of the first active thread). Then, by the transitive + // property of equality, if all comparisons return true, then they are all equal. + statement("template"); + statement("inline bool spvSubgroupAllEqual(T value)"); + begin_scope(); + if (msl_options.use_quadgroup_operation()) + statement("return quad_all(all(value == quad_broadcast_first(value)));"); + else + statement("return simd_all(all(value == simd_broadcast_first(value)));"); + end_scope(); + statement(""); + statement("template<>"); + statement("inline bool spvSubgroupAllEqual(bool value)"); + begin_scope(); + if (msl_options.use_quadgroup_operation()) + statement("return quad_all(value) || !quad_any(value);"); + else + statement("return simd_all(value) || !simd_any(value);"); + end_scope(); + statement(""); + statement("template"); + statement("inline bool spvSubgroupAllEqual(vec value)"); + begin_scope(); + if (msl_options.use_quadgroup_operation()) + statement("return quad_all(all(value == (vec)quad_broadcast_first((vec)value)));"); + else + statement("return simd_all(all(value == (vec)simd_broadcast_first((vec)value)));"); + end_scope(); + statement(""); + break; + + case SPVFuncImplSubgroupShuffle: + statement("template"); + statement("inline T spvSubgroupShuffle(T value, ushort lane)"); + begin_scope(); + if (msl_options.use_quadgroup_operation()) + statement("return quad_shuffle(value, lane);"); + else + statement("return simd_shuffle(value, lane);"); + end_scope(); + statement(""); + statement("template<>"); + statement("inline bool spvSubgroupShuffle(bool value, ushort lane)"); + begin_scope(); + if (msl_options.use_quadgroup_operation()) + statement("return !!quad_shuffle((ushort)value, lane);"); + else + statement("return !!simd_shuffle((ushort)value, lane);"); + end_scope(); + statement(""); + statement("template"); + statement("inline vec spvSubgroupShuffle(vec value, ushort lane)"); + begin_scope(); + if (msl_options.use_quadgroup_operation()) + statement("return (vec)quad_shuffle((vec)value, lane);"); + else + statement("return (vec)simd_shuffle((vec)value, lane);"); + end_scope(); + statement(""); + break; + + case SPVFuncImplSubgroupShuffleXor: + statement("template"); + statement("inline T spvSubgroupShuffleXor(T value, ushort mask)"); + begin_scope(); + if (msl_options.use_quadgroup_operation()) + statement("return quad_shuffle_xor(value, mask);"); + else + statement("return simd_shuffle_xor(value, mask);"); + end_scope(); + statement(""); + statement("template<>"); + statement("inline bool spvSubgroupShuffleXor(bool value, ushort mask)"); + begin_scope(); + if (msl_options.use_quadgroup_operation()) + statement("return !!quad_shuffle_xor((ushort)value, mask);"); + else + statement("return !!simd_shuffle_xor((ushort)value, mask);"); + end_scope(); + statement(""); + statement("template"); + statement("inline vec spvSubgroupShuffleXor(vec value, ushort mask)"); + begin_scope(); + if (msl_options.use_quadgroup_operation()) + statement("return (vec)quad_shuffle_xor((vec)value, mask);"); + else + statement("return (vec)simd_shuffle_xor((vec)value, mask);"); + end_scope(); + statement(""); + break; + + case SPVFuncImplSubgroupShuffleUp: + statement("template"); + statement("inline T spvSubgroupShuffleUp(T value, ushort delta)"); + begin_scope(); + if (msl_options.use_quadgroup_operation()) + statement("return quad_shuffle_up(value, delta);"); + else + statement("return simd_shuffle_up(value, delta);"); + end_scope(); + statement(""); + statement("template<>"); + statement("inline bool spvSubgroupShuffleUp(bool value, ushort delta)"); + begin_scope(); + if (msl_options.use_quadgroup_operation()) + statement("return !!quad_shuffle_up((ushort)value, delta);"); + else + statement("return !!simd_shuffle_up((ushort)value, delta);"); + end_scope(); + statement(""); + statement("template"); + statement("inline vec spvSubgroupShuffleUp(vec value, ushort delta)"); + begin_scope(); + if (msl_options.use_quadgroup_operation()) + statement("return (vec)quad_shuffle_up((vec)value, delta);"); + else + statement("return (vec)simd_shuffle_up((vec)value, delta);"); + end_scope(); + statement(""); + break; + + case SPVFuncImplSubgroupShuffleDown: + statement("template"); + statement("inline T spvSubgroupShuffleDown(T value, ushort delta)"); + begin_scope(); + if (msl_options.use_quadgroup_operation()) + statement("return quad_shuffle_down(value, delta);"); + else + statement("return simd_shuffle_down(value, delta);"); + end_scope(); + statement(""); + statement("template<>"); + statement("inline bool spvSubgroupShuffleDown(bool value, ushort delta)"); + begin_scope(); + if (msl_options.use_quadgroup_operation()) + statement("return !!quad_shuffle_down((ushort)value, delta);"); + else + statement("return !!simd_shuffle_down((ushort)value, delta);"); + end_scope(); + statement(""); + statement("template"); + statement("inline vec spvSubgroupShuffleDown(vec value, ushort delta)"); + begin_scope(); + if (msl_options.use_quadgroup_operation()) + statement("return (vec)quad_shuffle_down((vec)value, delta);"); + else + statement("return (vec)simd_shuffle_down((vec)value, delta);"); + end_scope(); + statement(""); + break; + + case SPVFuncImplQuadBroadcast: + statement("template"); + statement("inline T spvQuadBroadcast(T value, uint lane)"); + begin_scope(); + statement("return quad_broadcast(value, lane);"); + end_scope(); + statement(""); + statement("template<>"); + statement("inline bool spvQuadBroadcast(bool value, uint lane)"); + begin_scope(); + statement("return !!quad_broadcast((ushort)value, lane);"); + end_scope(); + statement(""); + statement("template"); + statement("inline vec spvQuadBroadcast(vec value, uint lane)"); + begin_scope(); + statement("return (vec)quad_broadcast((vec)value, lane);"); + end_scope(); + statement(""); + break; + + case SPVFuncImplQuadSwap: + // We can implement this easily based on the following table giving + // the target lane ID from the direction and current lane ID: + // Direction + // | 0 | 1 | 2 | + // ---+---+---+---+ + // L 0 | 1 2 3 + // a 1 | 0 3 2 + // n 2 | 3 0 1 + // e 3 | 2 1 0 + // Notice that target = source ^ (direction + 1). + statement("template"); + statement("inline T spvQuadSwap(T value, uint dir)"); + begin_scope(); + statement("return quad_shuffle_xor(value, dir + 1);"); + end_scope(); + statement(""); + statement("template<>"); + statement("inline bool spvQuadSwap(bool value, uint dir)"); + begin_scope(); + statement("return !!quad_shuffle_xor((ushort)value, dir + 1);"); + end_scope(); + statement(""); + statement("template"); + statement("inline vec spvQuadSwap(vec value, uint dir)"); + begin_scope(); + statement("return (vec)quad_shuffle_xor((vec)value, dir + 1);"); + end_scope(); + statement(""); + break; + + case SPVFuncImplReflectScalar: + // Metal does not support scalar versions of these functions. + // Ensure fast-math is disabled to match Vulkan results. + statement("template"); + statement("[[clang::optnone]] T spvReflect(T i, T n)"); + begin_scope(); + statement("return i - T(2) * i * n * n;"); + end_scope(); + statement(""); + break; + + case SPVFuncImplRefractScalar: + // Metal does not support scalar versions of these functions. + statement("template"); + statement("inline T spvRefract(T i, T n, T eta)"); + begin_scope(); + statement("T NoI = n * i;"); + statement("T NoI2 = NoI * NoI;"); + statement("T k = T(1) - eta * eta * (T(1) - NoI2);"); + statement("if (k < T(0))"); + begin_scope(); + statement("return T(0);"); + end_scope(); + statement("else"); + begin_scope(); + statement("return eta * i - (eta * NoI + sqrt(k)) * n;"); + end_scope(); + end_scope(); + statement(""); + break; + + case SPVFuncImplFaceForwardScalar: + // Metal does not support scalar versions of these functions. + statement("template"); + statement("inline T spvFaceForward(T n, T i, T nref)"); + begin_scope(); + statement("return i * nref < T(0) ? n : -n;"); + end_scope(); + statement(""); + break; + + case SPVFuncImplChromaReconstructNearest2Plane: + statement("template"); + statement("inline vec spvChromaReconstructNearest(texture2d plane0, texture2d plane1, sampler " + "samp, float2 coord, LodOptions... options)"); + begin_scope(); + statement("vec ycbcr = vec(0, 0, 0, 1);"); + statement("ycbcr.g = plane0.sample(samp, coord, spvForward(options)...).r;"); + statement("ycbcr.br = plane1.sample(samp, coord, spvForward(options)...).rg;"); + statement("return ycbcr;"); + end_scope(); + statement(""); + break; + + case SPVFuncImplChromaReconstructNearest3Plane: + statement("template"); + statement("inline vec spvChromaReconstructNearest(texture2d plane0, texture2d plane1, " + "texture2d plane2, sampler samp, float2 coord, LodOptions... options)"); + begin_scope(); + statement("vec ycbcr = vec(0, 0, 0, 1);"); + statement("ycbcr.g = plane0.sample(samp, coord, spvForward(options)...).r;"); + statement("ycbcr.b = plane1.sample(samp, coord, spvForward(options)...).r;"); + statement("ycbcr.r = plane2.sample(samp, coord, spvForward(options)...).r;"); + statement("return ycbcr;"); + end_scope(); + statement(""); + break; + + case SPVFuncImplChromaReconstructLinear422CositedEven2Plane: + statement("template"); + statement("inline vec spvChromaReconstructLinear422CositedEven(texture2d plane0, texture2d " + "plane1, sampler samp, float2 coord, LodOptions... options)"); + begin_scope(); + statement("vec ycbcr = vec(0, 0, 0, 1);"); + statement("ycbcr.g = plane0.sample(samp, coord, spvForward(options)...).r;"); + statement("if (fract(coord.x * plane1.get_width()) != 0.0)"); + begin_scope(); + statement("ycbcr.br = vec(mix(plane1.sample(samp, coord, spvForward(options)...), " + "plane1.sample(samp, coord, spvForward(options)..., int2(1, 0)), 0.5).rg);"); + end_scope(); + statement("else"); + begin_scope(); + statement("ycbcr.br = plane1.sample(samp, coord, spvForward(options)...).rg;"); + end_scope(); + statement("return ycbcr;"); + end_scope(); + statement(""); + break; + + case SPVFuncImplChromaReconstructLinear422CositedEven3Plane: + statement("template"); + statement("inline vec spvChromaReconstructLinear422CositedEven(texture2d plane0, texture2d " + "plane1, texture2d plane2, sampler samp, float2 coord, LodOptions... options)"); + begin_scope(); + statement("vec ycbcr = vec(0, 0, 0, 1);"); + statement("ycbcr.g = plane0.sample(samp, coord, spvForward(options)...).r;"); + statement("if (fract(coord.x * plane1.get_width()) != 0.0)"); + begin_scope(); + statement("ycbcr.b = T(mix(plane1.sample(samp, coord, spvForward(options)...), " + "plane1.sample(samp, coord, spvForward(options)..., int2(1, 0)), 0.5).r);"); + statement("ycbcr.r = T(mix(plane2.sample(samp, coord, spvForward(options)...), " + "plane2.sample(samp, coord, spvForward(options)..., int2(1, 0)), 0.5).r);"); + end_scope(); + statement("else"); + begin_scope(); + statement("ycbcr.b = plane1.sample(samp, coord, spvForward(options)...).r;"); + statement("ycbcr.r = plane2.sample(samp, coord, spvForward(options)...).r;"); + end_scope(); + statement("return ycbcr;"); + end_scope(); + statement(""); + break; + + case SPVFuncImplChromaReconstructLinear422Midpoint2Plane: + statement("template"); + statement("inline vec spvChromaReconstructLinear422Midpoint(texture2d plane0, texture2d " + "plane1, sampler samp, float2 coord, LodOptions... options)"); + begin_scope(); + statement("vec ycbcr = vec(0, 0, 0, 1);"); + statement("ycbcr.g = plane0.sample(samp, coord, spvForward(options)...).r;"); + statement("int2 offs = int2(fract(coord.x * plane1.get_width()) != 0.0 ? 1 : -1, 0);"); + statement("ycbcr.br = vec(mix(plane1.sample(samp, coord, spvForward(options)...), " + "plane1.sample(samp, coord, spvForward(options)..., offs), 0.25).rg);"); + statement("return ycbcr;"); + end_scope(); + statement(""); + break; + + case SPVFuncImplChromaReconstructLinear422Midpoint3Plane: + statement("template"); + statement("inline vec spvChromaReconstructLinear422Midpoint(texture2d plane0, texture2d " + "plane1, texture2d plane2, sampler samp, float2 coord, LodOptions... options)"); + begin_scope(); + statement("vec ycbcr = vec(0, 0, 0, 1);"); + statement("ycbcr.g = plane0.sample(samp, coord, spvForward(options)...).r;"); + statement("int2 offs = int2(fract(coord.x * plane1.get_width()) != 0.0 ? 1 : -1, 0);"); + statement("ycbcr.b = T(mix(plane1.sample(samp, coord, spvForward(options)...), " + "plane1.sample(samp, coord, spvForward(options)..., offs), 0.25).r);"); + statement("ycbcr.r = T(mix(plane2.sample(samp, coord, spvForward(options)...), " + "plane2.sample(samp, coord, spvForward(options)..., offs), 0.25).r);"); + statement("return ycbcr;"); + end_scope(); + statement(""); + break; + + case SPVFuncImplChromaReconstructLinear420XCositedEvenYCositedEven2Plane: + statement("template"); + statement("inline vec spvChromaReconstructLinear420XCositedEvenYCositedEven(texture2d plane0, " + "texture2d plane1, sampler samp, float2 coord, LodOptions... options)"); + begin_scope(); + statement("vec ycbcr = vec(0, 0, 0, 1);"); + statement("ycbcr.g = plane0.sample(samp, coord, spvForward(options)...).r;"); + statement("float2 ab = fract(round(coord * float2(plane0.get_width(), plane0.get_height())) * 0.5);"); + statement("ycbcr.br = vec(mix(mix(plane1.sample(samp, coord, spvForward(options)...), " + "plane1.sample(samp, coord, spvForward(options)..., int2(1, 0)), ab.x), " + "mix(plane1.sample(samp, coord, spvForward(options)..., int2(0, 1)), " + "plane1.sample(samp, coord, spvForward(options)..., int2(1, 1)), ab.x), ab.y).rg);"); + statement("return ycbcr;"); + end_scope(); + statement(""); + break; + + case SPVFuncImplChromaReconstructLinear420XCositedEvenYCositedEven3Plane: + statement("template"); + statement("inline vec spvChromaReconstructLinear420XCositedEvenYCositedEven(texture2d plane0, " + "texture2d plane1, texture2d plane2, sampler samp, float2 coord, LodOptions... options)"); + begin_scope(); + statement("vec ycbcr = vec(0, 0, 0, 1);"); + statement("ycbcr.g = plane0.sample(samp, coord, spvForward(options)...).r;"); + statement("float2 ab = fract(round(coord * float2(plane0.get_width(), plane0.get_height())) * 0.5);"); + statement("ycbcr.b = T(mix(mix(plane1.sample(samp, coord, spvForward(options)...), " + "plane1.sample(samp, coord, spvForward(options)..., int2(1, 0)), ab.x), " + "mix(plane1.sample(samp, coord, spvForward(options)..., int2(0, 1)), " + "plane1.sample(samp, coord, spvForward(options)..., int2(1, 1)), ab.x), ab.y).r);"); + statement("ycbcr.r = T(mix(mix(plane2.sample(samp, coord, spvForward(options)...), " + "plane2.sample(samp, coord, spvForward(options)..., int2(1, 0)), ab.x), " + "mix(plane2.sample(samp, coord, spvForward(options)..., int2(0, 1)), " + "plane2.sample(samp, coord, spvForward(options)..., int2(1, 1)), ab.x), ab.y).r);"); + statement("return ycbcr;"); + end_scope(); + statement(""); + break; + + case SPVFuncImplChromaReconstructLinear420XMidpointYCositedEven2Plane: + statement("template"); + statement("inline vec spvChromaReconstructLinear420XMidpointYCositedEven(texture2d plane0, " + "texture2d plane1, sampler samp, float2 coord, LodOptions... options)"); + begin_scope(); + statement("vec ycbcr = vec(0, 0, 0, 1);"); + statement("ycbcr.g = plane0.sample(samp, coord, spvForward(options)...).r;"); + statement("float2 ab = fract((round(coord * float2(plane0.get_width(), plane0.get_height())) - float2(0.5, " + "0)) * 0.5);"); + statement("ycbcr.br = vec(mix(mix(plane1.sample(samp, coord, spvForward(options)...), " + "plane1.sample(samp, coord, spvForward(options)..., int2(1, 0)), ab.x), " + "mix(plane1.sample(samp, coord, spvForward(options)..., int2(0, 1)), " + "plane1.sample(samp, coord, spvForward(options)..., int2(1, 1)), ab.x), ab.y).rg);"); + statement("return ycbcr;"); + end_scope(); + statement(""); + break; + + case SPVFuncImplChromaReconstructLinear420XMidpointYCositedEven3Plane: + statement("template"); + statement("inline vec spvChromaReconstructLinear420XMidpointYCositedEven(texture2d plane0, " + "texture2d plane1, texture2d plane2, sampler samp, float2 coord, LodOptions... options)"); + begin_scope(); + statement("vec ycbcr = vec(0, 0, 0, 1);"); + statement("ycbcr.g = plane0.sample(samp, coord, spvForward(options)...).r;"); + statement("float2 ab = fract((round(coord * float2(plane0.get_width(), plane0.get_height())) - float2(0.5, " + "0)) * 0.5);"); + statement("ycbcr.b = T(mix(mix(plane1.sample(samp, coord, spvForward(options)...), " + "plane1.sample(samp, coord, spvForward(options)..., int2(1, 0)), ab.x), " + "mix(plane1.sample(samp, coord, spvForward(options)..., int2(0, 1)), " + "plane1.sample(samp, coord, spvForward(options)..., int2(1, 1)), ab.x), ab.y).r);"); + statement("ycbcr.r = T(mix(mix(plane2.sample(samp, coord, spvForward(options)...), " + "plane2.sample(samp, coord, spvForward(options)..., int2(1, 0)), ab.x), " + "mix(plane2.sample(samp, coord, spvForward(options)..., int2(0, 1)), " + "plane2.sample(samp, coord, spvForward(options)..., int2(1, 1)), ab.x), ab.y).r);"); + statement("return ycbcr;"); + end_scope(); + statement(""); + break; + + case SPVFuncImplChromaReconstructLinear420XCositedEvenYMidpoint2Plane: + statement("template"); + statement("inline vec spvChromaReconstructLinear420XCositedEvenYMidpoint(texture2d plane0, " + "texture2d plane1, sampler samp, float2 coord, LodOptions... options)"); + begin_scope(); + statement("vec ycbcr = vec(0, 0, 0, 1);"); + statement("ycbcr.g = plane0.sample(samp, coord, spvForward(options)...).r;"); + statement("float2 ab = fract((round(coord * float2(plane0.get_width(), plane0.get_height())) - float2(0, " + "0.5)) * 0.5);"); + statement("ycbcr.br = vec(mix(mix(plane1.sample(samp, coord, spvForward(options)...), " + "plane1.sample(samp, coord, spvForward(options)..., int2(1, 0)), ab.x), " + "mix(plane1.sample(samp, coord, spvForward(options)..., int2(0, 1)), " + "plane1.sample(samp, coord, spvForward(options)..., int2(1, 1)), ab.x), ab.y).rg);"); + statement("return ycbcr;"); + end_scope(); + statement(""); + break; + + case SPVFuncImplChromaReconstructLinear420XCositedEvenYMidpoint3Plane: + statement("template"); + statement("inline vec spvChromaReconstructLinear420XCositedEvenYMidpoint(texture2d plane0, " + "texture2d plane1, texture2d plane2, sampler samp, float2 coord, LodOptions... options)"); + begin_scope(); + statement("vec ycbcr = vec(0, 0, 0, 1);"); + statement("ycbcr.g = plane0.sample(samp, coord, spvForward(options)...).r;"); + statement("float2 ab = fract((round(coord * float2(plane0.get_width(), plane0.get_height())) - float2(0, " + "0.5)) * 0.5);"); + statement("ycbcr.b = T(mix(mix(plane1.sample(samp, coord, spvForward(options)...), " + "plane1.sample(samp, coord, spvForward(options)..., int2(1, 0)), ab.x), " + "mix(plane1.sample(samp, coord, spvForward(options)..., int2(0, 1)), " + "plane1.sample(samp, coord, spvForward(options)..., int2(1, 1)), ab.x), ab.y).r);"); + statement("ycbcr.r = T(mix(mix(plane2.sample(samp, coord, spvForward(options)...), " + "plane2.sample(samp, coord, spvForward(options)..., int2(1, 0)), ab.x), " + "mix(plane2.sample(samp, coord, spvForward(options)..., int2(0, 1)), " + "plane2.sample(samp, coord, spvForward(options)..., int2(1, 1)), ab.x), ab.y).r);"); + statement("return ycbcr;"); + end_scope(); + statement(""); + break; + + case SPVFuncImplChromaReconstructLinear420XMidpointYMidpoint2Plane: + statement("template"); + statement("inline vec spvChromaReconstructLinear420XMidpointYMidpoint(texture2d plane0, " + "texture2d plane1, sampler samp, float2 coord, LodOptions... options)"); + begin_scope(); + statement("vec ycbcr = vec(0, 0, 0, 1);"); + statement("ycbcr.g = plane0.sample(samp, coord, spvForward(options)...).r;"); + statement("float2 ab = fract((round(coord * float2(plane0.get_width(), plane0.get_height())) - float2(0.5, " + "0.5)) * 0.5);"); + statement("ycbcr.br = vec(mix(mix(plane1.sample(samp, coord, spvForward(options)...), " + "plane1.sample(samp, coord, spvForward(options)..., int2(1, 0)), ab.x), " + "mix(plane1.sample(samp, coord, spvForward(options)..., int2(0, 1)), " + "plane1.sample(samp, coord, spvForward(options)..., int2(1, 1)), ab.x), ab.y).rg);"); + statement("return ycbcr;"); + end_scope(); + statement(""); + break; + + case SPVFuncImplChromaReconstructLinear420XMidpointYMidpoint3Plane: + statement("template"); + statement("inline vec spvChromaReconstructLinear420XMidpointYMidpoint(texture2d plane0, " + "texture2d plane1, texture2d plane2, sampler samp, float2 coord, LodOptions... options)"); + begin_scope(); + statement("vec ycbcr = vec(0, 0, 0, 1);"); + statement("ycbcr.g = plane0.sample(samp, coord, spvForward(options)...).r;"); + statement("float2 ab = fract((round(coord * float2(plane0.get_width(), plane0.get_height())) - float2(0.5, " + "0.5)) * 0.5);"); + statement("ycbcr.b = T(mix(mix(plane1.sample(samp, coord, spvForward(options)...), " + "plane1.sample(samp, coord, spvForward(options)..., int2(1, 0)), ab.x), " + "mix(plane1.sample(samp, coord, spvForward(options)..., int2(0, 1)), " + "plane1.sample(samp, coord, spvForward(options)..., int2(1, 1)), ab.x), ab.y).r);"); + statement("ycbcr.r = T(mix(mix(plane2.sample(samp, coord, spvForward(options)...), " + "plane2.sample(samp, coord, spvForward(options)..., int2(1, 0)), ab.x), " + "mix(plane2.sample(samp, coord, spvForward(options)..., int2(0, 1)), " + "plane2.sample(samp, coord, spvForward(options)..., int2(1, 1)), ab.x), ab.y).r);"); + statement("return ycbcr;"); + end_scope(); + statement(""); + break; + + case SPVFuncImplExpandITUFullRange: + statement("template"); + statement("inline vec spvExpandITUFullRange(vec ycbcr, int n)"); + begin_scope(); + statement("ycbcr.br -= exp2(T(n-1))/(exp2(T(n))-1);"); + statement("return ycbcr;"); + end_scope(); + statement(""); + break; + + case SPVFuncImplExpandITUNarrowRange: + statement("template"); + statement("inline vec spvExpandITUNarrowRange(vec ycbcr, int n)"); + begin_scope(); + statement("ycbcr.g = (ycbcr.g * (exp2(T(n)) - 1) - ldexp(T(16), n - 8))/ldexp(T(219), n - 8);"); + statement("ycbcr.br = (ycbcr.br * (exp2(T(n)) - 1) - ldexp(T(128), n - 8))/ldexp(T(224), n - 8);"); + statement("return ycbcr;"); + end_scope(); + statement(""); + break; + + case SPVFuncImplConvertYCbCrBT709: + statement("// cf. Khronos Data Format Specification, section 15.1.1"); + statement("constant float3x3 spvBT709Factors = {{1, 1, 1}, {0, -0.13397432/0.7152, 1.8556}, {1.5748, " + "-0.33480248/0.7152, 0}};"); + statement(""); + statement("template"); + statement("inline vec spvConvertYCbCrBT709(vec ycbcr)"); + begin_scope(); + statement("vec rgba;"); + statement("rgba.rgb = vec(spvBT709Factors * ycbcr.gbr);"); + statement("rgba.a = ycbcr.a;"); + statement("return rgba;"); + end_scope(); + statement(""); + break; + + case SPVFuncImplConvertYCbCrBT601: + statement("// cf. Khronos Data Format Specification, section 15.1.2"); + statement("constant float3x3 spvBT601Factors = {{1, 1, 1}, {0, -0.202008/0.587, 1.772}, {1.402, " + "-0.419198/0.587, 0}};"); + statement(""); + statement("template"); + statement("inline vec spvConvertYCbCrBT601(vec ycbcr)"); + begin_scope(); + statement("vec rgba;"); + statement("rgba.rgb = vec(spvBT601Factors * ycbcr.gbr);"); + statement("rgba.a = ycbcr.a;"); + statement("return rgba;"); + end_scope(); + statement(""); + break; + + case SPVFuncImplConvertYCbCrBT2020: + statement("// cf. Khronos Data Format Specification, section 15.1.3"); + statement("constant float3x3 spvBT2020Factors = {{1, 1, 1}, {0, -0.11156702/0.6780, 1.8814}, {1.4746, " + "-0.38737742/0.6780, 0}};"); + statement(""); + statement("template"); + statement("inline vec spvConvertYCbCrBT2020(vec ycbcr)"); + begin_scope(); + statement("vec rgba;"); + statement("rgba.rgb = vec(spvBT2020Factors * ycbcr.gbr);"); + statement("rgba.a = ycbcr.a;"); + statement("return rgba;"); + end_scope(); + statement(""); + break; + + case SPVFuncImplDynamicImageSampler: + statement("enum class spvFormatResolution"); + begin_scope(); + statement("_444 = 0,"); + statement("_422,"); + statement("_420"); + end_scope_decl(); + statement(""); + statement("enum class spvChromaFilter"); + begin_scope(); + statement("nearest = 0,"); + statement("linear"); + end_scope_decl(); + statement(""); + statement("enum class spvXChromaLocation"); + begin_scope(); + statement("cosited_even = 0,"); + statement("midpoint"); + end_scope_decl(); + statement(""); + statement("enum class spvYChromaLocation"); + begin_scope(); + statement("cosited_even = 0,"); + statement("midpoint"); + end_scope_decl(); + statement(""); + statement("enum class spvYCbCrModelConversion"); + begin_scope(); + statement("rgb_identity = 0,"); + statement("ycbcr_identity,"); + statement("ycbcr_bt_709,"); + statement("ycbcr_bt_601,"); + statement("ycbcr_bt_2020"); + end_scope_decl(); + statement(""); + statement("enum class spvYCbCrRange"); + begin_scope(); + statement("itu_full = 0,"); + statement("itu_narrow"); + end_scope_decl(); + statement(""); + statement("struct spvComponentBits"); + begin_scope(); + statement("constexpr explicit spvComponentBits(int v) thread : value(v) {}"); + statement("uchar value : 6;"); + end_scope_decl(); + statement("// A class corresponding to metal::sampler which holds sampler"); + statement("// Y'CbCr conversion info."); + statement("struct spvYCbCrSampler"); + begin_scope(); + statement("constexpr spvYCbCrSampler() thread : val(build()) {}"); + statement("template"); + statement("constexpr spvYCbCrSampler(Ts... t) thread : val(build(t...)) {}"); + statement("constexpr spvYCbCrSampler(const thread spvYCbCrSampler& s) thread = default;"); + statement(""); + statement("spvFormatResolution get_resolution() const thread"); + begin_scope(); + statement("return spvFormatResolution((val & resolution_mask) >> resolution_base);"); + end_scope(); + statement("spvChromaFilter get_chroma_filter() const thread"); + begin_scope(); + statement("return spvChromaFilter((val & chroma_filter_mask) >> chroma_filter_base);"); + end_scope(); + statement("spvXChromaLocation get_x_chroma_offset() const thread"); + begin_scope(); + statement("return spvXChromaLocation((val & x_chroma_off_mask) >> x_chroma_off_base);"); + end_scope(); + statement("spvYChromaLocation get_y_chroma_offset() const thread"); + begin_scope(); + statement("return spvYChromaLocation((val & y_chroma_off_mask) >> y_chroma_off_base);"); + end_scope(); + statement("spvYCbCrModelConversion get_ycbcr_model() const thread"); + begin_scope(); + statement("return spvYCbCrModelConversion((val & ycbcr_model_mask) >> ycbcr_model_base);"); + end_scope(); + statement("spvYCbCrRange get_ycbcr_range() const thread"); + begin_scope(); + statement("return spvYCbCrRange((val & ycbcr_range_mask) >> ycbcr_range_base);"); + end_scope(); + statement("int get_bpc() const thread { return (val & bpc_mask) >> bpc_base; }"); + statement(""); + statement("private:"); + statement("ushort val;"); + statement(""); + statement("constexpr static constant ushort resolution_bits = 2;"); + statement("constexpr static constant ushort chroma_filter_bits = 2;"); + statement("constexpr static constant ushort x_chroma_off_bit = 1;"); + statement("constexpr static constant ushort y_chroma_off_bit = 1;"); + statement("constexpr static constant ushort ycbcr_model_bits = 3;"); + statement("constexpr static constant ushort ycbcr_range_bit = 1;"); + statement("constexpr static constant ushort bpc_bits = 6;"); + statement(""); + statement("constexpr static constant ushort resolution_base = 0;"); + statement("constexpr static constant ushort chroma_filter_base = 2;"); + statement("constexpr static constant ushort x_chroma_off_base = 4;"); + statement("constexpr static constant ushort y_chroma_off_base = 5;"); + statement("constexpr static constant ushort ycbcr_model_base = 6;"); + statement("constexpr static constant ushort ycbcr_range_base = 9;"); + statement("constexpr static constant ushort bpc_base = 10;"); + statement(""); + statement( + "constexpr static constant ushort resolution_mask = ((1 << resolution_bits) - 1) << resolution_base;"); + statement("constexpr static constant ushort chroma_filter_mask = ((1 << chroma_filter_bits) - 1) << " + "chroma_filter_base;"); + statement("constexpr static constant ushort x_chroma_off_mask = ((1 << x_chroma_off_bit) - 1) << " + "x_chroma_off_base;"); + statement("constexpr static constant ushort y_chroma_off_mask = ((1 << y_chroma_off_bit) - 1) << " + "y_chroma_off_base;"); + statement("constexpr static constant ushort ycbcr_model_mask = ((1 << ycbcr_model_bits) - 1) << " + "ycbcr_model_base;"); + statement("constexpr static constant ushort ycbcr_range_mask = ((1 << ycbcr_range_bit) - 1) << " + "ycbcr_range_base;"); + statement("constexpr static constant ushort bpc_mask = ((1 << bpc_bits) - 1) << bpc_base;"); + statement(""); + statement("static constexpr ushort build()"); + begin_scope(); + statement("return 0;"); + end_scope(); + statement(""); + statement("template"); + statement("static constexpr ushort build(spvFormatResolution res, Ts... t)"); + begin_scope(); + statement("return (ushort(res) << resolution_base) | (build(t...) & ~resolution_mask);"); + end_scope(); + statement(""); + statement("template"); + statement("static constexpr ushort build(spvChromaFilter filt, Ts... t)"); + begin_scope(); + statement("return (ushort(filt) << chroma_filter_base) | (build(t...) & ~chroma_filter_mask);"); + end_scope(); + statement(""); + statement("template"); + statement("static constexpr ushort build(spvXChromaLocation loc, Ts... t)"); + begin_scope(); + statement("return (ushort(loc) << x_chroma_off_base) | (build(t...) & ~x_chroma_off_mask);"); + end_scope(); + statement(""); + statement("template"); + statement("static constexpr ushort build(spvYChromaLocation loc, Ts... t)"); + begin_scope(); + statement("return (ushort(loc) << y_chroma_off_base) | (build(t...) & ~y_chroma_off_mask);"); + end_scope(); + statement(""); + statement("template"); + statement("static constexpr ushort build(spvYCbCrModelConversion model, Ts... t)"); + begin_scope(); + statement("return (ushort(model) << ycbcr_model_base) | (build(t...) & ~ycbcr_model_mask);"); + end_scope(); + statement(""); + statement("template"); + statement("static constexpr ushort build(spvYCbCrRange range, Ts... t)"); + begin_scope(); + statement("return (ushort(range) << ycbcr_range_base) | (build(t...) & ~ycbcr_range_mask);"); + end_scope(); + statement(""); + statement("template"); + statement("static constexpr ushort build(spvComponentBits bpc, Ts... t)"); + begin_scope(); + statement("return (ushort(bpc.value) << bpc_base) | (build(t...) & ~bpc_mask);"); + end_scope(); + end_scope_decl(); + statement(""); + statement("// A class which can hold up to three textures and a sampler, including"); + statement("// Y'CbCr conversion info, used to pass combined image-samplers"); + statement("// dynamically to functions."); + statement("template"); + statement("struct spvDynamicImageSampler"); + begin_scope(); + statement("texture2d plane0;"); + statement("texture2d plane1;"); + statement("texture2d plane2;"); + statement("sampler samp;"); + statement("spvYCbCrSampler ycbcr_samp;"); + statement("uint swizzle = 0;"); + statement(""); + if (msl_options.swizzle_texture_samples) + { + statement("constexpr spvDynamicImageSampler(texture2d tex, sampler samp, uint sw) thread :"); + statement(" plane0(tex), samp(samp), swizzle(sw) {}"); + } + else + { + statement("constexpr spvDynamicImageSampler(texture2d tex, sampler samp) thread :"); + statement(" plane0(tex), samp(samp) {}"); + } + statement("constexpr spvDynamicImageSampler(texture2d tex, sampler samp, spvYCbCrSampler ycbcr_samp, " + "uint sw) thread :"); + statement(" plane0(tex), samp(samp), ycbcr_samp(ycbcr_samp), swizzle(sw) {}"); + statement("constexpr spvDynamicImageSampler(texture2d plane0, texture2d plane1,"); + statement(" sampler samp, spvYCbCrSampler ycbcr_samp, uint sw) thread :"); + statement(" plane0(plane0), plane1(plane1), samp(samp), ycbcr_samp(ycbcr_samp), swizzle(sw) {}"); + statement( + "constexpr spvDynamicImageSampler(texture2d plane0, texture2d plane1, texture2d plane2,"); + statement(" sampler samp, spvYCbCrSampler ycbcr_samp, uint sw) thread :"); + statement(" plane0(plane0), plane1(plane1), plane2(plane2), samp(samp), ycbcr_samp(ycbcr_samp), " + "swizzle(sw) {}"); + statement(""); + // XXX This is really hard to follow... I've left comments to make it a bit easier. + statement("template"); + statement("vec do_sample(float2 coord, LodOptions... options) const thread"); + begin_scope(); + statement("if (!is_null_texture(plane1))"); + begin_scope(); + statement("if (ycbcr_samp.get_resolution() == spvFormatResolution::_444 ||"); + statement(" ycbcr_samp.get_chroma_filter() == spvChromaFilter::nearest)"); + begin_scope(); + statement("if (!is_null_texture(plane2))"); + statement(" return spvChromaReconstructNearest(plane0, plane1, plane2, samp, coord,"); + statement(" spvForward(options)...);"); + statement( + "return spvChromaReconstructNearest(plane0, plane1, samp, coord, spvForward(options)...);"); + end_scope(); // if (resolution == 422 || chroma_filter == nearest) + statement("switch (ycbcr_samp.get_resolution())"); + begin_scope(); + statement("case spvFormatResolution::_444: break;"); + statement("case spvFormatResolution::_422:"); + begin_scope(); + statement("switch (ycbcr_samp.get_x_chroma_offset())"); + begin_scope(); + statement("case spvXChromaLocation::cosited_even:"); + statement(" if (!is_null_texture(plane2))"); + statement(" return spvChromaReconstructLinear422CositedEven("); + statement(" plane0, plane1, plane2, samp,"); + statement(" coord, spvForward(options)...);"); + statement(" return spvChromaReconstructLinear422CositedEven("); + statement(" plane0, plane1, samp, coord,"); + statement(" spvForward(options)...);"); + statement("case spvXChromaLocation::midpoint:"); + statement(" if (!is_null_texture(plane2))"); + statement(" return spvChromaReconstructLinear422Midpoint("); + statement(" plane0, plane1, plane2, samp,"); + statement(" coord, spvForward(options)...);"); + statement(" return spvChromaReconstructLinear422Midpoint("); + statement(" plane0, plane1, samp, coord,"); + statement(" spvForward(options)...);"); + end_scope(); // switch (x_chroma_offset) + end_scope(); // case 422: + statement("case spvFormatResolution::_420:"); + begin_scope(); + statement("switch (ycbcr_samp.get_x_chroma_offset())"); + begin_scope(); + statement("case spvXChromaLocation::cosited_even:"); + begin_scope(); + statement("switch (ycbcr_samp.get_y_chroma_offset())"); + begin_scope(); + statement("case spvYChromaLocation::cosited_even:"); + statement(" if (!is_null_texture(plane2))"); + statement(" return spvChromaReconstructLinear420XCositedEvenYCositedEven("); + statement(" plane0, plane1, plane2, samp,"); + statement(" coord, spvForward(options)...);"); + statement(" return spvChromaReconstructLinear420XCositedEvenYCositedEven("); + statement(" plane0, plane1, samp, coord,"); + statement(" spvForward(options)...);"); + statement("case spvYChromaLocation::midpoint:"); + statement(" if (!is_null_texture(plane2))"); + statement(" return spvChromaReconstructLinear420XCositedEvenYMidpoint("); + statement(" plane0, plane1, plane2, samp,"); + statement(" coord, spvForward(options)...);"); + statement(" return spvChromaReconstructLinear420XCositedEvenYMidpoint("); + statement(" plane0, plane1, samp, coord,"); + statement(" spvForward(options)...);"); + end_scope(); // switch (y_chroma_offset) + end_scope(); // case x::cosited_even: + statement("case spvXChromaLocation::midpoint:"); + begin_scope(); + statement("switch (ycbcr_samp.get_y_chroma_offset())"); + begin_scope(); + statement("case spvYChromaLocation::cosited_even:"); + statement(" if (!is_null_texture(plane2))"); + statement(" return spvChromaReconstructLinear420XMidpointYCositedEven("); + statement(" plane0, plane1, plane2, samp,"); + statement(" coord, spvForward(options)...);"); + statement(" return spvChromaReconstructLinear420XMidpointYCositedEven("); + statement(" plane0, plane1, samp, coord,"); + statement(" spvForward(options)...);"); + statement("case spvYChromaLocation::midpoint:"); + statement(" if (!is_null_texture(plane2))"); + statement(" return spvChromaReconstructLinear420XMidpointYMidpoint("); + statement(" plane0, plane1, plane2, samp,"); + statement(" coord, spvForward(options)...);"); + statement(" return spvChromaReconstructLinear420XMidpointYMidpoint("); + statement(" plane0, plane1, samp, coord,"); + statement(" spvForward(options)...);"); + end_scope(); // switch (y_chroma_offset) + end_scope(); // case x::midpoint + end_scope(); // switch (x_chroma_offset) + end_scope(); // case 420: + end_scope(); // switch (resolution) + end_scope(); // if (multiplanar) + statement("return plane0.sample(samp, coord, spvForward(options)...);"); + end_scope(); // do_sample() + statement("template "); + statement("vec sample(float2 coord, LodOptions... options) const thread"); + begin_scope(); + statement( + "vec s = spvTextureSwizzle(do_sample(coord, spvForward(options)...), swizzle);"); + statement("if (ycbcr_samp.get_ycbcr_model() == spvYCbCrModelConversion::rgb_identity)"); + statement(" return s;"); + statement(""); + statement("switch (ycbcr_samp.get_ycbcr_range())"); + begin_scope(); + statement("case spvYCbCrRange::itu_full:"); + statement(" s = spvExpandITUFullRange(s, ycbcr_samp.get_bpc());"); + statement(" break;"); + statement("case spvYCbCrRange::itu_narrow:"); + statement(" s = spvExpandITUNarrowRange(s, ycbcr_samp.get_bpc());"); + statement(" break;"); + end_scope(); + statement(""); + statement("switch (ycbcr_samp.get_ycbcr_model())"); + begin_scope(); + statement("case spvYCbCrModelConversion::rgb_identity:"); // Silence Clang warning + statement("case spvYCbCrModelConversion::ycbcr_identity:"); + statement(" return s;"); + statement("case spvYCbCrModelConversion::ycbcr_bt_709:"); + statement(" return spvConvertYCbCrBT709(s);"); + statement("case spvYCbCrModelConversion::ycbcr_bt_601:"); + statement(" return spvConvertYCbCrBT601(s);"); + statement("case spvYCbCrModelConversion::ycbcr_bt_2020:"); + statement(" return spvConvertYCbCrBT2020(s);"); + end_scope(); + end_scope(); + statement(""); + // Sampler Y'CbCr conversion forbids offsets. + statement("vec sample(float2 coord, int2 offset) const thread"); + begin_scope(); + if (msl_options.swizzle_texture_samples) + statement("return spvTextureSwizzle(plane0.sample(samp, coord, offset), swizzle);"); + else + statement("return plane0.sample(samp, coord, offset);"); + end_scope(); + statement("template"); + statement("vec sample(float2 coord, lod_options options, int2 offset) const thread"); + begin_scope(); + if (msl_options.swizzle_texture_samples) + statement("return spvTextureSwizzle(plane0.sample(samp, coord, options, offset), swizzle);"); + else + statement("return plane0.sample(samp, coord, options, offset);"); + end_scope(); + statement("#if __HAVE_MIN_LOD_CLAMP__"); + statement("vec sample(float2 coord, bias b, min_lod_clamp min_lod, int2 offset) const thread"); + begin_scope(); + statement("return plane0.sample(samp, coord, b, min_lod, offset);"); + end_scope(); + statement( + "vec sample(float2 coord, gradient2d grad, min_lod_clamp min_lod, int2 offset) const thread"); + begin_scope(); + statement("return plane0.sample(samp, coord, grad, min_lod, offset);"); + end_scope(); + statement("#endif"); + statement(""); + // Y'CbCr conversion forbids all operations but sampling. + statement("vec read(uint2 coord, uint lod = 0) const thread"); + begin_scope(); + statement("return plane0.read(coord, lod);"); + end_scope(); + statement(""); + statement("vec gather(float2 coord, int2 offset = int2(0), component c = component::x) const thread"); + begin_scope(); + if (msl_options.swizzle_texture_samples) + statement("return spvGatherSwizzle(plane0, samp, swizzle, c, coord, offset);"); + else + statement("return plane0.gather(samp, coord, offset, c);"); + end_scope(); + end_scope_decl(); + statement(""); + break; + + case SPVFuncImplRayQueryIntersectionParams: + statement("intersection_params spvMakeIntersectionParams(uint flags)"); + begin_scope(); + statement("intersection_params ip;"); + statement("if ((flags & ", RayFlagsOpaqueKHRMask, ") != 0)"); + statement(" ip.force_opacity(forced_opacity::opaque);"); + statement("if ((flags & ", RayFlagsNoOpaqueKHRMask, ") != 0)"); + statement(" ip.force_opacity(forced_opacity::non_opaque);"); + statement("if ((flags & ", RayFlagsTerminateOnFirstHitKHRMask, ") != 0)"); + statement(" ip.accept_any_intersection(true);"); + // RayFlagsSkipClosestHitShaderKHRMask is not available in MSL + statement("if ((flags & ", RayFlagsCullBackFacingTrianglesKHRMask, ") != 0)"); + statement(" ip.set_triangle_cull_mode(triangle_cull_mode::back);"); + statement("if ((flags & ", RayFlagsCullFrontFacingTrianglesKHRMask, ") != 0)"); + statement(" ip.set_triangle_cull_mode(triangle_cull_mode::front);"); + statement("if ((flags & ", RayFlagsCullOpaqueKHRMask, ") != 0)"); + statement(" ip.set_opacity_cull_mode(opacity_cull_mode::opaque);"); + statement("if ((flags & ", RayFlagsCullNoOpaqueKHRMask, ") != 0)"); + statement(" ip.set_opacity_cull_mode(opacity_cull_mode::non_opaque);"); + statement("if ((flags & ", RayFlagsSkipTrianglesKHRMask, ") != 0)"); + statement(" ip.set_geometry_cull_mode(geometry_cull_mode::triangle);"); + statement("if ((flags & ", RayFlagsSkipAABBsKHRMask, ") != 0)"); + statement(" ip.set_geometry_cull_mode(geometry_cull_mode::bounding_box);"); + statement("return ip;"); + end_scope(); + statement(""); + break; + + case SPVFuncImplVariableDescriptor: + statement("template"); + statement("struct spvDescriptor"); + begin_scope(); + statement("T value;"); + end_scope_decl(); + statement(""); + break; + + case SPVFuncImplVariableSizedDescriptor: + statement("template"); + statement("struct spvBufferDescriptor"); + begin_scope(); + statement("T value;"); + statement("int length;"); + statement("const device T& operator -> () const device"); + begin_scope(); + statement("return value;"); + end_scope(); + statement("const device T& operator * () const device"); + begin_scope(); + statement("return value;"); + end_scope(); + end_scope_decl(); + statement(""); + break; + + case SPVFuncImplVariableDescriptorArray: + if (spv_function_implementations.count(SPVFuncImplVariableDescriptor) != 0) + { + statement("template"); + statement("struct spvDescriptorArray"); + begin_scope(); + statement("spvDescriptorArray(const device spvDescriptor* ptr) : ptr(&ptr->value)"); + begin_scope(); + end_scope(); + statement("const device T& operator [] (size_t i) const"); + begin_scope(); + statement("return ptr[i];"); + end_scope(); + statement("const device T* ptr;"); + end_scope_decl(); + statement(""); + } + else + { + statement("template"); + statement("struct spvDescriptorArray;"); + statement(""); + } + + if (msl_options.runtime_array_rich_descriptor && + spv_function_implementations.count(SPVFuncImplVariableSizedDescriptor) != 0) + { + statement("template"); + statement("struct spvDescriptorArray"); + begin_scope(); + statement("spvDescriptorArray(const device spvBufferDescriptor* ptr) : ptr(ptr)"); + begin_scope(); + end_scope(); + statement("const device T* operator [] (size_t i) const"); + begin_scope(); + statement("return ptr[i].value;"); + end_scope(); + statement("const int length(int i) const"); + begin_scope(); + statement("return ptr[i].length;"); + end_scope(); + statement("const device spvBufferDescriptor* ptr;"); + end_scope_decl(); + statement(""); + } + break; + + case SPVFuncImplPaddedStd140: + // .data is used in access chain. + statement("template "); + statement("struct spvPaddedStd140 { alignas(16) T data; };"); + statement("template "); + statement("using spvPaddedStd140Matrix = spvPaddedStd140[n];"); + statement(""); + break; + + case SPVFuncImplReduceAdd: + // Metal doesn't support __builtin_reduce_add or simd_reduce_add, so we need this. + // Metal also doesn't support the other vector builtins, which would have been useful to make this a single template. + + statement("template "); + statement("T reduce_add(vec v) { return v.x + v.y; }"); + + statement("template "); + statement("T reduce_add(vec v) { return v.x + v.y + v.z; }"); + + statement("template "); + statement("T reduce_add(vec v) { return v.x + v.y + v.z + v.w; }"); + + statement(""); + break; + + case SPVFuncImplImageFence: + statement("template "); + statement("void spvImageFence(ImageT img) { img.fence(); }"); + statement(""); + break; + + case SPVFuncImplTextureCast: + statement("template "); + statement("T spvTextureCast(U img)"); + begin_scope(); + // MSL complains if you try to cast the texture itself, but casting the reference type is ... ok? *shrug* + // Gotta go what you gotta do I suppose. + statement("return reinterpret_cast(img);"); + end_scope(); + statement(""); + break; + + default: + break; + } + } +} + +static string inject_top_level_storage_qualifier(const string &expr, const string &qualifier) +{ + // Easier to do this through text munging since the qualifier does not exist in the type system at all, + // and plumbing in all that information is not very helpful. + size_t last_reference = expr.find_last_of('&'); + size_t last_pointer = expr.find_last_of('*'); + size_t last_significant = string::npos; + + if (last_reference == string::npos) + last_significant = last_pointer; + else if (last_pointer == string::npos) + last_significant = last_reference; + else + last_significant = max(last_reference, last_pointer); + + if (last_significant == string::npos) + return join(qualifier, " ", expr); + else + { + return join(expr.substr(0, last_significant + 1), " ", + qualifier, expr.substr(last_significant + 1, string::npos)); + } +} + +void CompilerMSL::declare_constant_arrays() +{ + bool fully_inlined = ir.ids_for_type[TypeFunction].size() == 1; + + // MSL cannot declare arrays inline (except when declaring a variable), so we must move them out to + // global constants directly, so we are able to use constants as variable expressions. + bool emitted = false; + + ir.for_each_typed_id([&](uint32_t, SPIRConstant &c) { + if (c.specialization) + return; + + auto &type = this->get(c.constant_type); + // Constant arrays of non-primitive types (i.e. matrices) won't link properly into Metal libraries. + // FIXME: However, hoisting constants to main() means we need to pass down constant arrays to leaf functions if they are used there. + // If there are multiple functions in the module, drop this case to avoid breaking use cases which do not need to + // link into Metal libraries. This is hacky. + if (is_array(type) && (!fully_inlined || is_scalar(type) || is_vector(type))) + { + add_resource_name(c.self); + auto name = to_name(c.self); + statement(inject_top_level_storage_qualifier(variable_decl(type, name), "constant"), + " = ", constant_expression(c), ";"); + emitted = true; + } + }); + + if (emitted) + statement(""); +} + +// Constant arrays of non-primitive types (i.e. matrices) won't link properly into Metal libraries +void CompilerMSL::declare_complex_constant_arrays() +{ + // If we do not have a fully inlined module, we did not opt in to + // declaring constant arrays of complex types. See CompilerMSL::declare_constant_arrays(). + bool fully_inlined = ir.ids_for_type[TypeFunction].size() == 1; + if (!fully_inlined) + return; + + // MSL cannot declare arrays inline (except when declaring a variable), so we must move them out to + // global constants directly, so we are able to use constants as variable expressions. + bool emitted = false; + + ir.for_each_typed_id([&](uint32_t, SPIRConstant &c) { + if (c.specialization) + return; + + auto &type = this->get(c.constant_type); + if (is_array(type) && !(is_scalar(type) || is_vector(type))) + { + add_resource_name(c.self); + auto name = to_name(c.self); + statement("", variable_decl(type, name), " = ", constant_expression(c), ";"); + emitted = true; + } + }); + + if (emitted) + statement(""); +} + +void CompilerMSL::emit_resources() +{ + declare_constant_arrays(); + + // Emit the special [[stage_in]] and [[stage_out]] interface blocks which we created. + emit_interface_block(stage_out_var_id); + emit_interface_block(patch_stage_out_var_id); + emit_interface_block(stage_in_var_id); + emit_interface_block(patch_stage_in_var_id); +} + +// Emit declarations for the specialization Metal function constants +void CompilerMSL::emit_specialization_constants_and_structs() +{ + SpecializationConstant wg_x, wg_y, wg_z; + ID workgroup_size_id = get_work_group_size_specialization_constants(wg_x, wg_y, wg_z); + bool emitted = false; + + unordered_set declared_structs; + unordered_set aligned_structs; + + // First, we need to deal with scalar block layout. + // It is possible that a struct may have to be placed at an alignment which does not match the innate alignment of the struct itself. + // In that case, if such a case exists for a struct, we must force that all elements of the struct become packed_ types. + // This makes the struct alignment as small as physically possible. + // When we actually align the struct later, we can insert padding as necessary to make the packed members behave like normally aligned types. + ir.for_each_typed_id([&](uint32_t type_id, const SPIRType &type) { + if (type.basetype == SPIRType::Struct && + has_extended_decoration(type_id, SPIRVCrossDecorationBufferBlockRepacked)) + mark_scalar_layout_structs(type); + }); + + bool builtin_block_type_is_required = false; + // Very special case. If gl_PerVertex is initialized as an array (tessellation) + // we have to potentially emit the gl_PerVertex struct type so that we can emit a constant LUT. + ir.for_each_typed_id([&](uint32_t, SPIRConstant &c) { + auto &type = this->get(c.constant_type); + if (is_array(type) && has_decoration(type.self, DecorationBlock) && is_builtin_type(type)) + builtin_block_type_is_required = true; + }); + + // Very particular use of the soft loop lock. + // align_struct may need to create custom types on the fly, but we don't care about + // these types for purpose of iterating over them in ir.ids_for_type and friends. + auto loop_lock = ir.create_loop_soft_lock(); + + // Physical storage buffer pointers can have cyclical references, + // so emit forward declarations of them before other structs. + // Ignore type_id because we want the underlying struct type from the pointer. + ir.for_each_typed_id([&](uint32_t /* type_id */, const SPIRType &type) { + if (type.basetype == SPIRType::Struct && + type.pointer && type.storage == StorageClassPhysicalStorageBuffer && + declared_structs.count(type.self) == 0) + { + statement("struct ", to_name(type.self), ";"); + declared_structs.insert(type.self); + emitted = true; + } + }); + if (emitted) + statement(""); + + emitted = false; + declared_structs.clear(); + + // It is possible to have multiple spec constants that use the same spec constant ID. + // The most common cause of this is defining spec constants in GLSL while also declaring + // the workgroup size to use those spec constants. But, Metal forbids declaring more than + // one variable with the same function constant ID. + // In this case, we must only declare one variable with the [[function_constant(id)]] + // attribute, and use its initializer to initialize all the spec constants with + // that ID. + std::unordered_map unique_func_constants; + + for (auto &id_ : ir.ids_for_constant_undef_or_type) + { + auto &id = ir.ids[id_]; + + if (id.get_type() == TypeConstant) + { + auto &c = id.get(); + + if (c.self == workgroup_size_id) + { + // TODO: This can be expressed as a [[threads_per_threadgroup]] input semantic, but we need to know + // the work group size at compile time in SPIR-V, and [[threads_per_threadgroup]] would need to be passed around as a global. + // The work group size may be a specialization constant. + statement("constant uint3 ", builtin_to_glsl(BuiltInWorkgroupSize, StorageClassWorkgroup), + " [[maybe_unused]] = ", constant_expression(get(workgroup_size_id)), ";"); + emitted = true; + } + else if (c.specialization) + { + auto &type = get(c.constant_type); + string sc_type_name = type_to_glsl(type); + add_resource_name(c.self); + string sc_name = to_name(c.self); + + // Function constants are only supported in MSL 1.2 and later. + // If we don't support it just declare the "default" directly. + // This "default" value can be overridden to the true specialization constant by the API user. + // Specialization constants which are used as array length expressions cannot be function constants in MSL, + // so just fall back to macros. + if (msl_options.supports_msl_version(1, 2) && has_decoration(c.self, DecorationSpecId) && + !c.is_used_as_array_length) + { + // Only scalar, non-composite values can be function constants. + uint32_t constant_id = get_decoration(c.self, DecorationSpecId); + if (!unique_func_constants.count(constant_id)) + unique_func_constants.insert(make_pair(constant_id, c.self)); + SPIRType::BaseType sc_tmp_type = expression_type(unique_func_constants[constant_id]).basetype; + string sc_tmp_name = to_name(unique_func_constants[constant_id]) + "_tmp"; + if (unique_func_constants[constant_id] == c.self) + statement("constant ", sc_type_name, " ", sc_tmp_name, " [[function_constant(", constant_id, + ")]];"); + statement("constant ", sc_type_name, " ", sc_name, " = is_function_constant_defined(", sc_tmp_name, + ") ? ", bitcast_expression(type, sc_tmp_type, sc_tmp_name), " : ", constant_expression(c), + ";"); + } + else if (has_decoration(c.self, DecorationSpecId)) + { + // Fallback to macro overrides. + c.specialization_constant_macro_name = + constant_value_macro_name(get_decoration(c.self, DecorationSpecId)); + + statement("#ifndef ", c.specialization_constant_macro_name); + statement("#define ", c.specialization_constant_macro_name, " ", constant_expression(c)); + statement("#endif"); + statement("constant ", sc_type_name, " ", sc_name, " = ", c.specialization_constant_macro_name, + ";"); + } + else + { + // Composite specialization constants must be built from other specialization constants. + statement("constant ", sc_type_name, " ", sc_name, " = ", constant_expression(c), ";"); + } + emitted = true; + } + } + else if (id.get_type() == TypeConstantOp) + { + auto &c = id.get(); + auto &type = get(c.basetype); + add_resource_name(c.self); + auto name = to_name(c.self); + statement("constant ", variable_decl(type, name), " = ", constant_op_expression(c), ";"); + emitted = true; + } + else if (id.get_type() == TypeType) + { + // Output non-builtin interface structs. These include local function structs + // and structs nested within uniform and read-write buffers. + auto &type = id.get(); + TypeID type_id = type.self; + + bool is_struct = (type.basetype == SPIRType::Struct) && type.array.empty() && !type.pointer; + bool is_block = + has_decoration(type.self, DecorationBlock) || has_decoration(type.self, DecorationBufferBlock); + + bool is_builtin_block = is_block && is_builtin_type(type); + bool is_declarable_struct = is_struct && (!is_builtin_block || builtin_block_type_is_required); + + // We'll declare this later. + if (stage_out_var_id && get_stage_out_struct_type().self == type_id) + is_declarable_struct = false; + if (patch_stage_out_var_id && get_patch_stage_out_struct_type().self == type_id) + is_declarable_struct = false; + if (stage_in_var_id && get_stage_in_struct_type().self == type_id) + is_declarable_struct = false; + if (patch_stage_in_var_id && get_patch_stage_in_struct_type().self == type_id) + is_declarable_struct = false; + + // Special case. Declare builtin struct anyways if we need to emit a threadgroup version of it. + if (stage_out_masked_builtin_type_id == type_id) + is_declarable_struct = true; + + // Align and emit declarable structs...but avoid declaring each more than once. + if (is_declarable_struct && declared_structs.count(type_id) == 0) + { + if (emitted) + statement(""); + emitted = false; + + declared_structs.insert(type_id); + + if (has_extended_decoration(type_id, SPIRVCrossDecorationBufferBlockRepacked)) + align_struct(type, aligned_structs); + + // Make sure we declare the underlying struct type, and not the "decorated" type with pointers, etc. + emit_struct(get(type_id)); + } + } + else if (id.get_type() == TypeUndef) + { + auto &undef = id.get(); + auto &type = get(undef.basetype); + // OpUndef can be void for some reason ... + if (type.basetype == SPIRType::Void) + return; + + // Undefined global memory is not allowed in MSL. + // Declare constant and init to zeros. Use {}, as global constructors can break Metal. + statement( + inject_top_level_storage_qualifier(variable_decl(type, to_name(undef.self), undef.self), "constant"), + " = {};"); + emitted = true; + } + } + + if (emitted) + statement(""); +} + +void CompilerMSL::emit_binary_ptr_op(uint32_t result_type, uint32_t result_id, uint32_t op0, uint32_t op1, const char *op) +{ + bool forward = should_forward(op0) && should_forward(op1); + emit_op(result_type, result_id, join(to_ptr_expression(op0), " ", op, " ", to_ptr_expression(op1)), forward); + inherit_expression_dependencies(result_id, op0); + inherit_expression_dependencies(result_id, op1); +} + +string CompilerMSL::to_ptr_expression(uint32_t id, bool register_expression_read) +{ + auto *e = maybe_get(id); + auto expr = enclose_expression(e && e->need_transpose ? e->expression : to_expression(id, register_expression_read)); + if (!should_dereference(id)) + expr = address_of_expression(expr); + return expr; +} + +void CompilerMSL::emit_binary_unord_op(uint32_t result_type, uint32_t result_id, uint32_t op0, uint32_t op1, + const char *op) +{ + bool forward = should_forward(op0) && should_forward(op1); + emit_op(result_type, result_id, + join("(isunordered(", to_enclosed_unpacked_expression(op0), ", ", to_enclosed_unpacked_expression(op1), + ") || ", to_enclosed_unpacked_expression(op0), " ", op, " ", to_enclosed_unpacked_expression(op1), + ")"), + forward); + + inherit_expression_dependencies(result_id, op0); + inherit_expression_dependencies(result_id, op1); +} + +bool CompilerMSL::emit_tessellation_io_load(uint32_t result_type_id, uint32_t id, uint32_t ptr) +{ + auto &ptr_type = expression_type(ptr); + auto &result_type = get(result_type_id); + if (ptr_type.storage != StorageClassInput && ptr_type.storage != StorageClassOutput) + return false; + if (ptr_type.storage == StorageClassOutput && is_tese_shader()) + return false; + + if (has_decoration(ptr, DecorationPatch)) + return false; + bool ptr_is_io_variable = ir.ids[ptr].get_type() == TypeVariable; + + bool flattened_io = variable_storage_requires_stage_io(ptr_type.storage); + + bool flat_data_type = flattened_io && + (is_matrix(result_type) || is_array(result_type) || result_type.basetype == SPIRType::Struct); + + // Edge case, even with multi-patch workgroups, we still need to unroll load + // if we're loading control points directly. + if (ptr_is_io_variable && is_array(result_type)) + flat_data_type = true; + + if (!flat_data_type) + return false; + + // Now, we must unflatten a composite type and take care of interleaving array access with gl_in/gl_out. + // Lots of painful code duplication since we *really* should not unroll these kinds of loads in entry point fixup + // unless we're forced to do this when the code is emitting inoptimal OpLoads. + string expr; + + uint32_t interface_index = get_extended_decoration(ptr, SPIRVCrossDecorationInterfaceMemberIndex); + auto *var = maybe_get_backing_variable(ptr); + auto &expr_type = get_pointee_type(ptr_type.self); + + const auto &iface_type = expression_type(stage_in_ptr_var_id); + + if (!flattened_io) + { + // Simplest case for multi-patch workgroups, just unroll array as-is. + if (interface_index == uint32_t(-1)) + return false; + + expr += type_to_glsl(result_type) + "({ "; + uint32_t num_control_points = to_array_size_literal(result_type, uint32_t(result_type.array.size()) - 1); + + for (uint32_t i = 0; i < num_control_points; i++) + { + const uint32_t indices[2] = { i, interface_index }; + AccessChainMeta meta; + expr += access_chain_internal(stage_in_ptr_var_id, indices, 2, + ACCESS_CHAIN_INDEX_IS_LITERAL_BIT | ACCESS_CHAIN_PTR_CHAIN_BIT, &meta); + if (i + 1 < num_control_points) + expr += ", "; + } + expr += " })"; + } + else if (result_type.array.size() > 2) + { + SPIRV_CROSS_THROW("Cannot load tessellation IO variables with more than 2 dimensions."); + } + else if (result_type.array.size() == 2) + { + if (!ptr_is_io_variable) + SPIRV_CROSS_THROW("Loading an array-of-array must be loaded directly from an IO variable."); + if (interface_index == uint32_t(-1)) + SPIRV_CROSS_THROW("Interface index is unknown. Cannot continue."); + if (result_type.basetype == SPIRType::Struct || is_matrix(result_type)) + SPIRV_CROSS_THROW("Cannot load array-of-array of composite type in tessellation IO."); + + expr += type_to_glsl(result_type) + "({ "; + uint32_t num_control_points = to_array_size_literal(result_type, 1); + uint32_t base_interface_index = interface_index; + + auto &sub_type = get(result_type.parent_type); + + for (uint32_t i = 0; i < num_control_points; i++) + { + expr += type_to_glsl(sub_type) + "({ "; + interface_index = base_interface_index; + uint32_t array_size = to_array_size_literal(result_type, 0); + for (uint32_t j = 0; j < array_size; j++, interface_index++) + { + const uint32_t indices[2] = { i, interface_index }; + + AccessChainMeta meta; + expr += access_chain_internal(stage_in_ptr_var_id, indices, 2, + ACCESS_CHAIN_INDEX_IS_LITERAL_BIT | ACCESS_CHAIN_PTR_CHAIN_BIT, &meta); + if (!is_matrix(sub_type) && sub_type.basetype != SPIRType::Struct && + expr_type.vecsize > sub_type.vecsize) + expr += vector_swizzle(sub_type.vecsize, 0); + + if (j + 1 < array_size) + expr += ", "; + } + expr += " })"; + if (i + 1 < num_control_points) + expr += ", "; + } + expr += " })"; + } + else if (result_type.basetype == SPIRType::Struct) + { + bool is_array_of_struct = is_array(result_type); + if (is_array_of_struct && !ptr_is_io_variable) + SPIRV_CROSS_THROW("Loading array of struct from IO variable must come directly from IO variable."); + + uint32_t num_control_points = 1; + if (is_array_of_struct) + { + num_control_points = to_array_size_literal(result_type, 0); + expr += type_to_glsl(result_type) + "({ "; + } + + auto &struct_type = is_array_of_struct ? get(result_type.parent_type) : result_type; + assert(struct_type.array.empty()); + + for (uint32_t i = 0; i < num_control_points; i++) + { + expr += type_to_glsl(struct_type) + "{ "; + for (uint32_t j = 0; j < uint32_t(struct_type.member_types.size()); j++) + { + // The base interface index is stored per variable for structs. + if (var) + { + interface_index = + get_extended_member_decoration(var->self, j, SPIRVCrossDecorationInterfaceMemberIndex); + } + + if (interface_index == uint32_t(-1)) + SPIRV_CROSS_THROW("Interface index is unknown. Cannot continue."); + + const auto &mbr_type = get(struct_type.member_types[j]); + const auto &expr_mbr_type = get(expr_type.member_types[j]); + if (is_matrix(mbr_type) && ptr_type.storage == StorageClassInput) + { + expr += type_to_glsl(mbr_type) + "("; + for (uint32_t k = 0; k < mbr_type.columns; k++, interface_index++) + { + if (is_array_of_struct) + { + const uint32_t indices[2] = { i, interface_index }; + AccessChainMeta meta; + expr += access_chain_internal( + stage_in_ptr_var_id, indices, 2, + ACCESS_CHAIN_INDEX_IS_LITERAL_BIT | ACCESS_CHAIN_PTR_CHAIN_BIT, &meta); + } + else + expr += to_expression(ptr) + "." + to_member_name(iface_type, interface_index); + if (expr_mbr_type.vecsize > mbr_type.vecsize) + expr += vector_swizzle(mbr_type.vecsize, 0); + + if (k + 1 < mbr_type.columns) + expr += ", "; + } + expr += ")"; + } + else if (is_array(mbr_type)) + { + expr += type_to_glsl(mbr_type) + "({ "; + uint32_t array_size = to_array_size_literal(mbr_type, 0); + for (uint32_t k = 0; k < array_size; k++, interface_index++) + { + if (is_array_of_struct) + { + const uint32_t indices[2] = { i, interface_index }; + AccessChainMeta meta; + expr += access_chain_internal( + stage_in_ptr_var_id, indices, 2, + ACCESS_CHAIN_INDEX_IS_LITERAL_BIT | ACCESS_CHAIN_PTR_CHAIN_BIT, &meta); + } + else + expr += to_expression(ptr) + "." + to_member_name(iface_type, interface_index); + if (expr_mbr_type.vecsize > mbr_type.vecsize) + expr += vector_swizzle(mbr_type.vecsize, 0); + + if (k + 1 < array_size) + expr += ", "; + } + expr += " })"; + } + else + { + if (is_array_of_struct) + { + const uint32_t indices[2] = { i, interface_index }; + AccessChainMeta meta; + expr += access_chain_internal(stage_in_ptr_var_id, indices, 2, + ACCESS_CHAIN_INDEX_IS_LITERAL_BIT | ACCESS_CHAIN_PTR_CHAIN_BIT, + &meta); + } + else + expr += to_expression(ptr) + "." + to_member_name(iface_type, interface_index); + if (expr_mbr_type.vecsize > mbr_type.vecsize) + expr += vector_swizzle(mbr_type.vecsize, 0); + } + + if (j + 1 < struct_type.member_types.size()) + expr += ", "; + } + expr += " }"; + if (i + 1 < num_control_points) + expr += ", "; + } + if (is_array_of_struct) + expr += " })"; + } + else if (is_matrix(result_type)) + { + bool is_array_of_matrix = is_array(result_type); + if (is_array_of_matrix && !ptr_is_io_variable) + SPIRV_CROSS_THROW("Loading array of matrix from IO variable must come directly from IO variable."); + if (interface_index == uint32_t(-1)) + SPIRV_CROSS_THROW("Interface index is unknown. Cannot continue."); + + if (is_array_of_matrix) + { + // Loading a matrix from each control point. + uint32_t base_interface_index = interface_index; + uint32_t num_control_points = to_array_size_literal(result_type, 0); + expr += type_to_glsl(result_type) + "({ "; + + auto &matrix_type = get_variable_element_type(get(ptr)); + + for (uint32_t i = 0; i < num_control_points; i++) + { + interface_index = base_interface_index; + expr += type_to_glsl(matrix_type) + "("; + for (uint32_t j = 0; j < result_type.columns; j++, interface_index++) + { + const uint32_t indices[2] = { i, interface_index }; + + AccessChainMeta meta; + expr += access_chain_internal(stage_in_ptr_var_id, indices, 2, + ACCESS_CHAIN_INDEX_IS_LITERAL_BIT | ACCESS_CHAIN_PTR_CHAIN_BIT, &meta); + if (expr_type.vecsize > result_type.vecsize) + expr += vector_swizzle(result_type.vecsize, 0); + if (j + 1 < result_type.columns) + expr += ", "; + } + expr += ")"; + if (i + 1 < num_control_points) + expr += ", "; + } + + expr += " })"; + } + else + { + expr += type_to_glsl(result_type) + "("; + for (uint32_t i = 0; i < result_type.columns; i++, interface_index++) + { + expr += to_expression(ptr) + "." + to_member_name(iface_type, interface_index); + if (expr_type.vecsize > result_type.vecsize) + expr += vector_swizzle(result_type.vecsize, 0); + if (i + 1 < result_type.columns) + expr += ", "; + } + expr += ")"; + } + } + else if (ptr_is_io_variable) + { + assert(is_array(result_type)); + assert(result_type.array.size() == 1); + if (interface_index == uint32_t(-1)) + SPIRV_CROSS_THROW("Interface index is unknown. Cannot continue."); + + // We're loading an array directly from a global variable. + // This means we're loading one member from each control point. + expr += type_to_glsl(result_type) + "({ "; + uint32_t num_control_points = to_array_size_literal(result_type, 0); + + for (uint32_t i = 0; i < num_control_points; i++) + { + const uint32_t indices[2] = { i, interface_index }; + + AccessChainMeta meta; + expr += access_chain_internal(stage_in_ptr_var_id, indices, 2, + ACCESS_CHAIN_INDEX_IS_LITERAL_BIT | ACCESS_CHAIN_PTR_CHAIN_BIT, &meta); + if (expr_type.vecsize > result_type.vecsize) + expr += vector_swizzle(result_type.vecsize, 0); + + if (i + 1 < num_control_points) + expr += ", "; + } + expr += " })"; + } + else + { + // We're loading an array from a concrete control point. + assert(is_array(result_type)); + assert(result_type.array.size() == 1); + if (interface_index == uint32_t(-1)) + SPIRV_CROSS_THROW("Interface index is unknown. Cannot continue."); + + expr += type_to_glsl(result_type) + "({ "; + uint32_t array_size = to_array_size_literal(result_type, 0); + for (uint32_t i = 0; i < array_size; i++, interface_index++) + { + expr += to_expression(ptr) + "." + to_member_name(iface_type, interface_index); + if (expr_type.vecsize > result_type.vecsize) + expr += vector_swizzle(result_type.vecsize, 0); + if (i + 1 < array_size) + expr += ", "; + } + expr += " })"; + } + + emit_op(result_type_id, id, expr, false); + register_read(id, ptr, false); + return true; +} + +bool CompilerMSL::emit_tessellation_access_chain(const uint32_t *ops, uint32_t length) +{ + // If this is a per-vertex output, remap it to the I/O array buffer. + + // Any object which did not go through IO flattening shenanigans will go there instead. + // We will unflatten on-demand instead as needed, but not all possible cases can be supported, especially with arrays. + + auto *var = maybe_get_backing_variable(ops[2]); + bool patch = false; + bool flat_data = false; + bool ptr_is_chain = false; + bool flatten_composites = false; + + bool is_block = false; + bool is_arrayed = false; + + if (var) + { + auto &type = get_variable_data_type(*var); + is_block = has_decoration(type.self, DecorationBlock); + is_arrayed = !type.array.empty(); + + flatten_composites = variable_storage_requires_stage_io(var->storage); + patch = has_decoration(ops[2], DecorationPatch) || is_patch_block(type); + + // Should match strip_array in add_interface_block. + flat_data = var->storage == StorageClassInput || (var->storage == StorageClassOutput && is_tesc_shader()); + + // Patch inputs are treated as normal block IO variables, so they don't deal with this path at all. + if (patch && (!is_block || is_arrayed || var->storage == StorageClassInput)) + flat_data = false; + + // We might have a chained access chain, where + // we first take the access chain to the control point, and then we chain into a member or something similar. + // In this case, we need to skip gl_in/gl_out remapping. + // Also, skip ptr chain for patches. + ptr_is_chain = var->self != ID(ops[2]); + } + + bool builtin_variable = false; + bool variable_is_flat = false; + + if (var && flat_data) + { + builtin_variable = is_builtin_variable(*var); + + BuiltIn bi_type = BuiltInMax; + if (builtin_variable && !is_block) + bi_type = BuiltIn(get_decoration(var->self, DecorationBuiltIn)); + + variable_is_flat = !builtin_variable || is_block || + bi_type == BuiltInPosition || bi_type == BuiltInPointSize || + bi_type == BuiltInClipDistance || bi_type == BuiltInCullDistance; + } + + if (variable_is_flat) + { + // If output is masked, it is emitted as a "normal" variable, just go through normal code paths. + // Only check this for the first level of access chain. + // Dealing with this for partial access chains should be possible, but awkward. + if (var->storage == StorageClassOutput && !ptr_is_chain) + { + bool masked = false; + if (is_block) + { + uint32_t relevant_member_index = patch ? 3 : 4; + // FIXME: This won't work properly if the application first access chains into gl_out element, + // then access chains into the member. Super weird, but theoretically possible ... + if (length > relevant_member_index) + { + uint32_t mbr_idx = get(ops[relevant_member_index]).scalar(); + masked = is_stage_output_block_member_masked(*var, mbr_idx, true); + } + } + else if (var) + masked = is_stage_output_variable_masked(*var); + + if (masked) + return false; + } + + AccessChainMeta meta; + SmallVector indices; + uint32_t next_id = ir.increase_bound_by(1); + + indices.reserve(length - 3 + 1); + + uint32_t first_non_array_index = (ptr_is_chain ? 3 : 4) - (patch ? 1 : 0); + + VariableID stage_var_id; + if (patch) + stage_var_id = var->storage == StorageClassInput ? patch_stage_in_var_id : patch_stage_out_var_id; + else + stage_var_id = var->storage == StorageClassInput ? stage_in_ptr_var_id : stage_out_ptr_var_id; + + VariableID ptr = ptr_is_chain ? VariableID(ops[2]) : stage_var_id; + if (!ptr_is_chain && !patch) + { + // Index into gl_in/gl_out with first array index. + indices.push_back(ops[first_non_array_index - 1]); + } + + auto &result_ptr_type = get(ops[0]); + + uint32_t const_mbr_id = next_id++; + uint32_t index = get_extended_decoration(ops[2], SPIRVCrossDecorationInterfaceMemberIndex); + + // If we have a pointer chain expression, and we are no longer pointing to a composite + // object, we are in the clear. There is no longer a need to flatten anything. + bool further_access_chain_is_trivial = false; + if (ptr_is_chain && flatten_composites) + { + auto &ptr_type = expression_type(ptr); + if (!is_array(ptr_type) && !is_matrix(ptr_type) && ptr_type.basetype != SPIRType::Struct) + further_access_chain_is_trivial = true; + } + + if (!further_access_chain_is_trivial && (flatten_composites || is_block)) + { + uint32_t i = first_non_array_index; + auto *type = &get_variable_element_type(*var); + if (index == uint32_t(-1) && length >= (first_non_array_index + 1)) + { + // Maybe this is a struct type in the input class, in which case + // we put it as a decoration on the corresponding member. + uint32_t mbr_idx = get_constant(ops[first_non_array_index]).scalar(); + index = get_extended_member_decoration(var->self, mbr_idx, + SPIRVCrossDecorationInterfaceMemberIndex); + assert(index != uint32_t(-1)); + i++; + type = &get(type->member_types[mbr_idx]); + } + + // In this case, we're poking into flattened structures and arrays, so now we have to + // combine the following indices. If we encounter a non-constant index, + // we're hosed. + for (; flatten_composites && i < length; ++i) + { + if (!is_array(*type) && !is_matrix(*type) && type->basetype != SPIRType::Struct) + break; + + auto *c = maybe_get(ops[i]); + if (!c || c->specialization) + SPIRV_CROSS_THROW("Trying to dynamically index into an array interface variable in tessellation. " + "This is currently unsupported."); + + // We're in flattened space, so just increment the member index into IO block. + // We can only do this once in the current implementation, so either: + // Struct, Matrix or 1-dimensional array for a control point. + if (type->basetype == SPIRType::Struct && var->storage == StorageClassOutput) + { + // Need to consider holes, since individual block members might be masked away. + uint32_t mbr_idx = c->scalar(); + for (uint32_t j = 0; j < mbr_idx; j++) + if (!is_stage_output_block_member_masked(*var, j, true)) + index++; + } + else + index += c->scalar(); + + if (type->parent_type) + type = &get(type->parent_type); + else if (type->basetype == SPIRType::Struct) + type = &get(type->member_types[c->scalar()]); + } + + // We're not going to emit the actual member name, we let any further OpLoad take care of that. + // Tag the access chain with the member index we're referencing. + auto &result_pointee_type = get_pointee_type(result_ptr_type); + bool defer_access_chain = flatten_composites && (is_matrix(result_pointee_type) || is_array(result_pointee_type) || + result_pointee_type.basetype == SPIRType::Struct); + + if (!defer_access_chain) + { + // Access the appropriate member of gl_in/gl_out. + set(const_mbr_id, get_uint_type_id(), index, false); + indices.push_back(const_mbr_id); + + // Member index is now irrelevant. + index = uint32_t(-1); + + // Append any straggling access chain indices. + if (i < length) + indices.insert(indices.end(), ops + i, ops + length); + } + else + { + // We must have consumed the entire access chain if we're deferring it. + assert(i == length); + } + + if (index != uint32_t(-1)) + set_extended_decoration(ops[1], SPIRVCrossDecorationInterfaceMemberIndex, index); + else + unset_extended_decoration(ops[1], SPIRVCrossDecorationInterfaceMemberIndex); + } + else + { + if (index != uint32_t(-1)) + { + set(const_mbr_id, get_uint_type_id(), index, false); + indices.push_back(const_mbr_id); + } + + // Member index is now irrelevant. + index = uint32_t(-1); + unset_extended_decoration(ops[1], SPIRVCrossDecorationInterfaceMemberIndex); + + indices.insert(indices.end(), ops + first_non_array_index, ops + length); + } + + // We use the pointer to the base of the input/output array here, + // so this is always a pointer chain. + string e; + + if (!ptr_is_chain) + { + // This is the start of an access chain, use ptr_chain to index into control point array. + e = access_chain(ptr, indices.data(), uint32_t(indices.size()), result_ptr_type, &meta, !patch); + } + else + { + // If we're accessing a struct, we need to use member indices which are based on the IO block, + // not actual struct type, so we have to use a split access chain here where + // first path resolves the control point index, i.e. gl_in[index], and second half deals with + // looking up flattened member name. + + // However, it is possible that we partially accessed a struct, + // by taking pointer to member inside the control-point array. + // For this case, we fall back to a natural access chain since we have already dealt with remapping struct members. + // One way to check this here is if we have 2 implied read expressions. + // First one is the gl_in/gl_out struct itself, then an index into that array. + // If we have traversed further, we use a normal access chain formulation. + auto *ptr_expr = maybe_get(ptr); + bool split_access_chain_formulation = flatten_composites && ptr_expr && + ptr_expr->implied_read_expressions.size() == 2 && + !further_access_chain_is_trivial; + + if (split_access_chain_formulation) + { + e = join(to_expression(ptr), + access_chain_internal(stage_var_id, indices.data(), uint32_t(indices.size()), + ACCESS_CHAIN_CHAIN_ONLY_BIT, &meta)); + } + else + { + e = access_chain_internal(ptr, indices.data(), uint32_t(indices.size()), 0, &meta); + } + } + + // Get the actual type of the object that was accessed. If it's a vector type and we changed it, + // then we'll need to add a swizzle. + // For this, we can't necessarily rely on the type of the base expression, because it might be + // another access chain, and it will therefore already have the "correct" type. + auto *expr_type = &get_variable_data_type(*var); + if (has_extended_decoration(ops[2], SPIRVCrossDecorationTessIOOriginalInputTypeID)) + expr_type = &get(get_extended_decoration(ops[2], SPIRVCrossDecorationTessIOOriginalInputTypeID)); + for (uint32_t i = 3; i < length; i++) + { + if (!is_array(*expr_type) && expr_type->basetype == SPIRType::Struct) + expr_type = &get(expr_type->member_types[get(ops[i]).scalar()]); + else + expr_type = &get(expr_type->parent_type); + } + if (!is_array(*expr_type) && !is_matrix(*expr_type) && expr_type->basetype != SPIRType::Struct && + expr_type->vecsize > result_ptr_type.vecsize) + e += vector_swizzle(result_ptr_type.vecsize, 0); + + auto &expr = set(ops[1], std::move(e), ops[0], should_forward(ops[2])); + expr.loaded_from = var->self; + expr.need_transpose = meta.need_transpose; + expr.access_chain = true; + + // Mark the result as being packed if necessary. + if (meta.storage_is_packed) + set_extended_decoration(ops[1], SPIRVCrossDecorationPhysicalTypePacked); + if (meta.storage_physical_type != 0) + set_extended_decoration(ops[1], SPIRVCrossDecorationPhysicalTypeID, meta.storage_physical_type); + if (meta.storage_is_invariant) + set_decoration(ops[1], DecorationInvariant); + // Save the type we found in case the result is used in another access chain. + set_extended_decoration(ops[1], SPIRVCrossDecorationTessIOOriginalInputTypeID, expr_type->self); + + // If we have some expression dependencies in our access chain, this access chain is technically a forwarded + // temporary which could be subject to invalidation. + // Need to assume we're forwarded while calling inherit_expression_depdendencies. + forwarded_temporaries.insert(ops[1]); + // The access chain itself is never forced to a temporary, but its dependencies might. + suppressed_usage_tracking.insert(ops[1]); + + for (uint32_t i = 2; i < length; i++) + { + inherit_expression_dependencies(ops[1], ops[i]); + add_implied_read_expression(expr, ops[i]); + } + + // If we have no dependencies after all, i.e., all indices in the access chain are immutable temporaries, + // we're not forwarded after all. + if (expr.expression_dependencies.empty()) + forwarded_temporaries.erase(ops[1]); + + return true; + } + + // If this is the inner tessellation level, and we're tessellating triangles, + // drop the last index. It isn't an array in this case, so we can't have an + // array reference here. We need to make this ID a variable instead of an + // expression so we don't try to dereference it as a variable pointer. + // Don't do this if the index is a constant 1, though. We need to drop stores + // to that one. + auto *m = ir.find_meta(var ? var->self : ID(0)); + if (is_tesc_shader() && var && m && m->decoration.builtin_type == BuiltInTessLevelInner && + is_tessellating_triangles()) + { + auto *c = maybe_get(ops[3]); + if (c && c->scalar() == 1) + return false; + auto &dest_var = set(ops[1], *var); + dest_var.basetype = ops[0]; + ir.meta[ops[1]] = ir.meta[ops[2]]; + inherit_expression_dependencies(ops[1], ops[2]); + return true; + } + + return false; +} + +bool CompilerMSL::is_out_of_bounds_tessellation_level(uint32_t id_lhs) +{ + if (!is_tessellating_triangles()) + return false; + + // In SPIR-V, TessLevelInner always has two elements and TessLevelOuter always has + // four. This is true even if we are tessellating triangles. This allows clients + // to use a single tessellation control shader with multiple tessellation evaluation + // shaders. + // In Metal, however, only the first element of TessLevelInner and the first three + // of TessLevelOuter are accessible. This stems from how in Metal, the tessellation + // levels must be stored to a dedicated buffer in a particular format that depends + // on the patch type. Therefore, in Triangles mode, any store to the second + // inner level or the fourth outer level must be dropped. + const auto *e = maybe_get(id_lhs); + if (!e || !e->access_chain) + return false; + BuiltIn builtin = BuiltIn(get_decoration(e->loaded_from, DecorationBuiltIn)); + if (builtin != BuiltInTessLevelInner && builtin != BuiltInTessLevelOuter) + return false; + auto *c = maybe_get(e->implied_read_expressions[1]); + if (!c) + return false; + return (builtin == BuiltInTessLevelInner && c->scalar() == 1) || + (builtin == BuiltInTessLevelOuter && c->scalar() == 3); +} + +bool CompilerMSL::prepare_access_chain_for_scalar_access(std::string &expr, const SPIRType &type, + spv::StorageClass storage, bool &is_packed) +{ + // If there is any risk of writes happening with the access chain in question, + // and there is a risk of concurrent write access to other components, + // we must cast the access chain to a plain pointer to ensure we only access the exact scalars we expect. + // The MSL compiler refuses to allow component-level access for any non-packed vector types. + if (!is_packed && (storage == StorageClassStorageBuffer || storage == StorageClassWorkgroup)) + { + const char *addr_space = storage == StorageClassWorkgroup ? "threadgroup" : "device"; + expr = join("((", addr_space, " ", type_to_glsl(type), "*)&", enclose_expression(expr), ")"); + + // Further indexing should happen with packed rules (array index, not swizzle). + is_packed = true; + return true; + } + else + return false; +} + +bool CompilerMSL::access_chain_needs_stage_io_builtin_translation(uint32_t base) +{ + auto *var = maybe_get_backing_variable(base); + if (!var || !is_tessellation_shader()) + return true; + + // We only need to rewrite builtin access chains when accessing flattened builtins like gl_ClipDistance_N. + // Avoid overriding it back to just gl_ClipDistance. + // This can only happen in scenarios where we cannot flatten/unflatten access chains, so, the only case + // where this triggers is evaluation shader inputs. + bool redirect_builtin = is_tese_shader() ? var->storage == StorageClassOutput : false; + return redirect_builtin; +} + +// Sets the interface member index for an access chain to a pull-model interpolant. +void CompilerMSL::fix_up_interpolant_access_chain(const uint32_t *ops, uint32_t length) +{ + auto *var = maybe_get_backing_variable(ops[2]); + if (!var || !pull_model_inputs.count(var->self)) + return; + // Get the base index. + uint32_t interface_index; + auto &var_type = get_variable_data_type(*var); + auto &result_type = get(ops[0]); + auto *type = &var_type; + if (has_extended_decoration(ops[2], SPIRVCrossDecorationInterfaceMemberIndex)) + { + interface_index = get_extended_decoration(ops[2], SPIRVCrossDecorationInterfaceMemberIndex); + } + else + { + // Assume an access chain into a struct variable. + assert(var_type.basetype == SPIRType::Struct); + auto &c = get(ops[3 + var_type.array.size()]); + interface_index = + get_extended_member_decoration(var->self, c.scalar(), SPIRVCrossDecorationInterfaceMemberIndex); + } + // Accumulate indices. We'll have to skip over the one for the struct, if present, because we already accounted + // for that getting the base index. + for (uint32_t i = 3; i < length; ++i) + { + if (is_vector(*type) && !is_array(*type) && is_scalar(result_type)) + { + // We don't want to combine the next index. Actually, we need to save it + // so we know to apply a swizzle to the result of the interpolation. + set_extended_decoration(ops[1], SPIRVCrossDecorationInterpolantComponentExpr, ops[i]); + break; + } + + auto *c = maybe_get(ops[i]); + if (!c || c->specialization) + SPIRV_CROSS_THROW("Trying to dynamically index into an array interface variable using pull-model " + "interpolation. This is currently unsupported."); + + if (type->parent_type) + type = &get(type->parent_type); + else if (type->basetype == SPIRType::Struct) + type = &get(type->member_types[c->scalar()]); + + if (!has_extended_decoration(ops[2], SPIRVCrossDecorationInterfaceMemberIndex) && + i - 3 == var_type.array.size()) + continue; + + interface_index += c->scalar(); + } + // Save this to the access chain itself so we can recover it later when calling an interpolation function. + set_extended_decoration(ops[1], SPIRVCrossDecorationInterfaceMemberIndex, interface_index); +} + + +// If the physical type of a physical buffer pointer has been changed +// to a ulong or ulongn vector, add a cast back to the pointer type. +void CompilerMSL::check_physical_type_cast(std::string &expr, const SPIRType *type, uint32_t physical_type) +{ + auto *p_physical_type = maybe_get(physical_type); + if (p_physical_type && + p_physical_type->storage == StorageClassPhysicalStorageBuffer && + p_physical_type->basetype == to_unsigned_basetype(64)) + { + if (p_physical_type->vecsize > 1) + expr += ".x"; + + expr = join("((", type_to_glsl(*type), ")", expr, ")"); + } +} + +// Override for MSL-specific syntax instructions +void CompilerMSL::emit_instruction(const Instruction &instruction) +{ +#define MSL_BOP(op) emit_binary_op(ops[0], ops[1], ops[2], ops[3], #op) +#define MSL_PTR_BOP(op) emit_binary_ptr_op(ops[0], ops[1], ops[2], ops[3], #op) + // MSL does care about implicit integer promotion, but those cases are all handled in common code. +#define MSL_BOP_CAST(op, type) \ + emit_binary_op_cast(ops[0], ops[1], ops[2], ops[3], #op, type, opcode_is_sign_invariant(opcode), false) +#define MSL_UOP(op) emit_unary_op(ops[0], ops[1], ops[2], #op) +#define MSL_QFOP(op) emit_quaternary_func_op(ops[0], ops[1], ops[2], ops[3], ops[4], ops[5], #op) +#define MSL_TFOP(op) emit_trinary_func_op(ops[0], ops[1], ops[2], ops[3], ops[4], #op) +#define MSL_BFOP(op) emit_binary_func_op(ops[0], ops[1], ops[2], ops[3], #op) +#define MSL_BFOP_CAST(op, type) \ + emit_binary_func_op_cast(ops[0], ops[1], ops[2], ops[3], #op, type, opcode_is_sign_invariant(opcode)) +#define MSL_UFOP(op) emit_unary_func_op(ops[0], ops[1], ops[2], #op) +#define MSL_UNORD_BOP(op) emit_binary_unord_op(ops[0], ops[1], ops[2], ops[3], #op) + + auto ops = stream(instruction); + auto opcode = static_cast(instruction.op); + + opcode = get_remapped_spirv_op(opcode); + + // If we need to do implicit bitcasts, make sure we do it with the correct type. + uint32_t integer_width = get_integer_width_for_instruction(instruction); + auto int_type = to_signed_basetype(integer_width); + auto uint_type = to_unsigned_basetype(integer_width); + + switch (opcode) + { + case OpLoad: + { + uint32_t id = ops[1]; + uint32_t ptr = ops[2]; + if (is_tessellation_shader()) + { + if (!emit_tessellation_io_load(ops[0], id, ptr)) + CompilerGLSL::emit_instruction(instruction); + } + else + { + // Sample mask input for Metal is not an array + if (BuiltIn(get_decoration(ptr, DecorationBuiltIn)) == BuiltInSampleMask) + set_decoration(id, DecorationBuiltIn, BuiltInSampleMask); + CompilerGLSL::emit_instruction(instruction); + } + break; + } + + // Comparisons + case OpIEqual: + MSL_BOP_CAST(==, int_type); + break; + + case OpLogicalEqual: + case OpFOrdEqual: + MSL_BOP(==); + break; + + case OpINotEqual: + MSL_BOP_CAST(!=, int_type); + break; + + case OpLogicalNotEqual: + case OpFOrdNotEqual: + // TODO: Should probably negate the == result here. + // Typically OrdNotEqual comes from GLSL which itself does not really specify what + // happens with NaN. + // Consider fixing this if we run into real issues. + MSL_BOP(!=); + break; + + case OpUGreaterThan: + MSL_BOP_CAST(>, uint_type); + break; + + case OpSGreaterThan: + MSL_BOP_CAST(>, int_type); + break; + + case OpFOrdGreaterThan: + MSL_BOP(>); + break; + + case OpUGreaterThanEqual: + MSL_BOP_CAST(>=, uint_type); + break; + + case OpSGreaterThanEqual: + MSL_BOP_CAST(>=, int_type); + break; + + case OpFOrdGreaterThanEqual: + MSL_BOP(>=); + break; + + case OpULessThan: + MSL_BOP_CAST(<, uint_type); + break; + + case OpSLessThan: + MSL_BOP_CAST(<, int_type); + break; + + case OpFOrdLessThan: + MSL_BOP(<); + break; + + case OpULessThanEqual: + MSL_BOP_CAST(<=, uint_type); + break; + + case OpSLessThanEqual: + MSL_BOP_CAST(<=, int_type); + break; + + case OpFOrdLessThanEqual: + MSL_BOP(<=); + break; + + case OpFUnordEqual: + MSL_UNORD_BOP(==); + break; + + case OpFUnordNotEqual: + // not equal in MSL generates une opcodes to begin with. + // Since unordered not equal is how it works in C, just inherit that behavior. + MSL_BOP(!=); + break; + + case OpFUnordGreaterThan: + MSL_UNORD_BOP(>); + break; + + case OpFUnordGreaterThanEqual: + MSL_UNORD_BOP(>=); + break; + + case OpFUnordLessThan: + MSL_UNORD_BOP(<); + break; + + case OpFUnordLessThanEqual: + MSL_UNORD_BOP(<=); + break; + + // Pointer math + case OpPtrEqual: + MSL_PTR_BOP(==); + break; + + case OpPtrNotEqual: + MSL_PTR_BOP(!=); + break; + + case OpPtrDiff: + MSL_PTR_BOP(-); + break; + + // Derivatives + case OpDPdx: + case OpDPdxFine: + case OpDPdxCoarse: + MSL_UFOP(dfdx); + register_control_dependent_expression(ops[1]); + break; + + case OpDPdy: + case OpDPdyFine: + case OpDPdyCoarse: + MSL_UFOP(dfdy); + register_control_dependent_expression(ops[1]); + break; + + case OpFwidth: + case OpFwidthCoarse: + case OpFwidthFine: + MSL_UFOP(fwidth); + register_control_dependent_expression(ops[1]); + break; + + // Bitfield + case OpBitFieldInsert: + { + emit_bitfield_insert_op(ops[0], ops[1], ops[2], ops[3], ops[4], ops[5], "insert_bits", SPIRType::UInt); + break; + } + + case OpBitFieldSExtract: + { + emit_trinary_func_op_bitextract(ops[0], ops[1], ops[2], ops[3], ops[4], "extract_bits", int_type, int_type, + SPIRType::UInt, SPIRType::UInt); + break; + } + + case OpBitFieldUExtract: + { + emit_trinary_func_op_bitextract(ops[0], ops[1], ops[2], ops[3], ops[4], "extract_bits", uint_type, uint_type, + SPIRType::UInt, SPIRType::UInt); + break; + } + + case OpBitReverse: + // BitReverse does not have issues with sign since result type must match input type. + MSL_UFOP(reverse_bits); + break; + + case OpBitCount: + { + auto basetype = expression_type(ops[2]).basetype; + emit_unary_func_op_cast(ops[0], ops[1], ops[2], "popcount", basetype, basetype); + break; + } + + case OpFRem: + MSL_BFOP(fmod); + break; + + case OpFMul: + if (msl_options.invariant_float_math || has_decoration(ops[1], DecorationNoContraction)) + MSL_BFOP(spvFMul); + else + MSL_BOP(*); + break; + + case OpFAdd: + if (msl_options.invariant_float_math || has_decoration(ops[1], DecorationNoContraction)) + MSL_BFOP(spvFAdd); + else + MSL_BOP(+); + break; + + case OpFSub: + if (msl_options.invariant_float_math || has_decoration(ops[1], DecorationNoContraction)) + MSL_BFOP(spvFSub); + else + MSL_BOP(-); + break; + + // Atomics + case OpAtomicExchange: + { + uint32_t result_type = ops[0]; + uint32_t id = ops[1]; + uint32_t ptr = ops[2]; + uint32_t mem_sem = ops[4]; + uint32_t val = ops[5]; + emit_atomic_func_op(result_type, id, "atomic_exchange", opcode, mem_sem, mem_sem, false, ptr, val); + break; + } + + case OpAtomicCompareExchange: + { + uint32_t result_type = ops[0]; + uint32_t id = ops[1]; + uint32_t ptr = ops[2]; + uint32_t mem_sem_pass = ops[4]; + uint32_t mem_sem_fail = ops[5]; + uint32_t val = ops[6]; + uint32_t comp = ops[7]; + emit_atomic_func_op(result_type, id, "atomic_compare_exchange_weak", opcode, + mem_sem_pass, mem_sem_fail, true, + ptr, comp, true, false, val); + break; + } + + case OpAtomicCompareExchangeWeak: + SPIRV_CROSS_THROW("OpAtomicCompareExchangeWeak is only supported in kernel profile."); + + case OpAtomicLoad: + { + uint32_t result_type = ops[0]; + uint32_t id = ops[1]; + uint32_t ptr = ops[2]; + uint32_t mem_sem = ops[4]; + check_atomic_image(ptr); + emit_atomic_func_op(result_type, id, "atomic_load", opcode, mem_sem, mem_sem, false, ptr, 0); + break; + } + + case OpAtomicStore: + { + uint32_t result_type = expression_type(ops[0]).self; + uint32_t id = ops[0]; + uint32_t ptr = ops[0]; + uint32_t mem_sem = ops[2]; + uint32_t val = ops[3]; + check_atomic_image(ptr); + emit_atomic_func_op(result_type, id, "atomic_store", opcode, mem_sem, mem_sem, false, ptr, val); + break; + } + +#define MSL_AFMO_IMPL(op, valsrc, valconst) \ + do \ + { \ + uint32_t result_type = ops[0]; \ + uint32_t id = ops[1]; \ + uint32_t ptr = ops[2]; \ + uint32_t mem_sem = ops[4]; \ + uint32_t val = valsrc; \ + emit_atomic_func_op(result_type, id, "atomic_fetch_" #op, opcode, \ + mem_sem, mem_sem, false, ptr, val, \ + false, valconst); \ + } while (false) + +#define MSL_AFMO(op) MSL_AFMO_IMPL(op, ops[5], false) +#define MSL_AFMIO(op) MSL_AFMO_IMPL(op, 1, true) + + case OpAtomicIIncrement: + MSL_AFMIO(add); + break; + + case OpAtomicIDecrement: + MSL_AFMIO(sub); + break; + + case OpAtomicIAdd: + case OpAtomicFAddEXT: + MSL_AFMO(add); + break; + + case OpAtomicISub: + MSL_AFMO(sub); + break; + + case OpAtomicSMin: + case OpAtomicUMin: + MSL_AFMO(min); + break; + + case OpAtomicSMax: + case OpAtomicUMax: + MSL_AFMO(max); + break; + + case OpAtomicAnd: + MSL_AFMO(and); + break; + + case OpAtomicOr: + MSL_AFMO(or); + break; + + case OpAtomicXor: + MSL_AFMO(xor); + break; + + // Images + + // Reads == Fetches in Metal + case OpImageRead: + { + // Mark that this shader reads from this image + uint32_t img_id = ops[2]; + auto &type = expression_type(img_id); + auto *p_var = maybe_get_backing_variable(img_id); + if (type.image.dim != DimSubpassData) + { + if (p_var && has_decoration(p_var->self, DecorationNonReadable)) + { + unset_decoration(p_var->self, DecorationNonReadable); + force_recompile(); + } + } + + // Metal requires explicit fences to break up RAW hazards, even within the same shader invocation + if (msl_options.readwrite_texture_fences && p_var && !has_decoration(p_var->self, DecorationNonWritable)) + { + add_spv_func_and_recompile(SPVFuncImplImageFence); + // Need to wrap this with a value type, + // since the Metal headers are broken and do not consider case when the image is a reference. + statement("spvImageFence(", to_expression(img_id), ");"); + } + + emit_texture_op(instruction, false); + break; + } + + // Emulate texture2D atomic operations + case OpImageTexelPointer: + { + // When using the pointer, we need to know which variable it is actually loaded from. + auto *var = maybe_get_backing_variable(ops[2]); + if (var && atomic_image_vars_emulated.count(var->self)) + { + uint32_t result_type = ops[0]; + uint32_t id = ops[1]; + + std::string coord = to_expression(ops[3]); + auto &type = expression_type(ops[2]); + if (type.image.dim == Dim2D) + { + coord = join("spvImage2DAtomicCoord(", coord, ", ", to_expression(ops[2]), ")"); + } + + auto &e = set(id, join(to_expression(ops[2]), "_atomic[", coord, "]"), result_type, true); + e.loaded_from = var ? var->self : ID(0); + inherit_expression_dependencies(id, ops[3]); + } + else + { + uint32_t result_type = ops[0]; + uint32_t id = ops[1]; + + // Virtual expression. Split this up in the actual image atomic. + // In GLSL and HLSL we are able to resolve the dereference inline, but MSL has + // image.op(coord, ...) syntax. + auto &e = + set(id, join(to_expression(ops[2]), "@", + bitcast_expression(SPIRType::UInt, ops[3])), + result_type, true); + + // When using the pointer, we need to know which variable it is actually loaded from. + e.loaded_from = var ? var->self : ID(0); + inherit_expression_dependencies(id, ops[3]); + } + break; + } + + case OpImageWrite: + { + uint32_t img_id = ops[0]; + uint32_t coord_id = ops[1]; + uint32_t texel_id = ops[2]; + const uint32_t *opt = &ops[3]; + uint32_t length = instruction.length - 3; + + // Bypass pointers because we need the real image struct + auto &type = expression_type(img_id); + auto &img_type = get(type.self); + + // Ensure this image has been marked as being written to and force a + // recommpile so that the image type output will include write access + auto *p_var = maybe_get_backing_variable(img_id); + if (p_var && has_decoration(p_var->self, DecorationNonWritable)) + { + unset_decoration(p_var->self, DecorationNonWritable); + force_recompile(); + } + + bool forward = false; + uint32_t bias = 0; + uint32_t lod = 0; + uint32_t flags = 0; + + if (length) + { + flags = *opt++; + length--; + } + + auto test = [&](uint32_t &v, uint32_t flag) { + if (length && (flags & flag)) + { + v = *opt++; + length--; + } + }; + + test(bias, ImageOperandsBiasMask); + test(lod, ImageOperandsLodMask); + + auto &texel_type = expression_type(texel_id); + auto store_type = texel_type; + store_type.vecsize = 4; + + TextureFunctionArguments args = {}; + args.base.img = img_id; + args.base.imgtype = &img_type; + args.base.is_fetch = true; + args.coord = coord_id; + args.lod = lod; + + string expr; + if (needs_frag_discard_checks()) + expr = join("(", builtin_to_glsl(BuiltInHelperInvocation, StorageClassInput), " ? ((void)0) : "); + expr += join(to_expression(img_id), ".write(", + remap_swizzle(store_type, texel_type.vecsize, to_expression(texel_id)), ", ", + CompilerMSL::to_function_args(args, &forward), ")"); + if (needs_frag_discard_checks()) + expr += ")"; + statement(expr, ";"); + + if (p_var && variable_storage_is_aliased(*p_var)) + flush_all_aliased_variables(); + + break; + } + + case OpImageQuerySize: + case OpImageQuerySizeLod: + { + uint32_t rslt_type_id = ops[0]; + auto &rslt_type = get(rslt_type_id); + + uint32_t id = ops[1]; + + uint32_t img_id = ops[2]; + string img_exp = to_expression(img_id); + auto &img_type = expression_type(img_id); + Dim img_dim = img_type.image.dim; + bool img_is_array = img_type.image.arrayed; + + if (img_type.basetype != SPIRType::Image) + SPIRV_CROSS_THROW("Invalid type for OpImageQuerySize."); + + string lod; + if (opcode == OpImageQuerySizeLod) + { + // LOD index defaults to zero, so don't bother outputing level zero index + string decl_lod = to_expression(ops[3]); + if (decl_lod != "0") + lod = decl_lod; + } + + string expr = type_to_glsl(rslt_type) + "("; + expr += img_exp + ".get_width(" + lod + ")"; + + if (img_dim == Dim2D || img_dim == DimCube || img_dim == Dim3D) + expr += ", " + img_exp + ".get_height(" + lod + ")"; + + if (img_dim == Dim3D) + expr += ", " + img_exp + ".get_depth(" + lod + ")"; + + if (img_is_array) + { + expr += ", " + img_exp + ".get_array_size()"; + if (img_dim == DimCube && msl_options.emulate_cube_array) + expr += " / 6"; + } + + expr += ")"; + + emit_op(rslt_type_id, id, expr, should_forward(img_id)); + + break; + } + + case OpImageQueryLod: + { + if (!msl_options.supports_msl_version(2, 2)) + SPIRV_CROSS_THROW("ImageQueryLod is only supported on MSL 2.2 and up."); + uint32_t result_type = ops[0]; + uint32_t id = ops[1]; + uint32_t image_id = ops[2]; + uint32_t coord_id = ops[3]; + emit_uninitialized_temporary_expression(result_type, id); + + std::string coord_expr = to_expression(coord_id); + auto sampler_expr = to_sampler_expression(image_id); + auto *combined = maybe_get(image_id); + auto image_expr = combined ? to_expression(combined->image) : to_expression(image_id); + const SPIRType &image_type = expression_type(image_id); + const SPIRType &coord_type = expression_type(coord_id); + + switch (image_type.image.dim) + { + case Dim1D: + if (!msl_options.texture_1D_as_2D) + SPIRV_CROSS_THROW("ImageQueryLod is not supported on 1D textures."); + [[fallthrough]]; + case Dim2D: + if (coord_type.vecsize > 2) + coord_expr = enclose_expression(coord_expr) + ".xy"; + break; + case DimCube: + case Dim3D: + if (coord_type.vecsize > 3) + coord_expr = enclose_expression(coord_expr) + ".xyz"; + break; + default: + SPIRV_CROSS_THROW("Bad image type given to OpImageQueryLod"); + } + + // TODO: It is unclear if calculcate_clamped_lod also conditionally rounds + // the reported LOD based on the sampler. NEAREST miplevel should + // round the LOD, but LINEAR miplevel should not round. + // Let's hope this does not become an issue ... + statement(to_expression(id), ".x = ", image_expr, ".calculate_clamped_lod(", sampler_expr, ", ", + coord_expr, ");"); + statement(to_expression(id), ".y = ", image_expr, ".calculate_unclamped_lod(", sampler_expr, ", ", + coord_expr, ");"); + register_control_dependent_expression(id); + break; + } + +#define MSL_ImgQry(qrytype) \ + do \ + { \ + uint32_t rslt_type_id = ops[0]; \ + auto &rslt_type = get(rslt_type_id); \ + uint32_t id = ops[1]; \ + uint32_t img_id = ops[2]; \ + string img_exp = to_expression(img_id); \ + string expr = type_to_glsl(rslt_type) + "(" + img_exp + ".get_num_" #qrytype "())"; \ + emit_op(rslt_type_id, id, expr, should_forward(img_id)); \ + } while (false) + + case OpImageQueryLevels: + MSL_ImgQry(mip_levels); + break; + + case OpImageQuerySamples: + MSL_ImgQry(samples); + break; + + case OpImage: + { + uint32_t result_type = ops[0]; + uint32_t id = ops[1]; + auto *combined = maybe_get(ops[2]); + + if (combined) + { + auto &e = emit_op(result_type, id, to_expression(combined->image), true, true); + auto *var = maybe_get_backing_variable(combined->image); + if (var) + e.loaded_from = var->self; + } + else + { + auto *var = maybe_get_backing_variable(ops[2]); + SPIRExpression *e; + if (var && has_extended_decoration(var->self, SPIRVCrossDecorationDynamicImageSampler)) + e = &emit_op(result_type, id, join(to_expression(ops[2]), ".plane0"), true, true); + else + e = &emit_op(result_type, id, to_expression(ops[2]), true, true); + if (var) + e->loaded_from = var->self; + } + break; + } + + // Casting + case OpQuantizeToF16: + { + uint32_t result_type = ops[0]; + uint32_t id = ops[1]; + uint32_t arg = ops[2]; + string exp = join("spvQuantizeToF16(", to_expression(arg), ")"); + emit_op(result_type, id, exp, should_forward(arg)); + break; + } + + case OpInBoundsAccessChain: + case OpAccessChain: + case OpPtrAccessChain: + if (is_tessellation_shader()) + { + if (!emit_tessellation_access_chain(ops, instruction.length)) + CompilerGLSL::emit_instruction(instruction); + } + else + CompilerGLSL::emit_instruction(instruction); + fix_up_interpolant_access_chain(ops, instruction.length); + break; + + case OpStore: + { + const auto &type = expression_type(ops[0]); + + if (is_out_of_bounds_tessellation_level(ops[0])) + break; + + if (needs_frag_discard_checks() && + (type.storage == StorageClassStorageBuffer || type.storage == StorageClassUniform)) + { + // If we're in a continue block, this kludge will make the block too complex + // to emit normally. + assert(current_emitting_block); + auto cont_type = continue_block_type(*current_emitting_block); + if (cont_type != SPIRBlock::ContinueNone && cont_type != SPIRBlock::ComplexLoop) + { + current_emitting_block->complex_continue = true; + force_recompile(); + } + statement("if (!", builtin_to_glsl(BuiltInHelperInvocation, StorageClassInput), ")"); + begin_scope(); + } + if (!maybe_emit_array_assignment(ops[0], ops[1])) + CompilerGLSL::emit_instruction(instruction); + if (needs_frag_discard_checks() && + (type.storage == StorageClassStorageBuffer || type.storage == StorageClassUniform)) + end_scope(); + break; + } + + // Compute barriers + case OpMemoryBarrier: + emit_barrier(0, ops[0], ops[1]); + break; + + case OpControlBarrier: + // In GLSL a memory barrier is often followed by a control barrier. + // But in MSL, memory barriers are also control barriers, so don't + // emit a simple control barrier if a memory barrier has just been emitted. + if (previous_instruction_opcode != OpMemoryBarrier) + emit_barrier(ops[0], ops[1], ops[2]); + break; + + case OpOuterProduct: + { + uint32_t result_type = ops[0]; + uint32_t id = ops[1]; + uint32_t a = ops[2]; + uint32_t b = ops[3]; + + auto &type = get(result_type); + string expr = type_to_glsl_constructor(type); + expr += "("; + for (uint32_t col = 0; col < type.columns; col++) + { + expr += to_enclosed_unpacked_expression(a); + expr += " * "; + expr += to_extract_component_expression(b, col); + if (col + 1 < type.columns) + expr += ", "; + } + expr += ")"; + emit_op(result_type, id, expr, should_forward(a) && should_forward(b)); + inherit_expression_dependencies(id, a); + inherit_expression_dependencies(id, b); + break; + } + + case OpVectorTimesMatrix: + case OpMatrixTimesVector: + { + if (!msl_options.invariant_float_math && !has_decoration(ops[1], DecorationNoContraction)) + { + CompilerGLSL::emit_instruction(instruction); + break; + } + + // If the matrix needs transpose, just flip the multiply order. + auto *e = maybe_get(ops[opcode == OpMatrixTimesVector ? 2 : 3]); + if (e && e->need_transpose) + { + e->need_transpose = false; + string expr; + + if (opcode == OpMatrixTimesVector) + { + expr = join("spvFMulVectorMatrix(", to_enclosed_unpacked_expression(ops[3]), ", ", + to_unpacked_row_major_matrix_expression(ops[2]), ")"); + } + else + { + expr = join("spvFMulMatrixVector(", to_unpacked_row_major_matrix_expression(ops[3]), ", ", + to_enclosed_unpacked_expression(ops[2]), ")"); + } + + bool forward = should_forward(ops[2]) && should_forward(ops[3]); + emit_op(ops[0], ops[1], expr, forward); + e->need_transpose = true; + inherit_expression_dependencies(ops[1], ops[2]); + inherit_expression_dependencies(ops[1], ops[3]); + } + else + { + if (opcode == OpMatrixTimesVector) + MSL_BFOP(spvFMulMatrixVector); + else + MSL_BFOP(spvFMulVectorMatrix); + } + break; + } + + case OpMatrixTimesMatrix: + { + if (!msl_options.invariant_float_math && !has_decoration(ops[1], DecorationNoContraction)) + { + CompilerGLSL::emit_instruction(instruction); + break; + } + + auto *a = maybe_get(ops[2]); + auto *b = maybe_get(ops[3]); + + // If both matrices need transpose, we can multiply in flipped order and tag the expression as transposed. + // a^T * b^T = (b * a)^T. + if (a && b && a->need_transpose && b->need_transpose) + { + a->need_transpose = false; + b->need_transpose = false; + + auto expr = + join("spvFMulMatrixMatrix(", enclose_expression(to_unpacked_row_major_matrix_expression(ops[3])), ", ", + enclose_expression(to_unpacked_row_major_matrix_expression(ops[2])), ")"); + + bool forward = should_forward(ops[2]) && should_forward(ops[3]); + auto &e = emit_op(ops[0], ops[1], expr, forward); + e.need_transpose = true; + a->need_transpose = true; + b->need_transpose = true; + inherit_expression_dependencies(ops[1], ops[2]); + inherit_expression_dependencies(ops[1], ops[3]); + } + else + MSL_BFOP(spvFMulMatrixMatrix); + + break; + } + + case OpIAddCarry: + case OpISubBorrow: + { + uint32_t result_type = ops[0]; + uint32_t result_id = ops[1]; + uint32_t op0 = ops[2]; + uint32_t op1 = ops[3]; + auto &type = get(result_type); + emit_uninitialized_temporary_expression(result_type, result_id); + + auto &res_type = get(type.member_types[1]); + if (opcode == OpIAddCarry) + { + statement(to_expression(result_id), ".", to_member_name(type, 0), " = ", + to_enclosed_unpacked_expression(op0), " + ", to_enclosed_unpacked_expression(op1), ";"); + statement(to_expression(result_id), ".", to_member_name(type, 1), " = select(", type_to_glsl(res_type), + "(1), ", type_to_glsl(res_type), "(0), ", to_unpacked_expression(result_id), ".", to_member_name(type, 0), + " >= max(", to_unpacked_expression(op0), ", ", to_unpacked_expression(op1), "));"); + } + else + { + statement(to_expression(result_id), ".", to_member_name(type, 0), " = ", to_enclosed_unpacked_expression(op0), " - ", + to_enclosed_unpacked_expression(op1), ";"); + statement(to_expression(result_id), ".", to_member_name(type, 1), " = select(", type_to_glsl(res_type), + "(1), ", type_to_glsl(res_type), "(0), ", to_enclosed_unpacked_expression(op0), + " >= ", to_enclosed_unpacked_expression(op1), ");"); + } + break; + } + + case OpUMulExtended: + case OpSMulExtended: + { + uint32_t result_type = ops[0]; + uint32_t result_id = ops[1]; + uint32_t op0 = ops[2]; + uint32_t op1 = ops[3]; + auto &type = get(result_type); + auto input_type = opcode == OpSMulExtended ? int_type : uint_type; + string cast_op0, cast_op1; + + binary_op_bitcast_helper(cast_op0, cast_op1, input_type, op0, op1, false); + emit_uninitialized_temporary_expression(result_type, result_id); + statement(to_expression(result_id), ".", to_member_name(type, 0), " = ", cast_op0, " * ", cast_op1, ";"); + statement(to_expression(result_id), ".", to_member_name(type, 1), " = mulhi(", cast_op0, ", ", cast_op1, ");"); + break; + } + + case OpArrayLength: + { + auto &type = expression_type(ops[2]); + uint32_t offset = type_struct_member_offset(type, ops[3]); + uint32_t stride = type_struct_member_array_stride(type, ops[3]); + + auto expr = join("(", to_buffer_size_expression(ops[2]), " - ", offset, ") / ", stride); + emit_op(ops[0], ops[1], expr, true); + break; + } + + // Legacy sub-group stuff ... + case OpSubgroupBallotKHR: + case OpSubgroupFirstInvocationKHR: + case OpSubgroupReadInvocationKHR: + case OpSubgroupAllKHR: + case OpSubgroupAnyKHR: + case OpSubgroupAllEqualKHR: + emit_subgroup_op(instruction); + break; + + // SPV_INTEL_shader_integer_functions2 + case OpUCountLeadingZerosINTEL: + MSL_UFOP(clz); + break; + + case OpUCountTrailingZerosINTEL: + MSL_UFOP(ctz); + break; + + case OpAbsISubINTEL: + case OpAbsUSubINTEL: + MSL_BFOP(absdiff); + break; + + case OpIAddSatINTEL: + case OpUAddSatINTEL: + MSL_BFOP(addsat); + break; + + case OpIAverageINTEL: + case OpUAverageINTEL: + MSL_BFOP(hadd); + break; + + case OpIAverageRoundedINTEL: + case OpUAverageRoundedINTEL: + MSL_BFOP(rhadd); + break; + + case OpISubSatINTEL: + case OpUSubSatINTEL: + MSL_BFOP(subsat); + break; + + case OpIMul32x16INTEL: + { + uint32_t result_type = ops[0]; + uint32_t id = ops[1]; + uint32_t a = ops[2], b = ops[3]; + bool forward = should_forward(a) && should_forward(b); + emit_op(result_type, id, join("int(short(", to_unpacked_expression(a), ")) * int(short(", to_unpacked_expression(b), "))"), forward); + inherit_expression_dependencies(id, a); + inherit_expression_dependencies(id, b); + break; + } + + case OpUMul32x16INTEL: + { + uint32_t result_type = ops[0]; + uint32_t id = ops[1]; + uint32_t a = ops[2], b = ops[3]; + bool forward = should_forward(a) && should_forward(b); + emit_op(result_type, id, join("uint(ushort(", to_unpacked_expression(a), ")) * uint(ushort(", to_unpacked_expression(b), "))"), forward); + inherit_expression_dependencies(id, a); + inherit_expression_dependencies(id, b); + break; + } + + // SPV_EXT_demote_to_helper_invocation + case OpDemoteToHelperInvocationEXT: + if (!msl_options.supports_msl_version(2, 3)) + SPIRV_CROSS_THROW("discard_fragment() does not formally have demote semantics until MSL 2.3."); + CompilerGLSL::emit_instruction(instruction); + break; + + case OpIsHelperInvocationEXT: + if (msl_options.is_ios() && !msl_options.supports_msl_version(2, 3)) + SPIRV_CROSS_THROW("simd_is_helper_thread() requires MSL 2.3 on iOS."); + else if (msl_options.is_macos() && !msl_options.supports_msl_version(2, 1)) + SPIRV_CROSS_THROW("simd_is_helper_thread() requires MSL 2.1 on macOS."); + emit_op(ops[0], ops[1], + needs_manual_helper_invocation_updates() ? builtin_to_glsl(BuiltInHelperInvocation, StorageClassInput) : + "simd_is_helper_thread()", + false); + break; + + case OpBeginInvocationInterlockEXT: + case OpEndInvocationInterlockEXT: + if (!msl_options.supports_msl_version(2, 0)) + SPIRV_CROSS_THROW("Raster order groups require MSL 2.0."); + break; // Nothing to do in the body + + case OpConvertUToAccelerationStructureKHR: + SPIRV_CROSS_THROW("ConvertUToAccelerationStructure is not supported in MSL."); + case OpRayQueryGetIntersectionInstanceShaderBindingTableRecordOffsetKHR: + SPIRV_CROSS_THROW("BindingTableRecordOffset is not supported in MSL."); + + case OpRayQueryInitializeKHR: + { + flush_variable_declaration(ops[0]); + register_write(ops[0]); + add_spv_func_and_recompile(SPVFuncImplRayQueryIntersectionParams); + + statement(to_expression(ops[0]), ".reset(", "ray(", to_expression(ops[4]), ", ", to_expression(ops[6]), ", ", + to_expression(ops[5]), ", ", to_expression(ops[7]), "), ", to_expression(ops[1]), ", ", to_expression(ops[3]), + ", spvMakeIntersectionParams(", to_expression(ops[2]), "));"); + break; + } + case OpRayQueryProceedKHR: + { + flush_variable_declaration(ops[0]); + register_write(ops[2]); + emit_op(ops[0], ops[1], join(to_expression(ops[2]), ".next()"), false); + break; + } +#define MSL_RAY_QUERY_IS_CANDIDATE get(ops[3]).scalar_i32() == 0 + +#define MSL_RAY_QUERY_GET_OP(op, msl_op) \ + case OpRayQueryGet##op##KHR: \ + flush_variable_declaration(ops[2]); \ + emit_op(ops[0], ops[1], join(to_expression(ops[2]), ".get_" #msl_op "()"), false); \ + break + +#define MSL_RAY_QUERY_OP_INNER2(op, msl_prefix, msl_op) \ + case OpRayQueryGet##op##KHR: \ + flush_variable_declaration(ops[2]); \ + if (MSL_RAY_QUERY_IS_CANDIDATE) \ + emit_op(ops[0], ops[1], join(to_expression(ops[2]), #msl_prefix "_candidate_" #msl_op "()"), false); \ + else \ + emit_op(ops[0], ops[1], join(to_expression(ops[2]), #msl_prefix "_committed_" #msl_op "()"), false); \ + break + +#define MSL_RAY_QUERY_GET_OP2(op, msl_op) MSL_RAY_QUERY_OP_INNER2(op, .get, msl_op) +#define MSL_RAY_QUERY_IS_OP2(op, msl_op) MSL_RAY_QUERY_OP_INNER2(op, .is, msl_op) + + MSL_RAY_QUERY_GET_OP(RayTMin, ray_min_distance); + MSL_RAY_QUERY_GET_OP(WorldRayOrigin, world_space_ray_origin); + MSL_RAY_QUERY_GET_OP(WorldRayDirection, world_space_ray_direction); + MSL_RAY_QUERY_GET_OP2(IntersectionInstanceId, instance_id); + MSL_RAY_QUERY_GET_OP2(IntersectionInstanceCustomIndex, user_instance_id); + MSL_RAY_QUERY_GET_OP2(IntersectionBarycentrics, triangle_barycentric_coord); + MSL_RAY_QUERY_GET_OP2(IntersectionPrimitiveIndex, primitive_id); + MSL_RAY_QUERY_GET_OP2(IntersectionGeometryIndex, geometry_id); + MSL_RAY_QUERY_GET_OP2(IntersectionObjectRayOrigin, ray_origin); + MSL_RAY_QUERY_GET_OP2(IntersectionObjectRayDirection, ray_direction); + MSL_RAY_QUERY_GET_OP2(IntersectionObjectToWorld, object_to_world_transform); + MSL_RAY_QUERY_GET_OP2(IntersectionWorldToObject, world_to_object_transform); + MSL_RAY_QUERY_IS_OP2(IntersectionFrontFace, triangle_front_facing); + + case OpRayQueryGetIntersectionTypeKHR: + flush_variable_declaration(ops[2]); + if (MSL_RAY_QUERY_IS_CANDIDATE) + emit_op(ops[0], ops[1], join("uint(", to_expression(ops[2]), ".get_candidate_intersection_type()) - 1"), + false); + else + emit_op(ops[0], ops[1], join("uint(", to_expression(ops[2]), ".get_committed_intersection_type())"), false); + break; + case OpRayQueryGetIntersectionTKHR: + flush_variable_declaration(ops[2]); + if (MSL_RAY_QUERY_IS_CANDIDATE) + emit_op(ops[0], ops[1], join(to_expression(ops[2]), ".get_candidate_triangle_distance()"), false); + else + emit_op(ops[0], ops[1], join(to_expression(ops[2]), ".get_committed_distance()"), false); + break; + case OpRayQueryGetIntersectionCandidateAABBOpaqueKHR: + { + flush_variable_declaration(ops[0]); + emit_op(ops[0], ops[1], join(to_expression(ops[2]), ".is_candidate_non_opaque_bounding_box()"), false); + break; + } + case OpRayQueryConfirmIntersectionKHR: + flush_variable_declaration(ops[0]); + register_write(ops[0]); + statement(to_expression(ops[0]), ".commit_triangle_intersection();"); + break; + case OpRayQueryGenerateIntersectionKHR: + flush_variable_declaration(ops[0]); + register_write(ops[0]); + statement(to_expression(ops[0]), ".commit_bounding_box_intersection(", to_expression(ops[1]), ");"); + break; + case OpRayQueryTerminateKHR: + flush_variable_declaration(ops[0]); + register_write(ops[0]); + statement(to_expression(ops[0]), ".abort();"); + break; +#undef MSL_RAY_QUERY_GET_OP +#undef MSL_RAY_QUERY_IS_CANDIDATE +#undef MSL_RAY_QUERY_IS_OP2 +#undef MSL_RAY_QUERY_GET_OP2 +#undef MSL_RAY_QUERY_OP_INNER2 + + case OpConvertPtrToU: + case OpConvertUToPtr: + case OpBitcast: + { + auto &type = get(ops[0]); + auto &input_type = expression_type(ops[2]); + + if (opcode != OpBitcast || type.pointer || input_type.pointer) + { + string op; + + if (type.vecsize == 1 && input_type.vecsize == 1) + op = join("reinterpret_cast<", type_to_glsl(type), ">(", to_unpacked_expression(ops[2]), ")"); + else if (input_type.vecsize == 2) + op = join("reinterpret_cast<", type_to_glsl(type), ">(as_type(", to_unpacked_expression(ops[2]), "))"); + else + op = join("as_type<", type_to_glsl(type), ">(reinterpret_cast(", to_unpacked_expression(ops[2]), "))"); + + emit_op(ops[0], ops[1], op, should_forward(ops[2])); + inherit_expression_dependencies(ops[1], ops[2]); + } + else + CompilerGLSL::emit_instruction(instruction); + + break; + } + + case OpSDot: + case OpUDot: + case OpSUDot: + { + uint32_t result_type = ops[0]; + uint32_t id = ops[1]; + uint32_t vec1 = ops[2]; + uint32_t vec2 = ops[3]; + + auto &input_type1 = expression_type(vec1); + auto &input_type2 = expression_type(vec2); + + string vec1input, vec2input; + auto input_size = input_type1.vecsize; + if (instruction.length == 5) + { + if (ops[4] == PackedVectorFormatPackedVectorFormat4x8Bit) + { + string type = opcode == OpSDot || opcode == OpSUDot ? "char4" : "uchar4"; + vec1input = join("as_type<", type, ">(", to_expression(vec1), ")"); + type = opcode == OpSDot ? "char4" : "uchar4"; + vec2input = join("as_type<", type, ">(", to_expression(vec2), ")"); + input_size = 4; + } + else + SPIRV_CROSS_THROW("Packed vector formats other than 4x8Bit for integer dot product is not supported."); + } + else + { + // Inputs are sign or zero-extended to their target width. + SPIRType::BaseType vec1_expected_type = + opcode != OpUDot ? + to_signed_basetype(input_type1.width) : + to_unsigned_basetype(input_type1.width); + + SPIRType::BaseType vec2_expected_type = + opcode != OpSDot ? + to_unsigned_basetype(input_type2.width) : + to_signed_basetype(input_type2.width); + + vec1input = bitcast_expression(vec1_expected_type, vec1); + vec2input = bitcast_expression(vec2_expected_type, vec2); + } + + auto &type = get(result_type); + + // We'll get the appropriate sign-extend or zero-extend, no matter which type we cast to here. + // The addition in reduce_add is sign-invariant. + auto result_type_cast = join(type_to_glsl(type), input_size); + + string exp = join("reduce_add(", + result_type_cast, "(", vec1input, ") * ", + result_type_cast, "(", vec2input, "))"); + + emit_op(result_type, id, exp, should_forward(vec1) && should_forward(vec2)); + inherit_expression_dependencies(id, vec1); + inherit_expression_dependencies(id, vec2); + break; + } + + case OpSDotAccSat: + case OpUDotAccSat: + case OpSUDotAccSat: + { + uint32_t result_type = ops[0]; + uint32_t id = ops[1]; + uint32_t vec1 = ops[2]; + uint32_t vec2 = ops[3]; + uint32_t acc = ops[4]; + + auto input_type1 = expression_type(vec1); + auto input_type2 = expression_type(vec2); + + string vec1input, vec2input; + if (instruction.length == 6) + { + if (ops[5] == PackedVectorFormatPackedVectorFormat4x8Bit) + { + string type = opcode == OpSDotAccSat || opcode == OpSUDotAccSat ? "char4" : "uchar4"; + vec1input = join("as_type<", type, ">(", to_expression(vec1), ")"); + type = opcode == OpSDotAccSat ? "char4" : "uchar4"; + vec2input = join("as_type<", type, ">(", to_expression(vec2), ")"); + input_type1.vecsize = 4; + input_type2.vecsize = 4; + } + else + SPIRV_CROSS_THROW("Packed vector formats other than 4x8Bit for integer dot product is not supported."); + } + else + { + // Inputs are sign or zero-extended to their target width. + SPIRType::BaseType vec1_expected_type = + opcode != OpUDotAccSat ? + to_signed_basetype(input_type1.width) : + to_unsigned_basetype(input_type1.width); + + SPIRType::BaseType vec2_expected_type = + opcode != OpSDotAccSat ? + to_unsigned_basetype(input_type2.width) : + to_signed_basetype(input_type2.width); + + vec1input = bitcast_expression(vec1_expected_type, vec1); + vec2input = bitcast_expression(vec2_expected_type, vec2); + } + + auto &type = get(result_type); + + SPIRType::BaseType pre_saturate_type = + opcode != OpUDotAccSat ? + to_signed_basetype(type.width) : + to_unsigned_basetype(type.width); + + input_type1.basetype = pre_saturate_type; + input_type2.basetype = pre_saturate_type; + + string exp = join(type_to_glsl(type), "(addsat(reduce_add(", + type_to_glsl(input_type1), "(", vec1input, ") * ", + type_to_glsl(input_type2), "(", vec2input, ")), ", + bitcast_expression(pre_saturate_type, acc), "))"); + + emit_op(result_type, id, exp, should_forward(vec1) && should_forward(vec2)); + inherit_expression_dependencies(id, vec1); + inherit_expression_dependencies(id, vec2); + break; + } + + default: + CompilerGLSL::emit_instruction(instruction); + break; + } + + previous_instruction_opcode = opcode; +} + +void CompilerMSL::emit_texture_op(const Instruction &i, bool sparse) +{ + if (sparse) + SPIRV_CROSS_THROW("Sparse feedback not yet supported in MSL."); + + if (msl_options.use_framebuffer_fetch_subpasses) + { + auto *ops = stream(i); + + uint32_t result_type_id = ops[0]; + uint32_t id = ops[1]; + uint32_t img = ops[2]; + + auto &type = expression_type(img); + auto &imgtype = get(type.self); + + // Use Metal's native frame-buffer fetch API for subpass inputs. + if (imgtype.image.dim == DimSubpassData) + { + // Subpass inputs cannot be invalidated, + // so just forward the expression directly. + string expr = to_expression(img); + emit_op(result_type_id, id, expr, true); + return; + } + } + + // Fallback to default implementation + CompilerGLSL::emit_texture_op(i, sparse); +} + +void CompilerMSL::emit_barrier(uint32_t id_exe_scope, uint32_t id_mem_scope, uint32_t id_mem_sem) +{ + if (get_execution_model() != ExecutionModelGLCompute && !is_tesc_shader()) + return; + + uint32_t exe_scope = id_exe_scope ? evaluate_constant_u32(id_exe_scope) : uint32_t(ScopeInvocation); + uint32_t mem_scope = id_mem_scope ? evaluate_constant_u32(id_mem_scope) : uint32_t(ScopeInvocation); + // Use the wider of the two scopes (smaller value) + exe_scope = min(exe_scope, mem_scope); + + if (msl_options.emulate_subgroups && exe_scope >= ScopeSubgroup && !id_mem_sem) + // In this case, we assume a "subgroup" size of 1. The barrier, then, is a noop. + return; + + string bar_stmt; + if ((msl_options.is_ios() && msl_options.supports_msl_version(1, 2)) || msl_options.supports_msl_version(2)) + bar_stmt = exe_scope < ScopeSubgroup ? "threadgroup_barrier" : "simdgroup_barrier"; + else + bar_stmt = "threadgroup_barrier"; + bar_stmt += "("; + + uint32_t mem_sem = id_mem_sem ? evaluate_constant_u32(id_mem_sem) : uint32_t(MemorySemanticsMaskNone); + + // Use the | operator to combine flags if we can. + if (msl_options.supports_msl_version(1, 2)) + { + string mem_flags = ""; + // For tesc shaders, this also affects objects in the Output storage class. + // Since in Metal, these are placed in a device buffer, we have to sync device memory here. + if (is_tesc_shader() || + (mem_sem & (MemorySemanticsUniformMemoryMask | MemorySemanticsCrossWorkgroupMemoryMask))) + mem_flags += "mem_flags::mem_device"; + + // Fix tessellation patch function processing + if (is_tesc_shader() || (mem_sem & (MemorySemanticsSubgroupMemoryMask | MemorySemanticsWorkgroupMemoryMask))) + { + if (!mem_flags.empty()) + mem_flags += " | "; + mem_flags += "mem_flags::mem_threadgroup"; + } + if (mem_sem & MemorySemanticsImageMemoryMask) + { + if (!mem_flags.empty()) + mem_flags += " | "; + mem_flags += "mem_flags::mem_texture"; + } + + if (mem_flags.empty()) + mem_flags = "mem_flags::mem_none"; + + bar_stmt += mem_flags; + } + else + { + if ((mem_sem & (MemorySemanticsUniformMemoryMask | MemorySemanticsCrossWorkgroupMemoryMask)) && + (mem_sem & (MemorySemanticsSubgroupMemoryMask | MemorySemanticsWorkgroupMemoryMask))) + bar_stmt += "mem_flags::mem_device_and_threadgroup"; + else if (mem_sem & (MemorySemanticsUniformMemoryMask | MemorySemanticsCrossWorkgroupMemoryMask)) + bar_stmt += "mem_flags::mem_device"; + else if (mem_sem & (MemorySemanticsSubgroupMemoryMask | MemorySemanticsWorkgroupMemoryMask)) + bar_stmt += "mem_flags::mem_threadgroup"; + else if (mem_sem & MemorySemanticsImageMemoryMask) + bar_stmt += "mem_flags::mem_texture"; + else + bar_stmt += "mem_flags::mem_none"; + } + + bar_stmt += ");"; + + statement(bar_stmt); + + assert(current_emitting_block); + flush_control_dependent_expressions(current_emitting_block->self); + flush_all_active_variables(); +} + +static bool storage_class_array_is_thread(StorageClass storage) +{ + switch (storage) + { + case StorageClassInput: + case StorageClassOutput: + case StorageClassGeneric: + case StorageClassFunction: + case StorageClassPrivate: + return true; + + default: + return false; + } +} + +bool CompilerMSL::emit_array_copy(const char *expr, uint32_t lhs_id, uint32_t rhs_id, + StorageClass lhs_storage, StorageClass rhs_storage) +{ + // Allow Metal to use the array template to make arrays a value type. + // This, however, cannot be used for threadgroup address specifiers, so consider the custom array copy as fallback. + bool lhs_is_thread_storage = storage_class_array_is_thread(lhs_storage); + bool rhs_is_thread_storage = storage_class_array_is_thread(rhs_storage); + + bool lhs_is_array_template = lhs_is_thread_storage || lhs_storage == StorageClassWorkgroup; + bool rhs_is_array_template = rhs_is_thread_storage || rhs_storage == StorageClassWorkgroup; + + // Special considerations for stage IO variables. + // If the variable is actually backed by non-user visible device storage, we use array templates for those. + // + // Another special consideration is given to thread local variables which happen to have Offset decorations + // applied to them. Block-like types do not use array templates, so we need to force POD path if we detect + // these scenarios. This check isn't perfect since it would be technically possible to mix and match these things, + // and for a fully correct solution we might have to track array template state through access chains as well, + // but for all reasonable use cases, this should suffice. + // This special case should also only apply to Function/Private storage classes. + // We should not check backing variable for temporaries. + auto *lhs_var = maybe_get_backing_variable(lhs_id); + if (lhs_var && lhs_storage == StorageClassStorageBuffer && storage_class_array_is_thread(lhs_var->storage)) + lhs_is_array_template = true; + else if (lhs_var && lhs_storage != StorageClassGeneric && type_is_block_like(get(lhs_var->basetype))) + lhs_is_array_template = false; + + auto *rhs_var = maybe_get_backing_variable(rhs_id); + if (rhs_var && rhs_storage == StorageClassStorageBuffer && storage_class_array_is_thread(rhs_var->storage)) + rhs_is_array_template = true; + else if (rhs_var && rhs_storage != StorageClassGeneric && type_is_block_like(get(rhs_var->basetype))) + rhs_is_array_template = false; + + // If threadgroup storage qualifiers are *not* used: + // Avoid spvCopy* wrapper functions; Otherwise, spvUnsafeArray<> template cannot be used with that storage qualifier. + if (lhs_is_array_template && rhs_is_array_template && !using_builtin_array()) + { + // Fall back to normal copy path. + return false; + } + else + { + // Ensure the LHS variable has been declared + if (lhs_var) + flush_variable_declaration(lhs_var->self); + + string lhs; + if (expr) + lhs = expr; + else + lhs = to_expression(lhs_id); + + // Assignment from an array initializer is fine. + auto &type = expression_type(rhs_id); + auto *var = maybe_get_backing_variable(rhs_id); + + // Unfortunately, we cannot template on address space in MSL, + // so explicit address space redirection it is ... + bool is_constant = false; + if (ir.ids[rhs_id].get_type() == TypeConstant) + { + is_constant = true; + } + else if (var && var->remapped_variable && var->statically_assigned && + ir.ids[var->static_expression].get_type() == TypeConstant) + { + is_constant = true; + } + else if (rhs_storage == StorageClassUniform || rhs_storage == StorageClassUniformConstant) + { + is_constant = true; + } + + // For the case where we have OpLoad triggering an array copy, + // we cannot easily detect this case ahead of time since it's + // context dependent. We might have to force a recompile here + // if this is the only use of array copies in our shader. + add_spv_func_and_recompile(type.array.size() > 1 ? SPVFuncImplArrayCopyMultidim : SPVFuncImplArrayCopy); + + const char *tag = nullptr; + if (lhs_is_thread_storage && is_constant) + tag = "FromConstantToStack"; + else if (lhs_storage == StorageClassWorkgroup && is_constant) + tag = "FromConstantToThreadGroup"; + else if (lhs_is_thread_storage && rhs_is_thread_storage) + tag = "FromStackToStack"; + else if (lhs_storage == StorageClassWorkgroup && rhs_is_thread_storage) + tag = "FromStackToThreadGroup"; + else if (lhs_is_thread_storage && rhs_storage == StorageClassWorkgroup) + tag = "FromThreadGroupToStack"; + else if (lhs_storage == StorageClassWorkgroup && rhs_storage == StorageClassWorkgroup) + tag = "FromThreadGroupToThreadGroup"; + else if (lhs_storage == StorageClassStorageBuffer && rhs_storage == StorageClassStorageBuffer) + tag = "FromDeviceToDevice"; + else if (lhs_storage == StorageClassStorageBuffer && is_constant) + tag = "FromConstantToDevice"; + else if (lhs_storage == StorageClassStorageBuffer && rhs_storage == StorageClassWorkgroup) + tag = "FromThreadGroupToDevice"; + else if (lhs_storage == StorageClassStorageBuffer && rhs_is_thread_storage) + tag = "FromStackToDevice"; + else if (lhs_storage == StorageClassWorkgroup && rhs_storage == StorageClassStorageBuffer) + tag = "FromDeviceToThreadGroup"; + else if (lhs_is_thread_storage && rhs_storage == StorageClassStorageBuffer) + tag = "FromDeviceToStack"; + else + SPIRV_CROSS_THROW("Unknown storage class used for copying arrays."); + + // Pass internal array of spvUnsafeArray<> into wrapper functions + if (lhs_is_array_template && rhs_is_array_template && !msl_options.force_native_arrays) + statement("spvArrayCopy", tag, "(", lhs, ".elements, ", to_expression(rhs_id), ".elements);"); + if (lhs_is_array_template && !msl_options.force_native_arrays) + statement("spvArrayCopy", tag, "(", lhs, ".elements, ", to_expression(rhs_id), ");"); + else if (rhs_is_array_template && !msl_options.force_native_arrays) + statement("spvArrayCopy", tag, "(", lhs, ", ", to_expression(rhs_id), ".elements);"); + else + statement("spvArrayCopy", tag, "(", lhs, ", ", to_expression(rhs_id), ");"); + } + + return true; +} + +uint32_t CompilerMSL::get_physical_tess_level_array_size(spv::BuiltIn builtin) const +{ + if (is_tessellating_triangles()) + return builtin == BuiltInTessLevelInner ? 1 : 3; + else + return builtin == BuiltInTessLevelInner ? 2 : 4; +} + +// Since MSL does not allow arrays to be copied via simple variable assignment, +// if the LHS and RHS represent an assignment of an entire array, it must be +// implemented by calling an array copy function. +// Returns whether the struct assignment was emitted. +bool CompilerMSL::maybe_emit_array_assignment(uint32_t id_lhs, uint32_t id_rhs) +{ + // We only care about assignments of an entire array + auto &type = expression_type(id_lhs); + if (!is_array(get_pointee_type(type))) + return false; + + auto *var = maybe_get(id_lhs); + + // Is this a remapped, static constant? Don't do anything. + if (var && var->remapped_variable && var->statically_assigned) + return true; + + if (ir.ids[id_rhs].get_type() == TypeConstant && var && var->deferred_declaration) + { + // Special case, if we end up declaring a variable when assigning the constant array, + // we can avoid the copy by directly assigning the constant expression. + // This is likely necessary to be able to use a variable as a true look-up table, as it is unlikely + // the compiler will be able to optimize the spvArrayCopy() into a constant LUT. + // After a variable has been declared, we can no longer assign constant arrays in MSL unfortunately. + statement(to_expression(id_lhs), " = ", constant_expression(get(id_rhs)), ";"); + return true; + } + + if (is_tesc_shader() && has_decoration(id_lhs, DecorationBuiltIn)) + { + auto builtin = BuiltIn(get_decoration(id_lhs, DecorationBuiltIn)); + // Need to manually unroll the array store. + if (builtin == BuiltInTessLevelInner || builtin == BuiltInTessLevelOuter) + { + uint32_t array_size = get_physical_tess_level_array_size(builtin); + if (array_size == 1) + statement(to_expression(id_lhs), " = half(", to_expression(id_rhs), "[0]);"); + else + { + for (uint32_t i = 0; i < array_size; i++) + statement(to_expression(id_lhs), "[", i, "] = half(", to_expression(id_rhs), "[", i, "]);"); + } + return true; + } + } + + auto lhs_storage = get_expression_effective_storage_class(id_lhs); + auto rhs_storage = get_expression_effective_storage_class(id_rhs); + if (!emit_array_copy(nullptr, id_lhs, id_rhs, lhs_storage, rhs_storage)) + return false; + + register_write(id_lhs); + + return true; +} + +// Emits one of the atomic functions. In MSL, the atomic functions operate on pointers +void CompilerMSL::emit_atomic_func_op(uint32_t result_type, uint32_t result_id, const char *op, Op opcode, + uint32_t mem_order_1, uint32_t mem_order_2, bool has_mem_order_2, uint32_t obj, uint32_t op1, + bool op1_is_pointer, bool op1_is_literal, uint32_t op2) +{ + string exp; + + auto &ptr_type = expression_type(obj); + auto &type = get_pointee_type(ptr_type); + auto expected_type = type.basetype; + if (opcode == OpAtomicUMax || opcode == OpAtomicUMin) + expected_type = to_unsigned_basetype(type.width); + else if (opcode == OpAtomicSMax || opcode == OpAtomicSMin) + expected_type = to_signed_basetype(type.width); + + bool use_native_image_atomic; + if (msl_options.supports_msl_version(3, 1)) + use_native_image_atomic = check_atomic_image(obj); + else + use_native_image_atomic = false; + + if (type.width == 64) + SPIRV_CROSS_THROW("MSL currently does not support 64-bit atomics."); + + auto remapped_type = type; + remapped_type.basetype = expected_type; + + auto *var = maybe_get_backing_variable(obj); + const auto *res_type = var ? &get(var->basetype) : nullptr; + assert(type.storage != StorageClassImage || res_type); + + bool is_atomic_compare_exchange_strong = op1_is_pointer && op1; + + bool check_discard = opcode != OpAtomicLoad && needs_frag_discard_checks() && + ptr_type.storage != StorageClassWorkgroup; + + // Even compare exchange atomics are vec4 on metal for ... reasons :v + uint32_t vec4_temporary_id = 0; + if (use_native_image_atomic && is_atomic_compare_exchange_strong) + { + uint32_t &tmp_id = extra_sub_expressions[result_id]; + if (!tmp_id) + { + tmp_id = ir.increase_bound_by(2); + + auto vec4_type = get(result_type); + vec4_type.vecsize = 4; + set(tmp_id + 1, vec4_type); + } + + vec4_temporary_id = tmp_id; + } + + if (check_discard) + { + if (is_atomic_compare_exchange_strong) + { + // We're already emitting a CAS loop here; a conditional won't hurt. + emit_uninitialized_temporary_expression(result_type, result_id); + if (vec4_temporary_id) + emit_uninitialized_temporary_expression(vec4_temporary_id + 1, vec4_temporary_id); + statement("if (!", builtin_to_glsl(BuiltInHelperInvocation, StorageClassInput), ")"); + begin_scope(); + } + else + exp = join("(!", builtin_to_glsl(BuiltInHelperInvocation, StorageClassInput), " ? "); + } + + if (use_native_image_atomic) + { + auto obj_expression = to_expression(obj); + auto split_index = obj_expression.find_first_of('@'); + + // Will only be false if we're in "force recompile later" mode. + if (split_index != string::npos) + { + auto coord = obj_expression.substr(split_index + 1); + auto image_expr = obj_expression.substr(0, split_index); + + // Handle problem cases with sign where we need signed min/max on a uint image for example. + // It seems to work to cast the texture type itself, even if it is probably wildly outside of spec, + // but SPIR-V requires this to work. + if ((opcode == OpAtomicUMax || opcode == OpAtomicUMin || + opcode == OpAtomicSMax || opcode == OpAtomicSMin) && + type.basetype != expected_type) + { + auto *backing_var = maybe_get_backing_variable(obj); + if (backing_var) + { + add_spv_func_and_recompile(SPVFuncImplTextureCast); + + const auto *backing_type = &get(backing_var->basetype); + while (backing_type->op != OpTypeImage) + backing_type = &get(backing_type->parent_type); + + auto img_type = *backing_type; + auto tmp_type = type; + tmp_type.basetype = expected_type; + img_type.image.type = ir.increase_bound_by(1); + set(img_type.image.type, tmp_type); + + image_expr = join("spvTextureCast<", type_to_glsl(img_type, obj), ">(", image_expr, ")"); + } + } + + exp += join(image_expr, ".", op, "("); + if (ptr_type.storage == StorageClassImage && res_type->image.arrayed) + { + switch (res_type->image.dim) + { + case Dim1D: + if (msl_options.texture_1D_as_2D) + exp += join("uint2(", coord, ".x, 0), ", coord, ".y"); + else + exp += join(coord, ".x, ", coord, ".y"); + + break; + case Dim2D: + exp += join(coord, ".xy, ", coord, ".z"); + break; + default: + SPIRV_CROSS_THROW("Cannot do atomics on Cube textures."); + } + } + else if (ptr_type.storage == StorageClassImage && res_type->image.dim == Dim1D && msl_options.texture_1D_as_2D) + exp += join("uint2(", coord, ", 0)"); + else + exp += coord; + } + else + { + exp += obj_expression; + } + } + else + { + exp += string(op) + "_explicit("; + exp += "("; + // Emulate texture2D atomic operations + if (ptr_type.storage == StorageClassImage) + { + auto &flags = ir.get_decoration_bitset(var->self); + if (decoration_flags_signal_volatile(flags)) + exp += "volatile "; + exp += "device"; + } + else if (var && ptr_type.storage != StorageClassPhysicalStorageBuffer) + { + exp += get_argument_address_space(*var); + } + else + { + // Fallback scenario, could happen for raw pointers. + exp += ptr_type.storage == StorageClassWorkgroup ? "threadgroup" : "device"; + } + + exp += " atomic_"; + // For signed and unsigned min/max, we can signal this through the pointer type. + // There is no other way, since C++ does not have explicit signage for atomics. + exp += type_to_glsl(remapped_type); + exp += "*)"; + + exp += "&"; + exp += to_enclosed_expression(obj); + } + + if (is_atomic_compare_exchange_strong) + { + assert(strcmp(op, "atomic_compare_exchange_weak") == 0); + assert(op2); + assert(has_mem_order_2); + exp += ", &"; + exp += to_name(vec4_temporary_id ? vec4_temporary_id : result_id); + exp += ", "; + exp += to_expression(op2); + + if (!use_native_image_atomic) + { + exp += ", "; + exp += get_memory_order(mem_order_1); + exp += ", "; + exp += get_memory_order(mem_order_2); + } + exp += ")"; + + // MSL only supports the weak atomic compare exchange, so emit a CAS loop here. + // The MSL function returns false if the atomic write fails OR the comparison test fails, + // so we must validate that it wasn't the comparison test that failed before continuing + // the CAS loop, otherwise it will loop infinitely, with the comparison test always failing. + // The function updates the comparator value from the memory value, so the additional + // comparison test evaluates the memory value against the expected value. + if (!check_discard) + { + emit_uninitialized_temporary_expression(result_type, result_id); + if (vec4_temporary_id) + emit_uninitialized_temporary_expression(vec4_temporary_id + 1, vec4_temporary_id); + } + + statement("do"); + begin_scope(); + + string scalar_expression; + if (vec4_temporary_id) + scalar_expression = join(to_expression(vec4_temporary_id), ".x"); + else + scalar_expression = to_expression(result_id); + + statement(scalar_expression, " = ", to_expression(op1), ";"); + end_scope_decl(join("while (!", exp, " && ", scalar_expression, " == ", to_enclosed_expression(op1), ")")); + if (vec4_temporary_id) + statement(to_expression(result_id), " = ", scalar_expression, ";"); + + // Vulkan: (section 9.29: ... and values returned by atomic instructions in helper invocations are undefined) + if (check_discard) + { + end_scope(); + statement("else"); + begin_scope(); + statement(to_expression(result_id), " = {};"); + end_scope(); + } + } + else + { + assert(strcmp(op, "atomic_compare_exchange_weak") != 0); + + if (op1) + { + exp += ", "; + if (op1_is_literal) + exp += to_string(op1); + else + exp += bitcast_expression(expected_type, op1); + } + + if (op2) + exp += ", " + to_expression(op2); + + if (!use_native_image_atomic) + { + exp += string(", ") + get_memory_order(mem_order_1); + if (has_mem_order_2) + exp += string(", ") + get_memory_order(mem_order_2); + } + + exp += ")"; + + // For some particular reason, atomics return vec4 in Metal ... + if (use_native_image_atomic) + exp += ".x"; + + // Vulkan: (section 9.29: ... and values returned by atomic instructions in helper invocations are undefined) + if (check_discard) + { + exp += " : "; + if (strcmp(op, "atomic_store") != 0) + exp += join(type_to_glsl(get(result_type)), "{}"); + else + exp += "((void)0)"; + exp += ")"; + } + + if (expected_type != type.basetype) + exp = bitcast_expression(type, expected_type, exp); + + if (strcmp(op, "atomic_store") != 0) + emit_op(result_type, result_id, exp, false); + else + statement(exp, ";"); + } + + flush_all_atomic_capable_variables(); +} + +// Metal only supports relaxed memory order for now +const char *CompilerMSL::get_memory_order(uint32_t) +{ + return "memory_order_relaxed"; +} + +// Override for MSL-specific extension syntax instructions. +// In some cases, deliberately select either the fast or precise versions of the MSL functions to match Vulkan math precision results. +void CompilerMSL::emit_glsl_op(uint32_t result_type, uint32_t id, uint32_t eop, const uint32_t *args, uint32_t count) +{ + auto op = static_cast(eop); + + // If we need to do implicit bitcasts, make sure we do it with the correct type. + uint32_t integer_width = get_integer_width_for_glsl_instruction(op, args, count); + auto int_type = to_signed_basetype(integer_width); + auto uint_type = to_unsigned_basetype(integer_width); + + op = get_remapped_glsl_op(op); + + auto &restype = get(result_type); + + switch (op) + { + case GLSLstd450Sinh: + if (restype.basetype == SPIRType::Half) + { + // MSL does not have overload for half. Force-cast back to half. + auto expr = join("half(fast::sinh(", to_unpacked_expression(args[0]), "))"); + emit_op(result_type, id, expr, should_forward(args[0])); + inherit_expression_dependencies(id, args[0]); + } + else + emit_unary_func_op(result_type, id, args[0], "fast::sinh"); + break; + case GLSLstd450Cosh: + if (restype.basetype == SPIRType::Half) + { + // MSL does not have overload for half. Force-cast back to half. + auto expr = join("half(fast::cosh(", to_unpacked_expression(args[0]), "))"); + emit_op(result_type, id, expr, should_forward(args[0])); + inherit_expression_dependencies(id, args[0]); + } + else + emit_unary_func_op(result_type, id, args[0], "fast::cosh"); + break; + case GLSLstd450Tanh: + if (restype.basetype == SPIRType::Half) + { + // MSL does not have overload for half. Force-cast back to half. + auto expr = join("half(fast::tanh(", to_unpacked_expression(args[0]), "))"); + emit_op(result_type, id, expr, should_forward(args[0])); + inherit_expression_dependencies(id, args[0]); + } + else + emit_unary_func_op(result_type, id, args[0], "precise::tanh"); + break; + case GLSLstd450Atan2: + if (restype.basetype == SPIRType::Half) + { + // MSL does not have overload for half. Force-cast back to half. + auto expr = join("half(fast::atan2(", to_unpacked_expression(args[0]), ", ", to_unpacked_expression(args[1]), "))"); + emit_op(result_type, id, expr, should_forward(args[0]) && should_forward(args[1])); + inherit_expression_dependencies(id, args[0]); + inherit_expression_dependencies(id, args[1]); + } + else + emit_binary_func_op(result_type, id, args[0], args[1], "precise::atan2"); + break; + case GLSLstd450InverseSqrt: + emit_unary_func_op(result_type, id, args[0], "rsqrt"); + break; + case GLSLstd450RoundEven: + emit_unary_func_op(result_type, id, args[0], "rint"); + break; + + case GLSLstd450FindILsb: + { + // In this template version of findLSB, we return T. + auto basetype = expression_type(args[0]).basetype; + emit_unary_func_op_cast(result_type, id, args[0], "spvFindLSB", basetype, basetype); + break; + } + + case GLSLstd450FindSMsb: + emit_unary_func_op_cast(result_type, id, args[0], "spvFindSMSB", int_type, int_type); + break; + + case GLSLstd450FindUMsb: + emit_unary_func_op_cast(result_type, id, args[0], "spvFindUMSB", uint_type, uint_type); + break; + + case GLSLstd450PackSnorm4x8: + emit_unary_func_op(result_type, id, args[0], "pack_float_to_snorm4x8"); + break; + case GLSLstd450PackUnorm4x8: + emit_unary_func_op(result_type, id, args[0], "pack_float_to_unorm4x8"); + break; + case GLSLstd450PackSnorm2x16: + emit_unary_func_op(result_type, id, args[0], "pack_float_to_snorm2x16"); + break; + case GLSLstd450PackUnorm2x16: + emit_unary_func_op(result_type, id, args[0], "pack_float_to_unorm2x16"); + break; + + case GLSLstd450PackHalf2x16: + { + auto expr = join("as_type(half2(", to_expression(args[0]), "))"); + emit_op(result_type, id, expr, should_forward(args[0])); + inherit_expression_dependencies(id, args[0]); + break; + } + + case GLSLstd450UnpackSnorm4x8: + emit_unary_func_op(result_type, id, args[0], "unpack_snorm4x8_to_float"); + break; + case GLSLstd450UnpackUnorm4x8: + emit_unary_func_op(result_type, id, args[0], "unpack_unorm4x8_to_float"); + break; + case GLSLstd450UnpackSnorm2x16: + emit_unary_func_op(result_type, id, args[0], "unpack_snorm2x16_to_float"); + break; + case GLSLstd450UnpackUnorm2x16: + emit_unary_func_op(result_type, id, args[0], "unpack_unorm2x16_to_float"); + break; + + case GLSLstd450UnpackHalf2x16: + { + auto expr = join("float2(as_type(", to_expression(args[0]), "))"); + emit_op(result_type, id, expr, should_forward(args[0])); + inherit_expression_dependencies(id, args[0]); + break; + } + + case GLSLstd450PackDouble2x32: + emit_unary_func_op(result_type, id, args[0], "unsupported_GLSLstd450PackDouble2x32"); // Currently unsupported + break; + case GLSLstd450UnpackDouble2x32: + emit_unary_func_op(result_type, id, args[0], "unsupported_GLSLstd450UnpackDouble2x32"); // Currently unsupported + break; + + case GLSLstd450MatrixInverse: + { + auto &mat_type = get(result_type); + switch (mat_type.columns) + { + case 2: + emit_unary_func_op(result_type, id, args[0], "spvInverse2x2"); + break; + case 3: + emit_unary_func_op(result_type, id, args[0], "spvInverse3x3"); + break; + case 4: + emit_unary_func_op(result_type, id, args[0], "spvInverse4x4"); + break; + default: + break; + } + break; + } + + case GLSLstd450FMin: + // If the result type isn't float, don't bother calling the specific + // precise::/fast:: version. Metal doesn't have those for half and + // double types. + if (get(result_type).basetype != SPIRType::Float) + emit_binary_func_op(result_type, id, args[0], args[1], "min"); + else + emit_binary_func_op(result_type, id, args[0], args[1], "fast::min"); + break; + + case GLSLstd450FMax: + if (get(result_type).basetype != SPIRType::Float) + emit_binary_func_op(result_type, id, args[0], args[1], "max"); + else + emit_binary_func_op(result_type, id, args[0], args[1], "fast::max"); + break; + + case GLSLstd450FClamp: + // TODO: If args[1] is 0 and args[2] is 1, emit a saturate() call. + if (get(result_type).basetype != SPIRType::Float) + emit_trinary_func_op(result_type, id, args[0], args[1], args[2], "clamp"); + else + emit_trinary_func_op(result_type, id, args[0], args[1], args[2], "fast::clamp"); + break; + + case GLSLstd450NMin: + if (get(result_type).basetype != SPIRType::Float) + emit_binary_func_op(result_type, id, args[0], args[1], "min"); + else + emit_binary_func_op(result_type, id, args[0], args[1], "precise::min"); + break; + + case GLSLstd450NMax: + if (get(result_type).basetype != SPIRType::Float) + emit_binary_func_op(result_type, id, args[0], args[1], "max"); + else + emit_binary_func_op(result_type, id, args[0], args[1], "precise::max"); + break; + + case GLSLstd450NClamp: + // TODO: If args[1] is 0 and args[2] is 1, emit a saturate() call. + if (get(result_type).basetype != SPIRType::Float) + emit_trinary_func_op(result_type, id, args[0], args[1], args[2], "clamp"); + else + emit_trinary_func_op(result_type, id, args[0], args[1], args[2], "precise::clamp"); + break; + + case GLSLstd450InterpolateAtCentroid: + { + // We can't just emit the expression normally, because the qualified name contains a call to the default + // interpolate method, or refers to a local variable. We saved the interface index we need; use it to construct + // the base for the method call. + uint32_t interface_index = get_extended_decoration(args[0], SPIRVCrossDecorationInterfaceMemberIndex); + string component; + if (has_extended_decoration(args[0], SPIRVCrossDecorationInterpolantComponentExpr)) + { + uint32_t index_expr = get_extended_decoration(args[0], SPIRVCrossDecorationInterpolantComponentExpr); + auto *c = maybe_get(index_expr); + if (!c || c->specialization) + component = join("[", to_expression(index_expr), "]"); + else + component = join(".", index_to_swizzle(c->scalar())); + } + emit_op(result_type, id, + join(to_name(stage_in_var_id), ".", to_member_name(get_stage_in_struct_type(), interface_index), + ".interpolate_at_centroid()", component), + should_forward(args[0])); + break; + } + + case GLSLstd450InterpolateAtSample: + { + uint32_t interface_index = get_extended_decoration(args[0], SPIRVCrossDecorationInterfaceMemberIndex); + string component; + if (has_extended_decoration(args[0], SPIRVCrossDecorationInterpolantComponentExpr)) + { + uint32_t index_expr = get_extended_decoration(args[0], SPIRVCrossDecorationInterpolantComponentExpr); + auto *c = maybe_get(index_expr); + if (!c || c->specialization) + component = join("[", to_expression(index_expr), "]"); + else + component = join(".", index_to_swizzle(c->scalar())); + } + emit_op(result_type, id, + join(to_name(stage_in_var_id), ".", to_member_name(get_stage_in_struct_type(), interface_index), + ".interpolate_at_sample(", to_expression(args[1]), ")", component), + should_forward(args[0]) && should_forward(args[1])); + break; + } + + case GLSLstd450InterpolateAtOffset: + { + uint32_t interface_index = get_extended_decoration(args[0], SPIRVCrossDecorationInterfaceMemberIndex); + string component; + if (has_extended_decoration(args[0], SPIRVCrossDecorationInterpolantComponentExpr)) + { + uint32_t index_expr = get_extended_decoration(args[0], SPIRVCrossDecorationInterpolantComponentExpr); + auto *c = maybe_get(index_expr); + if (!c || c->specialization) + component = join("[", to_expression(index_expr), "]"); + else + component = join(".", index_to_swizzle(c->scalar())); + } + // Like Direct3D, Metal puts the (0, 0) at the upper-left corner, not the center as SPIR-V and GLSL do. + // Offset the offset by (1/2 - 1/16), or 0.4375, to compensate for this. + // It has to be (1/2 - 1/16) and not 1/2, or several CTS tests subtly break on Intel. + emit_op(result_type, id, + join(to_name(stage_in_var_id), ".", to_member_name(get_stage_in_struct_type(), interface_index), + ".interpolate_at_offset(", to_expression(args[1]), " + 0.4375)", component), + should_forward(args[0]) && should_forward(args[1])); + break; + } + + case GLSLstd450Distance: + // MSL does not support scalar versions here. + if (expression_type(args[0]).vecsize == 1) + { + // Equivalent to length(a - b) -> abs(a - b). + emit_op(result_type, id, + join("abs(", to_enclosed_unpacked_expression(args[0]), " - ", + to_enclosed_unpacked_expression(args[1]), ")"), + should_forward(args[0]) && should_forward(args[1])); + inherit_expression_dependencies(id, args[0]); + inherit_expression_dependencies(id, args[1]); + } + else + CompilerGLSL::emit_glsl_op(result_type, id, eop, args, count); + break; + + case GLSLstd450Length: + // MSL does not support scalar versions, so use abs(). + if (expression_type(args[0]).vecsize == 1) + emit_unary_func_op(result_type, id, args[0], "abs"); + else + CompilerGLSL::emit_glsl_op(result_type, id, eop, args, count); + break; + + case GLSLstd450Normalize: + { + auto &exp_type = expression_type(args[0]); + // MSL does not support scalar versions here. + // MSL has no implementation for normalize in the fast:: namespace for half2 and half3 + // Returns -1 or 1 for valid input, sign() does the job. + if (exp_type.vecsize == 1) + emit_unary_func_op(result_type, id, args[0], "sign"); + else if (exp_type.vecsize <= 3 && exp_type.basetype == SPIRType::Half) + emit_unary_func_op(result_type, id, args[0], "normalize"); + else + emit_unary_func_op(result_type, id, args[0], "fast::normalize"); + break; + } + case GLSLstd450Reflect: + if (get(result_type).vecsize == 1) + emit_binary_func_op(result_type, id, args[0], args[1], "spvReflect"); + else + CompilerGLSL::emit_glsl_op(result_type, id, eop, args, count); + break; + + case GLSLstd450Refract: + if (get(result_type).vecsize == 1) + emit_trinary_func_op(result_type, id, args[0], args[1], args[2], "spvRefract"); + else + CompilerGLSL::emit_glsl_op(result_type, id, eop, args, count); + break; + + case GLSLstd450FaceForward: + if (get(result_type).vecsize == 1) + emit_trinary_func_op(result_type, id, args[0], args[1], args[2], "spvFaceForward"); + else + CompilerGLSL::emit_glsl_op(result_type, id, eop, args, count); + break; + + case GLSLstd450Modf: + case GLSLstd450Frexp: + { + // Special case. If the variable is a scalar access chain, we cannot use it directly. We have to emit a temporary. + // Another special case is if the variable is in a storage class which is not thread. + auto *ptr = maybe_get(args[1]); + auto &type = expression_type(args[1]); + + bool is_thread_storage = storage_class_array_is_thread(type.storage); + if (type.storage == StorageClassOutput && capture_output_to_buffer) + is_thread_storage = false; + + if (!is_thread_storage || + (ptr && ptr->access_chain && is_scalar(expression_type(args[1])))) + { + register_call_out_argument(args[1]); + forced_temporaries.insert(id); + + // Need to create temporaries and copy over to access chain after. + // We cannot directly take the reference of a vector swizzle in MSL, even if it's scalar ... + uint32_t &tmp_id = extra_sub_expressions[id]; + if (!tmp_id) + tmp_id = ir.increase_bound_by(1); + + uint32_t tmp_type_id = get_pointee_type_id(expression_type_id(args[1])); + emit_uninitialized_temporary_expression(tmp_type_id, tmp_id); + emit_binary_func_op(result_type, id, args[0], tmp_id, eop == GLSLstd450Modf ? "modf" : "frexp"); + statement(to_expression(args[1]), " = ", to_expression(tmp_id), ";"); + } + else + CompilerGLSL::emit_glsl_op(result_type, id, eop, args, count); + break; + } + + case GLSLstd450Pow: + // powr makes x < 0.0 undefined, just like SPIR-V. + emit_binary_func_op(result_type, id, args[0], args[1], "powr"); + break; + + default: + CompilerGLSL::emit_glsl_op(result_type, id, eop, args, count); + break; + } +} + +void CompilerMSL::emit_spv_amd_shader_trinary_minmax_op(uint32_t result_type, uint32_t id, uint32_t eop, + const uint32_t *args, uint32_t count) +{ + enum AMDShaderTrinaryMinMax + { + FMin3AMD = 1, + UMin3AMD = 2, + SMin3AMD = 3, + FMax3AMD = 4, + UMax3AMD = 5, + SMax3AMD = 6, + FMid3AMD = 7, + UMid3AMD = 8, + SMid3AMD = 9 + }; + + if (!msl_options.supports_msl_version(2, 1)) + SPIRV_CROSS_THROW("Trinary min/max functions require MSL 2.1."); + + auto op = static_cast(eop); + + switch (op) + { + case FMid3AMD: + case UMid3AMD: + case SMid3AMD: + emit_trinary_func_op(result_type, id, args[0], args[1], args[2], "median3"); + break; + default: + CompilerGLSL::emit_spv_amd_shader_trinary_minmax_op(result_type, id, eop, args, count); + break; + } +} + +// Emit a structure declaration for the specified interface variable. +void CompilerMSL::emit_interface_block(uint32_t ib_var_id) +{ + if (ib_var_id) + { + auto &ib_var = get(ib_var_id); + auto &ib_type = get_variable_data_type(ib_var); + //assert(ib_type.basetype == SPIRType::Struct && !ib_type.member_types.empty()); + assert(ib_type.basetype == SPIRType::Struct); + emit_struct(ib_type); + } +} + +// Emits the declaration signature of the specified function. +// If this is the entry point function, Metal-specific return value and function arguments are added. +void CompilerMSL::emit_function_prototype(SPIRFunction &func, const Bitset &) +{ + if (func.self != ir.default_entry_point) + add_function_overload(func); + + local_variable_names = resource_names; + string decl; + + processing_entry_point = func.self == ir.default_entry_point; + + // Metal helper functions must be static force-inline otherwise they will cause problems when linked together in a single Metallib. + if (!processing_entry_point) + statement(force_inline); + + auto &type = get(func.return_type); + + if (!type.array.empty() && msl_options.force_native_arrays) + { + // We cannot return native arrays in MSL, so "return" through an out variable. + decl += "void"; + } + else + { + decl += func_type_decl(type); + } + + decl += " "; + decl += to_name(func.self); + decl += "("; + + if (!type.array.empty() && msl_options.force_native_arrays) + { + // Fake arrays returns by writing to an out array instead. + decl += "thread "; + decl += type_to_glsl(type); + decl += " (&spvReturnValue)"; + decl += type_to_array_glsl(type, 0); + if (!func.arguments.empty()) + decl += ", "; + } + + if (processing_entry_point) + { + if (msl_options.argument_buffers) + decl += entry_point_args_argument_buffer(!func.arguments.empty()); + else + decl += entry_point_args_classic(!func.arguments.empty()); + + // append entry point args to avoid conflicts in local variable names. + local_variable_names.insert(resource_names.begin(), resource_names.end()); + + // If entry point function has variables that require early declaration, + // ensure they each have an empty initializer, creating one if needed. + // This is done at this late stage because the initialization expression + // is cleared after each compilation pass. + for (auto var_id : vars_needing_early_declaration) + { + auto &ed_var = get(var_id); + ID &initializer = ed_var.initializer; + if (!initializer) + initializer = ir.increase_bound_by(1); + + // Do not override proper initializers. + if (ir.ids[initializer].get_type() == TypeNone || ir.ids[initializer].get_type() == TypeExpression) + set(ed_var.initializer, "{}", ed_var.basetype, true); + } + } + + for (auto &arg : func.arguments) + { + uint32_t name_id = arg.id; + + auto *var = maybe_get(arg.id); + if (var) + { + // If we need to modify the name of the variable, make sure we modify the original variable. + // Our alias is just a shadow variable. + if (arg.alias_global_variable && var->basevariable) + name_id = var->basevariable; + + var->parameter = &arg; // Hold a pointer to the parameter so we can invalidate the readonly field if needed. + } + + add_local_variable_name(name_id); + + decl += argument_decl(arg); + + bool is_dynamic_img_sampler = has_extended_decoration(arg.id, SPIRVCrossDecorationDynamicImageSampler); + + auto &arg_type = get(arg.type); + if (arg_type.basetype == SPIRType::SampledImage && !is_dynamic_img_sampler) + { + // Manufacture automatic plane args for multiplanar texture + uint32_t planes = 1; + if (auto *constexpr_sampler = find_constexpr_sampler(name_id)) + if (constexpr_sampler->ycbcr_conversion_enable) + planes = constexpr_sampler->planes; + for (uint32_t i = 1; i < planes; i++) + decl += join(", ", argument_decl(arg), plane_name_suffix, i); + + // Manufacture automatic sampler arg for SampledImage texture + if (arg_type.image.dim != DimBuffer) + { + if (arg_type.array.empty() || (var ? is_var_runtime_size_array(*var) : is_runtime_size_array(arg_type))) + { + decl += join(", ", sampler_type(arg_type, arg.id, false), " ", to_sampler_expression(name_id)); + } + else + { + const char *sampler_address_space = + descriptor_address_space(name_id, + StorageClassUniformConstant, + "thread const"); + decl += join(", ", sampler_address_space, " ", sampler_type(arg_type, name_id, false), "& ", + to_sampler_expression(name_id)); + } + } + } + + // Manufacture automatic swizzle arg. + if (msl_options.swizzle_texture_samples && has_sampled_images && is_sampled_image_type(arg_type) && + !is_dynamic_img_sampler) + { + bool arg_is_array = !arg_type.array.empty(); + decl += join(", constant uint", arg_is_array ? "* " : "& ", to_swizzle_expression(name_id)); + } + + if (buffer_requires_array_length(name_id)) + { + bool arg_is_array = !arg_type.array.empty(); + decl += join(", constant uint", arg_is_array ? "* " : "& ", to_buffer_size_expression(name_id)); + } + + if (&arg != &func.arguments.back()) + decl += ", "; + } + + decl += ")"; + statement(decl); +} + +static bool needs_chroma_reconstruction(const MSLConstexprSampler *constexpr_sampler) +{ + // For now, only multiplanar images need explicit reconstruction. GBGR and BGRG images + // use implicit reconstruction. + return constexpr_sampler && constexpr_sampler->ycbcr_conversion_enable && constexpr_sampler->planes > 1; +} + +// Returns the texture sampling function string for the specified image and sampling characteristics. +string CompilerMSL::to_function_name(const TextureFunctionNameArguments &args) +{ + VariableID img = args.base.img; + const MSLConstexprSampler *constexpr_sampler = nullptr; + bool is_dynamic_img_sampler = false; + if (auto *var = maybe_get_backing_variable(img)) + { + constexpr_sampler = find_constexpr_sampler(var->basevariable ? var->basevariable : VariableID(var->self)); + is_dynamic_img_sampler = has_extended_decoration(var->self, SPIRVCrossDecorationDynamicImageSampler); + } + + // Special-case gather. We have to alter the component being looked up in the swizzle case. + if (msl_options.swizzle_texture_samples && args.base.is_gather && !is_dynamic_img_sampler && + (!constexpr_sampler || !constexpr_sampler->ycbcr_conversion_enable)) + { + bool is_compare = comparison_ids.count(img); + add_spv_func_and_recompile(is_compare ? SPVFuncImplGatherCompareSwizzle : SPVFuncImplGatherSwizzle); + return is_compare ? "spvGatherCompareSwizzle" : "spvGatherSwizzle"; + } + + // Special-case gather with an array of offsets. We have to lower into 4 separate gathers. + if (args.has_array_offsets && !is_dynamic_img_sampler && + (!constexpr_sampler || !constexpr_sampler->ycbcr_conversion_enable)) + { + bool is_compare = comparison_ids.count(img); + add_spv_func_and_recompile(is_compare ? SPVFuncImplGatherCompareConstOffsets : SPVFuncImplGatherConstOffsets); + add_spv_func_and_recompile(SPVFuncImplForwardArgs); + return is_compare ? "spvGatherCompareConstOffsets" : "spvGatherConstOffsets"; + } + + auto *combined = maybe_get(img); + + // Texture reference + string fname; + if (needs_chroma_reconstruction(constexpr_sampler) && !is_dynamic_img_sampler) + { + if (constexpr_sampler->planes != 2 && constexpr_sampler->planes != 3) + SPIRV_CROSS_THROW("Unhandled number of color image planes!"); + // 444 images aren't downsampled, so we don't need to do linear filtering. + if (constexpr_sampler->resolution == MSL_FORMAT_RESOLUTION_444 || + constexpr_sampler->chroma_filter == MSL_SAMPLER_FILTER_NEAREST) + { + if (constexpr_sampler->planes == 2) + add_spv_func_and_recompile(SPVFuncImplChromaReconstructNearest2Plane); + else + add_spv_func_and_recompile(SPVFuncImplChromaReconstructNearest3Plane); + fname = "spvChromaReconstructNearest"; + } + else // Linear with a downsampled format + { + fname = "spvChromaReconstructLinear"; + switch (constexpr_sampler->resolution) + { + case MSL_FORMAT_RESOLUTION_444: + assert(false); + break; // not reached + case MSL_FORMAT_RESOLUTION_422: + switch (constexpr_sampler->x_chroma_offset) + { + case MSL_CHROMA_LOCATION_COSITED_EVEN: + if (constexpr_sampler->planes == 2) + add_spv_func_and_recompile(SPVFuncImplChromaReconstructLinear422CositedEven2Plane); + else + add_spv_func_and_recompile(SPVFuncImplChromaReconstructLinear422CositedEven3Plane); + fname += "422CositedEven"; + break; + case MSL_CHROMA_LOCATION_MIDPOINT: + if (constexpr_sampler->planes == 2) + add_spv_func_and_recompile(SPVFuncImplChromaReconstructLinear422Midpoint2Plane); + else + add_spv_func_and_recompile(SPVFuncImplChromaReconstructLinear422Midpoint3Plane); + fname += "422Midpoint"; + break; + default: + SPIRV_CROSS_THROW("Invalid chroma location."); + } + break; + case MSL_FORMAT_RESOLUTION_420: + fname += "420"; + switch (constexpr_sampler->x_chroma_offset) + { + case MSL_CHROMA_LOCATION_COSITED_EVEN: + switch (constexpr_sampler->y_chroma_offset) + { + case MSL_CHROMA_LOCATION_COSITED_EVEN: + if (constexpr_sampler->planes == 2) + add_spv_func_and_recompile( + SPVFuncImplChromaReconstructLinear420XCositedEvenYCositedEven2Plane); + else + add_spv_func_and_recompile( + SPVFuncImplChromaReconstructLinear420XCositedEvenYCositedEven3Plane); + fname += "XCositedEvenYCositedEven"; + break; + case MSL_CHROMA_LOCATION_MIDPOINT: + if (constexpr_sampler->planes == 2) + add_spv_func_and_recompile( + SPVFuncImplChromaReconstructLinear420XCositedEvenYMidpoint2Plane); + else + add_spv_func_and_recompile( + SPVFuncImplChromaReconstructLinear420XCositedEvenYMidpoint3Plane); + fname += "XCositedEvenYMidpoint"; + break; + default: + SPIRV_CROSS_THROW("Invalid Y chroma location."); + } + break; + case MSL_CHROMA_LOCATION_MIDPOINT: + switch (constexpr_sampler->y_chroma_offset) + { + case MSL_CHROMA_LOCATION_COSITED_EVEN: + if (constexpr_sampler->planes == 2) + add_spv_func_and_recompile( + SPVFuncImplChromaReconstructLinear420XMidpointYCositedEven2Plane); + else + add_spv_func_and_recompile( + SPVFuncImplChromaReconstructLinear420XMidpointYCositedEven3Plane); + fname += "XMidpointYCositedEven"; + break; + case MSL_CHROMA_LOCATION_MIDPOINT: + if (constexpr_sampler->planes == 2) + add_spv_func_and_recompile(SPVFuncImplChromaReconstructLinear420XMidpointYMidpoint2Plane); + else + add_spv_func_and_recompile(SPVFuncImplChromaReconstructLinear420XMidpointYMidpoint3Plane); + fname += "XMidpointYMidpoint"; + break; + default: + SPIRV_CROSS_THROW("Invalid Y chroma location."); + } + break; + default: + SPIRV_CROSS_THROW("Invalid X chroma location."); + } + break; + default: + SPIRV_CROSS_THROW("Invalid format resolution."); + } + } + } + else + { + fname = to_expression(combined ? combined->image : img) + "."; + + // Texture function and sampler + if (args.base.is_fetch) + fname += "read"; + else if (args.base.is_gather) + fname += "gather"; + else + fname += "sample"; + + if (args.has_dref) + fname += "_compare"; + } + + return fname; +} + +string CompilerMSL::convert_to_f32(const string &expr, uint32_t components) +{ + SPIRType t { components > 1 ? OpTypeVector : OpTypeFloat }; + t.basetype = SPIRType::Float; + t.vecsize = components; + t.columns = 1; + return join(type_to_glsl_constructor(t), "(", expr, ")"); +} + +static inline bool sampling_type_needs_f32_conversion(const SPIRType &type) +{ + // Double is not supported to begin with, but doesn't hurt to check for completion. + return type.basetype == SPIRType::Half || type.basetype == SPIRType::Double; +} + +// Returns the function args for a texture sampling function for the specified image and sampling characteristics. +string CompilerMSL::to_function_args(const TextureFunctionArguments &args, bool *p_forward) +{ + VariableID img = args.base.img; + auto &imgtype = *args.base.imgtype; + uint32_t lod = args.lod; + uint32_t grad_x = args.grad_x; + uint32_t grad_y = args.grad_y; + uint32_t bias = args.bias; + + const MSLConstexprSampler *constexpr_sampler = nullptr; + bool is_dynamic_img_sampler = false; + if (auto *var = maybe_get_backing_variable(img)) + { + constexpr_sampler = find_constexpr_sampler(var->basevariable ? var->basevariable : VariableID(var->self)); + is_dynamic_img_sampler = has_extended_decoration(var->self, SPIRVCrossDecorationDynamicImageSampler); + } + + string farg_str; + bool forward = true; + + if (!is_dynamic_img_sampler) + { + // Texture reference (for some cases) + if (needs_chroma_reconstruction(constexpr_sampler)) + { + // Multiplanar images need two or three textures. + farg_str += to_expression(img); + for (uint32_t i = 1; i < constexpr_sampler->planes; i++) + farg_str += join(", ", to_expression(img), plane_name_suffix, i); + } + else if ((!constexpr_sampler || !constexpr_sampler->ycbcr_conversion_enable) && + msl_options.swizzle_texture_samples && args.base.is_gather) + { + auto *combined = maybe_get(img); + farg_str += to_expression(combined ? combined->image : img); + } + + // Gathers with constant offsets call a special function, so include the texture. + if (args.has_array_offsets) + farg_str += to_expression(img); + + // Sampler reference + if (!args.base.is_fetch) + { + if (!farg_str.empty()) + farg_str += ", "; + farg_str += to_sampler_expression(img); + } + + if ((!constexpr_sampler || !constexpr_sampler->ycbcr_conversion_enable) && + msl_options.swizzle_texture_samples && args.base.is_gather) + { + // Add the swizzle constant from the swizzle buffer. + farg_str += ", " + to_swizzle_expression(img); + used_swizzle_buffer = true; + } + + // Const offsets gather puts the const offsets before the other args. + if (args.has_array_offsets) + { + forward = forward && should_forward(args.offset); + farg_str += ", " + to_expression(args.offset); + } + + // Const offsets gather or swizzled gather puts the component before the other args. + if (args.component && (args.has_array_offsets || msl_options.swizzle_texture_samples)) + { + forward = forward && should_forward(args.component); + farg_str += ", " + to_component_argument(args.component); + } + } + + // Texture coordinates + forward = forward && should_forward(args.coord); + auto coord_expr = to_enclosed_expression(args.coord); + auto &coord_type = expression_type(args.coord); + bool coord_is_fp = type_is_floating_point(coord_type); + bool is_cube_fetch = false; + + string tex_coords = coord_expr; + uint32_t alt_coord_component = 0; + + switch (imgtype.image.dim) + { + + case Dim1D: + if (coord_type.vecsize > 1) + tex_coords = enclose_expression(tex_coords) + ".x"; + + if (args.base.is_fetch) + tex_coords = "uint(" + round_fp_tex_coords(tex_coords, coord_is_fp) + ")"; + else if (sampling_type_needs_f32_conversion(coord_type)) + tex_coords = convert_to_f32(tex_coords, 1); + + if (msl_options.texture_1D_as_2D) + { + if (args.base.is_fetch) + tex_coords = "uint2(" + tex_coords + ", 0)"; + else + tex_coords = "float2(" + tex_coords + ", 0.5)"; + } + + alt_coord_component = 1; + break; + + case DimBuffer: + if (coord_type.vecsize > 1) + tex_coords = enclose_expression(tex_coords) + ".x"; + + if (msl_options.texture_buffer_native) + { + tex_coords = "uint(" + round_fp_tex_coords(tex_coords, coord_is_fp) + ")"; + } + else + { + // Metal texel buffer textures are 2D, so convert 1D coord to 2D. + // Support for Metal 2.1's new texture_buffer type. + if (args.base.is_fetch) + { + if (msl_options.texel_buffer_texture_width > 0) + { + tex_coords = "spvTexelBufferCoord(" + round_fp_tex_coords(tex_coords, coord_is_fp) + ")"; + } + else + { + tex_coords = "spvTexelBufferCoord(" + round_fp_tex_coords(tex_coords, coord_is_fp) + ", " + + to_expression(img) + ")"; + } + } + } + + alt_coord_component = 1; + break; + + case DimSubpassData: + // If we're using Metal's native frame-buffer fetch API for subpass inputs, + // this path will not be hit. + tex_coords = "uint2(gl_FragCoord.xy)"; + alt_coord_component = 2; + break; + + case Dim2D: + if (coord_type.vecsize > 2) + tex_coords = enclose_expression(tex_coords) + ".xy"; + + if (args.base.is_fetch) + tex_coords = "uint2(" + round_fp_tex_coords(tex_coords, coord_is_fp) + ")"; + else if (sampling_type_needs_f32_conversion(coord_type)) + tex_coords = convert_to_f32(tex_coords, 2); + + alt_coord_component = 2; + break; + + case Dim3D: + if (coord_type.vecsize > 3) + tex_coords = enclose_expression(tex_coords) + ".xyz"; + + if (args.base.is_fetch) + tex_coords = "uint3(" + round_fp_tex_coords(tex_coords, coord_is_fp) + ")"; + else if (sampling_type_needs_f32_conversion(coord_type)) + tex_coords = convert_to_f32(tex_coords, 3); + + alt_coord_component = 3; + break; + + case DimCube: + if (args.base.is_fetch) + { + is_cube_fetch = true; + tex_coords += ".xy"; + tex_coords = "uint2(" + round_fp_tex_coords(tex_coords, coord_is_fp) + ")"; + } + else + { + if (coord_type.vecsize > 3) + tex_coords = enclose_expression(tex_coords) + ".xyz"; + } + + if (sampling_type_needs_f32_conversion(coord_type)) + tex_coords = convert_to_f32(tex_coords, 3); + + alt_coord_component = 3; + break; + + default: + break; + } + + if (args.base.is_fetch && args.offset) + { + // Fetch offsets must be applied directly to the coordinate. + forward = forward && should_forward(args.offset); + auto &type = expression_type(args.offset); + if (imgtype.image.dim == Dim1D && msl_options.texture_1D_as_2D) + { + if (type.basetype != SPIRType::UInt) + tex_coords += join(" + uint2(", bitcast_expression(SPIRType::UInt, args.offset), ", 0)"); + else + tex_coords += join(" + uint2(", to_enclosed_expression(args.offset), ", 0)"); + } + else + { + if (type.basetype != SPIRType::UInt) + tex_coords += " + " + bitcast_expression(SPIRType::UInt, args.offset); + else + tex_coords += " + " + to_enclosed_expression(args.offset); + } + } + + // If projection, use alt coord as divisor + if (args.base.is_proj) + { + if (sampling_type_needs_f32_conversion(coord_type)) + tex_coords += " / " + convert_to_f32(to_extract_component_expression(args.coord, alt_coord_component), 1); + else + tex_coords += " / " + to_extract_component_expression(args.coord, alt_coord_component); + } + + if (!farg_str.empty()) + farg_str += ", "; + + if (imgtype.image.dim == DimCube && imgtype.image.arrayed && msl_options.emulate_cube_array) + { + farg_str += "spvCubemapTo2DArrayFace(" + tex_coords + ").xy"; + + if (is_cube_fetch) + farg_str += ", uint(" + to_extract_component_expression(args.coord, 2) + ")"; + else + farg_str += + ", uint(spvCubemapTo2DArrayFace(" + tex_coords + ").z) + (uint(" + + round_fp_tex_coords(to_extract_component_expression(args.coord, alt_coord_component), coord_is_fp) + + ") * 6u)"; + + add_spv_func_and_recompile(SPVFuncImplCubemapTo2DArrayFace); + } + else + { + farg_str += tex_coords; + + // If fetch from cube, add face explicitly + if (is_cube_fetch) + { + // Special case for cube arrays, face and layer are packed in one dimension. + if (imgtype.image.arrayed) + farg_str += ", uint(" + to_extract_component_expression(args.coord, 2) + ") % 6u"; + else + farg_str += + ", uint(" + round_fp_tex_coords(to_extract_component_expression(args.coord, 2), coord_is_fp) + ")"; + } + + // If array, use alt coord + if (imgtype.image.arrayed) + { + // Special case for cube arrays, face and layer are packed in one dimension. + if (imgtype.image.dim == DimCube && args.base.is_fetch) + { + farg_str += ", uint(" + to_extract_component_expression(args.coord, 2) + ") / 6u"; + } + else + { + farg_str += + ", uint(" + + round_fp_tex_coords(to_extract_component_expression(args.coord, alt_coord_component), coord_is_fp) + + ")"; + if (imgtype.image.dim == DimSubpassData) + { + if (msl_options.multiview) + farg_str += " + gl_ViewIndex"; + else if (msl_options.arrayed_subpass_input) + farg_str += " + gl_Layer"; + } + } + } + else if (imgtype.image.dim == DimSubpassData) + { + if (msl_options.multiview) + farg_str += ", gl_ViewIndex"; + else if (msl_options.arrayed_subpass_input) + farg_str += ", gl_Layer"; + } + } + + // Depth compare reference value + if (args.dref) + { + forward = forward && should_forward(args.dref); + farg_str += ", "; + + auto &dref_type = expression_type(args.dref); + + string dref_expr; + if (args.base.is_proj) + dref_expr = join(to_enclosed_expression(args.dref), " / ", + to_extract_component_expression(args.coord, alt_coord_component)); + else + dref_expr = to_expression(args.dref); + + if (sampling_type_needs_f32_conversion(dref_type)) + dref_expr = convert_to_f32(dref_expr, 1); + + farg_str += dref_expr; + + if (msl_options.is_macos() && (grad_x || grad_y)) + { + // For sample compare, MSL does not support gradient2d for all targets (only iOS apparently according to docs). + // However, the most common case here is to have a constant gradient of 0, as that is the only way to express + // LOD == 0 in GLSL with sampler2DArrayShadow (cascaded shadow mapping). + // We will detect a compile-time constant 0 value for gradient and promote that to level(0) on MSL. + bool constant_zero_x = !grad_x || expression_is_constant_null(grad_x); + bool constant_zero_y = !grad_y || expression_is_constant_null(grad_y); + if (constant_zero_x && constant_zero_y && + (!imgtype.image.arrayed || !msl_options.sample_dref_lod_array_as_grad)) + { + lod = 0; + grad_x = 0; + grad_y = 0; + farg_str += ", level(0)"; + } + else if (!msl_options.supports_msl_version(2, 3)) + { + SPIRV_CROSS_THROW("Using non-constant 0.0 gradient() qualifier for sample_compare. This is not " + "supported on macOS prior to MSL 2.3."); + } + } + + if (msl_options.is_macos() && bias) + { + // Bias is not supported either on macOS with sample_compare. + // Verify it is compile-time zero, and drop the argument. + if (expression_is_constant_null(bias)) + { + bias = 0; + } + else if (!msl_options.supports_msl_version(2, 3)) + { + SPIRV_CROSS_THROW("Using non-constant 0.0 bias() qualifier for sample_compare. This is not supported " + "on macOS prior to MSL 2.3."); + } + } + } + + // LOD Options + // Metal does not support LOD for 1D textures. + if (bias && (imgtype.image.dim != Dim1D || msl_options.texture_1D_as_2D)) + { + forward = forward && should_forward(bias); + farg_str += ", bias(" + to_expression(bias) + ")"; + } + + // Metal does not support LOD for 1D textures. + if (lod && (imgtype.image.dim != Dim1D || msl_options.texture_1D_as_2D)) + { + forward = forward && should_forward(lod); + if (args.base.is_fetch) + { + farg_str += ", " + to_expression(lod); + } + else if (msl_options.sample_dref_lod_array_as_grad && args.dref && imgtype.image.arrayed) + { + if (msl_options.is_macos() && !msl_options.supports_msl_version(2, 3)) + SPIRV_CROSS_THROW("Using non-constant 0.0 gradient() qualifier for sample_compare. This is not " + "supported on macOS prior to MSL 2.3."); + // Some Metal devices have a bug where the LoD is erroneously biased upward + // when using a level() argument. Since this doesn't happen as much with gradient2d(), + // if we perform the LoD calculation in reverse, we can pass a gradient + // instead. + // lod = log2(rhoMax/eta) -> exp2(lod) = rhoMax/eta + // If we make all of the scale factors the same, eta will be 1 and + // exp2(lod) = rho. + // rhoX = dP/dx * extent; rhoY = dP/dy * extent + // Therefore, dP/dx = dP/dy = exp2(lod)/extent. + // (Subtracting 0.5 before exponentiation gives better results.) + string grad_opt, extent, grad_coord; + VariableID base_img = img; + if (auto *combined = maybe_get(img)) + base_img = combined->image; + switch (imgtype.image.dim) + { + case Dim1D: + grad_opt = "gradient2d"; + extent = join("float2(", to_expression(base_img), ".get_width(), 1.0)"); + break; + case Dim2D: + grad_opt = "gradient2d"; + extent = join("float2(", to_expression(base_img), ".get_width(), ", to_expression(base_img), ".get_height())"); + break; + case DimCube: + if (imgtype.image.arrayed && msl_options.emulate_cube_array) + { + grad_opt = "gradient2d"; + extent = join("float2(", to_expression(base_img), ".get_width())"); + } + else + { + if (msl_options.agx_manual_cube_grad_fixup) + { + add_spv_func_and_recompile(SPVFuncImplGradientCube); + grad_opt = "spvGradientCube"; + grad_coord = tex_coords + ", "; + } + else + { + grad_opt = "gradientcube"; + } + extent = join("float3(", to_expression(base_img), ".get_width())"); + } + break; + default: + grad_opt = "unsupported_gradient_dimension"; + extent = "float3(1.0)"; + break; + } + farg_str += join(", ", grad_opt, "(", grad_coord, "exp2(", to_expression(lod), " - 0.5) / ", extent, + ", exp2(", to_expression(lod), " - 0.5) / ", extent, ")"); + } + else + { + farg_str += ", level(" + to_expression(lod) + ")"; + } + } + else if (args.base.is_fetch && !lod && (imgtype.image.dim != Dim1D || msl_options.texture_1D_as_2D) && + imgtype.image.dim != DimBuffer && !imgtype.image.ms && imgtype.image.sampled != 2) + { + // Lod argument is optional in OpImageFetch, but we require a LOD value, pick 0 as the default. + // Check for sampled type as well, because is_fetch is also used for OpImageRead in MSL. + farg_str += ", 0"; + } + + // Metal does not support LOD for 1D textures. + if ((grad_x || grad_y) && (imgtype.image.dim != Dim1D || msl_options.texture_1D_as_2D)) + { + forward = forward && should_forward(grad_x); + forward = forward && should_forward(grad_y); + string grad_opt, grad_coord; + switch (imgtype.image.dim) + { + case Dim1D: + case Dim2D: + grad_opt = "gradient2d"; + break; + case Dim3D: + grad_opt = "gradient3d"; + break; + case DimCube: + if (imgtype.image.arrayed && msl_options.emulate_cube_array) + { + grad_opt = "gradient2d"; + } + else if (msl_options.agx_manual_cube_grad_fixup) + { + add_spv_func_and_recompile(SPVFuncImplGradientCube); + grad_opt = "spvGradientCube"; + grad_coord = tex_coords + ", "; + } + else + { + grad_opt = "gradientcube"; + } + break; + default: + grad_opt = "unsupported_gradient_dimension"; + break; + } + farg_str += join(", ", grad_opt, "(", grad_coord, to_expression(grad_x), ", ", to_expression(grad_y), ")"); + } + + if (args.min_lod) + { + if (!msl_options.supports_msl_version(2, 2)) + SPIRV_CROSS_THROW("min_lod_clamp() is only supported in MSL 2.2+ and up."); + + forward = forward && should_forward(args.min_lod); + farg_str += ", min_lod_clamp(" + to_expression(args.min_lod) + ")"; + } + + // Add offsets + string offset_expr; + const SPIRType *offset_type = nullptr; + if (args.offset && !args.base.is_fetch && !args.has_array_offsets) + { + forward = forward && should_forward(args.offset); + offset_expr = to_expression(args.offset); + offset_type = &expression_type(args.offset); + } + + if (!offset_expr.empty()) + { + switch (imgtype.image.dim) + { + case Dim1D: + if (!msl_options.texture_1D_as_2D) + break; + if (offset_type->vecsize > 1) + offset_expr = enclose_expression(offset_expr) + ".x"; + + farg_str += join(", int2(", offset_expr, ", 0)"); + break; + + case Dim2D: + if (offset_type->vecsize > 2) + offset_expr = enclose_expression(offset_expr) + ".xy"; + + farg_str += ", " + offset_expr; + break; + + case Dim3D: + if (offset_type->vecsize > 3) + offset_expr = enclose_expression(offset_expr) + ".xyz"; + + farg_str += ", " + offset_expr; + break; + + default: + break; + } + } + + if (args.component && !args.has_array_offsets) + { + // If 2D has gather component, ensure it also has an offset arg + if (imgtype.image.dim == Dim2D && offset_expr.empty()) + farg_str += ", int2(0)"; + + if (!msl_options.swizzle_texture_samples || is_dynamic_img_sampler) + { + forward = forward && should_forward(args.component); + + uint32_t image_var = 0; + if (const auto *combined = maybe_get(img)) + { + if (const auto *img_var = maybe_get_backing_variable(combined->image)) + image_var = img_var->self; + } + else if (const auto *var = maybe_get_backing_variable(img)) + { + image_var = var->self; + } + + if (image_var == 0 || !is_depth_image(expression_type(image_var), image_var)) + farg_str += ", " + to_component_argument(args.component); + } + } + + if (args.sample) + { + forward = forward && should_forward(args.sample); + farg_str += ", "; + farg_str += to_expression(args.sample); + } + + *p_forward = forward; + + return farg_str; +} + +// If the texture coordinates are floating point, invokes MSL round() function to round them. +string CompilerMSL::round_fp_tex_coords(string tex_coords, bool coord_is_fp) +{ + return coord_is_fp ? ("rint(" + tex_coords + ")") : tex_coords; +} + +// Returns a string to use in an image sampling function argument. +// The ID must be a scalar constant. +string CompilerMSL::to_component_argument(uint32_t id) +{ + uint32_t component_index = evaluate_constant_u32(id); + switch (component_index) + { + case 0: + return "component::x"; + case 1: + return "component::y"; + case 2: + return "component::z"; + case 3: + return "component::w"; + + default: + SPIRV_CROSS_THROW("The value (" + to_string(component_index) + ") of OpConstant ID " + to_string(id) + + " is not a valid Component index, which must be one of 0, 1, 2, or 3."); + } +} + +// Establish sampled image as expression object and assign the sampler to it. +void CompilerMSL::emit_sampled_image_op(uint32_t result_type, uint32_t result_id, uint32_t image_id, uint32_t samp_id) +{ + set(result_id, result_type, image_id, samp_id); +} + +string CompilerMSL::to_texture_op(const Instruction &i, bool sparse, bool *forward, + SmallVector &inherited_expressions) +{ + auto *ops = stream(i); + uint32_t result_type_id = ops[0]; + uint32_t img = ops[2]; + auto &result_type = get(result_type_id); + auto op = static_cast(i.op); + bool is_gather = (op == OpImageGather || op == OpImageDrefGather); + + // Bypass pointers because we need the real image struct + auto &type = expression_type(img); + auto &imgtype = get(type.self); + + const MSLConstexprSampler *constexpr_sampler = nullptr; + bool is_dynamic_img_sampler = false; + if (auto *var = maybe_get_backing_variable(img)) + { + constexpr_sampler = find_constexpr_sampler(var->basevariable ? var->basevariable : VariableID(var->self)); + is_dynamic_img_sampler = has_extended_decoration(var->self, SPIRVCrossDecorationDynamicImageSampler); + } + + string expr; + if (constexpr_sampler && constexpr_sampler->ycbcr_conversion_enable && !is_dynamic_img_sampler) + { + // If this needs sampler Y'CbCr conversion, we need to do some additional + // processing. + switch (constexpr_sampler->ycbcr_model) + { + case MSL_SAMPLER_YCBCR_MODEL_CONVERSION_RGB_IDENTITY: + case MSL_SAMPLER_YCBCR_MODEL_CONVERSION_YCBCR_IDENTITY: + // Default + break; + case MSL_SAMPLER_YCBCR_MODEL_CONVERSION_YCBCR_BT_709: + add_spv_func_and_recompile(SPVFuncImplConvertYCbCrBT709); + expr += "spvConvertYCbCrBT709("; + break; + case MSL_SAMPLER_YCBCR_MODEL_CONVERSION_YCBCR_BT_601: + add_spv_func_and_recompile(SPVFuncImplConvertYCbCrBT601); + expr += "spvConvertYCbCrBT601("; + break; + case MSL_SAMPLER_YCBCR_MODEL_CONVERSION_YCBCR_BT_2020: + add_spv_func_and_recompile(SPVFuncImplConvertYCbCrBT2020); + expr += "spvConvertYCbCrBT2020("; + break; + default: + SPIRV_CROSS_THROW("Invalid Y'CbCr model conversion."); + } + + if (constexpr_sampler->ycbcr_model != MSL_SAMPLER_YCBCR_MODEL_CONVERSION_RGB_IDENTITY) + { + switch (constexpr_sampler->ycbcr_range) + { + case MSL_SAMPLER_YCBCR_RANGE_ITU_FULL: + add_spv_func_and_recompile(SPVFuncImplExpandITUFullRange); + expr += "spvExpandITUFullRange("; + break; + case MSL_SAMPLER_YCBCR_RANGE_ITU_NARROW: + add_spv_func_and_recompile(SPVFuncImplExpandITUNarrowRange); + expr += "spvExpandITUNarrowRange("; + break; + default: + SPIRV_CROSS_THROW("Invalid Y'CbCr range."); + } + } + } + else if (msl_options.swizzle_texture_samples && !is_gather && is_sampled_image_type(imgtype) && + !is_dynamic_img_sampler) + { + add_spv_func_and_recompile(SPVFuncImplTextureSwizzle); + expr += "spvTextureSwizzle("; + } + + string inner_expr = CompilerGLSL::to_texture_op(i, sparse, forward, inherited_expressions); + + if (constexpr_sampler && constexpr_sampler->ycbcr_conversion_enable && !is_dynamic_img_sampler) + { + if (!constexpr_sampler->swizzle_is_identity()) + { + static const char swizzle_names[] = "rgba"; + if (!constexpr_sampler->swizzle_has_one_or_zero()) + { + // If we can, do it inline. + expr += inner_expr + "."; + for (uint32_t c = 0; c < 4; c++) + { + switch (constexpr_sampler->swizzle[c]) + { + case MSL_COMPONENT_SWIZZLE_IDENTITY: + expr += swizzle_names[c]; + break; + case MSL_COMPONENT_SWIZZLE_R: + case MSL_COMPONENT_SWIZZLE_G: + case MSL_COMPONENT_SWIZZLE_B: + case MSL_COMPONENT_SWIZZLE_A: + expr += swizzle_names[constexpr_sampler->swizzle[c] - MSL_COMPONENT_SWIZZLE_R]; + break; + default: + SPIRV_CROSS_THROW("Invalid component swizzle."); + } + } + } + else + { + // Otherwise, we need to emit a temporary and swizzle that. + uint32_t temp_id = ir.increase_bound_by(1); + emit_op(result_type_id, temp_id, inner_expr, false); + for (auto &inherit : inherited_expressions) + inherit_expression_dependencies(temp_id, inherit); + inherited_expressions.clear(); + inherited_expressions.push_back(temp_id); + + switch (op) + { + case OpImageSampleDrefImplicitLod: + case OpImageSampleImplicitLod: + case OpImageSampleProjImplicitLod: + case OpImageSampleProjDrefImplicitLod: + register_control_dependent_expression(temp_id); + break; + + default: + break; + } + expr += type_to_glsl(result_type) + "("; + for (uint32_t c = 0; c < 4; c++) + { + switch (constexpr_sampler->swizzle[c]) + { + case MSL_COMPONENT_SWIZZLE_IDENTITY: + expr += to_expression(temp_id) + "." + swizzle_names[c]; + break; + case MSL_COMPONENT_SWIZZLE_ZERO: + expr += "0"; + break; + case MSL_COMPONENT_SWIZZLE_ONE: + expr += "1"; + break; + case MSL_COMPONENT_SWIZZLE_R: + case MSL_COMPONENT_SWIZZLE_G: + case MSL_COMPONENT_SWIZZLE_B: + case MSL_COMPONENT_SWIZZLE_A: + expr += to_expression(temp_id) + "." + + swizzle_names[constexpr_sampler->swizzle[c] - MSL_COMPONENT_SWIZZLE_R]; + break; + default: + SPIRV_CROSS_THROW("Invalid component swizzle."); + } + if (c < 3) + expr += ", "; + } + expr += ")"; + } + } + else + expr += inner_expr; + if (constexpr_sampler->ycbcr_model != MSL_SAMPLER_YCBCR_MODEL_CONVERSION_RGB_IDENTITY) + { + expr += join(", ", constexpr_sampler->bpc, ")"); + if (constexpr_sampler->ycbcr_model != MSL_SAMPLER_YCBCR_MODEL_CONVERSION_YCBCR_IDENTITY) + expr += ")"; + } + } + else + { + expr += inner_expr; + if (msl_options.swizzle_texture_samples && !is_gather && is_sampled_image_type(imgtype) && + !is_dynamic_img_sampler) + { + // Add the swizzle constant from the swizzle buffer. + expr += ", " + to_swizzle_expression(img) + ")"; + used_swizzle_buffer = true; + } + } + + return expr; +} + +static string create_swizzle(MSLComponentSwizzle swizzle) +{ + switch (swizzle) + { + case MSL_COMPONENT_SWIZZLE_IDENTITY: + return "spvSwizzle::none"; + case MSL_COMPONENT_SWIZZLE_ZERO: + return "spvSwizzle::zero"; + case MSL_COMPONENT_SWIZZLE_ONE: + return "spvSwizzle::one"; + case MSL_COMPONENT_SWIZZLE_R: + return "spvSwizzle::red"; + case MSL_COMPONENT_SWIZZLE_G: + return "spvSwizzle::green"; + case MSL_COMPONENT_SWIZZLE_B: + return "spvSwizzle::blue"; + case MSL_COMPONENT_SWIZZLE_A: + return "spvSwizzle::alpha"; + default: + SPIRV_CROSS_THROW("Invalid component swizzle."); + } +} + +// Returns a string representation of the ID, usable as a function arg. +// Manufacture automatic sampler arg for SampledImage texture. +string CompilerMSL::to_func_call_arg(const SPIRFunction::Parameter &arg, uint32_t id) +{ + string arg_str; + + auto &type = expression_type(id); + bool is_dynamic_img_sampler = has_extended_decoration(arg.id, SPIRVCrossDecorationDynamicImageSampler); + // If the argument *itself* is a "dynamic" combined-image sampler, then we can just pass that around. + bool arg_is_dynamic_img_sampler = has_extended_decoration(id, SPIRVCrossDecorationDynamicImageSampler); + if (is_dynamic_img_sampler && !arg_is_dynamic_img_sampler) + arg_str = join("spvDynamicImageSampler<", type_to_glsl(get(type.image.type)), ">("); + + auto *c = maybe_get(id); + if (msl_options.force_native_arrays && c && !get(c->constant_type).array.empty()) + { + // If we are passing a constant array directly to a function for some reason, + // the callee will expect an argument in thread const address space + // (since we can only bind to arrays with references in MSL). + // To resolve this, we must emit a copy in this address space. + // This kind of code gen should be rare enough that performance is not a real concern. + // Inline the SPIR-V to avoid this kind of suboptimal codegen. + // + // We risk calling this inside a continue block (invalid code), + // so just create a thread local copy in the current function. + arg_str = join("_", id, "_array_copy"); + auto &constants = current_function->constant_arrays_needed_on_stack; + auto itr = find(begin(constants), end(constants), ID(id)); + if (itr == end(constants)) + { + force_recompile(); + constants.push_back(id); + } + } + // Dereference pointer variables where needed. + // FIXME: This dereference is actually backwards. We should really just support passing pointer variables between functions. + else if (should_dereference(id)) + arg_str += dereference_expression(type, CompilerGLSL::to_func_call_arg(arg, id)); + else + arg_str += CompilerGLSL::to_func_call_arg(arg, id); + + // Need to check the base variable in case we need to apply a qualified alias. + uint32_t var_id = 0; + auto *var = maybe_get(id); + if (var) + var_id = var->basevariable; + + if (!arg_is_dynamic_img_sampler) + { + auto *constexpr_sampler = find_constexpr_sampler(var_id ? var_id : id); + if (type.basetype == SPIRType::SampledImage) + { + // Manufacture automatic plane args for multiplanar texture + uint32_t planes = 1; + if (constexpr_sampler && constexpr_sampler->ycbcr_conversion_enable) + { + planes = constexpr_sampler->planes; + // If this parameter isn't aliasing a global, then we need to use + // the special "dynamic image-sampler" class to pass it--and we need + // to use it for *every* non-alias parameter, in case a combined + // image-sampler with a Y'CbCr conversion is passed. Hopefully, this + // pathological case is so rare that it should never be hit in practice. + if (!arg.alias_global_variable) + add_spv_func_and_recompile(SPVFuncImplDynamicImageSampler); + } + for (uint32_t i = 1; i < planes; i++) + arg_str += join(", ", CompilerGLSL::to_func_call_arg(arg, id), plane_name_suffix, i); + // Manufacture automatic sampler arg if the arg is a SampledImage texture. + if (type.image.dim != DimBuffer) + arg_str += ", " + to_sampler_expression(var_id ? var_id : id); + + // Add sampler Y'CbCr conversion info if we have it + if (is_dynamic_img_sampler && constexpr_sampler && constexpr_sampler->ycbcr_conversion_enable) + { + SmallVector samp_args; + + switch (constexpr_sampler->resolution) + { + case MSL_FORMAT_RESOLUTION_444: + // Default + break; + case MSL_FORMAT_RESOLUTION_422: + samp_args.push_back("spvFormatResolution::_422"); + break; + case MSL_FORMAT_RESOLUTION_420: + samp_args.push_back("spvFormatResolution::_420"); + break; + default: + SPIRV_CROSS_THROW("Invalid format resolution."); + } + + if (constexpr_sampler->chroma_filter != MSL_SAMPLER_FILTER_NEAREST) + samp_args.push_back("spvChromaFilter::linear"); + + if (constexpr_sampler->x_chroma_offset != MSL_CHROMA_LOCATION_COSITED_EVEN) + samp_args.push_back("spvXChromaLocation::midpoint"); + if (constexpr_sampler->y_chroma_offset != MSL_CHROMA_LOCATION_COSITED_EVEN) + samp_args.push_back("spvYChromaLocation::midpoint"); + switch (constexpr_sampler->ycbcr_model) + { + case MSL_SAMPLER_YCBCR_MODEL_CONVERSION_RGB_IDENTITY: + // Default + break; + case MSL_SAMPLER_YCBCR_MODEL_CONVERSION_YCBCR_IDENTITY: + samp_args.push_back("spvYCbCrModelConversion::ycbcr_identity"); + break; + case MSL_SAMPLER_YCBCR_MODEL_CONVERSION_YCBCR_BT_709: + samp_args.push_back("spvYCbCrModelConversion::ycbcr_bt_709"); + break; + case MSL_SAMPLER_YCBCR_MODEL_CONVERSION_YCBCR_BT_601: + samp_args.push_back("spvYCbCrModelConversion::ycbcr_bt_601"); + break; + case MSL_SAMPLER_YCBCR_MODEL_CONVERSION_YCBCR_BT_2020: + samp_args.push_back("spvYCbCrModelConversion::ycbcr_bt_2020"); + break; + default: + SPIRV_CROSS_THROW("Invalid Y'CbCr model conversion."); + } + if (constexpr_sampler->ycbcr_range != MSL_SAMPLER_YCBCR_RANGE_ITU_FULL) + samp_args.push_back("spvYCbCrRange::itu_narrow"); + samp_args.push_back(join("spvComponentBits(", constexpr_sampler->bpc, ")")); + arg_str += join(", spvYCbCrSampler(", merge(samp_args), ")"); + } + } + + if (is_dynamic_img_sampler && constexpr_sampler && constexpr_sampler->ycbcr_conversion_enable) + arg_str += join(", (uint(", create_swizzle(constexpr_sampler->swizzle[3]), ") << 24) | (uint(", + create_swizzle(constexpr_sampler->swizzle[2]), ") << 16) | (uint(", + create_swizzle(constexpr_sampler->swizzle[1]), ") << 8) | uint(", + create_swizzle(constexpr_sampler->swizzle[0]), ")"); + else if (msl_options.swizzle_texture_samples && has_sampled_images && is_sampled_image_type(type)) + arg_str += ", " + to_swizzle_expression(var_id ? var_id : id); + + if (buffer_requires_array_length(var_id)) + arg_str += ", " + to_buffer_size_expression(var_id ? var_id : id); + + if (is_dynamic_img_sampler) + arg_str += ")"; + } + + // Emulate texture2D atomic operations + auto *backing_var = maybe_get_backing_variable(var_id); + if (backing_var && atomic_image_vars_emulated.count(backing_var->self)) + { + arg_str += ", " + to_expression(var_id) + "_atomic"; + } + + return arg_str; +} + +// If the ID represents a sampled image that has been assigned a sampler already, +// generate an expression for the sampler, otherwise generate a fake sampler name +// by appending a suffix to the expression constructed from the ID. +string CompilerMSL::to_sampler_expression(uint32_t id) +{ + auto *combined = maybe_get(id); + if (combined && combined->sampler) + return to_expression(combined->sampler); + + uint32_t expr_id = combined ? uint32_t(combined->image) : id; + + // Constexpr samplers are declared as local variables, + // so exclude any qualifier names on the image expression. + if (auto *var = maybe_get_backing_variable(expr_id)) + { + uint32_t img_id = var->basevariable ? var->basevariable : VariableID(var->self); + if (find_constexpr_sampler(img_id)) + return Compiler::to_name(img_id) + sampler_name_suffix; + } + + auto img_expr = to_expression(expr_id); + auto index = img_expr.find_first_of('['); + if (index == string::npos) + return img_expr + sampler_name_suffix; + else + return img_expr.substr(0, index) + sampler_name_suffix + img_expr.substr(index); +} + +string CompilerMSL::to_swizzle_expression(uint32_t id) +{ + auto *combined = maybe_get(id); + + auto expr = to_expression(combined ? combined->image : VariableID(id)); + auto index = expr.find_first_of('['); + + // If an image is part of an argument buffer translate this to a legal identifier. + string::size_type period = 0; + while ((period = expr.find_first_of('.', period)) != string::npos && period < index) + expr[period] = '_'; + + if (index == string::npos) + return expr + swizzle_name_suffix; + else + { + auto image_expr = expr.substr(0, index); + auto array_expr = expr.substr(index); + return image_expr + swizzle_name_suffix + array_expr; + } +} + +string CompilerMSL::to_buffer_size_expression(uint32_t id) +{ + auto expr = to_expression(id); + auto index = expr.find_first_of('['); + + // This is quite crude, but we need to translate the reference name (*spvDescriptorSetN.name) to + // the pointer expression spvDescriptorSetN.name to make a reasonable expression here. + // This only happens if we have argument buffers and we are using OpArrayLength on a lone SSBO in that set. + if (expr.size() >= 3 && expr[0] == '(' && expr[1] == '*') + expr = address_of_expression(expr); + + // If a buffer is part of an argument buffer translate this to a legal identifier. + for (auto &c : expr) + if (c == '.') + c = '_'; + + if (index == string::npos) + return expr + buffer_size_name_suffix; + else + { + auto buffer_expr = expr.substr(0, index); + auto array_expr = expr.substr(index); + if (auto var = maybe_get_backing_variable(id)) + { + if (is_var_runtime_size_array(*var)) + { + if (!msl_options.runtime_array_rich_descriptor) + SPIRV_CROSS_THROW("OpArrayLength requires rich descriptor format"); + + auto last_pos = array_expr.find_last_of(']'); + if (last_pos != std::string::npos) + return buffer_expr + ".length(" + array_expr.substr(1, last_pos - 1) + ")"; + } + } + return buffer_expr + buffer_size_name_suffix + array_expr; + } +} + +// Checks whether the type is a Block all of whose members have DecorationPatch. +bool CompilerMSL::is_patch_block(const SPIRType &type) +{ + if (!has_decoration(type.self, DecorationBlock)) + return false; + + for (uint32_t i = 0; i < type.member_types.size(); i++) + { + if (!has_member_decoration(type.self, i, DecorationPatch)) + return false; + } + + return true; +} + +// Checks whether the ID is a row_major matrix that requires conversion before use +bool CompilerMSL::is_non_native_row_major_matrix(uint32_t id) +{ + auto *e = maybe_get(id); + if (e) + return e->need_transpose; + else + return has_decoration(id, DecorationRowMajor); +} + +// Checks whether the member is a row_major matrix that requires conversion before use +bool CompilerMSL::member_is_non_native_row_major_matrix(const SPIRType &type, uint32_t index) +{ + return has_member_decoration(type.self, index, DecorationRowMajor); +} + +string CompilerMSL::convert_row_major_matrix(string exp_str, const SPIRType &exp_type, uint32_t physical_type_id, + bool is_packed, bool relaxed) +{ + if (!is_matrix(exp_type)) + { + return CompilerGLSL::convert_row_major_matrix(std::move(exp_str), exp_type, physical_type_id, is_packed, relaxed); + } + else + { + strip_enclosed_expression(exp_str); + if (physical_type_id != 0 || is_packed) + exp_str = unpack_expression_type(exp_str, exp_type, physical_type_id, is_packed, true); + return join("transpose(", exp_str, ")"); + } +} + +// Called automatically at the end of the entry point function +void CompilerMSL::emit_fixup() +{ + if (is_vertex_like_shader() && stage_out_var_id && !qual_pos_var_name.empty() && !capture_output_to_buffer) + { + if (options.vertex.fixup_clipspace) + statement(qual_pos_var_name, ".z = (", qual_pos_var_name, ".z + ", qual_pos_var_name, + ".w) * 0.5; // Adjust clip-space for Metal"); + + if (options.vertex.flip_vert_y) + statement(qual_pos_var_name, ".y = -(", qual_pos_var_name, ".y);", " // Invert Y-axis for Metal"); + } +} + +// Return a string defining a structure member, with padding and packing. +string CompilerMSL::to_struct_member(const SPIRType &type, uint32_t member_type_id, uint32_t index, + const string &qualifier) +{ + uint32_t orig_member_type_id = member_type_id; + if (member_is_remapped_physical_type(type, index)) + member_type_id = get_extended_member_decoration(type.self, index, SPIRVCrossDecorationPhysicalTypeID); + auto &physical_type = get(member_type_id); + + // If this member is packed, mark it as so. + string pack_pfx; + + // Allow Metal to use the array template to make arrays a value type + uint32_t orig_id = 0; + if (has_extended_member_decoration(type.self, index, SPIRVCrossDecorationInterfaceOrigID)) + orig_id = get_extended_member_decoration(type.self, index, SPIRVCrossDecorationInterfaceOrigID); + + bool row_major = false; + if (is_matrix(physical_type)) + row_major = has_member_decoration(type.self, index, DecorationRowMajor); + + SPIRType row_major_physical_type { OpTypeMatrix }; + const SPIRType *declared_type = &physical_type; + + // If a struct is being declared with physical layout, + // do not use array wrappers. + // This avoids a lot of complicated cases with packed vectors and matrices, + // and generally we cannot copy full arrays in and out of buffers into Function + // address space. + // Array of resources should also be declared as builtin arrays. + if (has_member_decoration(type.self, index, DecorationOffset)) + is_using_builtin_array = true; + else if (has_extended_member_decoration(type.self, index, SPIRVCrossDecorationResourceIndexPrimary)) + is_using_builtin_array = true; + + if (member_is_packed_physical_type(type, index)) + { + // If we're packing a matrix, output an appropriate typedef + if (physical_type.basetype == SPIRType::Struct) + { + SPIRV_CROSS_THROW("Cannot emit a packed struct currently."); + } + else if (is_matrix(physical_type)) + { + uint32_t rows = physical_type.vecsize; + uint32_t cols = physical_type.columns; + pack_pfx = "packed_"; + if (row_major) + { + // These are stored transposed. + rows = physical_type.columns; + cols = physical_type.vecsize; + pack_pfx = "packed_rm_"; + } + string base_type = physical_type.width == 16 ? "half" : "float"; + string td_line = "typedef "; + td_line += "packed_" + base_type + to_string(rows); + td_line += " " + pack_pfx; + // Use the actual matrix size here. + td_line += base_type + to_string(physical_type.columns) + "x" + to_string(physical_type.vecsize); + td_line += "[" + to_string(cols) + "]"; + td_line += ";"; + add_typedef_line(td_line); + } + else if (!is_scalar(physical_type)) // scalar type is already packed. + pack_pfx = "packed_"; + } + else if (is_matrix(physical_type)) + { + if (!msl_options.supports_msl_version(3, 0) && + has_extended_decoration(type.self, SPIRVCrossDecorationWorkgroupStruct)) + { + pack_pfx = "spvStorage_"; + add_spv_func_and_recompile(SPVFuncImplStorageMatrix); + // The pack prefix causes problems with array wrappers. + is_using_builtin_array = true; + } + if (row_major) + { + // Need to declare type with flipped vecsize/columns. + row_major_physical_type = physical_type; + swap(row_major_physical_type.vecsize, row_major_physical_type.columns); + declared_type = &row_major_physical_type; + } + } + + // iOS Tier 1 argument buffers do not support writable images. + if (physical_type.basetype == SPIRType::Image && + physical_type.image.sampled == 2 && + msl_options.is_ios() && + msl_options.argument_buffers_tier <= Options::ArgumentBuffersTier::Tier1 && + !has_decoration(orig_id, DecorationNonWritable)) + { + SPIRV_CROSS_THROW("Writable images are not allowed on Tier1 argument buffers on iOS."); + } + + // Array information is baked into these types. + string array_type; + if (physical_type.basetype != SPIRType::Image && physical_type.basetype != SPIRType::Sampler && + physical_type.basetype != SPIRType::SampledImage) + { + BuiltIn builtin = BuiltInMax; + + // Special handling. In [[stage_out]] or [[stage_in]] blocks, + // we need flat arrays, but if we're somehow declaring gl_PerVertex for constant array reasons, we want + // template array types to be declared. + bool is_ib_in_out = + ((stage_out_var_id && get_stage_out_struct_type().self == type.self && + variable_storage_requires_stage_io(StorageClassOutput)) || + (stage_in_var_id && get_stage_in_struct_type().self == type.self && + variable_storage_requires_stage_io(StorageClassInput))); + if (is_ib_in_out && is_member_builtin(type, index, &builtin)) + is_using_builtin_array = true; + array_type = type_to_array_glsl(physical_type, orig_id); + } + + if (orig_id) + { + auto *data_type = declared_type; + if (is_pointer(*data_type)) + data_type = &get_pointee_type(*data_type); + + if (is_array(*data_type) && get_resource_array_size(*data_type, orig_id) == 0) + { + // Hack for declaring unsized array of resources. Need to declare dummy sized array by value inline. + // This can then be wrapped in spvDescriptorArray as usual. + array_type = "[1] /* unsized array hack */"; + } + } + + string decl_type; + if (declared_type->vecsize > 4) + { + auto orig_type = get(orig_member_type_id); + if (is_matrix(orig_type) && row_major) + swap(orig_type.vecsize, orig_type.columns); + orig_type.columns = 1; + decl_type = type_to_glsl(orig_type, orig_id, true); + + if (declared_type->columns > 1) + decl_type = join("spvPaddedStd140Matrix<", decl_type, ", ", declared_type->columns, ">"); + else + decl_type = join("spvPaddedStd140<", decl_type, ">"); + } + else + decl_type = type_to_glsl(*declared_type, orig_id, true); + + const char *overlapping_binding_tag = + has_extended_member_decoration(type.self, index, SPIRVCrossDecorationOverlappingBinding) ? + "// Overlapping binding: " : ""; + + auto result = join(overlapping_binding_tag, pack_pfx, decl_type, " ", qualifier, + to_member_name(type, index), member_attribute_qualifier(type, index), array_type, ";"); + + is_using_builtin_array = false; + return result; +} + +// Emit a structure member, padding and packing to maintain the correct memeber alignments. +void CompilerMSL::emit_struct_member(const SPIRType &type, uint32_t member_type_id, uint32_t index, + const string &qualifier, uint32_t) +{ + // If this member requires padding to maintain its declared offset, emit a dummy padding member before it. + if (has_extended_member_decoration(type.self, index, SPIRVCrossDecorationPaddingTarget)) + { + uint32_t pad_len = get_extended_member_decoration(type.self, index, SPIRVCrossDecorationPaddingTarget); + statement("char _m", index, "_pad", "[", pad_len, "];"); + } + + // Handle HLSL-style 0-based vertex/instance index. + builtin_declaration = true; + statement(to_struct_member(type, member_type_id, index, qualifier)); + builtin_declaration = false; +} + +void CompilerMSL::emit_struct_padding_target(const SPIRType &type) +{ + uint32_t struct_size = get_declared_struct_size_msl(type, true, true); + uint32_t target_size = get_extended_decoration(type.self, SPIRVCrossDecorationPaddingTarget); + if (target_size < struct_size) + SPIRV_CROSS_THROW("Cannot pad with negative bytes."); + else if (target_size > struct_size) + statement("char _m0_final_padding[", target_size - struct_size, "];"); +} + +// Return a MSL qualifier for the specified function attribute member +string CompilerMSL::member_attribute_qualifier(const SPIRType &type, uint32_t index) +{ + auto &execution = get_entry_point(); + + uint32_t mbr_type_id = type.member_types[index]; + auto &mbr_type = get(mbr_type_id); + + BuiltIn builtin = BuiltInMax; + bool is_builtin = is_member_builtin(type, index, &builtin); + + if (has_extended_member_decoration(type.self, index, SPIRVCrossDecorationResourceIndexPrimary)) + { + string quals = join( + " [[id(", get_extended_member_decoration(type.self, index, SPIRVCrossDecorationResourceIndexPrimary), ")"); + if (interlocked_resources.count( + get_extended_member_decoration(type.self, index, SPIRVCrossDecorationInterfaceOrigID))) + quals += ", raster_order_group(0)"; + quals += "]]"; + return quals; + } + + // Vertex function inputs + if (execution.model == ExecutionModelVertex && type.storage == StorageClassInput) + { + if (is_builtin) + { + switch (builtin) + { + case BuiltInVertexId: + case BuiltInVertexIndex: + case BuiltInBaseVertex: + case BuiltInInstanceId: + case BuiltInInstanceIndex: + case BuiltInBaseInstance: + if (msl_options.vertex_for_tessellation) + return ""; + return string(" [[") + builtin_qualifier(builtin) + "]]"; + + case BuiltInDrawIndex: + SPIRV_CROSS_THROW("DrawIndex is not supported in MSL."); + + default: + return ""; + } + } + + uint32_t locn; + if (is_builtin) + locn = get_or_allocate_builtin_input_member_location(builtin, type.self, index); + else + locn = get_member_location(type.self, index); + + if (locn != k_unknown_location) + return string(" [[attribute(") + convert_to_string(locn) + ")]]"; + } + + // Vertex and tessellation evaluation function outputs + if (((execution.model == ExecutionModelVertex && !msl_options.vertex_for_tessellation) || is_tese_shader()) && + type.storage == StorageClassOutput) + { + if (is_builtin) + { + switch (builtin) + { + case BuiltInPointSize: + // Only mark the PointSize builtin if really rendering points. + // Some shaders may include a PointSize builtin even when used to render + // non-point topologies, and Metal will reject this builtin when compiling + // the shader into a render pipeline that uses a non-point topology. + return msl_options.enable_point_size_builtin ? (string(" [[") + builtin_qualifier(builtin) + "]]") : ""; + + case BuiltInViewportIndex: + if (!msl_options.supports_msl_version(2, 0)) + SPIRV_CROSS_THROW("ViewportIndex requires Metal 2.0."); + /* fallthrough */ + case BuiltInPosition: + case BuiltInLayer: + return string(" [[") + builtin_qualifier(builtin) + "]]" + (mbr_type.array.empty() ? "" : " "); + + case BuiltInClipDistance: + if (has_member_decoration(type.self, index, DecorationIndex)) + return join(" [[user(clip", get_member_decoration(type.self, index, DecorationIndex), ")]]"); + else + return string(" [[") + builtin_qualifier(builtin) + "]]" + (mbr_type.array.empty() ? "" : " "); + + case BuiltInCullDistance: + if (has_member_decoration(type.self, index, DecorationIndex)) + return join(" [[user(cull", get_member_decoration(type.self, index, DecorationIndex), ")]]"); + else + return string(" [[") + builtin_qualifier(builtin) + "]]" + (mbr_type.array.empty() ? "" : " "); + + default: + return ""; + } + } + string loc_qual = member_location_attribute_qualifier(type, index); + if (!loc_qual.empty()) + return join(" [[", loc_qual, "]]"); + } + + if (execution.model == ExecutionModelVertex && msl_options.vertex_for_tessellation && type.storage == StorageClassOutput) + { + // For this type of shader, we always arrange for it to capture its + // output to a buffer. For this reason, qualifiers are irrelevant here. + if (is_builtin) + // We still have to assign a location so the output struct will sort correctly. + get_or_allocate_builtin_output_member_location(builtin, type.self, index); + return ""; + } + + // Tessellation control function inputs + if (is_tesc_shader() && type.storage == StorageClassInput) + { + if (is_builtin) + { + switch (builtin) + { + case BuiltInInvocationId: + case BuiltInPrimitiveId: + if (msl_options.multi_patch_workgroup) + return ""; + return string(" [[") + builtin_qualifier(builtin) + "]]" + (mbr_type.array.empty() ? "" : " "); + case BuiltInSubgroupLocalInvocationId: // FIXME: Should work in any stage + case BuiltInSubgroupSize: // FIXME: Should work in any stage + if (msl_options.emulate_subgroups) + return ""; + return string(" [[") + builtin_qualifier(builtin) + "]]" + (mbr_type.array.empty() ? "" : " "); + case BuiltInPatchVertices: + return ""; + // Others come from stage input. + default: + break; + } + } + if (msl_options.multi_patch_workgroup) + return ""; + + uint32_t locn; + if (is_builtin) + locn = get_or_allocate_builtin_input_member_location(builtin, type.self, index); + else + locn = get_member_location(type.self, index); + + if (locn != k_unknown_location) + return string(" [[attribute(") + convert_to_string(locn) + ")]]"; + } + + // Tessellation control function outputs + if (is_tesc_shader() && type.storage == StorageClassOutput) + { + // For this type of shader, we always arrange for it to capture its + // output to a buffer. For this reason, qualifiers are irrelevant here. + if (is_builtin) + // We still have to assign a location so the output struct will sort correctly. + get_or_allocate_builtin_output_member_location(builtin, type.self, index); + return ""; + } + + // Tessellation evaluation function inputs + if (is_tese_shader() && type.storage == StorageClassInput) + { + if (is_builtin) + { + switch (builtin) + { + case BuiltInPrimitiveId: + case BuiltInTessCoord: + return string(" [[") + builtin_qualifier(builtin) + "]]"; + case BuiltInPatchVertices: + return ""; + // Others come from stage input. + default: + break; + } + } + + if (msl_options.raw_buffer_tese_input) + return ""; + + // The special control point array must not be marked with an attribute. + if (get_type(type.member_types[index]).basetype == SPIRType::ControlPointArray) + return ""; + + uint32_t locn; + if (is_builtin) + locn = get_or_allocate_builtin_input_member_location(builtin, type.self, index); + else + locn = get_member_location(type.self, index); + + if (locn != k_unknown_location) + return string(" [[attribute(") + convert_to_string(locn) + ")]]"; + } + + // Tessellation evaluation function outputs were handled above. + + // Fragment function inputs + if (execution.model == ExecutionModelFragment && type.storage == StorageClassInput) + { + string quals; + if (is_builtin) + { + switch (builtin) + { + case BuiltInViewIndex: + if (!msl_options.multiview || !msl_options.multiview_layered_rendering) + break; + /* fallthrough */ + case BuiltInFrontFacing: + case BuiltInPointCoord: + case BuiltInFragCoord: + case BuiltInSampleId: + case BuiltInSampleMask: + case BuiltInLayer: + case BuiltInBaryCoordKHR: + case BuiltInBaryCoordNoPerspKHR: + quals = builtin_qualifier(builtin); + break; + + case BuiltInClipDistance: + return join(" [[user(clip", get_member_decoration(type.self, index, DecorationIndex), ")]]"); + case BuiltInCullDistance: + return join(" [[user(cull", get_member_decoration(type.self, index, DecorationIndex), ")]]"); + + default: + break; + } + } + else + quals = member_location_attribute_qualifier(type, index); + + if (builtin == BuiltInBaryCoordKHR || builtin == BuiltInBaryCoordNoPerspKHR) + { + if (has_member_decoration(type.self, index, DecorationFlat) || + has_member_decoration(type.self, index, DecorationCentroid) || + has_member_decoration(type.self, index, DecorationSample) || + has_member_decoration(type.self, index, DecorationNoPerspective)) + { + // NoPerspective is baked into the builtin type. + SPIRV_CROSS_THROW( + "Flat, Centroid, Sample, NoPerspective decorations are not supported for BaryCoord inputs."); + } + } + + // Don't bother decorating integers with the 'flat' attribute; it's + // the default (in fact, the only option). Also don't bother with the + // FragCoord builtin; it's always noperspective on Metal. + if (!type_is_integral(mbr_type) && (!is_builtin || builtin != BuiltInFragCoord)) + { + if (has_member_decoration(type.self, index, DecorationFlat)) + { + if (!quals.empty()) + quals += ", "; + quals += "flat"; + } + else if (has_member_decoration(type.self, index, DecorationCentroid)) + { + if (!quals.empty()) + quals += ", "; + if (has_member_decoration(type.self, index, DecorationNoPerspective)) + quals += "centroid_no_perspective"; + else + quals += "centroid_perspective"; + } + else if (has_member_decoration(type.self, index, DecorationSample)) + { + if (!quals.empty()) + quals += ", "; + if (has_member_decoration(type.self, index, DecorationNoPerspective)) + quals += "sample_no_perspective"; + else + quals += "sample_perspective"; + } + else if (has_member_decoration(type.self, index, DecorationNoPerspective)) + { + if (!quals.empty()) + quals += ", "; + quals += "center_no_perspective"; + } + } + + if (!quals.empty()) + return " [[" + quals + "]]"; + } + + // Fragment function outputs + if (execution.model == ExecutionModelFragment && type.storage == StorageClassOutput) + { + if (is_builtin) + { + switch (builtin) + { + case BuiltInFragStencilRefEXT: + // Similar to PointSize, only mark FragStencilRef if there's a stencil buffer. + // Some shaders may include a FragStencilRef builtin even when used to render + // without a stencil attachment, and Metal will reject this builtin + // when compiling the shader into a render pipeline that does not set + // stencilAttachmentPixelFormat. + if (!msl_options.enable_frag_stencil_ref_builtin) + return ""; + if (!msl_options.supports_msl_version(2, 1)) + SPIRV_CROSS_THROW("Stencil export only supported in MSL 2.1 and up."); + return string(" [[") + builtin_qualifier(builtin) + "]]"; + + case BuiltInFragDepth: + // Ditto FragDepth. + if (!msl_options.enable_frag_depth_builtin) + return ""; + /* fallthrough */ + case BuiltInSampleMask: + return string(" [[") + builtin_qualifier(builtin) + "]]"; + + default: + return ""; + } + } + uint32_t locn = get_member_location(type.self, index); + // Metal will likely complain about missing color attachments, too. + if (locn != k_unknown_location && !(msl_options.enable_frag_output_mask & (1 << locn))) + return ""; + if (locn != k_unknown_location && has_member_decoration(type.self, index, DecorationIndex)) + return join(" [[color(", locn, "), index(", get_member_decoration(type.self, index, DecorationIndex), + ")]]"); + else if (locn != k_unknown_location) + return join(" [[color(", locn, ")]]"); + else if (has_member_decoration(type.self, index, DecorationIndex)) + return join(" [[index(", get_member_decoration(type.self, index, DecorationIndex), ")]]"); + else + return ""; + } + + // Compute function inputs + if (execution.model == ExecutionModelGLCompute && type.storage == StorageClassInput) + { + if (is_builtin) + { + switch (builtin) + { + case BuiltInNumSubgroups: + case BuiltInSubgroupId: + case BuiltInSubgroupLocalInvocationId: // FIXME: Should work in any stage + case BuiltInSubgroupSize: // FIXME: Should work in any stage + if (msl_options.emulate_subgroups) + break; + /* fallthrough */ + case BuiltInGlobalInvocationId: + case BuiltInWorkgroupId: + case BuiltInNumWorkgroups: + case BuiltInLocalInvocationId: + case BuiltInLocalInvocationIndex: + return string(" [[") + builtin_qualifier(builtin) + "]]"; + + default: + return ""; + } + } + } + + return ""; +} + +// A user-defined output variable is considered to match an input variable in the subsequent +// stage if the two variables are declared with the same Location and Component decoration and +// match in type and decoration, except that interpolation decorations are not required to match. +// For the purposes of interface matching, variables declared without a Component decoration are +// considered to have a Component decoration of zero. +string CompilerMSL::member_location_attribute_qualifier(const SPIRType &type, uint32_t index) +{ + string quals; + uint32_t comp; + uint32_t locn = get_member_location(type.self, index, &comp); + if (locn != k_unknown_location) + { + quals += "user(locn"; + quals += convert_to_string(locn); + if (comp != k_unknown_component && comp != 0) + { + quals += "_"; + quals += convert_to_string(comp); + } + quals += ")"; + } + return quals; +} + +// Returns the location decoration of the member with the specified index in the specified type. +// If the location of the member has been explicitly set, that location is used. If not, this +// function assumes the members are ordered in their location order, and simply returns the +// index as the location. +uint32_t CompilerMSL::get_member_location(uint32_t type_id, uint32_t index, uint32_t *comp) const +{ + if (comp) + { + if (has_member_decoration(type_id, index, DecorationComponent)) + *comp = get_member_decoration(type_id, index, DecorationComponent); + else + *comp = k_unknown_component; + } + + if (has_member_decoration(type_id, index, DecorationLocation)) + return get_member_decoration(type_id, index, DecorationLocation); + else + return k_unknown_location; +} + +uint32_t CompilerMSL::get_or_allocate_builtin_input_member_location(spv::BuiltIn builtin, + uint32_t type_id, uint32_t index, + uint32_t *comp) +{ + uint32_t loc = get_member_location(type_id, index, comp); + if (loc != k_unknown_location) + return loc; + + if (comp) + *comp = k_unknown_component; + + // Late allocation. Find a location which is unused by the application. + // This can happen for built-in inputs in tessellation which are mixed and matched with user inputs. + auto &mbr_type = get(get(type_id).member_types[index]); + uint32_t count = type_to_location_count(mbr_type); + + loc = 0; + + const auto location_range_in_use = [this](uint32_t location, uint32_t location_count) -> bool { + for (uint32_t i = 0; i < location_count; i++) + if (location_inputs_in_use.count(location + i) != 0) + return true; + return false; + }; + + while (location_range_in_use(loc, count)) + loc++; + + set_member_decoration(type_id, index, DecorationLocation, loc); + + // Triangle tess level inputs are shared in one packed float4, + // mark both builtins as sharing one location. + if (!msl_options.raw_buffer_tese_input && is_tessellating_triangles() && + (builtin == BuiltInTessLevelInner || builtin == BuiltInTessLevelOuter)) + { + builtin_to_automatic_input_location[BuiltInTessLevelInner] = loc; + builtin_to_automatic_input_location[BuiltInTessLevelOuter] = loc; + } + else + builtin_to_automatic_input_location[builtin] = loc; + + mark_location_as_used_by_shader(loc, mbr_type, StorageClassInput, true); + return loc; +} + +uint32_t CompilerMSL::get_or_allocate_builtin_output_member_location(spv::BuiltIn builtin, + uint32_t type_id, uint32_t index, + uint32_t *comp) +{ + uint32_t loc = get_member_location(type_id, index, comp); + if (loc != k_unknown_location) + return loc; + loc = 0; + + if (comp) + *comp = k_unknown_component; + + // Late allocation. Find a location which is unused by the application. + // This can happen for built-in outputs in tessellation which are mixed and matched with user inputs. + auto &mbr_type = get(get(type_id).member_types[index]); + uint32_t count = type_to_location_count(mbr_type); + + const auto location_range_in_use = [this](uint32_t location, uint32_t location_count) -> bool { + for (uint32_t i = 0; i < location_count; i++) + if (location_outputs_in_use.count(location + i) != 0) + return true; + return false; + }; + + while (location_range_in_use(loc, count)) + loc++; + + set_member_decoration(type_id, index, DecorationLocation, loc); + + // Triangle tess level inputs are shared in one packed float4; + // mark both builtins as sharing one location. + if (is_tessellating_triangles() && (builtin == BuiltInTessLevelInner || builtin == BuiltInTessLevelOuter)) + { + builtin_to_automatic_output_location[BuiltInTessLevelInner] = loc; + builtin_to_automatic_output_location[BuiltInTessLevelOuter] = loc; + } + else + builtin_to_automatic_output_location[builtin] = loc; + + mark_location_as_used_by_shader(loc, mbr_type, StorageClassOutput, true); + return loc; +} + +// Returns the type declaration for a function, including the +// entry type if the current function is the entry point function +string CompilerMSL::func_type_decl(SPIRType &type) +{ + // The regular function return type. If not processing the entry point function, that's all we need + string return_type = type_to_glsl(type) + type_to_array_glsl(type, 0); + if (!processing_entry_point) + return return_type; + + // If an outgoing interface block has been defined, and it should be returned, override the entry point return type + bool ep_should_return_output = !get_is_rasterization_disabled(); + if (stage_out_var_id && ep_should_return_output) + return_type = type_to_glsl(get_stage_out_struct_type()) + type_to_array_glsl(type, 0); + + // Prepend a entry type, based on the execution model + string entry_type; + auto &execution = get_entry_point(); + switch (execution.model) + { + case ExecutionModelVertex: + if (msl_options.vertex_for_tessellation && !msl_options.supports_msl_version(1, 2)) + SPIRV_CROSS_THROW("Tessellation requires Metal 1.2."); + entry_type = msl_options.vertex_for_tessellation ? "kernel" : "vertex"; + break; + case ExecutionModelTessellationEvaluation: + if (!msl_options.supports_msl_version(1, 2)) + SPIRV_CROSS_THROW("Tessellation requires Metal 1.2."); + if (execution.flags.get(ExecutionModeIsolines)) + SPIRV_CROSS_THROW("Metal does not support isoline tessellation."); + if (msl_options.is_ios()) + entry_type = join("[[ patch(", is_tessellating_triangles() ? "triangle" : "quad", ") ]] vertex"); + else + entry_type = join("[[ patch(", is_tessellating_triangles() ? "triangle" : "quad", ", ", + execution.output_vertices, ") ]] vertex"); + break; + case ExecutionModelFragment: + entry_type = uses_explicit_early_fragment_test() ? "[[ early_fragment_tests ]] fragment" : "fragment"; + break; + case ExecutionModelTessellationControl: + if (!msl_options.supports_msl_version(1, 2)) + SPIRV_CROSS_THROW("Tessellation requires Metal 1.2."); + if (execution.flags.get(ExecutionModeIsolines)) + SPIRV_CROSS_THROW("Metal does not support isoline tessellation."); + /* fallthrough */ + case ExecutionModelGLCompute: + case ExecutionModelKernel: + entry_type = "kernel"; + break; + default: + entry_type = "unknown"; + break; + } + + return entry_type + " " + return_type; +} + +bool CompilerMSL::is_tesc_shader() const +{ + return get_execution_model() == ExecutionModelTessellationControl; +} + +bool CompilerMSL::is_tese_shader() const +{ + return get_execution_model() == ExecutionModelTessellationEvaluation; +} + +bool CompilerMSL::uses_explicit_early_fragment_test() +{ + auto &ep_flags = get_entry_point().flags; + return ep_flags.get(ExecutionModeEarlyFragmentTests) || ep_flags.get(ExecutionModePostDepthCoverage); +} + +// In MSL, address space qualifiers are required for all pointer or reference variables +string CompilerMSL::get_argument_address_space(const SPIRVariable &argument) +{ + const auto &type = get(argument.basetype); + return get_type_address_space(type, argument.self, true); +} + +bool CompilerMSL::decoration_flags_signal_volatile(const Bitset &flags) +{ + return flags.get(DecorationVolatile) || flags.get(DecorationCoherent); +} + +string CompilerMSL::get_type_address_space(const SPIRType &type, uint32_t id, bool argument) +{ + // This can be called for variable pointer contexts as well, so be very careful about which method we choose. + Bitset flags; + auto *var = maybe_get(id); + if (var && type.basetype == SPIRType::Struct && + (has_decoration(type.self, DecorationBlock) || has_decoration(type.self, DecorationBufferBlock))) + flags = get_buffer_block_flags(id); + else + flags = get_decoration_bitset(id); + + const char *addr_space = nullptr; + switch (type.storage) + { + case StorageClassWorkgroup: + addr_space = "threadgroup"; + break; + + case StorageClassStorageBuffer: + case StorageClassPhysicalStorageBuffer: + { + // For arguments from variable pointers, we use the write count deduction, so + // we should not assume any constness here. Only for global SSBOs. + bool readonly = false; + if (!var || has_decoration(type.self, DecorationBlock)) + readonly = flags.get(DecorationNonWritable); + + addr_space = readonly ? "const device" : "device"; + break; + } + + case StorageClassUniform: + case StorageClassUniformConstant: + case StorageClassPushConstant: + if (type.basetype == SPIRType::Struct) + { + bool ssbo = has_decoration(type.self, DecorationBufferBlock); + if (ssbo) + addr_space = flags.get(DecorationNonWritable) ? "const device" : "device"; + else + addr_space = "constant"; + } + else if (!argument) + { + addr_space = "constant"; + } + else if (type_is_msl_framebuffer_fetch(type)) + { + // Subpass inputs are passed around by value. + addr_space = ""; + } + break; + + case StorageClassFunction: + case StorageClassGeneric: + break; + + case StorageClassInput: + if (is_tesc_shader() && var && var->basevariable == stage_in_ptr_var_id) + addr_space = msl_options.multi_patch_workgroup ? "const device" : "threadgroup"; + // Don't pass tessellation levels in the device AS; we load and convert them + // to float manually. + if (is_tese_shader() && msl_options.raw_buffer_tese_input && var) + { + bool is_stage_in = var->basevariable == stage_in_ptr_var_id; + bool is_patch_stage_in = has_decoration(var->self, DecorationPatch); + bool is_builtin = has_decoration(var->self, DecorationBuiltIn); + BuiltIn builtin = (BuiltIn)get_decoration(var->self, DecorationBuiltIn); + bool is_tess_level = is_builtin && (builtin == BuiltInTessLevelOuter || builtin == BuiltInTessLevelInner); + if (is_stage_in || (is_patch_stage_in && !is_tess_level)) + addr_space = "const device"; + } + if (get_execution_model() == ExecutionModelFragment && var && var->basevariable == stage_in_var_id) + addr_space = "thread"; + break; + + case StorageClassOutput: + if (capture_output_to_buffer) + { + if (var && type.storage == StorageClassOutput) + { + bool is_masked = is_stage_output_variable_masked(*var); + + if (is_masked) + { + if (is_tessellation_shader()) + addr_space = "threadgroup"; + else + addr_space = "thread"; + } + else if (variable_decl_is_remapped_storage(*var, StorageClassWorkgroup)) + addr_space = "threadgroup"; + } + + if (!addr_space) + addr_space = "device"; + } + break; + + default: + break; + } + + if (!addr_space) + { + // No address space for plain values. + addr_space = type.pointer || (argument && type.basetype == SPIRType::ControlPointArray) ? "thread" : ""; + } + + return join(decoration_flags_signal_volatile(flags) ? "volatile " : "", addr_space); +} + +const char *CompilerMSL::to_restrict(uint32_t id, bool space) +{ + // This can be called for variable pointer contexts as well, so be very careful about which method we choose. + Bitset flags; + if (ir.ids[id].get_type() == TypeVariable) + { + uint32_t type_id = expression_type_id(id); + auto &type = expression_type(id); + if (type.basetype == SPIRType::Struct && + (has_decoration(type_id, DecorationBlock) || has_decoration(type_id, DecorationBufferBlock))) + flags = get_buffer_block_flags(id); + else + flags = get_decoration_bitset(id); + } + else + flags = get_decoration_bitset(id); + + return flags.get(DecorationRestrict) || flags.get(DecorationRestrictPointerEXT) ? + (space ? "__restrict " : "__restrict") : ""; +} + +string CompilerMSL::entry_point_arg_stage_in() +{ + string decl; + + if ((is_tesc_shader() && msl_options.multi_patch_workgroup) || + (is_tese_shader() && msl_options.raw_buffer_tese_input)) + return decl; + + // Stage-in structure + uint32_t stage_in_id; + if (is_tese_shader()) + stage_in_id = patch_stage_in_var_id; + else + stage_in_id = stage_in_var_id; + + if (stage_in_id) + { + auto &var = get(stage_in_id); + auto &type = get_variable_data_type(var); + + add_resource_name(var.self); + decl = join(type_to_glsl(type), " ", to_name(var.self), " [[stage_in]]"); + } + + return decl; +} + +// Returns true if this input builtin should be a direct parameter on a shader function parameter list, +// and false for builtins that should be passed or calculated some other way. +bool CompilerMSL::is_direct_input_builtin(BuiltIn bi_type) +{ + switch (bi_type) + { + // Vertex function in + case BuiltInVertexId: + case BuiltInVertexIndex: + case BuiltInBaseVertex: + case BuiltInInstanceId: + case BuiltInInstanceIndex: + case BuiltInBaseInstance: + return get_execution_model() != ExecutionModelVertex || !msl_options.vertex_for_tessellation; + // Tess. control function in + case BuiltInPosition: + case BuiltInPointSize: + case BuiltInClipDistance: + case BuiltInCullDistance: + case BuiltInPatchVertices: + return false; + case BuiltInInvocationId: + case BuiltInPrimitiveId: + return !is_tesc_shader() || !msl_options.multi_patch_workgroup; + // Tess. evaluation function in + case BuiltInTessLevelInner: + case BuiltInTessLevelOuter: + return false; + // Fragment function in + case BuiltInSamplePosition: + case BuiltInHelperInvocation: + case BuiltInBaryCoordKHR: + case BuiltInBaryCoordNoPerspKHR: + return false; + case BuiltInViewIndex: + return get_execution_model() == ExecutionModelFragment && msl_options.multiview && + msl_options.multiview_layered_rendering; + // Compute function in + case BuiltInSubgroupId: + case BuiltInNumSubgroups: + return !msl_options.emulate_subgroups; + // Any stage function in + case BuiltInDeviceIndex: + case BuiltInSubgroupEqMask: + case BuiltInSubgroupGeMask: + case BuiltInSubgroupGtMask: + case BuiltInSubgroupLeMask: + case BuiltInSubgroupLtMask: + return false; + case BuiltInSubgroupSize: + if (msl_options.fixed_subgroup_size != 0) + return false; + /* fallthrough */ + case BuiltInSubgroupLocalInvocationId: + return !msl_options.emulate_subgroups; + default: + return true; + } +} + +// Returns true if this is a fragment shader that runs per sample, and false otherwise. +bool CompilerMSL::is_sample_rate() const +{ + auto &caps = get_declared_capabilities(); + return get_execution_model() == ExecutionModelFragment && + (msl_options.force_sample_rate_shading || + std::find(caps.begin(), caps.end(), CapabilitySampleRateShading) != caps.end() || + (msl_options.use_framebuffer_fetch_subpasses && need_subpass_input_ms)); +} + +bool CompilerMSL::is_intersection_query() const +{ + auto &caps = get_declared_capabilities(); + return std::find(caps.begin(), caps.end(), CapabilityRayQueryKHR) != caps.end(); +} + +void CompilerMSL::entry_point_args_builtin(string &ep_args) +{ + // Builtin variables + SmallVector, 8> active_builtins; + ir.for_each_typed_id([&](uint32_t var_id, SPIRVariable &var) { + if (var.storage != StorageClassInput) + return; + + auto bi_type = BuiltIn(get_decoration(var_id, DecorationBuiltIn)); + + // Don't emit SamplePosition as a separate parameter. In the entry + // point, we get that by calling get_sample_position() on the sample ID. + if (is_builtin_variable(var) && + get_variable_data_type(var).basetype != SPIRType::Struct && + get_variable_data_type(var).basetype != SPIRType::ControlPointArray) + { + // If the builtin is not part of the active input builtin set, don't emit it. + // Relevant for multiple entry-point modules which might declare unused builtins. + if (!active_input_builtins.get(bi_type) || !interface_variable_exists_in_entry_point(var_id)) + return; + + // Remember this variable. We may need to correct its type. + active_builtins.push_back(make_pair(&var, bi_type)); + + if (is_direct_input_builtin(bi_type)) + { + if (!ep_args.empty()) + ep_args += ", "; + + // Handle HLSL-style 0-based vertex/instance index. + builtin_declaration = true; + + // Handle different MSL gl_TessCoord types. (float2, float3) + if (bi_type == BuiltInTessCoord && get_entry_point().flags.get(ExecutionModeQuads)) + ep_args += "float2 " + to_expression(var_id) + "In"; + else + ep_args += builtin_type_decl(bi_type, var_id) + " " + to_expression(var_id); + + ep_args += string(" [[") + builtin_qualifier(bi_type); + if (bi_type == BuiltInSampleMask && get_entry_point().flags.get(ExecutionModePostDepthCoverage)) + { + if (!msl_options.supports_msl_version(2)) + SPIRV_CROSS_THROW("Post-depth coverage requires MSL 2.0."); + if (msl_options.is_macos() && !msl_options.supports_msl_version(2, 3)) + SPIRV_CROSS_THROW("Post-depth coverage on Mac requires MSL 2.3."); + ep_args += ", post_depth_coverage"; + } + ep_args += "]]"; + builtin_declaration = false; + } + } + + if (has_extended_decoration(var_id, SPIRVCrossDecorationBuiltInDispatchBase)) + { + // This is a special implicit builtin, not corresponding to any SPIR-V builtin, + // which holds the base that was passed to vkCmdDispatchBase() or vkCmdDrawIndexed(). If it's present, + // assume we emitted it for a good reason. + assert(msl_options.supports_msl_version(1, 2)); + if (!ep_args.empty()) + ep_args += ", "; + + ep_args += type_to_glsl(get_variable_data_type(var)) + " " + to_expression(var_id) + " [[grid_origin]]"; + } + + if (has_extended_decoration(var_id, SPIRVCrossDecorationBuiltInStageInputSize)) + { + // This is another special implicit builtin, not corresponding to any SPIR-V builtin, + // which holds the number of vertices and instances to draw. If it's present, + // assume we emitted it for a good reason. + assert(msl_options.supports_msl_version(1, 2)); + if (!ep_args.empty()) + ep_args += ", "; + + ep_args += type_to_glsl(get_variable_data_type(var)) + " " + to_expression(var_id) + " [[grid_size]]"; + } + }); + + // Correct the types of all encountered active builtins. We couldn't do this before + // because ensure_correct_builtin_type() may increase the bound, which isn't allowed + // while iterating over IDs. + for (auto &var : active_builtins) + var.first->basetype = ensure_correct_builtin_type(var.first->basetype, var.second); + + // Handle HLSL-style 0-based vertex/instance index. + if (needs_base_vertex_arg == TriState::Yes) + ep_args += built_in_func_arg(BuiltInBaseVertex, !ep_args.empty()); + + if (needs_base_instance_arg == TriState::Yes) + ep_args += built_in_func_arg(BuiltInBaseInstance, !ep_args.empty()); + + if (capture_output_to_buffer) + { + // Add parameters to hold the indirect draw parameters and the shader output. This has to be handled + // specially because it needs to be a pointer, not a reference. + if (stage_out_var_id) + { + if (!ep_args.empty()) + ep_args += ", "; + ep_args += join("device ", type_to_glsl(get_stage_out_struct_type()), "* ", output_buffer_var_name, + " [[buffer(", msl_options.shader_output_buffer_index, ")]]"); + } + + if (is_tesc_shader()) + { + if (!ep_args.empty()) + ep_args += ", "; + ep_args += + join("constant uint* spvIndirectParams [[buffer(", msl_options.indirect_params_buffer_index, ")]]"); + } + else if (stage_out_var_id && + !(get_execution_model() == ExecutionModelVertex && msl_options.vertex_for_tessellation)) + { + if (!ep_args.empty()) + ep_args += ", "; + ep_args += + join("device uint* spvIndirectParams [[buffer(", msl_options.indirect_params_buffer_index, ")]]"); + } + + if (get_execution_model() == ExecutionModelVertex && msl_options.vertex_for_tessellation && + (active_input_builtins.get(BuiltInVertexIndex) || active_input_builtins.get(BuiltInVertexId)) && + msl_options.vertex_index_type != Options::IndexType::None) + { + // Add the index buffer so we can set gl_VertexIndex correctly. + if (!ep_args.empty()) + ep_args += ", "; + switch (msl_options.vertex_index_type) + { + case Options::IndexType::None: + break; + case Options::IndexType::UInt16: + ep_args += join("const device ushort* ", index_buffer_var_name, " [[buffer(", + msl_options.shader_index_buffer_index, ")]]"); + break; + case Options::IndexType::UInt32: + ep_args += join("const device uint* ", index_buffer_var_name, " [[buffer(", + msl_options.shader_index_buffer_index, ")]]"); + break; + } + } + + // Tessellation control shaders get three additional parameters: + // a buffer to hold the per-patch data, a buffer to hold the per-patch + // tessellation levels, and a block of workgroup memory to hold the + // input control point data. + if (is_tesc_shader()) + { + if (patch_stage_out_var_id) + { + if (!ep_args.empty()) + ep_args += ", "; + ep_args += + join("device ", type_to_glsl(get_patch_stage_out_struct_type()), "* ", patch_output_buffer_var_name, + " [[buffer(", convert_to_string(msl_options.shader_patch_output_buffer_index), ")]]"); + } + if (!ep_args.empty()) + ep_args += ", "; + ep_args += join("device ", get_tess_factor_struct_name(), "* ", tess_factor_buffer_var_name, " [[buffer(", + convert_to_string(msl_options.shader_tess_factor_buffer_index), ")]]"); + + // Initializer for tess factors must be handled specially since it's never declared as a normal variable. + uint32_t outer_factor_initializer_id = 0; + uint32_t inner_factor_initializer_id = 0; + ir.for_each_typed_id([&](uint32_t, SPIRVariable &var) { + if (!has_decoration(var.self, DecorationBuiltIn) || var.storage != StorageClassOutput || !var.initializer) + return; + + BuiltIn builtin = BuiltIn(get_decoration(var.self, DecorationBuiltIn)); + if (builtin == BuiltInTessLevelInner) + inner_factor_initializer_id = var.initializer; + else if (builtin == BuiltInTessLevelOuter) + outer_factor_initializer_id = var.initializer; + }); + + const SPIRConstant *c = nullptr; + + if (outer_factor_initializer_id && (c = maybe_get(outer_factor_initializer_id))) + { + auto &entry_func = get(ir.default_entry_point); + entry_func.fixup_hooks_in.push_back( + [=]() + { + uint32_t components = is_tessellating_triangles() ? 3 : 4; + for (uint32_t i = 0; i < components; i++) + { + statement(builtin_to_glsl(BuiltInTessLevelOuter, StorageClassOutput), "[", i, + "] = ", "half(", to_expression(c->subconstants[i]), ");"); + } + }); + } + + if (inner_factor_initializer_id && (c = maybe_get(inner_factor_initializer_id))) + { + auto &entry_func = get(ir.default_entry_point); + if (is_tessellating_triangles()) + { + entry_func.fixup_hooks_in.push_back([=]() { + statement(builtin_to_glsl(BuiltInTessLevelInner, StorageClassOutput), " = ", "half(", + to_expression(c->subconstants[0]), ");"); + }); + } + else + { + entry_func.fixup_hooks_in.push_back([=]() { + for (uint32_t i = 0; i < 2; i++) + { + statement(builtin_to_glsl(BuiltInTessLevelInner, StorageClassOutput), "[", i, "] = ", + "half(", to_expression(c->subconstants[i]), ");"); + } + }); + } + } + + if (stage_in_var_id) + { + if (!ep_args.empty()) + ep_args += ", "; + if (msl_options.multi_patch_workgroup) + { + ep_args += join("device ", type_to_glsl(get_stage_in_struct_type()), "* ", input_buffer_var_name, + " [[buffer(", convert_to_string(msl_options.shader_input_buffer_index), ")]]"); + } + else + { + ep_args += join("threadgroup ", type_to_glsl(get_stage_in_struct_type()), "* ", input_wg_var_name, + " [[threadgroup(", convert_to_string(msl_options.shader_input_wg_index), ")]]"); + } + } + } + } + // Tessellation evaluation shaders get three additional parameters: + // a buffer for the per-patch data, a buffer for the per-patch + // tessellation levels, and a buffer for the control point data. + if (is_tese_shader() && msl_options.raw_buffer_tese_input) + { + if (patch_stage_in_var_id) + { + if (!ep_args.empty()) + ep_args += ", "; + ep_args += + join("const device ", type_to_glsl(get_patch_stage_in_struct_type()), "* ", patch_input_buffer_var_name, + " [[buffer(", convert_to_string(msl_options.shader_patch_input_buffer_index), ")]]"); + } + + if (tess_level_inner_var_id || tess_level_outer_var_id) + { + if (!ep_args.empty()) + ep_args += ", "; + ep_args += join("const device ", get_tess_factor_struct_name(), "* ", tess_factor_buffer_var_name, + " [[buffer(", convert_to_string(msl_options.shader_tess_factor_buffer_index), ")]]"); + } + + if (stage_in_var_id) + { + if (!ep_args.empty()) + ep_args += ", "; + ep_args += join("const device ", type_to_glsl(get_stage_in_struct_type()), "* ", input_buffer_var_name, + " [[buffer(", convert_to_string(msl_options.shader_input_buffer_index), ")]]"); + } + } +} + +string CompilerMSL::entry_point_args_argument_buffer(bool append_comma) +{ + string ep_args = entry_point_arg_stage_in(); + Bitset claimed_bindings; + + for (uint32_t i = 0; i < kMaxArgumentBuffers; i++) + { + uint32_t id = argument_buffer_ids[i]; + if (id == 0) + continue; + + add_resource_name(id); + auto &var = get(id); + auto &type = get_variable_data_type(var); + + if (!ep_args.empty()) + ep_args += ", "; + + // Check if the argument buffer binding itself has been remapped. + uint32_t buffer_binding; + auto itr = resource_bindings.find({ get_entry_point().model, i, kArgumentBufferBinding }); + if (itr != end(resource_bindings)) + { + buffer_binding = itr->second.first.msl_buffer; + itr->second.second = true; + } + else + { + // As a fallback, directly map desc set <-> binding. + // If that was taken, take the next buffer binding. + if (claimed_bindings.get(i)) + buffer_binding = next_metal_resource_index_buffer; + else + buffer_binding = i; + } + + claimed_bindings.set(buffer_binding); + + ep_args += get_argument_address_space(var) + " "; + + if (recursive_inputs.count(type.self)) + ep_args += string("void* ") + to_restrict(id, true) + to_name(id) + "_vp"; + else + ep_args += type_to_glsl(type) + "& " + to_restrict(id, true) + to_name(id); + + ep_args += " [[buffer(" + convert_to_string(buffer_binding) + ")]]"; + + next_metal_resource_index_buffer = max(next_metal_resource_index_buffer, buffer_binding + 1); + } + + entry_point_args_discrete_descriptors(ep_args); + entry_point_args_builtin(ep_args); + + if (!ep_args.empty() && append_comma) + ep_args += ", "; + + return ep_args; +} + +const MSLConstexprSampler *CompilerMSL::find_constexpr_sampler(uint32_t id) const +{ + // Try by ID. + { + auto itr = constexpr_samplers_by_id.find(id); + if (itr != end(constexpr_samplers_by_id)) + return &itr->second; + } + + // Try by binding. + { + uint32_t desc_set = get_decoration(id, DecorationDescriptorSet); + uint32_t binding = get_decoration(id, DecorationBinding); + + auto itr = constexpr_samplers_by_binding.find({ desc_set, binding }); + if (itr != end(constexpr_samplers_by_binding)) + return &itr->second; + } + + return nullptr; +} + +void CompilerMSL::entry_point_args_discrete_descriptors(string &ep_args) +{ + // Output resources, sorted by resource index & type + // We need to sort to work around a bug on macOS 10.13 with NVidia drivers where switching between shaders + // with different order of buffers can result in issues with buffer assignments inside the driver. + struct Resource + { + SPIRVariable *var; + SPIRVariable *discrete_descriptor_alias; + string name; + SPIRType::BaseType basetype; + uint32_t index; + uint32_t plane; + uint32_t secondary_index; + }; + + SmallVector resources; + + entry_point_bindings.clear(); + ir.for_each_typed_id([&](uint32_t var_id, SPIRVariable &var) { + if ((var.storage == StorageClassUniform || var.storage == StorageClassUniformConstant || + var.storage == StorageClassPushConstant || var.storage == StorageClassStorageBuffer) && + !is_hidden_variable(var)) + { + auto &type = get_variable_data_type(var); + uint32_t desc_set = get_decoration(var_id, DecorationDescriptorSet); + + if (is_supported_argument_buffer_type(type) && var.storage != StorageClassPushConstant) + { + if (descriptor_set_is_argument_buffer(desc_set)) + { + if (is_var_runtime_size_array(var)) + { + // Runtime arrays need to be wrapped in spvDescriptorArray from argument buffer payload. + entry_point_bindings.push_back(&var); + // We'll wrap this, so to_name() will always use non-qualified name. + // We'll need the qualified name to create temporary variable instead. + ir.meta[var_id].decoration.qualified_alias_explicit_override = true; + } + return; + } + } + + // Handle descriptor aliasing of simple discrete cases. + // We can handle aliasing of buffers by casting pointers. + // The amount of aliasing we can perform for discrete descriptors is very limited. + // For fully mutable-style aliasing, we need argument buffers where we can exploit the fact + // that descriptors are all 8 bytes. + SPIRVariable *discrete_descriptor_alias = nullptr; + if (var.storage == StorageClassUniform || var.storage == StorageClassStorageBuffer) + { + for (auto &resource : resources) + { + if (get_decoration(resource.var->self, DecorationDescriptorSet) == + get_decoration(var_id, DecorationDescriptorSet) && + get_decoration(resource.var->self, DecorationBinding) == + get_decoration(var_id, DecorationBinding) && + resource.basetype == SPIRType::Struct && type.basetype == SPIRType::Struct && + (resource.var->storage == StorageClassUniform || + resource.var->storage == StorageClassStorageBuffer)) + { + discrete_descriptor_alias = resource.var; + // Self-reference marks that we should declare the resource, + // and it's being used as an alias (so we can emit void* instead). + resource.discrete_descriptor_alias = resource.var; + // Need to promote interlocked usage so that the primary declaration is correct. + if (interlocked_resources.count(var_id)) + interlocked_resources.insert(resource.var->self); + break; + } + } + } + + const MSLConstexprSampler *constexpr_sampler = nullptr; + if (type.basetype == SPIRType::SampledImage || type.basetype == SPIRType::Sampler) + { + constexpr_sampler = find_constexpr_sampler(var_id); + if (constexpr_sampler) + { + // Mark this ID as a constexpr sampler for later in case it came from set/bindings. + constexpr_samplers_by_id[var_id] = *constexpr_sampler; + } + } + + // Emulate texture2D atomic operations + uint32_t secondary_index = 0; + if (atomic_image_vars_emulated.count(var.self)) + { + secondary_index = get_metal_resource_index(var, SPIRType::AtomicCounter, 0); + } + + if (type.basetype == SPIRType::SampledImage) + { + add_resource_name(var_id); + + uint32_t plane_count = 1; + if (constexpr_sampler && constexpr_sampler->ycbcr_conversion_enable) + plane_count = constexpr_sampler->planes; + + entry_point_bindings.push_back(&var); + for (uint32_t i = 0; i < plane_count; i++) + resources.push_back({&var, discrete_descriptor_alias, to_name(var_id), SPIRType::Image, + get_metal_resource_index(var, SPIRType::Image, i), i, secondary_index }); + + if (type.image.dim != DimBuffer && !constexpr_sampler) + { + resources.push_back({&var, discrete_descriptor_alias, to_sampler_expression(var_id), SPIRType::Sampler, + get_metal_resource_index(var, SPIRType::Sampler), 0, 0 }); + } + } + else if (!constexpr_sampler) + { + // constexpr samplers are not declared as resources. + add_resource_name(var_id); + + // Don't allocate resource indices for aliases. + uint32_t resource_index = ~0u; + if (!discrete_descriptor_alias) + resource_index = get_metal_resource_index(var, type.basetype); + + entry_point_bindings.push_back(&var); + resources.push_back({&var, discrete_descriptor_alias, to_name(var_id), type.basetype, + resource_index, 0, secondary_index }); + } + } + }); + + stable_sort(resources.begin(), resources.end(), + [](const Resource &lhs, const Resource &rhs) + { return tie(lhs.basetype, lhs.index) < tie(rhs.basetype, rhs.index); }); + + for (auto &r : resources) + { + auto &var = *r.var; + auto &type = get_variable_data_type(var); + + uint32_t var_id = var.self; + + switch (r.basetype) + { + case SPIRType::Struct: + { + auto &m = ir.meta[type.self]; + if (m.members.size() == 0) + break; + + if (r.discrete_descriptor_alias) + { + if (r.var == r.discrete_descriptor_alias) + { + auto primary_name = join("spvBufferAliasSet", + get_decoration(var_id, DecorationDescriptorSet), + "Binding", + get_decoration(var_id, DecorationBinding)); + + // Declare the primary alias as void* + if (!ep_args.empty()) + ep_args += ", "; + ep_args += get_argument_address_space(var) + " void* " + primary_name; + ep_args += " [[buffer(" + convert_to_string(r.index) + ")"; + if (interlocked_resources.count(var_id)) + ep_args += ", raster_order_group(0)"; + ep_args += "]]"; + } + + buffer_aliases_discrete.push_back(r.var->self); + } + else if (!type.array.empty()) + { + if (type.array.size() > 1) + SPIRV_CROSS_THROW("Arrays of arrays of buffers are not supported."); + + is_using_builtin_array = true; + if (is_var_runtime_size_array(var)) + { + add_spv_func_and_recompile(SPVFuncImplVariableDescriptorArray); + if (!ep_args.empty()) + ep_args += ", "; + const bool ssbo = has_decoration(type.self, DecorationBufferBlock); + if ((var.storage == spv::StorageClassStorageBuffer || ssbo) && + msl_options.runtime_array_rich_descriptor) + { + add_spv_func_and_recompile(SPVFuncImplVariableSizedDescriptor); + ep_args += "const device spvBufferDescriptor<" + get_argument_address_space(var) + " " + + type_to_glsl(type) + "*>* "; + } + else + { + ep_args += "const device spvDescriptor<" + get_argument_address_space(var) + " " + + type_to_glsl(type) + "*>* "; + } + ep_args += to_restrict(var_id, true) + r.name + "_"; + ep_args += " [[buffer(" + convert_to_string(r.index) + ")"; + if (interlocked_resources.count(var_id)) + ep_args += ", raster_order_group(0)"; + ep_args += "]]"; + } + else + { + uint32_t array_size = get_resource_array_size(type, var_id); + for (uint32_t i = 0; i < array_size; ++i) + { + if (!ep_args.empty()) + ep_args += ", "; + ep_args += get_argument_address_space(var) + " " + type_to_glsl(type) + "* " + + to_restrict(var_id, true) + r.name + "_" + convert_to_string(i); + ep_args += " [[buffer(" + convert_to_string(r.index + i) + ")"; + if (interlocked_resources.count(var_id)) + ep_args += ", raster_order_group(0)"; + ep_args += "]]"; + } + } + is_using_builtin_array = false; + } + else + { + if (!ep_args.empty()) + ep_args += ", "; + ep_args += get_argument_address_space(var) + " "; + + if (recursive_inputs.count(type.self)) + ep_args += string("void* ") + to_restrict(var_id, true) + r.name + "_vp"; + else + ep_args += type_to_glsl(type) + "& " + to_restrict(var_id, true) + r.name; + + ep_args += " [[buffer(" + convert_to_string(r.index) + ")"; + if (interlocked_resources.count(var_id)) + ep_args += ", raster_order_group(0)"; + ep_args += "]]"; + } + break; + } + case SPIRType::Sampler: + if (!ep_args.empty()) + ep_args += ", "; + ep_args += sampler_type(type, var_id, false) + " " + r.name; + if (is_var_runtime_size_array(var)) + ep_args += "_ [[buffer(" + convert_to_string(r.index) + ")]]"; + else + ep_args += " [[sampler(" + convert_to_string(r.index) + ")]]"; + break; + case SPIRType::Image: + { + if (!ep_args.empty()) + ep_args += ", "; + + // Use Metal's native frame-buffer fetch API for subpass inputs. + const auto &basetype = get(var.basetype); + if (!type_is_msl_framebuffer_fetch(basetype)) + { + ep_args += image_type_glsl(type, var_id, false) + " " + r.name; + if (r.plane > 0) + ep_args += join(plane_name_suffix, r.plane); + + if (is_var_runtime_size_array(var)) + ep_args += "_ [[buffer(" + convert_to_string(r.index) + ")"; + else + ep_args += " [[texture(" + convert_to_string(r.index) + ")"; + + if (interlocked_resources.count(var_id)) + ep_args += ", raster_order_group(0)"; + ep_args += "]]"; + } + else + { + if (msl_options.is_macos() && !msl_options.supports_msl_version(2, 3)) + SPIRV_CROSS_THROW("Framebuffer fetch on Mac is not supported before MSL 2.3."); + ep_args += image_type_glsl(type, var_id, false) + " " + r.name; + ep_args += " [[color(" + convert_to_string(r.index) + ")]]"; + } + + // Emulate texture2D atomic operations + if (atomic_image_vars_emulated.count(var.self)) + { + auto &flags = ir.get_decoration_bitset(var.self); + const char *cv_flags = decoration_flags_signal_volatile(flags) ? "volatile " : ""; + ep_args += join(", ", cv_flags, "device atomic_", type_to_glsl(get(basetype.image.type), 0)); + ep_args += "* " + r.name + "_atomic"; + ep_args += " [[buffer(" + convert_to_string(r.secondary_index) + ")"; + if (interlocked_resources.count(var_id)) + ep_args += ", raster_order_group(0)"; + ep_args += "]]"; + } + break; + } + case SPIRType::AccelerationStructure: + { + if (is_var_runtime_size_array(var)) + { + add_spv_func_and_recompile(SPVFuncImplVariableDescriptor); + const auto &parent_type = get(type.parent_type); + if (!ep_args.empty()) + ep_args += ", "; + ep_args += "const device spvDescriptor<" + type_to_glsl(parent_type) + ">* " + + to_restrict(var_id, true) + r.name + "_"; + ep_args += " [[buffer(" + convert_to_string(r.index) + ")]]"; + } + else + { + if (!ep_args.empty()) + ep_args += ", "; + ep_args += type_to_glsl(type, var_id) + " " + r.name; + ep_args += " [[buffer(" + convert_to_string(r.index) + ")]]"; + } + break; + } + default: + if (!ep_args.empty()) + ep_args += ", "; + if (!type.pointer) + ep_args += get_type_address_space(get(var.basetype), var_id) + " " + + type_to_glsl(type, var_id) + "& " + r.name; + else + ep_args += type_to_glsl(type, var_id) + " " + r.name; + ep_args += " [[buffer(" + convert_to_string(r.index) + ")"; + if (interlocked_resources.count(var_id)) + ep_args += ", raster_order_group(0)"; + ep_args += "]]"; + break; + } + } +} + +// Returns a string containing a comma-delimited list of args for the entry point function +// This is the "classic" method of MSL 1 when we don't have argument buffer support. +string CompilerMSL::entry_point_args_classic(bool append_comma) +{ + string ep_args = entry_point_arg_stage_in(); + entry_point_args_discrete_descriptors(ep_args); + entry_point_args_builtin(ep_args); + + if (!ep_args.empty() && append_comma) + ep_args += ", "; + + return ep_args; +} + +void CompilerMSL::fix_up_shader_inputs_outputs() +{ + auto &entry_func = this->get(ir.default_entry_point); + + // Emit a guard to ensure we don't execute beyond the last vertex. + // Vertex shaders shouldn't have the problems with barriers in non-uniform control flow that + // tessellation control shaders do, so early returns should be OK. We may need to revisit this + // if it ever becomes possible to use barriers from a vertex shader. + if (get_execution_model() == ExecutionModelVertex && msl_options.vertex_for_tessellation) + { + entry_func.fixup_hooks_in.push_back([this]() { + statement("if (any(", to_expression(builtin_invocation_id_id), + " >= ", to_expression(builtin_stage_input_size_id), "))"); + statement(" return;"); + }); + } + + // Look for sampled images and buffer. Add hooks to set up the swizzle constants or array lengths. + ir.for_each_typed_id([&](uint32_t, SPIRVariable &var) { + auto &type = get_variable_data_type(var); + uint32_t var_id = var.self; + bool ssbo = has_decoration(type.self, DecorationBufferBlock); + + if (var.storage == StorageClassUniformConstant && !is_hidden_variable(var)) + { + if (msl_options.swizzle_texture_samples && has_sampled_images && is_sampled_image_type(type)) + { + entry_func.fixup_hooks_in.push_back([this, &type, &var, var_id]() { + bool is_array_type = !type.array.empty(); + + uint32_t desc_set = get_decoration(var_id, DecorationDescriptorSet); + if (descriptor_set_is_argument_buffer(desc_set)) + { + statement("constant uint", is_array_type ? "* " : "& ", to_swizzle_expression(var_id), + is_array_type ? " = &" : " = ", to_name(argument_buffer_ids[desc_set]), + ".spvSwizzleConstants", "[", + convert_to_string(get_metal_resource_index(var, SPIRType::Image)), "];"); + } + else + { + // If we have an array of images, we need to be able to index into it, so take a pointer instead. + statement("constant uint", is_array_type ? "* " : "& ", to_swizzle_expression(var_id), + is_array_type ? " = &" : " = ", to_name(swizzle_buffer_id), "[", + convert_to_string(get_metal_resource_index(var, SPIRType::Image)), "];"); + } + }); + } + } + else if ((var.storage == StorageClassStorageBuffer || (var.storage == StorageClassUniform && ssbo)) && + !is_hidden_variable(var)) + { + if (buffer_requires_array_length(var.self)) + { + entry_func.fixup_hooks_in.push_back( + [this, &type, &var, var_id]() + { + bool is_array_type = !type.array.empty() && !is_var_runtime_size_array(var); + + uint32_t desc_set = get_decoration(var_id, DecorationDescriptorSet); + if (descriptor_set_is_argument_buffer(desc_set)) + { + statement("constant uint", is_array_type ? "* " : "& ", to_buffer_size_expression(var_id), + is_array_type ? " = &" : " = ", to_name(argument_buffer_ids[desc_set]), + ".spvBufferSizeConstants", "[", + convert_to_string(get_metal_resource_index(var, SPIRType::UInt)), "];"); + } + else + { + // If we have an array of images, we need to be able to index into it, so take a pointer instead. + statement("constant uint", is_array_type ? "* " : "& ", to_buffer_size_expression(var_id), + is_array_type ? " = &" : " = ", to_name(buffer_size_buffer_id), "[", + convert_to_string(get_metal_resource_index(var, type.basetype)), "];"); + } + }); + } + } + + if (!msl_options.argument_buffers && + msl_options.replace_recursive_inputs && type_contains_recursion(type) && + (var.storage == StorageClassUniform || var.storage == StorageClassUniformConstant || + var.storage == StorageClassPushConstant || var.storage == StorageClassStorageBuffer)) + { + recursive_inputs.insert(type.self); + entry_func.fixup_hooks_in.push_back([this, &type, &var, var_id]() { + auto addr_space = get_argument_address_space(var); + auto var_name = to_name(var_id); + statement(addr_space, " auto& ", to_restrict(var_id, true), var_name, + " = *(", addr_space, " ", type_to_glsl(type), "*)", var_name, "_vp;"); + }); + } + }); + + // Builtin variables + ir.for_each_typed_id([this, &entry_func](uint32_t, SPIRVariable &var) { + uint32_t var_id = var.self; + BuiltIn bi_type = ir.meta[var_id].decoration.builtin_type; + + if (var.storage != StorageClassInput && var.storage != StorageClassOutput) + return; + if (!interface_variable_exists_in_entry_point(var.self)) + return; + + if (var.storage == StorageClassInput && is_builtin_variable(var) && active_input_builtins.get(bi_type)) + { + switch (bi_type) + { + case BuiltInSamplePosition: + entry_func.fixup_hooks_in.push_back([=]() { + statement(builtin_type_decl(bi_type), " ", to_expression(var_id), " = get_sample_position(", + to_expression(builtin_sample_id_id), ");"); + }); + break; + case BuiltInFragCoord: + if (is_sample_rate()) + { + entry_func.fixup_hooks_in.push_back([=]() { + statement(to_expression(var_id), ".xy += get_sample_position(", + to_expression(builtin_sample_id_id), ") - 0.5;"); + }); + } + break; + case BuiltInInvocationId: + // This is direct-mapped without multi-patch workgroups. + if (!is_tesc_shader() || !msl_options.multi_patch_workgroup) + break; + + entry_func.fixup_hooks_in.push_back([=]() { + statement(builtin_type_decl(bi_type), " ", to_expression(var_id), " = ", + to_expression(builtin_invocation_id_id), ".x % ", this->get_entry_point().output_vertices, + ";"); + }); + break; + case BuiltInPrimitiveId: + // This is natively supported by fragment and tessellation evaluation shaders. + // In tessellation control shaders, this is direct-mapped without multi-patch workgroups. + if (!is_tesc_shader() || !msl_options.multi_patch_workgroup) + break; + + entry_func.fixup_hooks_in.push_back([=]() { + statement(builtin_type_decl(bi_type), " ", to_expression(var_id), " = min(", + to_expression(builtin_invocation_id_id), ".x / ", this->get_entry_point().output_vertices, + ", spvIndirectParams[1] - 1);"); + }); + break; + case BuiltInPatchVertices: + if (is_tese_shader()) + { + if (msl_options.raw_buffer_tese_input) + { + entry_func.fixup_hooks_in.push_back( + [=]() { + statement(builtin_type_decl(bi_type), " ", to_expression(var_id), " = ", + get_entry_point().output_vertices, ";"); + }); + } + else + { + entry_func.fixup_hooks_in.push_back( + [=]() + { + statement(builtin_type_decl(bi_type), " ", to_expression(var_id), " = ", + to_expression(patch_stage_in_var_id), ".gl_in.size();"); + }); + } + } + else + { + entry_func.fixup_hooks_in.push_back([=]() { + statement(builtin_type_decl(bi_type), " ", to_expression(var_id), " = spvIndirectParams[0];"); + }); + } + break; + case BuiltInTessCoord: + if (get_entry_point().flags.get(ExecutionModeQuads)) + { + // The entry point will only have a float2 TessCoord variable. + // Pad to float3. + entry_func.fixup_hooks_in.push_back([=]() { + auto name = builtin_to_glsl(BuiltInTessCoord, StorageClassInput); + statement("float3 " + name + " = float3(" + name + "In.x, " + name + "In.y, 0.0);"); + }); + } + + // Emit a fixup to account for the shifted domain. Don't do this for triangles; + // MoltenVK will just reverse the winding order instead. + if (msl_options.tess_domain_origin_lower_left && !is_tessellating_triangles()) + { + string tc = to_expression(var_id); + entry_func.fixup_hooks_in.push_back([=]() { statement(tc, ".y = 1.0 - ", tc, ".y;"); }); + } + break; + case BuiltInSubgroupId: + if (!msl_options.emulate_subgroups) + break; + // For subgroup emulation, this is the same as the local invocation index. + entry_func.fixup_hooks_in.push_back([=]() { + statement(builtin_type_decl(bi_type), " ", to_expression(var_id), " = ", + to_expression(builtin_local_invocation_index_id), ";"); + }); + break; + case BuiltInNumSubgroups: + if (!msl_options.emulate_subgroups) + break; + // For subgroup emulation, this is the same as the workgroup size. + entry_func.fixup_hooks_in.push_back([=]() { + auto &type = expression_type(builtin_workgroup_size_id); + string size_expr = to_expression(builtin_workgroup_size_id); + if (type.vecsize >= 3) + size_expr = join(size_expr, ".x * ", size_expr, ".y * ", size_expr, ".z"); + else if (type.vecsize == 2) + size_expr = join(size_expr, ".x * ", size_expr, ".y"); + statement(builtin_type_decl(bi_type), " ", to_expression(var_id), " = ", size_expr, ";"); + }); + break; + case BuiltInSubgroupLocalInvocationId: + if (!msl_options.emulate_subgroups) + break; + // For subgroup emulation, assume subgroups of size 1. + entry_func.fixup_hooks_in.push_back( + [=]() { statement(builtin_type_decl(bi_type), " ", to_expression(var_id), " = 0;"); }); + break; + case BuiltInSubgroupSize: + if (msl_options.emulate_subgroups) + { + // For subgroup emulation, assume subgroups of size 1. + entry_func.fixup_hooks_in.push_back( + [=]() { statement(builtin_type_decl(bi_type), " ", to_expression(var_id), " = 1;"); }); + } + else if (msl_options.fixed_subgroup_size != 0) + { + entry_func.fixup_hooks_in.push_back([=]() { + statement(builtin_type_decl(bi_type), " ", to_expression(var_id), " = ", + msl_options.fixed_subgroup_size, ";"); + }); + } + break; + case BuiltInSubgroupEqMask: + if (msl_options.is_ios() && !msl_options.supports_msl_version(2, 2)) + SPIRV_CROSS_THROW("Subgroup ballot functionality requires Metal 2.2 on iOS."); + if (!msl_options.supports_msl_version(2, 1)) + SPIRV_CROSS_THROW("Subgroup ballot functionality requires Metal 2.1."); + entry_func.fixup_hooks_in.push_back([=]() { + if (msl_options.is_ios()) + { + statement(builtin_type_decl(bi_type), " ", to_expression(var_id), " = ", "uint4(1 << ", + to_expression(builtin_subgroup_invocation_id_id), ", uint3(0));"); + } + else + { + statement(builtin_type_decl(bi_type), " ", to_expression(var_id), " = ", + to_expression(builtin_subgroup_invocation_id_id), " >= 32 ? uint4(0, (1 << (", + to_expression(builtin_subgroup_invocation_id_id), " - 32)), uint2(0)) : uint4(1 << ", + to_expression(builtin_subgroup_invocation_id_id), ", uint3(0));"); + } + }); + break; + case BuiltInSubgroupGeMask: + if (msl_options.is_ios() && !msl_options.supports_msl_version(2, 2)) + SPIRV_CROSS_THROW("Subgroup ballot functionality requires Metal 2.2 on iOS."); + if (!msl_options.supports_msl_version(2, 1)) + SPIRV_CROSS_THROW("Subgroup ballot functionality requires Metal 2.1."); + if (msl_options.fixed_subgroup_size != 0) + add_spv_func_and_recompile(SPVFuncImplSubgroupBallot); + entry_func.fixup_hooks_in.push_back([=]() { + // Case where index < 32, size < 32: + // mask0 = bfi(0, 0xFFFFFFFF, index, size - index); + // mask1 = bfi(0, 0xFFFFFFFF, 0, 0); // Gives 0 + // Case where index < 32 but size >= 32: + // mask0 = bfi(0, 0xFFFFFFFF, index, 32 - index); + // mask1 = bfi(0, 0xFFFFFFFF, 0, size - 32); + // Case where index >= 32: + // mask0 = bfi(0, 0xFFFFFFFF, 32, 0); // Gives 0 + // mask1 = bfi(0, 0xFFFFFFFF, index - 32, size - index); + // This is expressed without branches to avoid divergent + // control flow--hence the complicated min/max expressions. + // This is further complicated by the fact that if you attempt + // to bfi/bfe out-of-bounds on Metal, undefined behavior is the + // result. + if (msl_options.fixed_subgroup_size > 32) + { + // Don't use the subgroup size variable with fixed subgroup sizes, + // since the variables could be defined in the wrong order. + statement(builtin_type_decl(bi_type), " ", to_expression(var_id), + " = uint4(insert_bits(0u, 0xFFFFFFFF, min(", + to_expression(builtin_subgroup_invocation_id_id), ", 32u), (uint)max(32 - (int)", + to_expression(builtin_subgroup_invocation_id_id), + ", 0)), insert_bits(0u, 0xFFFFFFFF," + " (uint)max((int)", + to_expression(builtin_subgroup_invocation_id_id), " - 32, 0), ", + msl_options.fixed_subgroup_size, " - max(", + to_expression(builtin_subgroup_invocation_id_id), + ", 32u)), uint2(0));"); + } + else if (msl_options.fixed_subgroup_size != 0) + { + statement(builtin_type_decl(bi_type), " ", to_expression(var_id), + " = uint4(insert_bits(0u, 0xFFFFFFFF, ", + to_expression(builtin_subgroup_invocation_id_id), ", ", + msl_options.fixed_subgroup_size, " - ", + to_expression(builtin_subgroup_invocation_id_id), + "), uint3(0));"); + } + else if (msl_options.is_ios()) + { + // On iOS, the SIMD-group size will currently never exceed 32. + statement(builtin_type_decl(bi_type), " ", to_expression(var_id), + " = uint4(insert_bits(0u, 0xFFFFFFFF, ", + to_expression(builtin_subgroup_invocation_id_id), ", ", + to_expression(builtin_subgroup_size_id), " - ", + to_expression(builtin_subgroup_invocation_id_id), "), uint3(0));"); + } + else + { + statement(builtin_type_decl(bi_type), " ", to_expression(var_id), + " = uint4(insert_bits(0u, 0xFFFFFFFF, min(", + to_expression(builtin_subgroup_invocation_id_id), ", 32u), (uint)max(min((int)", + to_expression(builtin_subgroup_size_id), ", 32) - (int)", + to_expression(builtin_subgroup_invocation_id_id), + ", 0)), insert_bits(0u, 0xFFFFFFFF, (uint)max((int)", + to_expression(builtin_subgroup_invocation_id_id), " - 32, 0), (uint)max((int)", + to_expression(builtin_subgroup_size_id), " - (int)max(", + to_expression(builtin_subgroup_invocation_id_id), ", 32u), 0)), uint2(0));"); + } + }); + break; + case BuiltInSubgroupGtMask: + if (msl_options.is_ios() && !msl_options.supports_msl_version(2, 2)) + SPIRV_CROSS_THROW("Subgroup ballot functionality requires Metal 2.2 on iOS."); + if (!msl_options.supports_msl_version(2, 1)) + SPIRV_CROSS_THROW("Subgroup ballot functionality requires Metal 2.1."); + add_spv_func_and_recompile(SPVFuncImplSubgroupBallot); + entry_func.fixup_hooks_in.push_back([=]() { + // The same logic applies here, except now the index is one + // more than the subgroup invocation ID. + if (msl_options.fixed_subgroup_size > 32) + { + statement(builtin_type_decl(bi_type), " ", to_expression(var_id), + " = uint4(insert_bits(0u, 0xFFFFFFFF, min(", + to_expression(builtin_subgroup_invocation_id_id), " + 1, 32u), (uint)max(32 - (int)", + to_expression(builtin_subgroup_invocation_id_id), + " - 1, 0)), insert_bits(0u, 0xFFFFFFFF, (uint)max((int)", + to_expression(builtin_subgroup_invocation_id_id), " + 1 - 32, 0), ", + msl_options.fixed_subgroup_size, " - max(", + to_expression(builtin_subgroup_invocation_id_id), + " + 1, 32u)), uint2(0));"); + } + else if (msl_options.fixed_subgroup_size != 0) + { + statement(builtin_type_decl(bi_type), " ", to_expression(var_id), + " = uint4(insert_bits(0u, 0xFFFFFFFF, ", + to_expression(builtin_subgroup_invocation_id_id), " + 1, ", + msl_options.fixed_subgroup_size, " - ", + to_expression(builtin_subgroup_invocation_id_id), + " - 1), uint3(0));"); + } + else if (msl_options.is_ios()) + { + statement(builtin_type_decl(bi_type), " ", to_expression(var_id), + " = uint4(insert_bits(0u, 0xFFFFFFFF, ", + to_expression(builtin_subgroup_invocation_id_id), " + 1, ", + to_expression(builtin_subgroup_size_id), " - ", + to_expression(builtin_subgroup_invocation_id_id), " - 1), uint3(0));"); + } + else + { + statement(builtin_type_decl(bi_type), " ", to_expression(var_id), + " = uint4(insert_bits(0u, 0xFFFFFFFF, min(", + to_expression(builtin_subgroup_invocation_id_id), " + 1, 32u), (uint)max(min((int)", + to_expression(builtin_subgroup_size_id), ", 32) - (int)", + to_expression(builtin_subgroup_invocation_id_id), + " - 1, 0)), insert_bits(0u, 0xFFFFFFFF, (uint)max((int)", + to_expression(builtin_subgroup_invocation_id_id), " + 1 - 32, 0), (uint)max((int)", + to_expression(builtin_subgroup_size_id), " - (int)max(", + to_expression(builtin_subgroup_invocation_id_id), " + 1, 32u), 0)), uint2(0));"); + } + }); + break; + case BuiltInSubgroupLeMask: + if (msl_options.is_ios() && !msl_options.supports_msl_version(2, 2)) + SPIRV_CROSS_THROW("Subgroup ballot functionality requires Metal 2.2 on iOS."); + if (!msl_options.supports_msl_version(2, 1)) + SPIRV_CROSS_THROW("Subgroup ballot functionality requires Metal 2.1."); + add_spv_func_and_recompile(SPVFuncImplSubgroupBallot); + entry_func.fixup_hooks_in.push_back([=]() { + if (msl_options.is_ios()) + { + statement(builtin_type_decl(bi_type), " ", to_expression(var_id), + " = uint4(extract_bits(0xFFFFFFFF, 0, ", + to_expression(builtin_subgroup_invocation_id_id), " + 1), uint3(0));"); + } + else + { + statement(builtin_type_decl(bi_type), " ", to_expression(var_id), + " = uint4(extract_bits(0xFFFFFFFF, 0, min(", + to_expression(builtin_subgroup_invocation_id_id), + " + 1, 32u)), extract_bits(0xFFFFFFFF, 0, (uint)max((int)", + to_expression(builtin_subgroup_invocation_id_id), " + 1 - 32, 0)), uint2(0));"); + } + }); + break; + case BuiltInSubgroupLtMask: + if (msl_options.is_ios() && !msl_options.supports_msl_version(2, 2)) + SPIRV_CROSS_THROW("Subgroup ballot functionality requires Metal 2.2 on iOS."); + if (!msl_options.supports_msl_version(2, 1)) + SPIRV_CROSS_THROW("Subgroup ballot functionality requires Metal 2.1."); + add_spv_func_and_recompile(SPVFuncImplSubgroupBallot); + entry_func.fixup_hooks_in.push_back([=]() { + if (msl_options.is_ios()) + { + statement(builtin_type_decl(bi_type), " ", to_expression(var_id), + " = uint4(extract_bits(0xFFFFFFFF, 0, ", + to_expression(builtin_subgroup_invocation_id_id), "), uint3(0));"); + } + else + { + statement(builtin_type_decl(bi_type), " ", to_expression(var_id), + " = uint4(extract_bits(0xFFFFFFFF, 0, min(", + to_expression(builtin_subgroup_invocation_id_id), + ", 32u)), extract_bits(0xFFFFFFFF, 0, (uint)max((int)", + to_expression(builtin_subgroup_invocation_id_id), " - 32, 0)), uint2(0));"); + } + }); + break; + case BuiltInViewIndex: + if (!msl_options.multiview) + { + // According to the Vulkan spec, when not running under a multiview + // render pass, ViewIndex is 0. + entry_func.fixup_hooks_in.push_back([=]() { + statement("const ", builtin_type_decl(bi_type), " ", to_expression(var_id), " = 0;"); + }); + } + else if (msl_options.view_index_from_device_index) + { + // In this case, we take the view index from that of the device we're running on. + entry_func.fixup_hooks_in.push_back([=]() { + statement("const ", builtin_type_decl(bi_type), " ", to_expression(var_id), " = ", + msl_options.device_index, ";"); + }); + // We actually don't want to set the render_target_array_index here. + // Since every physical device is rendering a different view, + // there's no need for layered rendering here. + } + else if (!msl_options.multiview_layered_rendering) + { + // In this case, the views are rendered one at a time. The view index, then, + // is just the first part of the "view mask". + entry_func.fixup_hooks_in.push_back([=]() { + statement("const ", builtin_type_decl(bi_type), " ", to_expression(var_id), " = ", + to_expression(view_mask_buffer_id), "[0];"); + }); + } + else if (get_execution_model() == ExecutionModelFragment) + { + // Because we adjusted the view index in the vertex shader, we have to + // adjust it back here. + entry_func.fixup_hooks_in.push_back([=]() { + statement(to_expression(var_id), " += ", to_expression(view_mask_buffer_id), "[0];"); + }); + } + else if (get_execution_model() == ExecutionModelVertex) + { + // Metal provides no special support for multiview, so we smuggle + // the view index in the instance index. + entry_func.fixup_hooks_in.push_back([=]() { + statement(builtin_type_decl(bi_type), " ", to_expression(var_id), " = ", + to_expression(view_mask_buffer_id), "[0] + (", to_expression(builtin_instance_idx_id), + " - ", to_expression(builtin_base_instance_id), ") % ", + to_expression(view_mask_buffer_id), "[1];"); + statement(to_expression(builtin_instance_idx_id), " = (", + to_expression(builtin_instance_idx_id), " - ", + to_expression(builtin_base_instance_id), ") / ", to_expression(view_mask_buffer_id), + "[1] + ", to_expression(builtin_base_instance_id), ";"); + }); + // In addition to setting the variable itself, we also need to + // set the render_target_array_index with it on output. We have to + // offset this by the base view index, because Metal isn't in on + // our little game here. + entry_func.fixup_hooks_out.push_back([=]() { + statement(to_expression(builtin_layer_id), " = ", to_expression(var_id), " - ", + to_expression(view_mask_buffer_id), "[0];"); + }); + } + break; + case BuiltInDeviceIndex: + // Metal pipelines belong to the devices which create them, so we'll + // need to create a MTLPipelineState for every MTLDevice in a grouped + // VkDevice. We can assume, then, that the device index is constant. + entry_func.fixup_hooks_in.push_back([=]() { + statement("const ", builtin_type_decl(bi_type), " ", to_expression(var_id), " = ", + msl_options.device_index, ";"); + }); + break; + case BuiltInWorkgroupId: + if (!msl_options.dispatch_base || !active_input_builtins.get(BuiltInWorkgroupId)) + break; + + // The vkCmdDispatchBase() command lets the client set the base value + // of WorkgroupId. Metal has no direct equivalent; we must make this + // adjustment ourselves. + entry_func.fixup_hooks_in.push_back([=]() { + statement(to_expression(var_id), " += ", to_dereferenced_expression(builtin_dispatch_base_id), ";"); + }); + break; + case BuiltInGlobalInvocationId: + if (!msl_options.dispatch_base || !active_input_builtins.get(BuiltInGlobalInvocationId)) + break; + + // GlobalInvocationId is defined as LocalInvocationId + WorkgroupId * WorkgroupSize. + // This needs to be adjusted too. + entry_func.fixup_hooks_in.push_back([=]() { + auto &execution = this->get_entry_point(); + uint32_t workgroup_size_id = execution.workgroup_size.constant; + if (workgroup_size_id) + statement(to_expression(var_id), " += ", to_dereferenced_expression(builtin_dispatch_base_id), + " * ", to_expression(workgroup_size_id), ";"); + else + statement(to_expression(var_id), " += ", to_dereferenced_expression(builtin_dispatch_base_id), + " * uint3(", execution.workgroup_size.x, ", ", execution.workgroup_size.y, ", ", + execution.workgroup_size.z, ");"); + }); + break; + case BuiltInVertexId: + case BuiltInVertexIndex: + // This is direct-mapped normally. + if (!msl_options.vertex_for_tessellation) + break; + + entry_func.fixup_hooks_in.push_back([=]() { + builtin_declaration = true; + switch (msl_options.vertex_index_type) + { + case Options::IndexType::None: + statement(builtin_type_decl(bi_type), " ", to_expression(var_id), " = ", + to_expression(builtin_invocation_id_id), ".x + ", + to_expression(builtin_dispatch_base_id), ".x;"); + break; + case Options::IndexType::UInt16: + case Options::IndexType::UInt32: + statement(builtin_type_decl(bi_type), " ", to_expression(var_id), " = ", index_buffer_var_name, + "[", to_expression(builtin_invocation_id_id), ".x] + ", + to_expression(builtin_dispatch_base_id), ".x;"); + break; + } + builtin_declaration = false; + }); + break; + case BuiltInBaseVertex: + // This is direct-mapped normally. + if (!msl_options.vertex_for_tessellation) + break; + + entry_func.fixup_hooks_in.push_back([=]() { + statement(builtin_type_decl(bi_type), " ", to_expression(var_id), " = ", + to_expression(builtin_dispatch_base_id), ".x;"); + }); + break; + case BuiltInInstanceId: + case BuiltInInstanceIndex: + // This is direct-mapped normally. + if (!msl_options.vertex_for_tessellation) + break; + + entry_func.fixup_hooks_in.push_back([=]() { + builtin_declaration = true; + statement(builtin_type_decl(bi_type), " ", to_expression(var_id), " = ", + to_expression(builtin_invocation_id_id), ".y + ", to_expression(builtin_dispatch_base_id), + ".y;"); + builtin_declaration = false; + }); + break; + case BuiltInBaseInstance: + // This is direct-mapped normally. + if (!msl_options.vertex_for_tessellation) + break; + + entry_func.fixup_hooks_in.push_back([=]() { + statement(builtin_type_decl(bi_type), " ", to_expression(var_id), " = ", + to_expression(builtin_dispatch_base_id), ".y;"); + }); + break; + default: + break; + } + } + else if (var.storage == StorageClassOutput && get_execution_model() == ExecutionModelFragment && + is_builtin_variable(var) && active_output_builtins.get(bi_type)) + { + switch (bi_type) + { + case BuiltInSampleMask: + if (has_additional_fixed_sample_mask()) + { + // If the additional fixed sample mask was set, we need to adjust the sample_mask + // output to reflect that. If the shader outputs the sample_mask itself too, we need + // to AND the two masks to get the final one. + string op_str = does_shader_write_sample_mask ? " &= " : " = "; + entry_func.fixup_hooks_out.push_back([=]() { + statement(to_expression(builtin_sample_mask_id), op_str, additional_fixed_sample_mask_str(), ";"); + }); + } + break; + case BuiltInFragDepth: + if (msl_options.input_attachment_is_ds_attachment && !writes_to_depth) + { + entry_func.fixup_hooks_out.push_back([=]() { + statement(to_expression(builtin_frag_depth_id), " = ", to_expression(builtin_frag_coord_id), ".z;"); + }); + } + break; + default: + break; + } + } + }); +} + +// Returns the Metal index of the resource of the specified type as used by the specified variable. +uint32_t CompilerMSL::get_metal_resource_index(SPIRVariable &var, SPIRType::BaseType basetype, uint32_t plane) +{ + auto &execution = get_entry_point(); + auto &var_dec = ir.meta[var.self].decoration; + auto &var_type = get(var.basetype); + uint32_t var_desc_set = (var.storage == StorageClassPushConstant) ? kPushConstDescSet : var_dec.set; + uint32_t var_binding = (var.storage == StorageClassPushConstant) ? kPushConstBinding : var_dec.binding; + + // If a matching binding has been specified, find and use it. + auto itr = resource_bindings.find({ execution.model, var_desc_set, var_binding }); + + // Atomic helper buffers for image atomics need to use secondary bindings as well. + bool use_secondary_binding = (var_type.basetype == SPIRType::SampledImage && basetype == SPIRType::Sampler) || + basetype == SPIRType::AtomicCounter; + + auto resource_decoration = + use_secondary_binding ? SPIRVCrossDecorationResourceIndexSecondary : SPIRVCrossDecorationResourceIndexPrimary; + + if (plane == 1) + resource_decoration = SPIRVCrossDecorationResourceIndexTertiary; + if (plane == 2) + resource_decoration = SPIRVCrossDecorationResourceIndexQuaternary; + + if (itr != end(resource_bindings)) + { + auto &remap = itr->second; + remap.second = true; + switch (basetype) + { + case SPIRType::Image: + set_extended_decoration(var.self, resource_decoration, remap.first.msl_texture + plane); + return remap.first.msl_texture + plane; + case SPIRType::Sampler: + set_extended_decoration(var.self, resource_decoration, remap.first.msl_sampler); + return remap.first.msl_sampler; + default: + set_extended_decoration(var.self, resource_decoration, remap.first.msl_buffer); + return remap.first.msl_buffer; + } + } + + // If we have already allocated an index, keep using it. + if (has_extended_decoration(var.self, resource_decoration)) + return get_extended_decoration(var.self, resource_decoration); + + auto &type = get(var.basetype); + + if (type_is_msl_framebuffer_fetch(type)) + { + // Frame-buffer fetch gets its fallback resource index from the input attachment index, + // which is then treated as color index. + return get_decoration(var.self, DecorationInputAttachmentIndex); + } + else if (msl_options.enable_decoration_binding) + { + // Allow user to enable decoration binding. + // If there is no explicit mapping of bindings to MSL, use the declared binding as a fallback. + if (has_decoration(var.self, DecorationBinding)) + { + var_binding = get_decoration(var.self, DecorationBinding); + // Avoid emitting sentinel bindings. + if (var_binding < 0x80000000u) + return var_binding; + } + } + + // If we did not explicitly remap, allocate bindings on demand. + // We cannot reliably use Binding decorations since SPIR-V and MSL's binding models are very different. + + bool allocate_argument_buffer_ids = false; + + if (var.storage != StorageClassPushConstant) + allocate_argument_buffer_ids = descriptor_set_is_argument_buffer(var_desc_set); + + uint32_t binding_stride = 1; + for (uint32_t i = 0; i < uint32_t(type.array.size()); i++) + binding_stride *= to_array_size_literal(type, i); + + // If a binding has not been specified, revert to incrementing resource indices. + uint32_t resource_index; + + if (allocate_argument_buffer_ids) + { + // Allocate from a flat ID binding space. + resource_index = next_metal_resource_ids[var_desc_set]; + next_metal_resource_ids[var_desc_set] += binding_stride; + } + else + { + if (is_var_runtime_size_array(var)) + { + basetype = SPIRType::Struct; + binding_stride = 1; + } + // Allocate from plain bindings which are allocated per resource type. + switch (basetype) + { + case SPIRType::Image: + resource_index = next_metal_resource_index_texture; + next_metal_resource_index_texture += binding_stride; + break; + case SPIRType::Sampler: + resource_index = next_metal_resource_index_sampler; + next_metal_resource_index_sampler += binding_stride; + break; + default: + resource_index = next_metal_resource_index_buffer; + next_metal_resource_index_buffer += binding_stride; + break; + } + } + + set_extended_decoration(var.self, resource_decoration, resource_index); + return resource_index; +} + +bool CompilerMSL::type_is_msl_framebuffer_fetch(const SPIRType &type) const +{ + return type.basetype == SPIRType::Image && type.image.dim == DimSubpassData && + msl_options.use_framebuffer_fetch_subpasses; +} + +const char *CompilerMSL::descriptor_address_space(uint32_t id, StorageClass storage, const char *plain_address_space) const +{ + if (msl_options.argument_buffers) + { + bool storage_class_is_descriptor = storage == StorageClassUniform || + storage == StorageClassStorageBuffer || + storage == StorageClassUniformConstant; + + uint32_t desc_set = get_decoration(id, DecorationDescriptorSet); + if (storage_class_is_descriptor && descriptor_set_is_argument_buffer(desc_set)) + { + // An awkward case where we need to emit *more* address space declarations (yay!). + // An example is where we pass down an array of buffer pointers to leaf functions. + // It's a constant array containing pointers to constants. + // The pointer array is always constant however. E.g. + // device SSBO * constant (&array)[N]. + // const device SSBO * constant (&array)[N]. + // constant SSBO * constant (&array)[N]. + // However, this only matters for argument buffers, since for MSL 1.0 style codegen, + // we emit the buffer array on stack instead, and that seems to work just fine apparently. + + // If the argument was marked as being in device address space, any pointer to member would + // be const device, not constant. + if (argument_buffer_device_storage_mask & (1u << desc_set)) + return "const device"; + else + return "constant"; + } + } + + return plain_address_space; +} + +string CompilerMSL::argument_decl(const SPIRFunction::Parameter &arg) +{ + auto &var = get(arg.id); + auto &type = get_variable_data_type(var); + auto &var_type = get(arg.type); + StorageClass type_storage = var_type.storage; + + // If we need to modify the name of the variable, make sure we use the original variable. + // Our alias is just a shadow variable. + uint32_t name_id = var.self; + if (arg.alias_global_variable && var.basevariable) + name_id = var.basevariable; + + bool constref = !arg.alias_global_variable && is_pointer(var_type) && arg.write_count == 0; + // Framebuffer fetch is plain value, const looks out of place, but it is not wrong. + if (type_is_msl_framebuffer_fetch(type)) + constref = false; + else if (type_storage == StorageClassUniformConstant) + constref = true; + + bool type_is_image = type.basetype == SPIRType::Image || type.basetype == SPIRType::SampledImage || + type.basetype == SPIRType::Sampler; + bool type_is_tlas = type.basetype == SPIRType::AccelerationStructure; + + // For opaque types we handle const later due to descriptor address spaces. + const char *cv_qualifier = (constref && !type_is_image) ? "const " : ""; + string decl; + + // If this is a combined image-sampler for a 2D image with floating-point type, + // we emitted the 'spvDynamicImageSampler' type, and this is *not* an alias parameter + // for a global, then we need to emit a "dynamic" combined image-sampler. + // Unfortunately, this is necessary to properly support passing around + // combined image-samplers with Y'CbCr conversions on them. + bool is_dynamic_img_sampler = !arg.alias_global_variable && type.basetype == SPIRType::SampledImage && + type.image.dim == Dim2D && type_is_floating_point(get(type.image.type)) && + spv_function_implementations.count(SPVFuncImplDynamicImageSampler); + + // Allow Metal to use the array template to make arrays a value type + string address_space = get_argument_address_space(var); + bool builtin = has_decoration(var.self, DecorationBuiltIn); + auto builtin_type = BuiltIn(get_decoration(arg.id, DecorationBuiltIn)); + + if (var.basevariable && (var.basevariable == stage_in_ptr_var_id || var.basevariable == stage_out_ptr_var_id)) + decl = join(cv_qualifier, type_to_glsl(type, arg.id)); + else if (builtin) + { + // Only use templated array for Clip/Cull distance when feasible. + // In other scenarios, we need need to override array length for tess levels (if used as outputs), + // or we need to emit the expected type for builtins (uint vs int). + auto storage = get(var.basetype).storage; + + if (storage == StorageClassInput && + (builtin_type == BuiltInTessLevelInner || builtin_type == BuiltInTessLevelOuter)) + { + is_using_builtin_array = false; + } + else if (builtin_type != BuiltInClipDistance && builtin_type != BuiltInCullDistance) + { + is_using_builtin_array = true; + } + + if (storage == StorageClassOutput && variable_storage_requires_stage_io(storage) && + !is_stage_output_builtin_masked(builtin_type)) + is_using_builtin_array = true; + + if (is_using_builtin_array) + decl = join(cv_qualifier, builtin_type_decl(builtin_type, arg.id)); + else + decl = join(cv_qualifier, type_to_glsl(type, arg.id)); + } + else if (is_var_runtime_size_array(var)) + { + const auto *parent_type = &get(type.parent_type); + auto type_name = type_to_glsl(*parent_type, arg.id); + if (type.basetype == SPIRType::AccelerationStructure) + decl = join("spvDescriptorArray<", type_name, ">"); + else if (type_is_image) + decl = join("spvDescriptorArray<", cv_qualifier, type_name, ">"); + else + decl = join("spvDescriptorArray<", address_space, " ", type_name, "*>"); + address_space = "const"; + } + else if ((type_storage == StorageClassUniform || type_storage == StorageClassStorageBuffer) && is_array(type)) + { + is_using_builtin_array = true; + decl += join(cv_qualifier, type_to_glsl(type, arg.id), "*"); + } + else if (is_dynamic_img_sampler) + { + decl = join(cv_qualifier, "spvDynamicImageSampler<", type_to_glsl(get(type.image.type)), ">"); + // Mark the variable so that we can handle passing it to another function. + set_extended_decoration(arg.id, SPIRVCrossDecorationDynamicImageSampler); + } + else + { + // The type is a pointer type we need to emit cv_qualifier late. + if (is_pointer(type)) + { + decl = type_to_glsl(type, arg.id); + if (*cv_qualifier != '\0') + decl += join(" ", cv_qualifier); + } + else + { + decl = join(cv_qualifier, type_to_glsl(type, arg.id)); + } + } + + if (!builtin && !is_pointer(var_type) && + (type_storage == StorageClassFunction || type_storage == StorageClassGeneric)) + { + // If the argument is a pure value and not an opaque type, we will pass by value. + if (msl_options.force_native_arrays && is_array(type)) + { + // We are receiving an array by value. This is problematic. + // We cannot be sure of the target address space since we are supposed to receive a copy, + // but this is not possible with MSL without some extra work. + // We will have to assume we're getting a reference in thread address space. + // If we happen to get a reference in constant address space, the caller must emit a copy and pass that. + // Thread const therefore becomes the only logical choice, since we cannot "create" a constant array from + // non-constant arrays, but we can create thread const from constant. + decl = string("thread const ") + decl; + decl += " (&"; + const char *restrict_kw = to_restrict(name_id, true); + if (*restrict_kw) + { + decl += " "; + decl += restrict_kw; + } + decl += to_expression(name_id); + decl += ")"; + decl += type_to_array_glsl(type, name_id); + } + else + { + if (!address_space.empty()) + decl = join(address_space, " ", decl); + decl += " "; + decl += to_expression(name_id); + } + } + else if (is_array(type) && !type_is_image) + { + // Arrays of opaque types are special cased. + if (!address_space.empty()) + decl = join(address_space, " ", decl); + + // spvDescriptorArray absorbs the address space inside the template. + if (!is_var_runtime_size_array(var)) + { + const char *argument_buffer_space = descriptor_address_space(name_id, type_storage, nullptr); + if (argument_buffer_space) + { + decl += " "; + decl += argument_buffer_space; + } + } + + // Special case, need to override the array size here if we're using tess level as an argument. + if (is_tesc_shader() && builtin && + (builtin_type == BuiltInTessLevelInner || builtin_type == BuiltInTessLevelOuter)) + { + uint32_t array_size = get_physical_tess_level_array_size(builtin_type); + if (array_size == 1) + { + decl += " &"; + decl += to_expression(name_id); + } + else + { + decl += " (&"; + decl += to_expression(name_id); + decl += ")"; + decl += join("[", array_size, "]"); + } + } + else if (is_var_runtime_size_array(var)) + { + decl += " " + to_expression(name_id); + } + else + { + auto array_size_decl = type_to_array_glsl(type, name_id); + if (array_size_decl.empty()) + decl += "& "; + else + decl += " (&"; + + const char *restrict_kw = to_restrict(name_id, true); + if (*restrict_kw) + { + decl += " "; + decl += restrict_kw; + } + decl += to_expression(name_id); + + if (!array_size_decl.empty()) + { + decl += ")"; + decl += array_size_decl; + } + } + } + else if (!type_is_image && !type_is_tlas && + (!pull_model_inputs.count(var.basevariable) || type.basetype == SPIRType::Struct)) + { + // If this is going to be a reference to a variable pointer, the address space + // for the reference has to go before the '&', but after the '*'. + if (!address_space.empty()) + { + if (is_pointer(type)) + { + if (*cv_qualifier == '\0') + decl += ' '; + decl += join(address_space, " "); + } + else + decl = join(address_space, " ", decl); + } + decl += "&"; + decl += " "; + decl += to_restrict(name_id, true); + decl += to_expression(name_id); + } + else if (type_is_image || type_is_tlas) + { + if (is_var_runtime_size_array(var)) + { + decl = address_space + " " + decl + " " + to_expression(name_id); + } + else if (type.array.empty()) + { + // For non-arrayed types we can just pass opaque descriptors by value. + // This fixes problems if descriptors are passed by value from argument buffers and plain descriptors + // in same shader. + // There is no address space we can actually use, but value will work. + // This will break if applications attempt to pass down descriptor arrays as arguments, but + // fortunately that is extremely unlikely ... + decl += " "; + decl += to_expression(name_id); + } + else + { + const char *img_address_space = descriptor_address_space(name_id, type_storage, "thread const"); + decl = join(img_address_space, " ", decl); + decl += "& "; + decl += to_expression(name_id); + } + } + else + { + if (!address_space.empty()) + decl = join(address_space, " ", decl); + decl += " "; + decl += to_expression(name_id); + } + + // Emulate texture2D atomic operations + auto *backing_var = maybe_get_backing_variable(name_id); + if (backing_var && atomic_image_vars_emulated.count(backing_var->self)) + { + auto &flags = ir.get_decoration_bitset(backing_var->self); + const char *cv_flags = decoration_flags_signal_volatile(flags) ? "volatile " : ""; + decl += join(", ", cv_flags, "device atomic_", type_to_glsl(get(var_type.image.type), 0)); + decl += "* " + to_expression(name_id) + "_atomic"; + } + + is_using_builtin_array = false; + + return decl; +} + +// If we're currently in the entry point function, and the object +// has a qualified name, use it, otherwise use the standard name. +string CompilerMSL::to_name(uint32_t id, bool allow_alias) const +{ + if (current_function && (current_function->self == ir.default_entry_point)) + { + auto *m = ir.find_meta(id); + if (m && !m->decoration.qualified_alias_explicit_override && !m->decoration.qualified_alias.empty()) + return m->decoration.qualified_alias; + } + return Compiler::to_name(id, allow_alias); +} + +// Appends the name of the member to the variable qualifier string, except for Builtins. +string CompilerMSL::append_member_name(const string &qualifier, const SPIRType &type, uint32_t index) +{ + // Don't qualify Builtin names because they are unique and are treated as such when building expressions + BuiltIn builtin = BuiltInMax; + if (is_member_builtin(type, index, &builtin)) + return builtin_to_glsl(builtin, type.storage); + + // Strip any underscore prefix from member name + string mbr_name = to_member_name(type, index); + size_t startPos = mbr_name.find_first_not_of("_"); + mbr_name = (startPos != string::npos) ? mbr_name.substr(startPos) : ""; + return join(qualifier, "_", mbr_name); +} + +// Ensures that the specified name is permanently usable by prepending a prefix +// if the first chars are _ and a digit, which indicate a transient name. +string CompilerMSL::ensure_valid_name(string name, string pfx) +{ + return (name.size() >= 2 && name[0] == '_' && isdigit(name[1])) ? (pfx + name) : name; +} + +const std::unordered_set &CompilerMSL::get_reserved_keyword_set() +{ + static const unordered_set keywords = { + "kernel", + "vertex", + "fragment", + "compute", + "constant", + "device", + "bias", + "level", + "gradient2d", + "gradientcube", + "gradient3d", + "min_lod_clamp", + "assert", + "VARIABLE_TRACEPOINT", + "STATIC_DATA_TRACEPOINT", + "STATIC_DATA_TRACEPOINT_V", + "METAL_ALIGN", + "METAL_ASM", + "METAL_CONST", + "METAL_DEPRECATED", + "METAL_ENABLE_IF", + "METAL_FUNC", + "METAL_INTERNAL", + "METAL_NON_NULL_RETURN", + "METAL_NORETURN", + "METAL_NOTHROW", + "METAL_PURE", + "METAL_UNAVAILABLE", + "METAL_IMPLICIT", + "METAL_EXPLICIT", + "METAL_CONST_ARG", + "METAL_ARG_UNIFORM", + "METAL_ZERO_ARG", + "METAL_VALID_LOD_ARG", + "METAL_VALID_LEVEL_ARG", + "METAL_VALID_STORE_ORDER", + "METAL_VALID_LOAD_ORDER", + "METAL_VALID_COMPARE_EXCHANGE_FAILURE_ORDER", + "METAL_COMPATIBLE_COMPARE_EXCHANGE_ORDERS", + "METAL_VALID_RENDER_TARGET", + "is_function_constant_defined", + "CHAR_BIT", + "SCHAR_MAX", + "SCHAR_MIN", + "UCHAR_MAX", + "CHAR_MAX", + "CHAR_MIN", + "USHRT_MAX", + "SHRT_MAX", + "SHRT_MIN", + "UINT_MAX", + "INT_MAX", + "INT_MIN", + "FLT_DIG", + "FLT_MANT_DIG", + "FLT_MAX_10_EXP", + "FLT_MAX_EXP", + "FLT_MIN_10_EXP", + "FLT_MIN_EXP", + "FLT_RADIX", + "FLT_MAX", + "FLT_MIN", + "FLT_EPSILON", + "FP_ILOGB0", + "FP_ILOGBNAN", + "MAXFLOAT", + "HUGE_VALF", + "INFINITY", + "NAN", + "M_E_F", + "M_LOG2E_F", + "M_LOG10E_F", + "M_LN2_F", + "M_LN10_F", + "M_PI_F", + "M_PI_2_F", + "M_PI_4_F", + "M_1_PI_F", + "M_2_PI_F", + "M_2_SQRTPI_F", + "M_SQRT2_F", + "M_SQRT1_2_F", + "HALF_DIG", + "HALF_MANT_DIG", + "HALF_MAX_10_EXP", + "HALF_MAX_EXP", + "HALF_MIN_10_EXP", + "HALF_MIN_EXP", + "HALF_RADIX", + "HALF_MAX", + "HALF_MIN", + "HALF_EPSILON", + "MAXHALF", + "HUGE_VALH", + "M_E_H", + "M_LOG2E_H", + "M_LOG10E_H", + "M_LN2_H", + "M_LN10_H", + "M_PI_H", + "M_PI_2_H", + "M_PI_4_H", + "M_1_PI_H", + "M_2_PI_H", + "M_2_SQRTPI_H", + "M_SQRT2_H", + "M_SQRT1_2_H", + "DBL_DIG", + "DBL_MANT_DIG", + "DBL_MAX_10_EXP", + "DBL_MAX_EXP", + "DBL_MIN_10_EXP", + "DBL_MIN_EXP", + "DBL_RADIX", + "DBL_MAX", + "DBL_MIN", + "DBL_EPSILON", + "HUGE_VAL", + "M_E", + "M_LOG2E", + "M_LOG10E", + "M_LN2", + "M_LN10", + "M_PI", + "M_PI_2", + "M_PI_4", + "M_1_PI", + "M_2_PI", + "M_2_SQRTPI", + "M_SQRT2", + "M_SQRT1_2", + "quad_broadcast", + "thread", + "threadgroup", + }; + + return keywords; +} + +const std::unordered_set &CompilerMSL::get_illegal_func_names() +{ + static const unordered_set illegal_func_names = { + "main", + "saturate", + "assert", + "fmin3", + "fmax3", + "divide", + "median3", + "VARIABLE_TRACEPOINT", + "STATIC_DATA_TRACEPOINT", + "STATIC_DATA_TRACEPOINT_V", + "METAL_ALIGN", + "METAL_ASM", + "METAL_CONST", + "METAL_DEPRECATED", + "METAL_ENABLE_IF", + "METAL_FUNC", + "METAL_INTERNAL", + "METAL_NON_NULL_RETURN", + "METAL_NORETURN", + "METAL_NOTHROW", + "METAL_PURE", + "METAL_UNAVAILABLE", + "METAL_IMPLICIT", + "METAL_EXPLICIT", + "METAL_CONST_ARG", + "METAL_ARG_UNIFORM", + "METAL_ZERO_ARG", + "METAL_VALID_LOD_ARG", + "METAL_VALID_LEVEL_ARG", + "METAL_VALID_STORE_ORDER", + "METAL_VALID_LOAD_ORDER", + "METAL_VALID_COMPARE_EXCHANGE_FAILURE_ORDER", + "METAL_COMPATIBLE_COMPARE_EXCHANGE_ORDERS", + "METAL_VALID_RENDER_TARGET", + "is_function_constant_defined", + "CHAR_BIT", + "SCHAR_MAX", + "SCHAR_MIN", + "UCHAR_MAX", + "CHAR_MAX", + "CHAR_MIN", + "USHRT_MAX", + "SHRT_MAX", + "SHRT_MIN", + "UINT_MAX", + "INT_MAX", + "INT_MIN", + "FLT_DIG", + "FLT_MANT_DIG", + "FLT_MAX_10_EXP", + "FLT_MAX_EXP", + "FLT_MIN_10_EXP", + "FLT_MIN_EXP", + "FLT_RADIX", + "FLT_MAX", + "FLT_MIN", + "FLT_EPSILON", + "FP_ILOGB0", + "FP_ILOGBNAN", + "MAXFLOAT", + "HUGE_VALF", + "INFINITY", + "NAN", + "M_E_F", + "M_LOG2E_F", + "M_LOG10E_F", + "M_LN2_F", + "M_LN10_F", + "M_PI_F", + "M_PI_2_F", + "M_PI_4_F", + "M_1_PI_F", + "M_2_PI_F", + "M_2_SQRTPI_F", + "M_SQRT2_F", + "M_SQRT1_2_F", + "HALF_DIG", + "HALF_MANT_DIG", + "HALF_MAX_10_EXP", + "HALF_MAX_EXP", + "HALF_MIN_10_EXP", + "HALF_MIN_EXP", + "HALF_RADIX", + "HALF_MAX", + "HALF_MIN", + "HALF_EPSILON", + "MAXHALF", + "HUGE_VALH", + "M_E_H", + "M_LOG2E_H", + "M_LOG10E_H", + "M_LN2_H", + "M_LN10_H", + "M_PI_H", + "M_PI_2_H", + "M_PI_4_H", + "M_1_PI_H", + "M_2_PI_H", + "M_2_SQRTPI_H", + "M_SQRT2_H", + "M_SQRT1_2_H", + "DBL_DIG", + "DBL_MANT_DIG", + "DBL_MAX_10_EXP", + "DBL_MAX_EXP", + "DBL_MIN_10_EXP", + "DBL_MIN_EXP", + "DBL_RADIX", + "DBL_MAX", + "DBL_MIN", + "DBL_EPSILON", + "HUGE_VAL", + "M_E", + "M_LOG2E", + "M_LOG10E", + "M_LN2", + "M_LN10", + "M_PI", + "M_PI_2", + "M_PI_4", + "M_1_PI", + "M_2_PI", + "M_2_SQRTPI", + "M_SQRT2", + "M_SQRT1_2", + }; + + return illegal_func_names; +} + +// Replace all names that match MSL keywords or Metal Standard Library functions. +void CompilerMSL::replace_illegal_names() +{ + // FIXME: MSL and GLSL are doing two different things here. + // Agree on convention and remove this override. + auto &keywords = get_reserved_keyword_set(); + auto &illegal_func_names = get_illegal_func_names(); + + ir.for_each_typed_id([&](uint32_t self, SPIRVariable &) { + auto *meta = ir.find_meta(self); + if (!meta) + return; + + auto &dec = meta->decoration; + if (keywords.find(dec.alias) != end(keywords)) + dec.alias += "0"; + }); + + ir.for_each_typed_id([&](uint32_t self, SPIRFunction &) { + auto *meta = ir.find_meta(self); + if (!meta) + return; + + auto &dec = meta->decoration; + if (illegal_func_names.find(dec.alias) != end(illegal_func_names)) + dec.alias += "0"; + }); + + ir.for_each_typed_id([&](uint32_t self, SPIRType &) { + auto *meta = ir.find_meta(self); + if (!meta) + return; + + for (auto &mbr_dec : meta->members) + if (keywords.find(mbr_dec.alias) != end(keywords)) + mbr_dec.alias += "0"; + }); + + CompilerGLSL::replace_illegal_names(); +} + +void CompilerMSL::replace_illegal_entry_point_names() +{ + auto &illegal_func_names = get_illegal_func_names(); + + // It is important to this before we fixup identifiers, + // since if ep_name is reserved, we will need to fix that up, + // and then copy alias back into entry.name after the fixup. + for (auto &entry : ir.entry_points) + { + // Change both the entry point name and the alias, to keep them synced. + string &ep_name = entry.second.name; + if (illegal_func_names.find(ep_name) != end(illegal_func_names)) + ep_name += "0"; + + ir.meta[entry.first].decoration.alias = ep_name; + } +} + +void CompilerMSL::sync_entry_point_aliases_and_names() +{ + for (auto &entry : ir.entry_points) + entry.second.name = ir.meta[entry.first].decoration.alias; +} + +string CompilerMSL::to_member_reference(uint32_t base, const SPIRType &type, uint32_t index, bool ptr_chain_is_resolved) +{ + auto *var = maybe_get_backing_variable(base); + // If this is a buffer array, we have to dereference the buffer pointers. + // Otherwise, if this is a pointer expression, dereference it. + + bool declared_as_pointer = false; + + if (var) + { + // Only allow -> dereference for block types. This is so we get expressions like + // buffer[i]->first_member.second_member, rather than buffer[i]->first->second. + const bool is_block = + has_decoration(type.self, DecorationBlock) || has_decoration(type.self, DecorationBufferBlock); + + bool is_buffer_variable = + is_block && (var->storage == StorageClassUniform || var->storage == StorageClassStorageBuffer); + declared_as_pointer = is_buffer_variable && is_array(get_pointee_type(var->basetype)); + } + + if (declared_as_pointer || (!ptr_chain_is_resolved && should_dereference(base))) + return join("->", to_member_name(type, index)); + else + return join(".", to_member_name(type, index)); +} + +string CompilerMSL::to_qualifiers_glsl(uint32_t id) +{ + string quals; + + auto *var = maybe_get(id); + auto &type = expression_type(id); + + if (type.storage == StorageClassWorkgroup || (var && variable_decl_is_remapped_storage(*var, StorageClassWorkgroup))) + quals += "threadgroup "; + + return quals; +} + +// The optional id parameter indicates the object whose type we are trying +// to find the description for. It is optional. Most type descriptions do not +// depend on a specific object's use of that type. +string CompilerMSL::type_to_glsl(const SPIRType &type, uint32_t id, bool member) +{ + string type_name; + + // Pointer? + if (is_pointer(type) || type_is_array_of_pointers(type)) + { + assert(type.pointer_depth > 0); + + const char *restrict_kw; + + auto type_address_space = get_type_address_space(type, id); + const auto *p_parent_type = &get(type.parent_type); + + // If we're wrapping buffer descriptors in a spvDescriptorArray, we'll have to handle it as a special case. + if (member && id) + { + auto &var = get(id); + if (is_var_runtime_size_array(var) && is_runtime_size_array(*p_parent_type)) + { + const bool ssbo = has_decoration(p_parent_type->self, DecorationBufferBlock); + bool buffer_desc = + (var.storage == StorageClassStorageBuffer || ssbo) && + msl_options.runtime_array_rich_descriptor; + + const char *wrapper_type = buffer_desc ? "spvBufferDescriptor" : "spvDescriptor"; + add_spv_func_and_recompile(SPVFuncImplVariableDescriptorArray); + add_spv_func_and_recompile(buffer_desc ? SPVFuncImplVariableSizedDescriptor : SPVFuncImplVariableDescriptor); + + type_name = join(wrapper_type, "<", type_address_space, " ", type_to_glsl(*p_parent_type, id), " *>"); + return type_name; + } + } + + // Work around C pointer qualifier rules. If glsl_type is a pointer type as well + // we'll need to emit the address space to the right. + // We could always go this route, but it makes the code unnatural. + // Prefer emitting thread T *foo over T thread* foo since it's more readable, + // but we'll have to emit thread T * thread * T constant bar; for example. + if (is_pointer(type) && is_pointer(*p_parent_type)) + type_name = join(type_to_glsl(*p_parent_type, id), " ", type_address_space, " "); + else + { + // Since this is not a pointer-to-pointer, ensure we've dug down to the base type. + // Some situations chain pointers even though they are not formally pointers-of-pointers. + while (is_pointer(*p_parent_type)) + p_parent_type = &get(p_parent_type->parent_type); + + // If we're emitting BDA, just use the templated type. + // Emitting builtin arrays need a lot of cooperation with other code to ensure + // the C-style nesting works right. + // FIXME: This is somewhat of a hack. + bool old_is_using_builtin_array = is_using_builtin_array; + if (is_physical_pointer(type)) + is_using_builtin_array = false; + + type_name = join(type_address_space, " ", type_to_glsl(*p_parent_type, id)); + + is_using_builtin_array = old_is_using_builtin_array; + } + + switch (type.basetype) + { + case SPIRType::Image: + case SPIRType::SampledImage: + case SPIRType::Sampler: + // These are handles. + break; + default: + // Anything else can be a raw pointer. + type_name += "*"; + restrict_kw = to_restrict(id, false); + if (*restrict_kw) + { + type_name += " "; + type_name += restrict_kw; + } + break; + } + return type_name; + } + + switch (type.basetype) + { + case SPIRType::Struct: + // Need OpName lookup here to get a "sensible" name for a struct. + // Allow Metal to use the array template to make arrays a value type + type_name = to_name(type.self); + break; + + case SPIRType::Image: + case SPIRType::SampledImage: + return image_type_glsl(type, id, member); + + case SPIRType::Sampler: + return sampler_type(type, id, member); + + case SPIRType::Void: + return "void"; + + case SPIRType::AtomicCounter: + return "atomic_uint"; + + case SPIRType::ControlPointArray: + return join("patch_control_point<", type_to_glsl(get(type.parent_type), id), ">"); + + case SPIRType::Interpolant: + return join("interpolant<", type_to_glsl(get(type.parent_type), id), ", interpolation::", + has_decoration(type.self, DecorationNoPerspective) ? "no_perspective" : "perspective", ">"); + + // Scalars + case SPIRType::Boolean: + { + auto *var = maybe_get_backing_variable(id); + if (var && var->basevariable) + var = &get(var->basevariable); + + // Need to special-case threadgroup booleans. They are supposed to be logical + // storage, but MSL compilers will sometimes crash if you use threadgroup bool. + // Workaround this by using 16-bit types instead and fixup on load-store to this data. + if ((var && var->storage == StorageClassWorkgroup) || type.storage == StorageClassWorkgroup || member) + type_name = "short"; + else + type_name = "bool"; + break; + } + + case SPIRType::Char: + case SPIRType::SByte: + type_name = "char"; + break; + case SPIRType::UByte: + type_name = "uchar"; + break; + case SPIRType::Short: + type_name = "short"; + break; + case SPIRType::UShort: + type_name = "ushort"; + break; + case SPIRType::Int: + type_name = "int"; + break; + case SPIRType::UInt: + type_name = "uint"; + break; + case SPIRType::Int64: + if (!msl_options.supports_msl_version(2, 2)) + SPIRV_CROSS_THROW("64-bit integers are only supported in MSL 2.2 and above."); + type_name = "long"; + break; + case SPIRType::UInt64: + if (!msl_options.supports_msl_version(2, 2)) + SPIRV_CROSS_THROW("64-bit integers are only supported in MSL 2.2 and above."); + type_name = "ulong"; + break; + case SPIRType::Half: + type_name = "half"; + break; + case SPIRType::Float: + type_name = "float"; + break; + case SPIRType::Double: + type_name = "double"; // Currently unsupported + break; + case SPIRType::AccelerationStructure: + if (msl_options.supports_msl_version(2, 4)) + type_name = "raytracing::acceleration_structure"; + else if (msl_options.supports_msl_version(2, 3)) + type_name = "raytracing::instance_acceleration_structure"; + else + SPIRV_CROSS_THROW("Acceleration Structure Type is supported in MSL 2.3 and above."); + break; + case SPIRType::RayQuery: + return "raytracing::intersection_query"; + + default: + return "unknown_type"; + } + + // Matrix? + if (type.columns > 1) + { + auto *var = maybe_get_backing_variable(id); + if (var && var->basevariable) + var = &get(var->basevariable); + + // Need to special-case threadgroup matrices. Due to an oversight, Metal's + // matrix struct prior to Metal 3 lacks constructors in the threadgroup AS, + // preventing us from default-constructing or initializing matrices in threadgroup storage. + // Work around this by using our own type as storage. + if (((var && var->storage == StorageClassWorkgroup) || type.storage == StorageClassWorkgroup) && + !msl_options.supports_msl_version(3, 0)) + { + add_spv_func_and_recompile(SPVFuncImplStorageMatrix); + type_name = "spvStorage_" + type_name; + } + + type_name += to_string(type.columns) + "x"; + } + + // Vector or Matrix? + if (type.vecsize > 1) + type_name += to_string(type.vecsize); + + if (type.array.empty() || using_builtin_array()) + { + return type_name; + } + else + { + // Allow Metal to use the array template to make arrays a value type + add_spv_func_and_recompile(SPVFuncImplUnsafeArray); + string res; + string sizes; + + for (uint32_t i = 0; i < uint32_t(type.array.size()); i++) + { + res += "spvUnsafeArray<"; + sizes += ", "; + sizes += to_array_size(type, i); + sizes += ">"; + } + + res += type_name + sizes; + return res; + } +} + +string CompilerMSL::type_to_glsl(const SPIRType &type, uint32_t id) +{ + return type_to_glsl(type, id, false); +} + +string CompilerMSL::type_to_array_glsl(const SPIRType &type, uint32_t variable_id) +{ + // Allow Metal to use the array template to make arrays a value type + switch (type.basetype) + { + case SPIRType::AtomicCounter: + case SPIRType::ControlPointArray: + case SPIRType::RayQuery: + return CompilerGLSL::type_to_array_glsl(type, variable_id); + + default: + if (type_is_array_of_pointers(type) || using_builtin_array()) + { + const SPIRVariable *var = variable_id ? &get(variable_id) : nullptr; + if (var && (var->storage == StorageClassUniform || var->storage == StorageClassStorageBuffer) && + is_array(get_variable_data_type(*var))) + { + return join("[", get_resource_array_size(type, variable_id), "]"); + } + else + return CompilerGLSL::type_to_array_glsl(type, variable_id); + } + else + return ""; + } +} + +string CompilerMSL::constant_op_expression(const SPIRConstantOp &cop) +{ + switch (cop.opcode) + { + case OpQuantizeToF16: + add_spv_func_and_recompile(SPVFuncImplQuantizeToF16); + return join("spvQuantizeToF16(", to_expression(cop.arguments[0]), ")"); + default: + return CompilerGLSL::constant_op_expression(cop); + } +} + +bool CompilerMSL::variable_decl_is_remapped_storage(const SPIRVariable &variable, spv::StorageClass storage) const +{ + if (variable.storage == storage) + return true; + + if (storage == StorageClassWorkgroup) + { + // Specially masked IO block variable. + // Normally, we will never access IO blocks directly here. + // The only scenario which that should occur is with a masked IO block. + if (is_tesc_shader() && variable.storage == StorageClassOutput && + has_decoration(get(variable.basetype).self, DecorationBlock)) + { + return true; + } + + return variable.storage == StorageClassOutput && is_tesc_shader() && is_stage_output_variable_masked(variable); + } + else if (storage == StorageClassStorageBuffer) + { + // These builtins are passed directly; we don't want to use remapping + // for them. + auto builtin = (BuiltIn)get_decoration(variable.self, DecorationBuiltIn); + if (is_tese_shader() && is_builtin_variable(variable) && (builtin == BuiltInTessCoord || builtin == BuiltInPrimitiveId)) + return false; + + // We won't be able to catch writes to control point outputs here since variable + // refers to a function local pointer. + // This is fine, as there cannot be concurrent writers to that memory anyways, + // so we just ignore that case. + + return (variable.storage == StorageClassOutput || variable.storage == StorageClassInput) && + !variable_storage_requires_stage_io(variable.storage) && + (variable.storage != StorageClassOutput || !is_stage_output_variable_masked(variable)); + } + else + { + return false; + } +} + +// GCC workaround of lambdas calling protected funcs +std::string CompilerMSL::variable_decl(const SPIRType &type, const std::string &name, uint32_t id) +{ + return CompilerGLSL::variable_decl(type, name, id); +} + +std::string CompilerMSL::sampler_type(const SPIRType &type, uint32_t id, bool member) +{ + auto *var = maybe_get(id); + if (var && var->basevariable) + { + // Check against the base variable, and not a fake ID which might have been generated for this variable. + id = var->basevariable; + } + + if (!type.array.empty()) + { + if (!msl_options.supports_msl_version(2)) + SPIRV_CROSS_THROW("MSL 2.0 or greater is required for arrays of samplers."); + + if (type.array.size() > 1) + SPIRV_CROSS_THROW("Arrays of arrays of samplers are not supported in MSL."); + + // Arrays of samplers in MSL must be declared with a special array syntax ala C++11 std::array. + // If we have a runtime array, it could be a variable-count descriptor set binding. + auto &parent = get(get_pointee_type(type).parent_type); + uint32_t array_size = get_resource_array_size(type, id); + + if (array_size == 0) + { + add_spv_func_and_recompile(SPVFuncImplVariableDescriptor); + add_spv_func_and_recompile(SPVFuncImplVariableDescriptorArray); + + const char *descriptor_wrapper = processing_entry_point ? "const device spvDescriptor" : "const spvDescriptorArray"; + if (member) + descriptor_wrapper = "spvDescriptor"; + return join(descriptor_wrapper, "<", sampler_type(parent, id, false), ">", + processing_entry_point ? "*" : ""); + } + else + { + return join("array<", sampler_type(parent, id, false), ", ", array_size, ">"); + } + } + else + return "sampler"; +} + +// Returns an MSL string describing the SPIR-V image type +string CompilerMSL::image_type_glsl(const SPIRType &type, uint32_t id, bool member) +{ + auto *var = maybe_get(id); + if (var && var->basevariable) + { + // For comparison images, check against the base variable, + // and not the fake ID which might have been generated for this variable. + id = var->basevariable; + } + + if (!type.array.empty()) + { + uint32_t major = 2, minor = 0; + if (msl_options.is_ios()) + { + major = 1; + minor = 2; + } + if (!msl_options.supports_msl_version(major, minor)) + { + if (msl_options.is_ios()) + SPIRV_CROSS_THROW("MSL 1.2 or greater is required for arrays of textures."); + else + SPIRV_CROSS_THROW("MSL 2.0 or greater is required for arrays of textures."); + } + + if (type.array.size() > 1) + SPIRV_CROSS_THROW("Arrays of arrays of textures are not supported in MSL."); + + // Arrays of images in MSL must be declared with a special array syntax ala C++11 std::array. + // If we have a runtime array, it could be a variable-count descriptor set binding. + auto &parent = get(get_pointee_type(type).parent_type); + uint32_t array_size = get_resource_array_size(type, id); + + if (array_size == 0) + { + add_spv_func_and_recompile(SPVFuncImplVariableDescriptor); + add_spv_func_and_recompile(SPVFuncImplVariableDescriptorArray); + const char *descriptor_wrapper = processing_entry_point ? "const device spvDescriptor" : "const spvDescriptorArray"; + if (member) + { + descriptor_wrapper = "spvDescriptor"; + // This requires a specialized wrapper type that packs image and sampler side by side. + // It is possible in theory. + if (type.basetype == SPIRType::SampledImage) + SPIRV_CROSS_THROW("Argument buffer runtime array currently not supported for combined image sampler."); + } + return join(descriptor_wrapper, "<", image_type_glsl(parent, id, false), ">", + processing_entry_point ? "*" : ""); + } + else + { + return join("array<", image_type_glsl(parent, id, false), ", ", array_size, ">"); + } + } + + string img_type_name; + + auto &img_type = type.image; + + if (is_depth_image(type, id)) + { + switch (img_type.dim) + { + case Dim1D: + case Dim2D: + if (img_type.dim == Dim1D && !msl_options.texture_1D_as_2D) + { + // Use a native Metal 1D texture + img_type_name += "depth1d_unsupported_by_metal"; + break; + } + + if (img_type.ms && img_type.arrayed) + { + if (!msl_options.supports_msl_version(2, 1)) + SPIRV_CROSS_THROW("Multisampled array textures are supported from 2.1."); + img_type_name += "depth2d_ms_array"; + } + else if (img_type.ms) + img_type_name += "depth2d_ms"; + else if (img_type.arrayed) + img_type_name += "depth2d_array"; + else + img_type_name += "depth2d"; + break; + case Dim3D: + img_type_name += "depth3d_unsupported_by_metal"; + break; + case DimCube: + if (!msl_options.emulate_cube_array) + img_type_name += (img_type.arrayed ? "depthcube_array" : "depthcube"); + else + img_type_name += (img_type.arrayed ? "depth2d_array" : "depthcube"); + break; + default: + img_type_name += "unknown_depth_texture_type"; + break; + } + } + else + { + switch (img_type.dim) + { + case DimBuffer: + if (img_type.ms || img_type.arrayed) + SPIRV_CROSS_THROW("Cannot use texel buffers with multisampling or array layers."); + + if (msl_options.texture_buffer_native) + { + if (!msl_options.supports_msl_version(2, 1)) + SPIRV_CROSS_THROW("Native texture_buffer type is only supported in MSL 2.1."); + img_type_name = "texture_buffer"; + } + else + img_type_name += "texture2d"; + break; + case Dim1D: + case Dim2D: + case DimSubpassData: + { + bool subpass_array = + img_type.dim == DimSubpassData && (msl_options.multiview || msl_options.arrayed_subpass_input); + if (img_type.dim == Dim1D && !msl_options.texture_1D_as_2D) + { + // Use a native Metal 1D texture + img_type_name += (img_type.arrayed ? "texture1d_array" : "texture1d"); + break; + } + + // Use Metal's native frame-buffer fetch API for subpass inputs. + if (type_is_msl_framebuffer_fetch(type)) + { + auto img_type_4 = get(img_type.type); + img_type_4.vecsize = 4; + return type_to_glsl(img_type_4); + } + if (img_type.ms && (img_type.arrayed || subpass_array)) + { + if (!msl_options.supports_msl_version(2, 1)) + SPIRV_CROSS_THROW("Multisampled array textures are supported from 2.1."); + img_type_name += "texture2d_ms_array"; + } + else if (img_type.ms) + img_type_name += "texture2d_ms"; + else if (img_type.arrayed || subpass_array) + img_type_name += "texture2d_array"; + else + img_type_name += "texture2d"; + break; + } + case Dim3D: + img_type_name += "texture3d"; + break; + case DimCube: + if (!msl_options.emulate_cube_array) + img_type_name += (img_type.arrayed ? "texturecube_array" : "texturecube"); + else + img_type_name += (img_type.arrayed ? "texture2d_array" : "texturecube"); + break; + default: + img_type_name += "unknown_texture_type"; + break; + } + } + + // Append the pixel type + img_type_name += "<"; + img_type_name += type_to_glsl(get(img_type.type)); + + // For unsampled images, append the sample/read/write access qualifier. + // For kernel images, the access qualifier my be supplied directly by SPIR-V. + // Otherwise it may be set based on whether the image is read from or written to within the shader. + if (type.basetype == SPIRType::Image && type.image.sampled == 2 && type.image.dim != DimSubpassData) + { + switch (img_type.access) + { + case AccessQualifierReadOnly: + img_type_name += ", access::read"; + break; + + case AccessQualifierWriteOnly: + img_type_name += ", access::write"; + break; + + case AccessQualifierReadWrite: + img_type_name += ", access::read_write"; + break; + + default: + { + auto *p_var = maybe_get_backing_variable(id); + if (p_var && p_var->basevariable) + p_var = maybe_get(p_var->basevariable); + if (p_var && !has_decoration(p_var->self, DecorationNonWritable)) + { + img_type_name += ", access::"; + + if (!has_decoration(p_var->self, DecorationNonReadable)) + img_type_name += "read_"; + + img_type_name += "write"; + } + break; + } + } + } + + img_type_name += ">"; + + return img_type_name; +} + +void CompilerMSL::emit_subgroup_op(const Instruction &i) +{ + const uint32_t *ops = stream(i); + auto op = static_cast(i.op); + + if (msl_options.emulate_subgroups) + { + // In this mode, only the GroupNonUniform cap is supported. The only op + // we need to handle, then, is OpGroupNonUniformElect. + if (op != OpGroupNonUniformElect) + SPIRV_CROSS_THROW("Subgroup emulation does not support operations other than Elect."); + // In this mode, the subgroup size is assumed to be one, so every invocation + // is elected. + emit_op(ops[0], ops[1], "true", true); + return; + } + + // Metal 2.0 is required. iOS only supports quad ops on 11.0 (2.0), with + // full support in 13.0 (2.2). macOS only supports broadcast and shuffle on + // 10.13 (2.0), with full support in 10.14 (2.1). + // Note that Apple GPUs before A13 make no distinction between a quad-group + // and a SIMD-group; all SIMD-groups are quad-groups on those. + if (!msl_options.supports_msl_version(2)) + SPIRV_CROSS_THROW("Subgroups are only supported in Metal 2.0 and up."); + + // If we need to do implicit bitcasts, make sure we do it with the correct type. + uint32_t integer_width = get_integer_width_for_instruction(i); + auto int_type = to_signed_basetype(integer_width); + auto uint_type = to_unsigned_basetype(integer_width); + + if (msl_options.is_ios() && (!msl_options.supports_msl_version(2, 3) || !msl_options.ios_use_simdgroup_functions)) + { + switch (op) + { + default: + SPIRV_CROSS_THROW("Subgroup ops beyond broadcast, ballot, and shuffle on iOS require Metal 2.3 and up."); + case OpGroupNonUniformBroadcastFirst: + if (!msl_options.supports_msl_version(2, 2)) + SPIRV_CROSS_THROW("BroadcastFirst on iOS requires Metal 2.2 and up."); + break; + case OpGroupNonUniformElect: + if (!msl_options.supports_msl_version(2, 2)) + SPIRV_CROSS_THROW("Elect on iOS requires Metal 2.2 and up."); + break; + case OpGroupNonUniformAny: + case OpGroupNonUniformAll: + case OpGroupNonUniformAllEqual: + case OpGroupNonUniformBallot: + case OpGroupNonUniformInverseBallot: + case OpGroupNonUniformBallotBitExtract: + case OpGroupNonUniformBallotFindLSB: + case OpGroupNonUniformBallotFindMSB: + case OpGroupNonUniformBallotBitCount: + case OpSubgroupBallotKHR: + case OpSubgroupAllKHR: + case OpSubgroupAnyKHR: + case OpSubgroupAllEqualKHR: + if (!msl_options.supports_msl_version(2, 2)) + SPIRV_CROSS_THROW("Ballot ops on iOS requires Metal 2.2 and up."); + break; + case OpGroupNonUniformBroadcast: + case OpGroupNonUniformShuffle: + case OpGroupNonUniformShuffleXor: + case OpGroupNonUniformShuffleUp: + case OpGroupNonUniformShuffleDown: + case OpGroupNonUniformQuadSwap: + case OpGroupNonUniformQuadBroadcast: + case OpSubgroupReadInvocationKHR: + break; + } + } + + if (msl_options.is_macos() && !msl_options.supports_msl_version(2, 1)) + { + switch (op) + { + default: + SPIRV_CROSS_THROW("Subgroup ops beyond broadcast and shuffle on macOS require Metal 2.1 and up."); + case OpGroupNonUniformBroadcast: + case OpGroupNonUniformShuffle: + case OpGroupNonUniformShuffleXor: + case OpGroupNonUniformShuffleUp: + case OpGroupNonUniformShuffleDown: + case OpSubgroupReadInvocationKHR: + break; + } + } + + uint32_t op_idx = 0; + uint32_t result_type = ops[op_idx++]; + uint32_t id = ops[op_idx++]; + + Scope scope; + switch (op) + { + case OpSubgroupBallotKHR: + case OpSubgroupFirstInvocationKHR: + case OpSubgroupReadInvocationKHR: + case OpSubgroupAllKHR: + case OpSubgroupAnyKHR: + case OpSubgroupAllEqualKHR: + // These earlier instructions don't have the scope operand. + scope = ScopeSubgroup; + break; + default: + scope = static_cast(evaluate_constant_u32(ops[op_idx++])); + break; + } + if (scope != ScopeSubgroup) + SPIRV_CROSS_THROW("Only subgroup scope is supported."); + + switch (op) + { + case OpGroupNonUniformElect: + if (msl_options.use_quadgroup_operation()) + emit_op(result_type, id, "quad_is_first()", false); + else + emit_op(result_type, id, "simd_is_first()", false); + break; + + case OpGroupNonUniformBroadcast: + case OpSubgroupReadInvocationKHR: + emit_binary_func_op(result_type, id, ops[op_idx], ops[op_idx + 1], "spvSubgroupBroadcast"); + break; + + case OpGroupNonUniformBroadcastFirst: + case OpSubgroupFirstInvocationKHR: + emit_unary_func_op(result_type, id, ops[op_idx], "spvSubgroupBroadcastFirst"); + break; + + case OpGroupNonUniformBallot: + case OpSubgroupBallotKHR: + emit_unary_func_op(result_type, id, ops[op_idx], "spvSubgroupBallot"); + break; + + case OpGroupNonUniformInverseBallot: + emit_binary_func_op(result_type, id, ops[op_idx], builtin_subgroup_invocation_id_id, "spvSubgroupBallotBitExtract"); + break; + + case OpGroupNonUniformBallotBitExtract: + emit_binary_func_op(result_type, id, ops[op_idx], ops[op_idx + 1], "spvSubgroupBallotBitExtract"); + break; + + case OpGroupNonUniformBallotFindLSB: + emit_binary_func_op(result_type, id, ops[op_idx], builtin_subgroup_size_id, "spvSubgroupBallotFindLSB"); + break; + + case OpGroupNonUniformBallotFindMSB: + emit_binary_func_op(result_type, id, ops[op_idx], builtin_subgroup_size_id, "spvSubgroupBallotFindMSB"); + break; + + case OpGroupNonUniformBallotBitCount: + { + auto operation = static_cast(ops[op_idx++]); + switch (operation) + { + case GroupOperationReduce: + emit_binary_func_op(result_type, id, ops[op_idx], builtin_subgroup_size_id, "spvSubgroupBallotBitCount"); + break; + case GroupOperationInclusiveScan: + emit_binary_func_op(result_type, id, ops[op_idx], builtin_subgroup_invocation_id_id, + "spvSubgroupBallotInclusiveBitCount"); + break; + case GroupOperationExclusiveScan: + emit_binary_func_op(result_type, id, ops[op_idx], builtin_subgroup_invocation_id_id, + "spvSubgroupBallotExclusiveBitCount"); + break; + default: + SPIRV_CROSS_THROW("Invalid BitCount operation."); + } + break; + } + + case OpGroupNonUniformShuffle: + emit_binary_func_op(result_type, id, ops[op_idx], ops[op_idx + 1], "spvSubgroupShuffle"); + break; + + case OpGroupNonUniformShuffleXor: + emit_binary_func_op(result_type, id, ops[op_idx], ops[op_idx + 1], "spvSubgroupShuffleXor"); + break; + + case OpGroupNonUniformShuffleUp: + emit_binary_func_op(result_type, id, ops[op_idx], ops[op_idx + 1], "spvSubgroupShuffleUp"); + break; + + case OpGroupNonUniformShuffleDown: + emit_binary_func_op(result_type, id, ops[op_idx], ops[op_idx + 1], "spvSubgroupShuffleDown"); + break; + + case OpGroupNonUniformAll: + case OpSubgroupAllKHR: + if (msl_options.use_quadgroup_operation()) + emit_unary_func_op(result_type, id, ops[op_idx], "quad_all"); + else + emit_unary_func_op(result_type, id, ops[op_idx], "simd_all"); + break; + + case OpGroupNonUniformAny: + case OpSubgroupAnyKHR: + if (msl_options.use_quadgroup_operation()) + emit_unary_func_op(result_type, id, ops[op_idx], "quad_any"); + else + emit_unary_func_op(result_type, id, ops[op_idx], "simd_any"); + break; + + case OpGroupNonUniformAllEqual: + case OpSubgroupAllEqualKHR: + emit_unary_func_op(result_type, id, ops[op_idx], "spvSubgroupAllEqual"); + break; + + // clang-format off +#define MSL_GROUP_OP(op, msl_op) \ +case OpGroupNonUniform##op: \ + { \ + auto operation = static_cast(ops[op_idx++]); \ + if (operation == GroupOperationReduce) \ + emit_unary_func_op(result_type, id, ops[op_idx], "simd_" #msl_op); \ + else if (operation == GroupOperationInclusiveScan) \ + emit_unary_func_op(result_type, id, ops[op_idx], "simd_prefix_inclusive_" #msl_op); \ + else if (operation == GroupOperationExclusiveScan) \ + emit_unary_func_op(result_type, id, ops[op_idx], "simd_prefix_exclusive_" #msl_op); \ + else if (operation == GroupOperationClusteredReduce) \ + { \ + /* Only cluster sizes of 4 are supported. */ \ + uint32_t cluster_size = evaluate_constant_u32(ops[op_idx + 1]); \ + if (cluster_size != 4) \ + SPIRV_CROSS_THROW("Metal only supports quad ClusteredReduce."); \ + emit_unary_func_op(result_type, id, ops[op_idx], "quad_" #msl_op); \ + } \ + else \ + SPIRV_CROSS_THROW("Invalid group operation."); \ + break; \ + } + MSL_GROUP_OP(FAdd, sum) + MSL_GROUP_OP(FMul, product) + MSL_GROUP_OP(IAdd, sum) + MSL_GROUP_OP(IMul, product) +#undef MSL_GROUP_OP + // The others, unfortunately, don't support InclusiveScan or ExclusiveScan. + +#define MSL_GROUP_OP(op, msl_op) \ +case OpGroupNonUniform##op: \ + { \ + auto operation = static_cast(ops[op_idx++]); \ + if (operation == GroupOperationReduce) \ + emit_unary_func_op(result_type, id, ops[op_idx], "simd_" #msl_op); \ + else if (operation == GroupOperationInclusiveScan) \ + SPIRV_CROSS_THROW("Metal doesn't support InclusiveScan for OpGroupNonUniform" #op "."); \ + else if (operation == GroupOperationExclusiveScan) \ + SPIRV_CROSS_THROW("Metal doesn't support ExclusiveScan for OpGroupNonUniform" #op "."); \ + else if (operation == GroupOperationClusteredReduce) \ + { \ + /* Only cluster sizes of 4 are supported. */ \ + uint32_t cluster_size = evaluate_constant_u32(ops[op_idx + 1]); \ + if (cluster_size != 4) \ + SPIRV_CROSS_THROW("Metal only supports quad ClusteredReduce."); \ + emit_unary_func_op(result_type, id, ops[op_idx], "quad_" #msl_op); \ + } \ + else \ + SPIRV_CROSS_THROW("Invalid group operation."); \ + break; \ + } + +#define MSL_GROUP_OP_CAST(op, msl_op, type) \ +case OpGroupNonUniform##op: \ + { \ + auto operation = static_cast(ops[op_idx++]); \ + if (operation == GroupOperationReduce) \ + emit_unary_func_op_cast(result_type, id, ops[op_idx], "simd_" #msl_op, type, type); \ + else if (operation == GroupOperationInclusiveScan) \ + SPIRV_CROSS_THROW("Metal doesn't support InclusiveScan for OpGroupNonUniform" #op "."); \ + else if (operation == GroupOperationExclusiveScan) \ + SPIRV_CROSS_THROW("Metal doesn't support ExclusiveScan for OpGroupNonUniform" #op "."); \ + else if (operation == GroupOperationClusteredReduce) \ + { \ + /* Only cluster sizes of 4 are supported. */ \ + uint32_t cluster_size = evaluate_constant_u32(ops[op_idx + 1]); \ + if (cluster_size != 4) \ + SPIRV_CROSS_THROW("Metal only supports quad ClusteredReduce."); \ + emit_unary_func_op_cast(result_type, id, ops[op_idx], "quad_" #msl_op, type, type); \ + } \ + else \ + SPIRV_CROSS_THROW("Invalid group operation."); \ + break; \ + } + + MSL_GROUP_OP(FMin, min) + MSL_GROUP_OP(FMax, max) + MSL_GROUP_OP_CAST(SMin, min, int_type) + MSL_GROUP_OP_CAST(SMax, max, int_type) + MSL_GROUP_OP_CAST(UMin, min, uint_type) + MSL_GROUP_OP_CAST(UMax, max, uint_type) + MSL_GROUP_OP(BitwiseAnd, and) + MSL_GROUP_OP(BitwiseOr, or) + MSL_GROUP_OP(BitwiseXor, xor) + MSL_GROUP_OP(LogicalAnd, and) + MSL_GROUP_OP(LogicalOr, or) + MSL_GROUP_OP(LogicalXor, xor) + // clang-format on +#undef MSL_GROUP_OP +#undef MSL_GROUP_OP_CAST + + case OpGroupNonUniformQuadSwap: + emit_binary_func_op(result_type, id, ops[op_idx], ops[op_idx + 1], "spvQuadSwap"); + break; + + case OpGroupNonUniformQuadBroadcast: + emit_binary_func_op(result_type, id, ops[op_idx], ops[op_idx + 1], "spvQuadBroadcast"); + break; + + default: + SPIRV_CROSS_THROW("Invalid opcode for subgroup."); + } + + register_control_dependent_expression(id); +} + +string CompilerMSL::bitcast_glsl_op(const SPIRType &out_type, const SPIRType &in_type) +{ + if (out_type.basetype == in_type.basetype) + return ""; + + assert(out_type.basetype != SPIRType::Boolean); + assert(in_type.basetype != SPIRType::Boolean); + + bool integral_cast = type_is_integral(out_type) && type_is_integral(in_type) && (out_type.vecsize == in_type.vecsize); + bool same_size_cast = (out_type.width * out_type.vecsize) == (in_type.width * in_type.vecsize); + + // Bitcasting can only be used between types of the same overall size. + // And always formally cast between integers, because it's trivial, and also + // because Metal can internally cast the results of some integer ops to a larger + // size (eg. short shift right becomes int), which means chaining integer ops + // together may introduce size variations that SPIR-V doesn't know about. + if (same_size_cast && !integral_cast) + return "as_type<" + type_to_glsl(out_type) + ">"; + else + return type_to_glsl(out_type); +} + +bool CompilerMSL::emit_complex_bitcast(uint32_t, uint32_t, uint32_t) +{ + // This is handled from the outside where we deal with PtrToU/UToPtr and friends. + return false; +} + +// Returns an MSL string identifying the name of a SPIR-V builtin. +// Output builtins are qualified with the name of the stage out structure. +string CompilerMSL::builtin_to_glsl(BuiltIn builtin, StorageClass storage) +{ + switch (builtin) + { + // Handle HLSL-style 0-based vertex/instance index. + // Override GLSL compiler strictness + case BuiltInVertexId: + ensure_builtin(StorageClassInput, BuiltInVertexId); + if (msl_options.enable_base_index_zero && msl_options.supports_msl_version(1, 1) && + (msl_options.ios_support_base_vertex_instance || msl_options.is_macos())) + { + if (builtin_declaration) + { + if (needs_base_vertex_arg != TriState::No) + needs_base_vertex_arg = TriState::Yes; + return "gl_VertexID"; + } + else + { + ensure_builtin(StorageClassInput, BuiltInBaseVertex); + return "(gl_VertexID - gl_BaseVertex)"; + } + } + else + { + return "gl_VertexID"; + } + case BuiltInInstanceId: + ensure_builtin(StorageClassInput, BuiltInInstanceId); + if (msl_options.enable_base_index_zero && msl_options.supports_msl_version(1, 1) && + (msl_options.ios_support_base_vertex_instance || msl_options.is_macos())) + { + if (builtin_declaration) + { + if (needs_base_instance_arg != TriState::No) + needs_base_instance_arg = TriState::Yes; + return "gl_InstanceID"; + } + else + { + ensure_builtin(StorageClassInput, BuiltInBaseInstance); + return "(gl_InstanceID - gl_BaseInstance)"; + } + } + else + { + return "gl_InstanceID"; + } + case BuiltInVertexIndex: + ensure_builtin(StorageClassInput, BuiltInVertexIndex); + if (msl_options.enable_base_index_zero && msl_options.supports_msl_version(1, 1) && + (msl_options.ios_support_base_vertex_instance || msl_options.is_macos())) + { + if (builtin_declaration) + { + if (needs_base_vertex_arg != TriState::No) + needs_base_vertex_arg = TriState::Yes; + return "gl_VertexIndex"; + } + else + { + ensure_builtin(StorageClassInput, BuiltInBaseVertex); + return "(gl_VertexIndex - gl_BaseVertex)"; + } + } + else + { + return "gl_VertexIndex"; + } + case BuiltInInstanceIndex: + ensure_builtin(StorageClassInput, BuiltInInstanceIndex); + if (msl_options.enable_base_index_zero && msl_options.supports_msl_version(1, 1) && + (msl_options.ios_support_base_vertex_instance || msl_options.is_macos())) + { + if (builtin_declaration) + { + if (needs_base_instance_arg != TriState::No) + needs_base_instance_arg = TriState::Yes; + return "gl_InstanceIndex"; + } + else + { + ensure_builtin(StorageClassInput, BuiltInBaseInstance); + return "(gl_InstanceIndex - gl_BaseInstance)"; + } + } + else + { + return "gl_InstanceIndex"; + } + case BuiltInBaseVertex: + if (msl_options.supports_msl_version(1, 1) && + (msl_options.ios_support_base_vertex_instance || msl_options.is_macos())) + { + needs_base_vertex_arg = TriState::No; + return "gl_BaseVertex"; + } + else + { + SPIRV_CROSS_THROW("BaseVertex requires Metal 1.1 and Mac or Apple A9+ hardware."); + } + case BuiltInBaseInstance: + if (msl_options.supports_msl_version(1, 1) && + (msl_options.ios_support_base_vertex_instance || msl_options.is_macos())) + { + needs_base_instance_arg = TriState::No; + return "gl_BaseInstance"; + } + else + { + SPIRV_CROSS_THROW("BaseInstance requires Metal 1.1 and Mac or Apple A9+ hardware."); + } + case BuiltInDrawIndex: + SPIRV_CROSS_THROW("DrawIndex is not supported in MSL."); + + // When used in the entry function, output builtins are qualified with output struct name. + // Test storage class as NOT Input, as output builtins might be part of generic type. + // Also don't do this for tessellation control shaders. + case BuiltInViewportIndex: + if (!msl_options.supports_msl_version(2, 0)) + SPIRV_CROSS_THROW("ViewportIndex requires Metal 2.0."); + /* fallthrough */ + case BuiltInFragDepth: + case BuiltInFragStencilRefEXT: + if ((builtin == BuiltInFragDepth && !msl_options.enable_frag_depth_builtin) || + (builtin == BuiltInFragStencilRefEXT && !msl_options.enable_frag_stencil_ref_builtin)) + break; + /* fallthrough */ + case BuiltInPosition: + case BuiltInPointSize: + case BuiltInClipDistance: + case BuiltInCullDistance: + case BuiltInLayer: + if (is_tesc_shader()) + break; + if (storage != StorageClassInput && current_function && (current_function->self == ir.default_entry_point) && + !is_stage_output_builtin_masked(builtin)) + return stage_out_var_name + "." + CompilerGLSL::builtin_to_glsl(builtin, storage); + break; + + case BuiltInSampleMask: + if (storage == StorageClassInput && current_function && (current_function->self == ir.default_entry_point) && + (has_additional_fixed_sample_mask() || needs_sample_id)) + { + string samp_mask_in; + samp_mask_in += "(" + CompilerGLSL::builtin_to_glsl(builtin, storage); + if (has_additional_fixed_sample_mask()) + samp_mask_in += " & " + additional_fixed_sample_mask_str(); + if (needs_sample_id) + samp_mask_in += " & (1 << gl_SampleID)"; + samp_mask_in += ")"; + return samp_mask_in; + } + if (storage != StorageClassInput && current_function && (current_function->self == ir.default_entry_point) && + !is_stage_output_builtin_masked(builtin)) + return stage_out_var_name + "." + CompilerGLSL::builtin_to_glsl(builtin, storage); + break; + + case BuiltInBaryCoordKHR: + case BuiltInBaryCoordNoPerspKHR: + if (storage == StorageClassInput && current_function && (current_function->self == ir.default_entry_point)) + return stage_in_var_name + "." + CompilerGLSL::builtin_to_glsl(builtin, storage); + break; + + case BuiltInTessLevelOuter: + if (is_tesc_shader() && storage != StorageClassInput && current_function && + (current_function->self == ir.default_entry_point)) + { + return join(tess_factor_buffer_var_name, "[", to_expression(builtin_primitive_id_id), + "].edgeTessellationFactor"); + } + break; + + case BuiltInTessLevelInner: + if (is_tesc_shader() && storage != StorageClassInput && current_function && + (current_function->self == ir.default_entry_point)) + { + return join(tess_factor_buffer_var_name, "[", to_expression(builtin_primitive_id_id), + "].insideTessellationFactor"); + } + break; + + case BuiltInHelperInvocation: + if (needs_manual_helper_invocation_updates()) + break; + if (msl_options.is_ios() && !msl_options.supports_msl_version(2, 3)) + SPIRV_CROSS_THROW("simd_is_helper_thread() requires version 2.3 on iOS."); + else if (msl_options.is_macos() && !msl_options.supports_msl_version(2, 1)) + SPIRV_CROSS_THROW("simd_is_helper_thread() requires version 2.1 on macOS."); + // In SPIR-V 1.6 with Volatile HelperInvocation, we cannot emit a fixup early. + return "simd_is_helper_thread()"; + + default: + break; + } + + return CompilerGLSL::builtin_to_glsl(builtin, storage); +} + +// Returns an MSL string attribute qualifer for a SPIR-V builtin +string CompilerMSL::builtin_qualifier(BuiltIn builtin) +{ + auto &execution = get_entry_point(); + + switch (builtin) + { + // Vertex function in + case BuiltInVertexId: + return "vertex_id"; + case BuiltInVertexIndex: + return "vertex_id"; + case BuiltInBaseVertex: + return "base_vertex"; + case BuiltInInstanceId: + return "instance_id"; + case BuiltInInstanceIndex: + return "instance_id"; + case BuiltInBaseInstance: + return "base_instance"; + case BuiltInDrawIndex: + SPIRV_CROSS_THROW("DrawIndex is not supported in MSL."); + + // Vertex function out + case BuiltInClipDistance: + return "clip_distance"; + case BuiltInPointSize: + return "point_size"; + case BuiltInPosition: + if (position_invariant) + { + if (!msl_options.supports_msl_version(2, 1)) + SPIRV_CROSS_THROW("Invariant position is only supported on MSL 2.1 and up."); + return "position, invariant"; + } + else + return "position"; + case BuiltInLayer: + return "render_target_array_index"; + case BuiltInViewportIndex: + if (!msl_options.supports_msl_version(2, 0)) + SPIRV_CROSS_THROW("ViewportIndex requires Metal 2.0."); + return "viewport_array_index"; + + // Tess. control function in + case BuiltInInvocationId: + if (msl_options.multi_patch_workgroup) + { + // Shouldn't be reached. + SPIRV_CROSS_THROW("InvocationId is computed manually with multi-patch workgroups in MSL."); + } + return "thread_index_in_threadgroup"; + case BuiltInPatchVertices: + // Shouldn't be reached. + SPIRV_CROSS_THROW("PatchVertices is derived from the auxiliary buffer in MSL."); + case BuiltInPrimitiveId: + switch (execution.model) + { + case ExecutionModelTessellationControl: + if (msl_options.multi_patch_workgroup) + { + // Shouldn't be reached. + SPIRV_CROSS_THROW("PrimitiveId is computed manually with multi-patch workgroups in MSL."); + } + return "threadgroup_position_in_grid"; + case ExecutionModelTessellationEvaluation: + return "patch_id"; + case ExecutionModelFragment: + if (msl_options.is_ios() && !msl_options.supports_msl_version(2, 3)) + SPIRV_CROSS_THROW("PrimitiveId on iOS requires MSL 2.3."); + else if (msl_options.is_macos() && !msl_options.supports_msl_version(2, 2)) + SPIRV_CROSS_THROW("PrimitiveId on macOS requires MSL 2.2."); + return "primitive_id"; + default: + SPIRV_CROSS_THROW("PrimitiveId is not supported in this execution model."); + } + + // Tess. control function out + case BuiltInTessLevelOuter: + case BuiltInTessLevelInner: + // Shouldn't be reached. + SPIRV_CROSS_THROW("Tessellation levels are handled specially in MSL."); + + // Tess. evaluation function in + case BuiltInTessCoord: + return "position_in_patch"; + + // Fragment function in + case BuiltInFrontFacing: + return "front_facing"; + case BuiltInPointCoord: + return "point_coord"; + case BuiltInFragCoord: + return "position"; + case BuiltInSampleId: + return "sample_id"; + case BuiltInSampleMask: + return "sample_mask"; + case BuiltInSamplePosition: + // Shouldn't be reached. + SPIRV_CROSS_THROW("Sample position is retrieved by a function in MSL."); + case BuiltInViewIndex: + if (execution.model != ExecutionModelFragment) + SPIRV_CROSS_THROW("ViewIndex is handled specially outside fragment shaders."); + // The ViewIndex was implicitly used in the prior stages to set the render_target_array_index, + // so we can get it from there. + return "render_target_array_index"; + + // Fragment function out + case BuiltInFragDepth: + if (execution.flags.get(ExecutionModeDepthGreater)) + return "depth(greater)"; + else if (execution.flags.get(ExecutionModeDepthLess)) + return "depth(less)"; + else + return "depth(any)"; + + case BuiltInFragStencilRefEXT: + return "stencil"; + + // Compute function in + case BuiltInGlobalInvocationId: + return "thread_position_in_grid"; + + case BuiltInWorkgroupId: + return "threadgroup_position_in_grid"; + + case BuiltInNumWorkgroups: + return "threadgroups_per_grid"; + + case BuiltInLocalInvocationId: + return "thread_position_in_threadgroup"; + + case BuiltInLocalInvocationIndex: + return "thread_index_in_threadgroup"; + + case BuiltInSubgroupSize: + if (msl_options.emulate_subgroups || msl_options.fixed_subgroup_size != 0) + // Shouldn't be reached. + SPIRV_CROSS_THROW("Emitting threads_per_simdgroup attribute with fixed subgroup size??"); + if (execution.model == ExecutionModelFragment) + { + if (!msl_options.supports_msl_version(2, 2)) + SPIRV_CROSS_THROW("threads_per_simdgroup requires Metal 2.2 in fragment shaders."); + return "threads_per_simdgroup"; + } + else + { + // thread_execution_width is an alias for threads_per_simdgroup, and it's only available since 1.0, + // but not in fragment. + return "thread_execution_width"; + } + + case BuiltInNumSubgroups: + if (msl_options.emulate_subgroups) + // Shouldn't be reached. + SPIRV_CROSS_THROW("NumSubgroups is handled specially with emulation."); + if (!msl_options.supports_msl_version(2)) + SPIRV_CROSS_THROW("Subgroup builtins require Metal 2.0."); + return msl_options.use_quadgroup_operation() ? "quadgroups_per_threadgroup" : "simdgroups_per_threadgroup"; + + case BuiltInSubgroupId: + if (msl_options.emulate_subgroups) + // Shouldn't be reached. + SPIRV_CROSS_THROW("SubgroupId is handled specially with emulation."); + if (!msl_options.supports_msl_version(2)) + SPIRV_CROSS_THROW("Subgroup builtins require Metal 2.0."); + return msl_options.use_quadgroup_operation() ? "quadgroup_index_in_threadgroup" : "simdgroup_index_in_threadgroup"; + + case BuiltInSubgroupLocalInvocationId: + if (msl_options.emulate_subgroups) + // Shouldn't be reached. + SPIRV_CROSS_THROW("SubgroupLocalInvocationId is handled specially with emulation."); + if (execution.model == ExecutionModelFragment) + { + if (!msl_options.supports_msl_version(2, 2)) + SPIRV_CROSS_THROW("thread_index_in_simdgroup requires Metal 2.2 in fragment shaders."); + return "thread_index_in_simdgroup"; + } + else if (execution.model == ExecutionModelKernel || execution.model == ExecutionModelGLCompute || + execution.model == ExecutionModelTessellationControl || + (execution.model == ExecutionModelVertex && msl_options.vertex_for_tessellation)) + { + // We are generating a Metal kernel function. + if (!msl_options.supports_msl_version(2)) + SPIRV_CROSS_THROW("Subgroup builtins in kernel functions require Metal 2.0."); + return msl_options.use_quadgroup_operation() ? "thread_index_in_quadgroup" : "thread_index_in_simdgroup"; + } + else + SPIRV_CROSS_THROW("Subgroup builtins are not available in this type of function."); + + case BuiltInSubgroupEqMask: + case BuiltInSubgroupGeMask: + case BuiltInSubgroupGtMask: + case BuiltInSubgroupLeMask: + case BuiltInSubgroupLtMask: + // Shouldn't be reached. + SPIRV_CROSS_THROW("Subgroup ballot masks are handled specially in MSL."); + + case BuiltInBaryCoordKHR: + if (msl_options.is_ios() && !msl_options.supports_msl_version(2, 3)) + SPIRV_CROSS_THROW("Barycentrics are only supported in MSL 2.3 and above on iOS."); + else if (!msl_options.supports_msl_version(2, 2)) + SPIRV_CROSS_THROW("Barycentrics are only supported in MSL 2.2 and above on macOS."); + return "barycentric_coord, center_perspective"; + + case BuiltInBaryCoordNoPerspKHR: + if (msl_options.is_ios() && !msl_options.supports_msl_version(2, 3)) + SPIRV_CROSS_THROW("Barycentrics are only supported in MSL 2.3 and above on iOS."); + else if (!msl_options.supports_msl_version(2, 2)) + SPIRV_CROSS_THROW("Barycentrics are only supported in MSL 2.2 and above on macOS."); + return "barycentric_coord, center_no_perspective"; + + default: + return "unsupported-built-in"; + } +} + +// Returns an MSL string type declaration for a SPIR-V builtin +string CompilerMSL::builtin_type_decl(BuiltIn builtin, uint32_t id) +{ + switch (builtin) + { + // Vertex function in + case BuiltInVertexId: + return "uint"; + case BuiltInVertexIndex: + return "uint"; + case BuiltInBaseVertex: + return "uint"; + case BuiltInInstanceId: + return "uint"; + case BuiltInInstanceIndex: + return "uint"; + case BuiltInBaseInstance: + return "uint"; + case BuiltInDrawIndex: + SPIRV_CROSS_THROW("DrawIndex is not supported in MSL."); + + // Vertex function out + case BuiltInClipDistance: + case BuiltInCullDistance: + return "float"; + case BuiltInPointSize: + return "float"; + case BuiltInPosition: + return "float4"; + case BuiltInLayer: + return "uint"; + case BuiltInViewportIndex: + if (!msl_options.supports_msl_version(2, 0)) + SPIRV_CROSS_THROW("ViewportIndex requires Metal 2.0."); + return "uint"; + + // Tess. control function in + case BuiltInInvocationId: + return "uint"; + case BuiltInPatchVertices: + return "uint"; + case BuiltInPrimitiveId: + return "uint"; + + // Tess. control function out + case BuiltInTessLevelInner: + if (is_tese_shader()) + return (msl_options.raw_buffer_tese_input || is_tessellating_triangles()) ? "float" : "float2"; + return "half"; + case BuiltInTessLevelOuter: + if (is_tese_shader()) + return (msl_options.raw_buffer_tese_input || is_tessellating_triangles()) ? "float" : "float4"; + return "half"; + + // Tess. evaluation function in + case BuiltInTessCoord: + return "float3"; + + // Fragment function in + case BuiltInFrontFacing: + return "bool"; + case BuiltInPointCoord: + return "float2"; + case BuiltInFragCoord: + return "float4"; + case BuiltInSampleId: + return "uint"; + case BuiltInSampleMask: + return "uint"; + case BuiltInSamplePosition: + return "float2"; + case BuiltInViewIndex: + return "uint"; + + case BuiltInHelperInvocation: + return "bool"; + + case BuiltInBaryCoordKHR: + case BuiltInBaryCoordNoPerspKHR: + // Use the type as declared, can be 1, 2 or 3 components. + return type_to_glsl(get_variable_data_type(get(id))); + + // Fragment function out + case BuiltInFragDepth: + return "float"; + + case BuiltInFragStencilRefEXT: + return "uint"; + + // Compute function in + case BuiltInGlobalInvocationId: + case BuiltInLocalInvocationId: + case BuiltInNumWorkgroups: + case BuiltInWorkgroupId: + return "uint3"; + case BuiltInLocalInvocationIndex: + case BuiltInNumSubgroups: + case BuiltInSubgroupId: + case BuiltInSubgroupSize: + case BuiltInSubgroupLocalInvocationId: + return "uint"; + case BuiltInSubgroupEqMask: + case BuiltInSubgroupGeMask: + case BuiltInSubgroupGtMask: + case BuiltInSubgroupLeMask: + case BuiltInSubgroupLtMask: + return "uint4"; + + case BuiltInDeviceIndex: + return "int"; + + default: + return "unsupported-built-in-type"; + } +} + +// Returns the declaration of a built-in argument to a function +string CompilerMSL::built_in_func_arg(BuiltIn builtin, bool prefix_comma) +{ + string bi_arg; + if (prefix_comma) + bi_arg += ", "; + + // Handle HLSL-style 0-based vertex/instance index. + builtin_declaration = true; + bi_arg += builtin_type_decl(builtin); + bi_arg += string(" ") + builtin_to_glsl(builtin, StorageClassInput); + bi_arg += string(" [[") + builtin_qualifier(builtin) + string("]]"); + builtin_declaration = false; + + return bi_arg; +} + +const SPIRType &CompilerMSL::get_physical_member_type(const SPIRType &type, uint32_t index) const +{ + if (member_is_remapped_physical_type(type, index)) + return get(get_extended_member_decoration(type.self, index, SPIRVCrossDecorationPhysicalTypeID)); + else + return get(type.member_types[index]); +} + +SPIRType CompilerMSL::get_presumed_input_type(const SPIRType &ib_type, uint32_t index) const +{ + SPIRType type = get_physical_member_type(ib_type, index); + uint32_t loc = get_member_decoration(ib_type.self, index, DecorationLocation); + uint32_t cmp = get_member_decoration(ib_type.self, index, DecorationComponent); + auto p_va = inputs_by_location.find({loc, cmp}); + if (p_va != end(inputs_by_location) && p_va->second.vecsize > type.vecsize) + type.vecsize = p_va->second.vecsize; + + return type; +} + +uint32_t CompilerMSL::get_declared_type_array_stride_msl(const SPIRType &type, bool is_packed, bool row_major) const +{ + // Array stride in MSL is always size * array_size. sizeof(float3) == 16, + // unlike GLSL and HLSL where array stride would be 16 and size 12. + + // We could use parent type here and recurse, but that makes creating physical type remappings + // far more complicated. We'd rather just create the final type, and ignore having to create the entire type + // hierarchy in order to compute this value, so make a temporary type on the stack. + + auto basic_type = type; + basic_type.array.clear(); + basic_type.array_size_literal.clear(); + uint32_t value_size = get_declared_type_size_msl(basic_type, is_packed, row_major); + + uint32_t dimensions = uint32_t(type.array.size()); + assert(dimensions > 0); + dimensions--; + + // Multiply together every dimension, except the last one. + for (uint32_t dim = 0; dim < dimensions; dim++) + { + uint32_t array_size = to_array_size_literal(type, dim); + value_size *= max(array_size, 1u); + } + + return value_size; +} + +uint32_t CompilerMSL::get_declared_struct_member_array_stride_msl(const SPIRType &type, uint32_t index) const +{ + return get_declared_type_array_stride_msl(get_physical_member_type(type, index), + member_is_packed_physical_type(type, index), + has_member_decoration(type.self, index, DecorationRowMajor)); +} + +uint32_t CompilerMSL::get_declared_input_array_stride_msl(const SPIRType &type, uint32_t index) const +{ + return get_declared_type_array_stride_msl(get_presumed_input_type(type, index), false, + has_member_decoration(type.self, index, DecorationRowMajor)); +} + +uint32_t CompilerMSL::get_declared_type_matrix_stride_msl(const SPIRType &type, bool packed, bool row_major) const +{ + // For packed matrices, we just use the size of the vector type. + // Otherwise, MatrixStride == alignment, which is the size of the underlying vector type. + if (packed) + return (type.width / 8) * ((row_major && type.columns > 1) ? type.columns : type.vecsize); + else + return get_declared_type_alignment_msl(type, false, row_major); +} + +uint32_t CompilerMSL::get_declared_struct_member_matrix_stride_msl(const SPIRType &type, uint32_t index) const +{ + return get_declared_type_matrix_stride_msl(get_physical_member_type(type, index), + member_is_packed_physical_type(type, index), + has_member_decoration(type.self, index, DecorationRowMajor)); +} + +uint32_t CompilerMSL::get_declared_input_matrix_stride_msl(const SPIRType &type, uint32_t index) const +{ + return get_declared_type_matrix_stride_msl(get_presumed_input_type(type, index), false, + has_member_decoration(type.self, index, DecorationRowMajor)); +} + +uint32_t CompilerMSL::get_declared_struct_size_msl(const SPIRType &struct_type, bool ignore_alignment, + bool ignore_padding) const +{ + // If we have a target size, that is the declared size as well. + if (!ignore_padding && has_extended_decoration(struct_type.self, SPIRVCrossDecorationPaddingTarget)) + return get_extended_decoration(struct_type.self, SPIRVCrossDecorationPaddingTarget); + + if (struct_type.member_types.empty()) + return 0; + + uint32_t mbr_cnt = uint32_t(struct_type.member_types.size()); + + // In MSL, a struct's alignment is equal to the maximum alignment of any of its members. + uint32_t alignment = 1; + + if (!ignore_alignment) + { + for (uint32_t i = 0; i < mbr_cnt; i++) + { + uint32_t mbr_alignment = get_declared_struct_member_alignment_msl(struct_type, i); + alignment = max(alignment, mbr_alignment); + } + } + + // Last member will always be matched to the final Offset decoration, but size of struct in MSL now depends + // on physical size in MSL, and the size of the struct itself is then aligned to struct alignment. + uint32_t spirv_offset = type_struct_member_offset(struct_type, mbr_cnt - 1); + uint32_t msl_size = spirv_offset + get_declared_struct_member_size_msl(struct_type, mbr_cnt - 1); + msl_size = (msl_size + alignment - 1) & ~(alignment - 1); + return msl_size; +} + +uint32_t CompilerMSL::get_physical_type_stride(const SPIRType &type) const +{ + // This should only be relevant for plain types such as scalars and vectors? + // If we're pointing to a struct, it will recursively pick up packed/row-major state. + return get_declared_type_size_msl(type, false, false); +} + +// Returns the byte size of a struct member. +uint32_t CompilerMSL::get_declared_type_size_msl(const SPIRType &type, bool is_packed, bool row_major) const +{ + // Pointers take 8 bytes each + // Match both pointer and array-of-pointer here. + if (type.pointer && type.storage == StorageClassPhysicalStorageBuffer) + { + uint32_t type_size = 8; + + // Work our way through potentially layered arrays, + // stopping when we hit a pointer that is not also an array. + int32_t dim_idx = (int32_t)type.array.size() - 1; + auto *p_type = &type; + while (!is_pointer(*p_type) && dim_idx >= 0) + { + type_size *= to_array_size_literal(*p_type, dim_idx); + p_type = &get(p_type->parent_type); + dim_idx--; + } + + return type_size; + } + + switch (type.basetype) + { + case SPIRType::Unknown: + case SPIRType::Void: + case SPIRType::AtomicCounter: + case SPIRType::Image: + case SPIRType::SampledImage: + case SPIRType::Sampler: + SPIRV_CROSS_THROW("Querying size of opaque object."); + + default: + { + if (!type.array.empty()) + { + uint32_t array_size = to_array_size_literal(type); + return get_declared_type_array_stride_msl(type, is_packed, row_major) * max(array_size, 1u); + } + + if (type.basetype == SPIRType::Struct) + return get_declared_struct_size_msl(type); + + if (is_packed) + { + return type.vecsize * type.columns * (type.width / 8); + } + else + { + // An unpacked 3-element vector or matrix column is the same memory size as a 4-element. + uint32_t vecsize = type.vecsize; + uint32_t columns = type.columns; + + if (row_major && columns > 1) + swap(vecsize, columns); + + if (vecsize == 3) + vecsize = 4; + + return vecsize * columns * (type.width / 8); + } + } + } +} + +uint32_t CompilerMSL::get_declared_struct_member_size_msl(const SPIRType &type, uint32_t index) const +{ + return get_declared_type_size_msl(get_physical_member_type(type, index), + member_is_packed_physical_type(type, index), + has_member_decoration(type.self, index, DecorationRowMajor)); +} + +uint32_t CompilerMSL::get_declared_input_size_msl(const SPIRType &type, uint32_t index) const +{ + return get_declared_type_size_msl(get_presumed_input_type(type, index), false, + has_member_decoration(type.self, index, DecorationRowMajor)); +} + +// Returns the byte alignment of a type. +uint32_t CompilerMSL::get_declared_type_alignment_msl(const SPIRType &type, bool is_packed, bool row_major) const +{ + // Pointers align on multiples of 8 bytes. + // Deliberately ignore array-ness here. It's not relevant for alignment. + if (type.pointer && type.storage == StorageClassPhysicalStorageBuffer) + return 8; + + switch (type.basetype) + { + case SPIRType::Unknown: + case SPIRType::Void: + case SPIRType::AtomicCounter: + case SPIRType::Image: + case SPIRType::SampledImage: + case SPIRType::Sampler: + SPIRV_CROSS_THROW("Querying alignment of opaque object."); + + case SPIRType::Double: + SPIRV_CROSS_THROW("double types are not supported in buffers in MSL."); + + case SPIRType::Struct: + { + // In MSL, a struct's alignment is equal to the maximum alignment of any of its members. + uint32_t alignment = 1; + for (uint32_t i = 0; i < type.member_types.size(); i++) + alignment = max(alignment, uint32_t(get_declared_struct_member_alignment_msl(type, i))); + return alignment; + } + + default: + { + if (type.basetype == SPIRType::Int64 && !msl_options.supports_msl_version(2, 3)) + SPIRV_CROSS_THROW("long types in buffers are only supported in MSL 2.3 and above."); + if (type.basetype == SPIRType::UInt64 && !msl_options.supports_msl_version(2, 3)) + SPIRV_CROSS_THROW("ulong types in buffers are only supported in MSL 2.3 and above."); + // Alignment of packed type is the same as the underlying component or column size. + // Alignment of unpacked type is the same as the vector size. + // Alignment of 3-elements vector is the same as 4-elements (including packed using column). + if (is_packed) + { + // If we have packed_T and friends, the alignment is always scalar. + return type.width / 8; + } + else + { + // This is the general rule for MSL. Size == alignment. + uint32_t vecsize = (row_major && type.columns > 1) ? type.columns : type.vecsize; + return (type.width / 8) * (vecsize == 3 ? 4 : vecsize); + } + } + } +} + +uint32_t CompilerMSL::get_declared_struct_member_alignment_msl(const SPIRType &type, uint32_t index) const +{ + return get_declared_type_alignment_msl(get_physical_member_type(type, index), + member_is_packed_physical_type(type, index), + has_member_decoration(type.self, index, DecorationRowMajor)); +} + +uint32_t CompilerMSL::get_declared_input_alignment_msl(const SPIRType &type, uint32_t index) const +{ + return get_declared_type_alignment_msl(get_presumed_input_type(type, index), false, + has_member_decoration(type.self, index, DecorationRowMajor)); +} + +bool CompilerMSL::skip_argument(uint32_t) const +{ + return false; +} + +void CompilerMSL::analyze_sampled_image_usage() +{ + if (msl_options.swizzle_texture_samples) + { + SampledImageScanner scanner(*this); + traverse_all_reachable_opcodes(get(ir.default_entry_point), scanner); + } +} + +bool CompilerMSL::SampledImageScanner::handle(spv::Op opcode, const uint32_t *args, uint32_t length) +{ + switch (opcode) + { + case OpLoad: + case OpImage: + case OpSampledImage: + { + if (length < 3) + return false; + + uint32_t result_type = args[0]; + auto &type = compiler.get(result_type); + if ((type.basetype != SPIRType::Image && type.basetype != SPIRType::SampledImage) || type.image.sampled != 1) + return true; + + uint32_t id = args[1]; + compiler.set(id, "", result_type, true); + break; + } + case OpImageSampleExplicitLod: + case OpImageSampleProjExplicitLod: + case OpImageSampleDrefExplicitLod: + case OpImageSampleProjDrefExplicitLod: + case OpImageSampleImplicitLod: + case OpImageSampleProjImplicitLod: + case OpImageSampleDrefImplicitLod: + case OpImageSampleProjDrefImplicitLod: + case OpImageFetch: + case OpImageGather: + case OpImageDrefGather: + compiler.has_sampled_images = + compiler.has_sampled_images || compiler.is_sampled_image_type(compiler.expression_type(args[2])); + compiler.needs_swizzle_buffer_def = compiler.needs_swizzle_buffer_def || compiler.has_sampled_images; + break; + default: + break; + } + return true; +} + +// If a needed custom function wasn't added before, add it and force a recompile. +void CompilerMSL::add_spv_func_and_recompile(SPVFuncImpl spv_func) +{ + if (spv_function_implementations.count(spv_func) == 0) + { + spv_function_implementations.insert(spv_func); + suppress_missing_prototypes = true; + force_recompile(); + } +} + +bool CompilerMSL::OpCodePreprocessor::handle(Op opcode, const uint32_t *args, uint32_t length) +{ + // Since MSL exists in a single execution scope, function prototype declarations are not + // needed, and clutter the output. If secondary functions are output (either as a SPIR-V + // function implementation or as indicated by the presence of OpFunctionCall), then set + // suppress_missing_prototypes to suppress compiler warnings of missing function prototypes. + + // Mark if the input requires the implementation of an SPIR-V function that does not exist in Metal. + SPVFuncImpl spv_func = get_spv_func_impl(opcode, args); + if (spv_func != SPVFuncImplNone) + { + compiler.spv_function_implementations.insert(spv_func); + suppress_missing_prototypes = true; + } + + switch (opcode) + { + + case OpFunctionCall: + suppress_missing_prototypes = true; + break; + + case OpDemoteToHelperInvocationEXT: + uses_discard = true; + break; + + // Emulate texture2D atomic operations + case OpImageTexelPointer: + { + if (!compiler.msl_options.supports_msl_version(3, 1)) + { + auto *var = compiler.maybe_get_backing_variable(args[2]); + image_pointers_emulated[args[1]] = var ? var->self : ID(0); + } + break; + } + + case OpImageWrite: + uses_image_write = true; + break; + + case OpStore: + check_resource_write(args[0]); + break; + + // Emulate texture2D atomic operations + case OpAtomicExchange: + case OpAtomicCompareExchange: + case OpAtomicCompareExchangeWeak: + case OpAtomicIIncrement: + case OpAtomicIDecrement: + case OpAtomicIAdd: + case OpAtomicFAddEXT: + case OpAtomicISub: + case OpAtomicSMin: + case OpAtomicUMin: + case OpAtomicSMax: + case OpAtomicUMax: + case OpAtomicAnd: + case OpAtomicOr: + case OpAtomicXor: + { + uses_atomics = true; + auto it = image_pointers_emulated.find(args[2]); + if (it != image_pointers_emulated.end()) + { + uses_image_write = true; + compiler.atomic_image_vars_emulated.insert(it->second); + } + else + check_resource_write(args[2]); + break; + } + + case OpAtomicStore: + { + uses_atomics = true; + auto it = image_pointers_emulated.find(args[0]); + if (it != image_pointers_emulated.end()) + { + compiler.atomic_image_vars_emulated.insert(it->second); + uses_image_write = true; + } + else + check_resource_write(args[0]); + break; + } + + case OpAtomicLoad: + { + uses_atomics = true; + auto it = image_pointers_emulated.find(args[2]); + if (it != image_pointers_emulated.end()) + { + compiler.atomic_image_vars_emulated.insert(it->second); + } + break; + } + + case OpGroupNonUniformInverseBallot: + needs_subgroup_invocation_id = true; + break; + + case OpGroupNonUniformBallotFindLSB: + case OpGroupNonUniformBallotFindMSB: + needs_subgroup_size = true; + break; + + case OpGroupNonUniformBallotBitCount: + if (args[3] == GroupOperationReduce) + needs_subgroup_size = true; + else + needs_subgroup_invocation_id = true; + break; + + case OpArrayLength: + { + auto *var = compiler.maybe_get_backing_variable(args[2]); + if (var != nullptr) + { + if (!compiler.is_var_runtime_size_array(*var)) + compiler.buffers_requiring_array_length.insert(var->self); + } + break; + } + + case OpInBoundsAccessChain: + case OpAccessChain: + case OpPtrAccessChain: + { + // OpArrayLength might want to know if taking ArrayLength of an array of SSBOs. + uint32_t result_type = args[0]; + uint32_t id = args[1]; + uint32_t ptr = args[2]; + + compiler.set(id, "", result_type, true); + compiler.register_read(id, ptr, true); + compiler.ir.ids[id].set_allow_type_rewrite(); + break; + } + + case OpExtInst: + { + uint32_t extension_set = args[2]; + if (compiler.get(extension_set).ext == SPIRExtension::GLSL) + { + auto op_450 = static_cast(args[3]); + switch (op_450) + { + case GLSLstd450InterpolateAtCentroid: + case GLSLstd450InterpolateAtSample: + case GLSLstd450InterpolateAtOffset: + { + if (!compiler.msl_options.supports_msl_version(2, 3)) + SPIRV_CROSS_THROW("Pull-model interpolation requires MSL 2.3."); + // Fragment varyings used with pull-model interpolation need special handling, + // due to the way pull-model interpolation works in Metal. + auto *var = compiler.maybe_get_backing_variable(args[4]); + if (var) + { + compiler.pull_model_inputs.insert(var->self); + auto &var_type = compiler.get_variable_element_type(*var); + // In addition, if this variable has a 'Sample' decoration, we need the sample ID + // in order to do default interpolation. + if (compiler.has_decoration(var->self, DecorationSample)) + { + needs_sample_id = true; + } + else if (var_type.basetype == SPIRType::Struct) + { + // Now we need to check each member and see if it has this decoration. + for (uint32_t i = 0; i < var_type.member_types.size(); ++i) + { + if (compiler.has_member_decoration(var_type.self, i, DecorationSample)) + { + needs_sample_id = true; + break; + } + } + } + } + break; + } + default: + break; + } + } + break; + } + + case OpIsHelperInvocationEXT: + if (compiler.needs_manual_helper_invocation_updates()) + needs_helper_invocation = true; + break; + + default: + break; + } + + // If it has one, keep track of the instruction's result type, mapped by ID + uint32_t result_type, result_id; + if (compiler.instruction_to_result_type(result_type, result_id, opcode, args, length)) + result_types[result_id] = result_type; + + return true; +} + +// If the variable is a Uniform or StorageBuffer, mark that a resource has been written to. +void CompilerMSL::OpCodePreprocessor::check_resource_write(uint32_t var_id) +{ + auto *p_var = compiler.maybe_get_backing_variable(var_id); + StorageClass sc = p_var ? p_var->storage : StorageClassMax; + if (sc == StorageClassUniform || sc == StorageClassStorageBuffer) + uses_buffer_write = true; +} + +// Returns an enumeration of a SPIR-V function that needs to be output for certain Op codes. +CompilerMSL::SPVFuncImpl CompilerMSL::OpCodePreprocessor::get_spv_func_impl(Op opcode, const uint32_t *args) +{ + switch (opcode) + { + case OpFMod: + return SPVFuncImplMod; + + case OpFAdd: + case OpFSub: + if (compiler.msl_options.invariant_float_math || + compiler.has_decoration(args[1], DecorationNoContraction)) + { + return opcode == OpFAdd ? SPVFuncImplFAdd : SPVFuncImplFSub; + } + break; + + case OpFMul: + case OpOuterProduct: + case OpMatrixTimesVector: + case OpVectorTimesMatrix: + case OpMatrixTimesMatrix: + if (compiler.msl_options.invariant_float_math || + compiler.has_decoration(args[1], DecorationNoContraction)) + { + return SPVFuncImplFMul; + } + break; + + case OpQuantizeToF16: + return SPVFuncImplQuantizeToF16; + + case OpTypeArray: + { + // Allow Metal to use the array template to make arrays a value type + return SPVFuncImplUnsafeArray; + } + + // Emulate texture2D atomic operations + case OpAtomicExchange: + case OpAtomicCompareExchange: + case OpAtomicCompareExchangeWeak: + case OpAtomicIIncrement: + case OpAtomicIDecrement: + case OpAtomicIAdd: + case OpAtomicFAddEXT: + case OpAtomicISub: + case OpAtomicSMin: + case OpAtomicUMin: + case OpAtomicSMax: + case OpAtomicUMax: + case OpAtomicAnd: + case OpAtomicOr: + case OpAtomicXor: + case OpAtomicLoad: + case OpAtomicStore: + { + auto it = image_pointers_emulated.find(args[opcode == OpAtomicStore ? 0 : 2]); + if (it != image_pointers_emulated.end()) + { + uint32_t tid = compiler.get(it->second).basetype; + if (tid && compiler.get(tid).image.dim == Dim2D) + return SPVFuncImplImage2DAtomicCoords; + } + break; + } + + case OpImageFetch: + case OpImageRead: + case OpImageWrite: + { + // Retrieve the image type, and if it's a Buffer, emit a texel coordinate function + uint32_t tid = result_types[args[opcode == OpImageWrite ? 0 : 2]]; + if (tid && compiler.get(tid).image.dim == DimBuffer && !compiler.msl_options.texture_buffer_native) + return SPVFuncImplTexelBufferCoords; + break; + } + + case OpExtInst: + { + uint32_t extension_set = args[2]; + if (compiler.get(extension_set).ext == SPIRExtension::GLSL) + { + auto op_450 = static_cast(args[3]); + switch (op_450) + { + case GLSLstd450Radians: + return SPVFuncImplRadians; + case GLSLstd450Degrees: + return SPVFuncImplDegrees; + case GLSLstd450FindILsb: + return SPVFuncImplFindILsb; + case GLSLstd450FindSMsb: + return SPVFuncImplFindSMsb; + case GLSLstd450FindUMsb: + return SPVFuncImplFindUMsb; + case GLSLstd450SSign: + return SPVFuncImplSSign; + case GLSLstd450Reflect: + { + auto &type = compiler.get(args[0]); + if (type.vecsize == 1) + return SPVFuncImplReflectScalar; + break; + } + case GLSLstd450Refract: + { + auto &type = compiler.get(args[0]); + if (type.vecsize == 1) + return SPVFuncImplRefractScalar; + break; + } + case GLSLstd450FaceForward: + { + auto &type = compiler.get(args[0]); + if (type.vecsize == 1) + return SPVFuncImplFaceForwardScalar; + break; + } + case GLSLstd450MatrixInverse: + { + auto &mat_type = compiler.get(args[0]); + switch (mat_type.columns) + { + case 2: + return SPVFuncImplInverse2x2; + case 3: + return SPVFuncImplInverse3x3; + case 4: + return SPVFuncImplInverse4x4; + default: + break; + } + break; + } + default: + break; + } + } + break; + } + + case OpGroupNonUniformBroadcast: + case OpSubgroupReadInvocationKHR: + return SPVFuncImplSubgroupBroadcast; + + case OpGroupNonUniformBroadcastFirst: + case OpSubgroupFirstInvocationKHR: + return SPVFuncImplSubgroupBroadcastFirst; + + case OpGroupNonUniformBallot: + case OpSubgroupBallotKHR: + return SPVFuncImplSubgroupBallot; + + case OpGroupNonUniformInverseBallot: + case OpGroupNonUniformBallotBitExtract: + return SPVFuncImplSubgroupBallotBitExtract; + + case OpGroupNonUniformBallotFindLSB: + return SPVFuncImplSubgroupBallotFindLSB; + + case OpGroupNonUniformBallotFindMSB: + return SPVFuncImplSubgroupBallotFindMSB; + + case OpGroupNonUniformBallotBitCount: + return SPVFuncImplSubgroupBallotBitCount; + + case OpGroupNonUniformAllEqual: + case OpSubgroupAllEqualKHR: + return SPVFuncImplSubgroupAllEqual; + + case OpGroupNonUniformShuffle: + return SPVFuncImplSubgroupShuffle; + + case OpGroupNonUniformShuffleXor: + return SPVFuncImplSubgroupShuffleXor; + + case OpGroupNonUniformShuffleUp: + return SPVFuncImplSubgroupShuffleUp; + + case OpGroupNonUniformShuffleDown: + return SPVFuncImplSubgroupShuffleDown; + + case OpGroupNonUniformQuadBroadcast: + return SPVFuncImplQuadBroadcast; + + case OpGroupNonUniformQuadSwap: + return SPVFuncImplQuadSwap; + + case OpSDot: + case OpUDot: + case OpSUDot: + case OpSDotAccSat: + case OpUDotAccSat: + case OpSUDotAccSat: + return SPVFuncImplReduceAdd; + + default: + break; + } + return SPVFuncImplNone; +} + +// Sort both type and meta member content based on builtin status (put builtins at end), +// then by the required sorting aspect. +void CompilerMSL::MemberSorter::sort() +{ + // Create a temporary array of consecutive member indices and sort it based on how + // the members should be reordered, based on builtin and sorting aspect meta info. + size_t mbr_cnt = type.member_types.size(); + SmallVector mbr_idxs(mbr_cnt); + std::iota(mbr_idxs.begin(), mbr_idxs.end(), 0); // Fill with consecutive indices + std::stable_sort(mbr_idxs.begin(), mbr_idxs.end(), *this); // Sort member indices based on sorting aspect + + bool sort_is_identity = true; + for (uint32_t mbr_idx = 0; mbr_idx < mbr_cnt; mbr_idx++) + { + if (mbr_idx != mbr_idxs[mbr_idx]) + { + sort_is_identity = false; + break; + } + } + + if (sort_is_identity) + return; + + if (meta.members.size() < type.member_types.size()) + { + // This should never trigger in normal circumstances, but to be safe. + meta.members.resize(type.member_types.size()); + } + + // Move type and meta member info to the order defined by the sorted member indices. + // This is done by creating temporary copies of both member types and meta, and then + // copying back to the original content at the sorted indices. + auto mbr_types_cpy = type.member_types; + auto mbr_meta_cpy = meta.members; + for (uint32_t mbr_idx = 0; mbr_idx < mbr_cnt; mbr_idx++) + { + type.member_types[mbr_idx] = mbr_types_cpy[mbr_idxs[mbr_idx]]; + meta.members[mbr_idx] = mbr_meta_cpy[mbr_idxs[mbr_idx]]; + } + + // If we're sorting by Offset, this might affect user code which accesses a buffer block. + // We will need to redirect member indices from defined index to sorted index using reverse lookup. + if (sort_aspect == SortAspect::Offset) + { + type.member_type_index_redirection.resize(mbr_cnt); + for (uint32_t map_idx = 0; map_idx < mbr_cnt; map_idx++) + type.member_type_index_redirection[mbr_idxs[map_idx]] = map_idx; + } +} + +bool CompilerMSL::MemberSorter::operator()(uint32_t mbr_idx1, uint32_t mbr_idx2) +{ + auto &mbr_meta1 = meta.members[mbr_idx1]; + auto &mbr_meta2 = meta.members[mbr_idx2]; + + if (sort_aspect == LocationThenBuiltInType) + { + // Sort first by builtin status (put builtins at end), then by the sorting aspect. + if (mbr_meta1.builtin != mbr_meta2.builtin) + return mbr_meta2.builtin; + else if (mbr_meta1.builtin) + return mbr_meta1.builtin_type < mbr_meta2.builtin_type; + else if (mbr_meta1.location == mbr_meta2.location) + return mbr_meta1.component < mbr_meta2.component; + else + return mbr_meta1.location < mbr_meta2.location; + } + else + return mbr_meta1.offset < mbr_meta2.offset; +} + +CompilerMSL::MemberSorter::MemberSorter(SPIRType &t, Meta &m, SortAspect sa) + : type(t) + , meta(m) + , sort_aspect(sa) +{ + // Ensure enough meta info is available + meta.members.resize(max(type.member_types.size(), meta.members.size())); +} + +void CompilerMSL::remap_constexpr_sampler(VariableID id, const MSLConstexprSampler &sampler) +{ + auto &type = get(get(id).basetype); + if (type.basetype != SPIRType::SampledImage && type.basetype != SPIRType::Sampler) + SPIRV_CROSS_THROW("Can only remap SampledImage and Sampler type."); + if (!type.array.empty()) + SPIRV_CROSS_THROW("Can not remap array of samplers."); + constexpr_samplers_by_id[id] = sampler; +} + +void CompilerMSL::remap_constexpr_sampler_by_binding(uint32_t desc_set, uint32_t binding, + const MSLConstexprSampler &sampler) +{ + constexpr_samplers_by_binding[{ desc_set, binding }] = sampler; +} + +void CompilerMSL::cast_from_variable_load(uint32_t source_id, std::string &expr, const SPIRType &expr_type) +{ + bool is_packed = has_extended_decoration(source_id, SPIRVCrossDecorationPhysicalTypePacked); + auto *source_expr = maybe_get(source_id); + auto *var = maybe_get_backing_variable(source_id); + const SPIRType *var_type = nullptr, *phys_type = nullptr; + + if (uint32_t phys_id = get_extended_decoration(source_id, SPIRVCrossDecorationPhysicalTypeID)) + phys_type = &get(phys_id); + else + phys_type = &expr_type; + + if (var) + { + source_id = var->self; + var_type = &get_variable_data_type(*var); + } + + bool rewrite_boolean_load = + expr_type.basetype == SPIRType::Boolean && + (var && (var->storage == StorageClassWorkgroup || var_type->basetype == SPIRType::Struct)); + + // Type fixups for workgroup variables if they are booleans. + if (rewrite_boolean_load) + { + if (is_array(expr_type)) + expr = to_rerolled_array_expression(expr_type, expr, expr_type); + else + expr = join(type_to_glsl(expr_type), "(", expr, ")"); + } + + // Type fixups for workgroup variables if they are matrices. + // Don't do fixup for packed types; those are handled specially. + // FIXME: Maybe use a type like spvStorageMatrix for packed matrices? + if (!msl_options.supports_msl_version(3, 0) && var && + (var->storage == StorageClassWorkgroup || + (var_type->basetype == SPIRType::Struct && + has_extended_decoration(var_type->self, SPIRVCrossDecorationWorkgroupStruct) && !is_packed)) && + expr_type.columns > 1) + { + SPIRType matrix_type = *phys_type; + if (source_expr && source_expr->need_transpose) + swap(matrix_type.vecsize, matrix_type.columns); + matrix_type.array.clear(); + matrix_type.array_size_literal.clear(); + expr = join(type_to_glsl(matrix_type), "(", expr, ")"); + } + + // Only interested in standalone builtin variables in the switch below. + if (!has_decoration(source_id, DecorationBuiltIn)) + { + // If the backing variable does not match our expected sign, we can fix it up here. + // See ensure_correct_input_type(). + if (var && var->storage == StorageClassInput) + { + auto &base_type = get(var->basetype); + if (base_type.basetype != SPIRType::Struct && expr_type.basetype != base_type.basetype) + expr = join(type_to_glsl(expr_type), "(", expr, ")"); + } + return; + } + + auto builtin = static_cast(get_decoration(source_id, DecorationBuiltIn)); + auto expected_type = expr_type.basetype; + auto expected_width = expr_type.width; + switch (builtin) + { + case BuiltInGlobalInvocationId: + case BuiltInLocalInvocationId: + case BuiltInWorkgroupId: + case BuiltInLocalInvocationIndex: + case BuiltInWorkgroupSize: + case BuiltInNumWorkgroups: + case BuiltInLayer: + case BuiltInViewportIndex: + case BuiltInFragStencilRefEXT: + case BuiltInPrimitiveId: + case BuiltInSubgroupSize: + case BuiltInSubgroupLocalInvocationId: + case BuiltInViewIndex: + case BuiltInVertexIndex: + case BuiltInInstanceIndex: + case BuiltInBaseInstance: + case BuiltInBaseVertex: + case BuiltInSampleMask: + expected_type = SPIRType::UInt; + expected_width = 32; + break; + + case BuiltInTessLevelInner: + case BuiltInTessLevelOuter: + if (is_tesc_shader()) + { + expected_type = SPIRType::Half; + expected_width = 16; + } + break; + + default: + break; + } + + if (is_array(expr_type) && builtin == BuiltInSampleMask) + { + // Needs special handling. + auto wrap_expr = join(type_to_glsl(expr_type), "({ "); + wrap_expr += join(type_to_glsl(get(expr_type.parent_type)), "(", expr, ")"); + wrap_expr += " })"; + expr = std::move(wrap_expr); + } + else if (expected_type != expr_type.basetype) + { + if (is_array(expr_type) && (builtin == BuiltInTessLevelInner || builtin == BuiltInTessLevelOuter)) + { + // Triggers when loading TessLevel directly as an array. + // Need explicit padding + cast. + auto wrap_expr = join(type_to_glsl(expr_type), "({ "); + + uint32_t array_size = get_physical_tess_level_array_size(builtin); + for (uint32_t i = 0; i < array_size; i++) + { + if (array_size > 1) + wrap_expr += join("float(", expr, "[", i, "])"); + else + wrap_expr += join("float(", expr, ")"); + if (i + 1 < array_size) + wrap_expr += ", "; + } + + if (is_tessellating_triangles()) + wrap_expr += ", 0.0"; + + wrap_expr += " })"; + expr = std::move(wrap_expr); + } + else + { + // These are of different widths, so we cannot do a straight bitcast. + if (expected_width != expr_type.width) + expr = join(type_to_glsl(expr_type), "(", expr, ")"); + else + expr = bitcast_expression(expr_type, expected_type, expr); + } + } +} + +void CompilerMSL::cast_to_variable_store(uint32_t target_id, std::string &expr, const SPIRType &expr_type) +{ + bool is_packed = has_extended_decoration(target_id, SPIRVCrossDecorationPhysicalTypePacked); + auto *target_expr = maybe_get(target_id); + auto *var = maybe_get_backing_variable(target_id); + const SPIRType *var_type = nullptr, *phys_type = nullptr; + + if (uint32_t phys_id = get_extended_decoration(target_id, SPIRVCrossDecorationPhysicalTypeID)) + phys_type = &get(phys_id); + else + phys_type = &expr_type; + + if (var) + { + target_id = var->self; + var_type = &get_variable_data_type(*var); + } + + bool rewrite_boolean_store = + expr_type.basetype == SPIRType::Boolean && + (var && (var->storage == StorageClassWorkgroup || var_type->basetype == SPIRType::Struct)); + + // Type fixups for workgroup variables or struct members if they are booleans. + if (rewrite_boolean_store) + { + if (is_array(expr_type)) + { + expr = to_rerolled_array_expression(*var_type, expr, expr_type); + } + else + { + auto short_type = expr_type; + short_type.basetype = SPIRType::Short; + expr = join(type_to_glsl(short_type), "(", expr, ")"); + } + } + + // Type fixups for workgroup variables if they are matrices. + // Don't do fixup for packed types; those are handled specially. + // FIXME: Maybe use a type like spvStorageMatrix for packed matrices? + if (!msl_options.supports_msl_version(3, 0) && var && + (var->storage == StorageClassWorkgroup || + (var_type->basetype == SPIRType::Struct && + has_extended_decoration(var_type->self, SPIRVCrossDecorationWorkgroupStruct) && !is_packed)) && + expr_type.columns > 1) + { + SPIRType matrix_type = *phys_type; + if (target_expr && target_expr->need_transpose) + swap(matrix_type.vecsize, matrix_type.columns); + expr = join("spvStorage_", type_to_glsl(matrix_type), "(", expr, ")"); + } + + // Only interested in standalone builtin variables. + if (!has_decoration(target_id, DecorationBuiltIn)) + return; + + auto builtin = static_cast(get_decoration(target_id, DecorationBuiltIn)); + auto expected_type = expr_type.basetype; + auto expected_width = expr_type.width; + switch (builtin) + { + case BuiltInLayer: + case BuiltInViewportIndex: + case BuiltInFragStencilRefEXT: + case BuiltInPrimitiveId: + case BuiltInViewIndex: + expected_type = SPIRType::UInt; + expected_width = 32; + break; + + case BuiltInTessLevelInner: + case BuiltInTessLevelOuter: + expected_type = SPIRType::Half; + expected_width = 16; + break; + + default: + break; + } + + if (expected_type != expr_type.basetype) + { + if (expected_width != expr_type.width) + { + // These are of different widths, so we cannot do a straight bitcast. + auto type = expr_type; + type.basetype = expected_type; + type.width = expected_width; + expr = join(type_to_glsl(type), "(", expr, ")"); + } + else + { + auto type = expr_type; + type.basetype = expected_type; + expr = bitcast_expression(type, expr_type.basetype, expr); + } + } +} + +string CompilerMSL::to_initializer_expression(const SPIRVariable &var) +{ + // We risk getting an array initializer here with MSL. If we have an array. + // FIXME: We cannot handle non-constant arrays being initialized. + // We will need to inject spvArrayCopy here somehow ... + auto &type = get(var.basetype); + string expr; + if (ir.ids[var.initializer].get_type() == TypeConstant && + (!type.array.empty() || type.basetype == SPIRType::Struct)) + expr = constant_expression(get(var.initializer)); + else + expr = CompilerGLSL::to_initializer_expression(var); + // If the initializer has more vector components than the variable, add a swizzle. + // FIXME: This can't handle arrays or structs. + auto &init_type = expression_type(var.initializer); + if (type.array.empty() && type.basetype != SPIRType::Struct && init_type.vecsize > type.vecsize) + expr = enclose_expression(expr + vector_swizzle(type.vecsize, 0)); + return expr; +} + +string CompilerMSL::to_zero_initialized_expression(uint32_t) +{ + return "{}"; +} + +bool CompilerMSL::descriptor_set_is_argument_buffer(uint32_t desc_set) const +{ + if (!msl_options.argument_buffers) + return false; + if (desc_set >= kMaxArgumentBuffers) + return false; + + return (argument_buffer_discrete_mask & (1u << desc_set)) == 0; +} + +bool CompilerMSL::is_supported_argument_buffer_type(const SPIRType &type) const +{ + // iOS Tier 1 argument buffers do not support writable images. + // When the argument buffer is encoded, we don't know whether this image will have a + // NonWritable decoration, so just use discrete arguments for all storage images on iOS. + bool is_supported_type = !(type.basetype == SPIRType::Image && + type.image.sampled == 2 && + msl_options.is_ios() && + msl_options.argument_buffers_tier <= Options::ArgumentBuffersTier::Tier1); + return is_supported_type && !type_is_msl_framebuffer_fetch(type); +} + +void CompilerMSL::emit_argument_buffer_aliased_descriptor(const SPIRVariable &aliased_var, + const SPIRVariable &base_var) +{ + // To deal with buffer <-> image aliasing, we need to perform an unholy UB ritual. + // A texture type in Metal 3.0 is a pointer. However, we cannot simply cast a pointer to texture. + // What we *can* do is to cast pointer-to-pointer to pointer-to-texture. + + // We need to explicitly reach into the descriptor buffer lvalue, not any spvDescriptorArray wrapper. + auto *var_meta = ir.find_meta(base_var.self); + bool old_explicit_qualifier = var_meta && var_meta->decoration.qualified_alias_explicit_override; + if (var_meta) + var_meta->decoration.qualified_alias_explicit_override = false; + auto unqualified_name = to_name(base_var.self, false); + if (var_meta) + var_meta->decoration.qualified_alias_explicit_override = old_explicit_qualifier; + + // For non-arrayed buffers, we have already performed a de-reference. + // We need a proper lvalue to cast, so strip away the de-reference. + if (unqualified_name.size() > 2 && unqualified_name[0] == '(' && unqualified_name[1] == '*') + { + unqualified_name.erase(unqualified_name.begin(), unqualified_name.begin() + 2); + unqualified_name.pop_back(); + } + + string name; + + auto &var_type = get(aliased_var.basetype); + auto &data_type = get_variable_data_type(aliased_var); + string descriptor_storage = descriptor_address_space(aliased_var.self, aliased_var.storage, ""); + + if (aliased_var.storage == StorageClassUniformConstant) + { + if (is_var_runtime_size_array(aliased_var)) + { + // This becomes a plain pointer to spvDescriptor. + name = join("reinterpret_cast<", descriptor_storage, " ", + type_to_glsl(get_variable_data_type(aliased_var), aliased_var.self, true), ">(&", + unqualified_name, ")"); + } + else + { + name = join("reinterpret_cast<", descriptor_storage, " ", + type_to_glsl(get_variable_data_type(aliased_var), aliased_var.self, true), " &>(", + unqualified_name, ");"); + } + } + else + { + // Buffer types. + bool old_is_using_builtin_array = is_using_builtin_array; + is_using_builtin_array = true; + + bool needs_post_cast_deref = !is_array(data_type); + string ref_type = needs_post_cast_deref ? "&" : join("(&)", type_to_array_glsl(var_type, aliased_var.self)); + + if (is_var_runtime_size_array(aliased_var)) + { + name = join("reinterpret_cast<", + type_to_glsl(var_type, aliased_var.self, true), " ", descriptor_storage, " *>(&", + unqualified_name, ")"); + } + else + { + name = join(needs_post_cast_deref ? "*" : "", "reinterpret_cast<", + type_to_glsl(var_type, aliased_var.self, true), " ", descriptor_storage, " ", + ref_type, + ">(", unqualified_name, ");"); + } + + if (needs_post_cast_deref) + descriptor_storage = get_type_address_space(var_type, aliased_var.self, false); + + // These kinds of ridiculous casts trigger warnings in compiler. Just ignore them. + if (!suppress_incompatible_pointer_types_discard_qualifiers) + { + suppress_incompatible_pointer_types_discard_qualifiers = true; + force_recompile_guarantee_forward_progress(); + } + + is_using_builtin_array = old_is_using_builtin_array; + } + + if (!is_var_runtime_size_array(aliased_var)) + { + // Lower to temporary, so drop the qualification. + set_qualified_name(aliased_var.self, ""); + statement(descriptor_storage, " auto &", to_name(aliased_var.self), " = ", name); + } + else + { + // This alias may have already been used to emit an entry point declaration. If there is a mismatch, we need a recompile. + // Moving this code to be run earlier will also conflict, + // because we need the qualified alias for the base resource, + // so forcing recompile until things sync up is the least invasive method for now. + if (ir.meta[aliased_var.self].decoration.qualified_alias != name) + force_recompile(); + + // This will get wrapped in a separate temporary when a spvDescriptorArray wrapper is emitted. + set_qualified_name(aliased_var.self, name); + } +} + +void CompilerMSL::analyze_argument_buffers() +{ + // Gather all used resources and sort them out into argument buffers. + // Each argument buffer corresponds to a descriptor set in SPIR-V. + // The [[id(N)]] values used correspond to the resource mapping we have for MSL. + // Otherwise, the binding number is used, but this is generally not safe some types like + // combined image samplers and arrays of resources. Metal needs different indices here, + // while SPIR-V can have one descriptor set binding. To use argument buffers in practice, + // you will need to use the remapping from the API. + for (auto &id : argument_buffer_ids) + id = 0; + + // Output resources, sorted by resource index & type. + struct Resource + { + SPIRVariable *var; + string name; + SPIRType::BaseType basetype; + uint32_t index; + uint32_t plane_count; + uint32_t plane; + uint32_t overlapping_var_id; + }; + SmallVector resources_in_set[kMaxArgumentBuffers]; + SmallVector inline_block_vars; + + bool set_needs_swizzle_buffer[kMaxArgumentBuffers] = {}; + bool set_needs_buffer_sizes[kMaxArgumentBuffers] = {}; + bool needs_buffer_sizes = false; + + ir.for_each_typed_id([&](uint32_t self, SPIRVariable &var) { + if ((var.storage == StorageClassUniform || var.storage == StorageClassUniformConstant || + var.storage == StorageClassStorageBuffer) && + !is_hidden_variable(var)) + { + uint32_t desc_set = get_decoration(self, DecorationDescriptorSet); + // Ignore if it's part of a push descriptor set. + if (!descriptor_set_is_argument_buffer(desc_set)) + return; + + uint32_t var_id = var.self; + auto &type = get_variable_data_type(var); + + if (desc_set >= kMaxArgumentBuffers) + SPIRV_CROSS_THROW("Descriptor set index is out of range."); + + const MSLConstexprSampler *constexpr_sampler = nullptr; + if (type.basetype == SPIRType::SampledImage || type.basetype == SPIRType::Sampler) + { + constexpr_sampler = find_constexpr_sampler(var_id); + if (constexpr_sampler) + { + // Mark this ID as a constexpr sampler for later in case it came from set/bindings. + constexpr_samplers_by_id[var_id] = *constexpr_sampler; + } + } + + uint32_t binding = get_decoration(var_id, DecorationBinding); + if (type.basetype == SPIRType::SampledImage) + { + add_resource_name(var_id); + + uint32_t plane_count = 1; + if (constexpr_sampler && constexpr_sampler->ycbcr_conversion_enable) + plane_count = constexpr_sampler->planes; + + for (uint32_t i = 0; i < plane_count; i++) + { + uint32_t image_resource_index = get_metal_resource_index(var, SPIRType::Image, i); + resources_in_set[desc_set].push_back( + { &var, to_name(var_id), SPIRType::Image, image_resource_index, plane_count, i, 0 }); + } + + if (type.image.dim != DimBuffer && !constexpr_sampler) + { + uint32_t sampler_resource_index = get_metal_resource_index(var, SPIRType::Sampler); + resources_in_set[desc_set].push_back( + { &var, to_sampler_expression(var_id), SPIRType::Sampler, sampler_resource_index, 1, 0, 0 }); + } + } + else if (inline_uniform_blocks.count(SetBindingPair{ desc_set, binding })) + { + inline_block_vars.push_back(var_id); + } + else if (!constexpr_sampler && is_supported_argument_buffer_type(type)) + { + // constexpr samplers are not declared as resources. + // Inline uniform blocks are always emitted at the end. + add_resource_name(var_id); + + uint32_t resource_index = get_metal_resource_index(var, type.basetype); + + resources_in_set[desc_set].push_back( + { &var, to_name(var_id), type.basetype, resource_index, 1, 0, 0 }); + + // Emulate texture2D atomic operations + if (atomic_image_vars_emulated.count(var.self)) + { + uint32_t buffer_resource_index = get_metal_resource_index(var, SPIRType::AtomicCounter, 0); + resources_in_set[desc_set].push_back( + { &var, to_name(var_id) + "_atomic", SPIRType::Struct, buffer_resource_index, 1, 0, 0 }); + } + } + + // Check if this descriptor set needs a swizzle buffer. + if (needs_swizzle_buffer_def && is_sampled_image_type(type)) + set_needs_swizzle_buffer[desc_set] = true; + else if (buffer_requires_array_length(var_id)) + { + set_needs_buffer_sizes[desc_set] = true; + needs_buffer_sizes = true; + } + } + }); + + if (needs_swizzle_buffer_def || needs_buffer_sizes) + { + uint32_t uint_ptr_type_id = 0; + + // We might have to add a swizzle buffer resource to the set. + for (uint32_t desc_set = 0; desc_set < kMaxArgumentBuffers; desc_set++) + { + if (!set_needs_swizzle_buffer[desc_set] && !set_needs_buffer_sizes[desc_set]) + continue; + + if (uint_ptr_type_id == 0) + { + uint_ptr_type_id = ir.increase_bound_by(1); + + // Create a buffer to hold extra data, including the swizzle constants. + SPIRType uint_type_pointer = get_uint_type(); + uint_type_pointer.op = OpTypePointer; + uint_type_pointer.pointer = true; + uint_type_pointer.pointer_depth++; + uint_type_pointer.parent_type = get_uint_type_id(); + uint_type_pointer.storage = StorageClassUniform; + set(uint_ptr_type_id, uint_type_pointer); + set_decoration(uint_ptr_type_id, DecorationArrayStride, 4); + } + + if (set_needs_swizzle_buffer[desc_set]) + { + uint32_t var_id = ir.increase_bound_by(1); + auto &var = set(var_id, uint_ptr_type_id, StorageClassUniformConstant); + set_name(var_id, "spvSwizzleConstants"); + set_decoration(var_id, DecorationDescriptorSet, desc_set); + set_decoration(var_id, DecorationBinding, kSwizzleBufferBinding); + resources_in_set[desc_set].push_back( + { &var, to_name(var_id), SPIRType::UInt, get_metal_resource_index(var, SPIRType::UInt), 1, 0, 0 }); + } + + if (set_needs_buffer_sizes[desc_set]) + { + uint32_t var_id = ir.increase_bound_by(1); + auto &var = set(var_id, uint_ptr_type_id, StorageClassUniformConstant); + set_name(var_id, "spvBufferSizeConstants"); + set_decoration(var_id, DecorationDescriptorSet, desc_set); + set_decoration(var_id, DecorationBinding, kBufferSizeBufferBinding); + resources_in_set[desc_set].push_back( + { &var, to_name(var_id), SPIRType::UInt, get_metal_resource_index(var, SPIRType::UInt), 1, 0, 0 }); + } + } + } + + // Now add inline uniform blocks. + for (uint32_t var_id : inline_block_vars) + { + auto &var = get(var_id); + uint32_t desc_set = get_decoration(var_id, DecorationDescriptorSet); + add_resource_name(var_id); + resources_in_set[desc_set].push_back( + { &var, to_name(var_id), SPIRType::Struct, get_metal_resource_index(var, SPIRType::Struct), 1, 0, 0 }); + } + + for (uint32_t desc_set = 0; desc_set < kMaxArgumentBuffers; desc_set++) + { + auto &resources = resources_in_set[desc_set]; + if (resources.empty()) + continue; + + assert(descriptor_set_is_argument_buffer(desc_set)); + + uint32_t next_id = ir.increase_bound_by(3); + uint32_t type_id = next_id + 1; + uint32_t ptr_type_id = next_id + 2; + argument_buffer_ids[desc_set] = next_id; + + auto &buffer_type = set(type_id, OpTypeStruct); + + buffer_type.basetype = SPIRType::Struct; + + if ((argument_buffer_device_storage_mask & (1u << desc_set)) != 0) + { + buffer_type.storage = StorageClassStorageBuffer; + // Make sure the argument buffer gets marked as const device. + set_decoration(next_id, DecorationNonWritable); + // Need to mark the type as a Block to enable this. + set_decoration(type_id, DecorationBlock); + } + else + buffer_type.storage = StorageClassUniform; + + auto buffer_type_name = join("spvDescriptorSetBuffer", desc_set); + set_name(type_id, buffer_type_name); + + auto &ptr_type = set(ptr_type_id, OpTypePointer); + ptr_type = buffer_type; + ptr_type.op = spv::OpTypePointer; + ptr_type.pointer = true; + ptr_type.pointer_depth++; + ptr_type.parent_type = type_id; + + uint32_t buffer_variable_id = next_id; + auto &buffer_var = set(buffer_variable_id, ptr_type_id, StorageClassUniform); + auto buffer_name = join("spvDescriptorSet", desc_set); + set_name(buffer_variable_id, buffer_name); + + // Ids must be emitted in ID order. + stable_sort(begin(resources), end(resources), [&](const Resource &lhs, const Resource &rhs) -> bool { + return tie(lhs.index, lhs.basetype) < tie(rhs.index, rhs.basetype); + }); + + for (size_t i = 0; i < resources.size() - 1; i++) + { + auto &r1 = resources[i]; + auto &r2 = resources[i + 1]; + + if (r1.index == r2.index) + { + if (r1.overlapping_var_id) + r2.overlapping_var_id = r1.overlapping_var_id; + else + r2.overlapping_var_id = r1.var->self; + + set_extended_decoration(r2.var->self, SPIRVCrossDecorationOverlappingBinding, r2.overlapping_var_id); + } + } + + uint32_t member_index = 0; + uint32_t next_arg_buff_index = 0; + for (auto &resource : resources) + { + auto &var = *resource.var; + auto &type = get_variable_data_type(var); + + if (is_var_runtime_size_array(var) && (argument_buffer_device_storage_mask & (1u << desc_set)) == 0) + SPIRV_CROSS_THROW("Runtime sized variables must be in device storage argument buffers."); + + // If needed, synthesize and add padding members. + // member_index and next_arg_buff_index are incremented when padding members are added. + if (msl_options.pad_argument_buffer_resources && resource.plane == 0 && resource.overlapping_var_id == 0) + { + auto rez_bind = get_argument_buffer_resource(desc_set, next_arg_buff_index); + while (resource.index > next_arg_buff_index) + { + switch (rez_bind.basetype) + { + case SPIRType::Void: + case SPIRType::Boolean: + case SPIRType::SByte: + case SPIRType::UByte: + case SPIRType::Short: + case SPIRType::UShort: + case SPIRType::Int: + case SPIRType::UInt: + case SPIRType::Int64: + case SPIRType::UInt64: + case SPIRType::AtomicCounter: + case SPIRType::Half: + case SPIRType::Float: + case SPIRType::Double: + add_argument_buffer_padding_buffer_type(buffer_type, member_index, next_arg_buff_index, rez_bind); + break; + case SPIRType::Image: + add_argument_buffer_padding_image_type(buffer_type, member_index, next_arg_buff_index, rez_bind); + break; + case SPIRType::Sampler: + add_argument_buffer_padding_sampler_type(buffer_type, member_index, next_arg_buff_index, rez_bind); + break; + case SPIRType::SampledImage: + if (next_arg_buff_index == rez_bind.msl_sampler) + add_argument_buffer_padding_sampler_type(buffer_type, member_index, next_arg_buff_index, rez_bind); + else + add_argument_buffer_padding_image_type(buffer_type, member_index, next_arg_buff_index, rez_bind); + break; + default: + break; + } + + // After padding, retrieve the resource again. It will either be more padding, or the actual resource. + rez_bind = get_argument_buffer_resource(desc_set, next_arg_buff_index); + } + + // Adjust the number of slots consumed by current member itself. + // Use the count value from the app, instead of the shader, in case the + // shader is only accessing part, or even one element, of the array. + next_arg_buff_index += resource.plane_count * rez_bind.count; + } + + string mbr_name = ensure_valid_name(resource.name, "m"); + if (resource.plane > 0) + mbr_name += join(plane_name_suffix, resource.plane); + set_member_name(buffer_type.self, member_index, mbr_name); + + if (resource.basetype == SPIRType::Sampler && type.basetype != SPIRType::Sampler) + { + // Have to synthesize a sampler type here. + + bool type_is_array = !type.array.empty(); + uint32_t sampler_type_id = ir.increase_bound_by(type_is_array ? 2 : 1); + auto &new_sampler_type = set(sampler_type_id, OpTypeSampler); + new_sampler_type.basetype = SPIRType::Sampler; + new_sampler_type.storage = StorageClassUniformConstant; + + if (type_is_array) + { + uint32_t sampler_type_array_id = sampler_type_id + 1; + auto &sampler_type_array = set(sampler_type_array_id, OpTypeArray); + sampler_type_array = new_sampler_type; + sampler_type_array.array = type.array; + sampler_type_array.array_size_literal = type.array_size_literal; + sampler_type_array.parent_type = sampler_type_id; + buffer_type.member_types.push_back(sampler_type_array_id); + } + else + buffer_type.member_types.push_back(sampler_type_id); + } + else + { + uint32_t binding = get_decoration(var.self, DecorationBinding); + SetBindingPair pair = { desc_set, binding }; + + if (resource.basetype == SPIRType::Image || resource.basetype == SPIRType::Sampler || + resource.basetype == SPIRType::SampledImage) + { + // Drop pointer information when we emit the resources into a struct. + buffer_type.member_types.push_back(get_variable_data_type_id(var)); + if (has_extended_decoration(var.self, SPIRVCrossDecorationOverlappingBinding)) + { + if (!msl_options.supports_msl_version(3, 0)) + SPIRV_CROSS_THROW("Full mutable aliasing of argument buffer descriptors only works on Metal 3+."); + + auto &entry_func = get(ir.default_entry_point); + entry_func.fixup_hooks_in.push_back([this, resource]() { + emit_argument_buffer_aliased_descriptor(*resource.var, this->get(resource.overlapping_var_id)); + }); + } + else if (resource.plane == 0) + { + set_qualified_name(var.self, join(to_name(buffer_variable_id), ".", mbr_name)); + } + } + else if (buffers_requiring_dynamic_offset.count(pair)) + { + // Don't set the qualified name here; we'll define a variable holding the corrected buffer address later. + buffer_type.member_types.push_back(var.basetype); + buffers_requiring_dynamic_offset[pair].second = var.self; + } + else if (inline_uniform_blocks.count(pair)) + { + // Put the buffer block itself into the argument buffer. + buffer_type.member_types.push_back(get_variable_data_type_id(var)); + set_qualified_name(var.self, join(to_name(buffer_variable_id), ".", mbr_name)); + } + else if (atomic_image_vars_emulated.count(var.self)) + { + // Emulate texture2D atomic operations. + // Don't set the qualified name: it's already set for this variable, + // and the code that references the buffer manually appends "_atomic" + // to the name. + uint32_t offset = ir.increase_bound_by(2); + uint32_t atomic_type_id = offset; + uint32_t type_ptr_id = offset + 1; + + SPIRType atomic_type { OpTypeInt }; + atomic_type.basetype = SPIRType::AtomicCounter; + atomic_type.width = 32; + atomic_type.vecsize = 1; + set(atomic_type_id, atomic_type); + + atomic_type.op = OpTypePointer; + atomic_type.pointer = true; + atomic_type.pointer_depth++; + atomic_type.parent_type = atomic_type_id; + atomic_type.storage = StorageClassStorageBuffer; + auto &atomic_ptr_type = set(type_ptr_id, atomic_type); + atomic_ptr_type.self = atomic_type_id; + + buffer_type.member_types.push_back(type_ptr_id); + } + else + { + buffer_type.member_types.push_back(var.basetype); + if (has_extended_decoration(var.self, SPIRVCrossDecorationOverlappingBinding)) + { + // Casting raw pointers is fine since their ABI is fixed, but anything opaque is deeply questionable on Metal 2. + if (get(resource.overlapping_var_id).storage == StorageClassUniformConstant && + !msl_options.supports_msl_version(3, 0)) + { + SPIRV_CROSS_THROW("Full mutable aliasing of argument buffer descriptors only works on Metal 3+."); + } + + auto &entry_func = get(ir.default_entry_point); + + entry_func.fixup_hooks_in.push_back([this, resource]() { + emit_argument_buffer_aliased_descriptor(*resource.var, this->get(resource.overlapping_var_id)); + }); + } + else if (type.array.empty()) + set_qualified_name(var.self, join("(*", to_name(buffer_variable_id), ".", mbr_name, ")")); + else + set_qualified_name(var.self, join(to_name(buffer_variable_id), ".", mbr_name)); + } + } + + set_extended_member_decoration(buffer_type.self, member_index, SPIRVCrossDecorationResourceIndexPrimary, + resource.index); + set_extended_member_decoration(buffer_type.self, member_index, SPIRVCrossDecorationInterfaceOrigID, + var.self); + if (has_extended_decoration(var.self, SPIRVCrossDecorationOverlappingBinding)) + set_extended_member_decoration(buffer_type.self, member_index, SPIRVCrossDecorationOverlappingBinding); + member_index++; + } + + if (msl_options.replace_recursive_inputs && type_contains_recursion(buffer_type)) + { + recursive_inputs.insert(type_id); + auto &entry_func = this->get(ir.default_entry_point); + auto addr_space = get_argument_address_space(buffer_var); + entry_func.fixup_hooks_in.push_back([this, addr_space, buffer_name, buffer_type_name]() { + statement(addr_space, " auto& ", buffer_name, " = *(", addr_space, " ", buffer_type_name, "*)", buffer_name, "_vp;"); + }); + } + } +} + +// Return the resource type of the app-provided resources for the descriptor set, +// that matches the resource index of the argument buffer index. +// This is a two-step lookup, first lookup the resource binding number from the argument buffer index, +// then lookup the resource binding using the binding number. +const MSLResourceBinding &CompilerMSL::get_argument_buffer_resource(uint32_t desc_set, uint32_t arg_idx) const +{ + auto stage = get_entry_point().model; + StageSetBinding arg_idx_tuple = { stage, desc_set, arg_idx }; + auto arg_itr = resource_arg_buff_idx_to_binding_number.find(arg_idx_tuple); + if (arg_itr != end(resource_arg_buff_idx_to_binding_number)) + { + StageSetBinding bind_tuple = { stage, desc_set, arg_itr->second }; + auto bind_itr = resource_bindings.find(bind_tuple); + if (bind_itr != end(resource_bindings)) + return bind_itr->second.first; + } + SPIRV_CROSS_THROW("Argument buffer resource base type could not be determined. When padding argument buffer " + "elements, all descriptor set resources must be supplied with a base type by the app."); +} + +// Adds an argument buffer padding argument buffer type as one or more members of the struct type at the member index. +// Metal does not support arrays of buffers, so these are emitted as multiple struct members. +void CompilerMSL::add_argument_buffer_padding_buffer_type(SPIRType &struct_type, uint32_t &mbr_idx, + uint32_t &arg_buff_index, MSLResourceBinding &rez_bind) +{ + if (!argument_buffer_padding_buffer_type_id) + { + uint32_t buff_type_id = ir.increase_bound_by(2); + auto &buff_type = set(buff_type_id, OpNop); + buff_type.basetype = rez_bind.basetype; + buff_type.storage = StorageClassUniformConstant; + + uint32_t ptr_type_id = buff_type_id + 1; + auto &ptr_type = set(ptr_type_id, OpTypePointer); + ptr_type = buff_type; + ptr_type.op = spv::OpTypePointer; + ptr_type.pointer = true; + ptr_type.pointer_depth++; + ptr_type.parent_type = buff_type_id; + + argument_buffer_padding_buffer_type_id = ptr_type_id; + } + + add_argument_buffer_padding_type(argument_buffer_padding_buffer_type_id, struct_type, mbr_idx, arg_buff_index, rez_bind.count); +} + +// Adds an argument buffer padding argument image type as a member of the struct type at the member index. +void CompilerMSL::add_argument_buffer_padding_image_type(SPIRType &struct_type, uint32_t &mbr_idx, + uint32_t &arg_buff_index, MSLResourceBinding &rez_bind) +{ + if (!argument_buffer_padding_image_type_id) + { + uint32_t base_type_id = ir.increase_bound_by(2); + auto &base_type = set(base_type_id, OpTypeFloat); + base_type.basetype = SPIRType::Float; + base_type.width = 32; + + uint32_t img_type_id = base_type_id + 1; + auto &img_type = set(img_type_id, OpTypeImage); + img_type.basetype = SPIRType::Image; + img_type.storage = StorageClassUniformConstant; + + img_type.image.type = base_type_id; + img_type.image.dim = Dim2D; + img_type.image.depth = false; + img_type.image.arrayed = false; + img_type.image.ms = false; + img_type.image.sampled = 1; + img_type.image.format = ImageFormatUnknown; + img_type.image.access = AccessQualifierMax; + + argument_buffer_padding_image_type_id = img_type_id; + } + + add_argument_buffer_padding_type(argument_buffer_padding_image_type_id, struct_type, mbr_idx, arg_buff_index, rez_bind.count); +} + +// Adds an argument buffer padding argument sampler type as a member of the struct type at the member index. +void CompilerMSL::add_argument_buffer_padding_sampler_type(SPIRType &struct_type, uint32_t &mbr_idx, + uint32_t &arg_buff_index, MSLResourceBinding &rez_bind) +{ + if (!argument_buffer_padding_sampler_type_id) + { + uint32_t samp_type_id = ir.increase_bound_by(1); + auto &samp_type = set(samp_type_id, OpTypeSampler); + samp_type.basetype = SPIRType::Sampler; + samp_type.storage = StorageClassUniformConstant; + + argument_buffer_padding_sampler_type_id = samp_type_id; + } + + add_argument_buffer_padding_type(argument_buffer_padding_sampler_type_id, struct_type, mbr_idx, arg_buff_index, rez_bind.count); +} + +// Adds the argument buffer padding argument type as a member of the struct type at the member index. +// Advances both arg_buff_index and mbr_idx to next argument slots. +void CompilerMSL::add_argument_buffer_padding_type(uint32_t mbr_type_id, SPIRType &struct_type, uint32_t &mbr_idx, + uint32_t &arg_buff_index, uint32_t count) +{ + uint32_t type_id = mbr_type_id; + if (count > 1) + { + uint32_t ary_type_id = ir.increase_bound_by(1); + auto &ary_type = set(ary_type_id, get(type_id)); + ary_type.op = OpTypeArray; + ary_type.array.push_back(count); + ary_type.array_size_literal.push_back(true); + ary_type.parent_type = type_id; + type_id = ary_type_id; + } + + set_member_name(struct_type.self, mbr_idx, join("_m", arg_buff_index, "_pad")); + set_extended_member_decoration(struct_type.self, mbr_idx, SPIRVCrossDecorationResourceIndexPrimary, arg_buff_index); + struct_type.member_types.push_back(type_id); + + arg_buff_index += count; + mbr_idx++; +} + +void CompilerMSL::activate_argument_buffer_resources() +{ + // For ABI compatibility, force-enable all resources which are part of argument buffers. + ir.for_each_typed_id([&](uint32_t self, const SPIRVariable &) { + if (!has_decoration(self, DecorationDescriptorSet)) + return; + + uint32_t desc_set = get_decoration(self, DecorationDescriptorSet); + if (descriptor_set_is_argument_buffer(desc_set)) + add_active_interface_variable(self); + }); +} + +bool CompilerMSL::using_builtin_array() const +{ + return msl_options.force_native_arrays || is_using_builtin_array; +} + +void CompilerMSL::set_combined_sampler_suffix(const char *suffix) +{ + sampler_name_suffix = suffix; +} + +const char *CompilerMSL::get_combined_sampler_suffix() const +{ + return sampler_name_suffix.c_str(); +} + +void CompilerMSL::emit_block_hints(const SPIRBlock &) +{ +} + +string CompilerMSL::additional_fixed_sample_mask_str() const +{ + char print_buffer[32]; +#ifdef _MSC_VER + // snprintf does not exist or is buggy on older MSVC versions, some of + // them being used by MinGW. Use sprintf instead and disable + // corresponding warning. +#pragma warning(push) +#pragma warning(disable : 4996) +#endif +#if _WIN32 + sprintf(print_buffer, "0x%x", msl_options.additional_fixed_sample_mask); +#else + snprintf(print_buffer, sizeof(print_buffer), "0x%x", msl_options.additional_fixed_sample_mask); +#endif +#ifdef _MSC_VER +#pragma warning(pop) +#endif + return print_buffer; +} diff --git a/thirdparty/spirv-cross/spirv_msl.hpp b/thirdparty/spirv-cross/spirv_msl.hpp new file mode 100644 index 00000000000..2d970c0da5b --- /dev/null +++ b/thirdparty/spirv-cross/spirv_msl.hpp @@ -0,0 +1,1349 @@ +/* + * Copyright 2016-2021 The Brenwill Workshop Ltd. + * SPDX-License-Identifier: Apache-2.0 OR MIT + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +/* + * At your option, you may choose to accept this material under either: + * 1. The Apache License, Version 2.0, found at , or + * 2. The MIT License, found at . + */ + +#ifndef SPIRV_CROSS_MSL_HPP +#define SPIRV_CROSS_MSL_HPP + +#include "spirv_glsl.hpp" +#include +#include +#include +#include +#include + +namespace SPIRV_CROSS_NAMESPACE +{ + +// Indicates the format of a shader interface variable. Currently limited to specifying +// if the input is an 8-bit unsigned integer, 16-bit unsigned integer, or +// some other format. +enum MSLShaderVariableFormat +{ + MSL_SHADER_VARIABLE_FORMAT_OTHER = 0, + MSL_SHADER_VARIABLE_FORMAT_UINT8 = 1, + MSL_SHADER_VARIABLE_FORMAT_UINT16 = 2, + MSL_SHADER_VARIABLE_FORMAT_ANY16 = 3, + MSL_SHADER_VARIABLE_FORMAT_ANY32 = 4, + + // Deprecated aliases. + MSL_VERTEX_FORMAT_OTHER = MSL_SHADER_VARIABLE_FORMAT_OTHER, + MSL_VERTEX_FORMAT_UINT8 = MSL_SHADER_VARIABLE_FORMAT_UINT8, + MSL_VERTEX_FORMAT_UINT16 = MSL_SHADER_VARIABLE_FORMAT_UINT16, + MSL_SHADER_INPUT_FORMAT_OTHER = MSL_SHADER_VARIABLE_FORMAT_OTHER, + MSL_SHADER_INPUT_FORMAT_UINT8 = MSL_SHADER_VARIABLE_FORMAT_UINT8, + MSL_SHADER_INPUT_FORMAT_UINT16 = MSL_SHADER_VARIABLE_FORMAT_UINT16, + MSL_SHADER_INPUT_FORMAT_ANY16 = MSL_SHADER_VARIABLE_FORMAT_ANY16, + MSL_SHADER_INPUT_FORMAT_ANY32 = MSL_SHADER_VARIABLE_FORMAT_ANY32, + + MSL_SHADER_VARIABLE_FORMAT_INT_MAX = 0x7fffffff +}; + +// Indicates the rate at which a variable changes value, one of: per-vertex, +// per-primitive, or per-patch. +enum MSLShaderVariableRate +{ + MSL_SHADER_VARIABLE_RATE_PER_VERTEX = 0, + MSL_SHADER_VARIABLE_RATE_PER_PRIMITIVE = 1, + MSL_SHADER_VARIABLE_RATE_PER_PATCH = 2, + + MSL_SHADER_VARIABLE_RATE_INT_MAX = 0x7fffffff, +}; + +// Defines MSL characteristics of a shader interface variable at a particular location. +// After compilation, it is possible to query whether or not this location was used. +// If vecsize is nonzero, it must be greater than or equal to the vecsize declared in the shader, +// or behavior is undefined. +struct MSLShaderInterfaceVariable +{ + uint32_t location = 0; + uint32_t component = 0; + MSLShaderVariableFormat format = MSL_SHADER_VARIABLE_FORMAT_OTHER; + spv::BuiltIn builtin = spv::BuiltInMax; + uint32_t vecsize = 0; + MSLShaderVariableRate rate = MSL_SHADER_VARIABLE_RATE_PER_VERTEX; +}; + +// Matches the binding index of a MSL resource for a binding within a descriptor set. +// Taken together, the stage, desc_set and binding combine to form a reference to a resource +// descriptor used in a particular shading stage. The count field indicates the number of +// resources consumed by this binding, if the binding represents an array of resources. +// If the resource array is a run-time-sized array, which are legal in GLSL or SPIR-V, this value +// will be used to declare the array size in MSL, which does not support run-time-sized arrays. +// If pad_argument_buffer_resources is enabled, the base_type and count values are used to +// specify the base type and array size of the resource in the argument buffer, if that resource +// is not defined and used by the shader. With pad_argument_buffer_resources enabled, this +// information will be used to pad the argument buffer structure, in order to align that +// structure consistently for all uses, across all shaders, of the descriptor set represented +// by the arugment buffer. If pad_argument_buffer_resources is disabled, base_type does not +// need to be populated, and if the resource is also not a run-time sized array, the count +// field does not need to be populated. +// If using MSL 2.0 argument buffers, the descriptor set is not marked as a discrete descriptor set, +// and (for iOS only) the resource is not a storage image (sampled != 2), the binding reference we +// remap to will become an [[id(N)]] attribute within the "descriptor set" argument buffer structure. +// For resources which are bound in the "classic" MSL 1.0 way or discrete descriptors, the remap will +// become a [[buffer(N)]], [[texture(N)]] or [[sampler(N)]] depending on the resource types used. +struct MSLResourceBinding +{ + spv::ExecutionModel stage = spv::ExecutionModelMax; + SPIRType::BaseType basetype = SPIRType::Unknown; + uint32_t desc_set = 0; + uint32_t binding = 0; + uint32_t count = 0; + uint32_t msl_buffer = 0; + uint32_t msl_texture = 0; + uint32_t msl_sampler = 0; +}; + +enum MSLSamplerCoord +{ + MSL_SAMPLER_COORD_NORMALIZED = 0, + MSL_SAMPLER_COORD_PIXEL = 1, + MSL_SAMPLER_INT_MAX = 0x7fffffff +}; + +enum MSLSamplerFilter +{ + MSL_SAMPLER_FILTER_NEAREST = 0, + MSL_SAMPLER_FILTER_LINEAR = 1, + MSL_SAMPLER_FILTER_INT_MAX = 0x7fffffff +}; + +enum MSLSamplerMipFilter +{ + MSL_SAMPLER_MIP_FILTER_NONE = 0, + MSL_SAMPLER_MIP_FILTER_NEAREST = 1, + MSL_SAMPLER_MIP_FILTER_LINEAR = 2, + MSL_SAMPLER_MIP_FILTER_INT_MAX = 0x7fffffff +}; + +enum MSLSamplerAddress +{ + MSL_SAMPLER_ADDRESS_CLAMP_TO_ZERO = 0, + MSL_SAMPLER_ADDRESS_CLAMP_TO_EDGE = 1, + MSL_SAMPLER_ADDRESS_CLAMP_TO_BORDER = 2, + MSL_SAMPLER_ADDRESS_REPEAT = 3, + MSL_SAMPLER_ADDRESS_MIRRORED_REPEAT = 4, + MSL_SAMPLER_ADDRESS_INT_MAX = 0x7fffffff +}; + +enum MSLSamplerCompareFunc +{ + MSL_SAMPLER_COMPARE_FUNC_NEVER = 0, + MSL_SAMPLER_COMPARE_FUNC_LESS = 1, + MSL_SAMPLER_COMPARE_FUNC_LESS_EQUAL = 2, + MSL_SAMPLER_COMPARE_FUNC_GREATER = 3, + MSL_SAMPLER_COMPARE_FUNC_GREATER_EQUAL = 4, + MSL_SAMPLER_COMPARE_FUNC_EQUAL = 5, + MSL_SAMPLER_COMPARE_FUNC_NOT_EQUAL = 6, + MSL_SAMPLER_COMPARE_FUNC_ALWAYS = 7, + MSL_SAMPLER_COMPARE_FUNC_INT_MAX = 0x7fffffff +}; + +enum MSLSamplerBorderColor +{ + MSL_SAMPLER_BORDER_COLOR_TRANSPARENT_BLACK = 0, + MSL_SAMPLER_BORDER_COLOR_OPAQUE_BLACK = 1, + MSL_SAMPLER_BORDER_COLOR_OPAQUE_WHITE = 2, + MSL_SAMPLER_BORDER_COLOR_INT_MAX = 0x7fffffff +}; + +enum MSLFormatResolution +{ + MSL_FORMAT_RESOLUTION_444 = 0, + MSL_FORMAT_RESOLUTION_422, + MSL_FORMAT_RESOLUTION_420, + MSL_FORMAT_RESOLUTION_INT_MAX = 0x7fffffff +}; + +enum MSLChromaLocation +{ + MSL_CHROMA_LOCATION_COSITED_EVEN = 0, + MSL_CHROMA_LOCATION_MIDPOINT, + MSL_CHROMA_LOCATION_INT_MAX = 0x7fffffff +}; + +enum MSLComponentSwizzle +{ + MSL_COMPONENT_SWIZZLE_IDENTITY = 0, + MSL_COMPONENT_SWIZZLE_ZERO, + MSL_COMPONENT_SWIZZLE_ONE, + MSL_COMPONENT_SWIZZLE_R, + MSL_COMPONENT_SWIZZLE_G, + MSL_COMPONENT_SWIZZLE_B, + MSL_COMPONENT_SWIZZLE_A, + MSL_COMPONENT_SWIZZLE_INT_MAX = 0x7fffffff +}; + +enum MSLSamplerYCbCrModelConversion +{ + MSL_SAMPLER_YCBCR_MODEL_CONVERSION_RGB_IDENTITY = 0, + MSL_SAMPLER_YCBCR_MODEL_CONVERSION_YCBCR_IDENTITY, + MSL_SAMPLER_YCBCR_MODEL_CONVERSION_YCBCR_BT_709, + MSL_SAMPLER_YCBCR_MODEL_CONVERSION_YCBCR_BT_601, + MSL_SAMPLER_YCBCR_MODEL_CONVERSION_YCBCR_BT_2020, + MSL_SAMPLER_YCBCR_MODEL_CONVERSION_INT_MAX = 0x7fffffff +}; + +enum MSLSamplerYCbCrRange +{ + MSL_SAMPLER_YCBCR_RANGE_ITU_FULL = 0, + MSL_SAMPLER_YCBCR_RANGE_ITU_NARROW, + MSL_SAMPLER_YCBCR_RANGE_INT_MAX = 0x7fffffff +}; + +struct MSLConstexprSampler +{ + MSLSamplerCoord coord = MSL_SAMPLER_COORD_NORMALIZED; + MSLSamplerFilter min_filter = MSL_SAMPLER_FILTER_NEAREST; + MSLSamplerFilter mag_filter = MSL_SAMPLER_FILTER_NEAREST; + MSLSamplerMipFilter mip_filter = MSL_SAMPLER_MIP_FILTER_NONE; + MSLSamplerAddress s_address = MSL_SAMPLER_ADDRESS_CLAMP_TO_EDGE; + MSLSamplerAddress t_address = MSL_SAMPLER_ADDRESS_CLAMP_TO_EDGE; + MSLSamplerAddress r_address = MSL_SAMPLER_ADDRESS_CLAMP_TO_EDGE; + MSLSamplerCompareFunc compare_func = MSL_SAMPLER_COMPARE_FUNC_NEVER; + MSLSamplerBorderColor border_color = MSL_SAMPLER_BORDER_COLOR_TRANSPARENT_BLACK; + float lod_clamp_min = 0.0f; + float lod_clamp_max = 1000.0f; + int max_anisotropy = 1; + + // Sampler Y'CbCr conversion parameters + uint32_t planes = 0; + MSLFormatResolution resolution = MSL_FORMAT_RESOLUTION_444; + MSLSamplerFilter chroma_filter = MSL_SAMPLER_FILTER_NEAREST; + MSLChromaLocation x_chroma_offset = MSL_CHROMA_LOCATION_COSITED_EVEN; + MSLChromaLocation y_chroma_offset = MSL_CHROMA_LOCATION_COSITED_EVEN; + MSLComponentSwizzle swizzle[4]; // IDENTITY, IDENTITY, IDENTITY, IDENTITY + MSLSamplerYCbCrModelConversion ycbcr_model = MSL_SAMPLER_YCBCR_MODEL_CONVERSION_RGB_IDENTITY; + MSLSamplerYCbCrRange ycbcr_range = MSL_SAMPLER_YCBCR_RANGE_ITU_FULL; + uint32_t bpc = 8; + + bool compare_enable = false; + bool lod_clamp_enable = false; + bool anisotropy_enable = false; + bool ycbcr_conversion_enable = false; + + MSLConstexprSampler() + { + for (uint32_t i = 0; i < 4; i++) + swizzle[i] = MSL_COMPONENT_SWIZZLE_IDENTITY; + } + bool swizzle_is_identity() const + { + return (swizzle[0] == MSL_COMPONENT_SWIZZLE_IDENTITY && swizzle[1] == MSL_COMPONENT_SWIZZLE_IDENTITY && + swizzle[2] == MSL_COMPONENT_SWIZZLE_IDENTITY && swizzle[3] == MSL_COMPONENT_SWIZZLE_IDENTITY); + } + bool swizzle_has_one_or_zero() const + { + return (swizzle[0] == MSL_COMPONENT_SWIZZLE_ZERO || swizzle[0] == MSL_COMPONENT_SWIZZLE_ONE || + swizzle[1] == MSL_COMPONENT_SWIZZLE_ZERO || swizzle[1] == MSL_COMPONENT_SWIZZLE_ONE || + swizzle[2] == MSL_COMPONENT_SWIZZLE_ZERO || swizzle[2] == MSL_COMPONENT_SWIZZLE_ONE || + swizzle[3] == MSL_COMPONENT_SWIZZLE_ZERO || swizzle[3] == MSL_COMPONENT_SWIZZLE_ONE); + } +}; + +// Special constant used in a MSLResourceBinding desc_set +// element to indicate the bindings for the push constants. +// Kinda deprecated. Just use ResourceBindingPushConstant{DescriptorSet,Binding} directly. +static const uint32_t kPushConstDescSet = ResourceBindingPushConstantDescriptorSet; + +// Special constant used in a MSLResourceBinding binding +// element to indicate the bindings for the push constants. +// Kinda deprecated. Just use ResourceBindingPushConstant{DescriptorSet,Binding} directly. +static const uint32_t kPushConstBinding = ResourceBindingPushConstantBinding; + +// Special constant used in a MSLResourceBinding binding +// element to indicate the buffer binding for swizzle buffers. +static const uint32_t kSwizzleBufferBinding = ~(1u); + +// Special constant used in a MSLResourceBinding binding +// element to indicate the buffer binding for buffer size buffers to support OpArrayLength. +static const uint32_t kBufferSizeBufferBinding = ~(2u); + +// Special constant used in a MSLResourceBinding binding +// element to indicate the buffer binding used for the argument buffer itself. +// This buffer binding should be kept as small as possible as all automatic bindings for buffers +// will start at max(kArgumentBufferBinding) + 1. +static const uint32_t kArgumentBufferBinding = ~(3u); + +static const uint32_t kMaxArgumentBuffers = 8; + +// Decompiles SPIR-V to Metal Shading Language +class CompilerMSL : public CompilerGLSL +{ +public: + // Options for compiling to Metal Shading Language + struct Options + { + typedef enum + { + iOS = 0, + macOS = 1 + } Platform; + + Platform platform = macOS; + uint32_t msl_version = make_msl_version(1, 2); + uint32_t texel_buffer_texture_width = 4096; // Width of 2D Metal textures used as 1D texel buffers + uint32_t r32ui_linear_texture_alignment = 4; + uint32_t r32ui_alignment_constant_id = 65535; + uint32_t swizzle_buffer_index = 30; + uint32_t indirect_params_buffer_index = 29; + uint32_t shader_output_buffer_index = 28; + uint32_t shader_patch_output_buffer_index = 27; + uint32_t shader_tess_factor_buffer_index = 26; + uint32_t buffer_size_buffer_index = 25; + uint32_t view_mask_buffer_index = 24; + uint32_t dynamic_offsets_buffer_index = 23; + uint32_t shader_input_buffer_index = 22; + uint32_t shader_index_buffer_index = 21; + uint32_t shader_patch_input_buffer_index = 20; + uint32_t shader_input_wg_index = 0; + uint32_t device_index = 0; + uint32_t enable_frag_output_mask = 0xffffffff; + // Metal doesn't allow setting a fixed sample mask directly in the pipeline. + // We can evade this restriction by ANDing the internal sample_mask output + // of the shader with the additional fixed sample mask. + uint32_t additional_fixed_sample_mask = 0xffffffff; + bool enable_point_size_builtin = true; + bool enable_frag_depth_builtin = true; + bool enable_frag_stencil_ref_builtin = true; + bool disable_rasterization = false; + bool capture_output_to_buffer = false; + bool swizzle_texture_samples = false; + bool tess_domain_origin_lower_left = false; + bool multiview = false; + bool multiview_layered_rendering = true; + bool view_index_from_device_index = false; + bool dispatch_base = false; + bool texture_1D_as_2D = false; + + // Enable use of Metal argument buffers. + // MSL 2.0 must also be enabled. + bool argument_buffers = false; + + // Defines Metal argument buffer tier levels. + // Uses same values as Metal MTLArgumentBuffersTier enumeration. + enum class ArgumentBuffersTier + { + Tier1 = 0, + Tier2 = 1, + }; + + // When using Metal argument buffers, indicates the Metal argument buffer tier level supported by the Metal platform. + // Ignored when Options::argument_buffers is disabled. + // - Tier1 supports writable images on macOS, but not on iOS. + // - Tier2 supports writable images on macOS and iOS, and higher resource count limits. + // Tier capabilities based on recommendations from Apple engineering. + ArgumentBuffersTier argument_buffers_tier = ArgumentBuffersTier::Tier1; + + // Enables specifick argument buffer format with extra information to track SSBO-length + bool runtime_array_rich_descriptor = false; + + // Ensures vertex and instance indices start at zero. This reflects the behavior of HLSL with SV_VertexID and SV_InstanceID. + bool enable_base_index_zero = false; + + // Fragment output in MSL must have at least as many components as the render pass. + // Add support to explicit pad out components. + bool pad_fragment_output_components = false; + + // Specifies whether the iOS target version supports the [[base_vertex]] and [[base_instance]] attributes. + bool ios_support_base_vertex_instance = false; + + // Use Metal's native frame-buffer fetch API for subpass inputs. + bool use_framebuffer_fetch_subpasses = false; + + // Enables use of "fma" intrinsic for invariant float math + bool invariant_float_math = false; + + // Emulate texturecube_array with texture2d_array for iOS where this type is not available + bool emulate_cube_array = false; + + // Allow user to enable decoration binding + bool enable_decoration_binding = false; + + // Requires MSL 2.1, use the native support for texel buffers. + bool texture_buffer_native = false; + + // Forces all resources which are part of an argument buffer to be considered active. + // This ensures ABI compatibility between shaders where some resources might be unused, + // and would otherwise declare a different IAB. + bool force_active_argument_buffer_resources = false; + + // Aligns each resource in an argument buffer to its assigned index value, id(N), + // by adding synthetic padding members in the argument buffer struct for any resources + // in the argument buffer that are not defined and used by the shader. This allows + // the shader to index into the correct argument in a descriptor set argument buffer + // that is shared across shaders, where not all resources in the argument buffer are + // defined in each shader. For this to work, an MSLResourceBinding must be provided for + // all descriptors in any descriptor set held in an argument buffer in the shader, and + // that MSLResourceBinding must have the basetype and count members populated correctly. + // The implementation here assumes any inline blocks in the argument buffer is provided + // in a Metal buffer, and doesn't take into consideration inline blocks that are + // optionally embedded directly into the argument buffer via add_inline_uniform_block(). + bool pad_argument_buffer_resources = false; + + // Forces the use of plain arrays, which works around certain driver bugs on certain versions + // of Intel Macbooks. See https://github.com/KhronosGroup/SPIRV-Cross/issues/1210. + // May reduce performance in scenarios where arrays are copied around as value-types. + bool force_native_arrays = false; + + // If a shader writes clip distance, also emit user varyings which + // can be read in subsequent stages. + bool enable_clip_distance_user_varying = true; + + // In a tessellation control shader, assume that more than one patch can be processed in a + // single workgroup. This requires changes to the way the InvocationId and PrimitiveId + // builtins are processed, but should result in more efficient usage of the GPU. + bool multi_patch_workgroup = false; + + // Use storage buffers instead of vertex-style attributes for tessellation evaluation + // input. This may require conversion of inputs in the generated post-tessellation + // vertex shader, but allows the use of nested arrays. + bool raw_buffer_tese_input = false; + + // If set, a vertex shader will be compiled as part of a tessellation pipeline. + // It will be translated as a compute kernel, so it can use the global invocation ID + // to index the output buffer. + bool vertex_for_tessellation = false; + + // Assume that SubpassData images have multiple layers. Layered input attachments + // are addressed relative to the Layer output from the vertex pipeline. This option + // has no effect with multiview, since all input attachments are assumed to be layered + // and will be addressed using the current ViewIndex. + bool arrayed_subpass_input = false; + + // Whether to use SIMD-group or quadgroup functions to implement group non-uniform + // operations. Some GPUs on iOS do not support the SIMD-group functions, only the + // quadgroup functions. + bool ios_use_simdgroup_functions = false; + + // If set, the subgroup size will be assumed to be one, and subgroup-related + // builtins and operations will be emitted accordingly. This mode is intended to + // be used by MoltenVK on hardware/software configurations which do not provide + // sufficient support for subgroups. + bool emulate_subgroups = false; + + // If nonzero, a fixed subgroup size to assume. Metal, similarly to VK_EXT_subgroup_size_control, + // allows the SIMD-group size (aka thread execution width) to vary depending on + // register usage and requirements. In certain circumstances--for example, a pipeline + // in MoltenVK without VK_PIPELINE_SHADER_STAGE_CREATE_ALLOW_VARYING_SUBGROUP_SIZE_BIT_EXT-- + // this is undesirable. This fixes the value of the SubgroupSize builtin, instead of + // mapping it to the Metal builtin [[thread_execution_width]]. If the thread + // execution width is reduced, the extra invocations will appear to be inactive. + // If zero, the SubgroupSize will be allowed to vary, and the builtin will be mapped + // to the Metal [[thread_execution_width]] builtin. + uint32_t fixed_subgroup_size = 0; + + enum class IndexType + { + None = 0, + UInt16 = 1, + UInt32 = 2 + }; + + // The type of index in the index buffer, if present. For a compute shader, Metal + // requires specifying the indexing at pipeline creation, rather than at draw time + // as with graphics pipelines. This means we must create three different pipelines, + // for no indexing, 16-bit indices, and 32-bit indices. Each requires different + // handling for the gl_VertexIndex builtin. We may as well, then, create three + // different shaders for these three scenarios. + IndexType vertex_index_type = IndexType::None; + + // If set, a dummy [[sample_id]] input is added to a fragment shader if none is present. + // This will force the shader to run at sample rate, assuming Metal does not optimize + // the extra threads away. + bool force_sample_rate_shading = false; + + // If set, gl_HelperInvocation will be set manually whenever a fragment is discarded. + // Some Metal devices have a bug where simd_is_helper_thread() does not return true + // after a fragment has been discarded. This is a workaround that is only expected to be needed + // until the bug is fixed in Metal; it is provided as an option to allow disabling it when that occurs. + bool manual_helper_invocation_updates = true; + + // If set, extra checks will be emitted in fragment shaders to prevent writes + // from discarded fragments. Some Metal devices have a bug where writes to storage resources + // from discarded fragment threads continue to occur, despite the fragment being + // discarded. This is a workaround that is only expected to be needed until the + // bug is fixed in Metal; it is provided as an option so it can be enabled + // only when the bug is present. + bool check_discarded_frag_stores = false; + + // If set, Lod operands to OpImageSample*DrefExplicitLod for 1D and 2D array images + // will be implemented using a gradient instead of passing the level operand directly. + // Some Metal devices have a bug where the level() argument to depth2d_array::sample_compare() + // in a fragment shader is biased by some unknown amount, possibly dependent on the + // partial derivatives of the texture coordinates. This is a workaround that is only + // expected to be needed until the bug is fixed in Metal; it is provided as an option + // so it can be enabled only when the bug is present. + bool sample_dref_lod_array_as_grad = false; + + // MSL doesn't guarantee coherence between writes and subsequent reads of read_write textures. + // This inserts fences before each read of a read_write texture to ensure coherency. + // If you're sure you never rely on this, you can set this to false for a possible performance improvement. + // Note: Only Apple's GPU compiler takes advantage of the lack of coherency, so make sure to test on Apple GPUs if you disable this. + bool readwrite_texture_fences = true; + + // Metal 3.1 introduced a Metal regression bug which causes infinite recursion during + // Metal's analysis of an entry point input structure that is itself recursive. Enabling + // this option will replace the recursive input declaration with a alternate variable of + // type void*, and then cast to the correct type at the top of the entry point function. + // The bug has been reported to Apple, and will hopefully be fixed in future releases. + bool replace_recursive_inputs = false; + + // If set, manual fixups of gradient vectors for cube texture lookups will be performed. + // All released Apple Silicon GPUs to date behave incorrectly when sampling a cube texture + // with explicit gradients. They will ignore one of the three partial derivatives based + // on the selected major axis, and expect the remaining derivatives to be partially + // transformed. + bool agx_manual_cube_grad_fixup = false; + + // Metal will discard fragments with side effects under certain circumstances prematurely. + // Example: CTS test dEQP-VK.fragment_operations.early_fragment.discard_no_early_fragment_tests_depth + // Test will render a full screen quad with varying depth [0,1] for each fragment. + // Each fragment will do an operation with side effects, modify the depth value and + // discard the fragment. The test expects the fragment to be run due to: + // https://registry.khronos.org/vulkan/specs/1.0-extensions/html/vkspec.html#fragops-shader-depthreplacement + // which states that the fragment shader must be run due to replacing the depth in shader. + // However, Metal may prematurely discards fragments without executing them + // (I believe this to be due to a greedy optimization on their end) making the test fail. + // This option enforces fragment execution for such cases where the fragment has operations + // with side effects. Provided as an option hoping Metal will fix this issue in the future. + bool force_fragment_with_side_effects_execution = false; + + // If set, adds a depth pass through statement to circumvent the following issue: + // When the same depth/stencil is used as input and depth/stencil attachment, we need to + // force Metal to perform the depth/stencil write after fragment execution. Otherwise, + // Metal will write to the depth attachment before fragment execution. This happens + // if the fragment does not modify the depth value. + bool input_attachment_is_ds_attachment = false; + + bool is_ios() const + { + return platform == iOS; + } + + bool is_macos() const + { + return platform == macOS; + } + + bool use_quadgroup_operation() const + { + return is_ios() && !ios_use_simdgroup_functions; + } + + void set_msl_version(uint32_t major, uint32_t minor = 0, uint32_t patch = 0) + { + msl_version = make_msl_version(major, minor, patch); + } + + bool supports_msl_version(uint32_t major, uint32_t minor = 0, uint32_t patch = 0) const + { + return msl_version >= make_msl_version(major, minor, patch); + } + + static uint32_t make_msl_version(uint32_t major, uint32_t minor = 0, uint32_t patch = 0) + { + return (major * 10000) + (minor * 100) + patch; + } + }; + + const Options &get_msl_options() const + { + return msl_options; + } + + void set_msl_options(const Options &opts) + { + msl_options = opts; + } + + // Provide feedback to calling API to allow runtime to disable pipeline + // rasterization if vertex shader requires rasterization to be disabled. + bool get_is_rasterization_disabled() const + { + return is_rasterization_disabled && (get_entry_point().model == spv::ExecutionModelVertex || + get_entry_point().model == spv::ExecutionModelTessellationControl || + get_entry_point().model == spv::ExecutionModelTessellationEvaluation); + } + + // Provide feedback to calling API to allow it to pass an auxiliary + // swizzle buffer if the shader needs it. + bool needs_swizzle_buffer() const + { + return used_swizzle_buffer; + } + + // Provide feedback to calling API to allow it to pass a buffer + // containing STORAGE_BUFFER buffer sizes to support OpArrayLength. + bool needs_buffer_size_buffer() const + { + return !buffers_requiring_array_length.empty(); + } + + bool buffer_requires_array_length(VariableID id) const + { + return buffers_requiring_array_length.count(id) != 0; + } + + // Provide feedback to calling API to allow it to pass a buffer + // containing the view mask for the current multiview subpass. + bool needs_view_mask_buffer() const + { + return msl_options.multiview && !msl_options.view_index_from_device_index; + } + + // Provide feedback to calling API to allow it to pass a buffer + // containing the dispatch base workgroup ID. + bool needs_dispatch_base_buffer() const + { + return msl_options.dispatch_base && !msl_options.supports_msl_version(1, 2); + } + + // Provide feedback to calling API to allow it to pass an output + // buffer if the shader needs it. + bool needs_output_buffer() const + { + return capture_output_to_buffer && stage_out_var_id != ID(0); + } + + // Provide feedback to calling API to allow it to pass a patch output + // buffer if the shader needs it. + bool needs_patch_output_buffer() const + { + return capture_output_to_buffer && patch_stage_out_var_id != ID(0); + } + + // Provide feedback to calling API to allow it to pass an input threadgroup + // buffer if the shader needs it. + bool needs_input_threadgroup_mem() const + { + return capture_output_to_buffer && stage_in_var_id != ID(0); + } + + explicit CompilerMSL(std::vector spirv); + CompilerMSL(const uint32_t *ir, size_t word_count); + explicit CompilerMSL(const ParsedIR &ir); + explicit CompilerMSL(ParsedIR &&ir); + + // input is a shader interface variable description used to fix up shader input variables. + // If shader inputs are provided, is_msl_shader_input_used() will return true after + // calling ::compile() if the location were used by the MSL code. + void add_msl_shader_input(const MSLShaderInterfaceVariable &input); + + // output is a shader interface variable description used to fix up shader output variables. + // If shader outputs are provided, is_msl_shader_output_used() will return true after + // calling ::compile() if the location were used by the MSL code. + void add_msl_shader_output(const MSLShaderInterfaceVariable &output); + + // resource is a resource binding to indicate the MSL buffer, + // texture or sampler index to use for a particular SPIR-V description set + // and binding. If resource bindings are provided, + // is_msl_resource_binding_used() will return true after calling ::compile() if + // the set/binding combination was used by the MSL code. + void add_msl_resource_binding(const MSLResourceBinding &resource); + + // desc_set and binding are the SPIR-V descriptor set and binding of a buffer resource + // in this shader. index is the index within the dynamic offset buffer to use. This + // function marks that resource as using a dynamic offset (VK_DESCRIPTOR_TYPE_UNIFORM_BUFFER_DYNAMIC + // or VK_DESCRIPTOR_TYPE_STORAGE_BUFFER_DYNAMIC). This function only has any effect if argument buffers + // are enabled. If so, the buffer will have its address adjusted at the beginning of the shader with + // an offset taken from the dynamic offset buffer. + void add_dynamic_buffer(uint32_t desc_set, uint32_t binding, uint32_t index); + + // desc_set and binding are the SPIR-V descriptor set and binding of a buffer resource + // in this shader. This function marks that resource as an inline uniform block + // (VK_DESCRIPTOR_TYPE_INLINE_UNIFORM_BLOCK_EXT). This function only has any effect if argument buffers + // are enabled. If so, the buffer block will be directly embedded into the argument + // buffer, instead of being referenced indirectly via pointer. + void add_inline_uniform_block(uint32_t desc_set, uint32_t binding); + + // When using MSL argument buffers, we can force "classic" MSL 1.0 binding schemes for certain descriptor sets. + // This corresponds to VK_KHR_push_descriptor in Vulkan. + void add_discrete_descriptor_set(uint32_t desc_set); + + // If an argument buffer is large enough, it may need to be in the device storage space rather than + // constant. Opt-in to this behavior here on a per set basis. + void set_argument_buffer_device_address_space(uint32_t desc_set, bool device_storage); + + // Query after compilation is done. This allows you to check if an input location was used by the shader. + bool is_msl_shader_input_used(uint32_t location); + + // Query after compilation is done. This allows you to check if an output location were used by the shader. + bool is_msl_shader_output_used(uint32_t location); + + // If not using add_msl_shader_input, it's possible + // that certain builtin attributes need to be automatically assigned locations. + // This is typical for tessellation builtin inputs such as tess levels, gl_Position, etc. + // This returns k_unknown_location if the location was explicitly assigned with + // add_msl_shader_input or the builtin is not used, otherwise returns N in [[attribute(N)]]. + uint32_t get_automatic_builtin_input_location(spv::BuiltIn builtin) const; + + // If not using add_msl_shader_output, it's possible + // that certain builtin attributes need to be automatically assigned locations. + // This is typical for tessellation builtin outputs such as tess levels, gl_Position, etc. + // This returns k_unknown_location if the location were explicitly assigned with + // add_msl_shader_output or the builtin were not used, otherwise returns N in [[attribute(N)]]. + uint32_t get_automatic_builtin_output_location(spv::BuiltIn builtin) const; + + // NOTE: Only resources which are remapped using add_msl_resource_binding will be reported here. + // Constexpr samplers are always assumed to be emitted. + // No specific MSLResourceBinding remapping is required for constexpr samplers as long as they are remapped + // by remap_constexpr_sampler(_by_binding). + bool is_msl_resource_binding_used(spv::ExecutionModel model, uint32_t set, uint32_t binding) const; + + // This must only be called after a successful call to CompilerMSL::compile(). + // For a variable resource ID obtained through reflection API, report the automatically assigned resource index. + // If the descriptor set was part of an argument buffer, report the [[id(N)]], + // or [[buffer/texture/sampler]] binding for other resources. + // If the resource was a combined image sampler, report the image binding here, + // use the _secondary version of this call to query the sampler half of the resource. + // If no binding exists, uint32_t(-1) is returned. + uint32_t get_automatic_msl_resource_binding(uint32_t id) const; + + // Same as get_automatic_msl_resource_binding, but should only be used for combined image samplers, in which case the + // sampler's binding is returned instead. For any other resource type, -1 is returned. + // Secondary bindings are also used for the auxillary image atomic buffer. + uint32_t get_automatic_msl_resource_binding_secondary(uint32_t id) const; + + // Same as get_automatic_msl_resource_binding, but should only be used for combined image samplers for multiplanar images, + // in which case the second plane's binding is returned instead. For any other resource type, -1 is returned. + uint32_t get_automatic_msl_resource_binding_tertiary(uint32_t id) const; + + // Same as get_automatic_msl_resource_binding, but should only be used for combined image samplers for triplanar images, + // in which case the third plane's binding is returned instead. For any other resource type, -1 is returned. + uint32_t get_automatic_msl_resource_binding_quaternary(uint32_t id) const; + + // Compiles the SPIR-V code into Metal Shading Language. + std::string compile() override; + + // Remap a sampler with ID to a constexpr sampler. + // Older iOS targets must use constexpr samplers in certain cases (PCF), + // so a static sampler must be used. + // The sampler will not consume a binding, but be declared in the entry point as a constexpr sampler. + // This can be used on both combined image/samplers (sampler2D) or standalone samplers. + // The remapped sampler must not be an array of samplers. + // Prefer remap_constexpr_sampler_by_binding unless you're also doing reflection anyways. + void remap_constexpr_sampler(VariableID id, const MSLConstexprSampler &sampler); + + // Same as remap_constexpr_sampler, except you provide set/binding, rather than variable ID. + // Remaps based on ID take priority over set/binding remaps. + void remap_constexpr_sampler_by_binding(uint32_t desc_set, uint32_t binding, const MSLConstexprSampler &sampler); + + // If using CompilerMSL::Options::pad_fragment_output_components, override the number of components we expect + // to use for a particular location. The default is 4 if number of components is not overridden. + void set_fragment_output_components(uint32_t location, uint32_t components); + + void set_combined_sampler_suffix(const char *suffix); + const char *get_combined_sampler_suffix() const; + +protected: + // An enum of SPIR-V functions that are implemented in additional + // source code that is added to the shader if necessary. + enum SPVFuncImpl : uint8_t + { + SPVFuncImplNone, + SPVFuncImplMod, + SPVFuncImplRadians, + SPVFuncImplDegrees, + SPVFuncImplFindILsb, + SPVFuncImplFindSMsb, + SPVFuncImplFindUMsb, + SPVFuncImplSSign, + SPVFuncImplArrayCopy, + SPVFuncImplArrayCopyMultidim, + SPVFuncImplTexelBufferCoords, + SPVFuncImplImage2DAtomicCoords, // Emulate texture2D atomic operations + SPVFuncImplGradientCube, + SPVFuncImplFMul, + SPVFuncImplFAdd, + SPVFuncImplFSub, + SPVFuncImplQuantizeToF16, + SPVFuncImplCubemapTo2DArrayFace, + SPVFuncImplUnsafeArray, // Allow Metal to use the array template to make arrays a value type + SPVFuncImplStorageMatrix, // Allow threadgroup construction of matrices + SPVFuncImplInverse4x4, + SPVFuncImplInverse3x3, + SPVFuncImplInverse2x2, + // It is very important that this come before *Swizzle and ChromaReconstruct*, to ensure it's + // emitted before them. + SPVFuncImplForwardArgs, + // Likewise, this must come before *Swizzle. + SPVFuncImplGetSwizzle, + SPVFuncImplTextureSwizzle, + SPVFuncImplGatherSwizzle, + SPVFuncImplGatherCompareSwizzle, + SPVFuncImplGatherConstOffsets, + SPVFuncImplGatherCompareConstOffsets, + SPVFuncImplSubgroupBroadcast, + SPVFuncImplSubgroupBroadcastFirst, + SPVFuncImplSubgroupBallot, + SPVFuncImplSubgroupBallotBitExtract, + SPVFuncImplSubgroupBallotFindLSB, + SPVFuncImplSubgroupBallotFindMSB, + SPVFuncImplSubgroupBallotBitCount, + SPVFuncImplSubgroupAllEqual, + SPVFuncImplSubgroupShuffle, + SPVFuncImplSubgroupShuffleXor, + SPVFuncImplSubgroupShuffleUp, + SPVFuncImplSubgroupShuffleDown, + SPVFuncImplQuadBroadcast, + SPVFuncImplQuadSwap, + SPVFuncImplReflectScalar, + SPVFuncImplRefractScalar, + SPVFuncImplFaceForwardScalar, + SPVFuncImplChromaReconstructNearest2Plane, + SPVFuncImplChromaReconstructNearest3Plane, + SPVFuncImplChromaReconstructLinear422CositedEven2Plane, + SPVFuncImplChromaReconstructLinear422CositedEven3Plane, + SPVFuncImplChromaReconstructLinear422Midpoint2Plane, + SPVFuncImplChromaReconstructLinear422Midpoint3Plane, + SPVFuncImplChromaReconstructLinear420XCositedEvenYCositedEven2Plane, + SPVFuncImplChromaReconstructLinear420XCositedEvenYCositedEven3Plane, + SPVFuncImplChromaReconstructLinear420XMidpointYCositedEven2Plane, + SPVFuncImplChromaReconstructLinear420XMidpointYCositedEven3Plane, + SPVFuncImplChromaReconstructLinear420XCositedEvenYMidpoint2Plane, + SPVFuncImplChromaReconstructLinear420XCositedEvenYMidpoint3Plane, + SPVFuncImplChromaReconstructLinear420XMidpointYMidpoint2Plane, + SPVFuncImplChromaReconstructLinear420XMidpointYMidpoint3Plane, + SPVFuncImplExpandITUFullRange, + SPVFuncImplExpandITUNarrowRange, + SPVFuncImplConvertYCbCrBT709, + SPVFuncImplConvertYCbCrBT601, + SPVFuncImplConvertYCbCrBT2020, + SPVFuncImplDynamicImageSampler, + SPVFuncImplRayQueryIntersectionParams, + SPVFuncImplVariableDescriptor, + SPVFuncImplVariableSizedDescriptor, + SPVFuncImplVariableDescriptorArray, + SPVFuncImplPaddedStd140, + SPVFuncImplReduceAdd, + SPVFuncImplImageFence, + SPVFuncImplTextureCast + }; + + // If the underlying resource has been used for comparison then duplicate loads of that resource must be too + // Use Metal's native frame-buffer fetch API for subpass inputs. + void emit_texture_op(const Instruction &i, bool sparse) override; + void emit_binary_ptr_op(uint32_t result_type, uint32_t result_id, uint32_t op0, uint32_t op1, const char *op); + std::string to_ptr_expression(uint32_t id, bool register_expression_read = true); + void emit_binary_unord_op(uint32_t result_type, uint32_t result_id, uint32_t op0, uint32_t op1, const char *op); + void emit_instruction(const Instruction &instr) override; + void emit_glsl_op(uint32_t result_type, uint32_t result_id, uint32_t op, const uint32_t *args, + uint32_t count) override; + void emit_spv_amd_shader_trinary_minmax_op(uint32_t result_type, uint32_t result_id, uint32_t op, + const uint32_t *args, uint32_t count) override; + void emit_header() override; + void emit_function_prototype(SPIRFunction &func, const Bitset &return_flags) override; + void emit_sampled_image_op(uint32_t result_type, uint32_t result_id, uint32_t image_id, uint32_t samp_id) override; + void emit_subgroup_op(const Instruction &i) override; + std::string to_texture_op(const Instruction &i, bool sparse, bool *forward, + SmallVector &inherited_expressions) override; + void emit_fixup() override; + std::string to_struct_member(const SPIRType &type, uint32_t member_type_id, uint32_t index, + const std::string &qualifier = ""); + void emit_struct_member(const SPIRType &type, uint32_t member_type_id, uint32_t index, + const std::string &qualifier = "", uint32_t base_offset = 0) override; + void emit_struct_padding_target(const SPIRType &type) override; + std::string type_to_glsl(const SPIRType &type, uint32_t id, bool member); + std::string type_to_glsl(const SPIRType &type, uint32_t id = 0) override; + void emit_block_hints(const SPIRBlock &block) override; + + // Allow Metal to use the array template to make arrays a value type + std::string type_to_array_glsl(const SPIRType &type, uint32_t variable_id) override; + std::string constant_op_expression(const SPIRConstantOp &cop) override; + + bool variable_decl_is_remapped_storage(const SPIRVariable &variable, spv::StorageClass storage) const override; + + // GCC workaround of lambdas calling protected functions (for older GCC versions) + std::string variable_decl(const SPIRType &type, const std::string &name, uint32_t id = 0) override; + + std::string image_type_glsl(const SPIRType &type, uint32_t id, bool member) override; + std::string sampler_type(const SPIRType &type, uint32_t id, bool member); + std::string builtin_to_glsl(spv::BuiltIn builtin, spv::StorageClass storage) override; + std::string to_func_call_arg(const SPIRFunction::Parameter &arg, uint32_t id) override; + std::string to_name(uint32_t id, bool allow_alias = true) const override; + std::string to_function_name(const TextureFunctionNameArguments &args) override; + std::string to_function_args(const TextureFunctionArguments &args, bool *p_forward) override; + std::string to_initializer_expression(const SPIRVariable &var) override; + std::string to_zero_initialized_expression(uint32_t type_id) override; + + std::string unpack_expression_type(std::string expr_str, const SPIRType &type, uint32_t physical_type_id, + bool is_packed, bool row_major) override; + + // Returns true for BuiltInSampleMask because gl_SampleMask[] is an array in SPIR-V, but [[sample_mask]] is a scalar in Metal. + bool builtin_translates_to_nonarray(spv::BuiltIn builtin) const override; + + std::string bitcast_glsl_op(const SPIRType &result_type, const SPIRType &argument_type) override; + bool emit_complex_bitcast(uint32_t result_id, uint32_t id, uint32_t op0) override; + bool skip_argument(uint32_t id) const override; + std::string to_member_reference(uint32_t base, const SPIRType &type, uint32_t index, bool ptr_chain_is_resolved) override; + std::string to_qualifiers_glsl(uint32_t id) override; + void replace_illegal_names() override; + void declare_constant_arrays(); + + void replace_illegal_entry_point_names(); + void sync_entry_point_aliases_and_names(); + + static const std::unordered_set &get_reserved_keyword_set(); + static const std::unordered_set &get_illegal_func_names(); + + // Constant arrays of non-primitive types (i.e. matrices) won't link properly into Metal libraries + void declare_complex_constant_arrays(); + + bool is_patch_block(const SPIRType &type); + bool is_non_native_row_major_matrix(uint32_t id) override; + bool member_is_non_native_row_major_matrix(const SPIRType &type, uint32_t index) override; + std::string convert_row_major_matrix(std::string exp_str, const SPIRType &exp_type, uint32_t physical_type_id, + bool is_packed, bool relaxed) override; + + bool is_tesc_shader() const; + bool is_tese_shader() const; + + void preprocess_op_codes(); + void localize_global_variables(); + void extract_global_variables_from_functions(); + void mark_packable_structs(); + void mark_as_packable(SPIRType &type); + void mark_as_workgroup_struct(SPIRType &type); + + std::unordered_map> function_global_vars; + void extract_global_variables_from_function(uint32_t func_id, std::set &added_arg_ids, + std::unordered_set &global_var_ids, + std::unordered_set &processed_func_ids); + uint32_t add_interface_block(spv::StorageClass storage, bool patch = false); + uint32_t add_interface_block_pointer(uint32_t ib_var_id, spv::StorageClass storage); + + struct InterfaceBlockMeta + { + struct LocationMeta + { + uint32_t base_type_id = 0; + uint32_t num_components = 0; + bool flat = false; + bool noperspective = false; + bool centroid = false; + bool sample = false; + }; + std::unordered_map location_meta; + bool strip_array = false; + bool allow_local_declaration = false; + }; + + std::string to_tesc_invocation_id(); + void emit_local_masked_variable(const SPIRVariable &masked_var, bool strip_array); + void add_variable_to_interface_block(spv::StorageClass storage, const std::string &ib_var_ref, SPIRType &ib_type, + SPIRVariable &var, InterfaceBlockMeta &meta); + void add_composite_variable_to_interface_block(spv::StorageClass storage, const std::string &ib_var_ref, + SPIRType &ib_type, SPIRVariable &var, InterfaceBlockMeta &meta); + void add_plain_variable_to_interface_block(spv::StorageClass storage, const std::string &ib_var_ref, + SPIRType &ib_type, SPIRVariable &var, InterfaceBlockMeta &meta); + bool add_component_variable_to_interface_block(spv::StorageClass storage, const std::string &ib_var_ref, + SPIRVariable &var, const SPIRType &type, + InterfaceBlockMeta &meta); + void add_plain_member_variable_to_interface_block(spv::StorageClass storage, + const std::string &ib_var_ref, SPIRType &ib_type, + SPIRVariable &var, SPIRType &var_type, + uint32_t mbr_idx, InterfaceBlockMeta &meta, + const std::string &mbr_name_qual, + const std::string &var_chain_qual, + uint32_t &location, uint32_t &var_mbr_idx); + void add_composite_member_variable_to_interface_block(spv::StorageClass storage, + const std::string &ib_var_ref, SPIRType &ib_type, + SPIRVariable &var, SPIRType &var_type, + uint32_t mbr_idx, InterfaceBlockMeta &meta, + const std::string &mbr_name_qual, + const std::string &var_chain_qual, + uint32_t &location, uint32_t &var_mbr_idx, + const Bitset &interpolation_qual); + void add_tess_level_input_to_interface_block(const std::string &ib_var_ref, SPIRType &ib_type, SPIRVariable &var); + void add_tess_level_input(const std::string &base_ref, const std::string &mbr_name, SPIRVariable &var); + + void fix_up_interface_member_indices(spv::StorageClass storage, uint32_t ib_type_id); + + void mark_location_as_used_by_shader(uint32_t location, const SPIRType &type, + spv::StorageClass storage, bool fallback = false); + uint32_t ensure_correct_builtin_type(uint32_t type_id, spv::BuiltIn builtin); + uint32_t ensure_correct_input_type(uint32_t type_id, uint32_t location, uint32_t component, + uint32_t num_components, bool strip_array); + + void emit_custom_templates(); + void emit_custom_functions(); + void emit_resources(); + void emit_specialization_constants_and_structs(); + void emit_interface_block(uint32_t ib_var_id); + bool maybe_emit_array_assignment(uint32_t id_lhs, uint32_t id_rhs); + bool is_var_runtime_size_array(const SPIRVariable &var) const; + uint32_t get_resource_array_size(const SPIRType &type, uint32_t id) const; + + void fix_up_shader_inputs_outputs(); + + std::string func_type_decl(SPIRType &type); + std::string entry_point_args_classic(bool append_comma); + std::string entry_point_args_argument_buffer(bool append_comma); + std::string entry_point_arg_stage_in(); + void entry_point_args_builtin(std::string &args); + void entry_point_args_discrete_descriptors(std::string &args); + std::string append_member_name(const std::string &qualifier, const SPIRType &type, uint32_t index); + std::string ensure_valid_name(std::string name, std::string pfx); + std::string to_sampler_expression(uint32_t id); + std::string to_swizzle_expression(uint32_t id); + std::string to_buffer_size_expression(uint32_t id); + bool is_sample_rate() const; + bool is_intersection_query() const; + bool is_direct_input_builtin(spv::BuiltIn builtin); + std::string builtin_qualifier(spv::BuiltIn builtin); + std::string builtin_type_decl(spv::BuiltIn builtin, uint32_t id = 0); + std::string built_in_func_arg(spv::BuiltIn builtin, bool prefix_comma); + std::string member_attribute_qualifier(const SPIRType &type, uint32_t index); + std::string member_location_attribute_qualifier(const SPIRType &type, uint32_t index); + std::string argument_decl(const SPIRFunction::Parameter &arg); + const char *descriptor_address_space(uint32_t id, spv::StorageClass storage, const char *plain_address_space) const; + std::string round_fp_tex_coords(std::string tex_coords, bool coord_is_fp); + uint32_t get_metal_resource_index(SPIRVariable &var, SPIRType::BaseType basetype, uint32_t plane = 0); + uint32_t get_member_location(uint32_t type_id, uint32_t index, uint32_t *comp = nullptr) const; + uint32_t get_or_allocate_builtin_input_member_location(spv::BuiltIn builtin, + uint32_t type_id, uint32_t index, uint32_t *comp = nullptr); + uint32_t get_or_allocate_builtin_output_member_location(spv::BuiltIn builtin, + uint32_t type_id, uint32_t index, uint32_t *comp = nullptr); + + uint32_t get_physical_tess_level_array_size(spv::BuiltIn builtin) const; + + uint32_t get_physical_type_stride(const SPIRType &type) const override; + + // MSL packing rules. These compute the effective packing rules as observed by the MSL compiler in the MSL output. + // These values can change depending on various extended decorations which control packing rules. + // We need to make these rules match up with SPIR-V declared rules. + uint32_t get_declared_type_size_msl(const SPIRType &type, bool packed, bool row_major) const; + uint32_t get_declared_type_array_stride_msl(const SPIRType &type, bool packed, bool row_major) const; + uint32_t get_declared_type_matrix_stride_msl(const SPIRType &type, bool packed, bool row_major) const; + uint32_t get_declared_type_alignment_msl(const SPIRType &type, bool packed, bool row_major) const; + + uint32_t get_declared_struct_member_size_msl(const SPIRType &struct_type, uint32_t index) const; + uint32_t get_declared_struct_member_array_stride_msl(const SPIRType &struct_type, uint32_t index) const; + uint32_t get_declared_struct_member_matrix_stride_msl(const SPIRType &struct_type, uint32_t index) const; + uint32_t get_declared_struct_member_alignment_msl(const SPIRType &struct_type, uint32_t index) const; + + uint32_t get_declared_input_size_msl(const SPIRType &struct_type, uint32_t index) const; + uint32_t get_declared_input_array_stride_msl(const SPIRType &struct_type, uint32_t index) const; + uint32_t get_declared_input_matrix_stride_msl(const SPIRType &struct_type, uint32_t index) const; + uint32_t get_declared_input_alignment_msl(const SPIRType &struct_type, uint32_t index) const; + + const SPIRType &get_physical_member_type(const SPIRType &struct_type, uint32_t index) const; + SPIRType get_presumed_input_type(const SPIRType &struct_type, uint32_t index) const; + + uint32_t get_declared_struct_size_msl(const SPIRType &struct_type, bool ignore_alignment = false, + bool ignore_padding = false) const; + + std::string to_component_argument(uint32_t id); + void align_struct(SPIRType &ib_type, std::unordered_set &aligned_structs); + void mark_scalar_layout_structs(const SPIRType &ib_type); + void mark_struct_members_packed(const SPIRType &type); + void ensure_member_packing_rules_msl(SPIRType &ib_type, uint32_t index); + bool validate_member_packing_rules_msl(const SPIRType &type, uint32_t index) const; + std::string get_argument_address_space(const SPIRVariable &argument); + std::string get_type_address_space(const SPIRType &type, uint32_t id, bool argument = false); + static bool decoration_flags_signal_volatile(const Bitset &flags); + const char *to_restrict(uint32_t id, bool space); + SPIRType &get_stage_in_struct_type(); + SPIRType &get_stage_out_struct_type(); + SPIRType &get_patch_stage_in_struct_type(); + SPIRType &get_patch_stage_out_struct_type(); + std::string get_tess_factor_struct_name(); + SPIRType &get_uint_type(); + uint32_t get_uint_type_id(); + void emit_atomic_func_op(uint32_t result_type, uint32_t result_id, const char *op, spv::Op opcode, + uint32_t mem_order_1, uint32_t mem_order_2, bool has_mem_order_2, uint32_t op0, uint32_t op1 = 0, + bool op1_is_pointer = false, bool op1_is_literal = false, uint32_t op2 = 0); + const char *get_memory_order(uint32_t spv_mem_sem); + void add_pragma_line(const std::string &line); + void add_typedef_line(const std::string &line); + void emit_barrier(uint32_t id_exe_scope, uint32_t id_mem_scope, uint32_t id_mem_sem); + bool emit_array_copy(const char *expr, uint32_t lhs_id, uint32_t rhs_id, + spv::StorageClass lhs_storage, spv::StorageClass rhs_storage) override; + void build_implicit_builtins(); + uint32_t build_constant_uint_array_pointer(); + void emit_entry_point_declarations() override; + bool uses_explicit_early_fragment_test(); + + uint32_t builtin_frag_coord_id = 0; + uint32_t builtin_sample_id_id = 0; + uint32_t builtin_sample_mask_id = 0; + uint32_t builtin_helper_invocation_id = 0; + uint32_t builtin_vertex_idx_id = 0; + uint32_t builtin_base_vertex_id = 0; + uint32_t builtin_instance_idx_id = 0; + uint32_t builtin_base_instance_id = 0; + uint32_t builtin_view_idx_id = 0; + uint32_t builtin_layer_id = 0; + uint32_t builtin_invocation_id_id = 0; + uint32_t builtin_primitive_id_id = 0; + uint32_t builtin_subgroup_invocation_id_id = 0; + uint32_t builtin_subgroup_size_id = 0; + uint32_t builtin_dispatch_base_id = 0; + uint32_t builtin_stage_input_size_id = 0; + uint32_t builtin_local_invocation_index_id = 0; + uint32_t builtin_workgroup_size_id = 0; + uint32_t builtin_frag_depth_id = 0; + uint32_t swizzle_buffer_id = 0; + uint32_t buffer_size_buffer_id = 0; + uint32_t view_mask_buffer_id = 0; + uint32_t dynamic_offsets_buffer_id = 0; + uint32_t uint_type_id = 0; + uint32_t argument_buffer_padding_buffer_type_id = 0; + uint32_t argument_buffer_padding_image_type_id = 0; + uint32_t argument_buffer_padding_sampler_type_id = 0; + + bool does_shader_write_sample_mask = false; + bool frag_shader_needs_discard_checks = false; + + void cast_to_variable_store(uint32_t target_id, std::string &expr, const SPIRType &expr_type) override; + void cast_from_variable_load(uint32_t source_id, std::string &expr, const SPIRType &expr_type) override; + void emit_store_statement(uint32_t lhs_expression, uint32_t rhs_expression) override; + + void analyze_sampled_image_usage(); + + bool access_chain_needs_stage_io_builtin_translation(uint32_t base) override; + bool prepare_access_chain_for_scalar_access(std::string &expr, const SPIRType &type, spv::StorageClass storage, + bool &is_packed) override; + void fix_up_interpolant_access_chain(const uint32_t *ops, uint32_t length); + void check_physical_type_cast(std::string &expr, const SPIRType *type, uint32_t physical_type) override; + + bool emit_tessellation_access_chain(const uint32_t *ops, uint32_t length); + bool emit_tessellation_io_load(uint32_t result_type, uint32_t id, uint32_t ptr); + bool is_out_of_bounds_tessellation_level(uint32_t id_lhs); + + void ensure_builtin(spv::StorageClass storage, spv::BuiltIn builtin); + + void mark_implicit_builtin(spv::StorageClass storage, spv::BuiltIn builtin, uint32_t id); + + std::string convert_to_f32(const std::string &expr, uint32_t components); + + Options msl_options; + std::set spv_function_implementations; + // Must be ordered to ensure declarations are in a specific order. + std::map inputs_by_location; + std::unordered_map inputs_by_builtin; + std::map outputs_by_location; + std::unordered_map outputs_by_builtin; + std::unordered_set location_inputs_in_use; + std::unordered_set location_inputs_in_use_fallback; + std::unordered_set location_outputs_in_use; + std::unordered_set location_outputs_in_use_fallback; + std::unordered_map fragment_output_components; + std::unordered_map builtin_to_automatic_input_location; + std::unordered_map builtin_to_automatic_output_location; + std::set pragma_lines; + std::set typedef_lines; + SmallVector vars_needing_early_declaration; + + std::unordered_map, InternalHasher> resource_bindings; + std::unordered_map resource_arg_buff_idx_to_binding_number; + + uint32_t next_metal_resource_index_buffer = 0; + uint32_t next_metal_resource_index_texture = 0; + uint32_t next_metal_resource_index_sampler = 0; + // Intentionally uninitialized, works around MSVC 2013 bug. + uint32_t next_metal_resource_ids[kMaxArgumentBuffers]; + + VariableID stage_in_var_id = 0; + VariableID stage_out_var_id = 0; + VariableID patch_stage_in_var_id = 0; + VariableID patch_stage_out_var_id = 0; + VariableID stage_in_ptr_var_id = 0; + VariableID stage_out_ptr_var_id = 0; + VariableID tess_level_inner_var_id = 0; + VariableID tess_level_outer_var_id = 0; + VariableID stage_out_masked_builtin_type_id = 0; + + // Handle HLSL-style 0-based vertex/instance index. + enum class TriState + { + Neutral, + No, + Yes + }; + TriState needs_base_vertex_arg = TriState::Neutral; + TriState needs_base_instance_arg = TriState::Neutral; + + bool has_sampled_images = false; + bool builtin_declaration = false; // Handle HLSL-style 0-based vertex/instance index. + + bool is_using_builtin_array = false; // Force the use of C style array declaration. + bool using_builtin_array() const; + + bool is_rasterization_disabled = false; + bool capture_output_to_buffer = false; + bool needs_swizzle_buffer_def = false; + bool used_swizzle_buffer = false; + bool added_builtin_tess_level = false; + bool needs_subgroup_invocation_id = false; + bool needs_subgroup_size = false; + bool needs_sample_id = false; + bool needs_helper_invocation = false; + bool writes_to_depth = false; + std::string qual_pos_var_name; + std::string stage_in_var_name = "in"; + std::string stage_out_var_name = "out"; + std::string patch_stage_in_var_name = "patchIn"; + std::string patch_stage_out_var_name = "patchOut"; + std::string sampler_name_suffix = "Smplr"; + std::string swizzle_name_suffix = "Swzl"; + std::string buffer_size_name_suffix = "BufferSize"; + std::string plane_name_suffix = "Plane"; + std::string input_wg_var_name = "gl_in"; + std::string input_buffer_var_name = "spvIn"; + std::string output_buffer_var_name = "spvOut"; + std::string patch_input_buffer_var_name = "spvPatchIn"; + std::string patch_output_buffer_var_name = "spvPatchOut"; + std::string tess_factor_buffer_var_name = "spvTessLevel"; + std::string index_buffer_var_name = "spvIndices"; + spv::Op previous_instruction_opcode = spv::OpNop; + + // Must be ordered since declaration is in a specific order. + std::map constexpr_samplers_by_id; + std::unordered_map constexpr_samplers_by_binding; + const MSLConstexprSampler *find_constexpr_sampler(uint32_t id) const; + + std::unordered_set buffers_requiring_array_length; + SmallVector> buffer_aliases_argument; + SmallVector buffer_aliases_discrete; + std::unordered_set atomic_image_vars_emulated; // Emulate texture2D atomic operations + std::unordered_set pull_model_inputs; + std::unordered_set recursive_inputs; + + SmallVector entry_point_bindings; + + // Must be ordered since array is in a specific order. + std::map> buffers_requiring_dynamic_offset; + + SmallVector disabled_frag_outputs; + + std::unordered_set inline_uniform_blocks; + + uint32_t argument_buffer_ids[kMaxArgumentBuffers]; + uint32_t argument_buffer_discrete_mask = 0; + uint32_t argument_buffer_device_storage_mask = 0; + + void emit_argument_buffer_aliased_descriptor(const SPIRVariable &aliased_var, + const SPIRVariable &base_var); + + void analyze_argument_buffers(); + bool descriptor_set_is_argument_buffer(uint32_t desc_set) const; + const MSLResourceBinding &get_argument_buffer_resource(uint32_t desc_set, uint32_t arg_idx) const; + void add_argument_buffer_padding_buffer_type(SPIRType &struct_type, uint32_t &mbr_idx, uint32_t &arg_buff_index, MSLResourceBinding &rez_bind); + void add_argument_buffer_padding_image_type(SPIRType &struct_type, uint32_t &mbr_idx, uint32_t &arg_buff_index, MSLResourceBinding &rez_bind); + void add_argument_buffer_padding_sampler_type(SPIRType &struct_type, uint32_t &mbr_idx, uint32_t &arg_buff_index, MSLResourceBinding &rez_bind); + void add_argument_buffer_padding_type(uint32_t mbr_type_id, SPIRType &struct_type, uint32_t &mbr_idx, uint32_t &arg_buff_index, uint32_t count); + + uint32_t get_target_components_for_fragment_location(uint32_t location) const; + uint32_t build_extended_vector_type(uint32_t type_id, uint32_t components, + SPIRType::BaseType basetype = SPIRType::Unknown); + uint32_t build_msl_interpolant_type(uint32_t type_id, bool is_noperspective); + + bool suppress_missing_prototypes = false; + bool suppress_incompatible_pointer_types_discard_qualifiers = false; + + void add_spv_func_and_recompile(SPVFuncImpl spv_func); + + void activate_argument_buffer_resources(); + + bool type_is_msl_framebuffer_fetch(const SPIRType &type) const; + bool is_supported_argument_buffer_type(const SPIRType &type) const; + + bool variable_storage_requires_stage_io(spv::StorageClass storage) const; + + bool needs_manual_helper_invocation_updates() const + { + return msl_options.manual_helper_invocation_updates && msl_options.supports_msl_version(2, 3); + } + bool needs_frag_discard_checks() const + { + return get_execution_model() == spv::ExecutionModelFragment && msl_options.supports_msl_version(2, 3) && + msl_options.check_discarded_frag_stores && frag_shader_needs_discard_checks; + } + + bool has_additional_fixed_sample_mask() const { return msl_options.additional_fixed_sample_mask != 0xffffffff; } + std::string additional_fixed_sample_mask_str() const; + + // OpcodeHandler that handles several MSL preprocessing operations. + struct OpCodePreprocessor : OpcodeHandler + { + OpCodePreprocessor(CompilerMSL &compiler_) + : compiler(compiler_) + { + } + + bool handle(spv::Op opcode, const uint32_t *args, uint32_t length) override; + CompilerMSL::SPVFuncImpl get_spv_func_impl(spv::Op opcode, const uint32_t *args); + void check_resource_write(uint32_t var_id); + + CompilerMSL &compiler; + std::unordered_map result_types; + std::unordered_map image_pointers_emulated; // Emulate texture2D atomic operations + bool suppress_missing_prototypes = false; + bool uses_atomics = false; + bool uses_image_write = false; + bool uses_buffer_write = false; + bool uses_discard = false; + bool needs_subgroup_invocation_id = false; + bool needs_subgroup_size = false; + bool needs_sample_id = false; + bool needs_helper_invocation = false; + }; + + // OpcodeHandler that scans for uses of sampled images + struct SampledImageScanner : OpcodeHandler + { + SampledImageScanner(CompilerMSL &compiler_) + : compiler(compiler_) + { + } + + bool handle(spv::Op opcode, const uint32_t *args, uint32_t) override; + + CompilerMSL &compiler; + }; + + // Sorts the members of a SPIRType and associated Meta info based on a settable sorting + // aspect, which defines which aspect of the struct members will be used to sort them. + // Regardless of the sorting aspect, built-in members always appear at the end of the struct. + struct MemberSorter + { + enum SortAspect + { + LocationThenBuiltInType, + Offset + }; + + void sort(); + bool operator()(uint32_t mbr_idx1, uint32_t mbr_idx2); + MemberSorter(SPIRType &t, Meta &m, SortAspect sa); + + SPIRType &type; + Meta &meta; + SortAspect sort_aspect; + }; +}; +} // namespace SPIRV_CROSS_NAMESPACE + +#endif diff --git a/thirdparty/spirv-cross/spirv_parser.cpp b/thirdparty/spirv-cross/spirv_parser.cpp new file mode 100644 index 00000000000..6108dbb6533 --- /dev/null +++ b/thirdparty/spirv-cross/spirv_parser.cpp @@ -0,0 +1,1337 @@ +/* + * Copyright 2018-2021 Arm Limited + * SPDX-License-Identifier: Apache-2.0 OR MIT + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +/* + * At your option, you may choose to accept this material under either: + * 1. The Apache License, Version 2.0, found at , or + * 2. The MIT License, found at . + */ + +#include "spirv_parser.hpp" +#include + +using namespace std; +using namespace spv; + +namespace SPIRV_CROSS_NAMESPACE +{ +Parser::Parser(vector spirv) +{ + ir.spirv = std::move(spirv); +} + +Parser::Parser(const uint32_t *spirv_data, size_t word_count) +{ + ir.spirv = vector(spirv_data, spirv_data + word_count); +} + +static bool decoration_is_string(Decoration decoration) +{ + switch (decoration) + { + case DecorationHlslSemanticGOOGLE: + return true; + + default: + return false; + } +} + +static inline uint32_t swap_endian(uint32_t v) +{ + return ((v >> 24) & 0x000000ffu) | ((v >> 8) & 0x0000ff00u) | ((v << 8) & 0x00ff0000u) | ((v << 24) & 0xff000000u); +} + +static bool is_valid_spirv_version(uint32_t version) +{ + switch (version) + { + // Allow v99 since it tends to just work. + case 99: + case 0x10000: // SPIR-V 1.0 + case 0x10100: // SPIR-V 1.1 + case 0x10200: // SPIR-V 1.2 + case 0x10300: // SPIR-V 1.3 + case 0x10400: // SPIR-V 1.4 + case 0x10500: // SPIR-V 1.5 + case 0x10600: // SPIR-V 1.6 + return true; + + default: + return false; + } +} + +void Parser::parse() +{ + auto &spirv = ir.spirv; + + auto len = spirv.size(); + if (len < 5) + SPIRV_CROSS_THROW("SPIRV file too small."); + + auto s = spirv.data(); + + // Endian-swap if we need to. + if (s[0] == swap_endian(MagicNumber)) + transform(begin(spirv), end(spirv), begin(spirv), [](uint32_t c) { return swap_endian(c); }); + + if (s[0] != MagicNumber || !is_valid_spirv_version(s[1])) + SPIRV_CROSS_THROW("Invalid SPIRV format."); + + uint32_t bound = s[3]; + + const uint32_t MaximumNumberOfIDs = 0x3fffff; + if (bound > MaximumNumberOfIDs) + SPIRV_CROSS_THROW("ID bound exceeds limit of 0x3fffff.\n"); + + ir.set_id_bounds(bound); + + uint32_t offset = 5; + + SmallVector instructions; + while (offset < len) + { + Instruction instr = {}; + instr.op = spirv[offset] & 0xffff; + instr.count = (spirv[offset] >> 16) & 0xffff; + + if (instr.count == 0) + SPIRV_CROSS_THROW("SPIR-V instructions cannot consume 0 words. Invalid SPIR-V file."); + + instr.offset = offset + 1; + instr.length = instr.count - 1; + + offset += instr.count; + + if (offset > spirv.size()) + SPIRV_CROSS_THROW("SPIR-V instruction goes out of bounds."); + + instructions.push_back(instr); + } + + for (auto &i : instructions) + parse(i); + + for (auto &fixup : forward_pointer_fixups) + { + auto &target = get(fixup.first); + auto &source = get(fixup.second); + target.member_types = source.member_types; + target.basetype = source.basetype; + target.self = source.self; + } + forward_pointer_fixups.clear(); + + if (current_function) + SPIRV_CROSS_THROW("Function was not terminated."); + if (current_block) + SPIRV_CROSS_THROW("Block was not terminated."); + if (ir.default_entry_point == 0) + SPIRV_CROSS_THROW("There is no entry point in the SPIR-V module."); +} + +const uint32_t *Parser::stream(const Instruction &instr) const +{ + // If we're not going to use any arguments, just return nullptr. + // We want to avoid case where we return an out of range pointer + // that trips debug assertions on some platforms. + if (!instr.length) + return nullptr; + + if (instr.offset + instr.length > ir.spirv.size()) + SPIRV_CROSS_THROW("Compiler::stream() out of range."); + return &ir.spirv[instr.offset]; +} + +static string extract_string(const vector &spirv, uint32_t offset) +{ + string ret; + for (uint32_t i = offset; i < spirv.size(); i++) + { + uint32_t w = spirv[i]; + + for (uint32_t j = 0; j < 4; j++, w >>= 8) + { + char c = w & 0xff; + if (c == '\0') + return ret; + ret += c; + } + } + + SPIRV_CROSS_THROW("String was not terminated before EOF"); +} + +void Parser::parse(const Instruction &instruction) +{ + auto *ops = stream(instruction); + auto op = static_cast(instruction.op); + uint32_t length = instruction.length; + + // HACK for glslang that might emit OpEmitMeshTasksEXT followed by return / branch. + // Instead of failing hard, just ignore it. + if (ignore_trailing_block_opcodes) + { + ignore_trailing_block_opcodes = false; + if (op == OpReturn || op == OpBranch || op == OpUnreachable) + return; + } + + switch (op) + { + case OpSourceContinued: + case OpSourceExtension: + case OpNop: + case OpModuleProcessed: + break; + + case OpString: + { + set(ops[0], extract_string(ir.spirv, instruction.offset + 1)); + break; + } + + case OpMemoryModel: + ir.addressing_model = static_cast(ops[0]); + ir.memory_model = static_cast(ops[1]); + break; + + case OpSource: + { + auto lang = static_cast(ops[0]); + switch (lang) + { + case SourceLanguageESSL: + ir.source.es = true; + ir.source.version = ops[1]; + ir.source.known = true; + ir.source.hlsl = false; + break; + + case SourceLanguageGLSL: + ir.source.es = false; + ir.source.version = ops[1]; + ir.source.known = true; + ir.source.hlsl = false; + break; + + case SourceLanguageHLSL: + // For purposes of cross-compiling, this is GLSL 450. + ir.source.es = false; + ir.source.version = 450; + ir.source.known = true; + ir.source.hlsl = true; + break; + + default: + ir.source.known = false; + break; + } + break; + } + + case OpUndef: + { + uint32_t result_type = ops[0]; + uint32_t id = ops[1]; + set(id, result_type); + if (current_block) + current_block->ops.push_back(instruction); + break; + } + + case OpCapability: + { + uint32_t cap = ops[0]; + if (cap == CapabilityKernel) + SPIRV_CROSS_THROW("Kernel capability not supported."); + + ir.declared_capabilities.push_back(static_cast(ops[0])); + break; + } + + case OpExtension: + { + auto ext = extract_string(ir.spirv, instruction.offset); + ir.declared_extensions.push_back(std::move(ext)); + break; + } + + case OpExtInstImport: + { + uint32_t id = ops[0]; + + SPIRExtension::Extension spirv_ext = SPIRExtension::Unsupported; + + auto ext = extract_string(ir.spirv, instruction.offset + 1); + if (ext == "GLSL.std.450") + spirv_ext = SPIRExtension::GLSL; + else if (ext == "DebugInfo") + spirv_ext = SPIRExtension::SPV_debug_info; + else if (ext == "SPV_AMD_shader_ballot") + spirv_ext = SPIRExtension::SPV_AMD_shader_ballot; + else if (ext == "SPV_AMD_shader_explicit_vertex_parameter") + spirv_ext = SPIRExtension::SPV_AMD_shader_explicit_vertex_parameter; + else if (ext == "SPV_AMD_shader_trinary_minmax") + spirv_ext = SPIRExtension::SPV_AMD_shader_trinary_minmax; + else if (ext == "SPV_AMD_gcn_shader") + spirv_ext = SPIRExtension::SPV_AMD_gcn_shader; + else if (ext == "NonSemantic.DebugPrintf") + spirv_ext = SPIRExtension::NonSemanticDebugPrintf; + else if (ext == "NonSemantic.Shader.DebugInfo.100") + spirv_ext = SPIRExtension::NonSemanticShaderDebugInfo; + else if (ext.find("NonSemantic.") == 0) + spirv_ext = SPIRExtension::NonSemanticGeneric; + + set(id, spirv_ext); + // Other SPIR-V extensions which have ExtInstrs are currently not supported. + + break; + } + + case OpExtInst: + { + // The SPIR-V debug information extended instructions might come at global scope. + if (current_block) + { + current_block->ops.push_back(instruction); + if (length >= 2) + { + const auto *type = maybe_get(ops[0]); + if (type) + ir.load_type_width.insert({ ops[1], type->width }); + } + } + break; + } + + case OpEntryPoint: + { + auto itr = + ir.entry_points.insert(make_pair(ops[1], SPIREntryPoint(ops[1], static_cast(ops[0]), + extract_string(ir.spirv, instruction.offset + 2)))); + auto &e = itr.first->second; + + // Strings need nul-terminator and consume the whole word. + uint32_t strlen_words = uint32_t((e.name.size() + 1 + 3) >> 2); + + for (uint32_t i = strlen_words + 2; i < instruction.length; i++) + e.interface_variables.push_back(ops[i]); + + // Set the name of the entry point in case OpName is not provided later. + ir.set_name(ops[1], e.name); + + // If we don't have an entry, make the first one our "default". + if (!ir.default_entry_point) + ir.default_entry_point = ops[1]; + break; + } + + case OpExecutionMode: + { + auto &execution = ir.entry_points[ops[0]]; + auto mode = static_cast(ops[1]); + execution.flags.set(mode); + + switch (mode) + { + case ExecutionModeInvocations: + execution.invocations = ops[2]; + break; + + case ExecutionModeLocalSize: + execution.workgroup_size.x = ops[2]; + execution.workgroup_size.y = ops[3]; + execution.workgroup_size.z = ops[4]; + break; + + case ExecutionModeOutputVertices: + execution.output_vertices = ops[2]; + break; + + case ExecutionModeOutputPrimitivesEXT: + execution.output_primitives = ops[2]; + break; + + default: + break; + } + break; + } + + case OpExecutionModeId: + { + auto &execution = ir.entry_points[ops[0]]; + auto mode = static_cast(ops[1]); + execution.flags.set(mode); + + if (mode == ExecutionModeLocalSizeId) + { + execution.workgroup_size.id_x = ops[2]; + execution.workgroup_size.id_y = ops[3]; + execution.workgroup_size.id_z = ops[4]; + } + + break; + } + + case OpName: + { + uint32_t id = ops[0]; + ir.set_name(id, extract_string(ir.spirv, instruction.offset + 1)); + break; + } + + case OpMemberName: + { + uint32_t id = ops[0]; + uint32_t member = ops[1]; + ir.set_member_name(id, member, extract_string(ir.spirv, instruction.offset + 2)); + break; + } + + case OpDecorationGroup: + { + // Noop, this simply means an ID should be a collector of decorations. + // The meta array is already a flat array of decorations which will contain the relevant decorations. + break; + } + + case OpGroupDecorate: + { + uint32_t group_id = ops[0]; + auto &decorations = ir.meta[group_id].decoration; + auto &flags = decorations.decoration_flags; + + // Copies decorations from one ID to another. Only copy decorations which are set in the group, + // i.e., we cannot just copy the meta structure directly. + for (uint32_t i = 1; i < length; i++) + { + uint32_t target = ops[i]; + flags.for_each_bit([&](uint32_t bit) { + auto decoration = static_cast(bit); + + if (decoration_is_string(decoration)) + { + ir.set_decoration_string(target, decoration, ir.get_decoration_string(group_id, decoration)); + } + else + { + ir.meta[target].decoration_word_offset[decoration] = + ir.meta[group_id].decoration_word_offset[decoration]; + ir.set_decoration(target, decoration, ir.get_decoration(group_id, decoration)); + } + }); + } + break; + } + + case OpGroupMemberDecorate: + { + uint32_t group_id = ops[0]; + auto &flags = ir.meta[group_id].decoration.decoration_flags; + + // Copies decorations from one ID to another. Only copy decorations which are set in the group, + // i.e., we cannot just copy the meta structure directly. + for (uint32_t i = 1; i + 1 < length; i += 2) + { + uint32_t target = ops[i + 0]; + uint32_t index = ops[i + 1]; + flags.for_each_bit([&](uint32_t bit) { + auto decoration = static_cast(bit); + + if (decoration_is_string(decoration)) + ir.set_member_decoration_string(target, index, decoration, + ir.get_decoration_string(group_id, decoration)); + else + ir.set_member_decoration(target, index, decoration, ir.get_decoration(group_id, decoration)); + }); + } + break; + } + + case OpDecorate: + case OpDecorateId: + { + // OpDecorateId technically supports an array of arguments, but our only supported decorations are single uint, + // so merge decorate and decorate-id here. + uint32_t id = ops[0]; + + auto decoration = static_cast(ops[1]); + if (length >= 3) + { + ir.meta[id].decoration_word_offset[decoration] = uint32_t(&ops[2] - ir.spirv.data()); + ir.set_decoration(id, decoration, ops[2]); + } + else + ir.set_decoration(id, decoration); + + break; + } + + case OpDecorateStringGOOGLE: + { + uint32_t id = ops[0]; + auto decoration = static_cast(ops[1]); + ir.set_decoration_string(id, decoration, extract_string(ir.spirv, instruction.offset + 2)); + break; + } + + case OpMemberDecorate: + { + uint32_t id = ops[0]; + uint32_t member = ops[1]; + auto decoration = static_cast(ops[2]); + if (length >= 4) + ir.set_member_decoration(id, member, decoration, ops[3]); + else + ir.set_member_decoration(id, member, decoration); + break; + } + + case OpMemberDecorateStringGOOGLE: + { + uint32_t id = ops[0]; + uint32_t member = ops[1]; + auto decoration = static_cast(ops[2]); + ir.set_member_decoration_string(id, member, decoration, extract_string(ir.spirv, instruction.offset + 3)); + break; + } + + // Build up basic types. + case OpTypeVoid: + { + uint32_t id = ops[0]; + auto &type = set(id, op); + type.basetype = SPIRType::Void; + break; + } + + case OpTypeBool: + { + uint32_t id = ops[0]; + auto &type = set(id, op); + type.basetype = SPIRType::Boolean; + type.width = 1; + break; + } + + case OpTypeFloat: + { + uint32_t id = ops[0]; + uint32_t width = ops[1]; + auto &type = set(id, op); + if (width == 64) + type.basetype = SPIRType::Double; + else if (width == 32) + type.basetype = SPIRType::Float; + else if (width == 16) + type.basetype = SPIRType::Half; + else + SPIRV_CROSS_THROW("Unrecognized bit-width of floating point type."); + type.width = width; + break; + } + + case OpTypeInt: + { + uint32_t id = ops[0]; + uint32_t width = ops[1]; + bool signedness = ops[2] != 0; + auto &type = set(id, op); + type.basetype = signedness ? to_signed_basetype(width) : to_unsigned_basetype(width); + type.width = width; + break; + } + + // Build composite types by "inheriting". + // NOTE: The self member is also copied! For pointers and array modifiers this is a good thing + // since we can refer to decorations on pointee classes which is needed for UBO/SSBO, I/O blocks in geometry/tess etc. + case OpTypeVector: + { + uint32_t id = ops[0]; + uint32_t vecsize = ops[2]; + + auto &base = get(ops[1]); + auto &vecbase = set(id, base); + + vecbase.op = op; + vecbase.vecsize = vecsize; + vecbase.self = id; + vecbase.parent_type = ops[1]; + break; + } + + case OpTypeMatrix: + { + uint32_t id = ops[0]; + uint32_t colcount = ops[2]; + + auto &base = get(ops[1]); + auto &matrixbase = set(id, base); + + matrixbase.op = op; + matrixbase.columns = colcount; + matrixbase.self = id; + matrixbase.parent_type = ops[1]; + break; + } + + case OpTypeArray: + { + uint32_t id = ops[0]; + uint32_t tid = ops[1]; + auto &base = get(tid); + auto &arraybase = set(id, base); + + arraybase.op = op; + arraybase.parent_type = tid; + + uint32_t cid = ops[2]; + ir.mark_used_as_array_length(cid); + auto *c = maybe_get(cid); + bool literal = c && !c->specialization; + + // We're copying type information into Array types, so we'll need a fixup for any physical pointer + // references. + if (base.forward_pointer) + forward_pointer_fixups.push_back({ id, tid }); + + arraybase.array_size_literal.push_back(literal); + arraybase.array.push_back(literal ? c->scalar() : cid); + + // .self resolves down to non-array/non-pointer type. + arraybase.self = base.self; + break; + } + + case OpTypeRuntimeArray: + { + uint32_t id = ops[0]; + + auto &base = get(ops[1]); + auto &arraybase = set(id, base); + + // We're copying type information into Array types, so we'll need a fixup for any physical pointer + // references. + if (base.forward_pointer) + forward_pointer_fixups.push_back({ id, ops[1] }); + + arraybase.op = op; + arraybase.array.push_back(0); + arraybase.array_size_literal.push_back(true); + arraybase.parent_type = ops[1]; + + // .self resolves down to non-array/non-pointer type. + arraybase.self = base.self; + break; + } + + case OpTypeImage: + { + uint32_t id = ops[0]; + auto &type = set(id, op); + type.basetype = SPIRType::Image; + type.image.type = ops[1]; + type.image.dim = static_cast(ops[2]); + type.image.depth = ops[3] == 1; + type.image.arrayed = ops[4] != 0; + type.image.ms = ops[5] != 0; + type.image.sampled = ops[6]; + type.image.format = static_cast(ops[7]); + type.image.access = (length >= 9) ? static_cast(ops[8]) : AccessQualifierMax; + break; + } + + case OpTypeSampledImage: + { + uint32_t id = ops[0]; + uint32_t imagetype = ops[1]; + auto &type = set(id, op); + type = get(imagetype); + type.basetype = SPIRType::SampledImage; + type.self = id; + break; + } + + case OpTypeSampler: + { + uint32_t id = ops[0]; + auto &type = set(id, op); + type.basetype = SPIRType::Sampler; + break; + } + + case OpTypePointer: + { + uint32_t id = ops[0]; + + // Very rarely, we might receive a FunctionPrototype here. + // We won't be able to compile it, but we shouldn't crash when parsing. + // We should be able to reflect. + auto *base = maybe_get(ops[2]); + auto &ptrbase = set(id, op); + + if (base) + { + ptrbase = *base; + ptrbase.op = op; + } + + ptrbase.pointer = true; + ptrbase.pointer_depth++; + ptrbase.storage = static_cast(ops[1]); + + if (ptrbase.storage == StorageClassAtomicCounter) + ptrbase.basetype = SPIRType::AtomicCounter; + + if (base && base->forward_pointer) + forward_pointer_fixups.push_back({ id, ops[2] }); + + ptrbase.parent_type = ops[2]; + + // Do NOT set ptrbase.self! + break; + } + + case OpTypeForwardPointer: + { + uint32_t id = ops[0]; + auto &ptrbase = set(id, op); + ptrbase.pointer = true; + ptrbase.pointer_depth++; + ptrbase.storage = static_cast(ops[1]); + ptrbase.forward_pointer = true; + + if (ptrbase.storage == StorageClassAtomicCounter) + ptrbase.basetype = SPIRType::AtomicCounter; + + break; + } + + case OpTypeStruct: + { + uint32_t id = ops[0]; + auto &type = set(id, op); + type.basetype = SPIRType::Struct; + for (uint32_t i = 1; i < length; i++) + type.member_types.push_back(ops[i]); + + // Check if we have seen this struct type before, with just different + // decorations. + // + // Add workaround for issue #17 as well by looking at OpName for the struct + // types, which we shouldn't normally do. + // We should not normally have to consider type aliases like this to begin with + // however ... glslang issues #304, #307 cover this. + + // For stripped names, never consider struct type aliasing. + // We risk declaring the same struct multiple times, but type-punning is not allowed + // so this is safe. + bool consider_aliasing = !ir.get_name(type.self).empty(); + if (consider_aliasing) + { + for (auto &other : global_struct_cache) + { + if (ir.get_name(type.self) == ir.get_name(other) && + types_are_logically_equivalent(type, get(other))) + { + type.type_alias = other; + break; + } + } + + if (type.type_alias == TypeID(0)) + global_struct_cache.push_back(id); + } + break; + } + + case OpTypeFunction: + { + uint32_t id = ops[0]; + uint32_t ret = ops[1]; + + auto &func = set(id, ret); + for (uint32_t i = 2; i < length; i++) + func.parameter_types.push_back(ops[i]); + break; + } + + case OpTypeAccelerationStructureKHR: + { + uint32_t id = ops[0]; + auto &type = set(id, op); + type.basetype = SPIRType::AccelerationStructure; + break; + } + + case OpTypeRayQueryKHR: + { + uint32_t id = ops[0]; + auto &type = set(id, op); + type.basetype = SPIRType::RayQuery; + break; + } + + // Variable declaration + // All variables are essentially pointers with a storage qualifier. + case OpVariable: + { + uint32_t type = ops[0]; + uint32_t id = ops[1]; + auto storage = static_cast(ops[2]); + uint32_t initializer = length == 4 ? ops[3] : 0; + + if (storage == StorageClassFunction) + { + if (!current_function) + SPIRV_CROSS_THROW("No function currently in scope"); + current_function->add_local_variable(id); + } + + set(id, type, storage, initializer); + break; + } + + // OpPhi + // OpPhi is a fairly magical opcode. + // It selects temporary variables based on which parent block we *came from*. + // In high-level languages we can "de-SSA" by creating a function local, and flush out temporaries to this function-local + // variable to emulate SSA Phi. + case OpPhi: + { + if (!current_function) + SPIRV_CROSS_THROW("No function currently in scope"); + if (!current_block) + SPIRV_CROSS_THROW("No block currently in scope"); + + uint32_t result_type = ops[0]; + uint32_t id = ops[1]; + + // Instead of a temporary, create a new function-wide temporary with this ID instead. + auto &var = set(id, result_type, spv::StorageClassFunction); + var.phi_variable = true; + + current_function->add_local_variable(id); + + for (uint32_t i = 2; i + 2 <= length; i += 2) + current_block->phi_variables.push_back({ ops[i], ops[i + 1], id }); + break; + } + + // Constants + case OpSpecConstant: + case OpConstant: + { + uint32_t id = ops[1]; + auto &type = get(ops[0]); + + if (type.width > 32) + set(id, ops[0], ops[2] | (uint64_t(ops[3]) << 32), op == OpSpecConstant); + else + set(id, ops[0], ops[2], op == OpSpecConstant); + break; + } + + case OpSpecConstantFalse: + case OpConstantFalse: + { + uint32_t id = ops[1]; + set(id, ops[0], uint32_t(0), op == OpSpecConstantFalse); + break; + } + + case OpSpecConstantTrue: + case OpConstantTrue: + { + uint32_t id = ops[1]; + set(id, ops[0], uint32_t(1), op == OpSpecConstantTrue); + break; + } + + case OpConstantNull: + { + uint32_t id = ops[1]; + uint32_t type = ops[0]; + ir.make_constant_null(id, type, true); + break; + } + + case OpSpecConstantComposite: + case OpConstantComposite: + { + uint32_t id = ops[1]; + uint32_t type = ops[0]; + + auto &ctype = get(type); + + // We can have constants which are structs and arrays. + // In this case, our SPIRConstant will be a list of other SPIRConstant ids which we + // can refer to. + if (ctype.basetype == SPIRType::Struct || !ctype.array.empty()) + { + set(id, type, ops + 2, length - 2, op == OpSpecConstantComposite); + } + else + { + uint32_t elements = length - 2; + if (elements > 4) + SPIRV_CROSS_THROW("OpConstantComposite only supports 1, 2, 3 and 4 elements."); + + SPIRConstant remapped_constant_ops[4]; + const SPIRConstant *c[4]; + for (uint32_t i = 0; i < elements; i++) + { + // Specialization constants operations can also be part of this. + // We do not know their value, so any attempt to query SPIRConstant later + // will fail. We can only propagate the ID of the expression and use to_expression on it. + auto *constant_op = maybe_get(ops[2 + i]); + auto *undef_op = maybe_get(ops[2 + i]); + if (constant_op) + { + if (op == OpConstantComposite) + SPIRV_CROSS_THROW("Specialization constant operation used in OpConstantComposite."); + + remapped_constant_ops[i].make_null(get(constant_op->basetype)); + remapped_constant_ops[i].self = constant_op->self; + remapped_constant_ops[i].constant_type = constant_op->basetype; + remapped_constant_ops[i].specialization = true; + c[i] = &remapped_constant_ops[i]; + } + else if (undef_op) + { + // Undefined, just pick 0. + remapped_constant_ops[i].make_null(get(undef_op->basetype)); + remapped_constant_ops[i].constant_type = undef_op->basetype; + c[i] = &remapped_constant_ops[i]; + } + else + c[i] = &get(ops[2 + i]); + } + set(id, type, c, elements, op == OpSpecConstantComposite); + } + break; + } + + // Functions + case OpFunction: + { + uint32_t res = ops[0]; + uint32_t id = ops[1]; + // Control + uint32_t type = ops[3]; + + if (current_function) + SPIRV_CROSS_THROW("Must end a function before starting a new one!"); + + current_function = &set(id, res, type); + break; + } + + case OpFunctionParameter: + { + uint32_t type = ops[0]; + uint32_t id = ops[1]; + + if (!current_function) + SPIRV_CROSS_THROW("Must be in a function!"); + + current_function->add_parameter(type, id); + set(id, type, StorageClassFunction); + break; + } + + case OpFunctionEnd: + { + if (current_block) + { + // Very specific error message, but seems to come up quite often. + SPIRV_CROSS_THROW( + "Cannot end a function before ending the current block.\n" + "Likely cause: If this SPIR-V was created from glslang HLSL, make sure the entry point is valid."); + } + current_function = nullptr; + break; + } + + // Blocks + case OpLabel: + { + // OpLabel always starts a block. + if (!current_function) + SPIRV_CROSS_THROW("Blocks cannot exist outside functions!"); + + uint32_t id = ops[0]; + + current_function->blocks.push_back(id); + if (!current_function->entry_block) + current_function->entry_block = id; + + if (current_block) + SPIRV_CROSS_THROW("Cannot start a block before ending the current block."); + + current_block = &set(id); + break; + } + + // Branch instructions end blocks. + case OpBranch: + { + if (!current_block) + SPIRV_CROSS_THROW("Trying to end a non-existing block."); + + uint32_t target = ops[0]; + current_block->terminator = SPIRBlock::Direct; + current_block->next_block = target; + current_block = nullptr; + break; + } + + case OpBranchConditional: + { + if (!current_block) + SPIRV_CROSS_THROW("Trying to end a non-existing block."); + + current_block->condition = ops[0]; + current_block->true_block = ops[1]; + current_block->false_block = ops[2]; + + current_block->terminator = SPIRBlock::Select; + + if (current_block->true_block == current_block->false_block) + { + // Bogus conditional, translate to a direct branch. + // Avoids some ugly edge cases later when analyzing CFGs. + + // There are some super jank cases where the merge block is different from the true/false, + // and later branches can "break" out of the selection construct this way. + // This is complete nonsense, but CTS hits this case. + // In this scenario, we should see the selection construct as more of a Switch with one default case. + // The problem here is that this breaks any attempt to break out of outer switch statements, + // but it's theoretically solvable if this ever comes up using the ladder breaking system ... + + if (current_block->true_block != current_block->next_block && + current_block->merge == SPIRBlock::MergeSelection) + { + uint32_t ids = ir.increase_bound_by(2); + + auto &type = set(ids, OpTypeInt); + type.basetype = SPIRType::Int; + type.width = 32; + auto &c = set(ids + 1, ids); + + current_block->condition = c.self; + current_block->default_block = current_block->true_block; + current_block->terminator = SPIRBlock::MultiSelect; + ir.block_meta[current_block->next_block] &= ~ParsedIR::BLOCK_META_SELECTION_MERGE_BIT; + ir.block_meta[current_block->next_block] |= ParsedIR::BLOCK_META_MULTISELECT_MERGE_BIT; + } + else + { + // Collapse loops if we have to. + bool collapsed_loop = current_block->true_block == current_block->merge_block && + current_block->merge == SPIRBlock::MergeLoop; + + if (collapsed_loop) + { + ir.block_meta[current_block->merge_block] &= ~ParsedIR::BLOCK_META_LOOP_MERGE_BIT; + ir.block_meta[current_block->continue_block] &= ~ParsedIR::BLOCK_META_CONTINUE_BIT; + } + + current_block->next_block = current_block->true_block; + current_block->condition = 0; + current_block->true_block = 0; + current_block->false_block = 0; + current_block->merge_block = 0; + current_block->merge = SPIRBlock::MergeNone; + current_block->terminator = SPIRBlock::Direct; + } + } + + current_block = nullptr; + break; + } + + case OpSwitch: + { + if (!current_block) + SPIRV_CROSS_THROW("Trying to end a non-existing block."); + + current_block->terminator = SPIRBlock::MultiSelect; + + current_block->condition = ops[0]; + current_block->default_block = ops[1]; + + uint32_t remaining_ops = length - 2; + if ((remaining_ops % 2) == 0) + { + for (uint32_t i = 2; i + 2 <= length; i += 2) + current_block->cases_32bit.push_back({ ops[i], ops[i + 1] }); + } + + if ((remaining_ops % 3) == 0) + { + for (uint32_t i = 2; i + 3 <= length; i += 3) + { + uint64_t value = (static_cast(ops[i + 1]) << 32) | ops[i]; + current_block->cases_64bit.push_back({ value, ops[i + 2] }); + } + } + + // If we jump to next block, make it break instead since we're inside a switch case block at that point. + ir.block_meta[current_block->next_block] |= ParsedIR::BLOCK_META_MULTISELECT_MERGE_BIT; + + current_block = nullptr; + break; + } + + case OpKill: + case OpTerminateInvocation: + { + if (!current_block) + SPIRV_CROSS_THROW("Trying to end a non-existing block."); + current_block->terminator = SPIRBlock::Kill; + current_block = nullptr; + break; + } + + case OpTerminateRayKHR: + // NV variant is not a terminator. + if (!current_block) + SPIRV_CROSS_THROW("Trying to end a non-existing block."); + current_block->terminator = SPIRBlock::TerminateRay; + current_block = nullptr; + break; + + case OpIgnoreIntersectionKHR: + // NV variant is not a terminator. + if (!current_block) + SPIRV_CROSS_THROW("Trying to end a non-existing block."); + current_block->terminator = SPIRBlock::IgnoreIntersection; + current_block = nullptr; + break; + + case OpEmitMeshTasksEXT: + if (!current_block) + SPIRV_CROSS_THROW("Trying to end a non-existing block."); + current_block->terminator = SPIRBlock::EmitMeshTasks; + for (uint32_t i = 0; i < 3; i++) + current_block->mesh.groups[i] = ops[i]; + current_block->mesh.payload = length >= 4 ? ops[3] : 0; + current_block = nullptr; + // Currently glslang is bugged and does not treat EmitMeshTasksEXT as a terminator. + ignore_trailing_block_opcodes = true; + break; + + case OpReturn: + { + if (!current_block) + SPIRV_CROSS_THROW("Trying to end a non-existing block."); + current_block->terminator = SPIRBlock::Return; + current_block = nullptr; + break; + } + + case OpReturnValue: + { + if (!current_block) + SPIRV_CROSS_THROW("Trying to end a non-existing block."); + current_block->terminator = SPIRBlock::Return; + current_block->return_value = ops[0]; + current_block = nullptr; + break; + } + + case OpUnreachable: + { + if (!current_block) + SPIRV_CROSS_THROW("Trying to end a non-existing block."); + current_block->terminator = SPIRBlock::Unreachable; + current_block = nullptr; + break; + } + + case OpSelectionMerge: + { + if (!current_block) + SPIRV_CROSS_THROW("Trying to modify a non-existing block."); + + current_block->next_block = ops[0]; + current_block->merge = SPIRBlock::MergeSelection; + ir.block_meta[current_block->next_block] |= ParsedIR::BLOCK_META_SELECTION_MERGE_BIT; + + if (length >= 2) + { + if (ops[1] & SelectionControlFlattenMask) + current_block->hint = SPIRBlock::HintFlatten; + else if (ops[1] & SelectionControlDontFlattenMask) + current_block->hint = SPIRBlock::HintDontFlatten; + } + break; + } + + case OpLoopMerge: + { + if (!current_block) + SPIRV_CROSS_THROW("Trying to modify a non-existing block."); + + current_block->merge_block = ops[0]; + current_block->continue_block = ops[1]; + current_block->merge = SPIRBlock::MergeLoop; + + ir.block_meta[current_block->self] |= ParsedIR::BLOCK_META_LOOP_HEADER_BIT; + ir.block_meta[current_block->merge_block] |= ParsedIR::BLOCK_META_LOOP_MERGE_BIT; + + ir.continue_block_to_loop_header[current_block->continue_block] = BlockID(current_block->self); + + // Don't add loop headers to continue blocks, + // which would make it impossible branch into the loop header since + // they are treated as continues. + if (current_block->continue_block != BlockID(current_block->self)) + ir.block_meta[current_block->continue_block] |= ParsedIR::BLOCK_META_CONTINUE_BIT; + + if (length >= 3) + { + if (ops[2] & LoopControlUnrollMask) + current_block->hint = SPIRBlock::HintUnroll; + else if (ops[2] & LoopControlDontUnrollMask) + current_block->hint = SPIRBlock::HintDontUnroll; + } + break; + } + + case OpSpecConstantOp: + { + if (length < 3) + SPIRV_CROSS_THROW("OpSpecConstantOp not enough arguments."); + + uint32_t result_type = ops[0]; + uint32_t id = ops[1]; + auto spec_op = static_cast(ops[2]); + + set(id, result_type, spec_op, ops + 3, length - 3); + break; + } + + case OpLine: + { + // OpLine might come at global scope, but we don't care about those since they will not be declared in any + // meaningful correct order. + // Ignore all OpLine directives which live outside a function. + if (current_block) + current_block->ops.push_back(instruction); + + // Line directives may arrive before first OpLabel. + // Treat this as the line of the function declaration, + // so warnings for arguments can propagate properly. + if (current_function) + { + // Store the first one we find and emit it before creating the function prototype. + if (current_function->entry_line.file_id == 0) + { + current_function->entry_line.file_id = ops[0]; + current_function->entry_line.line_literal = ops[1]; + } + } + break; + } + + case OpNoLine: + { + // OpNoLine might come at global scope. + if (current_block) + current_block->ops.push_back(instruction); + break; + } + + // Actual opcodes. + default: + { + if (length >= 2) + { + const auto *type = maybe_get(ops[0]); + if (type) + ir.load_type_width.insert({ ops[1], type->width }); + } + + if (!current_block) + SPIRV_CROSS_THROW("Currently no block to insert opcode."); + + current_block->ops.push_back(instruction); + break; + } + } +} + +bool Parser::types_are_logically_equivalent(const SPIRType &a, const SPIRType &b) const +{ + if (a.basetype != b.basetype) + return false; + if (a.width != b.width) + return false; + if (a.vecsize != b.vecsize) + return false; + if (a.columns != b.columns) + return false; + if (a.array.size() != b.array.size()) + return false; + + size_t array_count = a.array.size(); + if (array_count && memcmp(a.array.data(), b.array.data(), array_count * sizeof(uint32_t)) != 0) + return false; + + if (a.basetype == SPIRType::Image || a.basetype == SPIRType::SampledImage) + { + if (memcmp(&a.image, &b.image, sizeof(SPIRType::Image)) != 0) + return false; + } + + if (a.member_types.size() != b.member_types.size()) + return false; + + size_t member_types = a.member_types.size(); + for (size_t i = 0; i < member_types; i++) + { + if (!types_are_logically_equivalent(get(a.member_types[i]), get(b.member_types[i]))) + return false; + } + + return true; +} + +bool Parser::variable_storage_is_aliased(const SPIRVariable &v) const +{ + auto &type = get(v.basetype); + + auto *type_meta = ir.find_meta(type.self); + + bool ssbo = v.storage == StorageClassStorageBuffer || + (type_meta && type_meta->decoration.decoration_flags.get(DecorationBufferBlock)); + bool image = type.basetype == SPIRType::Image; + bool counter = type.basetype == SPIRType::AtomicCounter; + + bool is_restrict; + if (ssbo) + is_restrict = ir.get_buffer_block_flags(v).get(DecorationRestrict); + else + is_restrict = ir.has_decoration(v.self, DecorationRestrict); + + return !is_restrict && (ssbo || image || counter); +} +} // namespace SPIRV_CROSS_NAMESPACE diff --git a/thirdparty/spirv-cross/spirv_parser.hpp b/thirdparty/spirv-cross/spirv_parser.hpp new file mode 100644 index 00000000000..dabc0e22446 --- /dev/null +++ b/thirdparty/spirv-cross/spirv_parser.hpp @@ -0,0 +1,103 @@ +/* + * Copyright 2018-2021 Arm Limited + * SPDX-License-Identifier: Apache-2.0 OR MIT + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +/* + * At your option, you may choose to accept this material under either: + * 1. The Apache License, Version 2.0, found at , or + * 2. The MIT License, found at . + */ + +#ifndef SPIRV_CROSS_PARSER_HPP +#define SPIRV_CROSS_PARSER_HPP + +#include "spirv_cross_parsed_ir.hpp" +#include + +namespace SPIRV_CROSS_NAMESPACE +{ +class Parser +{ +public: + Parser(const uint32_t *spirv_data, size_t word_count); + Parser(std::vector spirv); + + void parse(); + + ParsedIR &get_parsed_ir() + { + return ir; + } + +private: + ParsedIR ir; + SPIRFunction *current_function = nullptr; + SPIRBlock *current_block = nullptr; + // For workarounds. + bool ignore_trailing_block_opcodes = false; + + void parse(const Instruction &instr); + const uint32_t *stream(const Instruction &instr) const; + + template + T &set(uint32_t id, P &&... args) + { + ir.add_typed_id(static_cast(T::type), id); + auto &var = variant_set(ir.ids[id], std::forward

(args)...); + var.self = id; + return var; + } + + template + T &get(uint32_t id) + { + return variant_get(ir.ids[id]); + } + + template + T *maybe_get(uint32_t id) + { + if (ir.ids[id].get_type() == static_cast(T::type)) + return &get(id); + else + return nullptr; + } + + template + const T &get(uint32_t id) const + { + return variant_get(ir.ids[id]); + } + + template + const T *maybe_get(uint32_t id) const + { + if (ir.ids[id].get_type() == T::type) + return &get(id); + else + return nullptr; + } + + // This must be an ordered data structure so we always pick the same type aliases. + SmallVector global_struct_cache; + SmallVector> forward_pointer_fixups; + + bool types_are_logically_equivalent(const SPIRType &a, const SPIRType &b) const; + bool variable_storage_is_aliased(const SPIRVariable &v) const; +}; +} // namespace SPIRV_CROSS_NAMESPACE + +#endif diff --git a/thirdparty/spirv-cross/spirv_reflect.cpp b/thirdparty/spirv-cross/spirv_reflect.cpp new file mode 100644 index 00000000000..633983bd30d --- /dev/null +++ b/thirdparty/spirv-cross/spirv_reflect.cpp @@ -0,0 +1,710 @@ +/* + * Copyright 2018-2021 Bradley Austin Davis + * SPDX-License-Identifier: Apache-2.0 OR MIT + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +/* + * At your option, you may choose to accept this material under either: + * 1. The Apache License, Version 2.0, found at , or + * 2. The MIT License, found at . + */ + +#include "spirv_reflect.hpp" +#include "spirv_glsl.hpp" +#include + +using namespace spv; +using namespace SPIRV_CROSS_NAMESPACE; +using namespace std; + +namespace simple_json +{ +enum class Type +{ + Object, + Array, +}; + +using State = std::pair; +using Stack = std::stack; + +class Stream +{ + Stack stack; + StringStream<> buffer; + uint32_t indent{ 0 }; + char current_locale_radix_character = '.'; + +public: + void set_current_locale_radix_character(char c) + { + current_locale_radix_character = c; + } + + void begin_json_object(); + void end_json_object(); + void emit_json_key(const std::string &key); + void emit_json_key_value(const std::string &key, const std::string &value); + void emit_json_key_value(const std::string &key, bool value); + void emit_json_key_value(const std::string &key, uint32_t value); + void emit_json_key_value(const std::string &key, int32_t value); + void emit_json_key_value(const std::string &key, float value); + void emit_json_key_object(const std::string &key); + void emit_json_key_array(const std::string &key); + + void begin_json_array(); + void end_json_array(); + void emit_json_array_value(const std::string &value); + void emit_json_array_value(uint32_t value); + void emit_json_array_value(bool value); + + std::string str() const + { + return buffer.str(); + } + +private: + inline void statement_indent() + { + for (uint32_t i = 0; i < indent; i++) + buffer << " "; + } + + template + inline void statement_inner(T &&t) + { + buffer << std::forward(t); + } + + template + inline void statement_inner(T &&t, Ts &&... ts) + { + buffer << std::forward(t); + statement_inner(std::forward(ts)...); + } + + template + inline void statement(Ts &&... ts) + { + statement_indent(); + statement_inner(std::forward(ts)...); + buffer << '\n'; + } + + template + void statement_no_return(Ts &&... ts) + { + statement_indent(); + statement_inner(std::forward(ts)...); + } +}; +} // namespace simple_json + +using namespace simple_json; + +// Hackery to emit JSON without using nlohmann/json C++ library (which requires a +// higher level of compiler compliance than is required by SPIRV-Cross +void Stream::begin_json_array() +{ + if (!stack.empty() && stack.top().second) + { + statement_inner(",\n"); + } + statement("["); + ++indent; + stack.emplace(Type::Array, false); +} + +void Stream::end_json_array() +{ + if (stack.empty() || stack.top().first != Type::Array) + SPIRV_CROSS_THROW("Invalid JSON state"); + if (stack.top().second) + { + statement_inner("\n"); + } + --indent; + statement_no_return("]"); + stack.pop(); + if (!stack.empty()) + { + stack.top().second = true; + } +} + +void Stream::emit_json_array_value(const std::string &value) +{ + if (stack.empty() || stack.top().first != Type::Array) + SPIRV_CROSS_THROW("Invalid JSON state"); + + if (stack.top().second) + statement_inner(",\n"); + + statement_no_return("\"", value, "\""); + stack.top().second = true; +} + +void Stream::emit_json_array_value(uint32_t value) +{ + if (stack.empty() || stack.top().first != Type::Array) + SPIRV_CROSS_THROW("Invalid JSON state"); + if (stack.top().second) + statement_inner(",\n"); + statement_no_return(std::to_string(value)); + stack.top().second = true; +} + +void Stream::emit_json_array_value(bool value) +{ + if (stack.empty() || stack.top().first != Type::Array) + SPIRV_CROSS_THROW("Invalid JSON state"); + if (stack.top().second) + statement_inner(",\n"); + statement_no_return(value ? "true" : "false"); + stack.top().second = true; +} + +void Stream::begin_json_object() +{ + if (!stack.empty() && stack.top().second) + { + statement_inner(",\n"); + } + statement("{"); + ++indent; + stack.emplace(Type::Object, false); +} + +void Stream::end_json_object() +{ + if (stack.empty() || stack.top().first != Type::Object) + SPIRV_CROSS_THROW("Invalid JSON state"); + if (stack.top().second) + { + statement_inner("\n"); + } + --indent; + statement_no_return("}"); + stack.pop(); + if (!stack.empty()) + { + stack.top().second = true; + } +} + +void Stream::emit_json_key(const std::string &key) +{ + if (stack.empty() || stack.top().first != Type::Object) + SPIRV_CROSS_THROW("Invalid JSON state"); + + if (stack.top().second) + statement_inner(",\n"); + statement_no_return("\"", key, "\" : "); + stack.top().second = true; +} + +void Stream::emit_json_key_value(const std::string &key, const std::string &value) +{ + emit_json_key(key); + statement_inner("\"", value, "\""); +} + +void Stream::emit_json_key_value(const std::string &key, uint32_t value) +{ + emit_json_key(key); + statement_inner(value); +} + +void Stream::emit_json_key_value(const std::string &key, int32_t value) +{ + emit_json_key(key); + statement_inner(value); +} + +void Stream::emit_json_key_value(const std::string &key, float value) +{ + emit_json_key(key); + statement_inner(convert_to_string(value, current_locale_radix_character)); +} + +void Stream::emit_json_key_value(const std::string &key, bool value) +{ + emit_json_key(key); + statement_inner(value ? "true" : "false"); +} + +void Stream::emit_json_key_object(const std::string &key) +{ + emit_json_key(key); + statement_inner("{\n"); + ++indent; + stack.emplace(Type::Object, false); +} + +void Stream::emit_json_key_array(const std::string &key) +{ + emit_json_key(key); + statement_inner("[\n"); + ++indent; + stack.emplace(Type::Array, false); +} + +void CompilerReflection::set_format(const std::string &format) +{ + if (format != "json") + { + SPIRV_CROSS_THROW("Unsupported format"); + } +} + +string CompilerReflection::compile() +{ + json_stream = std::make_shared(); + json_stream->set_current_locale_radix_character(current_locale_radix_character); + json_stream->begin_json_object(); + reorder_type_alias(); + emit_entry_points(); + emit_types(); + emit_resources(); + emit_specialization_constants(); + json_stream->end_json_object(); + return json_stream->str(); +} + +static bool naturally_emit_type(const SPIRType &type) +{ + return type.basetype == SPIRType::Struct && !type.pointer && type.array.empty(); +} + +bool CompilerReflection::type_is_reference(const SPIRType &type) const +{ + // Physical pointers and arrays of physical pointers need to refer to the pointee's type. + return is_physical_pointer(type) || + (type_is_array_of_pointers(type) && type.storage == StorageClassPhysicalStorageBuffer); +} + +void CompilerReflection::emit_types() +{ + bool emitted_open_tag = false; + + SmallVector physical_pointee_types; + + // If we have physical pointers or arrays of physical pointers, it's also helpful to emit the pointee type + // and chain the type hierarchy. For POD, arrays can emit the entire type in-place. + ir.for_each_typed_id([&](uint32_t self, SPIRType &type) { + if (naturally_emit_type(type)) + { + emit_type(self, emitted_open_tag); + } + else if (type_is_reference(type)) + { + if (!naturally_emit_type(this->get(type.parent_type)) && + find(physical_pointee_types.begin(), physical_pointee_types.end(), type.parent_type) == + physical_pointee_types.end()) + { + physical_pointee_types.push_back(type.parent_type); + } + } + }); + + for (uint32_t pointee_type : physical_pointee_types) + emit_type(pointee_type, emitted_open_tag); + + if (emitted_open_tag) + { + json_stream->end_json_object(); + } +} + +void CompilerReflection::emit_type(uint32_t type_id, bool &emitted_open_tag) +{ + auto &type = get(type_id); + auto name = type_to_glsl(type); + + if (!emitted_open_tag) + { + json_stream->emit_json_key_object("types"); + emitted_open_tag = true; + } + json_stream->emit_json_key_object("_" + std::to_string(type_id)); + json_stream->emit_json_key_value("name", name); + + if (is_physical_pointer(type)) + { + json_stream->emit_json_key_value("type", "_" + std::to_string(type.parent_type)); + json_stream->emit_json_key_value("physical_pointer", true); + } + else if (!type.array.empty()) + { + emit_type_array(type); + json_stream->emit_json_key_value("type", "_" + std::to_string(type.parent_type)); + json_stream->emit_json_key_value("array_stride", get_decoration(type_id, DecorationArrayStride)); + } + else + { + json_stream->emit_json_key_array("members"); + // FIXME ideally we'd like to emit the size of a structure as a + // convenience to people parsing the reflected JSON. The problem + // is that there's no implicit size for a type. It's final size + // will be determined by the top level declaration in which it's + // included. So there might be one size for the struct if it's + // included in a std140 uniform block and another if it's included + // in a std430 uniform block. + // The solution is to include *all* potential sizes as a map of + // layout type name to integer, but that will probably require + // some additional logic being written in this class, or in the + // parent CompilerGLSL class. + auto size = type.member_types.size(); + for (uint32_t i = 0; i < size; ++i) + { + emit_type_member(type, i); + } + json_stream->end_json_array(); + } + + json_stream->end_json_object(); +} + +void CompilerReflection::emit_type_member(const SPIRType &type, uint32_t index) +{ + auto &membertype = get(type.member_types[index]); + json_stream->begin_json_object(); + auto name = to_member_name(type, index); + // FIXME we'd like to emit the offset of each member, but such offsets are + // context dependent. See the comment above regarding structure sizes + json_stream->emit_json_key_value("name", name); + + if (type_is_reference(membertype)) + { + json_stream->emit_json_key_value("type", "_" + std::to_string(membertype.parent_type)); + } + else if (membertype.basetype == SPIRType::Struct) + { + json_stream->emit_json_key_value("type", "_" + std::to_string(membertype.self)); + } + else + { + json_stream->emit_json_key_value("type", type_to_glsl(membertype)); + } + emit_type_member_qualifiers(type, index); + json_stream->end_json_object(); +} + +void CompilerReflection::emit_type_array(const SPIRType &type) +{ + if (!is_physical_pointer(type) && !type.array.empty()) + { + json_stream->emit_json_key_array("array"); + // Note that we emit the zeros here as a means of identifying + // unbounded arrays. This is necessary as otherwise there would + // be no way of differentiating between float[4] and float[4][] + for (const auto &value : type.array) + json_stream->emit_json_array_value(value); + json_stream->end_json_array(); + + json_stream->emit_json_key_array("array_size_is_literal"); + for (const auto &value : type.array_size_literal) + json_stream->emit_json_array_value(value); + json_stream->end_json_array(); + } +} + +void CompilerReflection::emit_type_member_qualifiers(const SPIRType &type, uint32_t index) +{ + auto &membertype = get(type.member_types[index]); + emit_type_array(membertype); + auto &memb = ir.meta[type.self].members; + if (index < memb.size()) + { + auto &dec = memb[index]; + if (dec.decoration_flags.get(DecorationLocation)) + json_stream->emit_json_key_value("location", dec.location); + if (dec.decoration_flags.get(DecorationOffset)) + json_stream->emit_json_key_value("offset", dec.offset); + + // Array stride is a property of the array type, not the struct. + if (has_decoration(type.member_types[index], DecorationArrayStride)) + json_stream->emit_json_key_value("array_stride", + get_decoration(type.member_types[index], DecorationArrayStride)); + + if (dec.decoration_flags.get(DecorationMatrixStride)) + json_stream->emit_json_key_value("matrix_stride", dec.matrix_stride); + if (dec.decoration_flags.get(DecorationRowMajor)) + json_stream->emit_json_key_value("row_major", true); + + if (is_physical_pointer(membertype)) + json_stream->emit_json_key_value("physical_pointer", true); + } +} + +string CompilerReflection::execution_model_to_str(spv::ExecutionModel model) +{ + switch (model) + { + case ExecutionModelVertex: + return "vert"; + case ExecutionModelTessellationControl: + return "tesc"; + case ExecutionModelTessellationEvaluation: + return "tese"; + case ExecutionModelGeometry: + return "geom"; + case ExecutionModelFragment: + return "frag"; + case ExecutionModelGLCompute: + return "comp"; + case ExecutionModelRayGenerationNV: + return "rgen"; + case ExecutionModelIntersectionNV: + return "rint"; + case ExecutionModelAnyHitNV: + return "rahit"; + case ExecutionModelClosestHitNV: + return "rchit"; + case ExecutionModelMissNV: + return "rmiss"; + case ExecutionModelCallableNV: + return "rcall"; + default: + return "???"; + } +} + +// FIXME include things like the local_size dimensions, geometry output vertex count, etc +void CompilerReflection::emit_entry_points() +{ + auto entries = get_entry_points_and_stages(); + if (!entries.empty()) + { + // Needed to make output deterministic. + sort(begin(entries), end(entries), [](const EntryPoint &a, const EntryPoint &b) -> bool { + if (a.execution_model < b.execution_model) + return true; + else if (a.execution_model > b.execution_model) + return false; + else + return a.name < b.name; + }); + + json_stream->emit_json_key_array("entryPoints"); + for (auto &e : entries) + { + json_stream->begin_json_object(); + json_stream->emit_json_key_value("name", e.name); + json_stream->emit_json_key_value("mode", execution_model_to_str(e.execution_model)); + if (e.execution_model == ExecutionModelGLCompute) + { + const auto &spv_entry = get_entry_point(e.name, e.execution_model); + + SpecializationConstant spec_x, spec_y, spec_z; + get_work_group_size_specialization_constants(spec_x, spec_y, spec_z); + + json_stream->emit_json_key_array("workgroup_size"); + json_stream->emit_json_array_value(spec_x.id != ID(0) ? spec_x.constant_id : + spv_entry.workgroup_size.x); + json_stream->emit_json_array_value(spec_y.id != ID(0) ? spec_y.constant_id : + spv_entry.workgroup_size.y); + json_stream->emit_json_array_value(spec_z.id != ID(0) ? spec_z.constant_id : + spv_entry.workgroup_size.z); + json_stream->end_json_array(); + + json_stream->emit_json_key_array("workgroup_size_is_spec_constant_id"); + json_stream->emit_json_array_value(spec_x.id != ID(0)); + json_stream->emit_json_array_value(spec_y.id != ID(0)); + json_stream->emit_json_array_value(spec_z.id != ID(0)); + json_stream->end_json_array(); + } + json_stream->end_json_object(); + } + json_stream->end_json_array(); + } +} + +void CompilerReflection::emit_resources() +{ + auto res = get_shader_resources(); + emit_resources("subpass_inputs", res.subpass_inputs); + emit_resources("inputs", res.stage_inputs); + emit_resources("outputs", res.stage_outputs); + emit_resources("textures", res.sampled_images); + emit_resources("separate_images", res.separate_images); + emit_resources("separate_samplers", res.separate_samplers); + emit_resources("images", res.storage_images); + emit_resources("ssbos", res.storage_buffers); + emit_resources("ubos", res.uniform_buffers); + emit_resources("push_constants", res.push_constant_buffers); + emit_resources("counters", res.atomic_counters); + emit_resources("acceleration_structures", res.acceleration_structures); +} + +void CompilerReflection::emit_resources(const char *tag, const SmallVector &resources) +{ + if (resources.empty()) + { + return; + } + + json_stream->emit_json_key_array(tag); + for (auto &res : resources) + { + auto &type = get_type(res.type_id); + auto typeflags = ir.meta[type.self].decoration.decoration_flags; + auto &mask = get_decoration_bitset(res.id); + + // If we don't have a name, use the fallback for the type instead of the variable + // for SSBOs and UBOs since those are the only meaningful names to use externally. + // Push constant blocks are still accessed by name and not block name, even though they are technically Blocks. + bool is_push_constant = get_storage_class(res.id) == StorageClassPushConstant; + bool is_block = get_decoration_bitset(type.self).get(DecorationBlock) || + get_decoration_bitset(type.self).get(DecorationBufferBlock); + + ID fallback_id = !is_push_constant && is_block ? ID(res.base_type_id) : ID(res.id); + + json_stream->begin_json_object(); + + if (type.basetype == SPIRType::Struct) + { + json_stream->emit_json_key_value("type", "_" + std::to_string(res.base_type_id)); + } + else + { + json_stream->emit_json_key_value("type", type_to_glsl(type)); + } + + json_stream->emit_json_key_value("name", !res.name.empty() ? res.name : get_fallback_name(fallback_id)); + { + bool ssbo_block = type.storage == StorageClassStorageBuffer || + (type.storage == StorageClassUniform && typeflags.get(DecorationBufferBlock)); + Bitset qualifier_mask = ssbo_block ? get_buffer_block_flags(res.id) : mask; + + if (qualifier_mask.get(DecorationNonReadable)) + json_stream->emit_json_key_value("writeonly", true); + if (qualifier_mask.get(DecorationNonWritable)) + json_stream->emit_json_key_value("readonly", true); + if (qualifier_mask.get(DecorationRestrict)) + json_stream->emit_json_key_value("restrict", true); + if (qualifier_mask.get(DecorationCoherent)) + json_stream->emit_json_key_value("coherent", true); + if (qualifier_mask.get(DecorationVolatile)) + json_stream->emit_json_key_value("volatile", true); + } + + emit_type_array(type); + + { + bool is_sized_block = is_block && (get_storage_class(res.id) == StorageClassUniform || + get_storage_class(res.id) == StorageClassUniformConstant || + get_storage_class(res.id) == StorageClassStorageBuffer); + if (is_sized_block) + { + uint32_t block_size = uint32_t(get_declared_struct_size(get_type(res.base_type_id))); + json_stream->emit_json_key_value("block_size", block_size); + } + } + + if (type.storage == StorageClassPushConstant) + json_stream->emit_json_key_value("push_constant", true); + if (mask.get(DecorationLocation)) + json_stream->emit_json_key_value("location", get_decoration(res.id, DecorationLocation)); + if (mask.get(DecorationRowMajor)) + json_stream->emit_json_key_value("row_major", true); + if (mask.get(DecorationColMajor)) + json_stream->emit_json_key_value("column_major", true); + if (mask.get(DecorationIndex)) + json_stream->emit_json_key_value("index", get_decoration(res.id, DecorationIndex)); + if (type.storage != StorageClassPushConstant && mask.get(DecorationDescriptorSet)) + json_stream->emit_json_key_value("set", get_decoration(res.id, DecorationDescriptorSet)); + if (mask.get(DecorationBinding)) + json_stream->emit_json_key_value("binding", get_decoration(res.id, DecorationBinding)); + if (mask.get(DecorationInputAttachmentIndex)) + json_stream->emit_json_key_value("input_attachment_index", + get_decoration(res.id, DecorationInputAttachmentIndex)); + if (mask.get(DecorationOffset)) + json_stream->emit_json_key_value("offset", get_decoration(res.id, DecorationOffset)); + if (mask.get(DecorationWeightTextureQCOM)) + json_stream->emit_json_key_value("WeightTextureQCOM", get_decoration(res.id, DecorationWeightTextureQCOM)); + if (mask.get(DecorationBlockMatchTextureQCOM)) + json_stream->emit_json_key_value("BlockMatchTextureQCOM", get_decoration(res.id, DecorationBlockMatchTextureQCOM)); + + // For images, the type itself adds a layout qualifer. + // Only emit the format for storage images. + if (type.basetype == SPIRType::Image && type.image.sampled == 2) + { + const char *fmt = format_to_glsl(type.image.format); + if (fmt != nullptr) + json_stream->emit_json_key_value("format", std::string(fmt)); + } + json_stream->end_json_object(); + } + json_stream->end_json_array(); +} + +void CompilerReflection::emit_specialization_constants() +{ + auto specialization_constants = get_specialization_constants(); + if (specialization_constants.empty()) + return; + + json_stream->emit_json_key_array("specialization_constants"); + for (const auto &spec_const : specialization_constants) + { + auto &c = get(spec_const.id); + auto type = get(c.constant_type); + json_stream->begin_json_object(); + json_stream->emit_json_key_value("name", get_name(spec_const.id)); + json_stream->emit_json_key_value("id", spec_const.constant_id); + json_stream->emit_json_key_value("type", type_to_glsl(type)); + json_stream->emit_json_key_value("variable_id", spec_const.id); + switch (type.basetype) + { + case SPIRType::UInt: + json_stream->emit_json_key_value("default_value", c.scalar()); + break; + + case SPIRType::Int: + json_stream->emit_json_key_value("default_value", c.scalar_i32()); + break; + + case SPIRType::Float: + json_stream->emit_json_key_value("default_value", c.scalar_f32()); + break; + + case SPIRType::Boolean: + json_stream->emit_json_key_value("default_value", c.scalar() != 0); + break; + + default: + break; + } + json_stream->end_json_object(); + } + json_stream->end_json_array(); +} + +string CompilerReflection::to_member_name(const SPIRType &type, uint32_t index) const +{ + auto *type_meta = ir.find_meta(type.self); + + if (type_meta) + { + auto &memb = type_meta->members; + if (index < memb.size() && !memb[index].alias.empty()) + return memb[index].alias; + else + return join("_m", index); + } + else + return join("_m", index); +} diff --git a/thirdparty/spirv-cross/spirv_reflect.hpp b/thirdparty/spirv-cross/spirv_reflect.hpp new file mode 100644 index 00000000000..a129ba54da5 --- /dev/null +++ b/thirdparty/spirv-cross/spirv_reflect.hpp @@ -0,0 +1,91 @@ +/* + * Copyright 2018-2021 Bradley Austin Davis + * SPDX-License-Identifier: Apache-2.0 OR MIT + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +/* + * At your option, you may choose to accept this material under either: + * 1. The Apache License, Version 2.0, found at , or + * 2. The MIT License, found at . + */ + +#ifndef SPIRV_CROSS_REFLECT_HPP +#define SPIRV_CROSS_REFLECT_HPP + +#include "spirv_glsl.hpp" +#include + +namespace simple_json +{ +class Stream; +} + +namespace SPIRV_CROSS_NAMESPACE +{ +class CompilerReflection : public CompilerGLSL +{ + using Parent = CompilerGLSL; + +public: + explicit CompilerReflection(std::vector spirv_) + : Parent(std::move(spirv_)) + { + options.vulkan_semantics = true; + } + + CompilerReflection(const uint32_t *ir_, size_t word_count) + : Parent(ir_, word_count) + { + options.vulkan_semantics = true; + } + + explicit CompilerReflection(const ParsedIR &ir_) + : CompilerGLSL(ir_) + { + options.vulkan_semantics = true; + } + + explicit CompilerReflection(ParsedIR &&ir_) + : CompilerGLSL(std::move(ir_)) + { + options.vulkan_semantics = true; + } + + void set_format(const std::string &format); + std::string compile() override; + +private: + static std::string execution_model_to_str(spv::ExecutionModel model); + + void emit_entry_points(); + void emit_types(); + void emit_resources(); + void emit_specialization_constants(); + + void emit_type(uint32_t type_id, bool &emitted_open_tag); + void emit_type_member(const SPIRType &type, uint32_t index); + void emit_type_member_qualifiers(const SPIRType &type, uint32_t index); + void emit_type_array(const SPIRType &type); + void emit_resources(const char *tag, const SmallVector &resources); + bool type_is_reference(const SPIRType &type) const; + + std::string to_member_name(const SPIRType &type, uint32_t index) const; + + std::shared_ptr json_stream; +}; + +} // namespace SPIRV_CROSS_NAMESPACE + +#endif