mirror of
https://github.com/godotengine/godot.git
synced 2024-11-10 14:12:51 +00:00
Remove denoise module and thirdparty OIDN.
This is replaced by a much lighter weight and faster JNLM denoiser. OIDN is still much more accurate, and may be provided as an optional backend in the future, but the JNLM denoiser seems good enough for most use cases and removing OIDN reduces the build system complexity, binary size, and build times very significantly.
This commit is contained in:
parent
1b2b726502
commit
ab65effed0
@ -425,11 +425,6 @@ Comment: Stripped down version of "nvapi.h" from the NVIDIA NVAPI SDK
|
|||||||
Copyright: 2019-2022, NVIDIA Corporation
|
Copyright: 2019-2022, NVIDIA Corporation
|
||||||
License: Expat
|
License: Expat
|
||||||
|
|
||||||
Files: ./thirdparty/oidn/
|
|
||||||
Comment: Intel Open Image Denoise
|
|
||||||
Copyright: 2009-2019, Intel Corporation
|
|
||||||
License: Apache-2.0
|
|
||||||
|
|
||||||
Files: ./thirdparty/openxr/
|
Files: ./thirdparty/openxr/
|
||||||
Comment: OpenXR Loader
|
Comment: OpenXR Loader
|
||||||
Copyright: 2020-2023, The Khronos Group Inc.
|
Copyright: 2020-2023, The Khronos Group Inc.
|
||||||
|
@ -1,138 +0,0 @@
|
|||||||
#!/usr/bin/env python
|
|
||||||
|
|
||||||
import resource_to_cpp
|
|
||||||
|
|
||||||
Import("env")
|
|
||||||
Import("env_modules")
|
|
||||||
|
|
||||||
env_oidn = env_modules.Clone()
|
|
||||||
|
|
||||||
# Thirdparty source files
|
|
||||||
|
|
||||||
thirdparty_obj = []
|
|
||||||
|
|
||||||
thirdparty_dir = "#thirdparty/oidn/"
|
|
||||||
thirdparty_sources = [
|
|
||||||
"core/api.cpp",
|
|
||||||
"core/device.cpp",
|
|
||||||
"core/filter.cpp",
|
|
||||||
"core/network.cpp",
|
|
||||||
"core/autoencoder.cpp",
|
|
||||||
"core/transfer_function.cpp",
|
|
||||||
"weights/rtlightmap_hdr.gen.cpp",
|
|
||||||
"mkl-dnn/src/common/batch_normalization.cpp",
|
|
||||||
"mkl-dnn/src/common/concat.cpp",
|
|
||||||
"mkl-dnn/src/common/convolution.cpp",
|
|
||||||
"mkl-dnn/src/common/convolution_pd.cpp",
|
|
||||||
"mkl-dnn/src/common/deconvolution.cpp",
|
|
||||||
"mkl-dnn/src/common/eltwise.cpp",
|
|
||||||
"mkl-dnn/src/common/engine.cpp",
|
|
||||||
"mkl-dnn/src/common/inner_product.cpp",
|
|
||||||
"mkl-dnn/src/common/inner_product_pd.cpp",
|
|
||||||
"mkl-dnn/src/common/lrn.cpp",
|
|
||||||
"mkl-dnn/src/common/memory.cpp",
|
|
||||||
"mkl-dnn/src/common/memory_desc_wrapper.cpp",
|
|
||||||
"mkl-dnn/src/common/mkldnn_debug.cpp",
|
|
||||||
"mkl-dnn/src/common/mkldnn_debug_autogenerated.cpp",
|
|
||||||
"mkl-dnn/src/common/pooling.cpp",
|
|
||||||
"mkl-dnn/src/common/primitive.cpp",
|
|
||||||
"mkl-dnn/src/common/primitive_attr.cpp",
|
|
||||||
"mkl-dnn/src/common/primitive_desc.cpp",
|
|
||||||
"mkl-dnn/src/common/primitive_exec_types.cpp",
|
|
||||||
"mkl-dnn/src/common/primitive_iterator.cpp",
|
|
||||||
"mkl-dnn/src/common/query.cpp",
|
|
||||||
"mkl-dnn/src/common/reorder.cpp",
|
|
||||||
"mkl-dnn/src/common/rnn.cpp",
|
|
||||||
"mkl-dnn/src/common/scratchpad.cpp",
|
|
||||||
"mkl-dnn/src/common/shuffle.cpp",
|
|
||||||
"mkl-dnn/src/common/softmax.cpp",
|
|
||||||
"mkl-dnn/src/common/stream.cpp",
|
|
||||||
"mkl-dnn/src/common/sum.cpp",
|
|
||||||
"mkl-dnn/src/common/utils.cpp",
|
|
||||||
"mkl-dnn/src/common/verbose.cpp",
|
|
||||||
"mkl-dnn/src/cpu/cpu_barrier.cpp",
|
|
||||||
"mkl-dnn/src/cpu/cpu_concat.cpp",
|
|
||||||
"mkl-dnn/src/cpu/cpu_engine.cpp",
|
|
||||||
"mkl-dnn/src/cpu/cpu_memory.cpp",
|
|
||||||
"mkl-dnn/src/cpu/cpu_reducer.cpp",
|
|
||||||
"mkl-dnn/src/cpu/cpu_reorder.cpp",
|
|
||||||
"mkl-dnn/src/cpu/cpu_sum.cpp",
|
|
||||||
"mkl-dnn/src/cpu/jit_avx2_conv_kernel_f32.cpp",
|
|
||||||
"mkl-dnn/src/cpu/jit_avx2_convolution.cpp",
|
|
||||||
"mkl-dnn/src/cpu/jit_avx512_common_conv_kernel.cpp",
|
|
||||||
"mkl-dnn/src/cpu/jit_avx512_common_conv_winograd_kernel_f32.cpp",
|
|
||||||
"mkl-dnn/src/cpu/jit_avx512_common_convolution.cpp",
|
|
||||||
"mkl-dnn/src/cpu/jit_avx512_common_convolution_winograd.cpp",
|
|
||||||
"mkl-dnn/src/cpu/jit_avx512_core_fp32_wino_conv_2x3.cpp",
|
|
||||||
"mkl-dnn/src/cpu/jit_avx512_core_fp32_wino_conv_4x3.cpp",
|
|
||||||
"mkl-dnn/src/cpu/jit_avx512_core_fp32_wino_conv_4x3_kernel.cpp",
|
|
||||||
"mkl-dnn/src/cpu/jit_sse42_conv_kernel_f32.cpp",
|
|
||||||
"mkl-dnn/src/cpu/jit_sse42_convolution.cpp",
|
|
||||||
"mkl-dnn/src/cpu/jit_transpose_src_utils.cpp",
|
|
||||||
"mkl-dnn/src/cpu/jit_uni_eltwise.cpp",
|
|
||||||
"mkl-dnn/src/cpu/jit_uni_pool_kernel_f32.cpp",
|
|
||||||
"mkl-dnn/src/cpu/jit_uni_pooling.cpp",
|
|
||||||
"mkl-dnn/src/cpu/jit_uni_reorder.cpp",
|
|
||||||
"mkl-dnn/src/cpu/jit_uni_reorder_utils.cpp",
|
|
||||||
"mkl-dnn/src/cpu/jit_utils/jit_utils.cpp",
|
|
||||||
"mkl-dnn/src/cpu/jit_utils/jitprofiling/jitprofiling.c",
|
|
||||||
"common/platform.cpp",
|
|
||||||
"common/thread.cpp",
|
|
||||||
"common/tensor.cpp",
|
|
||||||
]
|
|
||||||
thirdparty_sources = [thirdparty_dir + file for file in thirdparty_sources]
|
|
||||||
|
|
||||||
thirdparty_include_dirs = [
|
|
||||||
"",
|
|
||||||
"include",
|
|
||||||
"mkl-dnn/include",
|
|
||||||
"mkl-dnn/src",
|
|
||||||
"mkl-dnn/src/common",
|
|
||||||
"mkl-dnn/src/cpu/xbyak",
|
|
||||||
"mkl-dnn/src/cpu",
|
|
||||||
]
|
|
||||||
thirdparty_include_dirs = [thirdparty_dir + file for file in thirdparty_include_dirs]
|
|
||||||
|
|
||||||
|
|
||||||
env_oidn.Prepend(CPPPATH=thirdparty_include_dirs)
|
|
||||||
env_oidn.Append(
|
|
||||||
CPPDEFINES=[
|
|
||||||
"MKLDNN_THR=MKLDNN_THR_SEQ",
|
|
||||||
"OIDN_STATIC_LIB",
|
|
||||||
"__STDC_CONSTANT_MACROS",
|
|
||||||
"__STDC_LIMIT_MACROS",
|
|
||||||
"DISABLE_VERBOSE",
|
|
||||||
"MKLDNN_ENABLE_CONCURRENT_EXEC",
|
|
||||||
]
|
|
||||||
)
|
|
||||||
env_oidn.AppendUnique(CPPDEFINES=["NDEBUG"]) # No assert() even in debug builds.
|
|
||||||
|
|
||||||
env_thirdparty = env_oidn.Clone()
|
|
||||||
env_thirdparty.disable_warnings()
|
|
||||||
|
|
||||||
if env["disable_exceptions"]:
|
|
||||||
# OIDN hard-requires exceptions, so we re-enable them here.
|
|
||||||
if env.msvc and ("_HAS_EXCEPTIONS", 0) in env_thirdparty["CPPDEFINES"]:
|
|
||||||
env_thirdparty["CPPDEFINES"].remove(("_HAS_EXCEPTIONS", 0))
|
|
||||||
env_thirdparty.AppendUnique(CCFLAGS=["/EHsc"])
|
|
||||||
elif not env.msvc and "-fno-exceptions" in env_thirdparty["CCFLAGS"]:
|
|
||||||
env_thirdparty["CCFLAGS"].remove("-fno-exceptions")
|
|
||||||
|
|
||||||
env_thirdparty.add_source_files(thirdparty_obj, thirdparty_sources)
|
|
||||||
env.modules_sources += thirdparty_obj
|
|
||||||
|
|
||||||
weights_in_path = thirdparty_dir + "weights/rtlightmap_hdr.tza"
|
|
||||||
weights_out_path = thirdparty_dir + "weights/rtlightmap_hdr.gen.cpp"
|
|
||||||
|
|
||||||
env_thirdparty.Depends(weights_out_path, weights_in_path)
|
|
||||||
env_thirdparty.CommandNoCache(weights_out_path, weights_in_path, resource_to_cpp.tza_to_cpp)
|
|
||||||
|
|
||||||
# Godot source files
|
|
||||||
|
|
||||||
module_obj = []
|
|
||||||
|
|
||||||
env_oidn.add_source_files(module_obj, "*.cpp")
|
|
||||||
env.modules_sources += module_obj
|
|
||||||
|
|
||||||
# Needed to force rebuilding the module files when the thirdparty library is updated.
|
|
||||||
env.Depends(module_obj, thirdparty_obj)
|
|
@ -1,12 +0,0 @@
|
|||||||
def can_build(env, platform):
|
|
||||||
# Thirdparty dependency OpenImage Denoise includes oneDNN library
|
|
||||||
# and the version we use only supports x86_64.
|
|
||||||
# It's also only relevant for tools build and desktop platforms,
|
|
||||||
# as doing lightmap generation and denoising on Android or Web
|
|
||||||
# would be a bit far-fetched.
|
|
||||||
desktop_platforms = ["linuxbsd", "macos", "windows"]
|
|
||||||
return env.editor_build and platform in desktop_platforms and env["arch"] == "x86_64"
|
|
||||||
|
|
||||||
|
|
||||||
def configure(env):
|
|
||||||
pass
|
|
@ -1,66 +0,0 @@
|
|||||||
/**************************************************************************/
|
|
||||||
/* denoise_wrapper.cpp */
|
|
||||||
/**************************************************************************/
|
|
||||||
/* This file is part of: */
|
|
||||||
/* GODOT ENGINE */
|
|
||||||
/* https://godotengine.org */
|
|
||||||
/**************************************************************************/
|
|
||||||
/* Copyright (c) 2014-present Godot Engine contributors (see AUTHORS.md). */
|
|
||||||
/* Copyright (c) 2007-2014 Juan Linietsky, Ariel Manzur. */
|
|
||||||
/* */
|
|
||||||
/* Permission is hereby granted, free of charge, to any person obtaining */
|
|
||||||
/* a copy of this software and associated documentation files (the */
|
|
||||||
/* "Software"), to deal in the Software without restriction, including */
|
|
||||||
/* without limitation the rights to use, copy, modify, merge, publish, */
|
|
||||||
/* distribute, sublicense, and/or sell copies of the Software, and to */
|
|
||||||
/* permit persons to whom the Software is furnished to do so, subject to */
|
|
||||||
/* the following conditions: */
|
|
||||||
/* */
|
|
||||||
/* The above copyright notice and this permission notice shall be */
|
|
||||||
/* included in all copies or substantial portions of the Software. */
|
|
||||||
/* */
|
|
||||||
/* THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, */
|
|
||||||
/* EXPRESS OR IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF */
|
|
||||||
/* MERCHANTABILITY, FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. */
|
|
||||||
/* IN NO EVENT SHALL THE AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY */
|
|
||||||
/* CLAIM, DAMAGES OR OTHER LIABILITY, WHETHER IN AN ACTION OF CONTRACT, */
|
|
||||||
/* TORT OR OTHERWISE, ARISING FROM, OUT OF OR IN CONNECTION WITH THE */
|
|
||||||
/* SOFTWARE OR THE USE OR OTHER DEALINGS IN THE SOFTWARE. */
|
|
||||||
/**************************************************************************/
|
|
||||||
|
|
||||||
#include "denoise_wrapper.h"
|
|
||||||
|
|
||||||
#include <OpenImageDenoise/oidn.h>
|
|
||||||
|
|
||||||
#include <stdio.h>
|
|
||||||
|
|
||||||
void *oidn_denoiser_init() {
|
|
||||||
OIDNDeviceImpl *device = oidnNewDevice(OIDN_DEVICE_TYPE_CPU);
|
|
||||||
oidnCommitDevice(device);
|
|
||||||
return device;
|
|
||||||
}
|
|
||||||
|
|
||||||
bool oidn_denoise(void *deviceptr, float *p_floats, int p_width, int p_height) {
|
|
||||||
OIDNDeviceImpl *device = (OIDNDeviceImpl *)deviceptr;
|
|
||||||
OIDNFilter filter = oidnNewFilter(device, "RTLightmap");
|
|
||||||
oidnSetSharedFilterImage(filter, "color", (void *)p_floats, OIDN_FORMAT_FLOAT3, p_width, p_height, 0, 0, 0);
|
|
||||||
oidnSetSharedFilterImage(filter, "output", (void *)p_floats, OIDN_FORMAT_FLOAT3, p_width, p_height, 0, 0, 0);
|
|
||||||
oidnSetFilter1b(filter, "hdr", true);
|
|
||||||
//oidnSetFilter1f(filter, "hdrScale", 1.0f);
|
|
||||||
oidnCommitFilter(filter);
|
|
||||||
oidnExecuteFilter(filter);
|
|
||||||
|
|
||||||
const char *msg;
|
|
||||||
bool success = true;
|
|
||||||
if (oidnGetDeviceError(device, &msg) != OIDN_ERROR_NONE) {
|
|
||||||
printf("LightmapDenoiser: %s\n", msg);
|
|
||||||
success = false;
|
|
||||||
}
|
|
||||||
|
|
||||||
oidnReleaseFilter(filter);
|
|
||||||
return success;
|
|
||||||
}
|
|
||||||
|
|
||||||
void oidn_denoiser_finish(void *device) {
|
|
||||||
oidnReleaseDevice((OIDNDeviceImpl *)device);
|
|
||||||
}
|
|
@ -1,38 +0,0 @@
|
|||||||
/**************************************************************************/
|
|
||||||
/* denoise_wrapper.h */
|
|
||||||
/**************************************************************************/
|
|
||||||
/* This file is part of: */
|
|
||||||
/* GODOT ENGINE */
|
|
||||||
/* https://godotengine.org */
|
|
||||||
/**************************************************************************/
|
|
||||||
/* Copyright (c) 2014-present Godot Engine contributors (see AUTHORS.md). */
|
|
||||||
/* Copyright (c) 2007-2014 Juan Linietsky, Ariel Manzur. */
|
|
||||||
/* */
|
|
||||||
/* Permission is hereby granted, free of charge, to any person obtaining */
|
|
||||||
/* a copy of this software and associated documentation files (the */
|
|
||||||
/* "Software"), to deal in the Software without restriction, including */
|
|
||||||
/* without limitation the rights to use, copy, modify, merge, publish, */
|
|
||||||
/* distribute, sublicense, and/or sell copies of the Software, and to */
|
|
||||||
/* permit persons to whom the Software is furnished to do so, subject to */
|
|
||||||
/* the following conditions: */
|
|
||||||
/* */
|
|
||||||
/* The above copyright notice and this permission notice shall be */
|
|
||||||
/* included in all copies or substantial portions of the Software. */
|
|
||||||
/* */
|
|
||||||
/* THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, */
|
|
||||||
/* EXPRESS OR IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF */
|
|
||||||
/* MERCHANTABILITY, FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. */
|
|
||||||
/* IN NO EVENT SHALL THE AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY */
|
|
||||||
/* CLAIM, DAMAGES OR OTHER LIABILITY, WHETHER IN AN ACTION OF CONTRACT, */
|
|
||||||
/* TORT OR OTHERWISE, ARISING FROM, OUT OF OR IN CONNECTION WITH THE */
|
|
||||||
/* SOFTWARE OR THE USE OR OTHER DEALINGS IN THE SOFTWARE. */
|
|
||||||
/**************************************************************************/
|
|
||||||
|
|
||||||
#ifndef DENOISE_WRAPPER_H
|
|
||||||
#define DENOISE_WRAPPER_H
|
|
||||||
|
|
||||||
void *oidn_denoiser_init();
|
|
||||||
bool oidn_denoise(void *device, float *p_floats, int p_width, int p_height);
|
|
||||||
void oidn_denoiser_finish(void *device);
|
|
||||||
|
|
||||||
#endif // DENOISE_WRAPPER_H
|
|
@ -1,65 +0,0 @@
|
|||||||
/**************************************************************************/
|
|
||||||
/* lightmap_denoiser.cpp */
|
|
||||||
/**************************************************************************/
|
|
||||||
/* This file is part of: */
|
|
||||||
/* GODOT ENGINE */
|
|
||||||
/* https://godotengine.org */
|
|
||||||
/**************************************************************************/
|
|
||||||
/* Copyright (c) 2014-present Godot Engine contributors (see AUTHORS.md). */
|
|
||||||
/* Copyright (c) 2007-2014 Juan Linietsky, Ariel Manzur. */
|
|
||||||
/* */
|
|
||||||
/* Permission is hereby granted, free of charge, to any person obtaining */
|
|
||||||
/* a copy of this software and associated documentation files (the */
|
|
||||||
/* "Software"), to deal in the Software without restriction, including */
|
|
||||||
/* without limitation the rights to use, copy, modify, merge, publish, */
|
|
||||||
/* distribute, sublicense, and/or sell copies of the Software, and to */
|
|
||||||
/* permit persons to whom the Software is furnished to do so, subject to */
|
|
||||||
/* the following conditions: */
|
|
||||||
/* */
|
|
||||||
/* The above copyright notice and this permission notice shall be */
|
|
||||||
/* included in all copies or substantial portions of the Software. */
|
|
||||||
/* */
|
|
||||||
/* THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, */
|
|
||||||
/* EXPRESS OR IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF */
|
|
||||||
/* MERCHANTABILITY, FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. */
|
|
||||||
/* IN NO EVENT SHALL THE AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY */
|
|
||||||
/* CLAIM, DAMAGES OR OTHER LIABILITY, WHETHER IN AN ACTION OF CONTRACT, */
|
|
||||||
/* TORT OR OTHERWISE, ARISING FROM, OUT OF OR IN CONNECTION WITH THE */
|
|
||||||
/* SOFTWARE OR THE USE OR OTHER DEALINGS IN THE SOFTWARE. */
|
|
||||||
/**************************************************************************/
|
|
||||||
|
|
||||||
#include "lightmap_denoiser.h"
|
|
||||||
|
|
||||||
#include "denoise_wrapper.h"
|
|
||||||
|
|
||||||
#include "core/io/image.h"
|
|
||||||
|
|
||||||
LightmapDenoiser *LightmapDenoiserOIDN::create_oidn_denoiser() {
|
|
||||||
return memnew(LightmapDenoiserOIDN);
|
|
||||||
}
|
|
||||||
|
|
||||||
void LightmapDenoiserOIDN::make_default_denoiser() {
|
|
||||||
create_function = create_oidn_denoiser;
|
|
||||||
}
|
|
||||||
|
|
||||||
Ref<Image> LightmapDenoiserOIDN::denoise_image(const Ref<Image> &p_image) {
|
|
||||||
Ref<Image> img = p_image->duplicate();
|
|
||||||
|
|
||||||
img->convert(Image::FORMAT_RGBF);
|
|
||||||
|
|
||||||
Vector<uint8_t> data = img->get_data();
|
|
||||||
if (!oidn_denoise(device, (float *)data.ptrw(), img->get_width(), img->get_height())) {
|
|
||||||
return p_image;
|
|
||||||
}
|
|
||||||
|
|
||||||
img->set_data(img->get_width(), img->get_height(), false, img->get_format(), data);
|
|
||||||
return img;
|
|
||||||
}
|
|
||||||
|
|
||||||
LightmapDenoiserOIDN::LightmapDenoiserOIDN() {
|
|
||||||
device = oidn_denoiser_init();
|
|
||||||
}
|
|
||||||
|
|
||||||
LightmapDenoiserOIDN::~LightmapDenoiserOIDN() {
|
|
||||||
oidn_denoiser_finish(device);
|
|
||||||
}
|
|
@ -1,56 +0,0 @@
|
|||||||
/**************************************************************************/
|
|
||||||
/* lightmap_denoiser.h */
|
|
||||||
/**************************************************************************/
|
|
||||||
/* This file is part of: */
|
|
||||||
/* GODOT ENGINE */
|
|
||||||
/* https://godotengine.org */
|
|
||||||
/**************************************************************************/
|
|
||||||
/* Copyright (c) 2014-present Godot Engine contributors (see AUTHORS.md). */
|
|
||||||
/* Copyright (c) 2007-2014 Juan Linietsky, Ariel Manzur. */
|
|
||||||
/* */
|
|
||||||
/* Permission is hereby granted, free of charge, to any person obtaining */
|
|
||||||
/* a copy of this software and associated documentation files (the */
|
|
||||||
/* "Software"), to deal in the Software without restriction, including */
|
|
||||||
/* without limitation the rights to use, copy, modify, merge, publish, */
|
|
||||||
/* distribute, sublicense, and/or sell copies of the Software, and to */
|
|
||||||
/* permit persons to whom the Software is furnished to do so, subject to */
|
|
||||||
/* the following conditions: */
|
|
||||||
/* */
|
|
||||||
/* The above copyright notice and this permission notice shall be */
|
|
||||||
/* included in all copies or substantial portions of the Software. */
|
|
||||||
/* */
|
|
||||||
/* THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, */
|
|
||||||
/* EXPRESS OR IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF */
|
|
||||||
/* MERCHANTABILITY, FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. */
|
|
||||||
/* IN NO EVENT SHALL THE AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY */
|
|
||||||
/* CLAIM, DAMAGES OR OTHER LIABILITY, WHETHER IN AN ACTION OF CONTRACT, */
|
|
||||||
/* TORT OR OTHERWISE, ARISING FROM, OUT OF OR IN CONNECTION WITH THE */
|
|
||||||
/* SOFTWARE OR THE USE OR OTHER DEALINGS IN THE SOFTWARE. */
|
|
||||||
/**************************************************************************/
|
|
||||||
|
|
||||||
#ifndef LIGHTMAP_DENOISER_H
|
|
||||||
#define LIGHTMAP_DENOISER_H
|
|
||||||
|
|
||||||
#include "core/object/class_db.h"
|
|
||||||
#include "scene/3d/lightmapper.h"
|
|
||||||
|
|
||||||
struct OIDNDeviceImpl;
|
|
||||||
|
|
||||||
class LightmapDenoiserOIDN : public LightmapDenoiser {
|
|
||||||
GDCLASS(LightmapDenoiserOIDN, LightmapDenoiser);
|
|
||||||
|
|
||||||
protected:
|
|
||||||
void *device = nullptr;
|
|
||||||
|
|
||||||
public:
|
|
||||||
static LightmapDenoiser *create_oidn_denoiser();
|
|
||||||
|
|
||||||
Ref<Image> denoise_image(const Ref<Image> &p_image) override;
|
|
||||||
|
|
||||||
static void make_default_denoiser();
|
|
||||||
|
|
||||||
LightmapDenoiserOIDN();
|
|
||||||
~LightmapDenoiserOIDN();
|
|
||||||
};
|
|
||||||
|
|
||||||
#endif // LIGHTMAP_DENOISER_H
|
|
@ -1,49 +0,0 @@
|
|||||||
/**************************************************************************/
|
|
||||||
/* register_types.cpp */
|
|
||||||
/**************************************************************************/
|
|
||||||
/* This file is part of: */
|
|
||||||
/* GODOT ENGINE */
|
|
||||||
/* https://godotengine.org */
|
|
||||||
/**************************************************************************/
|
|
||||||
/* Copyright (c) 2014-present Godot Engine contributors (see AUTHORS.md). */
|
|
||||||
/* Copyright (c) 2007-2014 Juan Linietsky, Ariel Manzur. */
|
|
||||||
/* */
|
|
||||||
/* Permission is hereby granted, free of charge, to any person obtaining */
|
|
||||||
/* a copy of this software and associated documentation files (the */
|
|
||||||
/* "Software"), to deal in the Software without restriction, including */
|
|
||||||
/* without limitation the rights to use, copy, modify, merge, publish, */
|
|
||||||
/* distribute, sublicense, and/or sell copies of the Software, and to */
|
|
||||||
/* permit persons to whom the Software is furnished to do so, subject to */
|
|
||||||
/* the following conditions: */
|
|
||||||
/* */
|
|
||||||
/* The above copyright notice and this permission notice shall be */
|
|
||||||
/* included in all copies or substantial portions of the Software. */
|
|
||||||
/* */
|
|
||||||
/* THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, */
|
|
||||||
/* EXPRESS OR IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF */
|
|
||||||
/* MERCHANTABILITY, FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. */
|
|
||||||
/* IN NO EVENT SHALL THE AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY */
|
|
||||||
/* CLAIM, DAMAGES OR OTHER LIABILITY, WHETHER IN AN ACTION OF CONTRACT, */
|
|
||||||
/* TORT OR OTHERWISE, ARISING FROM, OUT OF OR IN CONNECTION WITH THE */
|
|
||||||
/* SOFTWARE OR THE USE OR OTHER DEALINGS IN THE SOFTWARE. */
|
|
||||||
/**************************************************************************/
|
|
||||||
|
|
||||||
#include "register_types.h"
|
|
||||||
|
|
||||||
#include "lightmap_denoiser.h"
|
|
||||||
|
|
||||||
#include "core/config/engine.h"
|
|
||||||
|
|
||||||
void initialize_denoise_module(ModuleInitializationLevel p_level) {
|
|
||||||
if (p_level != MODULE_INITIALIZATION_LEVEL_SCENE) {
|
|
||||||
return;
|
|
||||||
}
|
|
||||||
|
|
||||||
LightmapDenoiserOIDN::make_default_denoiser();
|
|
||||||
}
|
|
||||||
|
|
||||||
void uninitialize_denoise_module(ModuleInitializationLevel p_level) {
|
|
||||||
if (p_level != MODULE_INITIALIZATION_LEVEL_SCENE) {
|
|
||||||
return;
|
|
||||||
}
|
|
||||||
}
|
|
@ -1,39 +0,0 @@
|
|||||||
/**************************************************************************/
|
|
||||||
/* register_types.h */
|
|
||||||
/**************************************************************************/
|
|
||||||
/* This file is part of: */
|
|
||||||
/* GODOT ENGINE */
|
|
||||||
/* https://godotengine.org */
|
|
||||||
/**************************************************************************/
|
|
||||||
/* Copyright (c) 2014-present Godot Engine contributors (see AUTHORS.md). */
|
|
||||||
/* Copyright (c) 2007-2014 Juan Linietsky, Ariel Manzur. */
|
|
||||||
/* */
|
|
||||||
/* Permission is hereby granted, free of charge, to any person obtaining */
|
|
||||||
/* a copy of this software and associated documentation files (the */
|
|
||||||
/* "Software"), to deal in the Software without restriction, including */
|
|
||||||
/* without limitation the rights to use, copy, modify, merge, publish, */
|
|
||||||
/* distribute, sublicense, and/or sell copies of the Software, and to */
|
|
||||||
/* permit persons to whom the Software is furnished to do so, subject to */
|
|
||||||
/* the following conditions: */
|
|
||||||
/* */
|
|
||||||
/* The above copyright notice and this permission notice shall be */
|
|
||||||
/* included in all copies or substantial portions of the Software. */
|
|
||||||
/* */
|
|
||||||
/* THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, */
|
|
||||||
/* EXPRESS OR IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF */
|
|
||||||
/* MERCHANTABILITY, FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. */
|
|
||||||
/* IN NO EVENT SHALL THE AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY */
|
|
||||||
/* CLAIM, DAMAGES OR OTHER LIABILITY, WHETHER IN AN ACTION OF CONTRACT, */
|
|
||||||
/* TORT OR OTHERWISE, ARISING FROM, OUT OF OR IN CONNECTION WITH THE */
|
|
||||||
/* SOFTWARE OR THE USE OR OTHER DEALINGS IN THE SOFTWARE. */
|
|
||||||
/**************************************************************************/
|
|
||||||
|
|
||||||
#ifndef DENOISE_REGISTER_TYPES_H
|
|
||||||
#define DENOISE_REGISTER_TYPES_H
|
|
||||||
|
|
||||||
#include "modules/register_module_types.h"
|
|
||||||
|
|
||||||
void initialize_denoise_module(ModuleInitializationLevel p_level);
|
|
||||||
void uninitialize_denoise_module(ModuleInitializationLevel p_level);
|
|
||||||
|
|
||||||
#endif // DENOISE_REGISTER_TYPES_H
|
|
@ -1,68 +0,0 @@
|
|||||||
#!/usr/bin/env python
|
|
||||||
|
|
||||||
## ======================================================================== ##
|
|
||||||
## Copyright 2009-2019 Intel Corporation ##
|
|
||||||
## ##
|
|
||||||
## Licensed under the Apache License, Version 2.0 (the "License"); ##
|
|
||||||
## you may not use this file except in compliance with the License. ##
|
|
||||||
## You may obtain a copy of the License at ##
|
|
||||||
## ##
|
|
||||||
## http://www.apache.org/licenses/LICENSE-2.0 ##
|
|
||||||
## ##
|
|
||||||
## Unless required by applicable law or agreed to in writing, software ##
|
|
||||||
## distributed under the License is distributed on an "AS IS" BASIS, ##
|
|
||||||
## WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. ##
|
|
||||||
## See the License for the specific language governing permissions and ##
|
|
||||||
## limitations under the License. ##
|
|
||||||
## ======================================================================== ##
|
|
||||||
|
|
||||||
import os
|
|
||||||
from array import array
|
|
||||||
|
|
||||||
|
|
||||||
# Generates a C++ file from the specified binary resource file
|
|
||||||
def generate(in_path, out_path):
|
|
||||||
namespace = "oidn::weights"
|
|
||||||
scopes = namespace.split("::")
|
|
||||||
|
|
||||||
file_name = os.path.basename(in_path)
|
|
||||||
var_name = os.path.splitext(file_name)[0]
|
|
||||||
|
|
||||||
with open(in_path, "rb") as in_file, open(out_path, "w") as out_file:
|
|
||||||
# Header
|
|
||||||
out_file.write("// Generated from: %s\n" % file_name)
|
|
||||||
out_file.write("#include <cstddef>\n\n")
|
|
||||||
|
|
||||||
# Open the namespaces
|
|
||||||
for s in scopes:
|
|
||||||
out_file.write("namespace %s {\n" % s)
|
|
||||||
if scopes:
|
|
||||||
out_file.write("\n")
|
|
||||||
|
|
||||||
# Read the file
|
|
||||||
in_data = array("B", in_file.read())
|
|
||||||
|
|
||||||
# Write the size
|
|
||||||
out_file.write("//const size_t %s_size = %d;\n\n" % (var_name, len(in_data)))
|
|
||||||
|
|
||||||
# Write the data
|
|
||||||
out_file.write("unsigned char %s[] = {" % var_name)
|
|
||||||
for i in range(len(in_data)):
|
|
||||||
c = in_data[i]
|
|
||||||
if i > 0:
|
|
||||||
out_file.write(",")
|
|
||||||
if (i + 1) % 20 == 1:
|
|
||||||
out_file.write("\n")
|
|
||||||
out_file.write("%d" % c)
|
|
||||||
out_file.write("\n};\n")
|
|
||||||
|
|
||||||
# Close the namespaces
|
|
||||||
if scopes:
|
|
||||||
out_file.write("\n")
|
|
||||||
for scope in reversed(scopes):
|
|
||||||
out_file.write("} // namespace %s\n" % scope)
|
|
||||||
|
|
||||||
|
|
||||||
def tza_to_cpp(target, source, env):
|
|
||||||
for x in zip(source, target):
|
|
||||||
generate(str(x[0]), str(x[1]))
|
|
31
thirdparty/README.md
vendored
31
thirdparty/README.md
vendored
@ -642,37 +642,6 @@ Files extracted from the upstream source:
|
|||||||
- `nvapi_minimal.h` was created by using `nvapi.h` from upstream and removing unnecessary code.
|
- `nvapi_minimal.h` was created by using `nvapi.h` from upstream and removing unnecessary code.
|
||||||
|
|
||||||
|
|
||||||
## oidn
|
|
||||||
|
|
||||||
- Upstream: https://github.com/OpenImageDenoise/oidn
|
|
||||||
- Version: 1.1.0 (c58c5216db05ceef4cde5a096862f2eeffd14c06, 2019)
|
|
||||||
- License: Apache 2.0
|
|
||||||
|
|
||||||
Files extracted from upstream source:
|
|
||||||
|
|
||||||
- common/* (except tasking.* and CMakeLists.txt)
|
|
||||||
- core/*
|
|
||||||
- include/OpenImageDenoise/* (except version.h.in)
|
|
||||||
- LICENSE.txt
|
|
||||||
- mkl-dnn/include/*
|
|
||||||
- mkl-dnn/src/* (except CMakeLists.txt)
|
|
||||||
- weights/rtlightmap_hdr.tza
|
|
||||||
- scripts/resource_to_cpp.py
|
|
||||||
|
|
||||||
Modified files:
|
|
||||||
Modifications are marked with `// -- GODOT start --` and `// -- GODOT end --`.
|
|
||||||
Patch files are provided in `oidn/patches/`.
|
|
||||||
|
|
||||||
- core/autoencoder.cpp
|
|
||||||
- core/autoencoder.h
|
|
||||||
- core/common.h
|
|
||||||
- core/device.cpp
|
|
||||||
- core/device.h
|
|
||||||
- core/transfer_function.cpp
|
|
||||||
|
|
||||||
- scripts/resource_to_cpp.py (used in modules/denoise/resource_to_cpp.py)
|
|
||||||
|
|
||||||
|
|
||||||
## openxr
|
## openxr
|
||||||
|
|
||||||
- Upstream: https://github.com/KhronosGroup/OpenXR-SDK
|
- Upstream: https://github.com/KhronosGroup/OpenXR-SDK
|
||||||
|
202
thirdparty/oidn/LICENSE.txt
vendored
202
thirdparty/oidn/LICENSE.txt
vendored
@ -1,202 +0,0 @@
|
|||||||
|
|
||||||
Apache License
|
|
||||||
Version 2.0, January 2004
|
|
||||||
http://www.apache.org/licenses/
|
|
||||||
|
|
||||||
TERMS AND CONDITIONS FOR USE, REPRODUCTION, AND DISTRIBUTION
|
|
||||||
|
|
||||||
1. Definitions.
|
|
||||||
|
|
||||||
"License" shall mean the terms and conditions for use, reproduction,
|
|
||||||
and distribution as defined by Sections 1 through 9 of this document.
|
|
||||||
|
|
||||||
"Licensor" shall mean the copyright owner or entity authorized by
|
|
||||||
the copyright owner that is granting the License.
|
|
||||||
|
|
||||||
"Legal Entity" shall mean the union of the acting entity and all
|
|
||||||
other entities that control, are controlled by, or are under common
|
|
||||||
control with that entity. For the purposes of this definition,
|
|
||||||
"control" means (i) the power, direct or indirect, to cause the
|
|
||||||
direction or management of such entity, whether by contract or
|
|
||||||
otherwise, or (ii) ownership of fifty percent (50%) or more of the
|
|
||||||
outstanding shares, or (iii) beneficial ownership of such entity.
|
|
||||||
|
|
||||||
"You" (or "Your") shall mean an individual or Legal Entity
|
|
||||||
exercising permissions granted by this License.
|
|
||||||
|
|
||||||
"Source" form shall mean the preferred form for making modifications,
|
|
||||||
including but not limited to software source code, documentation
|
|
||||||
source, and configuration files.
|
|
||||||
|
|
||||||
"Object" form shall mean any form resulting from mechanical
|
|
||||||
transformation or translation of a Source form, including but
|
|
||||||
not limited to compiled object code, generated documentation,
|
|
||||||
and conversions to other media types.
|
|
||||||
|
|
||||||
"Work" shall mean the work of authorship, whether in Source or
|
|
||||||
Object form, made available under the License, as indicated by a
|
|
||||||
copyright notice that is included in or attached to the work
|
|
||||||
(an example is provided in the Appendix below).
|
|
||||||
|
|
||||||
"Derivative Works" shall mean any work, whether in Source or Object
|
|
||||||
form, that is based on (or derived from) the Work and for which the
|
|
||||||
editorial revisions, annotations, elaborations, or other modifications
|
|
||||||
represent, as a whole, an original work of authorship. For the purposes
|
|
||||||
of this License, Derivative Works shall not include works that remain
|
|
||||||
separable from, or merely link (or bind by name) to the interfaces of,
|
|
||||||
the Work and Derivative Works thereof.
|
|
||||||
|
|
||||||
"Contribution" shall mean any work of authorship, including
|
|
||||||
the original version of the Work and any modifications or additions
|
|
||||||
to that Work or Derivative Works thereof, that is intentionally
|
|
||||||
submitted to Licensor for inclusion in the Work by the copyright owner
|
|
||||||
or by an individual or Legal Entity authorized to submit on behalf of
|
|
||||||
the copyright owner. For the purposes of this definition, "submitted"
|
|
||||||
means any form of electronic, verbal, or written communication sent
|
|
||||||
to the Licensor or its representatives, including but not limited to
|
|
||||||
communication on electronic mailing lists, source code control systems,
|
|
||||||
and issue tracking systems that are managed by, or on behalf of, the
|
|
||||||
Licensor for the purpose of discussing and improving the Work, but
|
|
||||||
excluding communication that is conspicuously marked or otherwise
|
|
||||||
designated in writing by the copyright owner as "Not a Contribution."
|
|
||||||
|
|
||||||
"Contributor" shall mean Licensor and any individual or Legal Entity
|
|
||||||
on behalf of whom a Contribution has been received by Licensor and
|
|
||||||
subsequently incorporated within the Work.
|
|
||||||
|
|
||||||
2. Grant of Copyright License. Subject to the terms and conditions of
|
|
||||||
this License, each Contributor hereby grants to You a perpetual,
|
|
||||||
worldwide, non-exclusive, no-charge, royalty-free, irrevocable
|
|
||||||
copyright license to reproduce, prepare Derivative Works of,
|
|
||||||
publicly display, publicly perform, sublicense, and distribute the
|
|
||||||
Work and such Derivative Works in Source or Object form.
|
|
||||||
|
|
||||||
3. Grant of Patent License. Subject to the terms and conditions of
|
|
||||||
this License, each Contributor hereby grants to You a perpetual,
|
|
||||||
worldwide, non-exclusive, no-charge, royalty-free, irrevocable
|
|
||||||
(except as stated in this section) patent license to make, have made,
|
|
||||||
use, offer to sell, sell, import, and otherwise transfer the Work,
|
|
||||||
where such license applies only to those patent claims licensable
|
|
||||||
by such Contributor that are necessarily infringed by their
|
|
||||||
Contribution(s) alone or by combination of their Contribution(s)
|
|
||||||
with the Work to which such Contribution(s) was submitted. If You
|
|
||||||
institute patent litigation against any entity (including a
|
|
||||||
cross-claim or counterclaim in a lawsuit) alleging that the Work
|
|
||||||
or a Contribution incorporated within the Work constitutes direct
|
|
||||||
or contributory patent infringement, then any patent licenses
|
|
||||||
granted to You under this License for that Work shall terminate
|
|
||||||
as of the date such litigation is filed.
|
|
||||||
|
|
||||||
4. Redistribution. You may reproduce and distribute copies of the
|
|
||||||
Work or Derivative Works thereof in any medium, with or without
|
|
||||||
modifications, and in Source or Object form, provided that You
|
|
||||||
meet the following conditions:
|
|
||||||
|
|
||||||
(a) You must give any other recipients of the Work or
|
|
||||||
Derivative Works a copy of this License; and
|
|
||||||
|
|
||||||
(b) You must cause any modified files to carry prominent notices
|
|
||||||
stating that You changed the files; and
|
|
||||||
|
|
||||||
(c) You must retain, in the Source form of any Derivative Works
|
|
||||||
that You distribute, all copyright, patent, trademark, and
|
|
||||||
attribution notices from the Source form of the Work,
|
|
||||||
excluding those notices that do not pertain to any part of
|
|
||||||
the Derivative Works; and
|
|
||||||
|
|
||||||
(d) If the Work includes a "NOTICE" text file as part of its
|
|
||||||
distribution, then any Derivative Works that You distribute must
|
|
||||||
include a readable copy of the attribution notices contained
|
|
||||||
within such NOTICE file, excluding those notices that do not
|
|
||||||
pertain to any part of the Derivative Works, in at least one
|
|
||||||
of the following places: within a NOTICE text file distributed
|
|
||||||
as part of the Derivative Works; within the Source form or
|
|
||||||
documentation, if provided along with the Derivative Works; or,
|
|
||||||
within a display generated by the Derivative Works, if and
|
|
||||||
wherever such third-party notices normally appear. The contents
|
|
||||||
of the NOTICE file are for informational purposes only and
|
|
||||||
do not modify the License. You may add Your own attribution
|
|
||||||
notices within Derivative Works that You distribute, alongside
|
|
||||||
or as an addendum to the NOTICE text from the Work, provided
|
|
||||||
that such additional attribution notices cannot be construed
|
|
||||||
as modifying the License.
|
|
||||||
|
|
||||||
You may add Your own copyright statement to Your modifications and
|
|
||||||
may provide additional or different license terms and conditions
|
|
||||||
for use, reproduction, or distribution of Your modifications, or
|
|
||||||
for any such Derivative Works as a whole, provided Your use,
|
|
||||||
reproduction, and distribution of the Work otherwise complies with
|
|
||||||
the conditions stated in this License.
|
|
||||||
|
|
||||||
5. Submission of Contributions. Unless You explicitly state otherwise,
|
|
||||||
any Contribution intentionally submitted for inclusion in the Work
|
|
||||||
by You to the Licensor shall be under the terms and conditions of
|
|
||||||
this License, without any additional terms or conditions.
|
|
||||||
Notwithstanding the above, nothing herein shall supersede or modify
|
|
||||||
the terms of any separate license agreement you may have executed
|
|
||||||
with Licensor regarding such Contributions.
|
|
||||||
|
|
||||||
6. Trademarks. This License does not grant permission to use the trade
|
|
||||||
names, trademarks, service marks, or product names of the Licensor,
|
|
||||||
except as required for reasonable and customary use in describing the
|
|
||||||
origin of the Work and reproducing the content of the NOTICE file.
|
|
||||||
|
|
||||||
7. Disclaimer of Warranty. Unless required by applicable law or
|
|
||||||
agreed to in writing, Licensor provides the Work (and each
|
|
||||||
Contributor provides its Contributions) on an "AS IS" BASIS,
|
|
||||||
WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or
|
|
||||||
implied, including, without limitation, any warranties or conditions
|
|
||||||
of TITLE, NON-INFRINGEMENT, MERCHANTABILITY, or FITNESS FOR A
|
|
||||||
PARTICULAR PURPOSE. You are solely responsible for determining the
|
|
||||||
appropriateness of using or redistributing the Work and assume any
|
|
||||||
risks associated with Your exercise of permissions under this License.
|
|
||||||
|
|
||||||
8. Limitation of Liability. In no event and under no legal theory,
|
|
||||||
whether in tort (including negligence), contract, or otherwise,
|
|
||||||
unless required by applicable law (such as deliberate and grossly
|
|
||||||
negligent acts) or agreed to in writing, shall any Contributor be
|
|
||||||
liable to You for damages, including any direct, indirect, special,
|
|
||||||
incidental, or consequential damages of any character arising as a
|
|
||||||
result of this License or out of the use or inability to use the
|
|
||||||
Work (including but not limited to damages for loss of goodwill,
|
|
||||||
work stoppage, computer failure or malfunction, or any and all
|
|
||||||
other commercial damages or losses), even if such Contributor
|
|
||||||
has been advised of the possibility of such damages.
|
|
||||||
|
|
||||||
9. Accepting Warranty or Additional Liability. While redistributing
|
|
||||||
the Work or Derivative Works thereof, You may choose to offer,
|
|
||||||
and charge a fee for, acceptance of support, warranty, indemnity,
|
|
||||||
or other liability obligations and/or rights consistent with this
|
|
||||||
License. However, in accepting such obligations, You may act only
|
|
||||||
on Your own behalf and on Your sole responsibility, not on behalf
|
|
||||||
of any other Contributor, and only if You agree to indemnify,
|
|
||||||
defend, and hold each Contributor harmless for any liability
|
|
||||||
incurred by, or claims asserted against, such Contributor by reason
|
|
||||||
of your accepting any such warranty or additional liability.
|
|
||||||
|
|
||||||
END OF TERMS AND CONDITIONS
|
|
||||||
|
|
||||||
APPENDIX: How to apply the Apache License to your work.
|
|
||||||
|
|
||||||
To apply the Apache License to your work, attach the following
|
|
||||||
boilerplate notice, with the fields enclosed by brackets "[]"
|
|
||||||
replaced with your own identifying information. (Don't include
|
|
||||||
the brackets!) The text should be enclosed in the appropriate
|
|
||||||
comment syntax for the file format. We also recommend that a
|
|
||||||
file or class name and description of purpose be included on the
|
|
||||||
same "printed page" as the copyright notice for easier
|
|
||||||
identification within third-party archives.
|
|
||||||
|
|
||||||
Copyright [yyyy] [name of copyright owner]
|
|
||||||
|
|
||||||
Licensed under the Apache License, Version 2.0 (the "License");
|
|
||||||
you may not use this file except in compliance with the License.
|
|
||||||
You may obtain a copy of the License at
|
|
||||||
|
|
||||||
http://www.apache.org/licenses/LICENSE-2.0
|
|
||||||
|
|
||||||
Unless required by applicable law or agreed to in writing, software
|
|
||||||
distributed under the License is distributed on an "AS IS" BASIS,
|
|
||||||
WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
|
||||||
See the License for the specific language governing permissions and
|
|
||||||
limitations under the License.
|
|
52
thirdparty/oidn/common/barrier.h
vendored
52
thirdparty/oidn/common/barrier.h
vendored
@ -1,52 +0,0 @@
|
|||||||
// ======================================================================== //
|
|
||||||
// Copyright 2009-2019 Intel Corporation //
|
|
||||||
// //
|
|
||||||
// Licensed under the Apache License, Version 2.0 (the "License"); //
|
|
||||||
// you may not use this file except in compliance with the License. //
|
|
||||||
// You may obtain a copy of the License at //
|
|
||||||
// //
|
|
||||||
// http://www.apache.org/licenses/LICENSE-2.0 //
|
|
||||||
// //
|
|
||||||
// Unless required by applicable law or agreed to in writing, software //
|
|
||||||
// distributed under the License is distributed on an "AS IS" BASIS, //
|
|
||||||
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. //
|
|
||||||
// See the License for the specific language governing permissions and //
|
|
||||||
// limitations under the License. //
|
|
||||||
// ======================================================================== //
|
|
||||||
|
|
||||||
#pragma once
|
|
||||||
|
|
||||||
#include "platform.h"
|
|
||||||
#include <mutex>
|
|
||||||
#include <condition_variable>
|
|
||||||
|
|
||||||
namespace oidn {
|
|
||||||
|
|
||||||
class Barrier
|
|
||||||
{
|
|
||||||
private:
|
|
||||||
std::mutex m;
|
|
||||||
std::condition_variable cv;
|
|
||||||
volatile int count;
|
|
||||||
|
|
||||||
public:
|
|
||||||
Barrier(int count) : count(count) {}
|
|
||||||
|
|
||||||
void wait()
|
|
||||||
{
|
|
||||||
std::unique_lock<std::mutex> lk(m);
|
|
||||||
count--;
|
|
||||||
|
|
||||||
if (count == 0)
|
|
||||||
{
|
|
||||||
lk.unlock();
|
|
||||||
cv.notify_all();
|
|
||||||
}
|
|
||||||
else
|
|
||||||
{
|
|
||||||
cv.wait(lk, [&]{ return count == 0; });
|
|
||||||
}
|
|
||||||
}
|
|
||||||
};
|
|
||||||
|
|
||||||
} // namespace oidn
|
|
45
thirdparty/oidn/common/exception.h
vendored
45
thirdparty/oidn/common/exception.h
vendored
@ -1,45 +0,0 @@
|
|||||||
// ======================================================================== //
|
|
||||||
// Copyright 2009-2019 Intel Corporation //
|
|
||||||
// //
|
|
||||||
// Licensed under the Apache License, Version 2.0 (the "License"); //
|
|
||||||
// you may not use this file except in compliance with the License. //
|
|
||||||
// You may obtain a copy of the License at //
|
|
||||||
// //
|
|
||||||
// http://www.apache.org/licenses/LICENSE-2.0 //
|
|
||||||
// //
|
|
||||||
// Unless required by applicable law or agreed to in writing, software //
|
|
||||||
// distributed under the License is distributed on an "AS IS" BASIS, //
|
|
||||||
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. //
|
|
||||||
// See the License for the specific language governing permissions and //
|
|
||||||
// limitations under the License. //
|
|
||||||
// ======================================================================== //
|
|
||||||
|
|
||||||
#pragma once
|
|
||||||
|
|
||||||
#include <exception>
|
|
||||||
#include "platform.h"
|
|
||||||
|
|
||||||
namespace oidn {
|
|
||||||
|
|
||||||
class Exception : public std::exception
|
|
||||||
{
|
|
||||||
private:
|
|
||||||
Error error;
|
|
||||||
const char* message;
|
|
||||||
|
|
||||||
public:
|
|
||||||
Exception(Error error, const char* message)
|
|
||||||
: error(error), message(message) {}
|
|
||||||
|
|
||||||
Error code() const noexcept
|
|
||||||
{
|
|
||||||
return error;
|
|
||||||
}
|
|
||||||
|
|
||||||
const char* what() const noexcept override
|
|
||||||
{
|
|
||||||
return message;
|
|
||||||
}
|
|
||||||
};
|
|
||||||
|
|
||||||
} // namespace oidn
|
|
114
thirdparty/oidn/common/platform.cpp
vendored
114
thirdparty/oidn/common/platform.cpp
vendored
@ -1,114 +0,0 @@
|
|||||||
// ======================================================================== //
|
|
||||||
// Copyright 2009-2019 Intel Corporation //
|
|
||||||
// //
|
|
||||||
// Licensed under the Apache License, Version 2.0 (the "License"); //
|
|
||||||
// you may not use this file except in compliance with the License. //
|
|
||||||
// You may obtain a copy of the License at //
|
|
||||||
// //
|
|
||||||
// http://www.apache.org/licenses/LICENSE-2.0 //
|
|
||||||
// //
|
|
||||||
// Unless required by applicable law or agreed to in writing, software //
|
|
||||||
// distributed under the License is distributed on an "AS IS" BASIS, //
|
|
||||||
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. //
|
|
||||||
// See the License for the specific language governing permissions and //
|
|
||||||
// limitations under the License. //
|
|
||||||
// ======================================================================== //
|
|
||||||
|
|
||||||
#include "platform.h"
|
|
||||||
|
|
||||||
namespace oidn {
|
|
||||||
|
|
||||||
// ----------------------------------------------------------------------------
|
|
||||||
// Common functions
|
|
||||||
// ----------------------------------------------------------------------------
|
|
||||||
|
|
||||||
void* alignedMalloc(size_t size, size_t alignment)
|
|
||||||
{
|
|
||||||
if (size == 0)
|
|
||||||
return nullptr;
|
|
||||||
|
|
||||||
assert((alignment & (alignment-1)) == 0);
|
|
||||||
void* ptr = _mm_malloc(size, alignment);
|
|
||||||
|
|
||||||
if (ptr == nullptr)
|
|
||||||
throw std::bad_alloc();
|
|
||||||
|
|
||||||
return ptr;
|
|
||||||
}
|
|
||||||
|
|
||||||
void alignedFree(void* ptr)
|
|
||||||
{
|
|
||||||
if (ptr)
|
|
||||||
_mm_free(ptr);
|
|
||||||
}
|
|
||||||
|
|
||||||
// ----------------------------------------------------------------------------
|
|
||||||
// System information
|
|
||||||
// ----------------------------------------------------------------------------
|
|
||||||
|
|
||||||
std::string getPlatformName()
|
|
||||||
{
|
|
||||||
std::string name;
|
|
||||||
|
|
||||||
#if defined(__linux__)
|
|
||||||
name = "Linux";
|
|
||||||
#elif defined(__FreeBSD__)
|
|
||||||
name = "FreeBSD";
|
|
||||||
#elif defined(__CYGWIN__)
|
|
||||||
name = "Cygwin";
|
|
||||||
#elif defined(_WIN32)
|
|
||||||
name = "Windows";
|
|
||||||
#elif defined(__APPLE__)
|
|
||||||
name = "macOS";
|
|
||||||
#elif defined(__unix__)
|
|
||||||
name = "Unix";
|
|
||||||
#else
|
|
||||||
return "Unknown";
|
|
||||||
#endif
|
|
||||||
|
|
||||||
#if defined(__x86_64__) || defined(_M_X64) || defined(__ia64__) || defined(__aarch64__)
|
|
||||||
name += " (64-bit)";
|
|
||||||
#else
|
|
||||||
name += " (32-bit)";
|
|
||||||
#endif
|
|
||||||
|
|
||||||
return name;
|
|
||||||
}
|
|
||||||
|
|
||||||
std::string getCompilerName()
|
|
||||||
{
|
|
||||||
#if defined(__INTEL_COMPILER)
|
|
||||||
int mayor = __INTEL_COMPILER / 100 % 100;
|
|
||||||
int minor = __INTEL_COMPILER % 100;
|
|
||||||
std::string version = "Intel Compiler ";
|
|
||||||
version += toString(mayor);
|
|
||||||
version += "." + toString(minor);
|
|
||||||
#if defined(__INTEL_COMPILER_UPDATE)
|
|
||||||
version += "." + toString(__INTEL_COMPILER_UPDATE);
|
|
||||||
#endif
|
|
||||||
return version;
|
|
||||||
#elif defined(__clang__)
|
|
||||||
return "Clang " __clang_version__;
|
|
||||||
#elif defined(__GNUC__)
|
|
||||||
return "GCC " __VERSION__;
|
|
||||||
#elif defined(_MSC_VER)
|
|
||||||
std::string version = toString(_MSC_FULL_VER);
|
|
||||||
version.insert(4, ".");
|
|
||||||
version.insert(9, ".");
|
|
||||||
version.insert(2, ".");
|
|
||||||
return "Visual C++ Compiler " + version;
|
|
||||||
#else
|
|
||||||
return "Unknown";
|
|
||||||
#endif
|
|
||||||
}
|
|
||||||
|
|
||||||
std::string getBuildName()
|
|
||||||
{
|
|
||||||
#if defined(NDEBUG)
|
|
||||||
return "Release";
|
|
||||||
#else
|
|
||||||
return "Debug";
|
|
||||||
#endif
|
|
||||||
}
|
|
||||||
|
|
||||||
} // namespace oidn
|
|
131
thirdparty/oidn/common/platform.h
vendored
131
thirdparty/oidn/common/platform.h
vendored
@ -1,131 +0,0 @@
|
|||||||
// ======================================================================== //
|
|
||||||
// Copyright 2009-2019 Intel Corporation //
|
|
||||||
// //
|
|
||||||
// Licensed under the Apache License, Version 2.0 (the "License"); //
|
|
||||||
// you may not use this file except in compliance with the License. //
|
|
||||||
// You may obtain a copy of the License at //
|
|
||||||
// //
|
|
||||||
// http://www.apache.org/licenses/LICENSE-2.0 //
|
|
||||||
// //
|
|
||||||
// Unless required by applicable law or agreed to in writing, software //
|
|
||||||
// distributed under the License is distributed on an "AS IS" BASIS, //
|
|
||||||
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. //
|
|
||||||
// See the License for the specific language governing permissions and //
|
|
||||||
// limitations under the License. //
|
|
||||||
// ======================================================================== //
|
|
||||||
|
|
||||||
#pragma once
|
|
||||||
|
|
||||||
#if defined(_WIN32)
|
|
||||||
#define WIN32_LEAN_AND_MEAN
|
|
||||||
#define NOMINMAX
|
|
||||||
#include <windows.h>
|
|
||||||
#elif defined(__APPLE__)
|
|
||||||
#include <sys/sysctl.h>
|
|
||||||
#endif
|
|
||||||
|
|
||||||
#include <xmmintrin.h>
|
|
||||||
#include <cstdint>
|
|
||||||
#include <climits>
|
|
||||||
#include <limits>
|
|
||||||
#include <atomic>
|
|
||||||
#include <algorithm>
|
|
||||||
#include <memory>
|
|
||||||
#include <cmath>
|
|
||||||
#include <string>
|
|
||||||
#include <sstream>
|
|
||||||
#include <iostream>
|
|
||||||
#include <cassert>
|
|
||||||
#include "include/OpenImageDenoise/oidn.hpp"
|
|
||||||
|
|
||||||
namespace oidn {
|
|
||||||
|
|
||||||
// ----------------------------------------------------------------------------
|
|
||||||
// Macros
|
|
||||||
// ----------------------------------------------------------------------------
|
|
||||||
|
|
||||||
#if defined(_WIN32)
|
|
||||||
// Windows
|
|
||||||
#if !defined(__noinline)
|
|
||||||
#define __noinline __declspec(noinline)
|
|
||||||
#endif
|
|
||||||
#else
|
|
||||||
// Unix
|
|
||||||
#if !defined(__forceinline)
|
|
||||||
#define __forceinline inline __attribute__((always_inline))
|
|
||||||
#endif
|
|
||||||
#if !defined(__noinline)
|
|
||||||
#define __noinline __attribute__((noinline))
|
|
||||||
#endif
|
|
||||||
#endif
|
|
||||||
|
|
||||||
#ifndef UNUSED
|
|
||||||
#define UNUSED(x) ((void)x)
|
|
||||||
#endif
|
|
||||||
#ifndef MAYBE_UNUSED
|
|
||||||
#define MAYBE_UNUSED(x) UNUSED(x)
|
|
||||||
#endif
|
|
||||||
|
|
||||||
// ----------------------------------------------------------------------------
|
|
||||||
// Error handling and debugging
|
|
||||||
// ----------------------------------------------------------------------------
|
|
||||||
|
|
||||||
struct Verbose
|
|
||||||
{
|
|
||||||
int verbose;
|
|
||||||
|
|
||||||
Verbose(int v = 0) : verbose(v) {}
|
|
||||||
__forceinline bool isVerbose(int v = 1) const { return v <= verbose; }
|
|
||||||
};
|
|
||||||
|
|
||||||
#define OIDN_WARNING(message) { if (isVerbose()) std::cerr << "Warning: " << message << std::endl; }
|
|
||||||
#define OIDN_FATAL(message) throw std::runtime_error(message);
|
|
||||||
|
|
||||||
// ----------------------------------------------------------------------------
|
|
||||||
// Common functions
|
|
||||||
// ----------------------------------------------------------------------------
|
|
||||||
|
|
||||||
using std::min;
|
|
||||||
using std::max;
|
|
||||||
|
|
||||||
template<typename T>
|
|
||||||
__forceinline T clamp(const T& value, const T& minValue, const T& maxValue)
|
|
||||||
{
|
|
||||||
return min(max(value, minValue), maxValue);
|
|
||||||
}
|
|
||||||
|
|
||||||
void* alignedMalloc(size_t size, size_t alignment);
|
|
||||||
void alignedFree(void* ptr);
|
|
||||||
|
|
||||||
template<typename T>
|
|
||||||
inline std::string toString(const T& a)
|
|
||||||
{
|
|
||||||
std::stringstream sm;
|
|
||||||
sm << a;
|
|
||||||
return sm.str();
|
|
||||||
}
|
|
||||||
|
|
||||||
#if defined(__APPLE__)
|
|
||||||
template<typename T>
|
|
||||||
bool getSysctl(const char* name, T& value)
|
|
||||||
{
|
|
||||||
int64_t result = 0;
|
|
||||||
size_t size = sizeof(result);
|
|
||||||
|
|
||||||
if (sysctlbyname(name, &result, &size, nullptr, 0) != 0)
|
|
||||||
return false;
|
|
||||||
|
|
||||||
value = T(result);
|
|
||||||
return true;
|
|
||||||
}
|
|
||||||
#endif
|
|
||||||
|
|
||||||
// ----------------------------------------------------------------------------
|
|
||||||
// System information
|
|
||||||
// ----------------------------------------------------------------------------
|
|
||||||
|
|
||||||
std::string getPlatformName();
|
|
||||||
std::string getCompilerName();
|
|
||||||
std::string getBuildName();
|
|
||||||
|
|
||||||
} // namespace oidn
|
|
163
thirdparty/oidn/common/ref.h
vendored
163
thirdparty/oidn/common/ref.h
vendored
@ -1,163 +0,0 @@
|
|||||||
// ======================================================================== //
|
|
||||||
// Copyright 2009-2019 Intel Corporation //
|
|
||||||
// //
|
|
||||||
// Licensed under the Apache License, Version 2.0 (the "License"); //
|
|
||||||
// you may not use this file except in compliance with the License. //
|
|
||||||
// You may obtain a copy of the License at //
|
|
||||||
// //
|
|
||||||
// http://www.apache.org/licenses/LICENSE-2.0 //
|
|
||||||
// //
|
|
||||||
// Unless required by applicable law or agreed to in writing, software //
|
|
||||||
// distributed under the License is distributed on an "AS IS" BASIS, //
|
|
||||||
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. //
|
|
||||||
// See the License for the specific language governing permissions and //
|
|
||||||
// limitations under the License. //
|
|
||||||
// ======================================================================== //
|
|
||||||
|
|
||||||
#pragma once
|
|
||||||
|
|
||||||
#include "platform.h"
|
|
||||||
|
|
||||||
namespace oidn {
|
|
||||||
|
|
||||||
class RefCount
|
|
||||||
{
|
|
||||||
private:
|
|
||||||
std::atomic<size_t> count;
|
|
||||||
|
|
||||||
public:
|
|
||||||
__forceinline RefCount(int count = 0) noexcept : count(count) {}
|
|
||||||
|
|
||||||
__forceinline size_t incRef() noexcept
|
|
||||||
{
|
|
||||||
return count.fetch_add(1) + 1;
|
|
||||||
}
|
|
||||||
|
|
||||||
__forceinline size_t decRef()
|
|
||||||
{
|
|
||||||
const size_t newCount = decRefKeep();
|
|
||||||
if (newCount == 0)
|
|
||||||
destroy();
|
|
||||||
return newCount;
|
|
||||||
}
|
|
||||||
|
|
||||||
__forceinline size_t decRefKeep() noexcept
|
|
||||||
{
|
|
||||||
return count.fetch_add(-1) - 1;
|
|
||||||
}
|
|
||||||
|
|
||||||
__forceinline void destroy()
|
|
||||||
{
|
|
||||||
delete this;
|
|
||||||
}
|
|
||||||
|
|
||||||
protected:
|
|
||||||
// Disable copying
|
|
||||||
RefCount(const RefCount&) = delete;
|
|
||||||
RefCount& operator =(const RefCount&) = delete;
|
|
||||||
|
|
||||||
virtual ~RefCount() noexcept = default;
|
|
||||||
};
|
|
||||||
|
|
||||||
template<typename T>
|
|
||||||
class Ref
|
|
||||||
{
|
|
||||||
private:
|
|
||||||
T* ptr;
|
|
||||||
|
|
||||||
public:
|
|
||||||
__forceinline Ref() noexcept : ptr(nullptr) {}
|
|
||||||
__forceinline Ref(std::nullptr_t) noexcept : ptr(nullptr) {}
|
|
||||||
__forceinline Ref(const Ref& other) noexcept : ptr(other.ptr) { if (ptr) ptr->incRef(); }
|
|
||||||
__forceinline Ref(Ref&& other) noexcept : ptr(other.ptr) { other.ptr = nullptr; }
|
|
||||||
__forceinline Ref(T* ptr) noexcept : ptr(ptr) { if (ptr) ptr->incRef(); }
|
|
||||||
|
|
||||||
template<typename Y>
|
|
||||||
__forceinline Ref(const Ref<Y>& other) noexcept : ptr(other.get()) { if (ptr) ptr->incRef(); }
|
|
||||||
|
|
||||||
template<typename Y>
|
|
||||||
__forceinline explicit Ref(Y* ptr) noexcept : ptr(ptr) { if (ptr) ptr->incRef(); }
|
|
||||||
|
|
||||||
__forceinline ~Ref() { if (ptr) ptr->decRef(); }
|
|
||||||
|
|
||||||
__forceinline Ref& operator =(const Ref& other)
|
|
||||||
{
|
|
||||||
if (other.ptr)
|
|
||||||
other.ptr->incRef();
|
|
||||||
if (ptr)
|
|
||||||
ptr->decRef();
|
|
||||||
ptr = other.ptr;
|
|
||||||
return *this;
|
|
||||||
}
|
|
||||||
|
|
||||||
__forceinline Ref& operator =(Ref&& other)
|
|
||||||
{
|
|
||||||
if (ptr)
|
|
||||||
ptr->decRef();
|
|
||||||
ptr = other.ptr;
|
|
||||||
other.ptr = nullptr;
|
|
||||||
return *this;
|
|
||||||
}
|
|
||||||
|
|
||||||
__forceinline Ref& operator =(T* other)
|
|
||||||
{
|
|
||||||
if (other)
|
|
||||||
other->incRef();
|
|
||||||
if (ptr)
|
|
||||||
ptr->decRef();
|
|
||||||
ptr = other;
|
|
||||||
return *this;
|
|
||||||
}
|
|
||||||
|
|
||||||
__forceinline Ref& operator =(std::nullptr_t)
|
|
||||||
{
|
|
||||||
if (ptr)
|
|
||||||
ptr->decRef();
|
|
||||||
ptr = nullptr;
|
|
||||||
return *this;
|
|
||||||
}
|
|
||||||
|
|
||||||
__forceinline operator bool() const noexcept { return ptr != nullptr; }
|
|
||||||
|
|
||||||
__forceinline T& operator *() const noexcept { return *ptr; }
|
|
||||||
__forceinline T* operator ->() const noexcept { return ptr; }
|
|
||||||
|
|
||||||
__forceinline T* get() const noexcept { return ptr; }
|
|
||||||
|
|
||||||
__forceinline T* detach() noexcept
|
|
||||||
{
|
|
||||||
T* res = ptr;
|
|
||||||
ptr = nullptr;
|
|
||||||
return res;
|
|
||||||
}
|
|
||||||
};
|
|
||||||
|
|
||||||
template<typename T> __forceinline bool operator < (const Ref<T>& a, const Ref<T>& b) noexcept { return a.ptr < b.ptr; }
|
|
||||||
|
|
||||||
template<typename T> __forceinline bool operator ==(const Ref<T>& a, std::nullptr_t) noexcept { return a.ptr == nullptr; }
|
|
||||||
template<typename T> __forceinline bool operator ==(std::nullptr_t, const Ref<T>& b) noexcept { return nullptr == b.ptr; }
|
|
||||||
template<typename T> __forceinline bool operator ==(const Ref<T>& a, const Ref<T>& b) noexcept { return a.ptr == b.ptr; }
|
|
||||||
|
|
||||||
template<typename T> __forceinline bool operator !=(const Ref<T>& a, std::nullptr_t) noexcept { return a.ptr != nullptr; }
|
|
||||||
template<typename T> __forceinline bool operator !=(std::nullptr_t, const Ref<T>& b) noexcept { return nullptr != b.ptr; }
|
|
||||||
template<typename T> __forceinline bool operator !=(const Ref<T>& a, const Ref<T>& b) noexcept { return a.ptr != b.ptr; }
|
|
||||||
|
|
||||||
template<typename T, typename... Args>
|
|
||||||
__forceinline Ref<T> makeRef(Args&&... args)
|
|
||||||
{
|
|
||||||
return Ref<T>(new T(std::forward<Args>(args)...));
|
|
||||||
}
|
|
||||||
|
|
||||||
template<typename T, typename Y>
|
|
||||||
__forceinline Ref<Y> staticRefCast(const Ref<T>& a)
|
|
||||||
{
|
|
||||||
return Ref<Y>(static_cast<Y*>(a.get()));
|
|
||||||
}
|
|
||||||
|
|
||||||
template<typename T, typename Y>
|
|
||||||
__forceinline Ref<Y> dynamicRefCast(const Ref<T>& a)
|
|
||||||
{
|
|
||||||
return Ref<Y>(dynamic_cast<Y*>(a.get()));
|
|
||||||
}
|
|
||||||
|
|
||||||
} // namespace oidn
|
|
83
thirdparty/oidn/common/tensor.cpp
vendored
83
thirdparty/oidn/common/tensor.cpp
vendored
@ -1,83 +0,0 @@
|
|||||||
// ======================================================================== //
|
|
||||||
// Copyright 2009-2019 Intel Corporation //
|
|
||||||
// //
|
|
||||||
// Licensed under the Apache License, Version 2.0 (the "License"); //
|
|
||||||
// you may not use this file except in compliance with the License. //
|
|
||||||
// You may obtain a copy of the License at //
|
|
||||||
// //
|
|
||||||
// http://www.apache.org/licenses/LICENSE-2.0 //
|
|
||||||
// //
|
|
||||||
// Unless required by applicable law or agreed to in writing, software //
|
|
||||||
// distributed under the License is distributed on an "AS IS" BASIS, //
|
|
||||||
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. //
|
|
||||||
// See the License for the specific language governing permissions and //
|
|
||||||
// limitations under the License. //
|
|
||||||
// ======================================================================== //
|
|
||||||
|
|
||||||
#include "exception.h"
|
|
||||||
#include "tensor.h"
|
|
||||||
|
|
||||||
namespace oidn {
|
|
||||||
|
|
||||||
std::map<std::string, Tensor> parseTensors(void* buffer)
|
|
||||||
{
|
|
||||||
char* input = (char*)buffer;
|
|
||||||
|
|
||||||
// Parse the magic value
|
|
||||||
const int magic = *(unsigned short*)input;
|
|
||||||
if (magic != 0x41D7)
|
|
||||||
throw Exception(Error::InvalidOperation, "invalid tensor archive");
|
|
||||||
input += sizeof(unsigned short);
|
|
||||||
|
|
||||||
// Parse the version
|
|
||||||
const int majorVersion = *(unsigned char*)input++;
|
|
||||||
const int minorVersion = *(unsigned char*)input++;
|
|
||||||
UNUSED(minorVersion);
|
|
||||||
if (majorVersion > 1)
|
|
||||||
throw Exception(Error::InvalidOperation, "unsupported tensor archive version");
|
|
||||||
|
|
||||||
// Parse the number of tensors
|
|
||||||
const int numTensors = *(int*)input;
|
|
||||||
input += sizeof(int);
|
|
||||||
|
|
||||||
// Parse the tensors
|
|
||||||
std::map<std::string, Tensor> tensorMap;
|
|
||||||
for (int i = 0; i < numTensors; ++i)
|
|
||||||
{
|
|
||||||
Tensor tensor;
|
|
||||||
|
|
||||||
// Parse the name
|
|
||||||
const int nameLen = *(unsigned char*)input++;
|
|
||||||
std::string name(input, nameLen);
|
|
||||||
input += nameLen;
|
|
||||||
|
|
||||||
// Parse the number of dimensions
|
|
||||||
const int ndims = *(unsigned char*)input++;
|
|
||||||
|
|
||||||
// Parse the shape of the tensor
|
|
||||||
tensor.dims.resize(ndims);
|
|
||||||
for (int i = 0; i < ndims; ++i)
|
|
||||||
tensor.dims[i] = ((int*)input)[i];
|
|
||||||
input += ndims * sizeof(int);
|
|
||||||
|
|
||||||
// Parse the format of the tensor
|
|
||||||
tensor.format = std::string(input, input + ndims);
|
|
||||||
input += ndims;
|
|
||||||
|
|
||||||
// Parse the data type of the tensor
|
|
||||||
const char type = *(unsigned char*)input++;
|
|
||||||
if (type != 'f') // only float32 is supported
|
|
||||||
throw Exception(Error::InvalidOperation, "unsupported tensor data type");
|
|
||||||
|
|
||||||
// Skip the data
|
|
||||||
tensor.data = (float*)input;
|
|
||||||
input += tensor.size() * sizeof(float);
|
|
||||||
|
|
||||||
// Add the tensor to the map
|
|
||||||
tensorMap.emplace(name, std::move(tensor));
|
|
||||||
}
|
|
||||||
|
|
||||||
return tensorMap;
|
|
||||||
}
|
|
||||||
|
|
||||||
} // namespace oidn
|
|
66
thirdparty/oidn/common/tensor.h
vendored
66
thirdparty/oidn/common/tensor.h
vendored
@ -1,66 +0,0 @@
|
|||||||
// ======================================================================== //
|
|
||||||
// Copyright 2009-2019 Intel Corporation //
|
|
||||||
// //
|
|
||||||
// Licensed under the Apache License, Version 2.0 (the "License"); //
|
|
||||||
// you may not use this file except in compliance with the License. //
|
|
||||||
// You may obtain a copy of the License at //
|
|
||||||
// //
|
|
||||||
// http://www.apache.org/licenses/LICENSE-2.0 //
|
|
||||||
// //
|
|
||||||
// Unless required by applicable law or agreed to in writing, software //
|
|
||||||
// distributed under the License is distributed on an "AS IS" BASIS, //
|
|
||||||
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. //
|
|
||||||
// See the License for the specific language governing permissions and //
|
|
||||||
// limitations under the License. //
|
|
||||||
// ======================================================================== //
|
|
||||||
|
|
||||||
#pragma once
|
|
||||||
|
|
||||||
#include "platform.h"
|
|
||||||
#include <vector>
|
|
||||||
#include <map>
|
|
||||||
|
|
||||||
namespace oidn {
|
|
||||||
|
|
||||||
template<typename T>
|
|
||||||
using shared_vector = std::shared_ptr<std::vector<T>>;
|
|
||||||
|
|
||||||
// Generic tensor
|
|
||||||
struct Tensor
|
|
||||||
{
|
|
||||||
float* data;
|
|
||||||
std::vector<int64_t> dims;
|
|
||||||
std::string format;
|
|
||||||
shared_vector<char> buffer; // optional, only for reference counting
|
|
||||||
|
|
||||||
__forceinline Tensor() : data(nullptr) {}
|
|
||||||
|
|
||||||
__forceinline Tensor(const std::vector<int64_t>& dims, const std::string& format)
|
|
||||||
: dims(dims),
|
|
||||||
format(format)
|
|
||||||
{
|
|
||||||
buffer = std::make_shared<std::vector<char>>(size() * sizeof(float));
|
|
||||||
data = (float*)buffer->data();
|
|
||||||
}
|
|
||||||
|
|
||||||
__forceinline operator bool() const { return data != nullptr; }
|
|
||||||
|
|
||||||
__forceinline int ndims() const { return (int)dims.size(); }
|
|
||||||
|
|
||||||
// Returns the number of values
|
|
||||||
__forceinline size_t size() const
|
|
||||||
{
|
|
||||||
size_t size = 1;
|
|
||||||
for (int i = 0; i < ndims(); ++i)
|
|
||||||
size *= dims[i];
|
|
||||||
return size;
|
|
||||||
}
|
|
||||||
|
|
||||||
__forceinline float& operator [](size_t i) { return data[i]; }
|
|
||||||
__forceinline const float& operator [](size_t i) const { return data[i]; }
|
|
||||||
};
|
|
||||||
|
|
||||||
// Parses tensors from a buffer
|
|
||||||
std::map<std::string, Tensor> parseTensors(void* buffer);
|
|
||||||
|
|
||||||
} // namespace oidn
|
|
297
thirdparty/oidn/common/thread.cpp
vendored
297
thirdparty/oidn/common/thread.cpp
vendored
@ -1,297 +0,0 @@
|
|||||||
// ======================================================================== //
|
|
||||||
// Copyright 2009-2019 Intel Corporation //
|
|
||||||
// //
|
|
||||||
// Licensed under the Apache License, Version 2.0 (the "License"); //
|
|
||||||
// you may not use this file except in compliance with the License. //
|
|
||||||
// You may obtain a copy of the License at //
|
|
||||||
// //
|
|
||||||
// http://www.apache.org/licenses/LICENSE-2.0 //
|
|
||||||
// //
|
|
||||||
// Unless required by applicable law or agreed to in writing, software //
|
|
||||||
// distributed under the License is distributed on an "AS IS" BASIS, //
|
|
||||||
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. //
|
|
||||||
// See the License for the specific language governing permissions and //
|
|
||||||
// limitations under the License. //
|
|
||||||
// ======================================================================== //
|
|
||||||
|
|
||||||
#if defined(_MSC_VER)
|
|
||||||
#pragma warning (disable : 4146) // unary minus operator applied to unsigned type, result still unsigned
|
|
||||||
#endif
|
|
||||||
|
|
||||||
#if defined(__APPLE__)
|
|
||||||
#include <mach/thread_act.h>
|
|
||||||
#include <mach/mach_init.h>
|
|
||||||
#endif
|
|
||||||
|
|
||||||
#include "thread.h"
|
|
||||||
#include <fstream>
|
|
||||||
|
|
||||||
namespace oidn {
|
|
||||||
|
|
||||||
#if defined(_WIN32)
|
|
||||||
|
|
||||||
// --------------------------------------------------------------------------
|
|
||||||
// ThreadAffinity - Windows
|
|
||||||
// --------------------------------------------------------------------------
|
|
||||||
|
|
||||||
ThreadAffinity::ThreadAffinity(int numThreadsPerCore, int verbose)
|
|
||||||
: Verbose(verbose)
|
|
||||||
{
|
|
||||||
HMODULE hLib = GetModuleHandle(TEXT("kernel32"));
|
|
||||||
pGetLogicalProcessorInformationEx = (GetLogicalProcessorInformationExFunc)GetProcAddress(hLib, "GetLogicalProcessorInformationEx");
|
|
||||||
pSetThreadGroupAffinity = (SetThreadGroupAffinityFunc)GetProcAddress(hLib, "SetThreadGroupAffinity");
|
|
||||||
|
|
||||||
if (pGetLogicalProcessorInformationEx && pSetThreadGroupAffinity)
|
|
||||||
{
|
|
||||||
// Get logical processor information
|
|
||||||
PSYSTEM_LOGICAL_PROCESSOR_INFORMATION_EX buffer = nullptr;
|
|
||||||
DWORD bufferSize = 0;
|
|
||||||
|
|
||||||
// First call the function with an empty buffer to get the required buffer size
|
|
||||||
BOOL result = pGetLogicalProcessorInformationEx(RelationProcessorCore, buffer, &bufferSize);
|
|
||||||
if (result || GetLastError() != ERROR_INSUFFICIENT_BUFFER)
|
|
||||||
{
|
|
||||||
OIDN_WARNING("GetLogicalProcessorInformationEx failed");
|
|
||||||
return;
|
|
||||||
}
|
|
||||||
|
|
||||||
// Allocate the buffer
|
|
||||||
buffer = (PSYSTEM_LOGICAL_PROCESSOR_INFORMATION_EX)malloc(bufferSize);
|
|
||||||
if (!buffer)
|
|
||||||
{
|
|
||||||
OIDN_WARNING("SYSTEM_LOGICAL_PROCESSOR_INFORMATION_EX allocation failed");
|
|
||||||
return;
|
|
||||||
}
|
|
||||||
|
|
||||||
// Call again the function but now with the properly sized buffer
|
|
||||||
result = pGetLogicalProcessorInformationEx(RelationProcessorCore, buffer, &bufferSize);
|
|
||||||
if (!result)
|
|
||||||
{
|
|
||||||
OIDN_WARNING("GetLogicalProcessorInformationEx failed");
|
|
||||||
free(buffer);
|
|
||||||
return;
|
|
||||||
}
|
|
||||||
|
|
||||||
// Iterate over the logical processor information structures
|
|
||||||
// There should be one structure for each physical core
|
|
||||||
char* ptr = (char*)buffer;
|
|
||||||
while (ptr < (char*)buffer + bufferSize)
|
|
||||||
{
|
|
||||||
PSYSTEM_LOGICAL_PROCESSOR_INFORMATION_EX item = (PSYSTEM_LOGICAL_PROCESSOR_INFORMATION_EX)ptr;
|
|
||||||
if (item->Relationship == RelationProcessorCore && item->Processor.GroupCount > 0)
|
|
||||||
{
|
|
||||||
// Iterate over the groups
|
|
||||||
int numThreads = 0;
|
|
||||||
for (int group = 0; (group < item->Processor.GroupCount) && (numThreads < numThreadsPerCore); ++group)
|
|
||||||
{
|
|
||||||
GROUP_AFFINITY coreAffinity = item->Processor.GroupMask[group];
|
|
||||||
while ((coreAffinity.Mask != 0) && (numThreads < numThreadsPerCore))
|
|
||||||
{
|
|
||||||
// Extract the next set bit/thread from the mask
|
|
||||||
GROUP_AFFINITY threadAffinity = coreAffinity;
|
|
||||||
threadAffinity.Mask = threadAffinity.Mask & -threadAffinity.Mask;
|
|
||||||
|
|
||||||
// Push the affinity for this thread
|
|
||||||
affinities.push_back(threadAffinity);
|
|
||||||
oldAffinities.push_back(threadAffinity);
|
|
||||||
numThreads++;
|
|
||||||
|
|
||||||
// Remove this bit/thread from the mask
|
|
||||||
coreAffinity.Mask ^= threadAffinity.Mask;
|
|
||||||
}
|
|
||||||
}
|
|
||||||
}
|
|
||||||
|
|
||||||
// Next structure
|
|
||||||
ptr += item->Size;
|
|
||||||
}
|
|
||||||
|
|
||||||
// Free the buffer
|
|
||||||
free(buffer);
|
|
||||||
}
|
|
||||||
}
|
|
||||||
|
|
||||||
void ThreadAffinity::set(int threadIndex)
|
|
||||||
{
|
|
||||||
if (threadIndex >= (int)affinities.size())
|
|
||||||
return;
|
|
||||||
|
|
||||||
// Save the current affinity and set the new one
|
|
||||||
const HANDLE thread = GetCurrentThread();
|
|
||||||
if (!pSetThreadGroupAffinity(thread, &affinities[threadIndex], &oldAffinities[threadIndex]))
|
|
||||||
OIDN_WARNING("SetThreadGroupAffinity failed");
|
|
||||||
}
|
|
||||||
|
|
||||||
void ThreadAffinity::restore(int threadIndex)
|
|
||||||
{
|
|
||||||
if (threadIndex >= (int)affinities.size())
|
|
||||||
return;
|
|
||||||
|
|
||||||
// Restore the original affinity
|
|
||||||
const HANDLE thread = GetCurrentThread();
|
|
||||||
if (!pSetThreadGroupAffinity(thread, &oldAffinities[threadIndex], nullptr))
|
|
||||||
OIDN_WARNING("SetThreadGroupAffinity failed");
|
|
||||||
}
|
|
||||||
|
|
||||||
#elif defined(__linux__)
|
|
||||||
|
|
||||||
// --------------------------------------------------------------------------
|
|
||||||
// ThreadAffinity - Linux
|
|
||||||
// --------------------------------------------------------------------------
|
|
||||||
|
|
||||||
ThreadAffinity::ThreadAffinity(int numThreadsPerCore, int verbose)
|
|
||||||
: Verbose(verbose)
|
|
||||||
{
|
|
||||||
std::vector<int> threadIds;
|
|
||||||
|
|
||||||
// Parse the thread/CPU topology
|
|
||||||
for (int cpuId = 0; ; cpuId++)
|
|
||||||
{
|
|
||||||
std::fstream fs;
|
|
||||||
std::string cpu = std::string("/sys/devices/system/cpu/cpu") + std::to_string(cpuId) + std::string("/topology/thread_siblings_list");
|
|
||||||
fs.open(cpu.c_str(), std::fstream::in);
|
|
||||||
if (fs.fail()) break;
|
|
||||||
|
|
||||||
int i;
|
|
||||||
int j = 0;
|
|
||||||
while ((j < numThreadsPerCore) && (fs >> i))
|
|
||||||
{
|
|
||||||
if (std::none_of(threadIds.begin(), threadIds.end(), [&](int id) { return id == i; }))
|
|
||||||
threadIds.push_back(i);
|
|
||||||
|
|
||||||
if (fs.peek() == ',')
|
|
||||||
fs.ignore();
|
|
||||||
j++;
|
|
||||||
}
|
|
||||||
|
|
||||||
fs.close();
|
|
||||||
}
|
|
||||||
|
|
||||||
#if 0
|
|
||||||
for (size_t i = 0; i < thread_ids.size(); ++i)
|
|
||||||
std::cout << "thread " << i << " -> " << thread_ids[i] << std::endl;
|
|
||||||
#endif
|
|
||||||
|
|
||||||
// Create the affinity structures
|
|
||||||
affinities.resize(threadIds.size());
|
|
||||||
oldAffinities.resize(threadIds.size());
|
|
||||||
|
|
||||||
for (size_t i = 0; i < threadIds.size(); ++i)
|
|
||||||
{
|
|
||||||
cpu_set_t affinity;
|
|
||||||
CPU_ZERO(&affinity);
|
|
||||||
CPU_SET(threadIds[i], &affinity);
|
|
||||||
|
|
||||||
affinities[i] = affinity;
|
|
||||||
oldAffinities[i] = affinity;
|
|
||||||
}
|
|
||||||
}
|
|
||||||
|
|
||||||
void ThreadAffinity::set(int threadIndex)
|
|
||||||
{
|
|
||||||
if (threadIndex >= (int)affinities.size())
|
|
||||||
return;
|
|
||||||
|
|
||||||
const pthread_t thread = pthread_self();
|
|
||||||
|
|
||||||
// Save the current affinity
|
|
||||||
if (pthread_getaffinity_np(thread, sizeof(cpu_set_t), &oldAffinities[threadIndex]) != 0)
|
|
||||||
{
|
|
||||||
OIDN_WARNING("pthread_getaffinity_np failed");
|
|
||||||
oldAffinities[threadIndex] = affinities[threadIndex];
|
|
||||||
return;
|
|
||||||
}
|
|
||||||
|
|
||||||
// Set the new affinity
|
|
||||||
if (pthread_setaffinity_np(thread, sizeof(cpu_set_t), &affinities[threadIndex]) != 0)
|
|
||||||
OIDN_WARNING("pthread_setaffinity_np failed");
|
|
||||||
}
|
|
||||||
|
|
||||||
void ThreadAffinity::restore(int threadIndex)
|
|
||||||
{
|
|
||||||
if (threadIndex >= (int)affinities.size())
|
|
||||||
return;
|
|
||||||
|
|
||||||
const pthread_t thread = pthread_self();
|
|
||||||
|
|
||||||
// Restore the original affinity
|
|
||||||
if (pthread_setaffinity_np(thread, sizeof(cpu_set_t), &oldAffinities[threadIndex]) != 0)
|
|
||||||
OIDN_WARNING("pthread_setaffinity_np failed");
|
|
||||||
}
|
|
||||||
|
|
||||||
#elif defined(__APPLE__)
|
|
||||||
|
|
||||||
// --------------------------------------------------------------------------
|
|
||||||
// ThreadAffinity - macOS
|
|
||||||
// --------------------------------------------------------------------------
|
|
||||||
|
|
||||||
ThreadAffinity::ThreadAffinity(int numThreadsPerCore, int verbose)
|
|
||||||
: Verbose(verbose)
|
|
||||||
{
|
|
||||||
// Query the thread/CPU topology
|
|
||||||
int numPhysicalCpus;
|
|
||||||
int numLogicalCpus;
|
|
||||||
|
|
||||||
if (!getSysctl("hw.physicalcpu", numPhysicalCpus) || !getSysctl("hw.logicalcpu", numLogicalCpus))
|
|
||||||
{
|
|
||||||
OIDN_WARNING("sysctlbyname failed");
|
|
||||||
return;
|
|
||||||
}
|
|
||||||
|
|
||||||
if ((numLogicalCpus % numPhysicalCpus != 0) && (numThreadsPerCore > 1))
|
|
||||||
return; // this shouldn't happen
|
|
||||||
const int maxThreadsPerCore = numLogicalCpus / numPhysicalCpus;
|
|
||||||
|
|
||||||
// Create the affinity structures
|
|
||||||
// macOS doesn't support binding a thread to a specific core, but we can at least group threads which
|
|
||||||
// should be on the same core together
|
|
||||||
for (int core = 1; core <= numPhysicalCpus; ++core) // tags start from 1!
|
|
||||||
{
|
|
||||||
thread_affinity_policy affinity;
|
|
||||||
affinity.affinity_tag = core;
|
|
||||||
|
|
||||||
for (int thread = 0; thread < min(numThreadsPerCore, maxThreadsPerCore); ++thread)
|
|
||||||
{
|
|
||||||
affinities.push_back(affinity);
|
|
||||||
oldAffinities.push_back(affinity);
|
|
||||||
}
|
|
||||||
}
|
|
||||||
}
|
|
||||||
|
|
||||||
void ThreadAffinity::set(int threadIndex)
|
|
||||||
{
|
|
||||||
if (threadIndex >= (int)affinities.size())
|
|
||||||
return;
|
|
||||||
|
|
||||||
const auto thread = mach_thread_self();
|
|
||||||
|
|
||||||
// Save the current affinity
|
|
||||||
mach_msg_type_number_t policyCount = THREAD_AFFINITY_POLICY_COUNT;
|
|
||||||
boolean_t getDefault = FALSE;
|
|
||||||
if (thread_policy_get(thread, THREAD_AFFINITY_POLICY, (thread_policy_t)&oldAffinities[threadIndex], &policyCount, &getDefault) != KERN_SUCCESS)
|
|
||||||
{
|
|
||||||
OIDN_WARNING("thread_policy_get failed");
|
|
||||||
oldAffinities[threadIndex] = affinities[threadIndex];
|
|
||||||
return;
|
|
||||||
}
|
|
||||||
|
|
||||||
// Set the new affinity
|
|
||||||
if (thread_policy_set(thread, THREAD_AFFINITY_POLICY, (thread_policy_t)&affinities[threadIndex], THREAD_AFFINITY_POLICY_COUNT) != KERN_SUCCESS)
|
|
||||||
OIDN_WARNING("thread_policy_set failed");
|
|
||||||
}
|
|
||||||
|
|
||||||
void ThreadAffinity::restore(int threadIndex)
|
|
||||||
{
|
|
||||||
if (threadIndex >= (int)affinities.size())
|
|
||||||
return;
|
|
||||||
|
|
||||||
const auto thread = mach_thread_self();
|
|
||||||
|
|
||||||
// Restore the original affinity
|
|
||||||
if (thread_policy_set(thread, THREAD_AFFINITY_POLICY, (thread_policy_t)&oldAffinities[threadIndex], THREAD_AFFINITY_POLICY_COUNT) != KERN_SUCCESS)
|
|
||||||
OIDN_WARNING("thread_policy_set failed");
|
|
||||||
}
|
|
||||||
|
|
||||||
#endif
|
|
||||||
|
|
||||||
} // namespace oidn
|
|
202
thirdparty/oidn/common/thread.h
vendored
202
thirdparty/oidn/common/thread.h
vendored
@ -1,202 +0,0 @@
|
|||||||
// ======================================================================== //
|
|
||||||
// Copyright 2009-2019 Intel Corporation //
|
|
||||||
// //
|
|
||||||
// Licensed under the Apache License, Version 2.0 (the "License"); //
|
|
||||||
// you may not use this file except in compliance with the License. //
|
|
||||||
// You may obtain a copy of the License at //
|
|
||||||
// //
|
|
||||||
// http://www.apache.org/licenses/LICENSE-2.0 //
|
|
||||||
// //
|
|
||||||
// Unless required by applicable law or agreed to in writing, software //
|
|
||||||
// distributed under the License is distributed on an "AS IS" BASIS, //
|
|
||||||
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. //
|
|
||||||
// See the License for the specific language governing permissions and //
|
|
||||||
// limitations under the License. //
|
|
||||||
// ======================================================================== //
|
|
||||||
|
|
||||||
#pragma once
|
|
||||||
|
|
||||||
#include "platform.h"
|
|
||||||
|
|
||||||
#if !defined(_WIN32)
|
|
||||||
#include <pthread.h>
|
|
||||||
#include <sched.h>
|
|
||||||
#if defined(__APPLE__)
|
|
||||||
#include <mach/thread_policy.h>
|
|
||||||
#endif
|
|
||||||
#endif
|
|
||||||
|
|
||||||
#include <vector>
|
|
||||||
#include <mutex>
|
|
||||||
|
|
||||||
namespace oidn {
|
|
||||||
|
|
||||||
// --------------------------------------------------------------------------
|
|
||||||
// ThreadLocal
|
|
||||||
// --------------------------------------------------------------------------
|
|
||||||
|
|
||||||
// Wrapper which makes any variable thread-local
|
|
||||||
template<typename T>
|
|
||||||
class ThreadLocal : public Verbose
|
|
||||||
{
|
|
||||||
private:
|
|
||||||
#if defined(_WIN32)
|
|
||||||
DWORD key;
|
|
||||||
#else
|
|
||||||
pthread_key_t key;
|
|
||||||
#endif
|
|
||||||
|
|
||||||
std::vector<T*> instances;
|
|
||||||
std::mutex mutex;
|
|
||||||
|
|
||||||
public:
|
|
||||||
ThreadLocal(int verbose = 0)
|
|
||||||
: Verbose(verbose)
|
|
||||||
{
|
|
||||||
#if defined(_WIN32)
|
|
||||||
key = TlsAlloc();
|
|
||||||
if (key == TLS_OUT_OF_INDEXES)
|
|
||||||
OIDN_FATAL("TlsAlloc failed");
|
|
||||||
#else
|
|
||||||
if (pthread_key_create(&key, nullptr) != 0)
|
|
||||||
OIDN_FATAL("pthread_key_create failed");
|
|
||||||
#endif
|
|
||||||
}
|
|
||||||
|
|
||||||
~ThreadLocal()
|
|
||||||
{
|
|
||||||
std::lock_guard<std::mutex> lock(mutex);
|
|
||||||
for (T* ptr : instances)
|
|
||||||
delete ptr;
|
|
||||||
|
|
||||||
#if defined(_WIN32)
|
|
||||||
if (!TlsFree(key))
|
|
||||||
OIDN_WARNING("TlsFree failed");
|
|
||||||
#else
|
|
||||||
if (pthread_key_delete(key) != 0)
|
|
||||||
OIDN_WARNING("pthread_key_delete failed");
|
|
||||||
#endif
|
|
||||||
}
|
|
||||||
|
|
||||||
T& get()
|
|
||||||
{
|
|
||||||
#if defined(_WIN32)
|
|
||||||
T* ptr = (T*)TlsGetValue(key);
|
|
||||||
#else
|
|
||||||
T* ptr = (T*)pthread_getspecific(key);
|
|
||||||
#endif
|
|
||||||
|
|
||||||
if (ptr)
|
|
||||||
return *ptr;
|
|
||||||
|
|
||||||
ptr = new T;
|
|
||||||
std::lock_guard<std::mutex> lock(mutex);
|
|
||||||
instances.push_back(ptr);
|
|
||||||
|
|
||||||
#if defined(_WIN32)
|
|
||||||
if (!TlsSetValue(key, ptr))
|
|
||||||
OIDN_FATAL("TlsSetValue failed");
|
|
||||||
#else
|
|
||||||
if (pthread_setspecific(key, ptr) != 0)
|
|
||||||
OIDN_FATAL("pthread_setspecific failed");
|
|
||||||
#endif
|
|
||||||
|
|
||||||
return *ptr;
|
|
||||||
}
|
|
||||||
};
|
|
||||||
|
|
||||||
#if defined(_WIN32)
|
|
||||||
|
|
||||||
// --------------------------------------------------------------------------
|
|
||||||
// ThreadAffinity - Windows
|
|
||||||
// --------------------------------------------------------------------------
|
|
||||||
|
|
||||||
class ThreadAffinity : public Verbose
|
|
||||||
{
|
|
||||||
private:
|
|
||||||
typedef BOOL (WINAPI *GetLogicalProcessorInformationExFunc)(LOGICAL_PROCESSOR_RELATIONSHIP,
|
|
||||||
PSYSTEM_LOGICAL_PROCESSOR_INFORMATION_EX,
|
|
||||||
PDWORD);
|
|
||||||
|
|
||||||
typedef BOOL (WINAPI *SetThreadGroupAffinityFunc)(HANDLE,
|
|
||||||
CONST GROUP_AFFINITY*,
|
|
||||||
PGROUP_AFFINITY);
|
|
||||||
|
|
||||||
GetLogicalProcessorInformationExFunc pGetLogicalProcessorInformationEx = nullptr;
|
|
||||||
SetThreadGroupAffinityFunc pSetThreadGroupAffinity = nullptr;
|
|
||||||
|
|
||||||
std::vector<GROUP_AFFINITY> affinities; // thread affinities
|
|
||||||
std::vector<GROUP_AFFINITY> oldAffinities; // original thread affinities
|
|
||||||
|
|
||||||
public:
|
|
||||||
ThreadAffinity(int numThreadsPerCore = INT_MAX, int verbose = 0);
|
|
||||||
|
|
||||||
int getNumThreads() const
|
|
||||||
{
|
|
||||||
return (int)affinities.size();
|
|
||||||
}
|
|
||||||
|
|
||||||
// Sets the affinity (0..numThreads-1) of the thread after saving the current affinity
|
|
||||||
void set(int threadIndex);
|
|
||||||
|
|
||||||
// Restores the affinity of the thread
|
|
||||||
void restore(int threadIndex);
|
|
||||||
};
|
|
||||||
|
|
||||||
#elif defined(__linux__)
|
|
||||||
|
|
||||||
// --------------------------------------------------------------------------
|
|
||||||
// ThreadAffinity - Linux
|
|
||||||
// --------------------------------------------------------------------------
|
|
||||||
|
|
||||||
class ThreadAffinity : public Verbose
|
|
||||||
{
|
|
||||||
private:
|
|
||||||
std::vector<cpu_set_t> affinities; // thread affinities
|
|
||||||
std::vector<cpu_set_t> oldAffinities; // original thread affinities
|
|
||||||
|
|
||||||
public:
|
|
||||||
ThreadAffinity(int numThreadsPerCore = INT_MAX, int verbose = 0);
|
|
||||||
|
|
||||||
int getNumThreads() const
|
|
||||||
{
|
|
||||||
return (int)affinities.size();
|
|
||||||
}
|
|
||||||
|
|
||||||
// Sets the affinity (0..numThreads-1) of the thread after saving the current affinity
|
|
||||||
void set(int threadIndex);
|
|
||||||
|
|
||||||
// Restores the affinity of the thread
|
|
||||||
void restore(int threadIndex);
|
|
||||||
};
|
|
||||||
|
|
||||||
#elif defined(__APPLE__)
|
|
||||||
|
|
||||||
// --------------------------------------------------------------------------
|
|
||||||
// ThreadAffinity - macOS
|
|
||||||
// --------------------------------------------------------------------------
|
|
||||||
|
|
||||||
class ThreadAffinity : public Verbose
|
|
||||||
{
|
|
||||||
private:
|
|
||||||
std::vector<thread_affinity_policy> affinities; // thread affinities
|
|
||||||
std::vector<thread_affinity_policy> oldAffinities; // original thread affinities
|
|
||||||
|
|
||||||
public:
|
|
||||||
ThreadAffinity(int numThreadsPerCore = INT_MAX, int verbose = 0);
|
|
||||||
|
|
||||||
int getNumThreads() const
|
|
||||||
{
|
|
||||||
return (int)affinities.size();
|
|
||||||
}
|
|
||||||
|
|
||||||
// Sets the affinity (0..numThreads-1) of the thread after saving the current affinity
|
|
||||||
void set(int threadIndex);
|
|
||||||
|
|
||||||
// Restores the affinity of the thread
|
|
||||||
void restore(int threadIndex);
|
|
||||||
};
|
|
||||||
|
|
||||||
#endif
|
|
||||||
|
|
||||||
} // namespace oidn
|
|
49
thirdparty/oidn/common/timer.h
vendored
49
thirdparty/oidn/common/timer.h
vendored
@ -1,49 +0,0 @@
|
|||||||
// ======================================================================== //
|
|
||||||
// Copyright 2009-2019 Intel Corporation //
|
|
||||||
// //
|
|
||||||
// Licensed under the Apache License, Version 2.0 (the "License"); //
|
|
||||||
// you may not use this file except in compliance with the License. //
|
|
||||||
// You may obtain a copy of the License at //
|
|
||||||
// //
|
|
||||||
// http://www.apache.org/licenses/LICENSE-2.0 //
|
|
||||||
// //
|
|
||||||
// Unless required by applicable law or agreed to in writing, software //
|
|
||||||
// distributed under the License is distributed on an "AS IS" BASIS, //
|
|
||||||
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. //
|
|
||||||
// See the License for the specific language governing permissions and //
|
|
||||||
// limitations under the License. //
|
|
||||||
// ======================================================================== //
|
|
||||||
|
|
||||||
#pragma once
|
|
||||||
|
|
||||||
#include "platform.h"
|
|
||||||
#include <chrono>
|
|
||||||
|
|
||||||
namespace oidn {
|
|
||||||
|
|
||||||
class Timer
|
|
||||||
{
|
|
||||||
private:
|
|
||||||
using clock = std::chrono::high_resolution_clock;
|
|
||||||
|
|
||||||
std::chrono::time_point<clock> start;
|
|
||||||
|
|
||||||
public:
|
|
||||||
Timer()
|
|
||||||
{
|
|
||||||
reset();
|
|
||||||
}
|
|
||||||
|
|
||||||
void reset()
|
|
||||||
{
|
|
||||||
start = clock::now();
|
|
||||||
}
|
|
||||||
|
|
||||||
double query() const
|
|
||||||
{
|
|
||||||
auto end = clock::now();
|
|
||||||
return std::chrono::duration_cast<std::chrono::duration<double>>(end - start).count();
|
|
||||||
}
|
|
||||||
};
|
|
||||||
|
|
||||||
} // namespace oidn
|
|
408
thirdparty/oidn/core/api.cpp
vendored
408
thirdparty/oidn/core/api.cpp
vendored
@ -1,408 +0,0 @@
|
|||||||
// ======================================================================== //
|
|
||||||
// Copyright 2009-2019 Intel Corporation //
|
|
||||||
// //
|
|
||||||
// Licensed under the Apache License, Version 2.0 (the "License"); //
|
|
||||||
// you may not use this file except in compliance with the License. //
|
|
||||||
// You may obtain a copy of the License at //
|
|
||||||
// //
|
|
||||||
// http://www.apache.org/licenses/LICENSE-2.0 //
|
|
||||||
// //
|
|
||||||
// Unless required by applicable law or agreed to in writing, software //
|
|
||||||
// distributed under the License is distributed on an "AS IS" BASIS, //
|
|
||||||
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. //
|
|
||||||
// See the License for the specific language governing permissions and //
|
|
||||||
// limitations under the License. //
|
|
||||||
// ======================================================================== //
|
|
||||||
|
|
||||||
#ifdef _WIN32
|
|
||||||
# define OIDN_API extern "C" __declspec(dllexport)
|
|
||||||
#else
|
|
||||||
# define OIDN_API extern "C" __attribute__ ((visibility ("default")))
|
|
||||||
#endif
|
|
||||||
|
|
||||||
// Locks the device that owns the specified object
|
|
||||||
// Use *only* inside OIDN_TRY/CATCH!
|
|
||||||
#define OIDN_LOCK(obj) \
|
|
||||||
std::lock_guard<std::mutex> lock(obj->getDevice()->getMutex());
|
|
||||||
|
|
||||||
// Try/catch for converting exceptions to errors
|
|
||||||
#define OIDN_TRY \
|
|
||||||
try {
|
|
||||||
|
|
||||||
#define OIDN_CATCH(obj) \
|
|
||||||
} catch (Exception& e) { \
|
|
||||||
Device::setError(obj ? obj->getDevice() : nullptr, e.code(), e.what()); \
|
|
||||||
} catch (std::bad_alloc&) { \
|
|
||||||
Device::setError(obj ? obj->getDevice() : nullptr, Error::OutOfMemory, "out of memory"); \
|
|
||||||
} catch (mkldnn::error& e) { \
|
|
||||||
if (e.status == mkldnn_out_of_memory) \
|
|
||||||
Device::setError(obj ? obj->getDevice() : nullptr, Error::OutOfMemory, "out of memory"); \
|
|
||||||
else \
|
|
||||||
Device::setError(obj ? obj->getDevice() : nullptr, Error::Unknown, e.message); \
|
|
||||||
} catch (std::exception& e) { \
|
|
||||||
Device::setError(obj ? obj->getDevice() : nullptr, Error::Unknown, e.what()); \
|
|
||||||
} catch (...) { \
|
|
||||||
Device::setError(obj ? obj->getDevice() : nullptr, Error::Unknown, "unknown exception caught"); \
|
|
||||||
}
|
|
||||||
|
|
||||||
#include "device.h"
|
|
||||||
#include "filter.h"
|
|
||||||
#include <mutex>
|
|
||||||
|
|
||||||
namespace oidn {
|
|
||||||
|
|
||||||
namespace
|
|
||||||
{
|
|
||||||
__forceinline void checkHandle(void* handle)
|
|
||||||
{
|
|
||||||
if (handle == nullptr)
|
|
||||||
throw Exception(Error::InvalidArgument, "invalid handle");
|
|
||||||
}
|
|
||||||
|
|
||||||
template<typename T>
|
|
||||||
__forceinline void retainObject(T* obj)
|
|
||||||
{
|
|
||||||
if (obj)
|
|
||||||
{
|
|
||||||
obj->incRef();
|
|
||||||
}
|
|
||||||
else
|
|
||||||
{
|
|
||||||
OIDN_TRY
|
|
||||||
checkHandle(obj);
|
|
||||||
OIDN_CATCH(obj)
|
|
||||||
}
|
|
||||||
}
|
|
||||||
|
|
||||||
template<typename T>
|
|
||||||
__forceinline void releaseObject(T* obj)
|
|
||||||
{
|
|
||||||
if (obj == nullptr || obj->decRefKeep() == 0)
|
|
||||||
{
|
|
||||||
OIDN_TRY
|
|
||||||
checkHandle(obj);
|
|
||||||
OIDN_LOCK(obj);
|
|
||||||
obj->destroy();
|
|
||||||
OIDN_CATCH(obj)
|
|
||||||
}
|
|
||||||
}
|
|
||||||
|
|
||||||
template<>
|
|
||||||
__forceinline void releaseObject(Device* obj)
|
|
||||||
{
|
|
||||||
if (obj == nullptr || obj->decRefKeep() == 0)
|
|
||||||
{
|
|
||||||
OIDN_TRY
|
|
||||||
checkHandle(obj);
|
|
||||||
// Do NOT lock the device because it owns the mutex
|
|
||||||
obj->destroy();
|
|
||||||
OIDN_CATCH(obj)
|
|
||||||
}
|
|
||||||
}
|
|
||||||
}
|
|
||||||
|
|
||||||
OIDN_API OIDNDevice oidnNewDevice(OIDNDeviceType type)
|
|
||||||
{
|
|
||||||
Ref<Device> device = nullptr;
|
|
||||||
OIDN_TRY
|
|
||||||
if (type == OIDN_DEVICE_TYPE_CPU || type == OIDN_DEVICE_TYPE_DEFAULT)
|
|
||||||
device = makeRef<Device>();
|
|
||||||
else
|
|
||||||
throw Exception(Error::InvalidArgument, "invalid device type");
|
|
||||||
OIDN_CATCH(device)
|
|
||||||
return (OIDNDevice)device.detach();
|
|
||||||
}
|
|
||||||
|
|
||||||
OIDN_API void oidnRetainDevice(OIDNDevice hDevice)
|
|
||||||
{
|
|
||||||
Device* device = (Device*)hDevice;
|
|
||||||
retainObject(device);
|
|
||||||
}
|
|
||||||
|
|
||||||
OIDN_API void oidnReleaseDevice(OIDNDevice hDevice)
|
|
||||||
{
|
|
||||||
Device* device = (Device*)hDevice;
|
|
||||||
releaseObject(device);
|
|
||||||
}
|
|
||||||
|
|
||||||
OIDN_API void oidnSetDevice1b(OIDNDevice hDevice, const char* name, bool value)
|
|
||||||
{
|
|
||||||
Device* device = (Device*)hDevice;
|
|
||||||
OIDN_TRY
|
|
||||||
checkHandle(hDevice);
|
|
||||||
OIDN_LOCK(device);
|
|
||||||
device->set1i(name, value);
|
|
||||||
OIDN_CATCH(device)
|
|
||||||
}
|
|
||||||
|
|
||||||
OIDN_API void oidnSetDevice1i(OIDNDevice hDevice, const char* name, int value)
|
|
||||||
{
|
|
||||||
Device* device = (Device*)hDevice;
|
|
||||||
OIDN_TRY
|
|
||||||
checkHandle(hDevice);
|
|
||||||
OIDN_LOCK(device);
|
|
||||||
device->set1i(name, value);
|
|
||||||
OIDN_CATCH(device)
|
|
||||||
}
|
|
||||||
|
|
||||||
OIDN_API bool oidnGetDevice1b(OIDNDevice hDevice, const char* name)
|
|
||||||
{
|
|
||||||
Device* device = (Device*)hDevice;
|
|
||||||
OIDN_TRY
|
|
||||||
checkHandle(hDevice);
|
|
||||||
OIDN_LOCK(device);
|
|
||||||
return device->get1i(name);
|
|
||||||
OIDN_CATCH(device)
|
|
||||||
return false;
|
|
||||||
}
|
|
||||||
|
|
||||||
OIDN_API int oidnGetDevice1i(OIDNDevice hDevice, const char* name)
|
|
||||||
{
|
|
||||||
Device* device = (Device*)hDevice;
|
|
||||||
OIDN_TRY
|
|
||||||
checkHandle(hDevice);
|
|
||||||
OIDN_LOCK(device);
|
|
||||||
return device->get1i(name);
|
|
||||||
OIDN_CATCH(device)
|
|
||||||
return 0;
|
|
||||||
}
|
|
||||||
|
|
||||||
OIDN_API void oidnSetDeviceErrorFunction(OIDNDevice hDevice, OIDNErrorFunction func, void* userPtr)
|
|
||||||
{
|
|
||||||
Device* device = (Device*)hDevice;
|
|
||||||
OIDN_TRY
|
|
||||||
checkHandle(hDevice);
|
|
||||||
OIDN_LOCK(device);
|
|
||||||
device->setErrorFunction((ErrorFunction)func, userPtr);
|
|
||||||
OIDN_CATCH(device)
|
|
||||||
}
|
|
||||||
|
|
||||||
OIDN_API OIDNError oidnGetDeviceError(OIDNDevice hDevice, const char** outMessage)
|
|
||||||
{
|
|
||||||
Device* device = (Device*)hDevice;
|
|
||||||
OIDN_TRY
|
|
||||||
return (OIDNError)Device::getError(device, outMessage);
|
|
||||||
OIDN_CATCH(device)
|
|
||||||
if (outMessage) *outMessage = "";
|
|
||||||
return OIDN_ERROR_UNKNOWN;
|
|
||||||
}
|
|
||||||
|
|
||||||
OIDN_API void oidnCommitDevice(OIDNDevice hDevice)
|
|
||||||
{
|
|
||||||
Device* device = (Device*)hDevice;
|
|
||||||
OIDN_TRY
|
|
||||||
checkHandle(hDevice);
|
|
||||||
OIDN_LOCK(device);
|
|
||||||
device->commit();
|
|
||||||
OIDN_CATCH(device)
|
|
||||||
}
|
|
||||||
|
|
||||||
OIDN_API OIDNBuffer oidnNewBuffer(OIDNDevice hDevice, size_t byteSize)
|
|
||||||
{
|
|
||||||
Device* device = (Device*)hDevice;
|
|
||||||
OIDN_TRY
|
|
||||||
checkHandle(hDevice);
|
|
||||||
OIDN_LOCK(device);
|
|
||||||
Ref<Buffer> buffer = device->newBuffer(byteSize);
|
|
||||||
return (OIDNBuffer)buffer.detach();
|
|
||||||
OIDN_CATCH(device)
|
|
||||||
return nullptr;
|
|
||||||
}
|
|
||||||
|
|
||||||
OIDN_API OIDNBuffer oidnNewSharedBuffer(OIDNDevice hDevice, void* ptr, size_t byteSize)
|
|
||||||
{
|
|
||||||
Device* device = (Device*)hDevice;
|
|
||||||
OIDN_TRY
|
|
||||||
checkHandle(hDevice);
|
|
||||||
OIDN_LOCK(device);
|
|
||||||
Ref<Buffer> buffer = device->newBuffer(ptr, byteSize);
|
|
||||||
return (OIDNBuffer)buffer.detach();
|
|
||||||
OIDN_CATCH(device)
|
|
||||||
return nullptr;
|
|
||||||
}
|
|
||||||
|
|
||||||
OIDN_API void oidnRetainBuffer(OIDNBuffer hBuffer)
|
|
||||||
{
|
|
||||||
Buffer* buffer = (Buffer*)hBuffer;
|
|
||||||
retainObject(buffer);
|
|
||||||
}
|
|
||||||
|
|
||||||
OIDN_API void oidnReleaseBuffer(OIDNBuffer hBuffer)
|
|
||||||
{
|
|
||||||
Buffer* buffer = (Buffer*)hBuffer;
|
|
||||||
releaseObject(buffer);
|
|
||||||
}
|
|
||||||
|
|
||||||
OIDN_API void* oidnMapBuffer(OIDNBuffer hBuffer, OIDNAccess access, size_t byteOffset, size_t byteSize)
|
|
||||||
{
|
|
||||||
Buffer* buffer = (Buffer*)hBuffer;
|
|
||||||
OIDN_TRY
|
|
||||||
checkHandle(hBuffer);
|
|
||||||
OIDN_LOCK(buffer);
|
|
||||||
return buffer->map(byteOffset, byteSize);
|
|
||||||
OIDN_CATCH(buffer)
|
|
||||||
return nullptr;
|
|
||||||
}
|
|
||||||
|
|
||||||
OIDN_API void oidnUnmapBuffer(OIDNBuffer hBuffer, void* mappedPtr)
|
|
||||||
{
|
|
||||||
Buffer* buffer = (Buffer*)hBuffer;
|
|
||||||
OIDN_TRY
|
|
||||||
checkHandle(hBuffer);
|
|
||||||
OIDN_LOCK(buffer);
|
|
||||||
return buffer->unmap(mappedPtr);
|
|
||||||
OIDN_CATCH(buffer)
|
|
||||||
}
|
|
||||||
|
|
||||||
OIDN_API OIDNFilter oidnNewFilter(OIDNDevice hDevice, const char* type)
|
|
||||||
{
|
|
||||||
Device* device = (Device*)hDevice;
|
|
||||||
OIDN_TRY
|
|
||||||
checkHandle(hDevice);
|
|
||||||
OIDN_LOCK(device);
|
|
||||||
Ref<Filter> filter = device->newFilter(type);
|
|
||||||
return (OIDNFilter)filter.detach();
|
|
||||||
OIDN_CATCH(device)
|
|
||||||
return nullptr;
|
|
||||||
}
|
|
||||||
|
|
||||||
OIDN_API void oidnRetainFilter(OIDNFilter hFilter)
|
|
||||||
{
|
|
||||||
Filter* filter = (Filter*)hFilter;
|
|
||||||
retainObject(filter);
|
|
||||||
}
|
|
||||||
|
|
||||||
OIDN_API void oidnReleaseFilter(OIDNFilter hFilter)
|
|
||||||
{
|
|
||||||
Filter* filter = (Filter*)hFilter;
|
|
||||||
releaseObject(filter);
|
|
||||||
}
|
|
||||||
|
|
||||||
OIDN_API void oidnSetFilterImage(OIDNFilter hFilter, const char* name,
|
|
||||||
OIDNBuffer hBuffer, OIDNFormat format,
|
|
||||||
size_t width, size_t height,
|
|
||||||
size_t byteOffset,
|
|
||||||
size_t bytePixelStride, size_t byteRowStride)
|
|
||||||
{
|
|
||||||
Filter* filter = (Filter*)hFilter;
|
|
||||||
OIDN_TRY
|
|
||||||
checkHandle(hFilter);
|
|
||||||
checkHandle(hBuffer);
|
|
||||||
OIDN_LOCK(filter);
|
|
||||||
Ref<Buffer> buffer = (Buffer*)hBuffer;
|
|
||||||
if (buffer->getDevice() != filter->getDevice())
|
|
||||||
throw Exception(Error::InvalidArgument, "the specified objects are bound to different devices");
|
|
||||||
Image data(buffer, (Format)format, (int)width, (int)height, byteOffset, bytePixelStride, byteRowStride);
|
|
||||||
filter->setImage(name, data);
|
|
||||||
OIDN_CATCH(filter)
|
|
||||||
}
|
|
||||||
|
|
||||||
OIDN_API void oidnSetSharedFilterImage(OIDNFilter hFilter, const char* name,
|
|
||||||
void* ptr, OIDNFormat format,
|
|
||||||
size_t width, size_t height,
|
|
||||||
size_t byteOffset,
|
|
||||||
size_t bytePixelStride, size_t byteRowStride)
|
|
||||||
{
|
|
||||||
Filter* filter = (Filter*)hFilter;
|
|
||||||
OIDN_TRY
|
|
||||||
checkHandle(hFilter);
|
|
||||||
OIDN_LOCK(filter);
|
|
||||||
Image data(ptr, (Format)format, (int)width, (int)height, byteOffset, bytePixelStride, byteRowStride);
|
|
||||||
filter->setImage(name, data);
|
|
||||||
OIDN_CATCH(filter)
|
|
||||||
}
|
|
||||||
|
|
||||||
OIDN_API void oidnSetFilter1b(OIDNFilter hFilter, const char* name, bool value)
|
|
||||||
{
|
|
||||||
Filter* filter = (Filter*)hFilter;
|
|
||||||
OIDN_TRY
|
|
||||||
checkHandle(hFilter);
|
|
||||||
OIDN_LOCK(filter);
|
|
||||||
filter->set1i(name, int(value));
|
|
||||||
OIDN_CATCH(filter)
|
|
||||||
}
|
|
||||||
|
|
||||||
OIDN_API bool oidnGetFilter1b(OIDNFilter hFilter, const char* name)
|
|
||||||
{
|
|
||||||
Filter* filter = (Filter*)hFilter;
|
|
||||||
OIDN_TRY
|
|
||||||
checkHandle(hFilter);
|
|
||||||
OIDN_LOCK(filter);
|
|
||||||
return filter->get1i(name);
|
|
||||||
OIDN_CATCH(filter)
|
|
||||||
return false;
|
|
||||||
}
|
|
||||||
|
|
||||||
OIDN_API void oidnSetFilter1i(OIDNFilter hFilter, const char* name, int value)
|
|
||||||
{
|
|
||||||
Filter* filter = (Filter*)hFilter;
|
|
||||||
OIDN_TRY
|
|
||||||
checkHandle(hFilter);
|
|
||||||
OIDN_LOCK(filter);
|
|
||||||
filter->set1i(name, value);
|
|
||||||
OIDN_CATCH(filter)
|
|
||||||
}
|
|
||||||
|
|
||||||
OIDN_API int oidnGetFilter1i(OIDNFilter hFilter, const char* name)
|
|
||||||
{
|
|
||||||
Filter* filter = (Filter*)hFilter;
|
|
||||||
OIDN_TRY
|
|
||||||
checkHandle(hFilter);
|
|
||||||
OIDN_LOCK(filter);
|
|
||||||
return filter->get1i(name);
|
|
||||||
OIDN_CATCH(filter)
|
|
||||||
return 0;
|
|
||||||
}
|
|
||||||
|
|
||||||
OIDN_API void oidnSetFilter1f(OIDNFilter hFilter, const char* name, float value)
|
|
||||||
{
|
|
||||||
Filter* filter = (Filter*)hFilter;
|
|
||||||
OIDN_TRY
|
|
||||||
checkHandle(hFilter);
|
|
||||||
OIDN_LOCK(filter);
|
|
||||||
filter->set1f(name, value);
|
|
||||||
OIDN_CATCH(filter)
|
|
||||||
}
|
|
||||||
|
|
||||||
OIDN_API float oidnGetFilter1f(OIDNFilter hFilter, const char* name)
|
|
||||||
{
|
|
||||||
Filter* filter = (Filter*)hFilter;
|
|
||||||
OIDN_TRY
|
|
||||||
checkHandle(hFilter);
|
|
||||||
OIDN_LOCK(filter);
|
|
||||||
return filter->get1f(name);
|
|
||||||
OIDN_CATCH(filter)
|
|
||||||
return 0;
|
|
||||||
}
|
|
||||||
|
|
||||||
OIDN_API void oidnSetFilterProgressMonitorFunction(OIDNFilter hFilter, OIDNProgressMonitorFunction func, void* userPtr)
|
|
||||||
{
|
|
||||||
Filter* filter = (Filter*)hFilter;
|
|
||||||
OIDN_TRY
|
|
||||||
checkHandle(hFilter);
|
|
||||||
OIDN_LOCK(filter);
|
|
||||||
filter->setProgressMonitorFunction(func, userPtr);
|
|
||||||
OIDN_CATCH(filter)
|
|
||||||
}
|
|
||||||
|
|
||||||
OIDN_API void oidnCommitFilter(OIDNFilter hFilter)
|
|
||||||
{
|
|
||||||
Filter* filter = (Filter*)hFilter;
|
|
||||||
OIDN_TRY
|
|
||||||
checkHandle(hFilter);
|
|
||||||
OIDN_LOCK(filter);
|
|
||||||
filter->commit();
|
|
||||||
OIDN_CATCH(filter)
|
|
||||||
}
|
|
||||||
|
|
||||||
OIDN_API void oidnExecuteFilter(OIDNFilter hFilter)
|
|
||||||
{
|
|
||||||
Filter* filter = (Filter*)hFilter;
|
|
||||||
OIDN_TRY
|
|
||||||
checkHandle(hFilter);
|
|
||||||
OIDN_LOCK(filter);
|
|
||||||
filter->execute();
|
|
||||||
OIDN_CATCH(filter)
|
|
||||||
}
|
|
||||||
|
|
||||||
} // namespace oidn
|
|
535
thirdparty/oidn/core/autoencoder.cpp
vendored
535
thirdparty/oidn/core/autoencoder.cpp
vendored
@ -1,535 +0,0 @@
|
|||||||
// ======================================================================== //
|
|
||||||
// Copyright 2009-2019 Intel Corporation //
|
|
||||||
// //
|
|
||||||
// Licensed under the Apache License, Version 2.0 (the "License"); //
|
|
||||||
// you may not use this file except in compliance with the License. //
|
|
||||||
// You may obtain a copy of the License at //
|
|
||||||
// //
|
|
||||||
// http://www.apache.org/licenses/LICENSE-2.0 //
|
|
||||||
// //
|
|
||||||
// Unless required by applicable law or agreed to in writing, software //
|
|
||||||
// distributed under the License is distributed on an "AS IS" BASIS, //
|
|
||||||
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. //
|
|
||||||
// See the License for the specific language governing permissions and //
|
|
||||||
// limitations under the License. //
|
|
||||||
// ======================================================================== //
|
|
||||||
|
|
||||||
#include "autoencoder.h"
|
|
||||||
|
|
||||||
namespace oidn {
|
|
||||||
|
|
||||||
// --------------------------------------------------------------------------
|
|
||||||
// AutoencoderFilter
|
|
||||||
// --------------------------------------------------------------------------
|
|
||||||
|
|
||||||
AutoencoderFilter::AutoencoderFilter(const Ref<Device>& device)
|
|
||||||
: Filter(device)
|
|
||||||
{
|
|
||||||
}
|
|
||||||
|
|
||||||
void AutoencoderFilter::setImage(const std::string& name, const Image& data)
|
|
||||||
{
|
|
||||||
if (name == "color")
|
|
||||||
color = data;
|
|
||||||
else if (name == "albedo")
|
|
||||||
albedo = data;
|
|
||||||
else if (name == "normal")
|
|
||||||
normal = data;
|
|
||||||
else if (name == "output")
|
|
||||||
output = data;
|
|
||||||
|
|
||||||
dirty = true;
|
|
||||||
}
|
|
||||||
|
|
||||||
void AutoencoderFilter::set1i(const std::string& name, int value)
|
|
||||||
{
|
|
||||||
if (name == "hdr")
|
|
||||||
hdr = value;
|
|
||||||
else if (name == "srgb")
|
|
||||||
srgb = value;
|
|
||||||
else if (name == "maxMemoryMB")
|
|
||||||
maxMemoryMB = value;
|
|
||||||
|
|
||||||
dirty = true;
|
|
||||||
}
|
|
||||||
|
|
||||||
int AutoencoderFilter::get1i(const std::string& name)
|
|
||||||
{
|
|
||||||
if (name == "hdr")
|
|
||||||
return hdr;
|
|
||||||
else if (name == "srgb")
|
|
||||||
return srgb;
|
|
||||||
else if (name == "maxMemoryMB")
|
|
||||||
return maxMemoryMB;
|
|
||||||
else if (name == "alignment")
|
|
||||||
return alignment;
|
|
||||||
else if (name == "overlap")
|
|
||||||
return overlap;
|
|
||||||
else
|
|
||||||
throw Exception(Error::InvalidArgument, "invalid parameter");
|
|
||||||
}
|
|
||||||
|
|
||||||
void AutoencoderFilter::set1f(const std::string& name, float value)
|
|
||||||
{
|
|
||||||
if (name == "hdrScale")
|
|
||||||
hdrScale = value;
|
|
||||||
|
|
||||||
dirty = true;
|
|
||||||
}
|
|
||||||
|
|
||||||
float AutoencoderFilter::get1f(const std::string& name)
|
|
||||||
{
|
|
||||||
if (name == "hdrScale")
|
|
||||||
return hdrScale;
|
|
||||||
else
|
|
||||||
throw Exception(Error::InvalidArgument, "invalid parameter");
|
|
||||||
}
|
|
||||||
|
|
||||||
void AutoencoderFilter::commit()
|
|
||||||
{
|
|
||||||
if (!dirty)
|
|
||||||
return;
|
|
||||||
|
|
||||||
// -- GODOT start --
|
|
||||||
//device->executeTask([&]()
|
|
||||||
//{
|
|
||||||
// GODOT end --
|
|
||||||
|
|
||||||
if (mayiuse(avx512_common))
|
|
||||||
net = buildNet<16>();
|
|
||||||
else
|
|
||||||
net = buildNet<8>();
|
|
||||||
|
|
||||||
// GODOT start --
|
|
||||||
//});
|
|
||||||
// GODOT end --
|
|
||||||
|
|
||||||
dirty = false;
|
|
||||||
}
|
|
||||||
|
|
||||||
void AutoencoderFilter::execute()
|
|
||||||
{
|
|
||||||
if (dirty)
|
|
||||||
throw Exception(Error::InvalidOperation, "changes to the filter are not committed");
|
|
||||||
|
|
||||||
if (!net)
|
|
||||||
return;
|
|
||||||
// -- GODOT start --
|
|
||||||
//device->executeTask([&]()
|
|
||||||
//{
|
|
||||||
// -- GODOT end --
|
|
||||||
Progress progress;
|
|
||||||
progress.func = progressFunc;
|
|
||||||
progress.userPtr = progressUserPtr;
|
|
||||||
progress.taskCount = tileCountH * tileCountW;
|
|
||||||
|
|
||||||
// Iterate over the tiles
|
|
||||||
int tileIndex = 0;
|
|
||||||
|
|
||||||
for (int i = 0; i < tileCountH; ++i)
|
|
||||||
{
|
|
||||||
const int h = i * (tileH - 2*overlap); // input tile position (including overlap)
|
|
||||||
const int overlapBeginH = i > 0 ? overlap : 0; // overlap on the top
|
|
||||||
const int overlapEndH = i < tileCountH-1 ? overlap : 0; // overlap on the bottom
|
|
||||||
const int tileH1 = min(H - h, tileH); // input tile size (including overlap)
|
|
||||||
const int tileH2 = tileH1 - overlapBeginH - overlapEndH; // output tile size
|
|
||||||
const int alignOffsetH = tileH - roundUp(tileH1, alignment); // align to the bottom in the tile buffer
|
|
||||||
|
|
||||||
for (int j = 0; j < tileCountW; ++j)
|
|
||||||
{
|
|
||||||
const int w = j * (tileW - 2*overlap); // input tile position (including overlap)
|
|
||||||
const int overlapBeginW = j > 0 ? overlap : 0; // overlap on the left
|
|
||||||
const int overlapEndW = j < tileCountW-1 ? overlap : 0; // overlap on the right
|
|
||||||
const int tileW1 = min(W - w, tileW); // input tile size (including overlap)
|
|
||||||
const int tileW2 = tileW1 - overlapBeginW - overlapEndW; // output tile size
|
|
||||||
const int alignOffsetW = tileW - roundUp(tileW1, alignment); // align to the right in the tile buffer
|
|
||||||
|
|
||||||
// Set the input tile
|
|
||||||
inputReorder->setTile(h, w,
|
|
||||||
alignOffsetH, alignOffsetW,
|
|
||||||
tileH1, tileW1);
|
|
||||||
|
|
||||||
// Set the output tile
|
|
||||||
outputReorder->setTile(alignOffsetH + overlapBeginH, alignOffsetW + overlapBeginW,
|
|
||||||
h + overlapBeginH, w + overlapBeginW,
|
|
||||||
tileH2, tileW2);
|
|
||||||
|
|
||||||
//printf("Tile: %d %d -> %d %d\n", w+overlapBeginW, h+overlapBeginH, w+overlapBeginW+tileW2, h+overlapBeginH+tileH2);
|
|
||||||
|
|
||||||
// Denoise the tile
|
|
||||||
net->execute(progress, tileIndex);
|
|
||||||
|
|
||||||
// Next tile
|
|
||||||
tileIndex++;
|
|
||||||
}
|
|
||||||
}
|
|
||||||
// -- GODOT start --
|
|
||||||
//});
|
|
||||||
// -- GODOT end --
|
|
||||||
}
|
|
||||||
|
|
||||||
void AutoencoderFilter::computeTileSize()
|
|
||||||
{
|
|
||||||
const int minTileSize = 3*overlap;
|
|
||||||
const int estimatedBytesPerPixel = mayiuse(avx512_common) ? estimatedBytesPerPixel16 : estimatedBytesPerPixel8;
|
|
||||||
const int64_t maxTilePixels = (int64_t(maxMemoryMB)*1024*1024 - estimatedBytesBase) / estimatedBytesPerPixel;
|
|
||||||
|
|
||||||
tileCountH = 1;
|
|
||||||
tileCountW = 1;
|
|
||||||
tileH = roundUp(H, alignment);
|
|
||||||
tileW = roundUp(W, alignment);
|
|
||||||
|
|
||||||
// Divide the image into tiles until the tile size gets below the threshold
|
|
||||||
while (int64_t(tileH) * tileW > maxTilePixels)
|
|
||||||
{
|
|
||||||
if (tileH > minTileSize && tileH > tileW)
|
|
||||||
{
|
|
||||||
tileCountH++;
|
|
||||||
tileH = max(roundUp(ceilDiv(H - 2*overlap, tileCountH), alignment) + 2*overlap, minTileSize);
|
|
||||||
}
|
|
||||||
else if (tileW > minTileSize)
|
|
||||||
{
|
|
||||||
tileCountW++;
|
|
||||||
tileW = max(roundUp(ceilDiv(W - 2*overlap, tileCountW), alignment) + 2*overlap, minTileSize);
|
|
||||||
}
|
|
||||||
else
|
|
||||||
break;
|
|
||||||
}
|
|
||||||
|
|
||||||
// Compute the final number of tiles
|
|
||||||
tileCountH = (H > tileH) ? ceilDiv(H - 2*overlap, tileH - 2*overlap) : 1;
|
|
||||||
tileCountW = (W > tileW) ? ceilDiv(W - 2*overlap, tileW - 2*overlap) : 1;
|
|
||||||
|
|
||||||
if (device->isVerbose(2))
|
|
||||||
{
|
|
||||||
std::cout << "Tile size : " << tileW << "x" << tileH << std::endl;
|
|
||||||
std::cout << "Tile count: " << tileCountW << "x" << tileCountH << std::endl;
|
|
||||||
}
|
|
||||||
}
|
|
||||||
|
|
||||||
template<int K>
|
|
||||||
std::shared_ptr<Executable> AutoencoderFilter::buildNet()
|
|
||||||
{
|
|
||||||
H = color.height;
|
|
||||||
W = color.width;
|
|
||||||
|
|
||||||
// Configure the network
|
|
||||||
int inputC;
|
|
||||||
void* weightPtr;
|
|
||||||
|
|
||||||
if (srgb && hdr)
|
|
||||||
throw Exception(Error::InvalidOperation, "srgb and hdr modes cannot be enabled at the same time");
|
|
||||||
|
|
||||||
if (color && !albedo && !normal && weightData.hdr)
|
|
||||||
{
|
|
||||||
inputC = 3;
|
|
||||||
weightPtr = hdr ? weightData.hdr : weightData.ldr;
|
|
||||||
}
|
|
||||||
else if (color && albedo && !normal && weightData.hdr_alb)
|
|
||||||
{
|
|
||||||
inputC = 6;
|
|
||||||
weightPtr = hdr ? weightData.hdr_alb : weightData.ldr_alb;
|
|
||||||
}
|
|
||||||
else if (color && albedo && normal && weightData.hdr_alb_nrm)
|
|
||||||
{
|
|
||||||
inputC = 9;
|
|
||||||
weightPtr = hdr ? weightData.hdr_alb_nrm : weightData.ldr_alb_nrm;
|
|
||||||
}
|
|
||||||
else
|
|
||||||
{
|
|
||||||
throw Exception(Error::InvalidOperation, "unsupported combination of input features");
|
|
||||||
}
|
|
||||||
|
|
||||||
if (!output)
|
|
||||||
throw Exception(Error::InvalidOperation, "output image not specified");
|
|
||||||
|
|
||||||
if ((color.format != Format::Float3)
|
|
||||||
|| (albedo && albedo.format != Format::Float3)
|
|
||||||
|| (normal && normal.format != Format::Float3)
|
|
||||||
|| (output.format != Format::Float3))
|
|
||||||
throw Exception(Error::InvalidOperation, "unsupported image format");
|
|
||||||
|
|
||||||
if ((albedo && (albedo.width != W || albedo.height != H))
|
|
||||||
|| (normal && (normal.width != W || normal.height != H))
|
|
||||||
|| (output.width != W || output.height != H))
|
|
||||||
throw Exception(Error::InvalidOperation, "image size mismatch");
|
|
||||||
|
|
||||||
// Compute the tile size
|
|
||||||
computeTileSize();
|
|
||||||
|
|
||||||
// If the image size is zero, there is nothing else to do
|
|
||||||
if (H <= 0 || W <= 0)
|
|
||||||
return nullptr;
|
|
||||||
|
|
||||||
// Parse the weights
|
|
||||||
const auto weightMap = parseTensors(weightPtr);
|
|
||||||
|
|
||||||
// Create the network
|
|
||||||
std::shared_ptr<Network<K>> net = std::make_shared<Network<K>>(device, weightMap);
|
|
||||||
|
|
||||||
// Compute the tensor sizes
|
|
||||||
const auto inputDims = memory::dims({1, inputC, tileH, tileW});
|
|
||||||
const auto inputReorderDims = net->getInputReorderDims(inputDims, alignment); //-> concat0
|
|
||||||
|
|
||||||
const auto conv1Dims = net->getConvDims("conv1", inputReorderDims); //-> temp0
|
|
||||||
const auto conv1bDims = net->getConvDims("conv1b", conv1Dims); //-> temp1
|
|
||||||
const auto pool1Dims = net->getPoolDims(conv1bDims); //-> concat1
|
|
||||||
const auto conv2Dims = net->getConvDims("conv2", pool1Dims); //-> temp0
|
|
||||||
const auto pool2Dims = net->getPoolDims(conv2Dims); //-> concat2
|
|
||||||
const auto conv3Dims = net->getConvDims("conv3", pool2Dims); //-> temp0
|
|
||||||
const auto pool3Dims = net->getPoolDims(conv3Dims); //-> concat3
|
|
||||||
const auto conv4Dims = net->getConvDims("conv4", pool3Dims); //-> temp0
|
|
||||||
const auto pool4Dims = net->getPoolDims(conv4Dims); //-> concat4
|
|
||||||
const auto conv5Dims = net->getConvDims("conv5", pool4Dims); //-> temp0
|
|
||||||
const auto pool5Dims = net->getPoolDims(conv5Dims); //-> temp1
|
|
||||||
const auto upsample4Dims = net->getUpsampleDims(pool5Dims); //-> concat4
|
|
||||||
const auto concat4Dims = net->getConcatDims(upsample4Dims, pool4Dims);
|
|
||||||
const auto conv6Dims = net->getConvDims("conv6", concat4Dims); //-> temp0
|
|
||||||
const auto conv6bDims = net->getConvDims("conv6b", conv6Dims); //-> temp1
|
|
||||||
const auto upsample3Dims = net->getUpsampleDims(conv6bDims); //-> concat3
|
|
||||||
const auto concat3Dims = net->getConcatDims(upsample3Dims, pool3Dims);
|
|
||||||
const auto conv7Dims = net->getConvDims("conv7", concat3Dims); //-> temp0
|
|
||||||
const auto conv7bDims = net->getConvDims("conv7b", conv7Dims); //-> temp1
|
|
||||||
const auto upsample2Dims = net->getUpsampleDims(conv7bDims); //-> concat2
|
|
||||||
const auto concat2Dims = net->getConcatDims(upsample2Dims, pool2Dims);
|
|
||||||
const auto conv8Dims = net->getConvDims("conv8", concat2Dims); //-> temp0
|
|
||||||
const auto conv8bDims = net->getConvDims("conv8b", conv8Dims); //-> temp1
|
|
||||||
const auto upsample1Dims = net->getUpsampleDims(conv8bDims); //-> concat1
|
|
||||||
const auto concat1Dims = net->getConcatDims(upsample1Dims, pool1Dims);
|
|
||||||
const auto conv9Dims = net->getConvDims("conv9", concat1Dims); //-> temp0
|
|
||||||
const auto conv9bDims = net->getConvDims("conv9b", conv9Dims); //-> temp1
|
|
||||||
const auto upsample0Dims = net->getUpsampleDims(conv9bDims); //-> concat0
|
|
||||||
const auto concat0Dims = net->getConcatDims(upsample0Dims, inputReorderDims);
|
|
||||||
const auto conv10Dims = net->getConvDims("conv10", concat0Dims); //-> temp0
|
|
||||||
const auto conv10bDims = net->getConvDims("conv10b", conv10Dims); //-> temp1
|
|
||||||
const auto conv11Dims = net->getConvDims("conv11", conv10bDims); //-> temp0
|
|
||||||
|
|
||||||
const auto outputDims = memory::dims({1, 3, tileH, tileW});
|
|
||||||
|
|
||||||
// Allocate two temporary ping-pong buffers to decrease memory usage
|
|
||||||
const auto temp0Dims = getMaxTensorDims({
|
|
||||||
conv1Dims,
|
|
||||||
conv2Dims,
|
|
||||||
conv3Dims,
|
|
||||||
conv4Dims,
|
|
||||||
conv5Dims,
|
|
||||||
conv6Dims,
|
|
||||||
conv7Dims,
|
|
||||||
conv8Dims,
|
|
||||||
conv9Dims,
|
|
||||||
conv10Dims,
|
|
||||||
conv11Dims
|
|
||||||
});
|
|
||||||
|
|
||||||
const auto temp1Dims = getMaxTensorDims({
|
|
||||||
conv1bDims,
|
|
||||||
pool5Dims,
|
|
||||||
conv6bDims,
|
|
||||||
conv7bDims,
|
|
||||||
conv8bDims,
|
|
||||||
conv9bDims,
|
|
||||||
conv10bDims,
|
|
||||||
});
|
|
||||||
|
|
||||||
auto temp0 = net->allocTensor(temp0Dims);
|
|
||||||
auto temp1 = net->allocTensor(temp1Dims);
|
|
||||||
|
|
||||||
// Allocate enough memory to hold the concat outputs. Then use the first
|
|
||||||
// half to hold the previous conv output and the second half to hold the
|
|
||||||
// pool/orig image output. This works because everything is C dimension
|
|
||||||
// outermost, padded to K floats, and all the concats are on the C dimension.
|
|
||||||
auto concat0Dst = net->allocTensor(concat0Dims);
|
|
||||||
auto concat1Dst = net->allocTensor(concat1Dims);
|
|
||||||
auto concat2Dst = net->allocTensor(concat2Dims);
|
|
||||||
auto concat3Dst = net->allocTensor(concat3Dims);
|
|
||||||
auto concat4Dst = net->allocTensor(concat4Dims);
|
|
||||||
|
|
||||||
// Transfer function
|
|
||||||
std::shared_ptr<TransferFunction> transferFunc = makeTransferFunc();
|
|
||||||
|
|
||||||
// Autoexposure
|
|
||||||
if (auto tf = std::dynamic_pointer_cast<HDRTransferFunction>(transferFunc))
|
|
||||||
{
|
|
||||||
if (isnan(hdrScale))
|
|
||||||
net->addAutoexposure(color, tf);
|
|
||||||
else
|
|
||||||
tf->setExposure(hdrScale);
|
|
||||||
}
|
|
||||||
|
|
||||||
// Input reorder
|
|
||||||
auto inputReorderDst = net->castTensor(inputReorderDims, concat0Dst, upsample0Dims);
|
|
||||||
inputReorder = net->addInputReorder(color, albedo, normal,
|
|
||||||
transferFunc,
|
|
||||||
alignment, inputReorderDst);
|
|
||||||
|
|
||||||
// conv1
|
|
||||||
auto conv1 = net->addConv("conv1", inputReorder->getDst(), temp0);
|
|
||||||
|
|
||||||
// conv1b
|
|
||||||
auto conv1b = net->addConv("conv1b", conv1->getDst(), temp1);
|
|
||||||
|
|
||||||
// pool1
|
|
||||||
// Adjust pointer for pool1 to eliminate concat1
|
|
||||||
auto pool1Dst = net->castTensor(pool1Dims, concat1Dst, upsample1Dims);
|
|
||||||
auto pool1 = net->addPool(conv1b->getDst(), pool1Dst);
|
|
||||||
|
|
||||||
// conv2
|
|
||||||
auto conv2 = net->addConv("conv2", pool1->getDst(), temp0);
|
|
||||||
|
|
||||||
// pool2
|
|
||||||
// Adjust pointer for pool2 to eliminate concat2
|
|
||||||
auto pool2Dst = net->castTensor(pool2Dims, concat2Dst, upsample2Dims);
|
|
||||||
auto pool2 = net->addPool(conv2->getDst(), pool2Dst);
|
|
||||||
|
|
||||||
// conv3
|
|
||||||
auto conv3 = net->addConv("conv3", pool2->getDst(), temp0);
|
|
||||||
|
|
||||||
// pool3
|
|
||||||
// Adjust pointer for pool3 to eliminate concat3
|
|
||||||
auto pool3Dst = net->castTensor(pool3Dims, concat3Dst, upsample3Dims);
|
|
||||||
auto pool3 = net->addPool(conv3->getDst(), pool3Dst);
|
|
||||||
|
|
||||||
// conv4
|
|
||||||
auto conv4 = net->addConv("conv4", pool3->getDst(), temp0);
|
|
||||||
|
|
||||||
// pool4
|
|
||||||
// Adjust pointer for pool4 to eliminate concat4
|
|
||||||
auto pool4Dst = net->castTensor(pool4Dims, concat4Dst, upsample4Dims);
|
|
||||||
auto pool4 = net->addPool(conv4->getDst(), pool4Dst);
|
|
||||||
|
|
||||||
// conv5
|
|
||||||
auto conv5 = net->addConv("conv5", pool4->getDst(), temp0);
|
|
||||||
|
|
||||||
// pool5
|
|
||||||
auto pool5 = net->addPool(conv5->getDst(), temp1);
|
|
||||||
|
|
||||||
// upsample4
|
|
||||||
auto upsample4Dst = net->castTensor(upsample4Dims, concat4Dst);
|
|
||||||
auto upsample4 = net->addUpsample(pool5->getDst(), upsample4Dst);
|
|
||||||
|
|
||||||
// conv6
|
|
||||||
auto conv6 = net->addConv("conv6", concat4Dst, temp0);
|
|
||||||
|
|
||||||
// conv6b
|
|
||||||
auto conv6b = net->addConv("conv6b", conv6->getDst(), temp1);
|
|
||||||
|
|
||||||
// upsample3
|
|
||||||
auto upsample3Dst = net->castTensor(upsample3Dims, concat3Dst);
|
|
||||||
auto upsample3 = net->addUpsample(conv6b->getDst(), upsample3Dst);
|
|
||||||
|
|
||||||
// conv7
|
|
||||||
auto conv7 = net->addConv("conv7", concat3Dst, temp0);
|
|
||||||
|
|
||||||
// conv7b
|
|
||||||
auto conv7b = net->addConv("conv7b", conv7->getDst(), temp1);
|
|
||||||
|
|
||||||
// upsample2
|
|
||||||
auto upsample2Dst = net->castTensor(upsample2Dims, concat2Dst);
|
|
||||||
auto upsample2 = net->addUpsample(conv7b->getDst(), upsample2Dst);
|
|
||||||
|
|
||||||
// conv8
|
|
||||||
auto conv8 = net->addConv("conv8", concat2Dst, temp0);
|
|
||||||
|
|
||||||
// conv8b
|
|
||||||
auto conv8b = net->addConv("conv8b", conv8->getDst(), temp1);
|
|
||||||
|
|
||||||
// upsample1
|
|
||||||
auto upsample1Dst = net->castTensor(upsample1Dims, concat1Dst);
|
|
||||||
auto upsample1 = net->addUpsample(conv8b->getDst(), upsample1Dst);
|
|
||||||
|
|
||||||
// conv9
|
|
||||||
auto conv9 = net->addConv("conv9", concat1Dst, temp0);
|
|
||||||
|
|
||||||
// conv9b
|
|
||||||
auto conv9b = net->addConv("conv9b", conv9->getDst(), temp1);
|
|
||||||
|
|
||||||
// upsample0
|
|
||||||
auto upsample0Dst = net->castTensor(upsample0Dims, concat0Dst);
|
|
||||||
auto upsample0 = net->addUpsample(conv9b->getDst(), upsample0Dst);
|
|
||||||
|
|
||||||
// conv10
|
|
||||||
auto conv10 = net->addConv("conv10", concat0Dst, temp0);
|
|
||||||
|
|
||||||
// conv10b
|
|
||||||
auto conv10b = net->addConv("conv10b", conv10->getDst(), temp1);
|
|
||||||
|
|
||||||
// conv11
|
|
||||||
auto conv11 = net->addConv("conv11", conv10b->getDst(), temp0, false /* no relu */);
|
|
||||||
|
|
||||||
// Output reorder
|
|
||||||
outputReorder = net->addOutputReorder(conv11->getDst(), transferFunc, output);
|
|
||||||
|
|
||||||
net->finalize();
|
|
||||||
return net;
|
|
||||||
}
|
|
||||||
|
|
||||||
std::shared_ptr<TransferFunction> AutoencoderFilter::makeTransferFunc()
|
|
||||||
{
|
|
||||||
if (hdr)
|
|
||||||
return std::make_shared<PQXTransferFunction>();
|
|
||||||
else if (srgb)
|
|
||||||
return std::make_shared<LinearTransferFunction>();
|
|
||||||
else
|
|
||||||
return std::make_shared<GammaTransferFunction>();
|
|
||||||
}
|
|
||||||
|
|
||||||
// -- GODOT start --
|
|
||||||
// Godot doesn't need Raytracing filters. Removing them saves space in the weights files.
|
|
||||||
#if 0
|
|
||||||
// -- GODOT end --
|
|
||||||
|
|
||||||
// --------------------------------------------------------------------------
|
|
||||||
// RTFilter
|
|
||||||
// --------------------------------------------------------------------------
|
|
||||||
|
|
||||||
namespace weights
|
|
||||||
{
|
|
||||||
// LDR
|
|
||||||
extern unsigned char rt_ldr[]; // color
|
|
||||||
extern unsigned char rt_ldr_alb[]; // color, albedo
|
|
||||||
extern unsigned char rt_ldr_alb_nrm[]; // color, albedo, normal
|
|
||||||
|
|
||||||
// HDR
|
|
||||||
extern unsigned char rt_hdr[]; // color
|
|
||||||
extern unsigned char rt_hdr_alb[]; // color, albedo
|
|
||||||
extern unsigned char rt_hdr_alb_nrm[]; // color, albedo, normal
|
|
||||||
}
|
|
||||||
|
|
||||||
RTFilter::RTFilter(const Ref<Device>& device)
|
|
||||||
: AutoencoderFilter(device)
|
|
||||||
{
|
|
||||||
weightData.ldr = weights::rt_ldr;
|
|
||||||
weightData.ldr_alb = weights::rt_ldr_alb;
|
|
||||||
weightData.ldr_alb_nrm = weights::rt_ldr_alb_nrm;
|
|
||||||
weightData.hdr = weights::rt_hdr;
|
|
||||||
weightData.hdr_alb = weights::rt_hdr_alb;
|
|
||||||
weightData.hdr_alb_nrm = weights::rt_hdr_alb_nrm;
|
|
||||||
}
|
|
||||||
// -- GODOT start --
|
|
||||||
#endif
|
|
||||||
// -- GODOT end --
|
|
||||||
|
|
||||||
// --------------------------------------------------------------------------
|
|
||||||
// RTLightmapFilter
|
|
||||||
// --------------------------------------------------------------------------
|
|
||||||
|
|
||||||
namespace weights
|
|
||||||
{
|
|
||||||
// HDR
|
|
||||||
extern unsigned char rtlightmap_hdr[]; // color
|
|
||||||
}
|
|
||||||
|
|
||||||
RTLightmapFilter::RTLightmapFilter(const Ref<Device>& device)
|
|
||||||
: AutoencoderFilter(device)
|
|
||||||
{
|
|
||||||
weightData.hdr = weights::rtlightmap_hdr;
|
|
||||||
|
|
||||||
hdr = true;
|
|
||||||
}
|
|
||||||
|
|
||||||
std::shared_ptr<TransferFunction> RTLightmapFilter::makeTransferFunc()
|
|
||||||
{
|
|
||||||
return std::make_shared<LogTransferFunction>();
|
|
||||||
}
|
|
||||||
|
|
||||||
} // namespace oidn
|
|
120
thirdparty/oidn/core/autoencoder.h
vendored
120
thirdparty/oidn/core/autoencoder.h
vendored
@ -1,120 +0,0 @@
|
|||||||
// ======================================================================== //
|
|
||||||
// Copyright 2009-2019 Intel Corporation //
|
|
||||||
// //
|
|
||||||
// Licensed under the Apache License, Version 2.0 (the "License"); //
|
|
||||||
// you may not use this file except in compliance with the License. //
|
|
||||||
// You may obtain a copy of the License at //
|
|
||||||
// //
|
|
||||||
// http://www.apache.org/licenses/LICENSE-2.0 //
|
|
||||||
// //
|
|
||||||
// Unless required by applicable law or agreed to in writing, software //
|
|
||||||
// distributed under the License is distributed on an "AS IS" BASIS, //
|
|
||||||
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. //
|
|
||||||
// See the License for the specific language governing permissions and //
|
|
||||||
// limitations under the License. //
|
|
||||||
// ======================================================================== //
|
|
||||||
|
|
||||||
#pragma once
|
|
||||||
|
|
||||||
#include "filter.h"
|
|
||||||
#include "network.h"
|
|
||||||
#include "transfer_function.h"
|
|
||||||
|
|
||||||
namespace oidn {
|
|
||||||
|
|
||||||
// --------------------------------------------------------------------------
|
|
||||||
// AutoencoderFilter - Direct-predicting autoencoder
|
|
||||||
// --------------------------------------------------------------------------
|
|
||||||
|
|
||||||
class AutoencoderFilter : public Filter
|
|
||||||
{
|
|
||||||
protected:
|
|
||||||
static constexpr int alignment = 32; // required spatial alignment in pixels (padding may be necessary)
|
|
||||||
static constexpr int receptiveField = 222; // receptive field in pixels
|
|
||||||
static constexpr int overlap = roundUp(receptiveField / 2, alignment); // required spatial overlap between tiles in pixels
|
|
||||||
|
|
||||||
static constexpr int estimatedBytesBase = 16*1024*1024; // estimated base memory usage
|
|
||||||
static constexpr int estimatedBytesPerPixel8 = 889; // estimated memory usage per pixel for K=8
|
|
||||||
static constexpr int estimatedBytesPerPixel16 = 2185; // estimated memory usage per pixel for K=16
|
|
||||||
|
|
||||||
Image color;
|
|
||||||
Image albedo;
|
|
||||||
Image normal;
|
|
||||||
Image output;
|
|
||||||
bool hdr = false;
|
|
||||||
float hdrScale = std::numeric_limits<float>::quiet_NaN();
|
|
||||||
bool srgb = false;
|
|
||||||
int maxMemoryMB = 6000; // approximate maximum memory usage in MBs
|
|
||||||
|
|
||||||
int H = 0; // image height
|
|
||||||
int W = 0; // image width
|
|
||||||
int tileH = 0; // tile height
|
|
||||||
int tileW = 0; // tile width
|
|
||||||
int tileCountH = 1; // number of tiles in H dimension
|
|
||||||
int tileCountW = 1; // number of tiles in W dimension
|
|
||||||
|
|
||||||
std::shared_ptr<Executable> net;
|
|
||||||
std::shared_ptr<Node> inputReorder;
|
|
||||||
std::shared_ptr<Node> outputReorder;
|
|
||||||
|
|
||||||
struct
|
|
||||||
{
|
|
||||||
void* ldr = nullptr;
|
|
||||||
void* ldr_alb = nullptr;
|
|
||||||
void* ldr_alb_nrm = nullptr;
|
|
||||||
void* hdr = nullptr;
|
|
||||||
void* hdr_alb = nullptr;
|
|
||||||
void* hdr_alb_nrm = nullptr;
|
|
||||||
} weightData;
|
|
||||||
|
|
||||||
explicit AutoencoderFilter(const Ref<Device>& device);
|
|
||||||
virtual std::shared_ptr<TransferFunction> makeTransferFunc();
|
|
||||||
|
|
||||||
public:
|
|
||||||
void setImage(const std::string& name, const Image& data) override;
|
|
||||||
void set1i(const std::string& name, int value) override;
|
|
||||||
int get1i(const std::string& name) override;
|
|
||||||
void set1f(const std::string& name, float value) override;
|
|
||||||
float get1f(const std::string& name) override;
|
|
||||||
|
|
||||||
void commit() override;
|
|
||||||
void execute() override;
|
|
||||||
|
|
||||||
private:
|
|
||||||
void computeTileSize();
|
|
||||||
|
|
||||||
template<int K>
|
|
||||||
std::shared_ptr<Executable> buildNet();
|
|
||||||
|
|
||||||
bool isCommitted() const { return bool(net); }
|
|
||||||
};
|
|
||||||
|
|
||||||
// --------------------------------------------------------------------------
|
|
||||||
// RTFilter - Generic ray tracing denoiser
|
|
||||||
// --------------------------------------------------------------------------
|
|
||||||
|
|
||||||
// -- GODOT start --
|
|
||||||
// Godot doesn't need Raytracing filters. Removing them saves space in the weights files.
|
|
||||||
#if 0
|
|
||||||
// -- GODOT end --
|
|
||||||
class RTFilter : public AutoencoderFilter
|
|
||||||
{
|
|
||||||
public:
|
|
||||||
explicit RTFilter(const Ref<Device>& device);
|
|
||||||
};
|
|
||||||
// -- GODOT start --
|
|
||||||
#endif
|
|
||||||
// -- GODOT end --
|
|
||||||
|
|
||||||
// --------------------------------------------------------------------------
|
|
||||||
// RTLightmapFilter - Ray traced lightmap denoiser
|
|
||||||
// --------------------------------------------------------------------------
|
|
||||||
|
|
||||||
class RTLightmapFilter : public AutoencoderFilter
|
|
||||||
{
|
|
||||||
public:
|
|
||||||
explicit RTLightmapFilter(const Ref<Device>& device);
|
|
||||||
std::shared_ptr<TransferFunction> makeTransferFunc() override;
|
|
||||||
};
|
|
||||||
|
|
||||||
} // namespace oidn
|
|
75
thirdparty/oidn/core/buffer.h
vendored
75
thirdparty/oidn/core/buffer.h
vendored
@ -1,75 +0,0 @@
|
|||||||
// ======================================================================== //
|
|
||||||
// Copyright 2009-2019 Intel Corporation //
|
|
||||||
// //
|
|
||||||
// Licensed under the Apache License, Version 2.0 (the "License"); //
|
|
||||||
// you may not use this file except in compliance with the License. //
|
|
||||||
// You may obtain a copy of the License at //
|
|
||||||
// //
|
|
||||||
// http://www.apache.org/licenses/LICENSE-2.0 //
|
|
||||||
// //
|
|
||||||
// Unless required by applicable law or agreed to in writing, software //
|
|
||||||
// distributed under the License is distributed on an "AS IS" BASIS, //
|
|
||||||
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. //
|
|
||||||
// See the License for the specific language governing permissions and //
|
|
||||||
// limitations under the License. //
|
|
||||||
// ======================================================================== //
|
|
||||||
|
|
||||||
#pragma once
|
|
||||||
|
|
||||||
#include "common.h"
|
|
||||||
#include "device.h"
|
|
||||||
|
|
||||||
namespace oidn {
|
|
||||||
|
|
||||||
class Device;
|
|
||||||
|
|
||||||
// Buffer which may or may not own its data
|
|
||||||
class Buffer : public RefCount
|
|
||||||
{
|
|
||||||
private:
|
|
||||||
char* ptr;
|
|
||||||
size_t byteSize;
|
|
||||||
bool shared;
|
|
||||||
Ref<Device> device;
|
|
||||||
|
|
||||||
public:
|
|
||||||
__forceinline Buffer(const Ref<Device>& device, size_t size)
|
|
||||||
: ptr((char*)alignedMalloc(size, 64)),
|
|
||||||
byteSize(size),
|
|
||||||
shared(false),
|
|
||||||
device(device) {}
|
|
||||||
|
|
||||||
__forceinline Buffer(const Ref<Device>& device, void* data, size_t size)
|
|
||||||
: ptr((char*)data),
|
|
||||||
byteSize(size),
|
|
||||||
shared(true),
|
|
||||||
device(device)
|
|
||||||
{
|
|
||||||
if (data == nullptr)
|
|
||||||
throw Exception(Error::InvalidArgument, "buffer pointer null");
|
|
||||||
}
|
|
||||||
|
|
||||||
__forceinline ~Buffer()
|
|
||||||
{
|
|
||||||
if (!shared)
|
|
||||||
alignedFree(ptr);
|
|
||||||
}
|
|
||||||
|
|
||||||
__forceinline char* data() { return ptr; }
|
|
||||||
__forceinline const char* data() const { return ptr; }
|
|
||||||
__forceinline size_t size() const { return byteSize; }
|
|
||||||
|
|
||||||
void* map(size_t offset, size_t size)
|
|
||||||
{
|
|
||||||
if (offset + size > byteSize)
|
|
||||||
throw Exception(Error::InvalidArgument, "buffer region out of range");
|
|
||||||
|
|
||||||
return ptr + offset;
|
|
||||||
}
|
|
||||||
|
|
||||||
void unmap(void* mappedPtr) {}
|
|
||||||
|
|
||||||
Device* getDevice() { return device.get(); }
|
|
||||||
};
|
|
||||||
|
|
||||||
} // namespace oidn
|
|
136
thirdparty/oidn/core/common.h
vendored
136
thirdparty/oidn/core/common.h
vendored
@ -1,136 +0,0 @@
|
|||||||
// ======================================================================== //
|
|
||||||
// Copyright 2009-2019 Intel Corporation //
|
|
||||||
// //
|
|
||||||
// Licensed under the Apache License, Version 2.0 (the "License"); //
|
|
||||||
// you may not use this file except in compliance with the License. //
|
|
||||||
// You may obtain a copy of the License at //
|
|
||||||
// //
|
|
||||||
// http://www.apache.org/licenses/LICENSE-2.0 //
|
|
||||||
// //
|
|
||||||
// Unless required by applicable law or agreed to in writing, software //
|
|
||||||
// distributed under the License is distributed on an "AS IS" BASIS, //
|
|
||||||
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. //
|
|
||||||
// See the License for the specific language governing permissions and //
|
|
||||||
// limitations under the License. //
|
|
||||||
// ======================================================================== //
|
|
||||||
|
|
||||||
#pragma once
|
|
||||||
|
|
||||||
#include "common/platform.h"
|
|
||||||
|
|
||||||
#include "mkl-dnn/include/mkldnn.hpp"
|
|
||||||
#include "mkl-dnn/include/mkldnn_debug.h"
|
|
||||||
#include "mkl-dnn/src/common/mkldnn_thread.hpp"
|
|
||||||
#include "mkl-dnn/src/common/type_helpers.hpp"
|
|
||||||
#include "mkl-dnn/src/cpu/jit_generator.hpp"
|
|
||||||
|
|
||||||
#include "common/ref.h"
|
|
||||||
#include "common/exception.h"
|
|
||||||
#include "common/thread.h"
|
|
||||||
// -- GODOT start --
|
|
||||||
//#include "common/tasking.h"
|
|
||||||
// -- GODOT end --
|
|
||||||
#include "math.h"
|
|
||||||
|
|
||||||
namespace oidn {
|
|
||||||
|
|
||||||
using namespace mkldnn;
|
|
||||||
using namespace mkldnn::impl::cpu;
|
|
||||||
using mkldnn::impl::parallel_nd;
|
|
||||||
using mkldnn::impl::memory_desc_matches_tag;
|
|
||||||
|
|
||||||
|
|
||||||
inline size_t getFormatBytes(Format format)
|
|
||||||
{
|
|
||||||
switch (format)
|
|
||||||
{
|
|
||||||
case Format::Undefined: return 1;
|
|
||||||
case Format::Float: return sizeof(float);
|
|
||||||
case Format::Float2: return sizeof(float)*2;
|
|
||||||
case Format::Float3: return sizeof(float)*3;
|
|
||||||
case Format::Float4: return sizeof(float)*4;
|
|
||||||
}
|
|
||||||
assert(0);
|
|
||||||
return 0;
|
|
||||||
}
|
|
||||||
|
|
||||||
|
|
||||||
inline memory::dims getTensorDims(const std::shared_ptr<memory>& mem)
|
|
||||||
{
|
|
||||||
const mkldnn_memory_desc_t& desc = mem->get_desc().data;
|
|
||||||
return memory::dims(&desc.dims[0], &desc.dims[desc.ndims]);
|
|
||||||
}
|
|
||||||
|
|
||||||
inline memory::data_type getTensorType(const std::shared_ptr<memory>& mem)
|
|
||||||
{
|
|
||||||
const mkldnn_memory_desc_t& desc = mem->get_desc().data;
|
|
||||||
return memory::data_type(desc.data_type);
|
|
||||||
}
|
|
||||||
|
|
||||||
// Returns the number of values in a tensor
|
|
||||||
inline size_t getTensorSize(const memory::dims& dims)
|
|
||||||
{
|
|
||||||
size_t res = 1;
|
|
||||||
for (int i = 0; i < (int)dims.size(); ++i)
|
|
||||||
res *= dims[i];
|
|
||||||
return res;
|
|
||||||
}
|
|
||||||
|
|
||||||
inline memory::dims getMaxTensorDims(const std::vector<memory::dims>& dims)
|
|
||||||
{
|
|
||||||
memory::dims result;
|
|
||||||
size_t maxSize = 0;
|
|
||||||
|
|
||||||
for (const auto& d : dims)
|
|
||||||
{
|
|
||||||
const size_t size = getTensorSize(d);
|
|
||||||
if (size > maxSize)
|
|
||||||
{
|
|
||||||
result = d;
|
|
||||||
maxSize = size;
|
|
||||||
}
|
|
||||||
}
|
|
||||||
|
|
||||||
return result;
|
|
||||||
}
|
|
||||||
|
|
||||||
inline size_t getTensorSize(const std::shared_ptr<memory>& mem)
|
|
||||||
{
|
|
||||||
return getTensorSize(getTensorDims(mem));
|
|
||||||
}
|
|
||||||
|
|
||||||
|
|
||||||
template<int K>
|
|
||||||
inline int getPadded(int dim)
|
|
||||||
{
|
|
||||||
return (dim + (K-1)) & ~(K-1);
|
|
||||||
}
|
|
||||||
|
|
||||||
template<int K>
|
|
||||||
inline memory::dims getPadded_nchw(const memory::dims& dims)
|
|
||||||
{
|
|
||||||
assert(dims.size() == 4);
|
|
||||||
memory::dims padDims = dims;
|
|
||||||
padDims[1] = getPadded<K>(dims[1]); // pad C
|
|
||||||
return padDims;
|
|
||||||
}
|
|
||||||
|
|
||||||
|
|
||||||
template<int K>
|
|
||||||
struct BlockedFormat;
|
|
||||||
|
|
||||||
template<>
|
|
||||||
struct BlockedFormat<8>
|
|
||||||
{
|
|
||||||
static constexpr memory::format_tag nChwKc = memory::format_tag::nChw8c;
|
|
||||||
static constexpr memory::format_tag OIhwKiKo = memory::format_tag::OIhw8i8o;
|
|
||||||
};
|
|
||||||
|
|
||||||
template<>
|
|
||||||
struct BlockedFormat<16>
|
|
||||||
{
|
|
||||||
static constexpr memory::format_tag nChwKc = memory::format_tag::nChw16c;
|
|
||||||
static constexpr memory::format_tag OIhwKiKo = memory::format_tag::OIhw16i16o;
|
|
||||||
};
|
|
||||||
|
|
||||||
} // namespace oidn
|
|
238
thirdparty/oidn/core/device.cpp
vendored
238
thirdparty/oidn/core/device.cpp
vendored
@ -1,238 +0,0 @@
|
|||||||
// ======================================================================== //
|
|
||||||
// Copyright 2009-2019 Intel Corporation //
|
|
||||||
// //
|
|
||||||
// Licensed under the Apache License, Version 2.0 (the "License"); //
|
|
||||||
// you may not use this file except in compliance with the License. //
|
|
||||||
// You may obtain a copy of the License at //
|
|
||||||
// //
|
|
||||||
// http://www.apache.org/licenses/LICENSE-2.0 //
|
|
||||||
// //
|
|
||||||
// Unless required by applicable law or agreed to in writing, software //
|
|
||||||
// distributed under the License is distributed on an "AS IS" BASIS, //
|
|
||||||
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. //
|
|
||||||
// See the License for the specific language governing permissions and //
|
|
||||||
// limitations under the License. //
|
|
||||||
// ======================================================================== //
|
|
||||||
|
|
||||||
#include "device.h"
|
|
||||||
#include "autoencoder.h"
|
|
||||||
|
|
||||||
namespace oidn {
|
|
||||||
|
|
||||||
thread_local Device::ErrorState Device::globalError;
|
|
||||||
|
|
||||||
Device::Device()
|
|
||||||
{
|
|
||||||
if (!mayiuse(sse41))
|
|
||||||
throw Exception(Error::UnsupportedHardware, "SSE4.1 support is required at minimum");
|
|
||||||
}
|
|
||||||
|
|
||||||
Device::~Device()
|
|
||||||
{
|
|
||||||
// -- GODOT start --
|
|
||||||
//observer.reset();
|
|
||||||
// -- GODOT end --
|
|
||||||
}
|
|
||||||
|
|
||||||
void Device::setError(Device* device, Error code, const std::string& message)
|
|
||||||
{
|
|
||||||
// Update the stored error only if the previous error was queried
|
|
||||||
if (device)
|
|
||||||
{
|
|
||||||
ErrorState& curError = device->error.get();
|
|
||||||
|
|
||||||
if (curError.code == Error::None)
|
|
||||||
{
|
|
||||||
curError.code = code;
|
|
||||||
curError.message = message;
|
|
||||||
}
|
|
||||||
|
|
||||||
// Print the error message in verbose mode
|
|
||||||
if (device->isVerbose())
|
|
||||||
std::cerr << "Error: " << message << std::endl;
|
|
||||||
|
|
||||||
// Call the error callback function
|
|
||||||
ErrorFunction errorFunc;
|
|
||||||
void* errorUserPtr;
|
|
||||||
|
|
||||||
{
|
|
||||||
std::lock_guard<std::mutex> lock(device->mutex);
|
|
||||||
errorFunc = device->errorFunc;
|
|
||||||
errorUserPtr = device->errorUserPtr;
|
|
||||||
}
|
|
||||||
|
|
||||||
if (errorFunc)
|
|
||||||
errorFunc(errorUserPtr, code, (code == Error::None) ? nullptr : message.c_str());
|
|
||||||
}
|
|
||||||
else
|
|
||||||
{
|
|
||||||
if (globalError.code == Error::None)
|
|
||||||
{
|
|
||||||
globalError.code = code;
|
|
||||||
globalError.message = message;
|
|
||||||
}
|
|
||||||
}
|
|
||||||
}
|
|
||||||
|
|
||||||
Error Device::getError(Device* device, const char** outMessage)
|
|
||||||
{
|
|
||||||
// Return and clear the stored error code, but keep the error message so pointers to it will
|
|
||||||
// remain valid until the next getError call
|
|
||||||
if (device)
|
|
||||||
{
|
|
||||||
ErrorState& curError = device->error.get();
|
|
||||||
const Error code = curError.code;
|
|
||||||
if (outMessage)
|
|
||||||
*outMessage = (code == Error::None) ? nullptr : curError.message.c_str();
|
|
||||||
curError.code = Error::None;
|
|
||||||
return code;
|
|
||||||
}
|
|
||||||
else
|
|
||||||
{
|
|
||||||
const Error code = globalError.code;
|
|
||||||
if (outMessage)
|
|
||||||
*outMessage = (code == Error::None) ? nullptr : globalError.message.c_str();
|
|
||||||
globalError.code = Error::None;
|
|
||||||
return code;
|
|
||||||
}
|
|
||||||
}
|
|
||||||
|
|
||||||
void Device::setErrorFunction(ErrorFunction func, void* userPtr)
|
|
||||||
{
|
|
||||||
errorFunc = func;
|
|
||||||
errorUserPtr = userPtr;
|
|
||||||
}
|
|
||||||
|
|
||||||
int Device::get1i(const std::string& name)
|
|
||||||
{
|
|
||||||
if (name == "numThreads")
|
|
||||||
return numThreads;
|
|
||||||
else if (name == "setAffinity")
|
|
||||||
return setAffinity;
|
|
||||||
else if (name == "verbose")
|
|
||||||
return verbose;
|
|
||||||
else if (name == "version")
|
|
||||||
return OIDN_VERSION;
|
|
||||||
else if (name == "versionMajor")
|
|
||||||
return OIDN_VERSION_MAJOR;
|
|
||||||
else if (name == "versionMinor")
|
|
||||||
return OIDN_VERSION_MINOR;
|
|
||||||
else if (name == "versionPatch")
|
|
||||||
return OIDN_VERSION_PATCH;
|
|
||||||
else
|
|
||||||
throw Exception(Error::InvalidArgument, "invalid parameter");
|
|
||||||
}
|
|
||||||
|
|
||||||
void Device::set1i(const std::string& name, int value)
|
|
||||||
{
|
|
||||||
if (name == "numThreads")
|
|
||||||
numThreads = value;
|
|
||||||
else if (name == "setAffinity")
|
|
||||||
setAffinity = value;
|
|
||||||
else if (name == "verbose")
|
|
||||||
{
|
|
||||||
verbose = value;
|
|
||||||
error.verbose = value;
|
|
||||||
}
|
|
||||||
|
|
||||||
dirty = true;
|
|
||||||
}
|
|
||||||
|
|
||||||
void Device::commit()
|
|
||||||
{
|
|
||||||
if (isCommitted())
|
|
||||||
throw Exception(Error::InvalidOperation, "device can be committed only once");
|
|
||||||
|
|
||||||
// -- GODOT start --
|
|
||||||
#if 0
|
|
||||||
// -- GODOT end --
|
|
||||||
// Get the optimal thread affinities
|
|
||||||
if (setAffinity)
|
|
||||||
{
|
|
||||||
affinity = std::make_shared<ThreadAffinity>(1, verbose); // one thread per core
|
|
||||||
if (affinity->getNumThreads() == 0)
|
|
||||||
affinity.reset();
|
|
||||||
}
|
|
||||||
|
|
||||||
// Create the task arena
|
|
||||||
const int maxNumThreads = affinity ? affinity->getNumThreads() : tbb::this_task_arena::max_concurrency();
|
|
||||||
numThreads = (numThreads > 0) ? min(numThreads, maxNumThreads) : maxNumThreads;
|
|
||||||
arena = std::make_shared<tbb::task_arena>(numThreads);
|
|
||||||
|
|
||||||
// Automatically set the thread affinities
|
|
||||||
if (affinity)
|
|
||||||
observer = std::make_shared<PinningObserver>(affinity, *arena);
|
|
||||||
// -- GODOT start --
|
|
||||||
#endif
|
|
||||||
numThreads = 1;
|
|
||||||
// -- GODOT end --
|
|
||||||
dirty = false;
|
|
||||||
|
|
||||||
if (isVerbose())
|
|
||||||
print();
|
|
||||||
}
|
|
||||||
|
|
||||||
void Device::checkCommitted()
|
|
||||||
{
|
|
||||||
if (dirty)
|
|
||||||
throw Exception(Error::InvalidOperation, "changes to the device are not committed");
|
|
||||||
}
|
|
||||||
|
|
||||||
Ref<Buffer> Device::newBuffer(size_t byteSize)
|
|
||||||
{
|
|
||||||
checkCommitted();
|
|
||||||
return makeRef<Buffer>(Ref<Device>(this), byteSize);
|
|
||||||
}
|
|
||||||
|
|
||||||
Ref<Buffer> Device::newBuffer(void* ptr, size_t byteSize)
|
|
||||||
{
|
|
||||||
checkCommitted();
|
|
||||||
return makeRef<Buffer>(Ref<Device>(this), ptr, byteSize);
|
|
||||||
}
|
|
||||||
|
|
||||||
Ref<Filter> Device::newFilter(const std::string& type)
|
|
||||||
{
|
|
||||||
checkCommitted();
|
|
||||||
|
|
||||||
if (isVerbose())
|
|
||||||
std::cout << "Filter: " << type << std::endl;
|
|
||||||
|
|
||||||
Ref<Filter> filter;
|
|
||||||
|
|
||||||
// -- GODOT start --
|
|
||||||
// Godot doesn't need Raytracing filters. Removing them saves space in the weights files.
|
|
||||||
#if 0
|
|
||||||
// -- GODOT end --
|
|
||||||
if (type == "RT")
|
|
||||||
filter = makeRef<RTFilter>(Ref<Device>(this));
|
|
||||||
// -- GODOT start --
|
|
||||||
// Godot doesn't need Raytracing filters. Removing them saves space in the weights files.
|
|
||||||
#endif
|
|
||||||
if (type == "RTLightmap")
|
|
||||||
// -- GODOT end --
|
|
||||||
filter = makeRef<RTLightmapFilter>(Ref<Device>(this));
|
|
||||||
else
|
|
||||||
throw Exception(Error::InvalidArgument, "unknown filter type");
|
|
||||||
|
|
||||||
return filter;
|
|
||||||
}
|
|
||||||
|
|
||||||
void Device::print()
|
|
||||||
{
|
|
||||||
std::cout << std::endl;
|
|
||||||
|
|
||||||
std::cout << "Intel(R) Open Image Denoise " << OIDN_VERSION_STRING << std::endl;
|
|
||||||
std::cout << " Compiler: " << getCompilerName() << std::endl;
|
|
||||||
std::cout << " Build : " << getBuildName() << std::endl;
|
|
||||||
std::cout << " Platform: " << getPlatformName() << std::endl;
|
|
||||||
|
|
||||||
// -- GODOT start --
|
|
||||||
// std::cout << " Tasking :";
|
|
||||||
// std::cout << " TBB" << TBB_VERSION_MAJOR << "." << TBB_VERSION_MINOR;
|
|
||||||
// std::cout << " TBB_header_interface_" << TBB_INTERFACE_VERSION << " TBB_lib_interface_" << tbb::TBB_runtime_interface_version();
|
|
||||||
// std::cout << std::endl;
|
|
||||||
// -- GODOT end --
|
|
||||||
std::cout << std::endl;
|
|
||||||
}
|
|
||||||
|
|
||||||
} // namespace oidn
|
|
102
thirdparty/oidn/core/device.h
vendored
102
thirdparty/oidn/core/device.h
vendored
@ -1,102 +0,0 @@
|
|||||||
// ======================================================================== //
|
|
||||||
// Copyright 2009-2019 Intel Corporation //
|
|
||||||
// //
|
|
||||||
// Licensed under the Apache License, Version 2.0 (the "License"); //
|
|
||||||
// you may not use this file except in compliance with the License. //
|
|
||||||
// You may obtain a copy of the License at //
|
|
||||||
// //
|
|
||||||
// http://www.apache.org/licenses/LICENSE-2.0 //
|
|
||||||
// //
|
|
||||||
// Unless required by applicable law or agreed to in writing, software //
|
|
||||||
// distributed under the License is distributed on an "AS IS" BASIS, //
|
|
||||||
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. //
|
|
||||||
// See the License for the specific language governing permissions and //
|
|
||||||
// limitations under the License. //
|
|
||||||
// ======================================================================== //
|
|
||||||
|
|
||||||
#pragma once
|
|
||||||
|
|
||||||
#include "common.h"
|
|
||||||
|
|
||||||
namespace oidn {
|
|
||||||
|
|
||||||
class Buffer;
|
|
||||||
class Filter;
|
|
||||||
|
|
||||||
class Device : public RefCount, public Verbose
|
|
||||||
{
|
|
||||||
private:
|
|
||||||
// Thread-safety
|
|
||||||
std::mutex mutex;
|
|
||||||
|
|
||||||
// Error handling
|
|
||||||
struct ErrorState
|
|
||||||
{
|
|
||||||
Error code = Error::None;
|
|
||||||
std::string message;
|
|
||||||
};
|
|
||||||
|
|
||||||
static thread_local ErrorState globalError;
|
|
||||||
ThreadLocal<ErrorState> error;
|
|
||||||
ErrorFunction errorFunc = nullptr;
|
|
||||||
void* errorUserPtr = nullptr;
|
|
||||||
|
|
||||||
// -- GODOT start --
|
|
||||||
// // Tasking
|
|
||||||
// std::shared_ptr<tbb::task_arena> arena;
|
|
||||||
// std::shared_ptr<PinningObserver> observer;
|
|
||||||
// std::shared_ptr<ThreadAffinity> affinity;
|
|
||||||
// -- GODOT end --
|
|
||||||
|
|
||||||
// Parameters
|
|
||||||
int numThreads = 0; // autodetect by default
|
|
||||||
bool setAffinity = true;
|
|
||||||
|
|
||||||
bool dirty = true;
|
|
||||||
|
|
||||||
public:
|
|
||||||
Device();
|
|
||||||
~Device();
|
|
||||||
|
|
||||||
static void setError(Device* device, Error code, const std::string& message);
|
|
||||||
static Error getError(Device* device, const char** outMessage);
|
|
||||||
|
|
||||||
void setErrorFunction(ErrorFunction func, void* userPtr);
|
|
||||||
|
|
||||||
int get1i(const std::string& name);
|
|
||||||
void set1i(const std::string& name, int value);
|
|
||||||
|
|
||||||
void commit();
|
|
||||||
|
|
||||||
// -- GODOT start --
|
|
||||||
// template<typename F>
|
|
||||||
// void executeTask(F& f)
|
|
||||||
// {
|
|
||||||
// arena->execute(f);
|
|
||||||
// }
|
|
||||||
|
|
||||||
// template<typename F>
|
|
||||||
// void executeTask(const F& f)
|
|
||||||
// {
|
|
||||||
// arena->execute(f);
|
|
||||||
// }
|
|
||||||
// -- GODOT end --
|
|
||||||
|
|
||||||
Ref<Buffer> newBuffer(size_t byteSize);
|
|
||||||
Ref<Buffer> newBuffer(void* ptr, size_t byteSize);
|
|
||||||
Ref<Filter> newFilter(const std::string& type);
|
|
||||||
|
|
||||||
__forceinline Device* getDevice() { return this; }
|
|
||||||
__forceinline std::mutex& getMutex() { return mutex; }
|
|
||||||
|
|
||||||
private:
|
|
||||||
// -- GODOT start --
|
|
||||||
//bool isCommitted() const { return bool(arena); }
|
|
||||||
bool isCommitted() const { return false; }
|
|
||||||
// -- GODOT end --
|
|
||||||
void checkCommitted();
|
|
||||||
|
|
||||||
void print();
|
|
||||||
};
|
|
||||||
|
|
||||||
} // namespace oidn
|
|
27
thirdparty/oidn/core/filter.cpp
vendored
27
thirdparty/oidn/core/filter.cpp
vendored
@ -1,27 +0,0 @@
|
|||||||
// ======================================================================== //
|
|
||||||
// Copyright 2009-2019 Intel Corporation //
|
|
||||||
// //
|
|
||||||
// Licensed under the Apache License, Version 2.0 (the "License"); //
|
|
||||||
// you may not use this file except in compliance with the License. //
|
|
||||||
// You may obtain a copy of the License at //
|
|
||||||
// //
|
|
||||||
// http://www.apache.org/licenses/LICENSE-2.0 //
|
|
||||||
// //
|
|
||||||
// Unless required by applicable law or agreed to in writing, software //
|
|
||||||
// distributed under the License is distributed on an "AS IS" BASIS, //
|
|
||||||
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. //
|
|
||||||
// See the License for the specific language governing permissions and //
|
|
||||||
// limitations under the License. //
|
|
||||||
// ======================================================================== //
|
|
||||||
|
|
||||||
#include "filter.h"
|
|
||||||
|
|
||||||
namespace oidn {
|
|
||||||
|
|
||||||
void Filter::setProgressMonitorFunction(ProgressMonitorFunction func, void* userPtr)
|
|
||||||
{
|
|
||||||
progressFunc = func;
|
|
||||||
progressUserPtr = userPtr;
|
|
||||||
}
|
|
||||||
|
|
||||||
} // namespace oidn
|
|
52
thirdparty/oidn/core/filter.h
vendored
52
thirdparty/oidn/core/filter.h
vendored
@ -1,52 +0,0 @@
|
|||||||
// ======================================================================== //
|
|
||||||
// Copyright 2009-2019 Intel Corporation //
|
|
||||||
// //
|
|
||||||
// Licensed under the Apache License, Version 2.0 (the "License"); //
|
|
||||||
// you may not use this file except in compliance with the License. //
|
|
||||||
// You may obtain a copy of the License at //
|
|
||||||
// //
|
|
||||||
// http://www.apache.org/licenses/LICENSE-2.0 //
|
|
||||||
// //
|
|
||||||
// Unless required by applicable law or agreed to in writing, software //
|
|
||||||
// distributed under the License is distributed on an "AS IS" BASIS, //
|
|
||||||
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. //
|
|
||||||
// See the License for the specific language governing permissions and //
|
|
||||||
// limitations under the License. //
|
|
||||||
// ======================================================================== //
|
|
||||||
|
|
||||||
#pragma once
|
|
||||||
|
|
||||||
#include "common.h"
|
|
||||||
#include "device.h"
|
|
||||||
#include "image.h"
|
|
||||||
|
|
||||||
namespace oidn {
|
|
||||||
|
|
||||||
class Filter : public RefCount
|
|
||||||
{
|
|
||||||
protected:
|
|
||||||
Ref<Device> device;
|
|
||||||
|
|
||||||
ProgressMonitorFunction progressFunc = nullptr;
|
|
||||||
void* progressUserPtr = nullptr;
|
|
||||||
|
|
||||||
bool dirty = true;
|
|
||||||
|
|
||||||
public:
|
|
||||||
explicit Filter(const Ref<Device>& device) : device(device) {}
|
|
||||||
|
|
||||||
virtual void setImage(const std::string& name, const Image& data) = 0;
|
|
||||||
virtual void set1i(const std::string& name, int value) = 0;
|
|
||||||
virtual int get1i(const std::string& name) = 0;
|
|
||||||
virtual void set1f(const std::string& name, float value) = 0;
|
|
||||||
virtual float get1f(const std::string& name) = 0;
|
|
||||||
|
|
||||||
void setProgressMonitorFunction(ProgressMonitorFunction func, void* userPtr);
|
|
||||||
|
|
||||||
virtual void commit() = 0;
|
|
||||||
virtual void execute() = 0;
|
|
||||||
|
|
||||||
Device* getDevice() { return device.get(); }
|
|
||||||
};
|
|
||||||
|
|
||||||
} // namespace oidn
|
|
111
thirdparty/oidn/core/image.h
vendored
111
thirdparty/oidn/core/image.h
vendored
@ -1,111 +0,0 @@
|
|||||||
// ======================================================================== //
|
|
||||||
// Copyright 2009-2019 Intel Corporation //
|
|
||||||
// //
|
|
||||||
// Licensed under the Apache License, Version 2.0 (the "License"); //
|
|
||||||
// you may not use this file except in compliance with the License. //
|
|
||||||
// You may obtain a copy of the License at //
|
|
||||||
// //
|
|
||||||
// http://www.apache.org/licenses/LICENSE-2.0 //
|
|
||||||
// //
|
|
||||||
// Unless required by applicable law or agreed to in writing, software //
|
|
||||||
// distributed under the License is distributed on an "AS IS" BASIS, //
|
|
||||||
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. //
|
|
||||||
// See the License for the specific language governing permissions and //
|
|
||||||
// limitations under the License. //
|
|
||||||
// ======================================================================== //
|
|
||||||
|
|
||||||
#pragma once
|
|
||||||
|
|
||||||
#include "common.h"
|
|
||||||
#include "buffer.h"
|
|
||||||
|
|
||||||
namespace oidn {
|
|
||||||
|
|
||||||
struct Image
|
|
||||||
{
|
|
||||||
static constexpr int maxSize = 65536;
|
|
||||||
|
|
||||||
char* ptr; // pointer to the first pixel
|
|
||||||
int width; // width in number of pixels
|
|
||||||
int height; // height in number of pixels
|
|
||||||
size_t bytePixelStride; // pixel stride in number of *bytes*
|
|
||||||
size_t rowStride; // row stride in number of *pixel strides*
|
|
||||||
Format format; // pixel format
|
|
||||||
Ref<Buffer> buffer; // buffer containing the image data
|
|
||||||
|
|
||||||
Image() : ptr(nullptr), width(0), height(0), bytePixelStride(0), rowStride(0), format(Format::Undefined) {}
|
|
||||||
|
|
||||||
Image(void* ptr, Format format, int width, int height, size_t byteOffset, size_t inBytePixelStride, size_t inByteRowStride)
|
|
||||||
{
|
|
||||||
if (ptr == nullptr)
|
|
||||||
throw Exception(Error::InvalidArgument, "buffer pointer null");
|
|
||||||
|
|
||||||
init((char*)ptr + byteOffset, format, width, height, inBytePixelStride, inByteRowStride);
|
|
||||||
}
|
|
||||||
|
|
||||||
Image(const Ref<Buffer>& buffer, Format format, int width, int height, size_t byteOffset, size_t inBytePixelStride, size_t inByteRowStride)
|
|
||||||
{
|
|
||||||
init(buffer->data() + byteOffset, format, width, height, inBytePixelStride, inByteRowStride);
|
|
||||||
|
|
||||||
if (byteOffset + height * rowStride * bytePixelStride > buffer->size())
|
|
||||||
throw Exception(Error::InvalidArgument, "buffer region out of range");
|
|
||||||
}
|
|
||||||
|
|
||||||
void init(char* ptr, Format format, int width, int height, size_t inBytePixelStride, size_t inByteRowStride)
|
|
||||||
{
|
|
||||||
assert(width >= 0);
|
|
||||||
assert(height >= 0);
|
|
||||||
if (width > maxSize || height > maxSize)
|
|
||||||
throw Exception(Error::InvalidArgument, "image size too large");
|
|
||||||
|
|
||||||
this->ptr = ptr;
|
|
||||||
this->width = width;
|
|
||||||
this->height = height;
|
|
||||||
|
|
||||||
const size_t pixelSize = getFormatBytes(format);
|
|
||||||
if (inBytePixelStride != 0)
|
|
||||||
{
|
|
||||||
if (inBytePixelStride < pixelSize)
|
|
||||||
throw Exception(Error::InvalidArgument, "pixel stride smaller than pixel size");
|
|
||||||
|
|
||||||
this->bytePixelStride = inBytePixelStride;
|
|
||||||
}
|
|
||||||
else
|
|
||||||
{
|
|
||||||
this->bytePixelStride = pixelSize;
|
|
||||||
}
|
|
||||||
|
|
||||||
if (inByteRowStride != 0)
|
|
||||||
{
|
|
||||||
if (inByteRowStride < width * this->bytePixelStride)
|
|
||||||
throw Exception(Error::InvalidArgument, "row stride smaller than width * pixel stride");
|
|
||||||
if (inByteRowStride % this->bytePixelStride != 0)
|
|
||||||
throw Exception(Error::InvalidArgument, "row stride not integer multiple of pixel stride");
|
|
||||||
|
|
||||||
this->rowStride = inByteRowStride / this->bytePixelStride;
|
|
||||||
}
|
|
||||||
else
|
|
||||||
{
|
|
||||||
this->rowStride = width;
|
|
||||||
}
|
|
||||||
|
|
||||||
this->format = format;
|
|
||||||
}
|
|
||||||
|
|
||||||
__forceinline char* get(int y, int x)
|
|
||||||
{
|
|
||||||
return ptr + ((size_t(y) * rowStride + size_t(x)) * bytePixelStride);
|
|
||||||
}
|
|
||||||
|
|
||||||
__forceinline const char* get(int y, int x) const
|
|
||||||
{
|
|
||||||
return ptr + ((size_t(y) * rowStride + size_t(x)) * bytePixelStride);
|
|
||||||
}
|
|
||||||
|
|
||||||
operator bool() const
|
|
||||||
{
|
|
||||||
return ptr != nullptr;
|
|
||||||
}
|
|
||||||
};
|
|
||||||
|
|
||||||
} // namespace oidn
|
|
232
thirdparty/oidn/core/input_reorder.h
vendored
232
thirdparty/oidn/core/input_reorder.h
vendored
@ -1,232 +0,0 @@
|
|||||||
// ======================================================================== //
|
|
||||||
// Copyright 2009-2019 Intel Corporation //
|
|
||||||
// //
|
|
||||||
// Licensed under the Apache License, Version 2.0 (the "License"); //
|
|
||||||
// you may not use this file except in compliance with the License. //
|
|
||||||
// You may obtain a copy of the License at //
|
|
||||||
// //
|
|
||||||
// http://www.apache.org/licenses/LICENSE-2.0 //
|
|
||||||
// //
|
|
||||||
// Unless required by applicable law or agreed to in writing, software //
|
|
||||||
// distributed under the License is distributed on an "AS IS" BASIS, //
|
|
||||||
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. //
|
|
||||||
// See the License for the specific language governing permissions and //
|
|
||||||
// limitations under the License. //
|
|
||||||
// ======================================================================== //
|
|
||||||
|
|
||||||
#pragma once
|
|
||||||
|
|
||||||
#include "node.h"
|
|
||||||
#include "image.h"
|
|
||||||
|
|
||||||
namespace oidn {
|
|
||||||
|
|
||||||
// Input reorder node
|
|
||||||
template<int K, class TransferFunction>
|
|
||||||
class InputReorderNode : public Node
|
|
||||||
{
|
|
||||||
private:
|
|
||||||
// Source
|
|
||||||
Image color;
|
|
||||||
Image albedo;
|
|
||||||
Image normal;
|
|
||||||
|
|
||||||
// Destination
|
|
||||||
std::shared_ptr<memory> dst;
|
|
||||||
float* dstPtr;
|
|
||||||
int C2;
|
|
||||||
int H2;
|
|
||||||
int W2;
|
|
||||||
|
|
||||||
// Tile
|
|
||||||
int h1Begin;
|
|
||||||
int w1Begin;
|
|
||||||
int h2Begin;
|
|
||||||
int w2Begin;
|
|
||||||
int H;
|
|
||||||
int W;
|
|
||||||
|
|
||||||
std::shared_ptr<TransferFunction> transferFunc;
|
|
||||||
|
|
||||||
public:
|
|
||||||
InputReorderNode(const Image& color,
|
|
||||||
const Image& albedo,
|
|
||||||
const Image& normal,
|
|
||||||
const std::shared_ptr<memory>& dst,
|
|
||||||
const std::shared_ptr<TransferFunction>& transferFunc)
|
|
||||||
: color(color), albedo(albedo), normal(normal),
|
|
||||||
dst(dst),
|
|
||||||
h1Begin(0), w1Begin(0),
|
|
||||||
H(color.height), W(color.width),
|
|
||||||
transferFunc(transferFunc)
|
|
||||||
{
|
|
||||||
const mkldnn_memory_desc_t& dstDesc = dst->get_desc().data;
|
|
||||||
assert(memory_desc_matches_tag(dstDesc, mkldnn_format_tag_t(BlockedFormat<K>::nChwKc)));
|
|
||||||
assert(dstDesc.ndims == 4);
|
|
||||||
assert(dstDesc.data_type == memory::data_type::f32);
|
|
||||||
assert(dstDesc.dims[0] == 1);
|
|
||||||
//assert(dstDesc.dims[1] >= getPadded<K>(C1));
|
|
||||||
|
|
||||||
dstPtr = (float*)dst->get_data_handle();
|
|
||||||
C2 = dstDesc.dims[1];
|
|
||||||
H2 = dstDesc.dims[2];
|
|
||||||
W2 = dstDesc.dims[3];
|
|
||||||
}
|
|
||||||
|
|
||||||
void setTile(int h1, int w1, int h2, int w2, int H, int W) override
|
|
||||||
{
|
|
||||||
h1Begin = h1;
|
|
||||||
w1Begin = w1;
|
|
||||||
h2Begin = h2;
|
|
||||||
w2Begin = w2;
|
|
||||||
this->H = H;
|
|
||||||
this->W = W;
|
|
||||||
}
|
|
||||||
|
|
||||||
void execute(stream& sm) override
|
|
||||||
{
|
|
||||||
assert(H + h1Begin <= color.height);
|
|
||||||
assert(W + w1Begin <= color.width);
|
|
||||||
assert(H + h2Begin <= H2);
|
|
||||||
assert(W + w2Begin <= W2);
|
|
||||||
|
|
||||||
parallel_nd(H2, [&](int h2)
|
|
||||||
{
|
|
||||||
const int h = h2 - h2Begin;
|
|
||||||
|
|
||||||
if (h >= 0 && h < H)
|
|
||||||
{
|
|
||||||
const int h1 = h + h1Begin;
|
|
||||||
|
|
||||||
// Zero pad
|
|
||||||
for (int w2 = 0; w2 < w2Begin; ++w2)
|
|
||||||
{
|
|
||||||
int c = 0;
|
|
||||||
while (c < C2)
|
|
||||||
store(h2, w2, c, 0.f);
|
|
||||||
}
|
|
||||||
|
|
||||||
// Reorder
|
|
||||||
for (int w = 0; w < W; ++w)
|
|
||||||
{
|
|
||||||
const int w1 = w + w1Begin;
|
|
||||||
const int w2 = w + w2Begin;
|
|
||||||
|
|
||||||
int c = 0;
|
|
||||||
storeColor(h2, w2, c, (float*)color.get(h1, w1));
|
|
||||||
if (albedo)
|
|
||||||
storeAlbedo(h2, w2, c, (float*)albedo.get(h1, w1));
|
|
||||||
if (normal)
|
|
||||||
storeNormal(h2, w2, c, (float*)normal.get(h1, w1));
|
|
||||||
while (c < C2)
|
|
||||||
store(h2, w2, c, 0.f);
|
|
||||||
}
|
|
||||||
|
|
||||||
// Zero pad
|
|
||||||
for (int w2 = W + w2Begin; w2 < W2; ++w2)
|
|
||||||
{
|
|
||||||
int c = 0;
|
|
||||||
while (c < C2)
|
|
||||||
store(h2, w2, c, 0.f);
|
|
||||||
}
|
|
||||||
}
|
|
||||||
else
|
|
||||||
{
|
|
||||||
// Zero pad
|
|
||||||
for (int w2 = 0; w2 < W2; ++w2)
|
|
||||||
{
|
|
||||||
int c = 0;
|
|
||||||
while (c < C2)
|
|
||||||
store(h2, w2, c, 0.f);
|
|
||||||
}
|
|
||||||
}
|
|
||||||
});
|
|
||||||
}
|
|
||||||
|
|
||||||
std::shared_ptr<memory> getDst() const override { return dst; }
|
|
||||||
|
|
||||||
private:
|
|
||||||
// Stores a single value
|
|
||||||
__forceinline void store(int h, int w, int& c, float value)
|
|
||||||
{
|
|
||||||
// Destination is in nChwKc format
|
|
||||||
float* dst_c = dstPtr + (H2*W2*K*(c/K)) + h*W2*K + w*K + (c%K);
|
|
||||||
*dst_c = value;
|
|
||||||
c++;
|
|
||||||
}
|
|
||||||
|
|
||||||
// Stores a color
|
|
||||||
__forceinline void storeColor(int h, int w, int& c, const float* values)
|
|
||||||
{
|
|
||||||
#pragma unroll
|
|
||||||
for (int i = 0; i < 3; ++i)
|
|
||||||
{
|
|
||||||
// Load the value
|
|
||||||
float x = values[i];
|
|
||||||
|
|
||||||
// Sanitize the value
|
|
||||||
x = maxSafe(x, 0.f);
|
|
||||||
|
|
||||||
// Apply the transfer function
|
|
||||||
x = transferFunc->forward(x);
|
|
||||||
|
|
||||||
// Store the value
|
|
||||||
store(h, w, c, x);
|
|
||||||
}
|
|
||||||
}
|
|
||||||
|
|
||||||
// Stores an albedo
|
|
||||||
__forceinline void storeAlbedo(int h, int w, int& c, const float* values)
|
|
||||||
{
|
|
||||||
#pragma unroll
|
|
||||||
for (int i = 0; i < 3; ++i)
|
|
||||||
{
|
|
||||||
// Load the value
|
|
||||||
float x = values[i];
|
|
||||||
|
|
||||||
// Sanitize the value
|
|
||||||
x = clampSafe(x, 0.f, 1.f);
|
|
||||||
|
|
||||||
// Store the value
|
|
||||||
store(h, w, c, x);
|
|
||||||
}
|
|
||||||
}
|
|
||||||
|
|
||||||
// Stores a normal
|
|
||||||
__forceinline void storeNormal(int h, int w, int& c, const float* values)
|
|
||||||
{
|
|
||||||
// Load the normal
|
|
||||||
float x = values[0];
|
|
||||||
float y = values[1];
|
|
||||||
float z = values[2];
|
|
||||||
|
|
||||||
// Compute the length of the normal
|
|
||||||
const float lengthSqr = sqr(x) + sqr(y) + sqr(z);
|
|
||||||
|
|
||||||
// Normalize the normal and transform it to [0..1]
|
|
||||||
if (isfinite(lengthSqr))
|
|
||||||
{
|
|
||||||
const float invLength = (lengthSqr > minVectorLengthSqr) ? rsqrt(lengthSqr) : 1.f;
|
|
||||||
|
|
||||||
const float scale = invLength * 0.5f;
|
|
||||||
const float offset = 0.5f;
|
|
||||||
|
|
||||||
x = x * scale + offset;
|
|
||||||
y = y * scale + offset;
|
|
||||||
z = z * scale + offset;
|
|
||||||
}
|
|
||||||
else
|
|
||||||
{
|
|
||||||
x = 0.f;
|
|
||||||
y = 0.f;
|
|
||||||
z = 0.f;
|
|
||||||
}
|
|
||||||
|
|
||||||
// Store the normal
|
|
||||||
store(h, w, c, x);
|
|
||||||
store(h, w, c, y);
|
|
||||||
store(h, w, c, z);
|
|
||||||
}
|
|
||||||
};
|
|
||||||
|
|
||||||
} // namespace oidn
|
|
78
thirdparty/oidn/core/math.h
vendored
78
thirdparty/oidn/core/math.h
vendored
@ -1,78 +0,0 @@
|
|||||||
// ======================================================================== //
|
|
||||||
// Copyright 2009-2019 Intel Corporation //
|
|
||||||
// //
|
|
||||||
// Licensed under the Apache License, Version 2.0 (the "License"); //
|
|
||||||
// you may not use this file except in compliance with the License. //
|
|
||||||
// You may obtain a copy of the License at //
|
|
||||||
// //
|
|
||||||
// http://www.apache.org/licenses/LICENSE-2.0 //
|
|
||||||
// //
|
|
||||||
// Unless required by applicable law or agreed to in writing, software //
|
|
||||||
// distributed under the License is distributed on an "AS IS" BASIS, //
|
|
||||||
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. //
|
|
||||||
// See the License for the specific language governing permissions and //
|
|
||||||
// limitations under the License. //
|
|
||||||
// ======================================================================== //
|
|
||||||
|
|
||||||
#pragma once
|
|
||||||
|
|
||||||
#include "common/platform.h"
|
|
||||||
|
|
||||||
namespace oidn {
|
|
||||||
|
|
||||||
constexpr float minVectorLength = 1e-10f;
|
|
||||||
constexpr float minVectorLengthSqr = minVectorLength * minVectorLength;
|
|
||||||
|
|
||||||
using std::log;
|
|
||||||
using std::log2;
|
|
||||||
using std::exp;
|
|
||||||
using std::exp2;
|
|
||||||
using std::pow;
|
|
||||||
using std::isfinite;
|
|
||||||
using std::isnan;
|
|
||||||
|
|
||||||
__forceinline float sqr(float x)
|
|
||||||
{
|
|
||||||
return x * x;
|
|
||||||
}
|
|
||||||
|
|
||||||
__forceinline float rcp(float x)
|
|
||||||
{
|
|
||||||
__m128 r = _mm_rcp_ss(_mm_set_ss(x));
|
|
||||||
return _mm_cvtss_f32(_mm_sub_ss(_mm_add_ss(r, r), _mm_mul_ss(_mm_mul_ss(r, r), _mm_set_ss(x))));
|
|
||||||
}
|
|
||||||
|
|
||||||
__forceinline float rsqrt(float x)
|
|
||||||
{
|
|
||||||
__m128 r = _mm_rsqrt_ss(_mm_set_ss(x));
|
|
||||||
return _mm_cvtss_f32(_mm_add_ss(_mm_mul_ss(_mm_set_ss(1.5f), r),
|
|
||||||
_mm_mul_ss(_mm_mul_ss(_mm_mul_ss(_mm_set_ss(x), _mm_set_ss(-0.5f)), r), _mm_mul_ss(r, r))));
|
|
||||||
}
|
|
||||||
|
|
||||||
__forceinline float maxSafe(float value, float minValue)
|
|
||||||
{
|
|
||||||
return isfinite(value) ? max(value, minValue) : minValue;
|
|
||||||
}
|
|
||||||
|
|
||||||
__forceinline float clampSafe(float value, float minValue, float maxValue)
|
|
||||||
{
|
|
||||||
return isfinite(value) ? clamp(value, minValue, maxValue) : minValue;
|
|
||||||
}
|
|
||||||
|
|
||||||
// Returns ceil(a / b) for non-negative integers
|
|
||||||
template<class Int>
|
|
||||||
__forceinline constexpr Int ceilDiv(Int a, Int b)
|
|
||||||
{
|
|
||||||
//assert(a >= 0);
|
|
||||||
//assert(b > 0);
|
|
||||||
return (a + b - 1) / b;
|
|
||||||
}
|
|
||||||
|
|
||||||
// Returns a rounded up to multiple of b
|
|
||||||
template<class Int>
|
|
||||||
__forceinline constexpr Int roundUp(Int a, Int b)
|
|
||||||
{
|
|
||||||
return ceilDiv(a, b) * b;
|
|
||||||
}
|
|
||||||
|
|
||||||
} // namespace oidn
|
|
436
thirdparty/oidn/core/network.cpp
vendored
436
thirdparty/oidn/core/network.cpp
vendored
@ -1,436 +0,0 @@
|
|||||||
// ======================================================================== //
|
|
||||||
// Copyright 2009-2019 Intel Corporation //
|
|
||||||
// //
|
|
||||||
// Licensed under the Apache License, Version 2.0 (the "License"); //
|
|
||||||
// you may not use this file except in compliance with the License. //
|
|
||||||
// You may obtain a copy of the License at //
|
|
||||||
// //
|
|
||||||
// http://www.apache.org/licenses/LICENSE-2.0 //
|
|
||||||
// //
|
|
||||||
// Unless required by applicable law or agreed to in writing, software //
|
|
||||||
// distributed under the License is distributed on an "AS IS" BASIS, //
|
|
||||||
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. //
|
|
||||||
// See the License for the specific language governing permissions and //
|
|
||||||
// limitations under the License. //
|
|
||||||
// ======================================================================== //
|
|
||||||
|
|
||||||
#include "upsample.h"
|
|
||||||
#include "weights_reorder.h"
|
|
||||||
#include "network.h"
|
|
||||||
// -- GODOT start --
|
|
||||||
#include <cstring>
|
|
||||||
// -- GODOT end --
|
|
||||||
|
|
||||||
namespace oidn {
|
|
||||||
|
|
||||||
template<int K>
|
|
||||||
Network<K>::Network(const Ref<Device>& device, const std::map<std::string, Tensor>& weightMap)
|
|
||||||
: device(device),
|
|
||||||
eng(engine::cpu, 0),
|
|
||||||
sm(eng),
|
|
||||||
weightMap(weightMap)
|
|
||||||
{
|
|
||||||
}
|
|
||||||
|
|
||||||
template<int K>
|
|
||||||
void Network<K>::execute(const Progress& progress, int taskIndex)
|
|
||||||
{
|
|
||||||
if (progress.func)
|
|
||||||
{
|
|
||||||
const double value = double(taskIndex) / double(progress.taskCount);
|
|
||||||
if (!progress.func(progress.userPtr, value))
|
|
||||||
throw Exception(Error::Cancelled, "execution was cancelled");
|
|
||||||
}
|
|
||||||
|
|
||||||
for (size_t i = 0; i < nodes.size(); ++i)
|
|
||||||
{
|
|
||||||
nodes[i]->execute(sm);
|
|
||||||
|
|
||||||
if (progress.func)
|
|
||||||
{
|
|
||||||
const double value = (double(taskIndex) + double(i+1) / double(nodes.size())) / double(progress.taskCount);
|
|
||||||
if (!progress.func(progress.userPtr, value))
|
|
||||||
throw Exception(Error::Cancelled, "execution was cancelled");
|
|
||||||
}
|
|
||||||
}
|
|
||||||
}
|
|
||||||
|
|
||||||
template<int K>
|
|
||||||
std::shared_ptr<memory> Network<K>::allocTensor(const memory::dims& dims,
|
|
||||||
memory::format_tag format,
|
|
||||||
void* data)
|
|
||||||
{
|
|
||||||
if (format == memory::format_tag::any)
|
|
||||||
{
|
|
||||||
if (dims.size() == 4)
|
|
||||||
format = BlockedFormat<K>::nChwKc;
|
|
||||||
else if (dims.size() == 1)
|
|
||||||
format = memory::format_tag::x;
|
|
||||||
else
|
|
||||||
assert(0);
|
|
||||||
}
|
|
||||||
memory::desc desc(dims, memory::data_type::f32, format);
|
|
||||||
if (data == nullptr)
|
|
||||||
{
|
|
||||||
const size_t bytes = getTensorSize(dims) * sizeof(float);
|
|
||||||
if (format == BlockedFormat<K>::nChwKc)
|
|
||||||
activationAllocBytes += bytes;
|
|
||||||
totalAllocBytes += bytes;
|
|
||||||
|
|
||||||
return std::make_shared<memory>(desc, eng);
|
|
||||||
}
|
|
||||||
else
|
|
||||||
{
|
|
||||||
return std::make_shared<memory>(desc, eng, data);
|
|
||||||
}
|
|
||||||
}
|
|
||||||
|
|
||||||
template<int K>
|
|
||||||
std::shared_ptr<memory> Network<K>::castTensor(const memory::dims& dims,
|
|
||||||
const std::shared_ptr<memory>& src,
|
|
||||||
size_t srcOffset,
|
|
||||||
memory::format_tag format)
|
|
||||||
{
|
|
||||||
const mkldnn_memory_desc_t& srcDesc = src->get_desc().data;
|
|
||||||
MAYBE_UNUSED(srcDesc);
|
|
||||||
assert(srcDesc.data_type == memory::data_type::f32);
|
|
||||||
assert(getTensorSize(src) >= srcOffset + getTensorSize(dims));
|
|
||||||
|
|
||||||
if (format == memory::format_tag::any)
|
|
||||||
{
|
|
||||||
if (dims.size() == 4)
|
|
||||||
format = BlockedFormat<K>::nChwKc;
|
|
||||||
else if (dims.size() == 1)
|
|
||||||
format = memory::format_tag::x;
|
|
||||||
else
|
|
||||||
assert(0);
|
|
||||||
}
|
|
||||||
memory::desc desc(dims, memory::data_type::f32, format);
|
|
||||||
float* srcPtr = (float*)src->get_data_handle() + srcOffset;
|
|
||||||
return std::make_shared<memory>(desc, eng, srcPtr);
|
|
||||||
}
|
|
||||||
|
|
||||||
template<int K>
|
|
||||||
std::shared_ptr<memory> Network<K>::castTensor(const memory::dims& dims,
|
|
||||||
const std::shared_ptr<memory>& src,
|
|
||||||
const memory::dims& srcOffset)
|
|
||||||
{
|
|
||||||
return castTensor(dims, src, getTensorSize(srcOffset));
|
|
||||||
}
|
|
||||||
|
|
||||||
template<int K>
|
|
||||||
void Network<K>::zeroTensor(const std::shared_ptr<memory>& dst)
|
|
||||||
{
|
|
||||||
assert(getTensorType(dst) == memory::data_type::f32);
|
|
||||||
memset(dst->get_data_handle(), 0, getTensorSize(dst)*sizeof(float));
|
|
||||||
}
|
|
||||||
|
|
||||||
template<int K>
|
|
||||||
memory::dims Network<K>::getInputReorderDims(const memory::dims& srcDims, int alignment)
|
|
||||||
{
|
|
||||||
memory::dims dstDims = srcDims;
|
|
||||||
dstDims[1] = getPadded<K>(srcDims[1]); // round up C
|
|
||||||
dstDims[2] = roundUp(srcDims[2], memory::dim(alignment)); // round up H
|
|
||||||
dstDims[3] = roundUp(srcDims[3], memory::dim(alignment)); // round up W
|
|
||||||
return dstDims;
|
|
||||||
}
|
|
||||||
|
|
||||||
template<int K>
|
|
||||||
std::shared_ptr<Node> Network<K>::addInputReorder(const Image& color,
|
|
||||||
const Image& albedo,
|
|
||||||
const Image& normal,
|
|
||||||
const std::shared_ptr<TransferFunction>& transferFunc,
|
|
||||||
int alignment,
|
|
||||||
const std::shared_ptr<memory>& userDst)
|
|
||||||
{
|
|
||||||
assert(color);
|
|
||||||
int inputC = 3;
|
|
||||||
if (albedo) inputC += 3;
|
|
||||||
if (normal) inputC += 3;
|
|
||||||
|
|
||||||
memory::dims srcDims = {1, inputC, color.height, color.width};
|
|
||||||
memory::dims dstDims = getInputReorderDims(srcDims, alignment);
|
|
||||||
|
|
||||||
// Allocate padded memory
|
|
||||||
auto dst = userDst;
|
|
||||||
if (!dst)
|
|
||||||
dst = allocTensor(dstDims);
|
|
||||||
|
|
||||||
// Push node
|
|
||||||
std::shared_ptr<Node> node;
|
|
||||||
|
|
||||||
if (auto tf = std::dynamic_pointer_cast<LinearTransferFunction>(transferFunc))
|
|
||||||
node = std::make_shared<InputReorderNode<K, LinearTransferFunction>>(color, albedo, normal, dst, tf);
|
|
||||||
else if (auto tf = std::dynamic_pointer_cast<GammaTransferFunction>(transferFunc))
|
|
||||||
node = std::make_shared<InputReorderNode<K, GammaTransferFunction>>(color, albedo, normal, dst, tf);
|
|
||||||
else if (auto tf = std::dynamic_pointer_cast<LogTransferFunction>(transferFunc))
|
|
||||||
node = std::make_shared<InputReorderNode<K, LogTransferFunction>>(color, albedo, normal, dst, tf);
|
|
||||||
else if (auto tf = std::dynamic_pointer_cast<PQXTransferFunction>(transferFunc))
|
|
||||||
node = std::make_shared<InputReorderNode<K, PQXTransferFunction>>(color, albedo, normal, dst, tf);
|
|
||||||
else
|
|
||||||
assert(0);
|
|
||||||
|
|
||||||
nodes.push_back(node);
|
|
||||||
return node;
|
|
||||||
}
|
|
||||||
|
|
||||||
template<int K>
|
|
||||||
std::shared_ptr<Node> Network<K>::addOutputReorder(const std::shared_ptr<memory>& src,
|
|
||||||
const std::shared_ptr<TransferFunction>& transferFunc,
|
|
||||||
const Image& output)
|
|
||||||
{
|
|
||||||
memory::dims srcDims = getTensorDims(src);
|
|
||||||
assert(srcDims[1] == K);
|
|
||||||
|
|
||||||
// Push node
|
|
||||||
std::shared_ptr<Node> node;
|
|
||||||
|
|
||||||
if (auto tf = std::dynamic_pointer_cast<LinearTransferFunction>(transferFunc))
|
|
||||||
node = std::make_shared<OutputReorderNode<K, LinearTransferFunction>>(src, output, tf);
|
|
||||||
else if (auto tf = std::dynamic_pointer_cast<GammaTransferFunction>(transferFunc))
|
|
||||||
node = std::make_shared<OutputReorderNode<K, GammaTransferFunction>>(src, output, tf);
|
|
||||||
else if (auto tf = std::dynamic_pointer_cast<LogTransferFunction>(transferFunc))
|
|
||||||
node = std::make_shared<OutputReorderNode<K, LogTransferFunction>>(src, output, tf);
|
|
||||||
else if (auto tf = std::dynamic_pointer_cast<PQXTransferFunction>(transferFunc))
|
|
||||||
node = std::make_shared<OutputReorderNode<K, PQXTransferFunction>>(src, output, tf);
|
|
||||||
else
|
|
||||||
assert(0);
|
|
||||||
|
|
||||||
nodes.push_back(node);
|
|
||||||
return node;
|
|
||||||
}
|
|
||||||
|
|
||||||
template<int K>
|
|
||||||
memory::dims Network<K>::getConvDims(const std::string& name, const memory::dims& srcDims)
|
|
||||||
{
|
|
||||||
auto b = weightMap[name + "/b"];
|
|
||||||
memory::dims dstDims = srcDims;
|
|
||||||
dstDims[1] = getPadded<K>(b.dims[0]); // dstDims[C] = getPadded(OC)
|
|
||||||
return dstDims;
|
|
||||||
}
|
|
||||||
|
|
||||||
template<int K>
|
|
||||||
std::shared_ptr<Node> Network<K>::addConv(const std::string& name,
|
|
||||||
const std::shared_ptr<memory>& src,
|
|
||||||
const std::shared_ptr<memory>& userDst,
|
|
||||||
bool relu)
|
|
||||||
{
|
|
||||||
const memory::dims strides = {1, 1};
|
|
||||||
const memory::dims padding = {1, 1};
|
|
||||||
|
|
||||||
memory::dims srcDims = getTensorDims(src);
|
|
||||||
|
|
||||||
// Get the weights
|
|
||||||
const auto& W = weightMap[name + "/W"];
|
|
||||||
if (W.ndims() != 4 || W.format != "oihw")
|
|
||||||
throw Exception(Error::InvalidOperation, "invalid convolution weights");
|
|
||||||
memory::dims weightsDims = W.dims;
|
|
||||||
auto userWeights = allocTensor(weightsDims, memory::format_tag::oihw, W.data);
|
|
||||||
|
|
||||||
// Pad the weights
|
|
||||||
memory::dims weightsPadDims = weightsDims;
|
|
||||||
weightsPadDims[1] = getPadded<K>(weightsDims[1]); // IC
|
|
||||||
weightsPadDims[0] = getPadded<K>(weightsDims[0]); // OC
|
|
||||||
assert(srcDims[1] == weightsPadDims[1]); // srcDims[C] == weightsPadDims[IC]
|
|
||||||
auto weightsPad = allocTensor(weightsPadDims, memory::format_tag::oihw);
|
|
||||||
WeightsReorderNode<K>(userWeights, weightsPad).execute(sm);
|
|
||||||
|
|
||||||
// Get the biases
|
|
||||||
const auto& b = weightMap[name + "/b"];
|
|
||||||
if (b.ndims() != 1)
|
|
||||||
throw Exception(Error::InvalidOperation, "invalid convolution biases");
|
|
||||||
memory::dims biasDims = b.dims;
|
|
||||||
|
|
||||||
// Copy/pad the biases
|
|
||||||
memory::dims biasPadDims = {getPadded<K>(biasDims[0])};
|
|
||||||
auto bias = allocTensor(biasPadDims);
|
|
||||||
if (biasDims[0] != biasPadDims[0])
|
|
||||||
memset(bias->get_data_handle(), 0, biasPadDims[0]*sizeof(float));
|
|
||||||
memcpy(bias->get_data_handle(), b.data, biasDims[0]*sizeof(float));
|
|
||||||
|
|
||||||
// Allocate memory for destination
|
|
||||||
memory::dims dstDims = srcDims;
|
|
||||||
dstDims[1] = weightsPadDims[0]; // dstDims[C] = weightsPadDims[OC]
|
|
||||||
|
|
||||||
std::shared_ptr<memory> dst;
|
|
||||||
if (!userDst)
|
|
||||||
dst = allocTensor(dstDims);
|
|
||||||
else if (getTensorDims(userDst) == dstDims)
|
|
||||||
dst = userDst;
|
|
||||||
else
|
|
||||||
dst = castTensor(dstDims, userDst);
|
|
||||||
|
|
||||||
// Create a convolution
|
|
||||||
// Let the convolution primitive choose the weights format
|
|
||||||
auto weightsDesc = memory::desc({ weightsPadDims }, memory::data_type::f32, memory::format_tag::any);
|
|
||||||
|
|
||||||
auto convAlgo = (K == 16) ? convolution_winograd : convolution_direct;
|
|
||||||
auto convDesc = convolution_forward::desc(
|
|
||||||
prop_kind::forward_inference, convAlgo,
|
|
||||||
src->get_desc(),
|
|
||||||
weightsDesc,
|
|
||||||
bias->get_desc(),
|
|
||||||
dst->get_desc(),
|
|
||||||
strides, padding, padding, padding_kind::zero);
|
|
||||||
|
|
||||||
// Incorporate relu
|
|
||||||
mkldnn::primitive_attr convAttr;
|
|
||||||
if (relu)
|
|
||||||
{
|
|
||||||
mkldnn::post_ops ops;
|
|
||||||
ops.append_eltwise(
|
|
||||||
1.f, // scale factor, not used
|
|
||||||
algorithm::eltwise_relu,
|
|
||||||
0.f, // max with
|
|
||||||
0.f // unused
|
|
||||||
);
|
|
||||||
convAttr.set_post_ops(ops);
|
|
||||||
}
|
|
||||||
convAttr.set_scratchpad_mode(scratchpad_mode_user);
|
|
||||||
|
|
||||||
auto convPrimDesc = convolution_forward::primitive_desc(convDesc, convAttr, eng);
|
|
||||||
|
|
||||||
// Reorder the weights to the final format, if necessary
|
|
||||||
auto weights = weightsPad;
|
|
||||||
if (convPrimDesc.weights_desc() != weightsPad->get_desc())
|
|
||||||
{
|
|
||||||
weights = std::make_shared<memory>(convPrimDesc.weights_desc(), eng);
|
|
||||||
ReorderNode(weightsPad, weights).execute(sm);
|
|
||||||
}
|
|
||||||
|
|
||||||
// Create convolution node and add it to the net
|
|
||||||
auto node = std::make_shared<ConvNode>(convPrimDesc, src, weights, bias, dst);
|
|
||||||
nodes.push_back(node);
|
|
||||||
return node;
|
|
||||||
}
|
|
||||||
|
|
||||||
template<int K>
|
|
||||||
memory::dims Network<K>::getPoolDims(const memory::dims& srcDims)
|
|
||||||
{
|
|
||||||
memory::dims dstDims = srcDims;
|
|
||||||
dstDims[2] /= 2; // H/2
|
|
||||||
dstDims[3] /= 2; // W/2
|
|
||||||
return dstDims;
|
|
||||||
}
|
|
||||||
|
|
||||||
template<int K>
|
|
||||||
std::shared_ptr<Node> Network<K>::addPool(const std::shared_ptr<memory>& src,
|
|
||||||
const std::shared_ptr<memory>& userDst)
|
|
||||||
{
|
|
||||||
const memory::dims kernel = {2, 2};
|
|
||||||
const memory::dims strides = {2, 2};
|
|
||||||
const memory::dims padding = {0, 0};
|
|
||||||
|
|
||||||
memory::dims srcDims = getTensorDims(src);
|
|
||||||
memory::dims dstDims = getPoolDims(srcDims);
|
|
||||||
|
|
||||||
std::shared_ptr<memory> dst;
|
|
||||||
if (!userDst)
|
|
||||||
dst = allocTensor(dstDims);
|
|
||||||
else if (getTensorDims(userDst) == dstDims)
|
|
||||||
dst = userDst;
|
|
||||||
else
|
|
||||||
dst = castTensor(dstDims, userDst);
|
|
||||||
|
|
||||||
auto poolDesc = pooling_forward::desc(
|
|
||||||
prop_kind::forward_inference, pooling_max,
|
|
||||||
src->get_desc(),
|
|
||||||
dst->get_desc(),
|
|
||||||
strides, kernel, padding, padding, padding_kind::zero);
|
|
||||||
|
|
||||||
mkldnn::primitive_attr poolAttr;
|
|
||||||
poolAttr.set_scratchpad_mode(scratchpad_mode_user);
|
|
||||||
|
|
||||||
auto poolPrimDesc = pooling_forward::primitive_desc(poolDesc, poolAttr, eng);
|
|
||||||
|
|
||||||
auto node = std::make_shared<PoolNode>(poolPrimDesc, src, dst);
|
|
||||||
nodes.push_back(node);
|
|
||||||
return node;
|
|
||||||
}
|
|
||||||
|
|
||||||
template<int K>
|
|
||||||
memory::dims Network<K>::getUpsampleDims(const memory::dims& srcDims)
|
|
||||||
{
|
|
||||||
memory::dims dstDims = srcDims;
|
|
||||||
dstDims[2] *= 2; // H*2
|
|
||||||
dstDims[3] *= 2; // W*2
|
|
||||||
return dstDims;
|
|
||||||
}
|
|
||||||
|
|
||||||
template<int K>
|
|
||||||
std::shared_ptr<Node> Network<K>::addUpsample(const std::shared_ptr<memory>& src,
|
|
||||||
const std::shared_ptr<memory>& userDst)
|
|
||||||
{
|
|
||||||
memory::dims srcDims = getTensorDims(src);
|
|
||||||
memory::dims dstDims = getUpsampleDims(srcDims);
|
|
||||||
|
|
||||||
std::shared_ptr<memory> dst;
|
|
||||||
if (!userDst)
|
|
||||||
dst = allocTensor(dstDims);
|
|
||||||
else if (getTensorDims(userDst) == dstDims)
|
|
||||||
dst = userDst;
|
|
||||||
else
|
|
||||||
dst = castTensor(dstDims, userDst);
|
|
||||||
|
|
||||||
// Create upsampling node and add it to net
|
|
||||||
auto node = std::make_shared<UpsampleNode<K>>(src, dst);
|
|
||||||
nodes.push_back(node);
|
|
||||||
return node;
|
|
||||||
}
|
|
||||||
|
|
||||||
template<int K>
|
|
||||||
memory::dims Network<K>::getConcatDims(const memory::dims& src1Dims, const memory::dims& src2Dims)
|
|
||||||
{
|
|
||||||
assert(src1Dims[0] == src2Dims[0]); // N
|
|
||||||
assert(src1Dims[2] == src2Dims[2]); // H
|
|
||||||
assert(src1Dims[3] == src2Dims[3]); // W
|
|
||||||
|
|
||||||
memory::dims dstDims = src1Dims;
|
|
||||||
dstDims[1] += src2Dims[1]; // C
|
|
||||||
return dstDims;
|
|
||||||
}
|
|
||||||
|
|
||||||
template<int K>
|
|
||||||
std::shared_ptr<Node> Network<K>::addAutoexposure(const Image& color,
|
|
||||||
const std::shared_ptr<HDRTransferFunction>& transferFunc)
|
|
||||||
{
|
|
||||||
auto node = std::make_shared<AutoexposureNode>(color, transferFunc);
|
|
||||||
nodes.push_back(node);
|
|
||||||
return node;
|
|
||||||
}
|
|
||||||
|
|
||||||
template <int K>
|
|
||||||
void Network<K>::finalize()
|
|
||||||
{
|
|
||||||
// Compute the size of the scratchpad
|
|
||||||
size_t scratchpadSize = 0;
|
|
||||||
for (const auto& node : nodes)
|
|
||||||
scratchpadSize = max(scratchpadSize, node->getScratchpadSize());
|
|
||||||
|
|
||||||
// Allocate the scratchpad
|
|
||||||
memory::dims scratchpadDims = { memory::dim(scratchpadSize) };
|
|
||||||
memory::desc scratchpadDesc(scratchpadDims, memory::data_type::u8, memory::format_tag::x);
|
|
||||||
auto scratchpad = std::make_shared<memory>(scratchpadDesc, eng);
|
|
||||||
activationAllocBytes += scratchpadSize;
|
|
||||||
totalAllocBytes += scratchpadSize;
|
|
||||||
|
|
||||||
// Set the scratchpad for the nodes
|
|
||||||
for (auto& node : nodes)
|
|
||||||
node->setScratchpad(scratchpad);
|
|
||||||
|
|
||||||
// Free the weights
|
|
||||||
weightMap.clear();
|
|
||||||
|
|
||||||
// Print statistics
|
|
||||||
if (device->isVerbose(2))
|
|
||||||
{
|
|
||||||
std::cout << "Activation bytes: " << activationAllocBytes << std::endl;
|
|
||||||
std::cout << "Scratchpad bytes: " << scratchpadSize << std::endl;
|
|
||||||
std::cout << "Total bytes : " << totalAllocBytes << std::endl;
|
|
||||||
}
|
|
||||||
}
|
|
||||||
|
|
||||||
template class Network<8>;
|
|
||||||
template class Network<16>;
|
|
||||||
|
|
||||||
} // namespace oidn
|
|
112
thirdparty/oidn/core/network.h
vendored
112
thirdparty/oidn/core/network.h
vendored
@ -1,112 +0,0 @@
|
|||||||
// ======================================================================== //
|
|
||||||
// Copyright 2009-2019 Intel Corporation //
|
|
||||||
// //
|
|
||||||
// Licensed under the Apache License, Version 2.0 (the "License"); //
|
|
||||||
// you may not use this file except in compliance with the License. //
|
|
||||||
// You may obtain a copy of the License at //
|
|
||||||
// //
|
|
||||||
// http://www.apache.org/licenses/LICENSE-2.0 //
|
|
||||||
// //
|
|
||||||
// Unless required by applicable law or agreed to in writing, software //
|
|
||||||
// distributed under the License is distributed on an "AS IS" BASIS, //
|
|
||||||
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. //
|
|
||||||
// See the License for the specific language governing permissions and //
|
|
||||||
// limitations under the License. //
|
|
||||||
// ======================================================================== //
|
|
||||||
|
|
||||||
#include "common/tensor.h"
|
|
||||||
#include "image.h"
|
|
||||||
#include "node.h"
|
|
||||||
#include "input_reorder.h"
|
|
||||||
#include "output_reorder.h"
|
|
||||||
#include "transfer_function.h"
|
|
||||||
|
|
||||||
#pragma once
|
|
||||||
|
|
||||||
namespace oidn {
|
|
||||||
|
|
||||||
// Progress state
|
|
||||||
struct Progress
|
|
||||||
{
|
|
||||||
ProgressMonitorFunction func;
|
|
||||||
void* userPtr;
|
|
||||||
int taskCount;
|
|
||||||
};
|
|
||||||
|
|
||||||
class Executable
|
|
||||||
{
|
|
||||||
public:
|
|
||||||
virtual ~Executable() {}
|
|
||||||
virtual void execute(const Progress& progress, int taskIndex) = 0;
|
|
||||||
};
|
|
||||||
|
|
||||||
template<int K>
|
|
||||||
class Network : public Executable
|
|
||||||
{
|
|
||||||
public:
|
|
||||||
Network(const Ref<Device>& device, const std::map<std::string, Tensor>& weightMap);
|
|
||||||
|
|
||||||
void execute(const Progress& progress, int taskIndex) override;
|
|
||||||
|
|
||||||
std::shared_ptr<memory> allocTensor(const memory::dims& dims,
|
|
||||||
memory::format_tag format = memory::format_tag::any,
|
|
||||||
void* data = nullptr);
|
|
||||||
|
|
||||||
std::shared_ptr<memory> castTensor(const memory::dims& dims,
|
|
||||||
const std::shared_ptr<memory>& src,
|
|
||||||
size_t srcOffset = 0,
|
|
||||||
memory::format_tag format = memory::format_tag::any);
|
|
||||||
|
|
||||||
std::shared_ptr<memory> castTensor(const memory::dims& dims,
|
|
||||||
const std::shared_ptr<memory>& src,
|
|
||||||
const memory::dims& srcOffset);
|
|
||||||
|
|
||||||
void zeroTensor(const std::shared_ptr<memory>& dst);
|
|
||||||
|
|
||||||
memory::dims getInputReorderDims(const memory::dims& srcDims, int alignment);
|
|
||||||
|
|
||||||
std::shared_ptr<Node> addInputReorder(const Image& color,
|
|
||||||
const Image& albedo,
|
|
||||||
const Image& normal,
|
|
||||||
const std::shared_ptr<TransferFunction>& transferFunc,
|
|
||||||
int alignment,
|
|
||||||
const std::shared_ptr<memory>& userDst = nullptr);
|
|
||||||
|
|
||||||
std::shared_ptr<Node> addOutputReorder(const std::shared_ptr<memory>& src,
|
|
||||||
const std::shared_ptr<TransferFunction>& transferFunc,
|
|
||||||
const Image& output);
|
|
||||||
|
|
||||||
memory::dims getConvDims(const std::string& name, const memory::dims& srcDims);
|
|
||||||
std::shared_ptr<Node> addConv(const std::string& name,
|
|
||||||
const std::shared_ptr<memory>& src,
|
|
||||||
const std::shared_ptr<memory>& userDst = nullptr,
|
|
||||||
bool relu = true);
|
|
||||||
|
|
||||||
memory::dims getPoolDims(const memory::dims& srcDims);
|
|
||||||
std::shared_ptr<Node> addPool(const std::shared_ptr<memory>& src,
|
|
||||||
const std::shared_ptr<memory>& userDst = nullptr);
|
|
||||||
|
|
||||||
memory::dims getUpsampleDims(const memory::dims& srcDims);
|
|
||||||
std::shared_ptr<Node> addUpsample(const std::shared_ptr<memory>& src,
|
|
||||||
const std::shared_ptr<memory>& userDst = nullptr);
|
|
||||||
|
|
||||||
memory::dims getConcatDims(const memory::dims& src1Dims, const memory::dims& src2Dims);
|
|
||||||
|
|
||||||
std::shared_ptr<Node> addAutoexposure(const Image& color,
|
|
||||||
const std::shared_ptr<HDRTransferFunction>& transferFunc);
|
|
||||||
|
|
||||||
void finalize();
|
|
||||||
|
|
||||||
private:
|
|
||||||
Ref<Device> device;
|
|
||||||
engine eng;
|
|
||||||
stream sm;
|
|
||||||
std::vector<std::shared_ptr<Node>> nodes;
|
|
||||||
std::map<std::string, Tensor> weightMap;
|
|
||||||
|
|
||||||
// Memory allocation statistics
|
|
||||||
size_t activationAllocBytes = 0; // number of allocated activation bytes
|
|
||||||
size_t totalAllocBytes = 0; // total number of allocated bytes
|
|
||||||
};
|
|
||||||
|
|
||||||
} // namespace oidn
|
|
142
thirdparty/oidn/core/node.h
vendored
142
thirdparty/oidn/core/node.h
vendored
@ -1,142 +0,0 @@
|
|||||||
// ======================================================================== //
|
|
||||||
// Copyright 2009-2019 Intel Corporation //
|
|
||||||
// //
|
|
||||||
// Licensed under the Apache License, Version 2.0 (the "License"); //
|
|
||||||
// you may not use this file except in compliance with the License. //
|
|
||||||
// You may obtain a copy of the License at //
|
|
||||||
// //
|
|
||||||
// http://www.apache.org/licenses/LICENSE-2.0 //
|
|
||||||
// //
|
|
||||||
// Unless required by applicable law or agreed to in writing, software //
|
|
||||||
// distributed under the License is distributed on an "AS IS" BASIS, //
|
|
||||||
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. //
|
|
||||||
// See the License for the specific language governing permissions and //
|
|
||||||
// limitations under the License. //
|
|
||||||
// ======================================================================== //
|
|
||||||
|
|
||||||
#pragma once
|
|
||||||
|
|
||||||
#include "common.h"
|
|
||||||
#include <vector>
|
|
||||||
|
|
||||||
namespace oidn {
|
|
||||||
|
|
||||||
class Node
|
|
||||||
{
|
|
||||||
public:
|
|
||||||
virtual ~Node() = default;
|
|
||||||
|
|
||||||
virtual void execute(stream& sm) = 0;
|
|
||||||
|
|
||||||
virtual std::shared_ptr<memory> getDst() const { return nullptr; }
|
|
||||||
|
|
||||||
virtual size_t getScratchpadSize() const { return 0; }
|
|
||||||
virtual void setScratchpad(const std::shared_ptr<memory>& mem) {}
|
|
||||||
|
|
||||||
virtual void setTile(int h1, int w1, int h2, int w2, int H, int W)
|
|
||||||
{
|
|
||||||
assert(0); // not supported
|
|
||||||
}
|
|
||||||
};
|
|
||||||
|
|
||||||
// Node wrapping an MKL-DNN primitive
|
|
||||||
class MklNode : public Node
|
|
||||||
{
|
|
||||||
private:
|
|
||||||
primitive prim;
|
|
||||||
std::unordered_map<int, memory> args;
|
|
||||||
std::shared_ptr<memory> scratchpad;
|
|
||||||
|
|
||||||
public:
|
|
||||||
MklNode(const primitive& prim, const std::unordered_map<int, memory>& args)
|
|
||||||
: prim(prim),
|
|
||||||
args(args)
|
|
||||||
{}
|
|
||||||
|
|
||||||
size_t getScratchpadSize() const override
|
|
||||||
{
|
|
||||||
const auto primDesc = prim.get_primitive_desc();
|
|
||||||
const mkldnn_memory_desc_t* scratchpadDesc = mkldnn_primitive_desc_query_md(primDesc, mkldnn_query_scratchpad_md, 0);
|
|
||||||
if (scratchpadDesc == nullptr)
|
|
||||||
return 0;
|
|
||||||
return mkldnn_memory_desc_get_size(scratchpadDesc);
|
|
||||||
}
|
|
||||||
|
|
||||||
void setScratchpad(const std::shared_ptr<memory>& mem) override
|
|
||||||
{
|
|
||||||
scratchpad = mem;
|
|
||||||
args.insert(std::make_pair(MKLDNN_ARG_SCRATCHPAD, *scratchpad));
|
|
||||||
}
|
|
||||||
|
|
||||||
void execute(stream& sm) override
|
|
||||||
{
|
|
||||||
prim.execute(sm, args);
|
|
||||||
}
|
|
||||||
};
|
|
||||||
|
|
||||||
// Convolution node
|
|
||||||
class ConvNode : public MklNode
|
|
||||||
{
|
|
||||||
private:
|
|
||||||
std::shared_ptr<memory> src;
|
|
||||||
std::shared_ptr<memory> weights;
|
|
||||||
std::shared_ptr<memory> bias;
|
|
||||||
std::shared_ptr<memory> dst;
|
|
||||||
|
|
||||||
public:
|
|
||||||
ConvNode(const convolution_forward::primitive_desc& desc,
|
|
||||||
const std::shared_ptr<memory>& src,
|
|
||||||
const std::shared_ptr<memory>& weights,
|
|
||||||
const std::shared_ptr<memory>& bias,
|
|
||||||
const std::shared_ptr<memory>& dst)
|
|
||||||
: MklNode(convolution_forward(desc),
|
|
||||||
{ { MKLDNN_ARG_SRC, *src },
|
|
||||||
{ MKLDNN_ARG_WEIGHTS, *weights },
|
|
||||||
{ MKLDNN_ARG_BIAS, *bias },
|
|
||||||
{ MKLDNN_ARG_DST, *dst } }),
|
|
||||||
src(src), weights(weights), bias(bias), dst(dst)
|
|
||||||
{}
|
|
||||||
|
|
||||||
std::shared_ptr<memory> getDst() const override { return dst; }
|
|
||||||
};
|
|
||||||
|
|
||||||
// Pooling node
|
|
||||||
class PoolNode : public MklNode
|
|
||||||
{
|
|
||||||
private:
|
|
||||||
std::shared_ptr<memory> src;
|
|
||||||
std::shared_ptr<memory> dst;
|
|
||||||
|
|
||||||
public:
|
|
||||||
PoolNode(const pooling_forward::primitive_desc& desc,
|
|
||||||
const std::shared_ptr<memory>& src,
|
|
||||||
const std::shared_ptr<memory>& dst)
|
|
||||||
: MklNode(pooling_forward(desc),
|
|
||||||
{ { MKLDNN_ARG_SRC, *src },
|
|
||||||
{ MKLDNN_ARG_DST, *dst } }),
|
|
||||||
src(src), dst(dst)
|
|
||||||
{}
|
|
||||||
|
|
||||||
std::shared_ptr<memory> getDst() const override { return dst; }
|
|
||||||
};
|
|
||||||
|
|
||||||
// Reorder node
|
|
||||||
class ReorderNode : public MklNode
|
|
||||||
{
|
|
||||||
private:
|
|
||||||
std::shared_ptr<memory> src;
|
|
||||||
std::shared_ptr<memory> dst;
|
|
||||||
|
|
||||||
public:
|
|
||||||
ReorderNode(const std::shared_ptr<memory>& src,
|
|
||||||
const std::shared_ptr<memory>& dst)
|
|
||||||
: MklNode(reorder(reorder::primitive_desc(*src, *dst)),
|
|
||||||
{ { MKLDNN_ARG_SRC, *src },
|
|
||||||
{ MKLDNN_ARG_DST, *dst } }),
|
|
||||||
src(src), dst(dst)
|
|
||||||
{}
|
|
||||||
|
|
||||||
std::shared_ptr<memory> getDst() const override { return dst; }
|
|
||||||
};
|
|
||||||
|
|
||||||
} // namespace oidn
|
|
126
thirdparty/oidn/core/output_reorder.h
vendored
126
thirdparty/oidn/core/output_reorder.h
vendored
@ -1,126 +0,0 @@
|
|||||||
// ======================================================================== //
|
|
||||||
// Copyright 2009-2019 Intel Corporation //
|
|
||||||
// //
|
|
||||||
// Licensed under the Apache License, Version 2.0 (the "License"); //
|
|
||||||
// you may not use this file except in compliance with the License. //
|
|
||||||
// You may obtain a copy of the License at //
|
|
||||||
// //
|
|
||||||
// http://www.apache.org/licenses/LICENSE-2.0 //
|
|
||||||
// //
|
|
||||||
// Unless required by applicable law or agreed to in writing, software //
|
|
||||||
// distributed under the License is distributed on an "AS IS" BASIS, //
|
|
||||||
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. //
|
|
||||||
// See the License for the specific language governing permissions and //
|
|
||||||
// limitations under the License. //
|
|
||||||
// ======================================================================== //
|
|
||||||
|
|
||||||
#pragma once
|
|
||||||
|
|
||||||
#include "node.h"
|
|
||||||
#include "image.h"
|
|
||||||
|
|
||||||
namespace oidn {
|
|
||||||
|
|
||||||
// Output reorder node
|
|
||||||
template<int K, class TransferFunction>
|
|
||||||
class OutputReorderNode : public Node
|
|
||||||
{
|
|
||||||
private:
|
|
||||||
// Source
|
|
||||||
std::shared_ptr<memory> src;
|
|
||||||
const float* srcPtr;
|
|
||||||
int H1;
|
|
||||||
int W1;
|
|
||||||
|
|
||||||
// Destination
|
|
||||||
Image output;
|
|
||||||
|
|
||||||
// Tile
|
|
||||||
int h1Begin;
|
|
||||||
int w1Begin;
|
|
||||||
int h2Begin;
|
|
||||||
int w2Begin;
|
|
||||||
int H;
|
|
||||||
int W;
|
|
||||||
|
|
||||||
std::shared_ptr<TransferFunction> transferFunc;
|
|
||||||
|
|
||||||
public:
|
|
||||||
OutputReorderNode(const std::shared_ptr<memory>& src,
|
|
||||||
const Image& output,
|
|
||||||
const std::shared_ptr<TransferFunction>& transferFunc)
|
|
||||||
: src(src),
|
|
||||||
output(output),
|
|
||||||
h1Begin(0), w1Begin(0),
|
|
||||||
h2Begin(0), w2Begin(0),
|
|
||||||
H(output.height), W(output.width),
|
|
||||||
transferFunc(transferFunc)
|
|
||||||
{
|
|
||||||
const mkldnn_memory_desc_t& srcDesc = src->get_desc().data;
|
|
||||||
MAYBE_UNUSED(srcDesc);
|
|
||||||
assert(memory_desc_matches_tag(srcDesc, mkldnn_format_tag_t(BlockedFormat<K>::nChwKc)));
|
|
||||||
assert(srcDesc.ndims == 4);
|
|
||||||
assert(srcDesc.data_type == memory::data_type::f32);
|
|
||||||
assert(srcDesc.dims[0] == 1);
|
|
||||||
// We assume output data is <= K OC
|
|
||||||
assert(srcDesc.dims[1] == K);
|
|
||||||
|
|
||||||
srcPtr = (float*)src->get_data_handle();
|
|
||||||
H1 = srcDesc.dims[2];
|
|
||||||
W1 = srcDesc.dims[3];
|
|
||||||
}
|
|
||||||
|
|
||||||
void setTile(int h1, int w1, int h2, int w2, int H, int W) override
|
|
||||||
{
|
|
||||||
h1Begin = h1;
|
|
||||||
w1Begin = w1;
|
|
||||||
h2Begin = h2;
|
|
||||||
w2Begin = w2;
|
|
||||||
this->H = H;
|
|
||||||
this->W = W;
|
|
||||||
}
|
|
||||||
|
|
||||||
void execute(stream& sm) override
|
|
||||||
{
|
|
||||||
assert(h1Begin + H <= H1);
|
|
||||||
assert(w1Begin + W <= W1);
|
|
||||||
assert(h2Begin + H <= output.height);
|
|
||||||
assert(w2Begin + W <= output.width);
|
|
||||||
|
|
||||||
const int C1 = K;
|
|
||||||
|
|
||||||
parallel_nd(H, [&](int h)
|
|
||||||
{
|
|
||||||
const int h1 = h + h1Begin;
|
|
||||||
const int h2 = h + h2Begin;
|
|
||||||
|
|
||||||
for (int w = 0; w < W; ++w)
|
|
||||||
{
|
|
||||||
const int w1 = w + w1Begin;
|
|
||||||
const int w2 = w + w2Begin;
|
|
||||||
float* dstPtr_C = (float*)output.get(h2, w2);
|
|
||||||
|
|
||||||
// Source is in nChwKc format. In this case C is 1 so this is really nhwc
|
|
||||||
const float* srcPtr_C = srcPtr + h1*W1*C1 + w1*C1;
|
|
||||||
|
|
||||||
#pragma unroll
|
|
||||||
for (int i = 0; i < 3; ++i)
|
|
||||||
{
|
|
||||||
// Load the value
|
|
||||||
float x = srcPtr_C[i];
|
|
||||||
|
|
||||||
// The CNN output may contain negative values or even NaNs, so it must be sanitized
|
|
||||||
x = maxSafe(x, 0.f);
|
|
||||||
|
|
||||||
// Apply the inverse transfer function
|
|
||||||
x = transferFunc->inverse(x);
|
|
||||||
|
|
||||||
// Sanitize and store the final value
|
|
||||||
dstPtr_C[i] = max(x, 0.f);
|
|
||||||
}
|
|
||||||
}
|
|
||||||
});
|
|
||||||
}
|
|
||||||
};
|
|
||||||
|
|
||||||
} // namespace oidn
|
|
103
thirdparty/oidn/core/transfer_function.cpp
vendored
103
thirdparty/oidn/core/transfer_function.cpp
vendored
@ -1,103 +0,0 @@
|
|||||||
// ======================================================================== //
|
|
||||||
// Copyright 2009-2019 Intel Corporation //
|
|
||||||
// //
|
|
||||||
// Licensed under the Apache License, Version 2.0 (the "License"); //
|
|
||||||
// you may not use this file except in compliance with the License. //
|
|
||||||
// You may obtain a copy of the License at //
|
|
||||||
// //
|
|
||||||
// http://www.apache.org/licenses/LICENSE-2.0 //
|
|
||||||
// //
|
|
||||||
// Unless required by applicable law or agreed to in writing, software //
|
|
||||||
// distributed under the License is distributed on an "AS IS" BASIS, //
|
|
||||||
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. //
|
|
||||||
// See the License for the specific language governing permissions and //
|
|
||||||
// limitations under the License. //
|
|
||||||
// ======================================================================== //
|
|
||||||
|
|
||||||
#include "transfer_function.h"
|
|
||||||
|
|
||||||
namespace oidn {
|
|
||||||
|
|
||||||
const float LogTransferFunction::xScale = 1.f / log(LogTransferFunction::yMax + 1.f);
|
|
||||||
const float PQXTransferFunction::xScale = 1.f / PQXTransferFunction::pqxForward(PQXTransferFunction::yMax * PQXTransferFunction::yScale);
|
|
||||||
|
|
||||||
float AutoexposureNode::autoexposure(const Image& color)
|
|
||||||
{
|
|
||||||
assert(color.format == Format::Float3);
|
|
||||||
|
|
||||||
constexpr float key = 0.18f;
|
|
||||||
constexpr float eps = 1e-8f;
|
|
||||||
constexpr int K = 16; // downsampling amount
|
|
||||||
|
|
||||||
// Downsample the image to minimize sensitivity to noise
|
|
||||||
const int H = color.height; // original height
|
|
||||||
const int W = color.width; // original width
|
|
||||||
const int HK = (H + K/2) / K; // downsampled height
|
|
||||||
const int WK = (W + K/2) / K; // downsampled width
|
|
||||||
|
|
||||||
// Compute the average log luminance of the downsampled image
|
|
||||||
using Sum = std::pair<float, int>;
|
|
||||||
|
|
||||||
// -- GODOT start --
|
|
||||||
// Sum sum =
|
|
||||||
// tbb::parallel_reduce(
|
|
||||||
// tbb::blocked_range2d<int>(0, HK, 0, WK),
|
|
||||||
// Sum(0.f, 0),
|
|
||||||
// [&](const tbb::blocked_range2d<int>& r, Sum sum) -> Sum
|
|
||||||
// {
|
|
||||||
// // Iterate over blocks
|
|
||||||
// for (int i = r.rows().begin(); i != r.rows().end(); ++i)
|
|
||||||
// {
|
|
||||||
// for (int j = r.cols().begin(); j != r.cols().end(); ++j)
|
|
||||||
// {
|
|
||||||
|
|
||||||
Sum sum = Sum(0.0f, 0);
|
|
||||||
|
|
||||||
for (int i = 0; i != HK; ++i)
|
|
||||||
{
|
|
||||||
for (int j = 0; j != WK; ++j)
|
|
||||||
{
|
|
||||||
// Compute the average luminance in the current block
|
|
||||||
const int beginH = int(ptrdiff_t(i) * H / HK);
|
|
||||||
const int beginW = int(ptrdiff_t(j) * W / WK);
|
|
||||||
const int endH = int(ptrdiff_t(i+1) * H / HK);
|
|
||||||
const int endW = int(ptrdiff_t(j+1) * W / WK);
|
|
||||||
|
|
||||||
float L = 0.f;
|
|
||||||
|
|
||||||
for (int h = beginH; h < endH; ++h)
|
|
||||||
{
|
|
||||||
for (int w = beginW; w < endW; ++w)
|
|
||||||
{
|
|
||||||
const float* rgb = (const float*)color.get(h, w);
|
|
||||||
|
|
||||||
const float r = maxSafe(rgb[0], 0.f);
|
|
||||||
const float g = maxSafe(rgb[1], 0.f);
|
|
||||||
const float b = maxSafe(rgb[2], 0.f);
|
|
||||||
|
|
||||||
L += luminance(r, g, b);
|
|
||||||
}
|
|
||||||
}
|
|
||||||
|
|
||||||
L /= (endH - beginH) * (endW - beginW);
|
|
||||||
|
|
||||||
// Accumulate the log luminance
|
|
||||||
if (L > eps)
|
|
||||||
{
|
|
||||||
sum.first += log2(L);
|
|
||||||
sum.second++;
|
|
||||||
}
|
|
||||||
}
|
|
||||||
}
|
|
||||||
|
|
||||||
// return sum;
|
|
||||||
// },
|
|
||||||
// [](Sum a, Sum b) -> Sum { return Sum(a.first+b.first, a.second+b.second); },
|
|
||||||
// tbb::static_partitioner()
|
|
||||||
// );
|
|
||||||
// -- GODOT end --
|
|
||||||
|
|
||||||
return (sum.second > 0) ? (key / exp2(sum.first / float(sum.second))) : 1.f;
|
|
||||||
}
|
|
||||||
|
|
||||||
} // namespace oidn
|
|
201
thirdparty/oidn/core/transfer_function.h
vendored
201
thirdparty/oidn/core/transfer_function.h
vendored
@ -1,201 +0,0 @@
|
|||||||
// ======================================================================== //
|
|
||||||
// Copyright 2009-2019 Intel Corporation //
|
|
||||||
// //
|
|
||||||
// Licensed under the Apache License, Version 2.0 (the "License"); //
|
|
||||||
// you may not use this file except in compliance with the License. //
|
|
||||||
// You may obtain a copy of the License at //
|
|
||||||
// //
|
|
||||||
// http://www.apache.org/licenses/LICENSE-2.0 //
|
|
||||||
// //
|
|
||||||
// Unless required by applicable law or agreed to in writing, software //
|
|
||||||
// distributed under the License is distributed on an "AS IS" BASIS, //
|
|
||||||
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. //
|
|
||||||
// See the License for the specific language governing permissions and //
|
|
||||||
// limitations under the License. //
|
|
||||||
// ======================================================================== //
|
|
||||||
|
|
||||||
#pragma once
|
|
||||||
|
|
||||||
#include "image.h"
|
|
||||||
#include "node.h"
|
|
||||||
|
|
||||||
namespace oidn {
|
|
||||||
|
|
||||||
__forceinline float luminance(float r, float g, float b)
|
|
||||||
{
|
|
||||||
return 0.212671f * r + 0.715160f * g + 0.072169f * b;
|
|
||||||
}
|
|
||||||
|
|
||||||
// Color transfer function base class
|
|
||||||
class TransferFunction
|
|
||||||
{
|
|
||||||
public:
|
|
||||||
virtual ~TransferFunction() = default;
|
|
||||||
|
|
||||||
virtual float forward(float y) const = 0;
|
|
||||||
virtual float inverse(float x) const = 0;
|
|
||||||
};
|
|
||||||
|
|
||||||
// HDR transfer function base class
|
|
||||||
class HDRTransferFunction : public TransferFunction
|
|
||||||
{
|
|
||||||
protected:
|
|
||||||
static constexpr float yMax = 65504.f;
|
|
||||||
|
|
||||||
float exposure;
|
|
||||||
float rcpExposure;
|
|
||||||
|
|
||||||
public:
|
|
||||||
HDRTransferFunction(float exposure = 1.f)
|
|
||||||
{
|
|
||||||
setExposure(exposure);
|
|
||||||
}
|
|
||||||
|
|
||||||
void setExposure(float exposure)
|
|
||||||
{
|
|
||||||
this->exposure = exposure;
|
|
||||||
this->rcpExposure = (exposure != 0.f) ? (1.f / exposure) : 0.f;
|
|
||||||
}
|
|
||||||
};
|
|
||||||
|
|
||||||
// Linear transfer function (LDR)
|
|
||||||
class LinearTransferFunction : public TransferFunction
|
|
||||||
{
|
|
||||||
public:
|
|
||||||
__forceinline float forward(float y) const override
|
|
||||||
{
|
|
||||||
return min(y, 1.f);
|
|
||||||
}
|
|
||||||
|
|
||||||
__forceinline float inverse(float x) const override
|
|
||||||
{
|
|
||||||
return min(x, 1.f);
|
|
||||||
}
|
|
||||||
};
|
|
||||||
|
|
||||||
// 2.2 gamma transfer function (LDR)
|
|
||||||
class GammaTransferFunction : public TransferFunction
|
|
||||||
{
|
|
||||||
public:
|
|
||||||
__forceinline float forward(float y) const override
|
|
||||||
{
|
|
||||||
return min(pow(y, 1.f/2.2f), 1.f);
|
|
||||||
}
|
|
||||||
|
|
||||||
__forceinline float inverse(float x) const override
|
|
||||||
{
|
|
||||||
return min(pow(x, 2.2f), 1.f);
|
|
||||||
}
|
|
||||||
};
|
|
||||||
|
|
||||||
// Logarithmic transfer function (HDR)
|
|
||||||
// Compresses [0..65504] to [0..1]
|
|
||||||
class LogTransferFunction : public HDRTransferFunction
|
|
||||||
{
|
|
||||||
private:
|
|
||||||
static const float xScale;
|
|
||||||
|
|
||||||
public:
|
|
||||||
LogTransferFunction(float exposure = 1.f)
|
|
||||||
: HDRTransferFunction(exposure)
|
|
||||||
{
|
|
||||||
}
|
|
||||||
|
|
||||||
__forceinline float forward(float y) const override
|
|
||||||
{
|
|
||||||
return log(y * exposure + 1.f) * xScale;
|
|
||||||
}
|
|
||||||
|
|
||||||
__forceinline float inverse(float x) const override
|
|
||||||
{
|
|
||||||
return (exp(x * (1.f/xScale)) - 1.f) * rcpExposure;
|
|
||||||
}
|
|
||||||
};
|
|
||||||
|
|
||||||
// PQX transfer function (HDR)
|
|
||||||
// Compresses [0..65504] to [0..1]
|
|
||||||
class PQXTransferFunction : public HDRTransferFunction
|
|
||||||
{
|
|
||||||
private:
|
|
||||||
static constexpr float m1 = 2610.f / 4096.f / 4.f;
|
|
||||||
static constexpr float m2 = 2523.f / 4096.f * 128.f;
|
|
||||||
static constexpr float c1 = 3424.f / 4096.f;
|
|
||||||
static constexpr float c2 = 2413.f / 4096.f * 32.f;
|
|
||||||
static constexpr float c3 = 2392.f / 4096.f * 32.f;
|
|
||||||
static constexpr float a = 3711.f / 4096.f / 8.f;
|
|
||||||
|
|
||||||
static constexpr float yScale = 100.f / 10000.f;
|
|
||||||
static const float xScale;
|
|
||||||
|
|
||||||
public:
|
|
||||||
PQXTransferFunction(float exposure = 1.f)
|
|
||||||
: HDRTransferFunction(exposure)
|
|
||||||
{
|
|
||||||
}
|
|
||||||
|
|
||||||
__forceinline float forward(float y) const override
|
|
||||||
{
|
|
||||||
return pqxForward(y * exposure * yScale) * xScale;
|
|
||||||
}
|
|
||||||
|
|
||||||
__forceinline float inverse(float x) const override
|
|
||||||
{
|
|
||||||
return pqxInverse(x * (1.f/xScale)) * (1.f/yScale) * rcpExposure;
|
|
||||||
}
|
|
||||||
|
|
||||||
private:
|
|
||||||
static __forceinline float pqForward(float y)
|
|
||||||
{
|
|
||||||
const float yp = pow(y, m1);
|
|
||||||
return pow((c1 + c2 * yp) * rcp(1.f + c3 * yp), m2);
|
|
||||||
}
|
|
||||||
|
|
||||||
static __forceinline float pqxForward(float y)
|
|
||||||
{
|
|
||||||
if (y <= 1.f)
|
|
||||||
return pqForward(y);
|
|
||||||
else
|
|
||||||
return a * log(y) + 1.f;
|
|
||||||
}
|
|
||||||
|
|
||||||
static __forceinline float pqInverse(float x)
|
|
||||||
{
|
|
||||||
const float xp = pow(x, 1.f/m2);
|
|
||||||
return pow(max((xp - c1) * rcp(c2 - c3 * xp), 0.f), 1.f/m1);
|
|
||||||
}
|
|
||||||
|
|
||||||
static __forceinline float pqxInverse(float x)
|
|
||||||
{
|
|
||||||
if (x <= 1.f)
|
|
||||||
return pqInverse(x);
|
|
||||||
else
|
|
||||||
return exp((x - 1.f) * (1.f/a));
|
|
||||||
}
|
|
||||||
};
|
|
||||||
|
|
||||||
// Autoexposure node
|
|
||||||
class AutoexposureNode : public Node
|
|
||||||
{
|
|
||||||
private:
|
|
||||||
Image color;
|
|
||||||
std::shared_ptr<HDRTransferFunction> transferFunc;
|
|
||||||
|
|
||||||
public:
|
|
||||||
AutoexposureNode(const Image& color,
|
|
||||||
const std::shared_ptr<HDRTransferFunction>& transferFunc)
|
|
||||||
: color(color),
|
|
||||||
transferFunc(transferFunc)
|
|
||||||
{}
|
|
||||||
|
|
||||||
void execute(stream& sm) override
|
|
||||||
{
|
|
||||||
const float exposure = autoexposure(color);
|
|
||||||
//printf("exposure = %f\n", exposure);
|
|
||||||
transferFunc->setExposure(exposure);
|
|
||||||
}
|
|
||||||
|
|
||||||
private:
|
|
||||||
static float autoexposure(const Image& color);
|
|
||||||
};
|
|
||||||
|
|
||||||
} // namespace oidn
|
|
92
thirdparty/oidn/core/upsample.h
vendored
92
thirdparty/oidn/core/upsample.h
vendored
@ -1,92 +0,0 @@
|
|||||||
// ======================================================================== //
|
|
||||||
// Copyright 2009-2019 Intel Corporation //
|
|
||||||
// //
|
|
||||||
// Licensed under the Apache License, Version 2.0 (the "License"); //
|
|
||||||
// you may not use this file except in compliance with the License. //
|
|
||||||
// You may obtain a copy of the License at //
|
|
||||||
// //
|
|
||||||
// http://www.apache.org/licenses/LICENSE-2.0 //
|
|
||||||
// //
|
|
||||||
// Unless required by applicable law or agreed to in writing, software //
|
|
||||||
// distributed under the License is distributed on an "AS IS" BASIS, //
|
|
||||||
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. //
|
|
||||||
// See the License for the specific language governing permissions and //
|
|
||||||
// limitations under the License. //
|
|
||||||
// ======================================================================== //
|
|
||||||
|
|
||||||
#pragma once
|
|
||||||
|
|
||||||
#include "node.h"
|
|
||||||
|
|
||||||
namespace oidn {
|
|
||||||
|
|
||||||
// 2x2 nearest-neighbor upsampling node
|
|
||||||
template<int K>
|
|
||||||
class UpsampleNode : public Node
|
|
||||||
{
|
|
||||||
private:
|
|
||||||
std::shared_ptr<memory> src;
|
|
||||||
std::shared_ptr<memory> dst;
|
|
||||||
|
|
||||||
public:
|
|
||||||
UpsampleNode(const std::shared_ptr<memory>& src,
|
|
||||||
const std::shared_ptr<memory>& dst)
|
|
||||||
: src(src),
|
|
||||||
dst(dst)
|
|
||||||
{
|
|
||||||
const mkldnn_memory_desc_t& srcDesc = src->get_desc().data;
|
|
||||||
const mkldnn_memory_desc_t& dstDesc = dst->get_desc().data;
|
|
||||||
MAYBE_UNUSED(srcDesc);
|
|
||||||
MAYBE_UNUSED(dstDesc);
|
|
||||||
assert(memory_desc_matches_tag(srcDesc, mkldnn_format_tag_t(BlockedFormat<K>::nChwKc)));
|
|
||||||
assert(memory_desc_matches_tag(dstDesc, mkldnn_format_tag_t(BlockedFormat<K>::nChwKc)));
|
|
||||||
assert(srcDesc.ndims == 4);
|
|
||||||
assert(dstDesc.ndims == 4);
|
|
||||||
assert(srcDesc.data_type == memory::data_type::f32);
|
|
||||||
assert(dstDesc.data_type == memory::data_type::f32);
|
|
||||||
assert(srcDesc.dims[0] == 1);
|
|
||||||
assert(dstDesc.dims[0] == 1);
|
|
||||||
// 2x2 upsampling
|
|
||||||
assert(dstDesc.dims[2] == srcDesc.dims[2] * 2);
|
|
||||||
assert(dstDesc.dims[3] == srcDesc.dims[3] * 2);
|
|
||||||
}
|
|
||||||
|
|
||||||
void execute(stream& sm) override
|
|
||||||
{
|
|
||||||
const mkldnn_memory_desc_t& srcDesc = src->get_desc().data;
|
|
||||||
|
|
||||||
const float* srcPtr = (float*)src->get_data_handle();
|
|
||||||
float* dstPtr = (float*)dst->get_data_handle();
|
|
||||||
|
|
||||||
const int C = srcDesc.dims[1];
|
|
||||||
const int H = srcDesc.dims[2];
|
|
||||||
const int W = srcDesc.dims[3];
|
|
||||||
const int CK = C / K;
|
|
||||||
|
|
||||||
parallel_nd(CK, H, [&](int ck, int h)
|
|
||||||
{
|
|
||||||
const size_t offset = ck*H*W*K + h*W*K;
|
|
||||||
const float* srcPtr_line = srcPtr + offset;
|
|
||||||
float* dstPtr_line0 = dstPtr + offset * 4;
|
|
||||||
float* dstPtr_line1 = dstPtr_line0 + W*2*K; // next line
|
|
||||||
|
|
||||||
for (int w = 0; w < W; ++w)
|
|
||||||
{
|
|
||||||
#pragma unroll
|
|
||||||
for (int k = 0; k < K; k += 4)
|
|
||||||
{
|
|
||||||
const __m128 m = _mm_load_ps(&srcPtr_line[w*K + k]);
|
|
||||||
|
|
||||||
_mm_stream_ps(&dstPtr_line0[w*2*K + k], m);
|
|
||||||
_mm_stream_ps(&dstPtr_line0[w*2*K+K + k], m);
|
|
||||||
_mm_stream_ps(&dstPtr_line1[w*2*K + k], m);
|
|
||||||
_mm_stream_ps(&dstPtr_line1[w*2*K+K + k], m);
|
|
||||||
}
|
|
||||||
}
|
|
||||||
});
|
|
||||||
}
|
|
||||||
|
|
||||||
std::shared_ptr<memory> getDst() const override { return dst; }
|
|
||||||
};
|
|
||||||
|
|
||||||
} // namespace oidn
|
|
99
thirdparty/oidn/core/weights_reorder.h
vendored
99
thirdparty/oidn/core/weights_reorder.h
vendored
@ -1,99 +0,0 @@
|
|||||||
// ======================================================================== //
|
|
||||||
// Copyright 2009-2019 Intel Corporation //
|
|
||||||
// //
|
|
||||||
// Licensed under the Apache License, Version 2.0 (the "License"); //
|
|
||||||
// you may not use this file except in compliance with the License. //
|
|
||||||
// You may obtain a copy of the License at //
|
|
||||||
// //
|
|
||||||
// http://www.apache.org/licenses/LICENSE-2.0 //
|
|
||||||
// //
|
|
||||||
// Unless required by applicable law or agreed to in writing, software //
|
|
||||||
// distributed under the License is distributed on an "AS IS" BASIS, //
|
|
||||||
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. //
|
|
||||||
// See the License for the specific language governing permissions and //
|
|
||||||
// limitations under the License. //
|
|
||||||
// ======================================================================== //
|
|
||||||
|
|
||||||
#pragma once
|
|
||||||
|
|
||||||
#include "node.h"
|
|
||||||
|
|
||||||
namespace oidn {
|
|
||||||
|
|
||||||
// Reorders weights from oihw to padded oihw format
|
|
||||||
template<int K>
|
|
||||||
class WeightsReorderNode : public Node
|
|
||||||
{
|
|
||||||
private:
|
|
||||||
std::shared_ptr<memory> src;
|
|
||||||
std::shared_ptr<memory> dst;
|
|
||||||
|
|
||||||
public:
|
|
||||||
WeightsReorderNode(const std::shared_ptr<memory>& src,
|
|
||||||
const std::shared_ptr<memory>& dst)
|
|
||||||
: src(src),
|
|
||||||
dst(dst)
|
|
||||||
{
|
|
||||||
const mkldnn_memory_desc_t& srcDesc = src->get_desc().data;
|
|
||||||
const mkldnn_memory_desc_t& dstDesc = dst->get_desc().data;
|
|
||||||
MAYBE_UNUSED(srcDesc);
|
|
||||||
MAYBE_UNUSED(dstDesc);
|
|
||||||
assert(memory_desc_matches_tag(srcDesc, mkldnn_format_tag_t(memory::format_tag::oihw)));
|
|
||||||
assert(memory_desc_matches_tag(dstDesc, mkldnn_format_tag_t(memory::format_tag::oihw)));
|
|
||||||
assert(srcDesc.ndims == 4);
|
|
||||||
assert(dstDesc.ndims == 4);
|
|
||||||
assert(srcDesc.data_type == memory::data_type::f32);
|
|
||||||
assert(dstDesc.data_type == memory::data_type::f32);
|
|
||||||
assert(getPadded<K>(srcDesc.dims[0]) == dstDesc.dims[0]); // OC
|
|
||||||
assert(getPadded<K>(srcDesc.dims[1]) == dstDesc.dims[1]); // IC
|
|
||||||
assert(srcDesc.dims[2] == dstDesc.dims[2]);
|
|
||||||
assert(srcDesc.dims[3] == dstDesc.dims[3]);
|
|
||||||
}
|
|
||||||
|
|
||||||
void execute(stream& sm) override
|
|
||||||
{
|
|
||||||
const mkldnn_memory_desc_t& srcDesc = src->get_desc().data;
|
|
||||||
const mkldnn_memory_desc_t& dstDesc = dst->get_desc().data;
|
|
||||||
|
|
||||||
const float* srcPtr = (float*)src->get_data_handle();
|
|
||||||
float* dstPtr = (float*)dst->get_data_handle();
|
|
||||||
|
|
||||||
const int OC1 = srcDesc.dims[0];
|
|
||||||
const int OC2 = dstDesc.dims[0];
|
|
||||||
const int IC1 = srcDesc.dims[1];
|
|
||||||
const int IC2 = dstDesc.dims[1];
|
|
||||||
const int H = dstDesc.dims[2];
|
|
||||||
const int W = dstDesc.dims[3];
|
|
||||||
|
|
||||||
for (int oc = 0; oc < OC2; ++oc)
|
|
||||||
{
|
|
||||||
for (int ic = 0; ic < IC2; ++ic)
|
|
||||||
{
|
|
||||||
for (int h = 0; h < H; ++h)
|
|
||||||
{
|
|
||||||
for (int w = 0; w < W; ++w)
|
|
||||||
{
|
|
||||||
// Output is in oihw format
|
|
||||||
float* dstPtr_c = dstPtr + oc*IC2*H*W + ic*H*W + h*W + w;
|
|
||||||
|
|
||||||
if (oc < OC1 && ic < IC1)
|
|
||||||
{
|
|
||||||
// Input is in oihw format
|
|
||||||
const float* srcPtr_c = srcPtr + oc*IC1*H*W + ic*H*W + h*W + w;
|
|
||||||
*dstPtr_c = *srcPtr_c;
|
|
||||||
}
|
|
||||||
else
|
|
||||||
{
|
|
||||||
// padding
|
|
||||||
*dstPtr_c = 0;
|
|
||||||
}
|
|
||||||
}
|
|
||||||
}
|
|
||||||
}
|
|
||||||
}
|
|
||||||
}
|
|
||||||
|
|
||||||
std::shared_ptr<memory> getDst() const override { return dst; }
|
|
||||||
};
|
|
||||||
|
|
||||||
} // namespace oidn
|
|
214
thirdparty/oidn/include/OpenImageDenoise/oidn.h
vendored
214
thirdparty/oidn/include/OpenImageDenoise/oidn.h
vendored
@ -1,214 +0,0 @@
|
|||||||
// ======================================================================== //
|
|
||||||
// Copyright 2009-2019 Intel Corporation //
|
|
||||||
// //
|
|
||||||
// Licensed under the Apache License, Version 2.0 (the "License"); //
|
|
||||||
// you may not use this file except in compliance with the License. //
|
|
||||||
// You may obtain a copy of the License at //
|
|
||||||
// //
|
|
||||||
// http://www.apache.org/licenses/LICENSE-2.0 //
|
|
||||||
// //
|
|
||||||
// Unless required by applicable law or agreed to in writing, software //
|
|
||||||
// distributed under the License is distributed on an "AS IS" BASIS, //
|
|
||||||
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. //
|
|
||||||
// See the License for the specific language governing permissions and //
|
|
||||||
// limitations under the License. //
|
|
||||||
// ======================================================================== //
|
|
||||||
|
|
||||||
#pragma once
|
|
||||||
|
|
||||||
#include <stddef.h>
|
|
||||||
#include <stdbool.h>
|
|
||||||
#include <stdint.h>
|
|
||||||
|
|
||||||
#include "version.h"
|
|
||||||
|
|
||||||
#if defined(__cplusplus)
|
|
||||||
extern "C" {
|
|
||||||
#endif
|
|
||||||
|
|
||||||
#ifndef OIDN_API
|
|
||||||
#if defined(_WIN32) && !defined(OIDN_STATIC_LIB)
|
|
||||||
# define OIDN_API __declspec(dllimport)
|
|
||||||
#else
|
|
||||||
# define OIDN_API
|
|
||||||
#endif
|
|
||||||
#endif
|
|
||||||
|
|
||||||
// ----------------------------------------------------------------------------
|
|
||||||
// Device
|
|
||||||
// ----------------------------------------------------------------------------
|
|
||||||
|
|
||||||
// Device types
|
|
||||||
typedef enum
|
|
||||||
{
|
|
||||||
OIDN_DEVICE_TYPE_DEFAULT = 0, // select device automatically
|
|
||||||
|
|
||||||
OIDN_DEVICE_TYPE_CPU = 1, // CPU device
|
|
||||||
} OIDNDeviceType;
|
|
||||||
|
|
||||||
// Error codes
|
|
||||||
typedef enum
|
|
||||||
{
|
|
||||||
OIDN_ERROR_NONE = 0, // no error occurred
|
|
||||||
OIDN_ERROR_UNKNOWN = 1, // an unknown error occurred
|
|
||||||
OIDN_ERROR_INVALID_ARGUMENT = 2, // an invalid argument was specified
|
|
||||||
OIDN_ERROR_INVALID_OPERATION = 3, // the operation is not allowed
|
|
||||||
OIDN_ERROR_OUT_OF_MEMORY = 4, // not enough memory to execute the operation
|
|
||||||
OIDN_ERROR_UNSUPPORTED_HARDWARE = 5, // the hardware (e.g. CPU) is not supported
|
|
||||||
OIDN_ERROR_CANCELLED = 6, // the operation was cancelled by the user
|
|
||||||
} OIDNError;
|
|
||||||
|
|
||||||
// Error callback function
|
|
||||||
typedef void (*OIDNErrorFunction)(void* userPtr, OIDNError code, const char* message);
|
|
||||||
|
|
||||||
// Device handle
|
|
||||||
typedef struct OIDNDeviceImpl* OIDNDevice;
|
|
||||||
|
|
||||||
// Creates a new device.
|
|
||||||
OIDN_API OIDNDevice oidnNewDevice(OIDNDeviceType type);
|
|
||||||
|
|
||||||
// Retains the device (increments the reference count).
|
|
||||||
OIDN_API void oidnRetainDevice(OIDNDevice device);
|
|
||||||
|
|
||||||
// Releases the device (decrements the reference count).
|
|
||||||
OIDN_API void oidnReleaseDevice(OIDNDevice device);
|
|
||||||
|
|
||||||
// Sets a boolean parameter of the device.
|
|
||||||
OIDN_API void oidnSetDevice1b(OIDNDevice device, const char* name, bool value);
|
|
||||||
|
|
||||||
// Sets an integer parameter of the device.
|
|
||||||
OIDN_API void oidnSetDevice1i(OIDNDevice device, const char* name, int value);
|
|
||||||
|
|
||||||
// Gets a boolean parameter of the device.
|
|
||||||
OIDN_API bool oidnGetDevice1b(OIDNDevice device, const char* name);
|
|
||||||
|
|
||||||
// Gets an integer parameter of the device (e.g. "version").
|
|
||||||
OIDN_API int oidnGetDevice1i(OIDNDevice device, const char* name);
|
|
||||||
|
|
||||||
// Sets the error callback function of the device.
|
|
||||||
OIDN_API void oidnSetDeviceErrorFunction(OIDNDevice device, OIDNErrorFunction func, void* userPtr);
|
|
||||||
|
|
||||||
// Returns the first unqueried error code stored in the device for the current
|
|
||||||
// thread, optionally also returning a string message (if not NULL), and clears
|
|
||||||
// the stored error. Can be called with a NULL device as well to check why a
|
|
||||||
// device creation failed.
|
|
||||||
OIDN_API OIDNError oidnGetDeviceError(OIDNDevice device, const char** outMessage);
|
|
||||||
|
|
||||||
// Commits all previous changes to the device.
|
|
||||||
// Must be called before first using the device (e.g. creating filters).
|
|
||||||
OIDN_API void oidnCommitDevice(OIDNDevice device);
|
|
||||||
|
|
||||||
// ----------------------------------------------------------------------------
|
|
||||||
// Buffer
|
|
||||||
// ----------------------------------------------------------------------------
|
|
||||||
|
|
||||||
// Formats for images and other data stored in buffers
|
|
||||||
typedef enum
|
|
||||||
{
|
|
||||||
OIDN_FORMAT_UNDEFINED = 0,
|
|
||||||
|
|
||||||
// 32-bit single-precision floating point scalar and vector formats
|
|
||||||
OIDN_FORMAT_FLOAT = 1,
|
|
||||||
OIDN_FORMAT_FLOAT2 = 2,
|
|
||||||
OIDN_FORMAT_FLOAT3 = 3,
|
|
||||||
OIDN_FORMAT_FLOAT4 = 4,
|
|
||||||
} OIDNFormat;
|
|
||||||
|
|
||||||
// Access modes for mapping buffers
|
|
||||||
typedef enum
|
|
||||||
{
|
|
||||||
OIDN_ACCESS_READ = 0, // read-only access
|
|
||||||
OIDN_ACCESS_WRITE = 1, // write-only access
|
|
||||||
OIDN_ACCESS_READ_WRITE = 2, // read and write access
|
|
||||||
OIDN_ACCESS_WRITE_DISCARD = 3, // write-only access, previous contents discarded
|
|
||||||
} OIDNAccess;
|
|
||||||
|
|
||||||
// Buffer handle
|
|
||||||
typedef struct OIDNBufferImpl* OIDNBuffer;
|
|
||||||
|
|
||||||
// Creates a new buffer (data allocated and owned by the device).
|
|
||||||
OIDN_API OIDNBuffer oidnNewBuffer(OIDNDevice device, size_t byteSize);
|
|
||||||
|
|
||||||
// Creates a new shared buffer (data allocated and owned by the user).
|
|
||||||
OIDN_API OIDNBuffer oidnNewSharedBuffer(OIDNDevice device, void* ptr, size_t byteSize);
|
|
||||||
|
|
||||||
// Maps a region of the buffer to host memory.
|
|
||||||
// If byteSize is 0, the maximum available amount of memory will be mapped.
|
|
||||||
OIDN_API void* oidnMapBuffer(OIDNBuffer buffer, OIDNAccess access, size_t byteOffset, size_t byteSize);
|
|
||||||
|
|
||||||
// Unmaps a region of the buffer.
|
|
||||||
// mappedPtr must be a pointer returned by a previous call to oidnMapBuffer.
|
|
||||||
OIDN_API void oidnUnmapBuffer(OIDNBuffer buffer, void* mappedPtr);
|
|
||||||
|
|
||||||
// Retains the buffer (increments the reference count).
|
|
||||||
OIDN_API void oidnRetainBuffer(OIDNBuffer buffer);
|
|
||||||
|
|
||||||
// Releases the buffer (decrements the reference count).
|
|
||||||
OIDN_API void oidnReleaseBuffer(OIDNBuffer buffer);
|
|
||||||
|
|
||||||
// ----------------------------------------------------------------------------
|
|
||||||
// Filter
|
|
||||||
// ----------------------------------------------------------------------------
|
|
||||||
|
|
||||||
// Progress monitor callback function
|
|
||||||
typedef bool (*OIDNProgressMonitorFunction)(void* userPtr, double n);
|
|
||||||
|
|
||||||
// Filter handle
|
|
||||||
typedef struct OIDNFilterImpl* OIDNFilter;
|
|
||||||
|
|
||||||
// Creates a new filter of the specified type (e.g. "RT").
|
|
||||||
OIDN_API OIDNFilter oidnNewFilter(OIDNDevice device, const char* type);
|
|
||||||
|
|
||||||
// Retains the filter (increments the reference count).
|
|
||||||
OIDN_API void oidnRetainFilter(OIDNFilter filter);
|
|
||||||
|
|
||||||
// Releases the filter (decrements the reference count).
|
|
||||||
OIDN_API void oidnReleaseFilter(OIDNFilter filter);
|
|
||||||
|
|
||||||
// Sets an image parameter of the filter (stored in a buffer).
|
|
||||||
// If bytePixelStride and/or byteRowStride are zero, these will be computed automatically.
|
|
||||||
OIDN_API void oidnSetFilterImage(OIDNFilter filter, const char* name,
|
|
||||||
OIDNBuffer buffer, OIDNFormat format,
|
|
||||||
size_t width, size_t height,
|
|
||||||
size_t byteOffset,
|
|
||||||
size_t bytePixelStride, size_t byteRowStride);
|
|
||||||
|
|
||||||
// Sets an image parameter of the filter (owned by the user).
|
|
||||||
// If bytePixelStride and/or byteRowStride are zero, these will be computed automatically.
|
|
||||||
OIDN_API void oidnSetSharedFilterImage(OIDNFilter filter, const char* name,
|
|
||||||
void* ptr, OIDNFormat format,
|
|
||||||
size_t width, size_t height,
|
|
||||||
size_t byteOffset,
|
|
||||||
size_t bytePixelStride, size_t byteRowStride);
|
|
||||||
|
|
||||||
// Sets a boolean parameter of the filter.
|
|
||||||
OIDN_API void oidnSetFilter1b(OIDNFilter filter, const char* name, bool value);
|
|
||||||
|
|
||||||
// Gets a boolean parameter of the filter.
|
|
||||||
OIDN_API bool oidnGetFilter1b(OIDNFilter filter, const char* name);
|
|
||||||
|
|
||||||
// Sets an integer parameter of the filter.
|
|
||||||
OIDN_API void oidnSetFilter1i(OIDNFilter filter, const char* name, int value);
|
|
||||||
|
|
||||||
// Gets an integer parameter of the filter.
|
|
||||||
OIDN_API int oidnGetFilter1i(OIDNFilter filter, const char* name);
|
|
||||||
|
|
||||||
// Sets a float parameter of the filter.
|
|
||||||
OIDN_API void oidnSetFilter1f(OIDNFilter filter, const char* name, float value);
|
|
||||||
|
|
||||||
// Gets a float parameter of the filter.
|
|
||||||
OIDN_API float oidnGetFilter1f(OIDNFilter filter, const char* name);
|
|
||||||
|
|
||||||
// Sets the progress monitor callback function of the filter.
|
|
||||||
OIDN_API void oidnSetFilterProgressMonitorFunction(OIDNFilter filter, OIDNProgressMonitorFunction func, void* userPtr);
|
|
||||||
|
|
||||||
// Commits all previous changes to the filter.
|
|
||||||
// Must be called before first executing the filter.
|
|
||||||
OIDN_API void oidnCommitFilter(OIDNFilter filter);
|
|
||||||
|
|
||||||
// Executes the filter.
|
|
||||||
OIDN_API void oidnExecuteFilter(OIDNFilter filter);
|
|
||||||
|
|
||||||
#if defined(__cplusplus)
|
|
||||||
}
|
|
||||||
#endif
|
|
468
thirdparty/oidn/include/OpenImageDenoise/oidn.hpp
vendored
468
thirdparty/oidn/include/OpenImageDenoise/oidn.hpp
vendored
@ -1,468 +0,0 @@
|
|||||||
// ======================================================================== //
|
|
||||||
// Copyright 2009-2019 Intel Corporation //
|
|
||||||
// //
|
|
||||||
// Licensed under the Apache License, Version 2.0 (the "License"); //
|
|
||||||
// you may not use this file except in compliance with the License. //
|
|
||||||
// You may obtain a copy of the License at //
|
|
||||||
// //
|
|
||||||
// http://www.apache.org/licenses/LICENSE-2.0 //
|
|
||||||
// //
|
|
||||||
// Unless required by applicable law or agreed to in writing, software //
|
|
||||||
// distributed under the License is distributed on an "AS IS" BASIS, //
|
|
||||||
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. //
|
|
||||||
// See the License for the specific language governing permissions and //
|
|
||||||
// limitations under the License. //
|
|
||||||
// ======================================================================== //
|
|
||||||
|
|
||||||
#pragma once
|
|
||||||
|
|
||||||
#include <algorithm>
|
|
||||||
#include "oidn.h"
|
|
||||||
|
|
||||||
namespace oidn {
|
|
||||||
|
|
||||||
// --------------------------------------------------------------------------
|
|
||||||
// Buffer
|
|
||||||
// --------------------------------------------------------------------------
|
|
||||||
|
|
||||||
// Formats for images and other data stored in buffers
|
|
||||||
enum class Format
|
|
||||||
{
|
|
||||||
Undefined = OIDN_FORMAT_UNDEFINED,
|
|
||||||
|
|
||||||
// 32-bit single-precision floating point scalar and vector formats
|
|
||||||
Float = OIDN_FORMAT_FLOAT,
|
|
||||||
Float2 = OIDN_FORMAT_FLOAT2,
|
|
||||||
Float3 = OIDN_FORMAT_FLOAT3,
|
|
||||||
Float4 = OIDN_FORMAT_FLOAT4,
|
|
||||||
};
|
|
||||||
|
|
||||||
// Access modes for mapping buffers
|
|
||||||
enum class Access
|
|
||||||
{
|
|
||||||
Read = OIDN_ACCESS_READ, // read-only access
|
|
||||||
Write = OIDN_ACCESS_WRITE, // write-only access
|
|
||||||
ReadWrite = OIDN_ACCESS_READ_WRITE, // read and write access
|
|
||||||
WriteDiscard = OIDN_ACCESS_WRITE_DISCARD, // write-only access, previous contents discarded
|
|
||||||
};
|
|
||||||
|
|
||||||
// Buffer object with automatic reference counting
|
|
||||||
class BufferRef
|
|
||||||
{
|
|
||||||
private:
|
|
||||||
OIDNBuffer handle;
|
|
||||||
|
|
||||||
public:
|
|
||||||
BufferRef() : handle(nullptr) {}
|
|
||||||
BufferRef(OIDNBuffer handle) : handle(handle) {}
|
|
||||||
|
|
||||||
BufferRef(const BufferRef& other) : handle(other.handle)
|
|
||||||
{
|
|
||||||
if (handle)
|
|
||||||
oidnRetainBuffer(handle);
|
|
||||||
}
|
|
||||||
|
|
||||||
BufferRef(BufferRef&& other) : handle(other.handle)
|
|
||||||
{
|
|
||||||
other.handle = nullptr;
|
|
||||||
}
|
|
||||||
|
|
||||||
BufferRef& operator =(const BufferRef& other)
|
|
||||||
{
|
|
||||||
if (&other != this)
|
|
||||||
{
|
|
||||||
if (other.handle)
|
|
||||||
oidnRetainBuffer(other.handle);
|
|
||||||
if (handle)
|
|
||||||
oidnReleaseBuffer(handle);
|
|
||||||
handle = other.handle;
|
|
||||||
}
|
|
||||||
return *this;
|
|
||||||
}
|
|
||||||
|
|
||||||
BufferRef& operator =(BufferRef&& other)
|
|
||||||
{
|
|
||||||
std::swap(handle, other.handle);
|
|
||||||
return *this;
|
|
||||||
}
|
|
||||||
|
|
||||||
BufferRef& operator =(OIDNBuffer other)
|
|
||||||
{
|
|
||||||
if (other)
|
|
||||||
oidnRetainBuffer(other);
|
|
||||||
if (handle)
|
|
||||||
oidnReleaseBuffer(handle);
|
|
||||||
handle = other;
|
|
||||||
return *this;
|
|
||||||
}
|
|
||||||
|
|
||||||
~BufferRef()
|
|
||||||
{
|
|
||||||
if (handle)
|
|
||||||
oidnReleaseBuffer(handle);
|
|
||||||
}
|
|
||||||
|
|
||||||
OIDNBuffer getHandle() const
|
|
||||||
{
|
|
||||||
return handle;
|
|
||||||
}
|
|
||||||
|
|
||||||
operator bool() const
|
|
||||||
{
|
|
||||||
return handle != nullptr;
|
|
||||||
}
|
|
||||||
|
|
||||||
// Maps a region of the buffer to host memory.
|
|
||||||
// If byteSize is 0, the maximum available amount of memory will be mapped.
|
|
||||||
void* map(Access access = Access::ReadWrite, size_t byteOffset = 0, size_t byteSize = 0)
|
|
||||||
{
|
|
||||||
return oidnMapBuffer(handle, (OIDNAccess)access, byteOffset, byteSize);
|
|
||||||
}
|
|
||||||
|
|
||||||
// Unmaps a region of the buffer.
|
|
||||||
// mappedPtr must be a pointer returned by a previous call to map.
|
|
||||||
void unmap(void* mappedPtr)
|
|
||||||
{
|
|
||||||
oidnUnmapBuffer(handle, mappedPtr);
|
|
||||||
}
|
|
||||||
};
|
|
||||||
|
|
||||||
// --------------------------------------------------------------------------
|
|
||||||
// Filter
|
|
||||||
// --------------------------------------------------------------------------
|
|
||||||
|
|
||||||
// Progress monitor callback function
|
|
||||||
typedef bool (*ProgressMonitorFunction)(void* userPtr, double n);
|
|
||||||
|
|
||||||
// Filter object with automatic reference counting
|
|
||||||
class FilterRef
|
|
||||||
{
|
|
||||||
private:
|
|
||||||
OIDNFilter handle;
|
|
||||||
|
|
||||||
public:
|
|
||||||
FilterRef() : handle(nullptr) {}
|
|
||||||
FilterRef(OIDNFilter handle) : handle(handle) {}
|
|
||||||
|
|
||||||
FilterRef(const FilterRef& other) : handle(other.handle)
|
|
||||||
{
|
|
||||||
if (handle)
|
|
||||||
oidnRetainFilter(handle);
|
|
||||||
}
|
|
||||||
|
|
||||||
FilterRef(FilterRef&& other) : handle(other.handle)
|
|
||||||
{
|
|
||||||
other.handle = nullptr;
|
|
||||||
}
|
|
||||||
|
|
||||||
FilterRef& operator =(const FilterRef& other)
|
|
||||||
{
|
|
||||||
if (&other != this)
|
|
||||||
{
|
|
||||||
if (other.handle)
|
|
||||||
oidnRetainFilter(other.handle);
|
|
||||||
if (handle)
|
|
||||||
oidnReleaseFilter(handle);
|
|
||||||
handle = other.handle;
|
|
||||||
}
|
|
||||||
return *this;
|
|
||||||
}
|
|
||||||
|
|
||||||
FilterRef& operator =(FilterRef&& other)
|
|
||||||
{
|
|
||||||
std::swap(handle, other.handle);
|
|
||||||
return *this;
|
|
||||||
}
|
|
||||||
|
|
||||||
FilterRef& operator =(OIDNFilter other)
|
|
||||||
{
|
|
||||||
if (other)
|
|
||||||
oidnRetainFilter(other);
|
|
||||||
if (handle)
|
|
||||||
oidnReleaseFilter(handle);
|
|
||||||
handle = other;
|
|
||||||
return *this;
|
|
||||||
}
|
|
||||||
|
|
||||||
~FilterRef()
|
|
||||||
{
|
|
||||||
if (handle)
|
|
||||||
oidnReleaseFilter(handle);
|
|
||||||
}
|
|
||||||
|
|
||||||
OIDNFilter getHandle() const
|
|
||||||
{
|
|
||||||
return handle;
|
|
||||||
}
|
|
||||||
|
|
||||||
operator bool() const
|
|
||||||
{
|
|
||||||
return handle != nullptr;
|
|
||||||
}
|
|
||||||
|
|
||||||
// Sets an image parameter of the filter (stored in a buffer).
|
|
||||||
void setImage(const char* name,
|
|
||||||
const BufferRef& buffer, Format format,
|
|
||||||
size_t width, size_t height,
|
|
||||||
size_t byteOffset = 0,
|
|
||||||
size_t bytePixelStride = 0, size_t byteRowStride = 0)
|
|
||||||
{
|
|
||||||
oidnSetFilterImage(handle, name,
|
|
||||||
buffer.getHandle(), (OIDNFormat)format,
|
|
||||||
width, height,
|
|
||||||
byteOffset,
|
|
||||||
bytePixelStride, byteRowStride);
|
|
||||||
}
|
|
||||||
|
|
||||||
// Sets an image parameter of the filter (owned by the user).
|
|
||||||
void setImage(const char* name,
|
|
||||||
void* ptr, Format format,
|
|
||||||
size_t width, size_t height,
|
|
||||||
size_t byteOffset = 0,
|
|
||||||
size_t bytePixelStride = 0, size_t byteRowStride = 0)
|
|
||||||
{
|
|
||||||
oidnSetSharedFilterImage(handle, name,
|
|
||||||
ptr, (OIDNFormat)format,
|
|
||||||
width, height,
|
|
||||||
byteOffset,
|
|
||||||
bytePixelStride, byteRowStride);
|
|
||||||
}
|
|
||||||
|
|
||||||
// Sets a boolean parameter of the filter.
|
|
||||||
void set(const char* name, bool value)
|
|
||||||
{
|
|
||||||
oidnSetFilter1b(handle, name, value);
|
|
||||||
}
|
|
||||||
|
|
||||||
// Sets an integer parameter of the filter.
|
|
||||||
void set(const char* name, int value)
|
|
||||||
{
|
|
||||||
oidnSetFilter1i(handle, name, value);
|
|
||||||
}
|
|
||||||
|
|
||||||
// Sets a float parameter of the filter.
|
|
||||||
void set(const char* name, float value)
|
|
||||||
{
|
|
||||||
oidnSetFilter1f(handle, name, value);
|
|
||||||
}
|
|
||||||
|
|
||||||
// Gets a parameter of the filter.
|
|
||||||
template<typename T>
|
|
||||||
T get(const char* name);
|
|
||||||
|
|
||||||
// Sets the progress monitor callback function of the filter.
|
|
||||||
void setProgressMonitorFunction(ProgressMonitorFunction func, void* userPtr = nullptr)
|
|
||||||
{
|
|
||||||
oidnSetFilterProgressMonitorFunction(handle, (OIDNProgressMonitorFunction)func, userPtr);
|
|
||||||
}
|
|
||||||
|
|
||||||
// Commits all previous changes to the filter.
|
|
||||||
void commit()
|
|
||||||
{
|
|
||||||
oidnCommitFilter(handle);
|
|
||||||
}
|
|
||||||
|
|
||||||
// Executes the filter.
|
|
||||||
void execute()
|
|
||||||
{
|
|
||||||
oidnExecuteFilter(handle);
|
|
||||||
}
|
|
||||||
};
|
|
||||||
|
|
||||||
// Gets a boolean parameter of the filter.
|
|
||||||
template<>
|
|
||||||
inline bool FilterRef::get(const char* name)
|
|
||||||
{
|
|
||||||
return oidnGetFilter1b(handle, name);
|
|
||||||
}
|
|
||||||
|
|
||||||
// Gets an integer parameter of the filter.
|
|
||||||
template<>
|
|
||||||
inline int FilterRef::get(const char* name)
|
|
||||||
{
|
|
||||||
return oidnGetFilter1i(handle, name);
|
|
||||||
}
|
|
||||||
|
|
||||||
// Gets a float parameter of the filter.
|
|
||||||
template<>
|
|
||||||
inline float FilterRef::get(const char* name)
|
|
||||||
{
|
|
||||||
return oidnGetFilter1f(handle, name);
|
|
||||||
}
|
|
||||||
|
|
||||||
// --------------------------------------------------------------------------
|
|
||||||
// Device
|
|
||||||
// --------------------------------------------------------------------------
|
|
||||||
|
|
||||||
// Device types
|
|
||||||
enum class DeviceType
|
|
||||||
{
|
|
||||||
Default = OIDN_DEVICE_TYPE_DEFAULT, // select device automatically
|
|
||||||
|
|
||||||
CPU = OIDN_DEVICE_TYPE_CPU, // CPU device
|
|
||||||
};
|
|
||||||
|
|
||||||
// Error codes
|
|
||||||
enum class Error
|
|
||||||
{
|
|
||||||
None = OIDN_ERROR_NONE, // no error occurred
|
|
||||||
Unknown = OIDN_ERROR_UNKNOWN, // an unknown error occurred
|
|
||||||
InvalidArgument = OIDN_ERROR_INVALID_ARGUMENT, // an invalid argument was specified
|
|
||||||
InvalidOperation = OIDN_ERROR_INVALID_OPERATION, // the operation is not allowed
|
|
||||||
OutOfMemory = OIDN_ERROR_OUT_OF_MEMORY, // not enough memory to execute the operation
|
|
||||||
UnsupportedHardware = OIDN_ERROR_UNSUPPORTED_HARDWARE, // the hardware (e.g. CPU) is not supported
|
|
||||||
Cancelled = OIDN_ERROR_CANCELLED, // the operation was cancelled by the user
|
|
||||||
};
|
|
||||||
|
|
||||||
// Error callback function
|
|
||||||
typedef void (*ErrorFunction)(void* userPtr, Error code, const char* message);
|
|
||||||
|
|
||||||
// Device object with automatic reference counting
|
|
||||||
class DeviceRef
|
|
||||||
{
|
|
||||||
private:
|
|
||||||
OIDNDevice handle;
|
|
||||||
|
|
||||||
public:
|
|
||||||
DeviceRef() : handle(nullptr) {}
|
|
||||||
DeviceRef(OIDNDevice handle) : handle(handle) {}
|
|
||||||
|
|
||||||
DeviceRef(const DeviceRef& other) : handle(other.handle)
|
|
||||||
{
|
|
||||||
if (handle)
|
|
||||||
oidnRetainDevice(handle);
|
|
||||||
}
|
|
||||||
|
|
||||||
DeviceRef(DeviceRef&& other) : handle(other.handle)
|
|
||||||
{
|
|
||||||
other.handle = nullptr;
|
|
||||||
}
|
|
||||||
|
|
||||||
DeviceRef& operator =(const DeviceRef& other)
|
|
||||||
{
|
|
||||||
if (&other != this)
|
|
||||||
{
|
|
||||||
if (other.handle)
|
|
||||||
oidnRetainDevice(other.handle);
|
|
||||||
if (handle)
|
|
||||||
oidnReleaseDevice(handle);
|
|
||||||
handle = other.handle;
|
|
||||||
}
|
|
||||||
return *this;
|
|
||||||
}
|
|
||||||
|
|
||||||
DeviceRef& operator =(DeviceRef&& other)
|
|
||||||
{
|
|
||||||
std::swap(handle, other.handle);
|
|
||||||
return *this;
|
|
||||||
}
|
|
||||||
|
|
||||||
DeviceRef& operator =(OIDNDevice other)
|
|
||||||
{
|
|
||||||
if (other)
|
|
||||||
oidnRetainDevice(other);
|
|
||||||
if (handle)
|
|
||||||
oidnReleaseDevice(handle);
|
|
||||||
handle = other;
|
|
||||||
return *this;
|
|
||||||
}
|
|
||||||
|
|
||||||
~DeviceRef()
|
|
||||||
{
|
|
||||||
if (handle)
|
|
||||||
oidnReleaseDevice(handle);
|
|
||||||
}
|
|
||||||
|
|
||||||
OIDNDevice getHandle() const
|
|
||||||
{
|
|
||||||
return handle;
|
|
||||||
}
|
|
||||||
|
|
||||||
operator bool() const
|
|
||||||
{
|
|
||||||
return handle != nullptr;
|
|
||||||
}
|
|
||||||
|
|
||||||
// Sets a boolean parameter of the device.
|
|
||||||
void set(const char* name, bool value)
|
|
||||||
{
|
|
||||||
oidnSetDevice1b(handle, name, value);
|
|
||||||
}
|
|
||||||
|
|
||||||
// Sets an integer parameter of the device.
|
|
||||||
void set(const char* name, int value)
|
|
||||||
{
|
|
||||||
oidnSetDevice1i(handle, name, value);
|
|
||||||
}
|
|
||||||
|
|
||||||
// Gets a parameter of the device.
|
|
||||||
template<typename T>
|
|
||||||
T get(const char* name);
|
|
||||||
|
|
||||||
// Sets the error callback function of the device.
|
|
||||||
void setErrorFunction(ErrorFunction func, void* userPtr = nullptr)
|
|
||||||
{
|
|
||||||
oidnSetDeviceErrorFunction(handle, (OIDNErrorFunction)func, userPtr);
|
|
||||||
}
|
|
||||||
|
|
||||||
// Returns the first unqueried error code and clears the stored error.
|
|
||||||
// Can be called for a null device as well to check why a device creation failed.
|
|
||||||
Error getError()
|
|
||||||
{
|
|
||||||
return (Error)oidnGetDeviceError(handle, nullptr);
|
|
||||||
}
|
|
||||||
|
|
||||||
// Returns the first unqueried error code and string message, and clears the stored error.
|
|
||||||
// Can be called for a null device as well to check why a device creation failed.
|
|
||||||
Error getError(const char*& outMessage)
|
|
||||||
{
|
|
||||||
return (Error)oidnGetDeviceError(handle, &outMessage);
|
|
||||||
}
|
|
||||||
|
|
||||||
// Commits all previous changes to the device.
|
|
||||||
// Must be called before first using the device (e.g. creating filters).
|
|
||||||
void commit()
|
|
||||||
{
|
|
||||||
oidnCommitDevice(handle);
|
|
||||||
}
|
|
||||||
|
|
||||||
// Creates a new buffer (data allocated and owned by the device).
|
|
||||||
BufferRef newBuffer(size_t byteSize)
|
|
||||||
{
|
|
||||||
return oidnNewBuffer(handle, byteSize);
|
|
||||||
}
|
|
||||||
|
|
||||||
// Creates a new shared buffer (data allocated and owned by the user).
|
|
||||||
BufferRef newBuffer(void* ptr, size_t byteSize)
|
|
||||||
{
|
|
||||||
return oidnNewSharedBuffer(handle, ptr, byteSize);
|
|
||||||
}
|
|
||||||
|
|
||||||
// Creates a new filter of the specified type (e.g. "RT").
|
|
||||||
FilterRef newFilter(const char* type)
|
|
||||||
{
|
|
||||||
return oidnNewFilter(handle, type);
|
|
||||||
}
|
|
||||||
};
|
|
||||||
|
|
||||||
// Gets a boolean parameter of the device.
|
|
||||||
template<>
|
|
||||||
inline bool DeviceRef::get(const char* name)
|
|
||||||
{
|
|
||||||
return oidnGetDevice1b(handle, name);
|
|
||||||
}
|
|
||||||
|
|
||||||
// Gets an integer parameter of the device (e.g. "version").
|
|
||||||
template<>
|
|
||||||
inline int DeviceRef::get(const char* name)
|
|
||||||
{
|
|
||||||
return oidnGetDevice1i(handle, name);
|
|
||||||
}
|
|
||||||
|
|
||||||
// Creates a new device.
|
|
||||||
inline DeviceRef newDevice(DeviceType type = DeviceType::Default)
|
|
||||||
{
|
|
||||||
return DeviceRef(oidnNewDevice((OIDNDeviceType)type));
|
|
||||||
}
|
|
||||||
|
|
||||||
} // namespace oidn
|
|
@ -1,23 +0,0 @@
|
|||||||
// ======================================================================== //
|
|
||||||
// Copyright 2009-2019 Intel Corporation //
|
|
||||||
// //
|
|
||||||
// Licensed under the Apache License, Version 2.0 (the "License"); //
|
|
||||||
// you may not use this file except in compliance with the License. //
|
|
||||||
// You may obtain a copy of the License at //
|
|
||||||
// //
|
|
||||||
// http://www.apache.org/licenses/LICENSE-2.0 //
|
|
||||||
// //
|
|
||||||
// Unless required by applicable law or agreed to in writing, software //
|
|
||||||
// distributed under the License is distributed on an "AS IS" BASIS, //
|
|
||||||
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. //
|
|
||||||
// See the License for the specific language governing permissions and //
|
|
||||||
// limitations under the License. //
|
|
||||||
// ======================================================================== //
|
|
||||||
|
|
||||||
#pragma once
|
|
||||||
|
|
||||||
#define OIDN_VERSION_MAJOR 1
|
|
||||||
#define OIDN_VERSION_MINOR 1
|
|
||||||
#define OIDN_VERSION_PATCH 0
|
|
||||||
#define OIDN_VERSION 10100
|
|
||||||
#define OIDN_VERSION_STRING "1.1.0"
|
|
214
thirdparty/oidn/mkl-dnn/LICENSE
vendored
214
thirdparty/oidn/mkl-dnn/LICENSE
vendored
@ -1,214 +0,0 @@
|
|||||||
Apache License
|
|
||||||
Version 2.0, January 2004
|
|
||||||
http://www.apache.org/licenses/
|
|
||||||
|
|
||||||
TERMS AND CONDITIONS FOR USE, REPRODUCTION, AND DISTRIBUTION
|
|
||||||
|
|
||||||
1. Definitions.
|
|
||||||
|
|
||||||
"License" shall mean the terms and conditions for use, reproduction,
|
|
||||||
and distribution as defined by Sections 1 through 9 of this document.
|
|
||||||
|
|
||||||
"Licensor" shall mean the copyright owner or entity authorized by
|
|
||||||
the copyright owner that is granting the License.
|
|
||||||
|
|
||||||
"Legal Entity" shall mean the union of the acting entity and all
|
|
||||||
other entities that control, are controlled by, or are under common
|
|
||||||
control with that entity. For the purposes of this definition,
|
|
||||||
"control" means (i) the power, direct or indirect, to cause the
|
|
||||||
direction or management of such entity, whether by contract or
|
|
||||||
otherwise, or (ii) ownership of fifty percent (50%) or more of the
|
|
||||||
outstanding shares, or (iii) beneficial ownership of such entity.
|
|
||||||
|
|
||||||
"You" (or "Your") shall mean an individual or Legal Entity
|
|
||||||
exercising permissions granted by this License.
|
|
||||||
|
|
||||||
"Source" form shall mean the preferred form for making modifications,
|
|
||||||
including but not limited to software source code, documentation
|
|
||||||
source, and configuration files.
|
|
||||||
|
|
||||||
"Object" form shall mean any form resulting from mechanical
|
|
||||||
transformation or translation of a Source form, including but
|
|
||||||
not limited to compiled object code, generated documentation,
|
|
||||||
and conversions to other media types.
|
|
||||||
|
|
||||||
"Work" shall mean the work of authorship, whether in Source or
|
|
||||||
Object form, made available under the License, as indicated by a
|
|
||||||
copyright notice that is included in or attached to the work
|
|
||||||
(an example is provided in the Appendix below).
|
|
||||||
|
|
||||||
"Derivative Works" shall mean any work, whether in Source or Object
|
|
||||||
form, that is based on (or derived from) the Work and for which the
|
|
||||||
editorial revisions, annotations, elaborations, or other modifications
|
|
||||||
represent, as a whole, an original work of authorship. For the purposes
|
|
||||||
of this License, Derivative Works shall not include works that remain
|
|
||||||
separable from, or merely link (or bind by name) to the interfaces of,
|
|
||||||
the Work and Derivative Works thereof.
|
|
||||||
|
|
||||||
"Contribution" shall mean any work of authorship, including
|
|
||||||
the original version of the Work and any modifications or additions
|
|
||||||
to that Work or Derivative Works thereof, that is intentionally
|
|
||||||
submitted to Licensor for inclusion in the Work by the copyright owner
|
|
||||||
or by an individual or Legal Entity authorized to submit on behalf of
|
|
||||||
the copyright owner. For the purposes of this definition, "submitted"
|
|
||||||
means any form of electronic, verbal, or written communication sent
|
|
||||||
to the Licensor or its representatives, including but not limited to
|
|
||||||
communication on electronic mailing lists, source code control systems,
|
|
||||||
and issue tracking systems that are managed by, or on behalf of, the
|
|
||||||
Licensor for the purpose of discussing and improving the Work, but
|
|
||||||
excluding communication that is conspicuously marked or otherwise
|
|
||||||
designated in writing by the copyright owner as "Not a Contribution."
|
|
||||||
|
|
||||||
"Contributor" shall mean Licensor and any individual or Legal Entity
|
|
||||||
on behalf of whom a Contribution has been received by Licensor and
|
|
||||||
subsequently incorporated within the Work.
|
|
||||||
|
|
||||||
2. Grant of Copyright License. Subject to the terms and conditions of
|
|
||||||
this License, each Contributor hereby grants to You a perpetual,
|
|
||||||
worldwide, non-exclusive, no-charge, royalty-free, irrevocable
|
|
||||||
copyright license to reproduce, prepare Derivative Works of,
|
|
||||||
publicly display, publicly perform, sublicense, and distribute the
|
|
||||||
Work and such Derivative Works in Source or Object form.
|
|
||||||
|
|
||||||
3. Grant of Patent License. Subject to the terms and conditions of
|
|
||||||
this License, each Contributor hereby grants to You a perpetual,
|
|
||||||
worldwide, non-exclusive, no-charge, royalty-free, irrevocable
|
|
||||||
(except as stated in this section) patent license to make, have made,
|
|
||||||
use, offer to sell, sell, import, and otherwise transfer the Work,
|
|
||||||
where such license applies only to those patent claims licensable
|
|
||||||
by such Contributor that are necessarily infringed by their
|
|
||||||
Contribution(s) alone or by combination of their Contribution(s)
|
|
||||||
with the Work to which such Contribution(s) was submitted. If You
|
|
||||||
institute patent litigation against any entity (including a
|
|
||||||
cross-claim or counterclaim in a lawsuit) alleging that the Work
|
|
||||||
or a Contribution incorporated within the Work constitutes direct
|
|
||||||
or contributory patent infringement, then any patent licenses
|
|
||||||
granted to You under this License for that Work shall terminate
|
|
||||||
as of the date such litigation is filed.
|
|
||||||
|
|
||||||
4. Redistribution. You may reproduce and distribute copies of the
|
|
||||||
Work or Derivative Works thereof in any medium, with or without
|
|
||||||
modifications, and in Source or Object form, provided that You
|
|
||||||
meet the following conditions:
|
|
||||||
|
|
||||||
(a) You must give any other recipients of the Work or
|
|
||||||
Derivative Works a copy of this License; and
|
|
||||||
|
|
||||||
(b) You must cause any modified files to carry prominent notices
|
|
||||||
stating that You changed the files; and
|
|
||||||
|
|
||||||
(c) You must retain, in the Source form of any Derivative Works
|
|
||||||
that You distribute, all copyright, patent, trademark, and
|
|
||||||
attribution notices from the Source form of the Work,
|
|
||||||
excluding those notices that do not pertain to any part of
|
|
||||||
the Derivative Works; and
|
|
||||||
|
|
||||||
(d) If the Work includes a "NOTICE" text file as part of its
|
|
||||||
distribution, then any Derivative Works that You distribute must
|
|
||||||
include a readable copy of the attribution notices contained
|
|
||||||
within such NOTICE file, excluding those notices that do not
|
|
||||||
pertain to any part of the Derivative Works, in at least one
|
|
||||||
of the following places: within a NOTICE text file distributed
|
|
||||||
as part of the Derivative Works; within the Source form or
|
|
||||||
documentation, if provided along with the Derivative Works; or,
|
|
||||||
within a display generated by the Derivative Works, if and
|
|
||||||
wherever such third-party notices normally appear. The contents
|
|
||||||
of the NOTICE file are for informational purposes only and
|
|
||||||
do not modify the License. You may add Your own attribution
|
|
||||||
notices within Derivative Works that You distribute, alongside
|
|
||||||
or as an addendum to the NOTICE text from the Work, provided
|
|
||||||
that such additional attribution notices cannot be construed
|
|
||||||
as modifying the License.
|
|
||||||
|
|
||||||
You may add Your own copyright statement to Your modifications and
|
|
||||||
may provide additional or different license terms and conditions
|
|
||||||
for use, reproduction, or distribution of Your modifications, or
|
|
||||||
for any such Derivative Works as a whole, provided Your use,
|
|
||||||
reproduction, and distribution of the Work otherwise complies with
|
|
||||||
the conditions stated in this License.
|
|
||||||
|
|
||||||
5. Submission of Contributions. Unless You explicitly state otherwise,
|
|
||||||
any Contribution intentionally submitted for inclusion in the Work
|
|
||||||
by You to the Licensor shall be under the terms and conditions of
|
|
||||||
this License, without any additional terms or conditions.
|
|
||||||
Notwithstanding the above, nothing herein shall supersede or modify
|
|
||||||
the terms of any separate license agreement you may have executed
|
|
||||||
with Licensor regarding such Contributions.
|
|
||||||
|
|
||||||
6. Trademarks. This License does not grant permission to use the trade
|
|
||||||
names, trademarks, service marks, or product names of the Licensor,
|
|
||||||
except as required for reasonable and customary use in describing the
|
|
||||||
origin of the Work and reproducing the content of the NOTICE file.
|
|
||||||
|
|
||||||
7. Disclaimer of Warranty. Unless required by applicable law or
|
|
||||||
agreed to in writing, Licensor provides the Work (and each
|
|
||||||
Contributor provides its Contributions) on an "AS IS" BASIS,
|
|
||||||
WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or
|
|
||||||
implied, including, without limitation, any warranties or conditions
|
|
||||||
of TITLE, NON-INFRINGEMENT, MERCHANTABILITY, or FITNESS FOR A
|
|
||||||
PARTICULAR PURPOSE. You are solely responsible for determining the
|
|
||||||
appropriateness of using or redistributing the Work and assume any
|
|
||||||
risks associated with Your exercise of permissions under this License.
|
|
||||||
|
|
||||||
8. Limitation of Liability. In no event and under no legal theory,
|
|
||||||
whether in tort (including negligence), contract, or otherwise,
|
|
||||||
unless required by applicable law (such as deliberate and grossly
|
|
||||||
negligent acts) or agreed to in writing, shall any Contributor be
|
|
||||||
liable to You for damages, including any direct, indirect, special,
|
|
||||||
incidental, or consequential damages of any character arising as a
|
|
||||||
result of this License or out of the use or inability to use the
|
|
||||||
Work (including but not limited to damages for loss of goodwill,
|
|
||||||
work stoppage, computer failure or malfunction, or any and all
|
|
||||||
other commercial damages or losses), even if such Contributor
|
|
||||||
has been advised of the possibility of such damages.
|
|
||||||
|
|
||||||
9. Accepting Warranty or Additional Liability. While redistributing
|
|
||||||
the Work or Derivative Works thereof, You may choose to offer,
|
|
||||||
and charge a fee for, acceptance of support, warranty, indemnity,
|
|
||||||
or other liability obligations and/or rights consistent with this
|
|
||||||
License. However, in accepting such obligations, You may act only
|
|
||||||
on Your own behalf and on Your sole responsibility, not on behalf
|
|
||||||
of any other Contributor, and only if You agree to indemnify,
|
|
||||||
defend, and hold each Contributor harmless for any liability
|
|
||||||
incurred by, or claims asserted against, such Contributor by reason
|
|
||||||
of your accepting any such warranty or additional liability.
|
|
||||||
|
|
||||||
END OF TERMS AND CONDITIONS
|
|
||||||
|
|
||||||
APPENDIX: How to apply the Apache License to your work.
|
|
||||||
|
|
||||||
To apply the Apache License to your work, attach the following
|
|
||||||
boilerplate notice, with the fields enclosed by brackets "{}"
|
|
||||||
replaced with your own identifying information. (Don't include
|
|
||||||
the brackets!) The text should be enclosed in the appropriate
|
|
||||||
comment syntax for the file format. We also recommend that a
|
|
||||||
file or class name and description of purpose be included on the
|
|
||||||
same "printed page" as the copyright notice for easier
|
|
||||||
identification within third-party archives.
|
|
||||||
|
|
||||||
Copyright {yyyy} {name of copyright owner}
|
|
||||||
|
|
||||||
Licensed under the Apache License, Version 2.0 (the "License");
|
|
||||||
you may not use this file except in compliance with the License.
|
|
||||||
You may obtain a copy of the License at
|
|
||||||
|
|
||||||
http://www.apache.org/licenses/LICENSE-2.0
|
|
||||||
|
|
||||||
Unless required by applicable law or agreed to in writing, software
|
|
||||||
distributed under the License is distributed on an "AS IS" BASIS,
|
|
||||||
WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
|
||||||
See the License for the specific language governing permissions and
|
|
||||||
limitations under the License.
|
|
||||||
|
|
||||||
============================================================================
|
|
||||||
|
|
||||||
Intel MKL-DNN includes components with separate copyright
|
|
||||||
notices and license terms.
|
|
||||||
|
|
||||||
XByak, 3-clause BSD license
|
|
||||||
Copyright (c) 2007 MITSUNARI Shigeo
|
|
||||||
See full copyright notice and license text in src/cpu/xbyak/COPYRIGHT
|
|
||||||
|
|
||||||
gtest, 3-clause BSD license
|
|
||||||
Copyright 2008, Google Inc.
|
|
||||||
See full copyright notice and license text in tests/gtests/gtest/LICENSE
|
|
1771
thirdparty/oidn/mkl-dnn/include/mkldnn.h
vendored
1771
thirdparty/oidn/mkl-dnn/include/mkldnn.h
vendored
File diff suppressed because it is too large
Load Diff
2615
thirdparty/oidn/mkl-dnn/include/mkldnn.hpp
vendored
2615
thirdparty/oidn/mkl-dnn/include/mkldnn.hpp
vendored
File diff suppressed because it is too large
Load Diff
98
thirdparty/oidn/mkl-dnn/include/mkldnn_debug.h
vendored
98
thirdparty/oidn/mkl-dnn/include/mkldnn_debug.h
vendored
@ -1,98 +0,0 @@
|
|||||||
/*******************************************************************************
|
|
||||||
* Copyright 2018-2019 Intel Corporation
|
|
||||||
*
|
|
||||||
* Licensed under the Apache License, Version 2.0 (the "License");
|
|
||||||
* you may not use this file except in compliance with the License.
|
|
||||||
* You may obtain a copy of the License at
|
|
||||||
*
|
|
||||||
* http://www.apache.org/licenses/LICENSE-2.0
|
|
||||||
*
|
|
||||||
* Unless required by applicable law or agreed to in writing, software
|
|
||||||
* distributed under the License is distributed on an "AS IS" BASIS,
|
|
||||||
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
|
||||||
* See the License for the specific language governing permissions and
|
|
||||||
* limitations under the License.
|
|
||||||
*******************************************************************************/
|
|
||||||
|
|
||||||
/* DO NOT EDIT, AUTO-GENERATED */
|
|
||||||
|
|
||||||
#ifndef MKLDNN_DEBUG_H
|
|
||||||
#define MKLDNN_DEBUG_H
|
|
||||||
|
|
||||||
#ifndef DOXYGEN_SHOULD_SKIP_THIS
|
|
||||||
|
|
||||||
/* All symbols shall be internal unless marked as MKLDNN_API */
|
|
||||||
#if defined _WIN32 || defined __CYGWIN__
|
|
||||||
# define MKLDNN_HELPER_DLL_IMPORT __declspec(dllimport)
|
|
||||||
# define MKLDNN_HELPER_DLL_EXPORT __declspec(dllexport)
|
|
||||||
#else
|
|
||||||
# if __GNUC__ >= 4
|
|
||||||
# define MKLDNN_HELPER_DLL_IMPORT __attribute__ ((visibility ("default")))
|
|
||||||
# define MKLDNN_HELPER_DLL_EXPORT __attribute__ ((visibility ("default")))
|
|
||||||
# else
|
|
||||||
# define MKLDNN_HELPER_DLL_IMPORT
|
|
||||||
# define MKLDNN_HELPER_DLL_EXPORT
|
|
||||||
# endif
|
|
||||||
#endif
|
|
||||||
|
|
||||||
#ifdef MKLDNN_DLL
|
|
||||||
# ifdef MKLDNN_DLL_EXPORTS
|
|
||||||
# define MKLDNN_API MKLDNN_HELPER_DLL_EXPORT
|
|
||||||
# else
|
|
||||||
# define MKLDNN_API MKLDNN_HELPER_DLL_IMPORT
|
|
||||||
# endif
|
|
||||||
#else
|
|
||||||
# define MKLDNN_API
|
|
||||||
#endif
|
|
||||||
|
|
||||||
#if defined (__GNUC__)
|
|
||||||
# define MKLDNN_DEPRECATED __attribute__((deprecated))
|
|
||||||
#elif defined(_MSC_VER)
|
|
||||||
# define MKLDNN_DEPRECATED __declspec(deprecated)
|
|
||||||
#else
|
|
||||||
# define MKLDNN_DEPRECATED
|
|
||||||
#endif
|
|
||||||
|
|
||||||
#include "mkldnn_types.h"
|
|
||||||
#endif /* DOXYGEN_SHOULD_SKIP_THIS */
|
|
||||||
|
|
||||||
#ifdef __cplusplus
|
|
||||||
extern "C" {
|
|
||||||
#endif
|
|
||||||
|
|
||||||
const char MKLDNN_API *mkldnn_status2str(mkldnn_status_t v);
|
|
||||||
const char MKLDNN_API *mkldnn_dt2str(mkldnn_data_type_t v);
|
|
||||||
const char MKLDNN_API *mkldnn_fmt_kind2str(mkldnn_format_kind_t v);
|
|
||||||
const char MKLDNN_API *mkldnn_fmt_tag2str(mkldnn_format_tag_t v);
|
|
||||||
const char MKLDNN_API *mkldnn_prop_kind2str(mkldnn_prop_kind_t v);
|
|
||||||
const char MKLDNN_API *mkldnn_prim_kind2str(mkldnn_primitive_kind_t v);
|
|
||||||
const char MKLDNN_API *mkldnn_alg_kind2str(mkldnn_alg_kind_t v);
|
|
||||||
const char MKLDNN_API *mkldnn_rnn_direction2str(mkldnn_rnn_direction_t v);
|
|
||||||
|
|
||||||
/** Forms a format string for a given memory descriptor.
|
|
||||||
*
|
|
||||||
* The format is defined as: 'dt:[p|o|0]:fmt_kind:fmt:extra'.
|
|
||||||
* Here:
|
|
||||||
* - dt -- data type
|
|
||||||
* - p -- indicates there is non-trivial padding
|
|
||||||
* - o -- indicates there is non-trivial padding offset
|
|
||||||
* - 0 -- indicates there is non-trivial offset0
|
|
||||||
* - fmt_kind -- format kind (blocked, wino, etc...)
|
|
||||||
* - fmt -- extended format string (format_kind specific)
|
|
||||||
* - extra -- shows extra fields (underspecified)
|
|
||||||
*/
|
|
||||||
int MKLDNN_API mkldnn_md2fmt_str(char *fmt_str, size_t fmt_str_len,
|
|
||||||
const mkldnn_memory_desc_t *md);
|
|
||||||
|
|
||||||
/** Forms a dimension string for a given memory descriptor.
|
|
||||||
*
|
|
||||||
* The format is defined as: 'dim0xdim1x...xdimN
|
|
||||||
*/
|
|
||||||
int MKLDNN_API mkldnn_md2dim_str(char *dim_str, size_t dim_str_len,
|
|
||||||
const mkldnn_memory_desc_t *md);
|
|
||||||
|
|
||||||
#ifdef __cplusplus
|
|
||||||
}
|
|
||||||
#endif
|
|
||||||
|
|
||||||
#endif
|
|
1415
thirdparty/oidn/mkl-dnn/include/mkldnn_types.h
vendored
1415
thirdparty/oidn/mkl-dnn/include/mkldnn_types.h
vendored
File diff suppressed because it is too large
Load Diff
32
thirdparty/oidn/mkl-dnn/include/mkldnn_version.h
vendored
32
thirdparty/oidn/mkl-dnn/include/mkldnn_version.h
vendored
@ -1,32 +0,0 @@
|
|||||||
/*******************************************************************************
|
|
||||||
* Copyright 2019 Intel Corporation
|
|
||||||
*
|
|
||||||
* Licensed under the Apache License, Version 2.0 (the "License");
|
|
||||||
* you may not use this file except in compliance with the License.
|
|
||||||
* You may obtain a copy of the License at
|
|
||||||
*
|
|
||||||
* http://www.apache.org/licenses/LICENSE-2.0
|
|
||||||
*
|
|
||||||
* Unless required by applicable law or agreed to in writing, software
|
|
||||||
* distributed under the License is distributed on an "AS IS" BASIS,
|
|
||||||
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
|
||||||
* See the License for the specific language governing permissions and
|
|
||||||
* limitations under the License.
|
|
||||||
*******************************************************************************/
|
|
||||||
|
|
||||||
#ifndef MKLDNN_VERSION_H
|
|
||||||
#define MKLDNN_VERSION_H
|
|
||||||
|
|
||||||
/* Major version of MKL-DNN */
|
|
||||||
#define MKLDNN_VERSION_MAJOR 0
|
|
||||||
|
|
||||||
/* Minor version of MKL-DNN */
|
|
||||||
#define MKLDNN_VERSION_MINOR 90
|
|
||||||
|
|
||||||
/* Patch version of MKL-DNN */
|
|
||||||
#define MKLDNN_VERSION_PATCH 0
|
|
||||||
|
|
||||||
/* Git Commit Hash of MKL-DNN */
|
|
||||||
#define MKLDNN_VERSION_HASH "096bda1ca23324879f2df5a129e610e4405f775c"
|
|
||||||
|
|
||||||
#endif
|
|
@ -1,32 +0,0 @@
|
|||||||
/*******************************************************************************
|
|
||||||
* Copyright 2019 Intel Corporation
|
|
||||||
*
|
|
||||||
* Licensed under the Apache License, Version 2.0 (the "License");
|
|
||||||
* you may not use this file except in compliance with the License.
|
|
||||||
* You may obtain a copy of the License at
|
|
||||||
*
|
|
||||||
* http://www.apache.org/licenses/LICENSE-2.0
|
|
||||||
*
|
|
||||||
* Unless required by applicable law or agreed to in writing, software
|
|
||||||
* distributed under the License is distributed on an "AS IS" BASIS,
|
|
||||||
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
|
||||||
* See the License for the specific language governing permissions and
|
|
||||||
* limitations under the License.
|
|
||||||
*******************************************************************************/
|
|
||||||
|
|
||||||
#ifndef MKLDNN_VERSION_H
|
|
||||||
#define MKLDNN_VERSION_H
|
|
||||||
|
|
||||||
/* Major version of MKL-DNN */
|
|
||||||
#define MKLDNN_VERSION_MAJOR @MKLDNN_VERSION_MAJOR@
|
|
||||||
|
|
||||||
/* Minor version of MKL-DNN */
|
|
||||||
#define MKLDNN_VERSION_MINOR @MKLDNN_VERSION_MINOR@
|
|
||||||
|
|
||||||
/* Patch version of MKL-DNN */
|
|
||||||
#define MKLDNN_VERSION_PATCH @MKLDNN_VERSION_PATCH@
|
|
||||||
|
|
||||||
/* Git Commit Hash of MKL-DNN */
|
|
||||||
#define MKLDNN_VERSION_HASH "@MKLDNN_VERSION_HASH@"
|
|
||||||
|
|
||||||
#endif
|
|
@ -1,104 +0,0 @@
|
|||||||
/*******************************************************************************
|
|
||||||
* Copyright 2016-2018 Intel Corporation
|
|
||||||
*
|
|
||||||
* Licensed under the Apache License, Version 2.0 (the "License");
|
|
||||||
* you may not use this file except in compliance with the License.
|
|
||||||
* You may obtain a copy of the License at
|
|
||||||
*
|
|
||||||
* http://www.apache.org/licenses/LICENSE-2.0
|
|
||||||
*
|
|
||||||
* Unless required by applicable law or agreed to in writing, software
|
|
||||||
* distributed under the License is distributed on an "AS IS" BASIS,
|
|
||||||
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
|
||||||
* See the License for the specific language governing permissions and
|
|
||||||
* limitations under the License.
|
|
||||||
*******************************************************************************/
|
|
||||||
|
|
||||||
#include <assert.h>
|
|
||||||
#include "mkldnn.h"
|
|
||||||
|
|
||||||
#include "c_types_map.hpp"
|
|
||||||
#include "type_helpers.hpp"
|
|
||||||
#include "utils.hpp"
|
|
||||||
|
|
||||||
using namespace mkldnn::impl;
|
|
||||||
using namespace mkldnn::impl::utils;
|
|
||||||
using namespace mkldnn::impl::status;
|
|
||||||
using namespace mkldnn::impl::prop_kind;
|
|
||||||
using namespace mkldnn::impl::alg_kind;
|
|
||||||
using namespace mkldnn::impl::types;
|
|
||||||
|
|
||||||
namespace {
|
|
||||||
status_t bnrm_desc_init(batch_normalization_desc_t *bnrm_desc,
|
|
||||||
prop_kind_t prop_kind, const memory_desc_t *data_desc,
|
|
||||||
const memory_desc_t *diff_data_desc, float epsilon, unsigned flags) {
|
|
||||||
bool args_ok = true
|
|
||||||
&& !any_null(bnrm_desc, data_desc)
|
|
||||||
&& one_of(prop_kind, forward_training, forward_inference,
|
|
||||||
backward_data, backward)
|
|
||||||
&& IMPLICATION(prop_kind & backward, diff_data_desc != nullptr);
|
|
||||||
if (!args_ok) return invalid_arguments;
|
|
||||||
|
|
||||||
auto bd = batch_normalization_desc_t();
|
|
||||||
bd.primitive_kind = primitive_kind::batch_normalization;
|
|
||||||
bd.prop_kind = prop_kind;
|
|
||||||
|
|
||||||
bd.data_desc = *data_desc;
|
|
||||||
bd.diff_data_desc = zero_md();
|
|
||||||
if ( one_of(bd.prop_kind,backward_data, backward) )
|
|
||||||
bd.diff_data_desc = *diff_data_desc;
|
|
||||||
|
|
||||||
dims_t scaleshift_dims = { 2, data_desc->dims[1] };
|
|
||||||
mkldnn_memory_desc_init_by_tag(&bd.data_scaleshift_desc, 2,
|
|
||||||
scaleshift_dims, data_type::f32, mkldnn_nc);
|
|
||||||
bd.diff_data_scaleshift_desc = zero_md();
|
|
||||||
if (bd.prop_kind == backward) {
|
|
||||||
bd.diff_data_scaleshift_desc = bd.data_scaleshift_desc;
|
|
||||||
}
|
|
||||||
|
|
||||||
dims_t stats_dims = { data_desc->dims[1] };
|
|
||||||
mkldnn_memory_desc_init_by_tag(&bd.mean_desc, 1, stats_dims,
|
|
||||||
data_type::f32, mkldnn_x);
|
|
||||||
bd.variance_desc = bd.mean_desc;
|
|
||||||
bd.batch_norm_epsilon = epsilon;
|
|
||||||
|
|
||||||
unsigned bnorm_flags =
|
|
||||||
mkldnn_use_global_stats | mkldnn_use_scaleshift | mkldnn_fuse_bn_relu;
|
|
||||||
if ((~bnorm_flags & flags) != 0) return invalid_arguments;
|
|
||||||
|
|
||||||
bd.flags = flags;
|
|
||||||
|
|
||||||
bool consistency = true
|
|
||||||
&& utils::one_of(bd.data_desc.ndims, 2, 4, 5);
|
|
||||||
if (bd.prop_kind == backward_data)
|
|
||||||
consistency = consistency
|
|
||||||
&& utils::one_of(bd.diff_data_desc.ndims, 2, 4, 5)
|
|
||||||
&& array_cmp(bd.diff_data_desc.dims, bd.data_desc.dims,
|
|
||||||
bd.diff_data_desc.ndims);
|
|
||||||
if (!consistency) return invalid_arguments;
|
|
||||||
|
|
||||||
*bnrm_desc = bd;
|
|
||||||
return success;
|
|
||||||
}
|
|
||||||
}
|
|
||||||
|
|
||||||
status_t mkldnn_batch_normalization_forward_desc_init(
|
|
||||||
batch_normalization_desc_t *bnrm_desc, prop_kind_t prop_kind,
|
|
||||||
const memory_desc_t *data_desc, float epsilon, unsigned flags) {
|
|
||||||
if (!one_of(prop_kind, forward_training, forward_inference))
|
|
||||||
return invalid_arguments;
|
|
||||||
return bnrm_desc_init(bnrm_desc, prop_kind, data_desc, nullptr,
|
|
||||||
epsilon, flags);
|
|
||||||
}
|
|
||||||
|
|
||||||
status_t mkldnn_batch_normalization_backward_desc_init(
|
|
||||||
batch_normalization_desc_t *bnrm_desc, prop_kind_t prop_kind,
|
|
||||||
const memory_desc_t *diff_data_desc, const memory_desc_t *data_desc,
|
|
||||||
float epsilon, unsigned flags) {
|
|
||||||
if (!one_of(prop_kind, backward, backward_data))
|
|
||||||
return invalid_arguments;
|
|
||||||
return bnrm_desc_init(bnrm_desc, prop_kind, data_desc, diff_data_desc,
|
|
||||||
epsilon, flags);
|
|
||||||
}
|
|
||||||
|
|
||||||
// vim: et ts=4 sw=4 cindent cino^=l0,\:0,N-s
|
|
@ -1,240 +0,0 @@
|
|||||||
/*******************************************************************************
|
|
||||||
* Copyright 2016-2018 Intel Corporation
|
|
||||||
*
|
|
||||||
* Licensed under the Apache License, Version 2.0 (the "License");
|
|
||||||
* you may not use this file except in compliance with the License.
|
|
||||||
* You may obtain a copy of the License at
|
|
||||||
*
|
|
||||||
* http://www.apache.org/licenses/LICENSE-2.0
|
|
||||||
*
|
|
||||||
* Unless required by applicable law or agreed to in writing, software
|
|
||||||
* distributed under the License is distributed on an "AS IS" BASIS,
|
|
||||||
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
|
||||||
* See the License for the specific language governing permissions and
|
|
||||||
* limitations under the License.
|
|
||||||
*******************************************************************************/
|
|
||||||
|
|
||||||
#ifndef BATCH_NORMALIZATION_PD_HPP
|
|
||||||
#define BATCH_NORMALIZATION_PD_HPP
|
|
||||||
|
|
||||||
#include "mkldnn.h"
|
|
||||||
|
|
||||||
#include "c_types_map.hpp"
|
|
||||||
#include "primitive_desc.hpp"
|
|
||||||
#include "utils.hpp"
|
|
||||||
|
|
||||||
namespace mkldnn {
|
|
||||||
namespace impl {
|
|
||||||
|
|
||||||
struct batch_normalization_fwd_pd_t;
|
|
||||||
|
|
||||||
struct batch_normalization_pd_t: public primitive_desc_t {
|
|
||||||
static constexpr auto base_pkind = primitive_kind::batch_normalization;
|
|
||||||
|
|
||||||
batch_normalization_pd_t(engine_t *engine,
|
|
||||||
const batch_normalization_desc_t *adesc,
|
|
||||||
const primitive_attr_t *attr,
|
|
||||||
const batch_normalization_fwd_pd_t *hint_fwd_pd)
|
|
||||||
: primitive_desc_t(engine, attr, base_pkind)
|
|
||||||
, desc_(*adesc)
|
|
||||||
, hint_fwd_pd_(hint_fwd_pd)
|
|
||||||
, data_md_(desc_.data_desc)
|
|
||||||
, stat_md_(desc_.mean_desc)
|
|
||||||
, scaleshift_md_(desc_.data_scaleshift_desc)
|
|
||||||
, ws_md_()
|
|
||||||
{}
|
|
||||||
|
|
||||||
const batch_normalization_desc_t *desc() const { return &desc_; }
|
|
||||||
virtual const op_desc_t *op_desc() const override
|
|
||||||
{ return reinterpret_cast<const op_desc_t *>(this->desc()); }
|
|
||||||
virtual void init_info() override { impl::init_info(this, this->info_); }
|
|
||||||
|
|
||||||
virtual status_t query(query_t what, int idx, void *result) const override {
|
|
||||||
switch (what) {
|
|
||||||
case query::batch_normalization_d:
|
|
||||||
*(const batch_normalization_desc_t**)result = desc(); break;
|
|
||||||
default: return primitive_desc_t::query(what, idx, result);
|
|
||||||
}
|
|
||||||
return status::success;
|
|
||||||
}
|
|
||||||
|
|
||||||
/* common batch_normalization aux functions */
|
|
||||||
|
|
||||||
dim_t MB() const { return data_desc().dims[0]; }
|
|
||||||
dim_t C() const { return data_desc().dims[1]; }
|
|
||||||
dim_t D() const { return ndims() >= 5 ? data_desc().dims[ndims() - 3] : 1; }
|
|
||||||
dim_t H() const { return ndims() >= 4 ? data_desc().dims[ndims() - 2] : 1; }
|
|
||||||
dim_t W() const { return ndims() >= 3 ? data_desc().dims[ndims() - 1] : 1; }
|
|
||||||
|
|
||||||
int ndims() const { return desc_.data_desc.ndims; }
|
|
||||||
|
|
||||||
bool stats_is_src() const { return desc_.flags & mkldnn_use_global_stats; }
|
|
||||||
bool use_scaleshift() const { return desc_.flags & mkldnn_use_scaleshift; }
|
|
||||||
bool use_global_stats() const
|
|
||||||
{ return desc_.flags & mkldnn_use_global_stats; }
|
|
||||||
bool fuse_bn_relu() const { return desc_.flags & mkldnn_fuse_bn_relu; }
|
|
||||||
bool with_relu_post_op() const {
|
|
||||||
const auto &p = this->attr()->post_ops_;
|
|
||||||
return p.len_ == 1 && p.entry_[0].is_relu(true, true);
|
|
||||||
}
|
|
||||||
|
|
||||||
bool is_fwd() const {
|
|
||||||
return utils::one_of(desc_.prop_kind, prop_kind::forward_training,
|
|
||||||
prop_kind::forward_inference);
|
|
||||||
}
|
|
||||||
bool is_bwd() const { return !this->is_fwd(); }
|
|
||||||
bool is_training() const
|
|
||||||
{ return desc_.prop_kind == prop_kind::forward_training; }
|
|
||||||
|
|
||||||
bool has_zero_dim_memory() const
|
|
||||||
{ return memory_desc_wrapper(desc_.data_desc).has_zero_dim(); }
|
|
||||||
|
|
||||||
protected:
|
|
||||||
batch_normalization_desc_t desc_;
|
|
||||||
const batch_normalization_fwd_pd_t *hint_fwd_pd_;
|
|
||||||
|
|
||||||
memory_desc_t data_md_;
|
|
||||||
memory_desc_t stat_md_;
|
|
||||||
memory_desc_t scaleshift_md_;
|
|
||||||
|
|
||||||
memory_desc_t ws_md_;
|
|
||||||
|
|
||||||
void init_default_ws(size_t bits_per_element) {
|
|
||||||
const auto data_mdw = memory_desc_wrapper(data_md_);
|
|
||||||
|
|
||||||
const dim_t data_nelems = data_mdw.nelems(true);
|
|
||||||
const dim_t bits_per_byte = 8;
|
|
||||||
const dims_t ws_sz = { (dim_t)utils::div_up(
|
|
||||||
data_nelems * bits_per_element, bits_per_byte) };
|
|
||||||
mkldnn_memory_desc_init_by_tag(&ws_md_, 1, ws_sz, impl::data_type::u8,
|
|
||||||
format_tag::x);
|
|
||||||
}
|
|
||||||
|
|
||||||
private:
|
|
||||||
const memory_desc_t &data_desc() const { return desc_.data_desc; }
|
|
||||||
};
|
|
||||||
|
|
||||||
struct batch_normalization_fwd_pd_t: public batch_normalization_pd_t {
|
|
||||||
typedef batch_normalization_fwd_pd_t base_class;
|
|
||||||
typedef batch_normalization_fwd_pd_t hint_class;
|
|
||||||
|
|
||||||
batch_normalization_fwd_pd_t(engine_t *engine,
|
|
||||||
const batch_normalization_desc_t *adesc,
|
|
||||||
const primitive_attr_t *attr,
|
|
||||||
const batch_normalization_fwd_pd_t *hint_fwd_pd)
|
|
||||||
: batch_normalization_pd_t(engine, adesc, attr, hint_fwd_pd)
|
|
||||||
{}
|
|
||||||
|
|
||||||
virtual arg_usage_t arg_usage(primitive_arg_index_t arg) const override {
|
|
||||||
if (arg == MKLDNN_ARG_SRC) return arg_usage_t::input;
|
|
||||||
if (arg == MKLDNN_ARG_DST) return arg_usage_t::output;
|
|
||||||
|
|
||||||
if (utils::one_of(arg, MKLDNN_ARG_MEAN, MKLDNN_ARG_VARIANCE)) {
|
|
||||||
if (stats_is_src()) return arg_usage_t::input;
|
|
||||||
if (!stats_is_src() && is_training()) return arg_usage_t::output;
|
|
||||||
return arg_usage_t::unused;
|
|
||||||
}
|
|
||||||
|
|
||||||
if (arg == MKLDNN_ARG_SCALE_SHIFT && use_scaleshift())
|
|
||||||
return arg_usage_t::input;
|
|
||||||
|
|
||||||
if (arg == MKLDNN_ARG_WORKSPACE && is_training() && fuse_bn_relu())
|
|
||||||
return arg_usage_t::output;
|
|
||||||
|
|
||||||
return primitive_desc_t::arg_usage(arg);
|
|
||||||
}
|
|
||||||
|
|
||||||
virtual const memory_desc_t *src_md(int index = 0) const override {
|
|
||||||
if (index == 0) return &data_md_;
|
|
||||||
if (stats_is_src() && (index == 1 || index == 2)) return &stat_md_;
|
|
||||||
return nullptr;
|
|
||||||
}
|
|
||||||
|
|
||||||
virtual const memory_desc_t *dst_md(int index = 0) const override {
|
|
||||||
if (index == 0) return &data_md_;
|
|
||||||
if (!stats_is_src() && is_training() && (index == 1 || index == 2))
|
|
||||||
return &stat_md_;
|
|
||||||
return nullptr;
|
|
||||||
}
|
|
||||||
|
|
||||||
virtual const memory_desc_t *weights_md(int index = 0) const override
|
|
||||||
{ return index == 0 ? &scaleshift_md_ : nullptr; }
|
|
||||||
|
|
||||||
virtual const memory_desc_t *workspace_md(int index = 0) const override
|
|
||||||
{ return index == 0 && is_training() && fuse_bn_relu() ? &ws_md_ : nullptr; }
|
|
||||||
|
|
||||||
const memory_desc_t *stat_md() const
|
|
||||||
{ return stats_is_src() ? src_md(1) : dst_md(1); }
|
|
||||||
|
|
||||||
virtual int n_inputs() const override
|
|
||||||
{ return 1 + 2 * stats_is_src() + use_scaleshift(); }
|
|
||||||
virtual int n_outputs() const override
|
|
||||||
{ return 1 + (fuse_bn_relu() + 2 * (!stats_is_src())) * is_training(); }
|
|
||||||
};
|
|
||||||
|
|
||||||
struct batch_normalization_bwd_pd_t: public batch_normalization_pd_t {
|
|
||||||
typedef batch_normalization_bwd_pd_t base_class;
|
|
||||||
typedef batch_normalization_fwd_pd_t hint_class;
|
|
||||||
|
|
||||||
batch_normalization_bwd_pd_t(engine_t *engine,
|
|
||||||
const batch_normalization_desc_t *adesc,
|
|
||||||
const primitive_attr_t *attr,
|
|
||||||
const batch_normalization_fwd_pd_t *hint_fwd_pd)
|
|
||||||
: batch_normalization_pd_t(engine, adesc, attr, hint_fwd_pd)
|
|
||||||
, diff_data_md_(desc_.diff_data_desc)
|
|
||||||
, diff_scaleshift_md_(desc_.diff_data_scaleshift_desc)
|
|
||||||
{}
|
|
||||||
|
|
||||||
virtual arg_usage_t arg_usage(primitive_arg_index_t arg) const override {
|
|
||||||
if (utils::one_of(arg, MKLDNN_ARG_SRC, MKLDNN_ARG_MEAN,
|
|
||||||
MKLDNN_ARG_VARIANCE, MKLDNN_ARG_DIFF_DST))
|
|
||||||
return arg_usage_t::input;
|
|
||||||
|
|
||||||
if (arg == MKLDNN_ARG_SCALE_SHIFT && use_scaleshift())
|
|
||||||
return arg_usage_t::input;
|
|
||||||
|
|
||||||
if (arg == MKLDNN_ARG_WORKSPACE && fuse_bn_relu())
|
|
||||||
return arg_usage_t::input;
|
|
||||||
|
|
||||||
if (arg == MKLDNN_ARG_DIFF_SRC)
|
|
||||||
return arg_usage_t::output;
|
|
||||||
|
|
||||||
if (arg == MKLDNN_ARG_DIFF_SCALE_SHIFT && use_scaleshift())
|
|
||||||
return arg_usage_t::output;
|
|
||||||
|
|
||||||
return primitive_desc_t::arg_usage(arg);
|
|
||||||
}
|
|
||||||
|
|
||||||
virtual const memory_desc_t *src_md(int index = 0) const override
|
|
||||||
{ return index == 0 ? &data_md_ : index <= 2 ? &stat_md_ : nullptr; }
|
|
||||||
virtual const memory_desc_t *diff_dst_md(int index = 0) const override
|
|
||||||
{ return index == 0 ? &diff_data_md_ : nullptr; }
|
|
||||||
virtual const memory_desc_t *diff_src_md(int index = 0) const override
|
|
||||||
{ return index == 0 ? &diff_data_md_ : nullptr; }
|
|
||||||
|
|
||||||
virtual const memory_desc_t *weights_md(int index = 0) const override
|
|
||||||
{ return index == 0 ? &scaleshift_md_ : nullptr; }
|
|
||||||
virtual const memory_desc_t *diff_weights_md(int index = 0) const override
|
|
||||||
{ return index == 0 ? &diff_scaleshift_md_ : nullptr; }
|
|
||||||
|
|
||||||
virtual const memory_desc_t *workspace_md(int index = 0) const override
|
|
||||||
{ return index == 0 && fuse_bn_relu() ? &ws_md_ : nullptr; }
|
|
||||||
|
|
||||||
const memory_desc_t *stat_md() const { return src_md(1); }
|
|
||||||
|
|
||||||
virtual int n_inputs() const override
|
|
||||||
{ return 4 + use_scaleshift() + fuse_bn_relu(); }
|
|
||||||
virtual int n_outputs() const override
|
|
||||||
{ return 1 + (desc_.prop_kind == prop_kind::backward); }
|
|
||||||
|
|
||||||
protected:
|
|
||||||
memory_desc_t diff_data_md_;
|
|
||||||
memory_desc_t diff_scaleshift_md_;
|
|
||||||
};
|
|
||||||
|
|
||||||
}
|
|
||||||
}
|
|
||||||
|
|
||||||
#endif
|
|
||||||
|
|
||||||
// vim: et ts=4 sw=4 cindent cino^=l0,\:0,N-s
|
|
550
thirdparty/oidn/mkl-dnn/src/common/c_types_map.hpp
vendored
550
thirdparty/oidn/mkl-dnn/src/common/c_types_map.hpp
vendored
@ -1,550 +0,0 @@
|
|||||||
/*******************************************************************************
|
|
||||||
* Copyright 2016-2018 Intel Corporation
|
|
||||||
*
|
|
||||||
* Licensed under the Apache License, Version 2.0 (the "License");
|
|
||||||
* you may not use this file except in compliance with the License.
|
|
||||||
* You may obtain a copy of the License at
|
|
||||||
*
|
|
||||||
* http://www.apache.org/licenses/LICENSE-2.0
|
|
||||||
*
|
|
||||||
* Unless required by applicable law or agreed to in writing, software
|
|
||||||
* distributed under the License is distributed on an "AS IS" BASIS,
|
|
||||||
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
|
||||||
* See the License for the specific language governing permissions and
|
|
||||||
* limitations under the License.
|
|
||||||
*******************************************************************************/
|
|
||||||
|
|
||||||
#ifndef TYPE_MAPPING_HPP
|
|
||||||
#define TYPE_MAPPING_HPP
|
|
||||||
|
|
||||||
#include "mkldnn_types.h"
|
|
||||||
|
|
||||||
namespace mkldnn {
|
|
||||||
namespace impl {
|
|
||||||
|
|
||||||
// TODO: autogenerate this
|
|
||||||
|
|
||||||
using dim_t = mkldnn_dim_t;
|
|
||||||
using dims_t = mkldnn_dims_t;
|
|
||||||
using stride_t = mkldnn_dim_t;
|
|
||||||
using strides_t = mkldnn_strides_t;
|
|
||||||
|
|
||||||
using status_t = mkldnn_status_t;
|
|
||||||
namespace status {
|
|
||||||
const status_t success = mkldnn_success;
|
|
||||||
const status_t out_of_memory = mkldnn_out_of_memory;
|
|
||||||
const status_t try_again = mkldnn_try_again;
|
|
||||||
const status_t invalid_arguments = mkldnn_invalid_arguments;
|
|
||||||
const status_t not_ready = mkldnn_not_ready;
|
|
||||||
const status_t unimplemented = mkldnn_unimplemented;
|
|
||||||
const status_t iterator_ends = mkldnn_iterator_ends;
|
|
||||||
const status_t runtime_error = mkldnn_runtime_error;
|
|
||||||
const status_t not_required = mkldnn_not_required;
|
|
||||||
}
|
|
||||||
|
|
||||||
using prop_kind_t = mkldnn_prop_kind_t;
|
|
||||||
namespace prop_kind {
|
|
||||||
const prop_kind_t undef = mkldnn_prop_kind_undef;
|
|
||||||
const prop_kind_t forward_training = mkldnn_forward_training;
|
|
||||||
const prop_kind_t forward_inference = mkldnn_forward_inference;
|
|
||||||
const prop_kind_t forward_scoring = mkldnn_forward_scoring;
|
|
||||||
const prop_kind_t forward = mkldnn_forward;
|
|
||||||
const prop_kind_t backward = mkldnn_backward;
|
|
||||||
const prop_kind_t backward_data = mkldnn_backward_data;
|
|
||||||
const prop_kind_t backward_weights = mkldnn_backward_weights;
|
|
||||||
const prop_kind_t backward_bias = mkldnn_backward_bias;
|
|
||||||
}
|
|
||||||
|
|
||||||
using alg_kind_t = mkldnn_alg_kind_t;
|
|
||||||
namespace alg_kind {
|
|
||||||
const alg_kind_t undef = mkldnn_alg_kind_undef;
|
|
||||||
const alg_kind_t convolution_auto = mkldnn_convolution_auto;
|
|
||||||
const alg_kind_t convolution_direct = mkldnn_convolution_direct;
|
|
||||||
const alg_kind_t convolution_winograd = mkldnn_convolution_winograd;
|
|
||||||
const alg_kind_t deconvolution_direct = mkldnn_deconvolution_direct;
|
|
||||||
const alg_kind_t deconvolution_winograd = mkldnn_deconvolution_winograd;
|
|
||||||
const alg_kind_t eltwise_relu = mkldnn_eltwise_relu;
|
|
||||||
const alg_kind_t eltwise_tanh = mkldnn_eltwise_tanh;
|
|
||||||
const alg_kind_t eltwise_elu = mkldnn_eltwise_elu;
|
|
||||||
const alg_kind_t eltwise_square = mkldnn_eltwise_square;
|
|
||||||
const alg_kind_t eltwise_abs = mkldnn_eltwise_abs;
|
|
||||||
const alg_kind_t eltwise_sqrt = mkldnn_eltwise_sqrt;
|
|
||||||
const alg_kind_t eltwise_linear = mkldnn_eltwise_linear;
|
|
||||||
const alg_kind_t eltwise_bounded_relu = mkldnn_eltwise_bounded_relu;
|
|
||||||
const alg_kind_t eltwise_soft_relu = mkldnn_eltwise_soft_relu;
|
|
||||||
const alg_kind_t eltwise_logistic = mkldnn_eltwise_logistic;
|
|
||||||
const alg_kind_t pooling_max = mkldnn_pooling_max;
|
|
||||||
const alg_kind_t pooling_avg = mkldnn_pooling_avg;
|
|
||||||
const alg_kind_t pooling_avg_include_padding = mkldnn_pooling_avg_include_padding;
|
|
||||||
const alg_kind_t pooling_avg_exclude_padding = mkldnn_pooling_avg_exclude_padding;
|
|
||||||
const alg_kind_t lrn_across_channels = mkldnn_lrn_across_channels;
|
|
||||||
const alg_kind_t lrn_within_channel = mkldnn_lrn_within_channel;
|
|
||||||
const alg_kind_t vanilla_rnn = mkldnn_vanilla_rnn;
|
|
||||||
const alg_kind_t vanilla_lstm = mkldnn_vanilla_lstm;
|
|
||||||
const alg_kind_t vanilla_gru = mkldnn_vanilla_gru;
|
|
||||||
const alg_kind_t gru_linear_before_reset = mkldnn_gru_linear_before_reset;
|
|
||||||
}
|
|
||||||
|
|
||||||
using data_type_t = mkldnn_data_type_t;
|
|
||||||
namespace data_type {
|
|
||||||
const data_type_t undef = mkldnn_data_type_undef;
|
|
||||||
const data_type_t f32 = mkldnn_f32;
|
|
||||||
const data_type_t s32 = mkldnn_s32;
|
|
||||||
const data_type_t s8 = mkldnn_s8;
|
|
||||||
const data_type_t u8 = mkldnn_u8;
|
|
||||||
}
|
|
||||||
|
|
||||||
using scratchpad_mode_t = mkldnn_scratchpad_mode_t;
|
|
||||||
namespace scratchpad_mode {
|
|
||||||
const scratchpad_mode_t library = mkldnn_scratchpad_mode_library;
|
|
||||||
const scratchpad_mode_t user = mkldnn_scratchpad_mode_user;
|
|
||||||
}
|
|
||||||
|
|
||||||
using rnn_packed_format_t = mkldnn_rnn_packed_memory_format_t;
|
|
||||||
namespace rnn_packed_format {
|
|
||||||
const rnn_packed_format_t undef = mkldnn_packed_format_undef;
|
|
||||||
const rnn_packed_format_t ldigo_p = mkldnn_ldigo_p;
|
|
||||||
const rnn_packed_format_t ldgoi_p = mkldnn_ldgoi_p;
|
|
||||||
}
|
|
||||||
|
|
||||||
using format_kind_t = mkldnn_format_kind_t;
|
|
||||||
namespace format_kind {
|
|
||||||
const format_kind_t undef = mkldnn_format_kind_undef;
|
|
||||||
const format_kind_t any = mkldnn_format_kind_any;
|
|
||||||
const format_kind_t blocked = mkldnn_blocked;
|
|
||||||
const format_kind_t wino = mkldnn_format_kind_wino;
|
|
||||||
const format_kind_t rnn_packed = mkldnn_format_kind_rnn_packed;
|
|
||||||
}
|
|
||||||
|
|
||||||
using format_tag_t = mkldnn_format_tag_t;
|
|
||||||
namespace format_tag {
|
|
||||||
const format_tag_t undef = mkldnn_format_tag_undef;
|
|
||||||
const format_tag_t any = mkldnn_format_tag_any;
|
|
||||||
const format_tag_t a = mkldnn_a;
|
|
||||||
const format_tag_t ab = mkldnn_ab;
|
|
||||||
const format_tag_t abc = mkldnn_abc;
|
|
||||||
const format_tag_t abcd = mkldnn_abcd;
|
|
||||||
const format_tag_t abcde = mkldnn_abcde;
|
|
||||||
const format_tag_t abcdef = mkldnn_abcdef;
|
|
||||||
const format_tag_t abdec = mkldnn_abdec;
|
|
||||||
const format_tag_t acb = mkldnn_acb;
|
|
||||||
const format_tag_t acbde = mkldnn_acbde;
|
|
||||||
const format_tag_t acdb = mkldnn_acdb;
|
|
||||||
const format_tag_t acdeb = mkldnn_acdeb;
|
|
||||||
const format_tag_t ba = mkldnn_ba;
|
|
||||||
const format_tag_t bac = mkldnn_bac;
|
|
||||||
const format_tag_t bacd = mkldnn_bacd;
|
|
||||||
const format_tag_t bcda = mkldnn_bcda;
|
|
||||||
const format_tag_t cba = mkldnn_cba;
|
|
||||||
const format_tag_t cdba = mkldnn_cdba;
|
|
||||||
const format_tag_t cdeba = mkldnn_cdeba;
|
|
||||||
const format_tag_t decab = mkldnn_decab;
|
|
||||||
const format_tag_t Abc16a = mkldnn_Abc16a;
|
|
||||||
const format_tag_t ABc16a16b = mkldnn_ABc16a16b;
|
|
||||||
const format_tag_t aBc16b = mkldnn_aBc16b;
|
|
||||||
const format_tag_t ABc16b16a = mkldnn_ABc16b16a;
|
|
||||||
const format_tag_t Abc4a = mkldnn_Abc4a;
|
|
||||||
const format_tag_t aBc4b = mkldnn_aBc4b;
|
|
||||||
const format_tag_t ABc4b16a4b = mkldnn_ABc4b16a4b;
|
|
||||||
const format_tag_t ABc4b4a = mkldnn_ABc4b4a;
|
|
||||||
const format_tag_t ABc8a16b2a = mkldnn_ABc8a16b2a;
|
|
||||||
const format_tag_t ABc8a8b = mkldnn_ABc8a8b;
|
|
||||||
const format_tag_t aBc8b = mkldnn_aBc8b;
|
|
||||||
const format_tag_t ABc8b16a2b = mkldnn_ABc8b16a2b;
|
|
||||||
const format_tag_t ABc8b8a = mkldnn_ABc8b8a;
|
|
||||||
const format_tag_t Abcd16a = mkldnn_Abcd16a;
|
|
||||||
const format_tag_t ABcd16a16b = mkldnn_ABcd16a16b;
|
|
||||||
const format_tag_t aBcd16b = mkldnn_aBcd16b;
|
|
||||||
const format_tag_t ABcd16b16a = mkldnn_ABcd16b16a;
|
|
||||||
const format_tag_t aBCd16b16c = mkldnn_aBCd16b16c;
|
|
||||||
const format_tag_t aBCd16c16b = mkldnn_aBCd16c16b;
|
|
||||||
const format_tag_t Abcd4a = mkldnn_Abcd4a;
|
|
||||||
const format_tag_t aBcd4b = mkldnn_aBcd4b;
|
|
||||||
const format_tag_t ABcd4b16a4b = mkldnn_ABcd4b16a4b;
|
|
||||||
const format_tag_t ABcd4b4a = mkldnn_ABcd4b4a;
|
|
||||||
const format_tag_t aBCd4c16b4c = mkldnn_aBCd4c16b4c;
|
|
||||||
const format_tag_t aBCd4c4b = mkldnn_aBCd4c4b;
|
|
||||||
const format_tag_t ABcd8a16b2a = mkldnn_ABcd8a16b2a;
|
|
||||||
const format_tag_t ABcd8a8b = mkldnn_ABcd8a8b;
|
|
||||||
const format_tag_t aBcd8b = mkldnn_aBcd8b;
|
|
||||||
const format_tag_t ABcd8b16a2b = mkldnn_ABcd8b16a2b;
|
|
||||||
const format_tag_t aBCd8b16c2b = mkldnn_aBCd8b16c2b;
|
|
||||||
const format_tag_t ABcd8b8a = mkldnn_ABcd8b8a;
|
|
||||||
const format_tag_t aBCd8b8c = mkldnn_aBCd8b8c;
|
|
||||||
const format_tag_t aBCd8c16b2c = mkldnn_aBCd8c16b2c;
|
|
||||||
const format_tag_t aBCd8c8b = mkldnn_aBCd8c8b;
|
|
||||||
const format_tag_t Abcde16a = mkldnn_Abcde16a;
|
|
||||||
const format_tag_t ABcde16a16b = mkldnn_ABcde16a16b;
|
|
||||||
const format_tag_t aBcde16b = mkldnn_aBcde16b;
|
|
||||||
const format_tag_t ABcde16b16a = mkldnn_ABcde16b16a;
|
|
||||||
const format_tag_t aBCde16b16c = mkldnn_aBCde16b16c;
|
|
||||||
const format_tag_t aBCde16c16b = mkldnn_aBCde16c16b;
|
|
||||||
const format_tag_t aBCde2c8b4c = mkldnn_aBCde2c8b4c;
|
|
||||||
const format_tag_t Abcde4a = mkldnn_Abcde4a;
|
|
||||||
const format_tag_t aBcde4b = mkldnn_aBcde4b;
|
|
||||||
const format_tag_t ABcde4b4a = mkldnn_ABcde4b4a;
|
|
||||||
const format_tag_t aBCde4b4c = mkldnn_aBCde4b4c;
|
|
||||||
const format_tag_t aBCde4c16b4c = mkldnn_aBCde4c16b4c;
|
|
||||||
const format_tag_t aBCde4c4b = mkldnn_aBCde4c4b;
|
|
||||||
const format_tag_t Abcde8a = mkldnn_Abcde8a;
|
|
||||||
const format_tag_t ABcde8a8b = mkldnn_ABcde8a8b;
|
|
||||||
const format_tag_t aBcde8b = mkldnn_aBcde8b;
|
|
||||||
const format_tag_t ABcde8b16a2b = mkldnn_ABcde8b16a2b;
|
|
||||||
const format_tag_t aBCde8b16c2b = mkldnn_aBCde8b16c2b;
|
|
||||||
const format_tag_t ABcde8b8a = mkldnn_ABcde8b8a;
|
|
||||||
const format_tag_t aBCde8b8c = mkldnn_aBCde8b8c;
|
|
||||||
const format_tag_t aBCde8c16b2c = mkldnn_aBCde8c16b2c;
|
|
||||||
const format_tag_t aBCde8c8b = mkldnn_aBCde8c8b;
|
|
||||||
const format_tag_t aBcdef16b = mkldnn_aBcdef16b;
|
|
||||||
const format_tag_t aBCdef16b16c = mkldnn_aBCdef16b16c;
|
|
||||||
const format_tag_t aBCdef16c16b = mkldnn_aBCdef16c16b;
|
|
||||||
const format_tag_t aBcdef4b = mkldnn_aBcdef4b;
|
|
||||||
const format_tag_t aBCdef4c4b = mkldnn_aBCdef4c4b;
|
|
||||||
const format_tag_t aBCdef8b8c = mkldnn_aBCdef8b8c;
|
|
||||||
const format_tag_t aBCdef8c16b2c = mkldnn_aBCdef8c16b2c;
|
|
||||||
const format_tag_t aBCdef8c8b = mkldnn_aBCdef8c8b;
|
|
||||||
const format_tag_t aBdc16b = mkldnn_aBdc16b;
|
|
||||||
const format_tag_t aBdc4b = mkldnn_aBdc4b;
|
|
||||||
const format_tag_t aBdc8b = mkldnn_aBdc8b;
|
|
||||||
const format_tag_t aBdec16b = mkldnn_aBdec16b;
|
|
||||||
const format_tag_t aBdec4b = mkldnn_aBdec4b;
|
|
||||||
const format_tag_t aBdec8b = mkldnn_aBdec8b;
|
|
||||||
const format_tag_t aBdefc16b = mkldnn_aBdefc16b;
|
|
||||||
const format_tag_t aBdefc4b = mkldnn_aBdefc4b;
|
|
||||||
const format_tag_t aBdefc8b = mkldnn_aBdefc8b;
|
|
||||||
const format_tag_t Acb16a = mkldnn_Acb16a;
|
|
||||||
const format_tag_t Acb4a = mkldnn_Acb4a;
|
|
||||||
const format_tag_t Acb8a = mkldnn_Acb8a;
|
|
||||||
const format_tag_t aCBd16b16c = mkldnn_aCBd16b16c;
|
|
||||||
const format_tag_t aCBde16b16c = mkldnn_aCBde16b16c;
|
|
||||||
const format_tag_t Acdb16a = mkldnn_Acdb16a;
|
|
||||||
const format_tag_t Acdb4a = mkldnn_Acdb4a;
|
|
||||||
const format_tag_t Acdb8a = mkldnn_Acdb8a;
|
|
||||||
const format_tag_t Acdeb16a = mkldnn_Acdeb16a;
|
|
||||||
const format_tag_t Acdeb4a = mkldnn_Acdeb4a;
|
|
||||||
const format_tag_t Acdeb8a = mkldnn_Acdeb8a;
|
|
||||||
const format_tag_t BAc16a16b = mkldnn_BAc16a16b;
|
|
||||||
const format_tag_t BAcd16a16b = mkldnn_BAcd16a16b;
|
|
||||||
const format_tag_t last = mkldnn_format_tag_last;
|
|
||||||
|
|
||||||
const format_tag_t x = mkldnn_x;
|
|
||||||
const format_tag_t nc = mkldnn_nc;
|
|
||||||
const format_tag_t cn = mkldnn_cn;
|
|
||||||
const format_tag_t ncw = mkldnn_ncw;
|
|
||||||
const format_tag_t nwc = mkldnn_nwc;
|
|
||||||
const format_tag_t nchw = mkldnn_nchw;
|
|
||||||
const format_tag_t nhwc = mkldnn_nhwc;
|
|
||||||
const format_tag_t chwn = mkldnn_chwn;
|
|
||||||
const format_tag_t ncdhw = mkldnn_ncdhw;
|
|
||||||
const format_tag_t ndhwc = mkldnn_ndhwc;
|
|
||||||
const format_tag_t oi = mkldnn_oi;
|
|
||||||
const format_tag_t io = mkldnn_io;
|
|
||||||
const format_tag_t oiw = mkldnn_oiw;
|
|
||||||
const format_tag_t wio = mkldnn_wio;
|
|
||||||
const format_tag_t oihw = mkldnn_oihw;
|
|
||||||
const format_tag_t hwio = mkldnn_hwio;
|
|
||||||
const format_tag_t ihwo = mkldnn_ihwo;
|
|
||||||
const format_tag_t iohw = mkldnn_iohw;
|
|
||||||
const format_tag_t oidhw = mkldnn_oidhw;
|
|
||||||
const format_tag_t dhwio = mkldnn_dhwio;
|
|
||||||
const format_tag_t goiw = mkldnn_goiw;
|
|
||||||
const format_tag_t goihw = mkldnn_goihw;
|
|
||||||
const format_tag_t hwigo = mkldnn_hwigo;
|
|
||||||
const format_tag_t giohw = mkldnn_giohw;
|
|
||||||
const format_tag_t goidhw = mkldnn_goidhw;
|
|
||||||
const format_tag_t tnc = mkldnn_tnc;
|
|
||||||
const format_tag_t ntc = mkldnn_ntc;
|
|
||||||
const format_tag_t ldsnc = mkldnn_ldsnc;
|
|
||||||
const format_tag_t ldigo = mkldnn_ldigo;
|
|
||||||
const format_tag_t ldgoi = mkldnn_ldgoi;
|
|
||||||
const format_tag_t ldgo = mkldnn_ldgo;
|
|
||||||
const format_tag_t nCdhw16c = mkldnn_nCdhw16c;
|
|
||||||
const format_tag_t nCdhw4c = mkldnn_nCdhw4c;
|
|
||||||
const format_tag_t nCdhw8c = mkldnn_nCdhw8c;
|
|
||||||
const format_tag_t nChw16c = mkldnn_nChw16c;
|
|
||||||
const format_tag_t nChw4c = mkldnn_nChw4c;
|
|
||||||
const format_tag_t nChw8c = mkldnn_nChw8c;
|
|
||||||
const format_tag_t nCw16c = mkldnn_nCw16c;
|
|
||||||
const format_tag_t nCw4c = mkldnn_nCw4c;
|
|
||||||
const format_tag_t nCw8c = mkldnn_nCw8c;
|
|
||||||
const format_tag_t IOw16o16i = mkldnn_IOw16o16i;
|
|
||||||
const format_tag_t OIw16i16o = mkldnn_OIw16i16o;
|
|
||||||
const format_tag_t OIw16o16i = mkldnn_OIw16o16i;
|
|
||||||
const format_tag_t Oiw16o = mkldnn_Oiw16o;
|
|
||||||
const format_tag_t OIw4i16o4i = mkldnn_OIw4i16o4i;
|
|
||||||
const format_tag_t OIw4i4o = mkldnn_OIw4i4o;
|
|
||||||
const format_tag_t Oiw4o = mkldnn_Oiw4o;
|
|
||||||
const format_tag_t OIw8i16o2i = mkldnn_OIw8i16o2i;
|
|
||||||
const format_tag_t OIw8i8o = mkldnn_OIw8i8o;
|
|
||||||
const format_tag_t OIw8o16i2o = mkldnn_OIw8o16i2o;
|
|
||||||
const format_tag_t OIw8o8i = mkldnn_OIw8o8i;
|
|
||||||
const format_tag_t Owi16o = mkldnn_Owi16o;
|
|
||||||
const format_tag_t Owi4o = mkldnn_Owi4o;
|
|
||||||
const format_tag_t Owi8o = mkldnn_Owi8o;
|
|
||||||
const format_tag_t IOhw16o16i = mkldnn_IOhw16o16i;
|
|
||||||
const format_tag_t Ohwi16o = mkldnn_Ohwi16o;
|
|
||||||
const format_tag_t Ohwi4o = mkldnn_Ohwi4o;
|
|
||||||
const format_tag_t Ohwi8o = mkldnn_Ohwi8o;
|
|
||||||
const format_tag_t OIhw16i16o = mkldnn_OIhw16i16o;
|
|
||||||
const format_tag_t OIhw16o16i = mkldnn_OIhw16o16i;
|
|
||||||
const format_tag_t Oihw16o = mkldnn_Oihw16o;
|
|
||||||
const format_tag_t OIhw4i16o4i = mkldnn_OIhw4i16o4i;
|
|
||||||
const format_tag_t OIhw4i4o = mkldnn_OIhw4i4o;
|
|
||||||
const format_tag_t Oihw4o = mkldnn_Oihw4o;
|
|
||||||
const format_tag_t OIhw8i16o2i = mkldnn_OIhw8i16o2i;
|
|
||||||
const format_tag_t OIhw8i8o = mkldnn_OIhw8i8o;
|
|
||||||
const format_tag_t OIhw8o16i2o = mkldnn_OIhw8o16i2o;
|
|
||||||
const format_tag_t OIhw8o8i = mkldnn_OIhw8o8i;
|
|
||||||
const format_tag_t Odhwi16o = mkldnn_Odhwi16o;
|
|
||||||
const format_tag_t Odhwi4o = mkldnn_Odhwi4o;
|
|
||||||
const format_tag_t Odhwi8o = mkldnn_Odhwi8o;
|
|
||||||
const format_tag_t OIdhw16i16o = mkldnn_OIdhw16i16o;
|
|
||||||
const format_tag_t OIdhw16o16i = mkldnn_OIdhw16o16i;
|
|
||||||
const format_tag_t Oidhw16o = mkldnn_Oidhw16o;
|
|
||||||
const format_tag_t OIdhw4i4o = mkldnn_OIdhw4i4o;
|
|
||||||
const format_tag_t Oidhw4o = mkldnn_Oidhw4o;
|
|
||||||
const format_tag_t OIdhw8i16o2i = mkldnn_OIdhw8i16o2i;
|
|
||||||
const format_tag_t OIdhw8i8o = mkldnn_OIdhw8i8o;
|
|
||||||
const format_tag_t OIdhw8o8i = mkldnn_OIdhw8o8i;
|
|
||||||
const format_tag_t gIOw16o16i = mkldnn_gIOw16o16i;
|
|
||||||
const format_tag_t Goiw16g = mkldnn_Goiw16g;
|
|
||||||
const format_tag_t gOIw16i16o = mkldnn_gOIw16i16o;
|
|
||||||
const format_tag_t gOIw16o16i = mkldnn_gOIw16o16i;
|
|
||||||
const format_tag_t gOiw16o = mkldnn_gOiw16o;
|
|
||||||
const format_tag_t gOIw4i16o4i = mkldnn_gOIw4i16o4i;
|
|
||||||
const format_tag_t gOIw4i4o = mkldnn_gOIw4i4o;
|
|
||||||
const format_tag_t gOiw4o = mkldnn_gOiw4o;
|
|
||||||
const format_tag_t gOIw8i16o2i = mkldnn_gOIw8i16o2i;
|
|
||||||
const format_tag_t gOIw8i8o = mkldnn_gOIw8i8o;
|
|
||||||
const format_tag_t gOIw8o16i2o = mkldnn_gOIw8o16i2o;
|
|
||||||
const format_tag_t gOIw8o8i = mkldnn_gOIw8o8i;
|
|
||||||
const format_tag_t gOwi16o = mkldnn_gOwi16o;
|
|
||||||
const format_tag_t gOwi4o = mkldnn_gOwi4o;
|
|
||||||
const format_tag_t gOwi8o = mkldnn_gOwi8o;
|
|
||||||
const format_tag_t gIOhw16o16i = mkldnn_gIOhw16o16i;
|
|
||||||
const format_tag_t gOhwi16o = mkldnn_gOhwi16o;
|
|
||||||
const format_tag_t gOhwi4o = mkldnn_gOhwi4o;
|
|
||||||
const format_tag_t gOhwi8o = mkldnn_gOhwi8o;
|
|
||||||
const format_tag_t Goihw16g = mkldnn_Goihw16g;
|
|
||||||
const format_tag_t gOIhw16i16o = mkldnn_gOIhw16i16o;
|
|
||||||
const format_tag_t gOIhw16o16i = mkldnn_gOIhw16o16i;
|
|
||||||
const format_tag_t gOihw16o = mkldnn_gOihw16o;
|
|
||||||
const format_tag_t gOIhw2i8o4i = mkldnn_gOIhw2i8o4i;
|
|
||||||
const format_tag_t gOIhw4i16o4i = mkldnn_gOIhw4i16o4i;
|
|
||||||
const format_tag_t gOIhw4i4o = mkldnn_gOIhw4i4o;
|
|
||||||
const format_tag_t gOIhw4o4i = mkldnn_gOIhw4o4i;
|
|
||||||
const format_tag_t gOihw4o = mkldnn_gOihw4o;
|
|
||||||
const format_tag_t Goihw8g = mkldnn_Goihw8g;
|
|
||||||
const format_tag_t gOIhw8i16o2i = mkldnn_gOIhw8i16o2i;
|
|
||||||
const format_tag_t gOIhw8i8o = mkldnn_gOIhw8i8o;
|
|
||||||
const format_tag_t gOIhw8o16i2o = mkldnn_gOIhw8o16i2o;
|
|
||||||
const format_tag_t gOIhw8o8i = mkldnn_gOIhw8o8i;
|
|
||||||
const format_tag_t gOdhwi16o = mkldnn_gOdhwi16o;
|
|
||||||
const format_tag_t gOdhwi4o = mkldnn_gOdhwi4o;
|
|
||||||
const format_tag_t gOdhwi8o = mkldnn_gOdhwi8o;
|
|
||||||
const format_tag_t gOIdhw16i16o = mkldnn_gOIdhw16i16o;
|
|
||||||
const format_tag_t gOIdhw16o16i = mkldnn_gOIdhw16o16i;
|
|
||||||
const format_tag_t gOidhw16o = mkldnn_gOidhw16o;
|
|
||||||
const format_tag_t gOIdhw4i4o = mkldnn_gOIdhw4i4o;
|
|
||||||
const format_tag_t gOidhw4o = mkldnn_gOidhw4o;
|
|
||||||
const format_tag_t gOIdhw8i16o2i = mkldnn_gOIdhw8i16o2i;
|
|
||||||
const format_tag_t gOIdhw8i8o = mkldnn_gOIdhw8i8o;
|
|
||||||
const format_tag_t gOIdhw8o8i = mkldnn_gOIdhw8o8i;
|
|
||||||
}
|
|
||||||
|
|
||||||
using memory_extra_flags_t = mkldnn_memory_extra_flags_t;
|
|
||||||
namespace memory_extra_flags {
|
|
||||||
const memory_extra_flags_t none = mkldnn_memory_extra_flag_none;
|
|
||||||
const memory_extra_flags_t compensation_conv_s8s8 = mkldnn_memory_extra_flag_compensation_conv_s8s8;
|
|
||||||
const memory_extra_flags_t scale_adjust = mkldnn_memory_extra_flag_scale_adjust;
|
|
||||||
}
|
|
||||||
|
|
||||||
using padding_kind_t = mkldnn_padding_kind_t;
|
|
||||||
namespace padding_kind {
|
|
||||||
const padding_kind_t padding_zero = mkldnn_padding_zero;
|
|
||||||
}
|
|
||||||
|
|
||||||
using engine_kind_t = mkldnn_engine_kind_t;
|
|
||||||
namespace engine_kind {
|
|
||||||
const engine_kind_t any_engine = mkldnn_any_engine;
|
|
||||||
const engine_kind_t cpu = mkldnn_cpu;
|
|
||||||
}
|
|
||||||
|
|
||||||
using primitive_kind_t = mkldnn_primitive_kind_t;
|
|
||||||
namespace primitive_kind {
|
|
||||||
const primitive_kind_t undefined = mkldnn_undefined_primitive;
|
|
||||||
const primitive_kind_t reorder = mkldnn_reorder;
|
|
||||||
const primitive_kind_t concat = mkldnn_concat;
|
|
||||||
const primitive_kind_t sum = mkldnn_sum;
|
|
||||||
const primitive_kind_t convolution = mkldnn_convolution;
|
|
||||||
const primitive_kind_t deconvolution = mkldnn_deconvolution;
|
|
||||||
const primitive_kind_t shuffle = mkldnn_shuffle;
|
|
||||||
const primitive_kind_t eltwise = mkldnn_eltwise;
|
|
||||||
const primitive_kind_t softmax = mkldnn_softmax;
|
|
||||||
const primitive_kind_t pooling = mkldnn_pooling;
|
|
||||||
const primitive_kind_t lrn = mkldnn_lrn;
|
|
||||||
const primitive_kind_t batch_normalization = mkldnn_batch_normalization;
|
|
||||||
const primitive_kind_t inner_product = mkldnn_inner_product;
|
|
||||||
const primitive_kind_t rnn = mkldnn_rnn;
|
|
||||||
}
|
|
||||||
|
|
||||||
using query_t = mkldnn_query_t;
|
|
||||||
namespace query {
|
|
||||||
const query_t undef = mkldnn_query_undef;
|
|
||||||
|
|
||||||
const query_t engine = mkldnn_query_engine;
|
|
||||||
const query_t primitive_kind = mkldnn_query_primitive_kind;
|
|
||||||
|
|
||||||
const query_t num_of_inputs_s32 = mkldnn_query_num_of_inputs_s32;
|
|
||||||
const query_t num_of_outputs_s32 = mkldnn_query_num_of_outputs_s32;
|
|
||||||
|
|
||||||
const query_t time_estimate_f64 = mkldnn_query_time_estimate_f64;
|
|
||||||
const query_t memory_consumption_s64 = mkldnn_query_memory_consumption_s64;
|
|
||||||
|
|
||||||
const query_t scratchpad_engine = mkldnn_query_scratchpad_engine;
|
|
||||||
|
|
||||||
const query_t impl_info_str = mkldnn_query_impl_info_str;
|
|
||||||
|
|
||||||
const query_t some_d = mkldnn_query_some_d;
|
|
||||||
const query_t op_d = mkldnn_query_op_d;
|
|
||||||
const query_t convolution_d = mkldnn_query_convolution_d;
|
|
||||||
const query_t deconvolution_d = mkldnn_query_deconvolution_d;
|
|
||||||
const query_t shuffle_d = mkldnn_query_shuffle_d;
|
|
||||||
const query_t eltwise_d = mkldnn_query_eltwise_d;
|
|
||||||
const query_t softmax_d = mkldnn_query_softmax_d;
|
|
||||||
const query_t pooling_d = mkldnn_query_pooling_d;
|
|
||||||
const query_t lrn_d = mkldnn_query_lrn_d;
|
|
||||||
const query_t batch_normalization_d = mkldnn_query_batch_normalization_d;
|
|
||||||
const query_t inner_product_d = mkldnn_query_inner_product_d;
|
|
||||||
const query_t rnn_d = mkldnn_query_rnn_d;
|
|
||||||
|
|
||||||
const query_t some_md = mkldnn_query_some_md;
|
|
||||||
const query_t src_md = mkldnn_query_src_md;
|
|
||||||
const query_t diff_src_md = mkldnn_query_diff_src_md;
|
|
||||||
const query_t weights_md = mkldnn_query_weights_md;
|
|
||||||
const query_t diff_weights_md = mkldnn_query_diff_weights_md;
|
|
||||||
const query_t dst_md = mkldnn_query_dst_md;
|
|
||||||
const query_t diff_dst_md = mkldnn_query_diff_dst_md;
|
|
||||||
|
|
||||||
const query_t workspace_md = mkldnn_query_workspace_md;
|
|
||||||
const query_t scratchpad_md = mkldnn_query_scratchpad_md;
|
|
||||||
}
|
|
||||||
|
|
||||||
using blocking_desc_t = mkldnn_blocking_desc_t;
|
|
||||||
using rnn_packed_desc_t = mkldnn_rnn_packed_desc_t;
|
|
||||||
using wino_desc_t = mkldnn_wino_desc_t;
|
|
||||||
using memory_extra_desc_t = mkldnn_memory_extra_desc_t;
|
|
||||||
using memory_desc_t = mkldnn_memory_desc_t;
|
|
||||||
using convolution_desc_t = mkldnn_convolution_desc_t;
|
|
||||||
using deconvolution_desc_t = mkldnn_deconvolution_desc_t;
|
|
||||||
using shuffle_desc_t = mkldnn_shuffle_desc_t;
|
|
||||||
using pooling_desc_t = mkldnn_pooling_desc_t;
|
|
||||||
using eltwise_desc_t = mkldnn_eltwise_desc_t;
|
|
||||||
using softmax_desc_t = mkldnn_softmax_desc_t;
|
|
||||||
using lrn_desc_t = mkldnn_lrn_desc_t;
|
|
||||||
using batch_normalization_desc_t = mkldnn_batch_normalization_desc_t;
|
|
||||||
using inner_product_desc_t = mkldnn_inner_product_desc_t;
|
|
||||||
|
|
||||||
using rnn_direction_t = mkldnn_rnn_direction_t;
|
|
||||||
using rnn_cell_desc_t = mkldnn_rnn_cell_desc_t;
|
|
||||||
using rnn_desc_t = mkldnn_rnn_desc_t;
|
|
||||||
|
|
||||||
/* C op_desc_t, which eventually are just (void*) */
|
|
||||||
using c_op_desc_t = mkldnn_op_desc_t;
|
|
||||||
using const_c_op_desc_t = const_mkldnn_op_desc_t;
|
|
||||||
|
|
||||||
struct op_desc_t {
|
|
||||||
union {
|
|
||||||
primitive_kind_t kind;
|
|
||||||
convolution_desc_t convolution;
|
|
||||||
deconvolution_desc_t deconvolution;
|
|
||||||
shuffle_desc_t shuffle;
|
|
||||||
pooling_desc_t pooling;
|
|
||||||
eltwise_desc_t eltwise;
|
|
||||||
softmax_desc_t softmax;
|
|
||||||
lrn_desc_t lrn;
|
|
||||||
batch_normalization_desc_t batch_normalization;
|
|
||||||
inner_product_desc_t inner_product;
|
|
||||||
rnn_desc_t rnn;
|
|
||||||
};
|
|
||||||
|
|
||||||
op_desc_t(const primitive_kind_t &_): kind(_) {}
|
|
||||||
|
|
||||||
# define DECL_CTOR_AND_CONVERTERS(c_type, name) \
|
|
||||||
op_desc_t(const c_type &_): name(_) {} \
|
|
||||||
static op_desc_t *convert_from_c(c_type *_) \
|
|
||||||
{ return reinterpret_cast<op_desc_t*>(_); } \
|
|
||||||
static const op_desc_t *convert_from_c(const c_type *_) \
|
|
||||||
{ return reinterpret_cast<const op_desc_t*>(_); }
|
|
||||||
|
|
||||||
DECL_CTOR_AND_CONVERTERS(convolution_desc_t, convolution);
|
|
||||||
DECL_CTOR_AND_CONVERTERS(shuffle_desc_t, shuffle);
|
|
||||||
DECL_CTOR_AND_CONVERTERS(pooling_desc_t, pooling);
|
|
||||||
DECL_CTOR_AND_CONVERTERS(eltwise_desc_t, eltwise);
|
|
||||||
DECL_CTOR_AND_CONVERTERS(softmax_desc_t, softmax);
|
|
||||||
DECL_CTOR_AND_CONVERTERS(lrn_desc_t, lrn);
|
|
||||||
DECL_CTOR_AND_CONVERTERS(batch_normalization_desc_t, batch_normalization);
|
|
||||||
DECL_CTOR_AND_CONVERTERS(inner_product_desc_t, inner_product);
|
|
||||||
DECL_CTOR_AND_CONVERTERS(rnn_desc_t, rnn);
|
|
||||||
|
|
||||||
# undef DECL_CTOR_AND_CONVERTERS
|
|
||||||
};
|
|
||||||
|
|
||||||
using engine_t = mkldnn_engine;
|
|
||||||
using primitive_desc_iterator_t = mkldnn_primitive_desc_iterator;
|
|
||||||
using primitive_desc_t = mkldnn_primitive_desc;
|
|
||||||
using primitive_attr_t = mkldnn_primitive_attr;
|
|
||||||
using post_ops_t = mkldnn_post_ops;
|
|
||||||
using memory_t = mkldnn_memory;
|
|
||||||
using primitive_t = mkldnn_primitive;
|
|
||||||
|
|
||||||
using primitive_arg_index_t = int;
|
|
||||||
|
|
||||||
using stream_flags_t = mkldnn_stream_flags_t;
|
|
||||||
namespace stream_flags {
|
|
||||||
const stream_flags_t default_flags = mkldnn_stream_default_flags;
|
|
||||||
}
|
|
||||||
using stream_t = mkldnn_stream;
|
|
||||||
|
|
||||||
/* forward declaration of the internal primitive_desc types */
|
|
||||||
struct batch_normalization_bwd_pd_t;
|
|
||||||
struct batch_normalization_fwd_pd_t;
|
|
||||||
struct batch_normalization_pd_t;
|
|
||||||
struct concat_pd_t;
|
|
||||||
struct convolution_bwd_data_pd_t;
|
|
||||||
struct convolution_bwd_weights_pd_t;
|
|
||||||
struct convolution_fwd_pd_t;
|
|
||||||
struct convolution_pd_t;
|
|
||||||
struct deconvolution_bwd_data_pd_t;
|
|
||||||
struct deconvolution_bwd_weights_pd_t;
|
|
||||||
struct deconvolution_fwd_pd_t;
|
|
||||||
struct deconvolution_pd_t;
|
|
||||||
struct eltwise_bwd_pd_t;
|
|
||||||
struct eltwise_fwd_pd_t;
|
|
||||||
struct eltwise_pd_t;
|
|
||||||
struct inner_product_bwd_data_pd_t;
|
|
||||||
struct inner_product_bwd_weights_pd_t;
|
|
||||||
struct inner_product_fwd_pd_t;
|
|
||||||
struct inner_product_pd_t;
|
|
||||||
struct lrn_bwd_pd_t;
|
|
||||||
struct lrn_fwd_pd_t;
|
|
||||||
struct lrn_pd_t;
|
|
||||||
struct pooling_bwd_pd_t;
|
|
||||||
struct pooling_fwd_pd_t;
|
|
||||||
struct pooling_pd_t;
|
|
||||||
struct reorder_pd_t;
|
|
||||||
struct rnn_bwd_pd_t;
|
|
||||||
struct rnn_fwd_pd_t;
|
|
||||||
struct rnn_pd_t;
|
|
||||||
struct shuffle_pd_t;
|
|
||||||
struct softmax_bwd_pd_t;
|
|
||||||
struct softmax_fwd_pd_t;
|
|
||||||
struct softmax_pd_t;
|
|
||||||
struct sum_pd_t;
|
|
||||||
|
|
||||||
}
|
|
||||||
}
|
|
||||||
|
|
||||||
#endif
|
|
||||||
|
|
||||||
// vim: et ts=4 sw=4 cindent cino^=l0,\:0,N-s
|
|
86
thirdparty/oidn/mkl-dnn/src/common/concat.cpp
vendored
86
thirdparty/oidn/mkl-dnn/src/common/concat.cpp
vendored
@ -1,86 +0,0 @@
|
|||||||
/*******************************************************************************
|
|
||||||
* Copyright 2018 Intel Corporation
|
|
||||||
*
|
|
||||||
* Licensed under the Apache License, Version 2.0 (the "License");
|
|
||||||
* you may not use this file except in compliance with the License.
|
|
||||||
* You may obtain a copy of the License at
|
|
||||||
*
|
|
||||||
* http://www.apache.org/licenses/LICENSE-2.0
|
|
||||||
*
|
|
||||||
* Unless required by applicable law or agreed to in writing, software
|
|
||||||
* distributed under the License is distributed on an "AS IS" BASIS,
|
|
||||||
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
|
||||||
* See the License for the specific language governing permissions and
|
|
||||||
* limitations under the License.
|
|
||||||
*******************************************************************************/
|
|
||||||
|
|
||||||
#include <assert.h>
|
|
||||||
|
|
||||||
#include "mkldnn.h"
|
|
||||||
|
|
||||||
#include "c_types_map.hpp"
|
|
||||||
#include "engine.hpp"
|
|
||||||
#include "type_helpers.hpp"
|
|
||||||
#include "utils.hpp"
|
|
||||||
|
|
||||||
#include "concat_pd.hpp"
|
|
||||||
|
|
||||||
using namespace mkldnn::impl;
|
|
||||||
using namespace mkldnn::impl::utils;
|
|
||||||
using namespace mkldnn::impl::status;
|
|
||||||
|
|
||||||
status_t mkldnn_concat_primitive_desc_create(primitive_desc_t **concat_pd,
|
|
||||||
const memory_desc_t *dst_md, int n, int concat_dim,
|
|
||||||
const memory_desc_t *src_mds,
|
|
||||||
const primitive_attr_t *attr,
|
|
||||||
engine_t *engine) {
|
|
||||||
bool args_ok = !any_null(concat_pd, src_mds) && n > 0;
|
|
||||||
if (!args_ok) return invalid_arguments;
|
|
||||||
|
|
||||||
const primitive_attr_t dummy_attr;
|
|
||||||
if (attr == NULL)
|
|
||||||
attr = &dummy_attr;
|
|
||||||
|
|
||||||
const int ndims = src_mds[0].ndims;
|
|
||||||
const dims_t &dims = src_mds[0].dims;
|
|
||||||
const data_type_t dt = src_mds[0].data_type;
|
|
||||||
|
|
||||||
int concat_dim_sz = dims[concat_dim];
|
|
||||||
for (int i = 1; i < n; ++i) {
|
|
||||||
if (src_mds[i].ndims != ndims) return invalid_arguments;
|
|
||||||
for (int d = 0; d < ndims; ++d) {
|
|
||||||
if (d == concat_dim) continue;
|
|
||||||
if (src_mds[i].dims[d] != dims[d])
|
|
||||||
return invalid_arguments;
|
|
||||||
}
|
|
||||||
if (src_mds[i].data_type != dt) return invalid_arguments;
|
|
||||||
concat_dim_sz += src_mds[i].dims[concat_dim];
|
|
||||||
}
|
|
||||||
|
|
||||||
memory_desc_t dummy_dst_md;
|
|
||||||
if (dst_md) {
|
|
||||||
if (dst_md->ndims != ndims) return invalid_arguments;
|
|
||||||
for (int d = 0; d < ndims; ++d) {
|
|
||||||
if (dst_md->dims[d] !=
|
|
||||||
(d == concat_dim ? concat_dim_sz : dims[d]))
|
|
||||||
return invalid_arguments;
|
|
||||||
}
|
|
||||||
} else {
|
|
||||||
dummy_dst_md = src_mds[0];
|
|
||||||
dummy_dst_md.dims[concat_dim] = concat_dim_sz;
|
|
||||||
dummy_dst_md.format_kind = format_kind::any;
|
|
||||||
dst_md = &dummy_dst_md;
|
|
||||||
}
|
|
||||||
|
|
||||||
auto c_pd = reinterpret_cast<concat_pd_t **>(concat_pd);
|
|
||||||
|
|
||||||
for (auto c = engine->get_concat_implementation_list(); *c; ++c) {
|
|
||||||
if ((*c)(c_pd, engine, attr, dst_md, n, concat_dim, src_mds)
|
|
||||||
== success) {
|
|
||||||
(*c_pd)->init_info();
|
|
||||||
(*c_pd)->init_scratchpad_md();
|
|
||||||
return success;
|
|
||||||
}
|
|
||||||
}
|
|
||||||
return unimplemented;
|
|
||||||
}
|
|
211
thirdparty/oidn/mkl-dnn/src/common/concat_pd.hpp
vendored
211
thirdparty/oidn/mkl-dnn/src/common/concat_pd.hpp
vendored
@ -1,211 +0,0 @@
|
|||||||
/*******************************************************************************
|
|
||||||
* Copyright 2019 Intel Corporation
|
|
||||||
*
|
|
||||||
* Licensed under the Apache License, Version 2.0 (the "License");
|
|
||||||
* you may not use this file except in compliance with the License.
|
|
||||||
* You may obtain a copy of the License at
|
|
||||||
*
|
|
||||||
* http://www.apache.org/licenses/LICENSE-2.0
|
|
||||||
*
|
|
||||||
* Unless required by applicable law or agreed to in writing, software
|
|
||||||
* distributed under the License is distributed on an "AS IS" BASIS,
|
|
||||||
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
|
||||||
* See the License for the specific language governing permissions and
|
|
||||||
* limitations under the License.
|
|
||||||
*******************************************************************************/
|
|
||||||
|
|
||||||
#ifndef CONCAT_PD_HPP
|
|
||||||
#define CONCAT_PD_HPP
|
|
||||||
|
|
||||||
#include <assert.h>
|
|
||||||
|
|
||||||
#include "c_types_map.hpp"
|
|
||||||
#include "nstl.hpp"
|
|
||||||
#include "primitive_desc.hpp"
|
|
||||||
#include "type_helpers.hpp"
|
|
||||||
#include "utils.hpp"
|
|
||||||
|
|
||||||
namespace mkldnn {
|
|
||||||
namespace impl {
|
|
||||||
|
|
||||||
struct concat_pd_t: public primitive_desc_t {
|
|
||||||
concat_pd_t(engine_t *engine, const primitive_attr_t *attr,
|
|
||||||
const memory_desc_t *dst_md, int n, int concat_dim,
|
|
||||||
const memory_desc_t *src_mds)
|
|
||||||
: primitive_desc_t(engine, attr, primitive_kind::concat)
|
|
||||||
, n_(n), concat_dim_(concat_dim), dst_md_(*dst_md)
|
|
||||||
{
|
|
||||||
src_mds_.reserve(n_);
|
|
||||||
for (int i = 0; i < n_; ++i) src_mds_.push_back(src_mds[i]);
|
|
||||||
}
|
|
||||||
|
|
||||||
concat_pd_t(const concat_pd_t &rhs) = default;
|
|
||||||
|
|
||||||
virtual void init_info() override { impl::init_info(this, this->info_); }
|
|
||||||
|
|
||||||
virtual arg_usage_t arg_usage(primitive_arg_index_t arg) const override {
|
|
||||||
if (arg >= MKLDNN_ARG_MULTIPLE_SRC
|
|
||||||
&& arg < MKLDNN_ARG_MULTIPLE_SRC + n_inputs())
|
|
||||||
return arg_usage_t::input;
|
|
||||||
|
|
||||||
if (arg == MKLDNN_ARG_DST)
|
|
||||||
return arg_usage_t::output;
|
|
||||||
|
|
||||||
return primitive_desc_t::arg_usage(arg);
|
|
||||||
}
|
|
||||||
|
|
||||||
virtual const memory_desc_t *src_md(int index = 0) const override
|
|
||||||
{ return index < n_inputs() ? &src_mds_[index] : nullptr; }
|
|
||||||
virtual const memory_desc_t *dst_md(int index = 0) const override
|
|
||||||
{ return index == 0 ? &dst_md_ : nullptr; }
|
|
||||||
|
|
||||||
virtual int n_inputs() const override { return n_; }
|
|
||||||
virtual int n_outputs() const override { return 1; }
|
|
||||||
|
|
||||||
int concat_dim() const { return concat_dim_; }
|
|
||||||
|
|
||||||
const memory_desc_t *src_image_md(int index = 0) const
|
|
||||||
{ return index < n_inputs() ? &src_image_mds_[index] : nullptr; }
|
|
||||||
|
|
||||||
protected:
|
|
||||||
int n_, concat_dim_;
|
|
||||||
memory_desc_t dst_md_;
|
|
||||||
nstl::vector<memory_desc_t> src_mds_;
|
|
||||||
|
|
||||||
/* contains images of srcs in the dst memory (if possible)
|
|
||||||
* Lives here to simplify some implementations. An implementation might
|
|
||||||
* use this auxiliary array iff init() returned success */
|
|
||||||
nstl::vector<memory_desc_t> src_image_mds_;
|
|
||||||
|
|
||||||
protected:
|
|
||||||
/* inits src_image_mds_ and dst_md_ in simple cases. The call may fail */
|
|
||||||
status_t init() {
|
|
||||||
bool ok = true
|
|
||||||
&& set_default_params() == status::success
|
|
||||||
&& attr()->has_default_values();
|
|
||||||
if (!ok) return status::unimplemented;
|
|
||||||
|
|
||||||
for (int i = 0; i < n_; ++i) {
|
|
||||||
const memory_desc_wrapper i_d(&src_mds_[i]);
|
|
||||||
if (!i_d.is_blocking_desc() || i_d.is_additional_buffer())
|
|
||||||
return status::unimplemented;
|
|
||||||
}
|
|
||||||
|
|
||||||
const int ndims = dst_md_.ndims;
|
|
||||||
int current_concat_dim_offset = 0;
|
|
||||||
for (int i = 0; i < n_; ++i) {
|
|
||||||
const int dim = src_mds_[i].dims[concat_dim_];
|
|
||||||
dims_t dims, offsets = {};
|
|
||||||
utils::array_copy(dims, dst_md_.dims, ndims);
|
|
||||||
dims[concat_dim_] = dim;
|
|
||||||
offsets[concat_dim_] = current_concat_dim_offset;
|
|
||||||
|
|
||||||
memory_desc_t src_img_d;
|
|
||||||
status_t status = mkldnn_memory_desc_init_submemory(&src_img_d,
|
|
||||||
&dst_md_, dims, offsets);
|
|
||||||
if (status != status::success) return status;
|
|
||||||
src_image_mds_.push_back(src_img_d);
|
|
||||||
current_concat_dim_offset += dim;
|
|
||||||
}
|
|
||||||
|
|
||||||
return status::success;
|
|
||||||
}
|
|
||||||
|
|
||||||
status_t set_default_params() {
|
|
||||||
if (dst_md_.format_kind != format_kind::any)
|
|
||||||
return status::success;
|
|
||||||
|
|
||||||
const int ndims = dst_md_.ndims;
|
|
||||||
|
|
||||||
/* The stupidest ever heuristics (but not the same as we had before):
|
|
||||||
* - Pick the first non-plain format;
|
|
||||||
* - If all formats are plain or it is not possible to create a
|
|
||||||
* blocked format for the output, pick the format of the plain input
|
|
||||||
* - If this fails as well, use plain layout (abcd...)
|
|
||||||
*/
|
|
||||||
status_t status = status::unimplemented;
|
|
||||||
for (int i = 0; i < n_; ++i) {
|
|
||||||
const memory_desc_wrapper src_d(src_mds_[i]);
|
|
||||||
if (src_d.is_blocking_desc() && !src_d.is_plain()) {
|
|
||||||
status = memory_desc_init_by_blocking_desc(dst_md_,
|
|
||||||
src_d.blocking_desc());
|
|
||||||
if (status == status::success) break;
|
|
||||||
}
|
|
||||||
}
|
|
||||||
|
|
||||||
if (status == status::success) {
|
|
||||||
/* check if we can create a sub-memory for the dst */
|
|
||||||
bool desired_format_ok = true;
|
|
||||||
int current_concat_dim_offset = 0;
|
|
||||||
for (int i = 0; i < n_; ++i) {
|
|
||||||
const int dim = src_mds_[i].dims[concat_dim_];
|
|
||||||
dims_t dims, offsets = {};
|
|
||||||
utils::array_copy(dims, dst_md_.dims, ndims);
|
|
||||||
dims[concat_dim_] = dim;
|
|
||||||
offsets[concat_dim_] = current_concat_dim_offset;
|
|
||||||
|
|
||||||
memory_desc_t src_img_d;
|
|
||||||
status_t status = mkldnn_memory_desc_init_submemory(&src_img_d,
|
|
||||||
&dst_md_, dims, offsets);
|
|
||||||
if (status != status::success) {
|
|
||||||
desired_format_ok = false;
|
|
||||||
break;
|
|
||||||
}
|
|
||||||
current_concat_dim_offset += dim;
|
|
||||||
}
|
|
||||||
|
|
||||||
if (!desired_format_ok)
|
|
||||||
status = status::unimplemented;
|
|
||||||
}
|
|
||||||
|
|
||||||
/* if no success so far, try using the format of the first plain input */
|
|
||||||
if (status != status::success) {
|
|
||||||
for (int i = 0; i < n_; ++i) {
|
|
||||||
const memory_desc_wrapper src_d(src_mds_[i]);
|
|
||||||
if (src_d.is_blocking_desc() && src_d.is_plain()) {
|
|
||||||
status = memory_desc_init_by_blocking_desc(dst_md_,
|
|
||||||
memory_desc_wrapper(src_mds_[0]).blocking_desc());
|
|
||||||
if (status == status::success) return status;
|
|
||||||
}
|
|
||||||
}
|
|
||||||
}
|
|
||||||
|
|
||||||
/* the last line of defense: use plain abcd... format */
|
|
||||||
if (status != status::success)
|
|
||||||
status = memory_desc_init_by_strides(dst_md_, nullptr);
|
|
||||||
|
|
||||||
return status;
|
|
||||||
}
|
|
||||||
};
|
|
||||||
|
|
||||||
#define DECLARE_CONCAT_PD_t(impl_name, ...) \
|
|
||||||
static status_t create(concat_pd_t **concat_pd, \
|
|
||||||
engine_t *engine, const primitive_attr_t *attr, \
|
|
||||||
const memory_desc_t *dst_md, int n, int concat_dim, \
|
|
||||||
const memory_desc_t *src_mds) { \
|
|
||||||
using namespace status; \
|
|
||||||
auto _pd = new pd_t(engine, attr, dst_md, n, concat_dim, src_mds); \
|
|
||||||
if (_pd == nullptr) return out_of_memory; \
|
|
||||||
if (_pd->init() != success) { delete _pd; return unimplemented; } \
|
|
||||||
return safe_ptr_assign<concat_pd_t>(*concat_pd, _pd); \
|
|
||||||
} \
|
|
||||||
virtual status_t create_primitive(primitive_t **p) const override { \
|
|
||||||
double ms = get_msec(); \
|
|
||||||
auto ret = safe_ptr_assign<primitive_t>(*p, new (__VA_ARGS__)(this)); \
|
|
||||||
ms = get_msec() - ms; \
|
|
||||||
if (mkldnn_verbose()->level >= 2) { \
|
|
||||||
printf("mkldnn_verbose,create,%s,%g\n", this->info(), ms); \
|
|
||||||
fflush(0); \
|
|
||||||
} \
|
|
||||||
return ret; \
|
|
||||||
} \
|
|
||||||
virtual pd_t *clone() const override { return new pd_t(*this); } \
|
|
||||||
virtual const char *name() const override { return impl_name; } \
|
|
||||||
|
|
||||||
#define DECLARE_CONCAT_PD_T(impl_name, ...) \
|
|
||||||
DECLARE_CONCAT_PD_t(impl_name, __VA_ARGS__)
|
|
||||||
|
|
||||||
}
|
|
||||||
}
|
|
||||||
|
|
||||||
#endif
|
|
200
thirdparty/oidn/mkl-dnn/src/common/convolution.cpp
vendored
200
thirdparty/oidn/mkl-dnn/src/common/convolution.cpp
vendored
@ -1,200 +0,0 @@
|
|||||||
/*******************************************************************************
|
|
||||||
* Copyright 2016-2018 Intel Corporation
|
|
||||||
*
|
|
||||||
* Licensed under the Apache License, Version 2.0 (the "License");
|
|
||||||
* you may not use this file except in compliance with the License.
|
|
||||||
* You may obtain a copy of the License at
|
|
||||||
*
|
|
||||||
* http://www.apache.org/licenses/LICENSE-2.0
|
|
||||||
*
|
|
||||||
* Unless required by applicable law or agreed to in writing, software
|
|
||||||
* distributed under the License is distributed on an "AS IS" BASIS,
|
|
||||||
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
|
||||||
* See the License for the specific language governing permissions and
|
|
||||||
* limitations under the License.
|
|
||||||
*******************************************************************************/
|
|
||||||
|
|
||||||
#include <assert.h>
|
|
||||||
#include "mkldnn.h"
|
|
||||||
|
|
||||||
#include "c_types_map.hpp"
|
|
||||||
#include "type_helpers.hpp"
|
|
||||||
#include "utils.hpp"
|
|
||||||
|
|
||||||
using namespace mkldnn::impl;
|
|
||||||
using namespace mkldnn::impl::utils;
|
|
||||||
using namespace mkldnn::impl::status;
|
|
||||||
using namespace mkldnn::impl::prop_kind;
|
|
||||||
using namespace mkldnn::impl::alg_kind;
|
|
||||||
using namespace mkldnn::impl::types;
|
|
||||||
|
|
||||||
namespace mkldnn {
|
|
||||||
namespace impl {
|
|
||||||
status_t conv_desc_init(convolution_desc_t *conv_desc,
|
|
||||||
prop_kind_t prop_kind, alg_kind_t alg_kind,
|
|
||||||
const memory_desc_t *src_desc, const memory_desc_t *weights_desc,
|
|
||||||
const memory_desc_t *bias_desc, const memory_desc_t *dst_desc,
|
|
||||||
const dims_t strides, const dims_t dilates,
|
|
||||||
const dims_t padding_l, const dims_t padding_r,
|
|
||||||
padding_kind_t padding_kind) {
|
|
||||||
bool args_ok = true
|
|
||||||
&& !any_null(conv_desc, src_desc, weights_desc, dst_desc, strides,
|
|
||||||
padding_l)
|
|
||||||
&& one_of(alg_kind, convolution_auto, convolution_direct, convolution_winograd)
|
|
||||||
&& one_of(padding_kind, padding_kind::padding_zero);
|
|
||||||
if (!args_ok) return invalid_arguments;
|
|
||||||
|
|
||||||
if (padding_r == nullptr) padding_r = padding_l;
|
|
||||||
|
|
||||||
auto cd = convolution_desc_t();
|
|
||||||
cd.primitive_kind = primitive_kind::convolution;
|
|
||||||
cd.prop_kind = prop_kind;
|
|
||||||
cd.alg_kind = alg_kind;
|
|
||||||
|
|
||||||
cd.diff_src_desc = cd.src_desc = zero_md();
|
|
||||||
cd.diff_dst_desc = cd.dst_desc = zero_md();
|
|
||||||
cd.diff_weights_desc = cd.weights_desc = zero_md();
|
|
||||||
cd.diff_bias_desc = cd.bias_desc = zero_md();
|
|
||||||
|
|
||||||
const bool is_fwd = one_of(prop_kind, forward_training, forward_inference);
|
|
||||||
const bool with_bias =
|
|
||||||
bias_desc && bias_desc->format_kind != format_kind::undef;
|
|
||||||
const bool with_groups = weights_desc->ndims == src_desc->ndims + 1;
|
|
||||||
|
|
||||||
(prop_kind == backward_data ? cd.diff_src_desc : cd.src_desc) = *src_desc;
|
|
||||||
(is_fwd ? cd.dst_desc : cd.diff_dst_desc) = *dst_desc;
|
|
||||||
(prop_kind == backward_weights ? cd.diff_weights_desc : cd.weights_desc) =
|
|
||||||
*weights_desc;
|
|
||||||
if (with_bias)
|
|
||||||
(prop_kind == backward_weights ? cd.diff_bias_desc : cd.bias_desc) =
|
|
||||||
*bias_desc;
|
|
||||||
|
|
||||||
int sp_dims = src_desc->ndims - 2;
|
|
||||||
utils::array_copy(cd.strides, strides, sp_dims);
|
|
||||||
utils::array_copy(cd.padding[0], padding_l, sp_dims);
|
|
||||||
utils::array_copy(cd.padding[1], padding_r, sp_dims);
|
|
||||||
if (dilates)
|
|
||||||
utils::array_copy(cd.dilates, dilates, sp_dims);
|
|
||||||
else
|
|
||||||
utils::array_set(cd.dilates, 0, sp_dims);
|
|
||||||
|
|
||||||
cd.padding_kind = padding_kind;
|
|
||||||
cd.accum_data_type = types::default_accum_data_type(src_desc->data_type,
|
|
||||||
weights_desc->data_type, dst_desc->data_type, prop_kind);
|
|
||||||
|
|
||||||
const int g = with_groups ? weights_desc->dims[0] : 1;
|
|
||||||
const int bias_dim = prop_kind == backward_data
|
|
||||||
? src_desc->dims[1]
|
|
||||||
: dst_desc->dims[1];
|
|
||||||
|
|
||||||
bool consistency = true
|
|
||||||
&& memory_desc_wrapper(weights_desc).nelems()
|
|
||||||
&& src_desc->ndims == dst_desc->ndims
|
|
||||||
&& utils::one_of(src_desc->ndims, 3, 4, 5)
|
|
||||||
&& utils::one_of(weights_desc->ndims, src_desc->ndims,
|
|
||||||
src_desc->ndims + 1)
|
|
||||||
&& (with_bias ? bias_desc->ndims == 1 : true)
|
|
||||||
&& (with_bias ? bias_desc->dims[0] == bias_dim : true)
|
|
||||||
&& src_desc->dims[0] == dst_desc->dims[0]
|
|
||||||
&& src_desc->dims[1] == g * weights_desc->dims[with_groups + 1]
|
|
||||||
&& dst_desc->dims[1] == g * weights_desc->dims[with_groups + 0];
|
|
||||||
for (int i = 2; i < src_desc->ndims; ++i)
|
|
||||||
{
|
|
||||||
int src = src_desc->dims[i];
|
|
||||||
int ker = weights_desc->dims[with_groups + i];
|
|
||||||
int dil = cd.dilates[i - 2];
|
|
||||||
int pad_l = padding_l[i - 2];
|
|
||||||
int pad_r = padding_r[i - 2];
|
|
||||||
int str = strides[i - 2];
|
|
||||||
int dst = dst_desc->dims[i];
|
|
||||||
int ker_range = 1 + (ker - 1) * (dil + 1);
|
|
||||||
|
|
||||||
if (str < 1) return invalid_arguments;
|
|
||||||
consistency = consistency
|
|
||||||
&& dil >= 0
|
|
||||||
&& pad_l >= 0
|
|
||||||
&& pad_r + str > 0
|
|
||||||
&& (src - ker_range + pad_l + pad_r) / str + 1 == dst;
|
|
||||||
}
|
|
||||||
if (!consistency) return invalid_arguments;
|
|
||||||
|
|
||||||
*conv_desc = cd;
|
|
||||||
return success;
|
|
||||||
}
|
|
||||||
}
|
|
||||||
}
|
|
||||||
|
|
||||||
status_t mkldnn_convolution_forward_desc_init(convolution_desc_t *conv_desc,
|
|
||||||
prop_kind_t prop_kind, alg_kind_t alg_kind,
|
|
||||||
const memory_desc_t *src_desc, const memory_desc_t *weights_desc,
|
|
||||||
const memory_desc_t *bias_desc, const memory_desc_t *dst_desc,
|
|
||||||
const dims_t strides, const dims_t padding_l, const dims_t padding_r,
|
|
||||||
padding_kind_t padding_kind) {
|
|
||||||
if (!one_of(prop_kind, forward_training, forward_inference))
|
|
||||||
return invalid_arguments;
|
|
||||||
return mkldnn::impl::conv_desc_init(conv_desc, prop_kind, alg_kind, src_desc,
|
|
||||||
weights_desc, bias_desc, dst_desc, strides, nullptr,
|
|
||||||
padding_l, padding_r, padding_kind);
|
|
||||||
}
|
|
||||||
|
|
||||||
status_t mkldnn_dilated_convolution_forward_desc_init(
|
|
||||||
convolution_desc_t *conv_desc, prop_kind_t prop_kind,
|
|
||||||
alg_kind_t alg_kind, const memory_desc_t *src_desc,
|
|
||||||
const memory_desc_t *weights_desc, const memory_desc_t *bias_desc,
|
|
||||||
const memory_desc_t *dst_desc, const dims_t strides,
|
|
||||||
const dims_t dilates, const dims_t padding_l,
|
|
||||||
const dims_t padding_r, padding_kind_t padding_kind) {
|
|
||||||
if (!one_of(prop_kind, forward_training, forward_inference))
|
|
||||||
return invalid_arguments;
|
|
||||||
return mkldnn::impl::conv_desc_init(conv_desc, prop_kind, alg_kind, src_desc,
|
|
||||||
weights_desc, bias_desc, dst_desc, strides, dilates,
|
|
||||||
padding_l, padding_r, padding_kind);
|
|
||||||
}
|
|
||||||
|
|
||||||
status_t mkldnn_convolution_backward_data_desc_init(
|
|
||||||
convolution_desc_t *conv_desc, alg_kind_t alg_kind,
|
|
||||||
const memory_desc_t *diff_src_desc, const memory_desc_t *weights_desc,
|
|
||||||
const memory_desc_t *diff_dst_desc, const dims_t strides,
|
|
||||||
const dims_t padding_l, const dims_t padding_r,
|
|
||||||
padding_kind_t padding_kind) {
|
|
||||||
return mkldnn::impl::conv_desc_init(conv_desc, backward_data, alg_kind, diff_src_desc,
|
|
||||||
weights_desc, nullptr, diff_dst_desc, strides, nullptr,
|
|
||||||
padding_l, padding_r, padding_kind);
|
|
||||||
}
|
|
||||||
|
|
||||||
status_t mkldnn_dilated_convolution_backward_data_desc_init(
|
|
||||||
convolution_desc_t *conv_desc, alg_kind_t alg_kind,
|
|
||||||
const memory_desc_t *diff_src_desc, const memory_desc_t *weights_desc,
|
|
||||||
const memory_desc_t *diff_dst_desc, const dims_t strides,
|
|
||||||
const dims_t dilates, const dims_t padding_l, const dims_t padding_r,
|
|
||||||
padding_kind_t padding_kind) {
|
|
||||||
return mkldnn::impl::conv_desc_init(conv_desc, backward_data, alg_kind, diff_src_desc,
|
|
||||||
weights_desc, nullptr, diff_dst_desc, strides, dilates,
|
|
||||||
padding_l, padding_r, padding_kind);
|
|
||||||
}
|
|
||||||
|
|
||||||
status_t mkldnn_convolution_backward_weights_desc_init(
|
|
||||||
convolution_desc_t *conv_desc, alg_kind_t alg_kind,
|
|
||||||
const memory_desc_t *src_desc, const memory_desc_t *diff_weights_desc,
|
|
||||||
const memory_desc_t *diff_bias_desc,
|
|
||||||
const memory_desc_t *diff_dst_desc, const dims_t strides,
|
|
||||||
const dims_t padding_l, const dims_t padding_r,
|
|
||||||
padding_kind_t padding_kind) {
|
|
||||||
return mkldnn::impl::conv_desc_init(conv_desc, backward_weights, alg_kind, src_desc,
|
|
||||||
diff_weights_desc, diff_bias_desc, diff_dst_desc, strides,
|
|
||||||
nullptr, padding_l, padding_r, padding_kind);
|
|
||||||
}
|
|
||||||
|
|
||||||
status_t mkldnn_dilated_convolution_backward_weights_desc_init(
|
|
||||||
convolution_desc_t *conv_desc, alg_kind_t alg_kind,
|
|
||||||
const memory_desc_t *src_desc, const memory_desc_t *diff_weights_desc,
|
|
||||||
const memory_desc_t *diff_bias_desc,
|
|
||||||
const memory_desc_t *diff_dst_desc, const dims_t strides,
|
|
||||||
const dims_t dilates, const dims_t padding_l, const dims_t padding_r,
|
|
||||||
padding_kind_t padding_kind) {
|
|
||||||
return mkldnn::impl::conv_desc_init(conv_desc, backward_weights, alg_kind, src_desc,
|
|
||||||
diff_weights_desc, diff_bias_desc, diff_dst_desc, strides,
|
|
||||||
dilates, padding_l, padding_r, padding_kind);
|
|
||||||
}
|
|
||||||
|
|
||||||
// vim: et ts=4 sw=4 cindent cino^=l0,\:0,N-s
|
|
@ -1,56 +0,0 @@
|
|||||||
/*******************************************************************************
|
|
||||||
* Copyright 2018 Intel Corporation
|
|
||||||
*
|
|
||||||
* Licensed under the Apache License, Version 2.0 (the "License");
|
|
||||||
* you may not use this file except in compliance with the License.
|
|
||||||
* You may obtain a copy of the License at
|
|
||||||
*
|
|
||||||
* http://www.apache.org/licenses/LICENSE-2.0
|
|
||||||
*
|
|
||||||
* Unless required by applicable law or agreed to in writing, software
|
|
||||||
* distributed under the License is distributed on an "AS IS" BASIS,
|
|
||||||
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
|
||||||
* See the License for the specific language governing permissions and
|
|
||||||
* limitations under the License.
|
|
||||||
*******************************************************************************/
|
|
||||||
|
|
||||||
#include "utils.hpp"
|
|
||||||
|
|
||||||
#include "convolution_pd.hpp"
|
|
||||||
|
|
||||||
namespace mkldnn {
|
|
||||||
namespace impl {
|
|
||||||
|
|
||||||
using namespace prop_kind;
|
|
||||||
|
|
||||||
memory_desc_t *conv_prop_invariant_src_d(convolution_desc_t *desc) {
|
|
||||||
return desc->prop_kind == backward_data
|
|
||||||
? &desc->diff_src_desc : &desc->src_desc;
|
|
||||||
}
|
|
||||||
|
|
||||||
memory_desc_t *conv_prop_invariant_wei_d(convolution_desc_t *desc) {
|
|
||||||
return desc->prop_kind == backward_weights
|
|
||||||
? &desc->diff_weights_desc : &desc->weights_desc;
|
|
||||||
}
|
|
||||||
|
|
||||||
memory_desc_t *conv_prop_invariant_bia_d(convolution_desc_t *desc) {
|
|
||||||
return desc->prop_kind == backward_weights
|
|
||||||
? &desc->diff_bias_desc : &desc->bias_desc;
|
|
||||||
}
|
|
||||||
|
|
||||||
memory_desc_t *conv_prop_invariant_dst_d(convolution_desc_t *desc) {
|
|
||||||
return utils::one_of(desc->prop_kind, forward_inference, forward_training)
|
|
||||||
? &desc->dst_desc : &desc->diff_dst_desc;
|
|
||||||
}
|
|
||||||
|
|
||||||
const memory_desc_t *conv_prop_invariant_src_d(const convolution_desc_t *desc)
|
|
||||||
{ return conv_prop_invariant_src_d(const_cast<convolution_desc_t *>(desc)); }
|
|
||||||
const memory_desc_t *conv_prop_invariant_wei_d(const convolution_desc_t *desc)
|
|
||||||
{ return conv_prop_invariant_wei_d(const_cast<convolution_desc_t *>(desc)); }
|
|
||||||
const memory_desc_t *conv_prop_invariant_bia_d(const convolution_desc_t *desc)
|
|
||||||
{ return conv_prop_invariant_bia_d(const_cast<convolution_desc_t *>(desc)); }
|
|
||||||
const memory_desc_t *conv_prop_invariant_dst_d(const convolution_desc_t *desc)
|
|
||||||
{ return conv_prop_invariant_dst_d(const_cast<convolution_desc_t *>(desc)); }
|
|
||||||
|
|
||||||
}
|
|
||||||
}
|
|
@ -1,348 +0,0 @@
|
|||||||
/*******************************************************************************
|
|
||||||
* Copyright 2016-2018 Intel Corporation
|
|
||||||
*
|
|
||||||
* Licensed under the Apache License, Version 2.0 (the "License");
|
|
||||||
* you may not use this file except in compliance with the License.
|
|
||||||
* You may obtain a copy of the License at
|
|
||||||
*
|
|
||||||
* http://www.apache.org/licenses/LICENSE-2.0
|
|
||||||
*
|
|
||||||
* Unless required by applicable law or agreed to in writing, software
|
|
||||||
* distributed under the License is distributed on an "AS IS" BASIS,
|
|
||||||
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
|
||||||
* See the License for the specific language governing permissions and
|
|
||||||
* limitations under the License.
|
|
||||||
*******************************************************************************/
|
|
||||||
|
|
||||||
#ifndef CONVOLUTION_PD_HPP
|
|
||||||
#define CONVOLUTION_PD_HPP
|
|
||||||
|
|
||||||
#include "mkldnn.h"
|
|
||||||
|
|
||||||
#include "c_types_map.hpp"
|
|
||||||
#include "primitive_desc.hpp"
|
|
||||||
#include "utils.hpp"
|
|
||||||
|
|
||||||
namespace mkldnn {
|
|
||||||
namespace impl {
|
|
||||||
|
|
||||||
status_t conv_desc_init(convolution_desc_t *conv_desc,
|
|
||||||
prop_kind_t prop_kind, alg_kind_t alg_kind,
|
|
||||||
const memory_desc_t *src_desc, const memory_desc_t *weights_desc,
|
|
||||||
const memory_desc_t *bias_desc, const memory_desc_t *dst_desc,
|
|
||||||
const dims_t strides, const dims_t dilates,
|
|
||||||
const dims_t padding_l, const dims_t padding_r,
|
|
||||||
padding_kind_t padding_kind);
|
|
||||||
|
|
||||||
memory_desc_t *conv_prop_invariant_src_d(convolution_desc_t *desc);
|
|
||||||
memory_desc_t *conv_prop_invariant_wei_d(convolution_desc_t *desc);
|
|
||||||
memory_desc_t *conv_prop_invariant_bia_d(convolution_desc_t *desc);
|
|
||||||
memory_desc_t *conv_prop_invariant_dst_d(convolution_desc_t *desc);
|
|
||||||
const memory_desc_t *conv_prop_invariant_src_d(const convolution_desc_t *desc);
|
|
||||||
const memory_desc_t *conv_prop_invariant_wei_d(const convolution_desc_t *desc);
|
|
||||||
const memory_desc_t *conv_prop_invariant_bia_d(const convolution_desc_t *desc);
|
|
||||||
const memory_desc_t *conv_prop_invariant_dst_d(const convolution_desc_t *desc);
|
|
||||||
|
|
||||||
struct convolution_fwd_pd_t;
|
|
||||||
|
|
||||||
struct convolution_pd_t: public primitive_desc_t {
|
|
||||||
static constexpr auto base_pkind = primitive_kind::convolution;
|
|
||||||
|
|
||||||
convolution_pd_t(engine_t *engine,
|
|
||||||
const convolution_desc_t *adesc,
|
|
||||||
const primitive_attr_t *attr,
|
|
||||||
const convolution_fwd_pd_t *hint_fwd_pd)
|
|
||||||
: primitive_desc_t(engine, attr, base_pkind)
|
|
||||||
, desc_(*adesc)
|
|
||||||
, hint_fwd_pd_(hint_fwd_pd)
|
|
||||||
{}
|
|
||||||
|
|
||||||
const convolution_desc_t *desc() const { return &desc_; }
|
|
||||||
virtual const op_desc_t *op_desc() const override
|
|
||||||
{ return reinterpret_cast<const op_desc_t *>(this->desc()); }
|
|
||||||
virtual void init_info() override { impl::init_info(this, this->info_); }
|
|
||||||
|
|
||||||
virtual status_t query(query_t what, int idx, void *result) const override {
|
|
||||||
switch (what) {
|
|
||||||
case pkind_traits<base_pkind>::query_d:
|
|
||||||
*(const convolution_desc_t**)result = desc(); break;
|
|
||||||
default: return primitive_desc_t::query(what, idx, result);
|
|
||||||
}
|
|
||||||
return status::success;
|
|
||||||
}
|
|
||||||
|
|
||||||
/* common conv aux functions */
|
|
||||||
|
|
||||||
dim_t MB() const { return _src_md()->dims[0]; }
|
|
||||||
|
|
||||||
dim_t IC() const { return _src_md()->dims[1]; }
|
|
||||||
dim_t OC() const { return _dst_md()->dims[1]; }
|
|
||||||
dim_t G() const { return with_groups() ? _wei_md()->dims[0] : 1; }
|
|
||||||
|
|
||||||
dim_t ID() const { return ndims() >= 5 ? _src_md()->dims[ndims() - 3] : 1; }
|
|
||||||
dim_t IH() const { return ndims() >= 4 ? _src_md()->dims[ndims() - 2] : 1; }
|
|
||||||
dim_t IW() const { return _src_md()->dims[ndims() - 1]; }
|
|
||||||
|
|
||||||
dim_t OD() const { return ndims() >= 5 ? _dst_md()->dims[ndims() - 3] : 1; }
|
|
||||||
dim_t OH() const { return ndims() >= 4 ? _dst_md()->dims[ndims() - 2] : 1; }
|
|
||||||
dim_t OW() const { return _dst_md()->dims[ndims() - 1]; }
|
|
||||||
|
|
||||||
dim_t KD() const { return ndims() >= 5 ? _wei_md()->dims[ndims() + with_groups() - 3] : 1; }
|
|
||||||
dim_t KH() const { return ndims() >= 4 ? _wei_md()->dims[ndims() + with_groups() - 2] : 1; }
|
|
||||||
dim_t KW() const { return _wei_md()->dims[ndims() + with_groups() - 1]; }
|
|
||||||
|
|
||||||
dim_t KSD() const { return ndims() >= 5 ? desc_.strides[ndims() - 5] : 1; }
|
|
||||||
dim_t KSH() const { return ndims() >= 4 ? desc_.strides[ndims() - 4] : 1; }
|
|
||||||
dim_t KSW() const { return desc_.strides[ndims() - 3]; }
|
|
||||||
|
|
||||||
dim_t KDD() const { return ndims() >= 5 ? desc_.dilates[ndims() - 5] : 0; }
|
|
||||||
dim_t KDH() const { return ndims() >= 4 ? desc_.dilates[ndims() - 4] : 1; }
|
|
||||||
dim_t KDW() const { return desc_.dilates[ndims() - 3]; }
|
|
||||||
|
|
||||||
dim_t padFront() const { return ndims() >= 5 ? desc_.padding[0][ndims() - 5] : 0; }
|
|
||||||
dim_t padBack() const { return ndims() >= 5 ? desc_.padding[1][ndims() - 5] : 0; }
|
|
||||||
dim_t padT() const { return ndims() >= 4 ? desc_.padding[0][ndims() - 4] : 0; }
|
|
||||||
dim_t padB() const { return ndims() >= 4 ? desc_.padding[1][ndims() - 4] : 0; }
|
|
||||||
dim_t padL() const { return desc_.padding[0][ndims() - 3]; }
|
|
||||||
dim_t padR() const { return desc_.padding[1][ndims() - 3]; }
|
|
||||||
|
|
||||||
int ndims() const { return _src_md()->ndims; }
|
|
||||||
|
|
||||||
bool with_bias() const { return !memory_desc_wrapper(*_bia_md()).is_zero(); }
|
|
||||||
bool with_groups() const { return _wei_md()->ndims == ndims() + 1; }
|
|
||||||
|
|
||||||
bool is_fwd() const {
|
|
||||||
return utils::one_of(desc_.prop_kind, prop_kind::forward_training,
|
|
||||||
prop_kind::forward_inference);
|
|
||||||
}
|
|
||||||
|
|
||||||
bool has_zero_dim_memory() const {
|
|
||||||
const auto s_d = memory_desc_wrapper(*_src_md());
|
|
||||||
const auto d_d = memory_desc_wrapper(*_dst_md());
|
|
||||||
return s_d.has_zero_dim() || d_d.has_zero_dim();
|
|
||||||
}
|
|
||||||
|
|
||||||
protected:
|
|
||||||
convolution_desc_t desc_;
|
|
||||||
const convolution_fwd_pd_t *hint_fwd_pd_;
|
|
||||||
|
|
||||||
bool set_default_formats_common_template(
|
|
||||||
memory_desc_t &src_md, format_tag_t src_tag,
|
|
||||||
memory_desc_t &wei_md, format_tag_t wei_tag,
|
|
||||||
memory_desc_t &dst_md, format_tag_t dst_tag,
|
|
||||||
memory_desc_t &bia_md) {
|
|
||||||
using namespace format_tag;
|
|
||||||
|
|
||||||
# define IS_OK(f) \
|
|
||||||
do { if ((f) != status::success) return false; } while(0)
|
|
||||||
if (src_md.format_kind == format_kind::any
|
|
||||||
&& !utils::one_of(src_tag, any, undef))
|
|
||||||
IS_OK(memory_desc_init_by_tag(src_md, src_tag));
|
|
||||||
if (dst_md.format_kind == format_kind::any
|
|
||||||
&& !utils::one_of(dst_tag, any, undef))
|
|
||||||
IS_OK(memory_desc_init_by_tag(dst_md, dst_tag));
|
|
||||||
if (wei_md.format_kind == format_kind::any
|
|
||||||
&& !utils::one_of(wei_tag, any, undef))
|
|
||||||
IS_OK(memory_desc_init_by_tag(wei_md, wei_tag));
|
|
||||||
if (with_bias() && bia_md.format_kind == format_kind::any)
|
|
||||||
IS_OK(memory_desc_init_by_tag(bia_md, x));
|
|
||||||
# undef IS_OK
|
|
||||||
|
|
||||||
return true;
|
|
||||||
}
|
|
||||||
|
|
||||||
bool set_default_alg_kind(alg_kind_t alg_kind) {
|
|
||||||
assert(utils::one_of(alg_kind, alg_kind::convolution_direct,
|
|
||||||
alg_kind::convolution_winograd));
|
|
||||||
if (desc_.alg_kind == alg_kind::convolution_auto)
|
|
||||||
desc_.alg_kind = alg_kind;
|
|
||||||
return desc_.alg_kind == alg_kind;
|
|
||||||
}
|
|
||||||
|
|
||||||
bool expect_data_types(data_type_t src_dt, data_type_t wei_dt,
|
|
||||||
data_type_t bia_dt, data_type_t dst_dt, data_type_t acc_dt) const {
|
|
||||||
bool ok = true
|
|
||||||
&& (src_dt == data_type::undef || _src_md()->data_type == src_dt)
|
|
||||||
&& (wei_dt == data_type::undef || _wei_md()->data_type == wei_dt)
|
|
||||||
&& (dst_dt == data_type::undef || _dst_md()->data_type == dst_dt)
|
|
||||||
&& (acc_dt == data_type::undef || desc_.accum_data_type == acc_dt);
|
|
||||||
if (with_bias() && bia_dt != data_type::undef)
|
|
||||||
ok = ok && _bia_md()->data_type == bia_dt;
|
|
||||||
return ok;
|
|
||||||
}
|
|
||||||
|
|
||||||
private:
|
|
||||||
const memory_desc_t *_src_md() const { return conv_prop_invariant_src_d(&desc_); }
|
|
||||||
const memory_desc_t *_wei_md() const { return conv_prop_invariant_wei_d(&desc_); }
|
|
||||||
const memory_desc_t *_bia_md() const { return conv_prop_invariant_bia_d(&desc_); }
|
|
||||||
const memory_desc_t *_dst_md() const { return conv_prop_invariant_dst_d(&desc_); }
|
|
||||||
};
|
|
||||||
|
|
||||||
struct convolution_fwd_pd_t: public convolution_pd_t {
|
|
||||||
typedef convolution_fwd_pd_t base_class;
|
|
||||||
typedef convolution_fwd_pd_t hint_class;
|
|
||||||
|
|
||||||
convolution_fwd_pd_t(engine_t *engine,
|
|
||||||
const convolution_desc_t *adesc,
|
|
||||||
const primitive_attr_t *attr,
|
|
||||||
const convolution_fwd_pd_t *hint_fwd_pd)
|
|
||||||
: convolution_pd_t(engine, adesc, attr, hint_fwd_pd)
|
|
||||||
, src_md_(desc_.src_desc)
|
|
||||||
, weights_md_(desc_.weights_desc)
|
|
||||||
, bias_md_(desc_.bias_desc)
|
|
||||||
, dst_md_(desc_.dst_desc)
|
|
||||||
{}
|
|
||||||
|
|
||||||
virtual arg_usage_t arg_usage(primitive_arg_index_t arg) const override {
|
|
||||||
if (utils::one_of(arg, MKLDNN_ARG_SRC, MKLDNN_ARG_WEIGHTS))
|
|
||||||
return arg_usage_t::input;
|
|
||||||
|
|
||||||
if (arg == MKLDNN_ARG_BIAS && with_bias())
|
|
||||||
return arg_usage_t::input;
|
|
||||||
|
|
||||||
if (arg == MKLDNN_ARG_DST)
|
|
||||||
return arg_usage_t::output;
|
|
||||||
|
|
||||||
return primitive_desc_t::arg_usage(arg);
|
|
||||||
}
|
|
||||||
|
|
||||||
virtual const memory_desc_t *src_md(int index = 0) const override
|
|
||||||
{ return index == 0 ? &src_md_ : nullptr; }
|
|
||||||
virtual const memory_desc_t *dst_md(int index = 0) const override
|
|
||||||
{ return index == 0 ? &dst_md_ : nullptr; }
|
|
||||||
virtual const memory_desc_t *weights_md(int index = 0) const override {
|
|
||||||
if (index == 0) return &weights_md_;
|
|
||||||
if (index == 1 && with_bias()) return &bias_md_;
|
|
||||||
return nullptr;
|
|
||||||
}
|
|
||||||
|
|
||||||
virtual int n_inputs() const override { return 2 + with_bias(); }
|
|
||||||
virtual int n_outputs() const override { return 1; }
|
|
||||||
|
|
||||||
protected:
|
|
||||||
memory_desc_t src_md_;
|
|
||||||
memory_desc_t weights_md_;
|
|
||||||
memory_desc_t bias_md_;
|
|
||||||
memory_desc_t dst_md_;
|
|
||||||
|
|
||||||
bool set_default_formats_common(format_tag_t src_tag,
|
|
||||||
format_tag_t wei_tag, format_tag_t dst_tag) {
|
|
||||||
return set_default_formats_common_template(src_md_, src_tag,
|
|
||||||
weights_md_, wei_tag, dst_md_, dst_tag, bias_md_);
|
|
||||||
}
|
|
||||||
};
|
|
||||||
|
|
||||||
struct convolution_bwd_data_pd_t: public convolution_pd_t {
|
|
||||||
typedef convolution_bwd_data_pd_t base_class;
|
|
||||||
typedef convolution_fwd_pd_t hint_class;
|
|
||||||
|
|
||||||
convolution_bwd_data_pd_t(engine_t *engine,
|
|
||||||
const convolution_desc_t *adesc,
|
|
||||||
const primitive_attr_t *attr,
|
|
||||||
const convolution_fwd_pd_t *hint_fwd_pd)
|
|
||||||
: convolution_pd_t(engine, adesc, attr, hint_fwd_pd)
|
|
||||||
, diff_src_md_(desc_.diff_src_desc)
|
|
||||||
, weights_md_(desc_.weights_desc)
|
|
||||||
, bias_md_(desc_.bias_desc)
|
|
||||||
, diff_dst_md_(desc_.diff_dst_desc)
|
|
||||||
{}
|
|
||||||
|
|
||||||
virtual arg_usage_t arg_usage(primitive_arg_index_t arg) const override {
|
|
||||||
if (utils::one_of(arg, MKLDNN_ARG_WEIGHTS, MKLDNN_ARG_DIFF_DST))
|
|
||||||
return arg_usage_t::input;
|
|
||||||
|
|
||||||
if (arg == MKLDNN_ARG_DIFF_SRC)
|
|
||||||
return arg_usage_t::output;
|
|
||||||
|
|
||||||
return primitive_desc_t::arg_usage(arg);
|
|
||||||
}
|
|
||||||
|
|
||||||
virtual const memory_desc_t *diff_src_md(int index = 0) const override
|
|
||||||
{ return index == 0 ? &diff_src_md_ : nullptr; }
|
|
||||||
virtual const memory_desc_t *diff_dst_md(int index = 0) const override
|
|
||||||
{ return index == 0 ? &diff_dst_md_ : nullptr; }
|
|
||||||
virtual const memory_desc_t *weights_md(int index = 0) const override {
|
|
||||||
if (index == 0) return &weights_md_;
|
|
||||||
if (index == 1 && with_bias()) return &bias_md_;
|
|
||||||
return nullptr;
|
|
||||||
}
|
|
||||||
|
|
||||||
virtual int n_inputs() const override { return 2 + with_bias(); }
|
|
||||||
virtual int n_outputs() const override { return 1; }
|
|
||||||
|
|
||||||
virtual bool support_bias() const { return false; }
|
|
||||||
|
|
||||||
protected:
|
|
||||||
memory_desc_t diff_src_md_;
|
|
||||||
memory_desc_t weights_md_;
|
|
||||||
memory_desc_t bias_md_;
|
|
||||||
memory_desc_t diff_dst_md_;
|
|
||||||
|
|
||||||
bool set_default_formats_common(format_tag_t diff_src_tag,
|
|
||||||
format_tag_t wei_tag, format_tag_t diff_dst_tag) {
|
|
||||||
return set_default_formats_common_template(diff_src_md_, diff_src_tag,
|
|
||||||
weights_md_, wei_tag, diff_dst_md_, diff_dst_tag, bias_md_);
|
|
||||||
}
|
|
||||||
};
|
|
||||||
|
|
||||||
struct convolution_bwd_weights_pd_t: public convolution_pd_t {
|
|
||||||
typedef convolution_bwd_weights_pd_t base_class;
|
|
||||||
typedef convolution_fwd_pd_t hint_class;
|
|
||||||
|
|
||||||
convolution_bwd_weights_pd_t(engine_t *engine,
|
|
||||||
const convolution_desc_t *adesc,
|
|
||||||
const primitive_attr_t *attr,
|
|
||||||
const convolution_fwd_pd_t *hint_fwd_pd)
|
|
||||||
: convolution_pd_t(engine, adesc, attr, hint_fwd_pd)
|
|
||||||
, src_md_(desc_.src_desc)
|
|
||||||
, diff_weights_md_(desc_.diff_weights_desc)
|
|
||||||
, diff_bias_md_(desc_.diff_bias_desc)
|
|
||||||
, diff_dst_md_(desc_.diff_dst_desc)
|
|
||||||
{}
|
|
||||||
|
|
||||||
virtual arg_usage_t arg_usage(primitive_arg_index_t arg) const override {
|
|
||||||
if (utils::one_of(arg, MKLDNN_ARG_SRC, MKLDNN_ARG_DIFF_DST))
|
|
||||||
return arg_usage_t::input;
|
|
||||||
|
|
||||||
if (arg == MKLDNN_ARG_DIFF_WEIGHTS)
|
|
||||||
return arg_usage_t::output;
|
|
||||||
|
|
||||||
if (arg == MKLDNN_ARG_DIFF_BIAS && with_bias())
|
|
||||||
return arg_usage_t::output;
|
|
||||||
|
|
||||||
return primitive_desc_t::arg_usage(arg);
|
|
||||||
}
|
|
||||||
|
|
||||||
virtual const memory_desc_t *src_md(int index = 0) const override
|
|
||||||
{ return index == 0 ? &src_md_ : nullptr; }
|
|
||||||
virtual const memory_desc_t *diff_dst_md(int index = 0) const override
|
|
||||||
{ return index == 0 ? &diff_dst_md_ : nullptr; }
|
|
||||||
virtual const memory_desc_t *diff_weights_md(int index = 0) const override {
|
|
||||||
if (index == 0) return &diff_weights_md_;
|
|
||||||
if (index == 1 && with_bias()) return &diff_bias_md_;
|
|
||||||
return nullptr;
|
|
||||||
}
|
|
||||||
|
|
||||||
virtual int n_inputs() const override { return 2; }
|
|
||||||
virtual int n_outputs() const override { return 1 + with_bias(); }
|
|
||||||
|
|
||||||
protected:
|
|
||||||
memory_desc_t src_md_;
|
|
||||||
memory_desc_t diff_weights_md_;
|
|
||||||
memory_desc_t diff_bias_md_;
|
|
||||||
memory_desc_t diff_dst_md_;
|
|
||||||
|
|
||||||
bool set_default_formats_common(format_tag_t src_tag,
|
|
||||||
format_tag_t diff_wei_tag, format_tag_t diff_dst_tag) {
|
|
||||||
return set_default_formats_common_template(src_md_, src_tag,
|
|
||||||
diff_weights_md_, diff_wei_tag, diff_dst_md_, diff_dst_tag,
|
|
||||||
diff_bias_md_);
|
|
||||||
}
|
|
||||||
};
|
|
||||||
|
|
||||||
}
|
|
||||||
}
|
|
||||||
|
|
||||||
#endif
|
|
||||||
|
|
||||||
// vim: et ts=4 sw=4 cindent cino^=l0,\:0,N-s
|
|
188
thirdparty/oidn/mkl-dnn/src/common/deconvolution.cpp
vendored
188
thirdparty/oidn/mkl-dnn/src/common/deconvolution.cpp
vendored
@ -1,188 +0,0 @@
|
|||||||
/*******************************************************************************
|
|
||||||
* Copyright 2018 Intel Corporation
|
|
||||||
*
|
|
||||||
* Licensed under the Apache License, Version 2.0 (the "License");
|
|
||||||
* you may not use this file except in compliance with the License.
|
|
||||||
* You may obtain a copy of the License at
|
|
||||||
*
|
|
||||||
* http://www.apache.org/licenses/LICENSE-2.0
|
|
||||||
*
|
|
||||||
* Unless required by applicable law or agreed to in writing, software
|
|
||||||
* distributed under the License is distributed on an "AS IS" BASIS,
|
|
||||||
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
|
||||||
* See the License for the specific language governing permissions and
|
|
||||||
* limitations under the License.
|
|
||||||
*******************************************************************************/
|
|
||||||
|
|
||||||
#include "mkldnn.h"
|
|
||||||
#include <assert.h>
|
|
||||||
|
|
||||||
#include "c_types_map.hpp"
|
|
||||||
#include "type_helpers.hpp"
|
|
||||||
#include "utils.hpp"
|
|
||||||
|
|
||||||
using namespace mkldnn::impl;
|
|
||||||
using namespace mkldnn::impl::utils;
|
|
||||||
using namespace mkldnn::impl::status;
|
|
||||||
using namespace mkldnn::impl::prop_kind;
|
|
||||||
using namespace mkldnn::impl::alg_kind;
|
|
||||||
using namespace mkldnn::impl::types;
|
|
||||||
|
|
||||||
namespace {
|
|
||||||
status_t deconv_desc_init(deconvolution_desc_t *deconv_desc,
|
|
||||||
prop_kind_t prop_kind, alg_kind_t alg_kind,
|
|
||||||
const memory_desc_t *src_desc, const memory_desc_t *weights_desc,
|
|
||||||
const memory_desc_t *bias_desc, const memory_desc_t *dst_desc,
|
|
||||||
const dims_t strides, const dims_t dilates, const dims_t padding_l,
|
|
||||||
const dims_t padding_r, padding_kind_t padding_kind) {
|
|
||||||
bool args_ok = true
|
|
||||||
&& !any_null(deconv_desc, src_desc, weights_desc, dst_desc, strides,
|
|
||||||
padding_l)
|
|
||||||
&& one_of(alg_kind, deconvolution_direct, deconvolution_winograd)
|
|
||||||
&& one_of(padding_kind, padding_kind::padding_zero);
|
|
||||||
if (!args_ok)
|
|
||||||
return invalid_arguments;
|
|
||||||
|
|
||||||
if (padding_r == nullptr)
|
|
||||||
padding_r = padding_l;
|
|
||||||
|
|
||||||
auto dd = deconvolution_desc_t();
|
|
||||||
dd.primitive_kind = primitive_kind::deconvolution;
|
|
||||||
dd.prop_kind = prop_kind;
|
|
||||||
dd.alg_kind = alg_kind;
|
|
||||||
|
|
||||||
dd.diff_src_desc = dd.src_desc = zero_md();
|
|
||||||
dd.diff_dst_desc = dd.dst_desc = zero_md();
|
|
||||||
dd.diff_weights_desc = dd.weights_desc = zero_md();
|
|
||||||
dd.diff_bias_desc = dd.bias_desc = zero_md();
|
|
||||||
|
|
||||||
const bool is_fwd = one_of(prop_kind, forward_training, forward_inference);
|
|
||||||
const bool with_bias
|
|
||||||
= bias_desc && bias_desc->format_kind != format_kind::undef;
|
|
||||||
const bool with_groups = weights_desc->ndims == src_desc->ndims + 1;
|
|
||||||
|
|
||||||
(prop_kind == backward_data ? dd.diff_src_desc : dd.src_desc) = *src_desc;
|
|
||||||
(is_fwd ? dd.dst_desc : dd.diff_dst_desc) = *dst_desc;
|
|
||||||
(prop_kind == backward_weights ? dd.diff_weights_desc : dd.weights_desc)
|
|
||||||
= *weights_desc;
|
|
||||||
if (with_bias)
|
|
||||||
(prop_kind == backward_weights ? dd.diff_bias_desc : dd.bias_desc)
|
|
||||||
= *bias_desc;
|
|
||||||
|
|
||||||
int sp_dims = src_desc->ndims - 2;
|
|
||||||
utils::array_copy(dd.strides, strides, sp_dims);
|
|
||||||
utils::array_copy(dd.padding[0], padding_l, sp_dims);
|
|
||||||
utils::array_copy(dd.padding[1], padding_r, sp_dims);
|
|
||||||
if (dilates)
|
|
||||||
utils::array_copy(dd.dilates, dilates, sp_dims);
|
|
||||||
else
|
|
||||||
utils::array_set(dd.dilates, 0, sp_dims);
|
|
||||||
|
|
||||||
dd.padding_kind = padding_kind;
|
|
||||||
dd.accum_data_type = types::default_accum_data_type(src_desc->data_type,
|
|
||||||
weights_desc->data_type, dst_desc->data_type, prop_kind);
|
|
||||||
|
|
||||||
const int g = with_groups ? weights_desc->dims[0] : 1;
|
|
||||||
bool consistency = true
|
|
||||||
&& src_desc->ndims == dst_desc->ndims
|
|
||||||
&& utils::one_of(src_desc->ndims, 3, 4, 5)
|
|
||||||
&& utils::one_of(weights_desc->ndims, src_desc->ndims,
|
|
||||||
src_desc->ndims + 1)
|
|
||||||
&& (with_bias ? bias_desc->ndims == 1 : true)
|
|
||||||
&& (with_bias ? bias_desc->dims[0] == dst_desc->dims[1] : true)
|
|
||||||
&& src_desc->dims[0] == dst_desc->dims[0]
|
|
||||||
&& src_desc->dims[1] == g * weights_desc->dims[with_groups + 1]
|
|
||||||
&& dst_desc->dims[1] == g * weights_desc->dims[with_groups + 0];
|
|
||||||
for (int i = 2; i < src_desc->ndims; ++i) {
|
|
||||||
int src = src_desc->dims[i];
|
|
||||||
int ker = weights_desc->dims[with_groups + i];
|
|
||||||
int dil = dd.dilates[i - 2];
|
|
||||||
int pad = padding_l[i - 2] + padding_r[i - 2];
|
|
||||||
int str = strides[i - 2];
|
|
||||||
int dst = dst_desc->dims[i];
|
|
||||||
int ker_range = 1 + (ker - 1) * (dil + 1);
|
|
||||||
|
|
||||||
consistency
|
|
||||||
= consistency && (dst - ker_range + pad) / str + 1 == src;
|
|
||||||
}
|
|
||||||
if (!consistency)
|
|
||||||
return invalid_arguments;
|
|
||||||
|
|
||||||
*deconv_desc = dd;
|
|
||||||
return success;
|
|
||||||
}
|
|
||||||
}
|
|
||||||
|
|
||||||
status_t mkldnn_deconvolution_forward_desc_init(
|
|
||||||
deconvolution_desc_t *deconv_desc, prop_kind_t prop_kind,
|
|
||||||
alg_kind_t alg_kind, const memory_desc_t *src_desc,
|
|
||||||
const memory_desc_t *weights_desc, const memory_desc_t *bias_desc,
|
|
||||||
const memory_desc_t *dst_desc, const dims_t strides,
|
|
||||||
const dims_t padding_l, const dims_t padding_r,
|
|
||||||
padding_kind_t padding_kind) {
|
|
||||||
if (!one_of(prop_kind, forward_training, forward_inference))
|
|
||||||
return invalid_arguments;
|
|
||||||
return deconv_desc_init(deconv_desc, prop_kind, alg_kind, src_desc,
|
|
||||||
weights_desc, bias_desc, dst_desc, strides, nullptr, padding_l,
|
|
||||||
padding_r, padding_kind);
|
|
||||||
}
|
|
||||||
|
|
||||||
status_t mkldnn_dilated_deconvolution_forward_desc_init(
|
|
||||||
deconvolution_desc_t *deconv_desc, prop_kind_t prop_kind,
|
|
||||||
alg_kind_t alg_kind, const memory_desc_t *src_desc,
|
|
||||||
const memory_desc_t *weights_desc, const memory_desc_t *bias_desc,
|
|
||||||
const memory_desc_t *dst_desc, const dims_t strides,
|
|
||||||
const dims_t dilates, const dims_t padding_l, const dims_t padding_r,
|
|
||||||
padding_kind_t padding_kind) {
|
|
||||||
if (!one_of(prop_kind, forward_training, forward_inference))
|
|
||||||
return invalid_arguments;
|
|
||||||
return deconv_desc_init(deconv_desc, prop_kind, alg_kind, src_desc,
|
|
||||||
weights_desc, bias_desc, dst_desc, strides, dilates, padding_l,
|
|
||||||
padding_r, padding_kind);
|
|
||||||
}
|
|
||||||
|
|
||||||
status_t mkldnn_deconvolution_backward_data_desc_init(
|
|
||||||
deconvolution_desc_t *deconv_desc, alg_kind_t alg_kind,
|
|
||||||
const memory_desc_t *diff_src_desc, const memory_desc_t *weights_desc,
|
|
||||||
const memory_desc_t *diff_dst_desc, const dims_t strides,
|
|
||||||
const dims_t padding_l, const dims_t padding_r,
|
|
||||||
padding_kind_t padding_kind) {
|
|
||||||
return deconv_desc_init(deconv_desc, backward_data, alg_kind, diff_src_desc,
|
|
||||||
weights_desc, nullptr, diff_dst_desc, strides, nullptr, padding_l,
|
|
||||||
padding_r, padding_kind);
|
|
||||||
}
|
|
||||||
|
|
||||||
status_t mkldnn_dilated_deconvolution_backward_data_desc_init(
|
|
||||||
deconvolution_desc_t *deconv_desc, alg_kind_t alg_kind,
|
|
||||||
const memory_desc_t *diff_src_desc, const memory_desc_t *weights_desc,
|
|
||||||
const memory_desc_t *diff_dst_desc, const dims_t strides,
|
|
||||||
const dims_t dilates, const dims_t padding_l, const dims_t padding_r,
|
|
||||||
padding_kind_t padding_kind) {
|
|
||||||
return deconv_desc_init(deconv_desc, backward_data, alg_kind, diff_src_desc,
|
|
||||||
weights_desc, nullptr, diff_dst_desc, strides,dilates, padding_l,
|
|
||||||
padding_r, padding_kind);
|
|
||||||
}
|
|
||||||
|
|
||||||
status_t mkldnn_deconvolution_backward_weights_desc_init(
|
|
||||||
deconvolution_desc_t *deconv_desc, alg_kind_t alg_kind,
|
|
||||||
const memory_desc_t *src_desc, const memory_desc_t *diff_weights_desc,
|
|
||||||
const memory_desc_t *diff_bias_desc, const memory_desc_t *diff_dst_desc,
|
|
||||||
const dims_t strides, const dims_t padding_l, const dims_t padding_r,
|
|
||||||
padding_kind_t padding_kind) {
|
|
||||||
return deconv_desc_init(deconv_desc, backward_weights, alg_kind, src_desc,
|
|
||||||
diff_weights_desc, diff_bias_desc, diff_dst_desc, strides, nullptr,
|
|
||||||
padding_l, padding_r, padding_kind);
|
|
||||||
}
|
|
||||||
|
|
||||||
status_t mkldnn_dilated_deconvolution_backward_weights_desc_init(
|
|
||||||
deconvolution_desc_t *deconv_desc, alg_kind_t alg_kind,
|
|
||||||
const memory_desc_t *src_desc, const memory_desc_t *diff_weights_desc,
|
|
||||||
const memory_desc_t *diff_bias_desc, const memory_desc_t *diff_dst_desc,
|
|
||||||
const dims_t strides, const dims_t dilates, const dims_t padding_l,
|
|
||||||
const dims_t padding_r, padding_kind_t padding_kind) {
|
|
||||||
return deconv_desc_init(deconv_desc, backward_weights, alg_kind, src_desc,
|
|
||||||
diff_weights_desc, diff_bias_desc, diff_dst_desc, strides, dilates,
|
|
||||||
padding_l, padding_r, padding_kind);
|
|
||||||
}
|
|
||||||
|
|
||||||
// vim: et ts=4 sw=4 cindent cino^=l0,\:0,N-s
|
|
@ -1,293 +0,0 @@
|
|||||||
/*******************************************************************************
|
|
||||||
* Copyright 2018 Intel Corporation
|
|
||||||
*
|
|
||||||
* Licensed under the Apache License, Version 2.0 (the "License");
|
|
||||||
* you may not use this file except in compliance with the License.
|
|
||||||
* You may obtain a copy of the License at
|
|
||||||
*
|
|
||||||
* http://www.apache.org/licenses/LICENSE-2.0
|
|
||||||
*
|
|
||||||
* Unless required by applicable law or agreed to in writing, software
|
|
||||||
* distributed under the License is distributed on an "AS IS" BASIS,
|
|
||||||
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
|
||||||
* See the License for the specific language governing permissions and
|
|
||||||
* limitations under the License.
|
|
||||||
*******************************************************************************/
|
|
||||||
|
|
||||||
#ifndef DECONVOLUTION_PD_HPP
|
|
||||||
#define DECONVOLUTION_PD_HPP
|
|
||||||
|
|
||||||
#include "mkldnn.h"
|
|
||||||
|
|
||||||
#include "c_types_map.hpp"
|
|
||||||
#include "convolution_pd.hpp"
|
|
||||||
#include "primitive_desc.hpp"
|
|
||||||
#include "utils.hpp"
|
|
||||||
|
|
||||||
namespace mkldnn {
|
|
||||||
namespace impl {
|
|
||||||
|
|
||||||
struct deconvolution_fwd_pd_t;
|
|
||||||
|
|
||||||
struct deconvolution_pd_t: public primitive_desc_t {
|
|
||||||
static constexpr auto base_pkind = primitive_kind::deconvolution;
|
|
||||||
|
|
||||||
deconvolution_pd_t(engine_t *engine,
|
|
||||||
const deconvolution_desc_t *adesc,
|
|
||||||
const primitive_attr_t *attr,
|
|
||||||
const deconvolution_fwd_pd_t *hint_fwd_pd)
|
|
||||||
: primitive_desc_t(engine, attr, base_pkind)
|
|
||||||
, desc_(*adesc)
|
|
||||||
, hint_fwd_pd_(hint_fwd_pd)
|
|
||||||
{}
|
|
||||||
|
|
||||||
const deconvolution_desc_t *desc() const { return &desc_; }
|
|
||||||
virtual const op_desc_t *op_desc() const override
|
|
||||||
{ return reinterpret_cast<const op_desc_t *>(this->desc()); }
|
|
||||||
virtual void init_info() override { impl::init_info(this, this->info_); }
|
|
||||||
|
|
||||||
virtual status_t query(query_t what, int idx, void *result) const override {
|
|
||||||
switch (what) {
|
|
||||||
case pkind_traits<base_pkind>::query_d:
|
|
||||||
*(const deconvolution_desc_t **)result = desc();
|
|
||||||
break;
|
|
||||||
default: return primitive_desc_t::query(what, idx, result);
|
|
||||||
}
|
|
||||||
return status::success;
|
|
||||||
}
|
|
||||||
|
|
||||||
/* common deconv aux functions (note that conv_desc_t == deconv_desc_t) */
|
|
||||||
|
|
||||||
dim_t MB() const { return conv_prop_invariant_src_d(&desc_)->dims[0]; }
|
|
||||||
|
|
||||||
dim_t IC() const { return conv_prop_invariant_src_d(&desc_)->dims[1]; }
|
|
||||||
dim_t OC() const { return conv_prop_invariant_dst_d(&desc_)->dims[1]; }
|
|
||||||
dim_t G() const
|
|
||||||
{ return with_groups() ? conv_prop_invariant_wei_d(&desc_)->dims[0] : 1; }
|
|
||||||
|
|
||||||
dim_t ID() const {
|
|
||||||
return ndims() >= 5
|
|
||||||
? conv_prop_invariant_src_d(&desc_)->dims[ndims() - 3] : 1;
|
|
||||||
}
|
|
||||||
dim_t IH() const {
|
|
||||||
return ndims() >= 4
|
|
||||||
? conv_prop_invariant_src_d(&desc_)->dims[ndims() - 2] : 1;
|
|
||||||
}
|
|
||||||
dim_t IW() const {
|
|
||||||
return conv_prop_invariant_src_d(&desc_)->dims[ndims() - 1];
|
|
||||||
}
|
|
||||||
|
|
||||||
dim_t OD() const {
|
|
||||||
return ndims() >= 5
|
|
||||||
? conv_prop_invariant_dst_d(&desc_)->dims[ndims() - 3] : 1;
|
|
||||||
}
|
|
||||||
dim_t OH() const {
|
|
||||||
return ndims() >= 4
|
|
||||||
? conv_prop_invariant_dst_d(&desc_)->dims[ndims() - 2] : 1;
|
|
||||||
}
|
|
||||||
dim_t OW() const {
|
|
||||||
return conv_prop_invariant_dst_d(&desc_)->dims[ndims() - 1];
|
|
||||||
}
|
|
||||||
|
|
||||||
dim_t KD() const {
|
|
||||||
const int w_ndims = ndims() + with_groups();
|
|
||||||
return ndims() >= 5
|
|
||||||
? conv_prop_invariant_wei_d(&desc_)->dims[w_ndims - 3] : 1;
|
|
||||||
}
|
|
||||||
dim_t KH() const {
|
|
||||||
const int w_ndims = ndims() + with_groups();
|
|
||||||
return ndims() >= 4
|
|
||||||
? conv_prop_invariant_wei_d(&desc_)->dims[w_ndims - 2] : 1;
|
|
||||||
}
|
|
||||||
dim_t KW() const {
|
|
||||||
const int w_ndims = ndims() + with_groups();
|
|
||||||
return conv_prop_invariant_wei_d(&desc_)->dims[w_ndims - 1];
|
|
||||||
}
|
|
||||||
|
|
||||||
dim_t KSD() const { return ndims() >= 5 ? desc_.strides[ndims() - 5] : 1; }
|
|
||||||
dim_t KSH() const { return ndims() >= 4 ? desc_.strides[ndims() - 4] : 1; }
|
|
||||||
dim_t KSW() const { return desc_.strides[ndims() - 3]; }
|
|
||||||
|
|
||||||
dim_t KDD() const { return ndims() >= 5 ? desc_.dilates[ndims() - 5] : 0; }
|
|
||||||
dim_t KDH() const { return ndims() >= 4 ? desc_.dilates[ndims() - 4] : 1; }
|
|
||||||
dim_t KDW() const { return desc_.dilates[ndims() - 3]; }
|
|
||||||
|
|
||||||
dim_t padFront() const
|
|
||||||
{ return ndims() >= 5 ? desc_.padding[0][ndims() - 5] : 0; }
|
|
||||||
dim_t padBack() const
|
|
||||||
{ return ndims() >= 5 ? desc_.padding[1][ndims() - 5] : 0; }
|
|
||||||
dim_t padT() const
|
|
||||||
{ return ndims() >= 4 ? desc_.padding[0][ndims() - 4] : 0; }
|
|
||||||
dim_t padB() const
|
|
||||||
{ return ndims() >= 4 ? desc_.padding[1][ndims() - 4] : 0; }
|
|
||||||
dim_t padL() const { return desc_.padding[0][ndims() - 3]; }
|
|
||||||
dim_t padR() const { return desc_.padding[1][ndims() - 3]; }
|
|
||||||
|
|
||||||
bool with_bias() const {
|
|
||||||
return
|
|
||||||
!memory_desc_wrapper(*conv_prop_invariant_bia_d(&desc_)).is_zero();
|
|
||||||
}
|
|
||||||
|
|
||||||
bool with_groups() const
|
|
||||||
{ return conv_prop_invariant_wei_d(&desc_)->ndims == ndims() + 1; }
|
|
||||||
|
|
||||||
int ndims() const { return conv_prop_invariant_src_d(&desc_)->ndims; }
|
|
||||||
|
|
||||||
bool is_fwd() const {
|
|
||||||
return utils::one_of(desc_.prop_kind, prop_kind::forward_training,
|
|
||||||
prop_kind::forward_inference);
|
|
||||||
}
|
|
||||||
|
|
||||||
bool has_zero_dim_memory() const {
|
|
||||||
const auto s_d = memory_desc_wrapper(*conv_prop_invariant_src_d(&desc_));
|
|
||||||
const auto d_d = memory_desc_wrapper(*conv_prop_invariant_dst_d(&desc_));
|
|
||||||
return s_d.has_zero_dim() || d_d.has_zero_dim();
|
|
||||||
}
|
|
||||||
|
|
||||||
protected:
|
|
||||||
deconvolution_desc_t desc_;
|
|
||||||
const deconvolution_fwd_pd_t *hint_fwd_pd_;
|
|
||||||
};
|
|
||||||
|
|
||||||
struct deconvolution_fwd_pd_t: public deconvolution_pd_t {
|
|
||||||
typedef deconvolution_fwd_pd_t base_class;
|
|
||||||
typedef deconvolution_fwd_pd_t hint_class;
|
|
||||||
|
|
||||||
deconvolution_fwd_pd_t(engine_t *engine,
|
|
||||||
const deconvolution_desc_t *adesc,
|
|
||||||
const primitive_attr_t *attr,
|
|
||||||
const deconvolution_fwd_pd_t *hint_fwd_pd)
|
|
||||||
: deconvolution_pd_t(engine, adesc, attr, hint_fwd_pd)
|
|
||||||
, src_md_(desc_.src_desc)
|
|
||||||
, weights_md_(desc_.weights_desc)
|
|
||||||
, bias_md_(desc_.bias_desc)
|
|
||||||
, dst_md_(desc_.dst_desc)
|
|
||||||
{}
|
|
||||||
|
|
||||||
virtual arg_usage_t arg_usage(primitive_arg_index_t arg) const override {
|
|
||||||
if (utils::one_of(arg, MKLDNN_ARG_SRC, MKLDNN_ARG_WEIGHTS))
|
|
||||||
return arg_usage_t::input;
|
|
||||||
|
|
||||||
if (arg == MKLDNN_ARG_BIAS && with_bias())
|
|
||||||
return arg_usage_t::input;
|
|
||||||
|
|
||||||
if (arg == MKLDNN_ARG_DST)
|
|
||||||
return arg_usage_t::output;
|
|
||||||
|
|
||||||
return primitive_desc_t::arg_usage(arg);
|
|
||||||
}
|
|
||||||
|
|
||||||
virtual const memory_desc_t *src_md(int index = 0) const override
|
|
||||||
{ return index == 0 ? &src_md_ : nullptr; }
|
|
||||||
virtual const memory_desc_t *dst_md(int index = 0) const override
|
|
||||||
{ return index == 0 ? &dst_md_ : nullptr; }
|
|
||||||
virtual const memory_desc_t *weights_md(int index = 0) const override {
|
|
||||||
if (index == 0) return &weights_md_;
|
|
||||||
if (index == 1 && with_bias()) return &bias_md_;
|
|
||||||
return nullptr;
|
|
||||||
}
|
|
||||||
|
|
||||||
virtual int n_inputs() const override { return 2 + with_bias(); }
|
|
||||||
virtual int n_outputs() const override { return 1; }
|
|
||||||
|
|
||||||
protected:
|
|
||||||
memory_desc_t src_md_;
|
|
||||||
memory_desc_t weights_md_;
|
|
||||||
memory_desc_t bias_md_;
|
|
||||||
memory_desc_t dst_md_;
|
|
||||||
};
|
|
||||||
|
|
||||||
struct deconvolution_bwd_data_pd_t: public deconvolution_pd_t {
|
|
||||||
typedef deconvolution_bwd_data_pd_t base_class;
|
|
||||||
typedef deconvolution_fwd_pd_t hint_class;
|
|
||||||
|
|
||||||
deconvolution_bwd_data_pd_t(engine_t *engine,
|
|
||||||
const deconvolution_desc_t *adesc,
|
|
||||||
const primitive_attr_t *attr,
|
|
||||||
const deconvolution_fwd_pd_t *hint_fwd_pd)
|
|
||||||
: deconvolution_pd_t(engine, adesc, attr, hint_fwd_pd)
|
|
||||||
, diff_src_md_(desc_.diff_src_desc)
|
|
||||||
, weights_md_(desc_.weights_desc)
|
|
||||||
, diff_dst_md_(desc_.diff_dst_desc)
|
|
||||||
{}
|
|
||||||
|
|
||||||
virtual arg_usage_t arg_usage(primitive_arg_index_t arg) const override {
|
|
||||||
if (utils::one_of(arg, MKLDNN_ARG_WEIGHTS, MKLDNN_ARG_DIFF_DST))
|
|
||||||
return arg_usage_t::input;
|
|
||||||
|
|
||||||
if (arg == MKLDNN_ARG_DIFF_SRC)
|
|
||||||
return arg_usage_t::output;
|
|
||||||
|
|
||||||
return primitive_desc_t::arg_usage(arg);
|
|
||||||
}
|
|
||||||
|
|
||||||
virtual const memory_desc_t *diff_src_md(int index = 0) const override
|
|
||||||
{ return index == 0 ? &diff_src_md_ : nullptr; }
|
|
||||||
virtual const memory_desc_t *diff_dst_md(int index = 0) const override
|
|
||||||
{ return index == 0 ? &diff_dst_md_ : nullptr; }
|
|
||||||
virtual const memory_desc_t *weights_md(int index = 0) const override
|
|
||||||
{ return index == 0 ? &weights_md_ : nullptr; }
|
|
||||||
|
|
||||||
virtual int n_inputs() const override { return 2; }
|
|
||||||
virtual int n_outputs() const override { return 1; }
|
|
||||||
|
|
||||||
protected:
|
|
||||||
memory_desc_t diff_src_md_;
|
|
||||||
memory_desc_t weights_md_;
|
|
||||||
memory_desc_t diff_dst_md_;
|
|
||||||
};
|
|
||||||
|
|
||||||
struct deconvolution_bwd_weights_pd_t: public deconvolution_pd_t {
|
|
||||||
typedef deconvolution_bwd_weights_pd_t base_class;
|
|
||||||
typedef deconvolution_fwd_pd_t hint_class;
|
|
||||||
|
|
||||||
deconvolution_bwd_weights_pd_t(engine_t *engine,
|
|
||||||
const deconvolution_desc_t *adesc,
|
|
||||||
const primitive_attr_t *attr,
|
|
||||||
const deconvolution_fwd_pd_t *hint_fwd_pd)
|
|
||||||
: deconvolution_pd_t(engine, adesc, attr, hint_fwd_pd)
|
|
||||||
, src_md_(desc_.src_desc)
|
|
||||||
, diff_weights_md_(desc_.diff_weights_desc)
|
|
||||||
, diff_bias_md_(desc_.diff_bias_desc)
|
|
||||||
, diff_dst_md_(desc_.diff_dst_desc)
|
|
||||||
{}
|
|
||||||
|
|
||||||
virtual arg_usage_t arg_usage(primitive_arg_index_t arg) const override {
|
|
||||||
if (utils::one_of(arg, MKLDNN_ARG_SRC, MKLDNN_ARG_DIFF_DST))
|
|
||||||
return arg_usage_t::input;
|
|
||||||
|
|
||||||
if (arg == MKLDNN_ARG_DIFF_WEIGHTS)
|
|
||||||
return arg_usage_t::output;
|
|
||||||
|
|
||||||
if (arg == MKLDNN_ARG_DIFF_BIAS && with_bias())
|
|
||||||
return arg_usage_t::output;
|
|
||||||
|
|
||||||
return primitive_desc_t::arg_usage(arg);
|
|
||||||
}
|
|
||||||
|
|
||||||
virtual const memory_desc_t *src_md(int index = 0) const override
|
|
||||||
{ return index == 0 ? &src_md_ : nullptr; }
|
|
||||||
virtual const memory_desc_t *diff_dst_md(int index = 0) const override
|
|
||||||
{ return index == 0 ? &diff_dst_md_ : nullptr; }
|
|
||||||
virtual const memory_desc_t *diff_weights_md(int index = 0) const override {
|
|
||||||
if (index == 0) return &diff_weights_md_;
|
|
||||||
if (index == 1 && with_bias()) return &diff_bias_md_;
|
|
||||||
return nullptr;
|
|
||||||
}
|
|
||||||
|
|
||||||
virtual int n_inputs() const override { return 2; }
|
|
||||||
virtual int n_outputs() const override { return 1 + with_bias(); }
|
|
||||||
|
|
||||||
protected:
|
|
||||||
memory_desc_t src_md_;
|
|
||||||
memory_desc_t diff_weights_md_;
|
|
||||||
memory_desc_t diff_bias_md_;
|
|
||||||
memory_desc_t diff_dst_md_;
|
|
||||||
};
|
|
||||||
|
|
||||||
}
|
|
||||||
}
|
|
||||||
|
|
||||||
#endif
|
|
||||||
|
|
||||||
// vim: et ts=4 sw=4 cindent cino^=l0,\:0,N-s
|
|
84
thirdparty/oidn/mkl-dnn/src/common/eltwise.cpp
vendored
84
thirdparty/oidn/mkl-dnn/src/common/eltwise.cpp
vendored
@ -1,84 +0,0 @@
|
|||||||
/*******************************************************************************
|
|
||||||
* Copyright 2016-2018 Intel Corporation
|
|
||||||
*
|
|
||||||
* Licensed under the Apache License, Version 2.0 (the "License");
|
|
||||||
* you may not use this file except in compliance with the License.
|
|
||||||
* You may obtain a copy of the License at
|
|
||||||
*
|
|
||||||
* http://www.apache.org/licenses/LICENSE-2.0
|
|
||||||
*
|
|
||||||
* Unless required by applicable law or agreed to in writing, software
|
|
||||||
* distributed under the License is distributed on an "AS IS" BASIS,
|
|
||||||
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
|
||||||
* See the License for the specific language governing permissions and
|
|
||||||
* limitations under the License.
|
|
||||||
*******************************************************************************/
|
|
||||||
|
|
||||||
#include <assert.h>
|
|
||||||
#include "mkldnn.h"
|
|
||||||
|
|
||||||
#include "c_types_map.hpp"
|
|
||||||
#include "type_helpers.hpp"
|
|
||||||
#include "utils.hpp"
|
|
||||||
|
|
||||||
using namespace mkldnn::impl;
|
|
||||||
using namespace mkldnn::impl::utils;
|
|
||||||
using namespace mkldnn::impl::status;
|
|
||||||
using namespace mkldnn::impl::prop_kind;
|
|
||||||
using namespace mkldnn::impl::alg_kind;
|
|
||||||
using namespace mkldnn::impl::types;
|
|
||||||
|
|
||||||
namespace {
|
|
||||||
status_t eltwise_desc_init(eltwise_desc_t *eltwise_desc, prop_kind_t prop_kind,
|
|
||||||
alg_kind_t alg_kind, const memory_desc_t *data_desc,
|
|
||||||
const memory_desc_t *diff_data_desc, float alpha, float beta) {
|
|
||||||
bool args_ok = true
|
|
||||||
&& !any_null(eltwise_desc, data_desc)
|
|
||||||
&& one_of(prop_kind, forward_training, forward_inference,
|
|
||||||
backward_data)
|
|
||||||
&& one_of(alg_kind, eltwise_relu, eltwise_tanh, eltwise_elu,
|
|
||||||
eltwise_square, eltwise_abs, eltwise_sqrt, eltwise_linear,
|
|
||||||
eltwise_bounded_relu, eltwise_soft_relu, eltwise_logistic)
|
|
||||||
&& IMPLICATION(prop_kind == backward_data, diff_data_desc != nullptr);
|
|
||||||
if (!args_ok) return invalid_arguments;
|
|
||||||
|
|
||||||
auto ed = eltwise_desc_t();
|
|
||||||
ed.primitive_kind = primitive_kind::eltwise;
|
|
||||||
ed.prop_kind = prop_kind;
|
|
||||||
ed.alg_kind = alg_kind;
|
|
||||||
|
|
||||||
ed.data_desc = *data_desc;
|
|
||||||
ed.diff_data_desc =
|
|
||||||
(ed.prop_kind == backward_data) ? *diff_data_desc : zero_md();
|
|
||||||
|
|
||||||
ed.alpha = alpha;
|
|
||||||
ed.beta = beta;
|
|
||||||
|
|
||||||
bool consistency = true
|
|
||||||
&& IMPLICATION(ed.prop_kind == backward_data,
|
|
||||||
array_cmp(ed.diff_data_desc.dims, ed.data_desc.dims,
|
|
||||||
ed.diff_data_desc.ndims));
|
|
||||||
if (!consistency) return invalid_arguments;
|
|
||||||
|
|
||||||
*eltwise_desc = ed;
|
|
||||||
return success;
|
|
||||||
}
|
|
||||||
}
|
|
||||||
|
|
||||||
status_t mkldnn_eltwise_forward_desc_init(eltwise_desc_t *eltwise_desc,
|
|
||||||
prop_kind_t prop_kind, alg_kind_t alg_kind,
|
|
||||||
const memory_desc_t *data_desc, float alpha, float beta) {
|
|
||||||
if (!one_of(prop_kind, forward_training, forward_inference))
|
|
||||||
return invalid_arguments;
|
|
||||||
return eltwise_desc_init(eltwise_desc, prop_kind, alg_kind, data_desc,
|
|
||||||
nullptr, alpha, beta);
|
|
||||||
}
|
|
||||||
|
|
||||||
status_t mkldnn_eltwise_backward_desc_init(eltwise_desc_t *eltwise_desc,
|
|
||||||
alg_kind_t alg_kind, const memory_desc_t *diff_data_desc,
|
|
||||||
const memory_desc_t *data_desc, float alpha, float beta) {
|
|
||||||
return eltwise_desc_init(eltwise_desc, backward_data, alg_kind, data_desc,
|
|
||||||
diff_data_desc, alpha, beta);
|
|
||||||
}
|
|
||||||
|
|
||||||
// vim: et ts=4 sw=4 cindent cino^=l0,\:0,N-s
|
|
161
thirdparty/oidn/mkl-dnn/src/common/eltwise_pd.hpp
vendored
161
thirdparty/oidn/mkl-dnn/src/common/eltwise_pd.hpp
vendored
@ -1,161 +0,0 @@
|
|||||||
/*******************************************************************************
|
|
||||||
* Copyright 2016-2018 Intel Corporation
|
|
||||||
*
|
|
||||||
* Licensed under the Apache License, Version 2.0 (the "License");
|
|
||||||
* you may not use this file except in compliance with the License.
|
|
||||||
* You may obtain a copy of the License at
|
|
||||||
*
|
|
||||||
* http://www.apache.org/licenses/LICENSE-2.0
|
|
||||||
*
|
|
||||||
* Unless required by applicable law or agreed to in writing, software
|
|
||||||
* distributed under the License is distributed on an "AS IS" BASIS,
|
|
||||||
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
|
||||||
* See the License for the specific language governing permissions and
|
|
||||||
* limitations under the License.
|
|
||||||
*******************************************************************************/
|
|
||||||
|
|
||||||
#ifndef ELTWISE_PD_HPP
|
|
||||||
#define ELTWISE_PD_HPP
|
|
||||||
|
|
||||||
#include "mkldnn.h"
|
|
||||||
|
|
||||||
#include "c_types_map.hpp"
|
|
||||||
#include "primitive_desc.hpp"
|
|
||||||
|
|
||||||
namespace mkldnn {
|
|
||||||
namespace impl {
|
|
||||||
|
|
||||||
struct eltwise_fwd_pd_t;
|
|
||||||
|
|
||||||
struct eltwise_pd_t: public primitive_desc_t {
|
|
||||||
static constexpr auto base_pkind = primitive_kind::eltwise;
|
|
||||||
|
|
||||||
eltwise_pd_t(mkldnn::impl::engine_t *engine,
|
|
||||||
const eltwise_desc_t *adesc,
|
|
||||||
const primitive_attr_t *attr,
|
|
||||||
const eltwise_fwd_pd_t *hint_fwd_pd)
|
|
||||||
: primitive_desc_t(engine, attr, base_pkind)
|
|
||||||
, desc_(*adesc)
|
|
||||||
, hint_fwd_pd_(hint_fwd_pd)
|
|
||||||
, data_md_(desc_.data_desc)
|
|
||||||
{}
|
|
||||||
|
|
||||||
const eltwise_desc_t *desc() const { return &desc_; }
|
|
||||||
virtual const op_desc_t *op_desc() const override
|
|
||||||
{ return reinterpret_cast<const op_desc_t *>(this->desc()); }
|
|
||||||
virtual void init_info() override { impl::init_info(this, this->info_); }
|
|
||||||
|
|
||||||
virtual status_t query(query_t what, int idx, void *result) const override {
|
|
||||||
switch (what) {
|
|
||||||
case query::eltwise_d:
|
|
||||||
*(const eltwise_desc_t**)result = desc(); break;
|
|
||||||
default: return primitive_desc_t::query(what, idx, result);
|
|
||||||
}
|
|
||||||
return status::success;
|
|
||||||
}
|
|
||||||
|
|
||||||
/* common eltwise aux functions */
|
|
||||||
|
|
||||||
dim_t MB() const { return data_desc().dims[0]; }
|
|
||||||
dim_t C() const { return data_desc().dims[1]; }
|
|
||||||
dim_t D() const { return ndims() >= 5 ? data_desc().dims[ndims() - 3] : 1; }
|
|
||||||
dim_t H() const { return ndims() >= 4 ? data_desc().dims[ndims() - 2] : 1; }
|
|
||||||
dim_t W() const { return ndims() >= 3 ? data_desc().dims[ndims() - 1] : 1; }
|
|
||||||
|
|
||||||
int ndims() const { return data_desc().ndims; }
|
|
||||||
|
|
||||||
bool is_fwd() const {
|
|
||||||
return utils::one_of(desc_.prop_kind, prop_kind::forward_training,
|
|
||||||
prop_kind::forward_inference);
|
|
||||||
}
|
|
||||||
|
|
||||||
bool has_zero_dim_memory() const
|
|
||||||
{ return memory_desc_wrapper(desc_.data_desc).has_zero_dim(); }
|
|
||||||
|
|
||||||
protected:
|
|
||||||
eltwise_desc_t desc_;
|
|
||||||
const eltwise_fwd_pd_t *hint_fwd_pd_;
|
|
||||||
|
|
||||||
memory_desc_t data_md_;
|
|
||||||
|
|
||||||
private:
|
|
||||||
const memory_desc_t &data_desc() const { return desc_.data_desc; }
|
|
||||||
};
|
|
||||||
|
|
||||||
struct eltwise_fwd_pd_t: public eltwise_pd_t {
|
|
||||||
typedef eltwise_fwd_pd_t base_class;
|
|
||||||
typedef eltwise_fwd_pd_t hint_class;
|
|
||||||
|
|
||||||
eltwise_fwd_pd_t(mkldnn::impl::engine_t *engine,
|
|
||||||
const eltwise_desc_t *adesc,
|
|
||||||
const primitive_attr_t *attr,
|
|
||||||
const eltwise_fwd_pd_t *hint_fwd_pd)
|
|
||||||
: eltwise_pd_t(engine, adesc, attr, hint_fwd_pd)
|
|
||||||
{}
|
|
||||||
|
|
||||||
virtual arg_usage_t arg_usage(primitive_arg_index_t arg) const override {
|
|
||||||
if (arg == MKLDNN_ARG_SRC)
|
|
||||||
return arg_usage_t::input;
|
|
||||||
|
|
||||||
if (arg == MKLDNN_ARG_DST)
|
|
||||||
return arg_usage_t::output;
|
|
||||||
|
|
||||||
return primitive_desc_t::arg_usage(arg);
|
|
||||||
}
|
|
||||||
|
|
||||||
virtual const memory_desc_t *src_md(int index = 0) const override
|
|
||||||
{ return index == 0 ? &data_md_ : nullptr; }
|
|
||||||
virtual const memory_desc_t *dst_md(int index = 0) const override
|
|
||||||
{ return index == 0 ? &data_md_ : nullptr; }
|
|
||||||
|
|
||||||
virtual int n_inputs() const override { return 1; }
|
|
||||||
virtual int n_outputs() const override { return 1; }
|
|
||||||
|
|
||||||
bool is_zero_preserved() const
|
|
||||||
{ return math::eltwise_fwd_preserves_zero(desc_.alg_kind); }
|
|
||||||
};
|
|
||||||
|
|
||||||
struct eltwise_bwd_pd_t: public eltwise_pd_t {
|
|
||||||
typedef eltwise_bwd_pd_t base_class;
|
|
||||||
typedef eltwise_fwd_pd_t hint_class;
|
|
||||||
|
|
||||||
eltwise_bwd_pd_t(engine_t *engine,
|
|
||||||
const eltwise_desc_t *adesc,
|
|
||||||
const primitive_attr_t *attr,
|
|
||||||
const eltwise_fwd_pd_t *hint_fwd_pd)
|
|
||||||
: eltwise_pd_t(engine, adesc, attr, hint_fwd_pd)
|
|
||||||
, diff_data_md_(desc_.diff_data_desc)
|
|
||||||
{}
|
|
||||||
|
|
||||||
virtual arg_usage_t arg_usage(primitive_arg_index_t arg) const override {
|
|
||||||
if (utils::one_of(arg, MKLDNN_ARG_SRC, MKLDNN_ARG_DIFF_DST))
|
|
||||||
return arg_usage_t::input;
|
|
||||||
|
|
||||||
if (arg == MKLDNN_ARG_DIFF_SRC)
|
|
||||||
return arg_usage_t::output;
|
|
||||||
|
|
||||||
return primitive_desc_t::arg_usage(arg);
|
|
||||||
}
|
|
||||||
|
|
||||||
virtual const memory_desc_t *src_md(int index = 0) const override
|
|
||||||
{ return index == 0 ? &data_md_ : nullptr; }
|
|
||||||
virtual const memory_desc_t *diff_dst_md(int index = 0) const override
|
|
||||||
{ return index == 0 ? &diff_data_md_ : nullptr; }
|
|
||||||
virtual const memory_desc_t *diff_src_md(int index = 0) const override
|
|
||||||
{ return index == 0 ? &diff_data_md_ : nullptr; }
|
|
||||||
|
|
||||||
virtual int n_inputs() const override { return 2; }
|
|
||||||
virtual int n_outputs() const override { return 1; }
|
|
||||||
|
|
||||||
bool is_zero_preserved() const { return true; }
|
|
||||||
|
|
||||||
protected:
|
|
||||||
memory_desc_t diff_data_md_;
|
|
||||||
};
|
|
||||||
|
|
||||||
}
|
|
||||||
}
|
|
||||||
|
|
||||||
#endif
|
|
||||||
|
|
||||||
// vim: et ts=4 sw=4 cindent cino^=l0,\:0,N-s
|
|
75
thirdparty/oidn/mkl-dnn/src/common/engine.cpp
vendored
75
thirdparty/oidn/mkl-dnn/src/common/engine.cpp
vendored
@ -1,75 +0,0 @@
|
|||||||
/*******************************************************************************
|
|
||||||
* Copyright 2016-2018 Intel Corporation
|
|
||||||
*
|
|
||||||
* Licensed under the Apache License, Version 2.0 (the "License");
|
|
||||||
* you may not use this file except in compliance with the License.
|
|
||||||
* You may obtain a copy of the License at
|
|
||||||
*
|
|
||||||
* http://www.apache.org/licenses/LICENSE-2.0
|
|
||||||
*
|
|
||||||
* Unless required by applicable law or agreed to in writing, software
|
|
||||||
* distributed under the License is distributed on an "AS IS" BASIS,
|
|
||||||
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
|
||||||
* See the License for the specific language governing permissions and
|
|
||||||
* limitations under the License.
|
|
||||||
*******************************************************************************/
|
|
||||||
|
|
||||||
#include "mkldnn.h"
|
|
||||||
#include "engine.hpp"
|
|
||||||
#include "nstl.hpp"
|
|
||||||
|
|
||||||
#include "c_types_map.hpp"
|
|
||||||
#include "../cpu/cpu_engine.hpp"
|
|
||||||
|
|
||||||
namespace mkldnn {
|
|
||||||
namespace impl {
|
|
||||||
|
|
||||||
engine_factory_t *engine_factories[] = {
|
|
||||||
&cpu::engine_factory,
|
|
||||||
nullptr,
|
|
||||||
};
|
|
||||||
|
|
||||||
static inline engine_factory_t *get_engine_factory(engine_kind_t kind) {
|
|
||||||
for (engine_factory_t **ef = engine_factories; *ef; ef++)
|
|
||||||
if ((*ef)->kind() == kind)
|
|
||||||
return *ef;
|
|
||||||
return nullptr;
|
|
||||||
}
|
|
||||||
|
|
||||||
}
|
|
||||||
}
|
|
||||||
|
|
||||||
using namespace mkldnn::impl;
|
|
||||||
using namespace mkldnn::impl::status;
|
|
||||||
|
|
||||||
size_t mkldnn_engine_get_count(engine_kind_t kind) {
|
|
||||||
engine_factory_t *ef = get_engine_factory(kind);
|
|
||||||
return ef != nullptr ? ef->count() : 0;
|
|
||||||
}
|
|
||||||
|
|
||||||
status_t mkldnn_engine_create(engine_t **engine,
|
|
||||||
engine_kind_t kind, size_t index) {
|
|
||||||
if (engine == nullptr)
|
|
||||||
return invalid_arguments;
|
|
||||||
|
|
||||||
engine_factory_t *ef = get_engine_factory(kind);
|
|
||||||
if (ef == nullptr || index >= ef->count())
|
|
||||||
return invalid_arguments;
|
|
||||||
|
|
||||||
return ef->engine_create(engine, index);
|
|
||||||
}
|
|
||||||
|
|
||||||
status_t mkldnn_engine_get_kind(engine_t *engine, engine_kind_t *kind) {
|
|
||||||
if (engine == nullptr)
|
|
||||||
return invalid_arguments;
|
|
||||||
*kind = engine->kind();
|
|
||||||
return success;
|
|
||||||
}
|
|
||||||
|
|
||||||
status_t mkldnn_engine_destroy(engine_t *engine) {
|
|
||||||
/* TODO: engine->dec_ref_count(); */
|
|
||||||
delete engine;
|
|
||||||
return success;
|
|
||||||
}
|
|
||||||
|
|
||||||
// vim: et ts=4 sw=4 cindent cino^=l0,\:0,N-s
|
|
119
thirdparty/oidn/mkl-dnn/src/common/engine.hpp
vendored
119
thirdparty/oidn/mkl-dnn/src/common/engine.hpp
vendored
@ -1,119 +0,0 @@
|
|||||||
/*******************************************************************************
|
|
||||||
* Copyright 2016-2018 Intel Corporation
|
|
||||||
*
|
|
||||||
* Licensed under the Apache License, Version 2.0 (the "License");
|
|
||||||
* you may not use this file except in compliance with the License.
|
|
||||||
* You may obtain a copy of the License at
|
|
||||||
*
|
|
||||||
* http://www.apache.org/licenses/LICENSE-2.0
|
|
||||||
*
|
|
||||||
* Unless required by applicable law or agreed to in writing, software
|
|
||||||
* distributed under the License is distributed on an "AS IS" BASIS,
|
|
||||||
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
|
||||||
* See the License for the specific language governing permissions and
|
|
||||||
* limitations under the License.
|
|
||||||
*******************************************************************************/
|
|
||||||
|
|
||||||
#ifndef ENGINE_HPP
|
|
||||||
#define ENGINE_HPP
|
|
||||||
|
|
||||||
#include "mkldnn.h"
|
|
||||||
|
|
||||||
#include "c_types_map.hpp"
|
|
||||||
#include "primitive.hpp"
|
|
||||||
#include "utils.hpp"
|
|
||||||
|
|
||||||
/** \brief An abstraction of an execution unit with shared resources
|
|
||||||
*
|
|
||||||
* Responsibilities:
|
|
||||||
* - Provide engine specific memory allocation
|
|
||||||
* - Provide engine specific primitive_desc_t creators
|
|
||||||
*/
|
|
||||||
struct mkldnn_engine: public mkldnn::impl::c_compatible {
|
|
||||||
mkldnn_engine(mkldnn::impl::engine_kind_t kind)
|
|
||||||
: kind_(kind)
|
|
||||||
{}
|
|
||||||
virtual ~mkldnn_engine() {}
|
|
||||||
|
|
||||||
/** get kind of the current engine */
|
|
||||||
virtual mkldnn::impl::engine_kind_t kind() const { return kind_; }
|
|
||||||
|
|
||||||
/** allocate memory */
|
|
||||||
virtual mkldnn::impl::status_t memory_create(
|
|
||||||
mkldnn::impl::memory_t **memory,
|
|
||||||
const mkldnn::impl::memory_desc_t *md,
|
|
||||||
void *handle) = 0;
|
|
||||||
|
|
||||||
/** implementation section (typedefs) */
|
|
||||||
|
|
||||||
// TODO: remove engine?
|
|
||||||
typedef mkldnn::impl::status_t (*reorder_primitive_desc_create_f)(
|
|
||||||
mkldnn::impl::reorder_pd_t **reorder_pd,
|
|
||||||
mkldnn::impl::engine_t *engine,
|
|
||||||
const mkldnn::impl::primitive_attr_t *attr,
|
|
||||||
mkldnn::impl::engine_t *src_engine,
|
|
||||||
const mkldnn::impl::memory_desc_t *src_md,
|
|
||||||
mkldnn::impl::engine_t *dst_engine,
|
|
||||||
const mkldnn::impl::memory_desc_t *dst_md);
|
|
||||||
|
|
||||||
typedef mkldnn::impl::status_t (*concat_primitive_desc_create_f)(
|
|
||||||
mkldnn::impl::concat_pd_t **concat_pd,
|
|
||||||
mkldnn::impl::engine_t *engine,
|
|
||||||
const mkldnn::impl::primitive_attr_t *attr,
|
|
||||||
const mkldnn::impl::memory_desc_t *dst_md,
|
|
||||||
int n, int concat_dim,
|
|
||||||
const mkldnn::impl::memory_desc_t *src_mds);
|
|
||||||
|
|
||||||
typedef mkldnn::impl::status_t (*sum_primitive_desc_create_f)(
|
|
||||||
mkldnn::impl::sum_pd_t **sum_pd,
|
|
||||||
mkldnn::impl::engine_t *engine,
|
|
||||||
const mkldnn::impl::primitive_attr_t *attr,
|
|
||||||
const mkldnn::impl::memory_desc_t *dst_md,
|
|
||||||
int n, const float *scales,
|
|
||||||
const mkldnn::impl::memory_desc_t *src_mds);
|
|
||||||
|
|
||||||
typedef mkldnn::impl::status_t (*primitive_desc_create_f)(
|
|
||||||
mkldnn::impl::primitive_desc_t **, const mkldnn::impl::op_desc_t *,
|
|
||||||
const mkldnn::impl::primitive_attr_t *attr,
|
|
||||||
mkldnn::impl::engine_t *, const mkldnn::impl::primitive_desc_t *);
|
|
||||||
|
|
||||||
/* implementation section */
|
|
||||||
|
|
||||||
/** return the list of reorder implementations. engine guarantees to return
|
|
||||||
* a NULL-terminated list */
|
|
||||||
virtual const reorder_primitive_desc_create_f*
|
|
||||||
get_reorder_implementation_list() const = 0;
|
|
||||||
|
|
||||||
/** return the list of concat implementations. engine guarantees to return
|
|
||||||
* a NULL-terminated list */
|
|
||||||
virtual const concat_primitive_desc_create_f*
|
|
||||||
get_concat_implementation_list() const = 0;
|
|
||||||
|
|
||||||
/** return the list of sum implementations. engine guarantees to return
|
|
||||||
* a NULL-terminated list */
|
|
||||||
virtual const sum_primitive_desc_create_f*
|
|
||||||
get_sum_implementation_list() const = 0;
|
|
||||||
|
|
||||||
/** return the list of implementations. engine guarantees to return a
|
|
||||||
* NULL-terminated list */
|
|
||||||
virtual const primitive_desc_create_f* get_implementation_list() const = 0;
|
|
||||||
|
|
||||||
protected:
|
|
||||||
mkldnn::impl::engine_kind_t kind_;
|
|
||||||
};
|
|
||||||
|
|
||||||
namespace mkldnn {
|
|
||||||
namespace impl {
|
|
||||||
|
|
||||||
struct engine_factory_t: public c_compatible {
|
|
||||||
virtual size_t count() const = 0;
|
|
||||||
virtual engine_kind_t kind() const = 0;
|
|
||||||
virtual status_t engine_create(engine_t **engine, size_t index) const = 0;
|
|
||||||
};
|
|
||||||
|
|
||||||
}
|
|
||||||
}
|
|
||||||
|
|
||||||
#endif
|
|
||||||
|
|
||||||
// vim: et ts=4 sw=4 cindent cino^=l0,\:0,N-s
|
|
106
thirdparty/oidn/mkl-dnn/src/common/inner_product.cpp
vendored
106
thirdparty/oidn/mkl-dnn/src/common/inner_product.cpp
vendored
@ -1,106 +0,0 @@
|
|||||||
/*******************************************************************************
|
|
||||||
* Copyright 2016-2018 Intel Corporation
|
|
||||||
*
|
|
||||||
* Licensed under the Apache License, Version 2.0 (the "License");
|
|
||||||
* you may not use this file except in compliance with the License.
|
|
||||||
* You may obtain a copy of the License at
|
|
||||||
*
|
|
||||||
* http://www.apache.org/licenses/LICENSE-2.0
|
|
||||||
*
|
|
||||||
* Unless required by applicable law or agreed to in writing, software
|
|
||||||
* distributed under the License is distributed on an "AS IS" BASIS,
|
|
||||||
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
|
||||||
* See the License for the specific language governing permissions and
|
|
||||||
* limitations under the License.
|
|
||||||
*******************************************************************************/
|
|
||||||
|
|
||||||
#include <assert.h>
|
|
||||||
#include "mkldnn.h"
|
|
||||||
|
|
||||||
#include "c_types_map.hpp"
|
|
||||||
#include "type_helpers.hpp"
|
|
||||||
#include "utils.hpp"
|
|
||||||
|
|
||||||
using namespace mkldnn::impl;
|
|
||||||
using namespace mkldnn::impl::utils;
|
|
||||||
using namespace mkldnn::impl::status;
|
|
||||||
using namespace mkldnn::impl::prop_kind;
|
|
||||||
using namespace mkldnn::impl::types;
|
|
||||||
|
|
||||||
namespace {
|
|
||||||
status_t ip_desc_init(inner_product_desc_t *ip_desc, prop_kind_t prop_kind,
|
|
||||||
const memory_desc_t *src_desc, const memory_desc_t *weights_desc,
|
|
||||||
const memory_desc_t *bias_desc, const memory_desc_t *dst_desc) {
|
|
||||||
bool args_ok = !any_null(ip_desc, src_desc, weights_desc, dst_desc);
|
|
||||||
if (!args_ok) return invalid_arguments;
|
|
||||||
|
|
||||||
auto id = inner_product_desc_t();
|
|
||||||
id.primitive_kind = primitive_kind::inner_product;
|
|
||||||
id.prop_kind = prop_kind;
|
|
||||||
|
|
||||||
id.diff_src_desc = id.src_desc = zero_md();
|
|
||||||
id.diff_dst_desc = id.dst_desc = zero_md();
|
|
||||||
id.diff_weights_desc = id.weights_desc = zero_md();
|
|
||||||
id.diff_bias_desc = id.bias_desc = zero_md();
|
|
||||||
|
|
||||||
const bool is_fwd = one_of(prop_kind, forward_training, forward_inference);
|
|
||||||
const bool with_bias =
|
|
||||||
bias_desc && bias_desc->format_kind != format_kind::undef;
|
|
||||||
|
|
||||||
(prop_kind == backward_data ? id.diff_src_desc : id.src_desc) = *src_desc;
|
|
||||||
(is_fwd ? id.dst_desc : id.diff_dst_desc) = *dst_desc;
|
|
||||||
(prop_kind == backward_weights ? id.diff_weights_desc : id.weights_desc) =
|
|
||||||
*weights_desc;
|
|
||||||
if (with_bias)
|
|
||||||
(prop_kind == backward_weights ? id.diff_bias_desc : id.bias_desc) =
|
|
||||||
*bias_desc;
|
|
||||||
|
|
||||||
id.accum_data_type = types::default_accum_data_type(src_desc->data_type,
|
|
||||||
weights_desc->data_type, dst_desc->data_type, prop_kind);
|
|
||||||
|
|
||||||
bool consistency = true
|
|
||||||
&& memory_desc_wrapper(weights_desc).nelems()
|
|
||||||
&& one_of(src_desc->ndims, 2, 3, 4, 5)
|
|
||||||
&& dst_desc->ndims == 2
|
|
||||||
&& weights_desc->ndims == src_desc->ndims
|
|
||||||
&& (with_bias ? bias_desc->ndims == 1 : true)
|
|
||||||
&& (with_bias ? bias_desc->dims[0] == dst_desc->dims[1] : true)
|
|
||||||
&& src_desc->dims[0] == dst_desc->dims[0]
|
|
||||||
&& array_cmp(&src_desc->dims[1], &weights_desc->dims[1],
|
|
||||||
src_desc->ndims - 1)
|
|
||||||
&& dst_desc->dims[1] == weights_desc->dims[0];
|
|
||||||
if (!consistency) return invalid_arguments;
|
|
||||||
|
|
||||||
*ip_desc = id;
|
|
||||||
return success;
|
|
||||||
}
|
|
||||||
}
|
|
||||||
|
|
||||||
status_t mkldnn_inner_product_forward_desc_init(inner_product_desc_t *ip_desc,
|
|
||||||
prop_kind_t prop_kind, const memory_desc_t *src_desc,
|
|
||||||
const memory_desc_t *weights_desc, const memory_desc_t *bias_desc,
|
|
||||||
const memory_desc_t *dst_desc) {
|
|
||||||
if (!one_of(prop_kind, forward_training, forward_inference))
|
|
||||||
return invalid_arguments;
|
|
||||||
return ip_desc_init(ip_desc, prop_kind, src_desc, weights_desc, bias_desc,
|
|
||||||
dst_desc);
|
|
||||||
}
|
|
||||||
|
|
||||||
status_t mkldnn_inner_product_backward_data_desc_init(
|
|
||||||
inner_product_desc_t *ip_desc, const memory_desc_t *diff_src_desc,
|
|
||||||
const memory_desc_t *weights_desc, const memory_desc_t *diff_dst_desc)
|
|
||||||
{
|
|
||||||
return ip_desc_init(ip_desc, backward_data, diff_src_desc, weights_desc,
|
|
||||||
nullptr, diff_dst_desc);
|
|
||||||
}
|
|
||||||
|
|
||||||
status_t mkldnn_inner_product_backward_weights_desc_init(
|
|
||||||
inner_product_desc_t *ip_desc, const memory_desc_t *src_desc,
|
|
||||||
const memory_desc_t *diff_weights_desc,
|
|
||||||
const memory_desc_t *diff_bias_desc,
|
|
||||||
const memory_desc_t *diff_dst_desc) {
|
|
||||||
return ip_desc_init(ip_desc, backward_weights, src_desc, diff_weights_desc,
|
|
||||||
diff_bias_desc, diff_dst_desc);
|
|
||||||
}
|
|
||||||
|
|
||||||
// vim: et ts=4 sw=4 cindent cino^=l0,\:0,N-s
|
|
@ -1,56 +0,0 @@
|
|||||||
/*******************************************************************************
|
|
||||||
* Copyright 2018 Intel Corporation
|
|
||||||
*
|
|
||||||
* Licensed under the Apache License, Version 2.0 (the "License");
|
|
||||||
* you may not use this file except in compliance with the License.
|
|
||||||
* You may obtain a copy of the License at
|
|
||||||
*
|
|
||||||
* http://www.apache.org/licenses/LICENSE-2.0
|
|
||||||
*
|
|
||||||
* Unless required by applicable law or agreed to in writing, software
|
|
||||||
* distributed under the License is distributed on an "AS IS" BASIS,
|
|
||||||
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
|
||||||
* See the License for the specific language governing permissions and
|
|
||||||
* limitations under the License.
|
|
||||||
*******************************************************************************/
|
|
||||||
|
|
||||||
#include "utils.hpp"
|
|
||||||
|
|
||||||
#include "inner_product_pd.hpp"
|
|
||||||
|
|
||||||
namespace mkldnn {
|
|
||||||
namespace impl {
|
|
||||||
|
|
||||||
using namespace prop_kind;
|
|
||||||
|
|
||||||
memory_desc_t *ip_prop_invariant_src_d(inner_product_desc_t *desc) {
|
|
||||||
return desc->prop_kind == backward_data
|
|
||||||
? &desc->diff_src_desc : &desc->src_desc;
|
|
||||||
}
|
|
||||||
|
|
||||||
memory_desc_t *ip_prop_invariant_wei_d(inner_product_desc_t *desc) {
|
|
||||||
return desc->prop_kind == backward_weights
|
|
||||||
? &desc->diff_weights_desc : &desc->weights_desc;
|
|
||||||
}
|
|
||||||
|
|
||||||
memory_desc_t *ip_prop_invariant_bia_d(inner_product_desc_t *desc) {
|
|
||||||
return desc->prop_kind == backward_weights
|
|
||||||
? &desc->diff_bias_desc : &desc->bias_desc;
|
|
||||||
}
|
|
||||||
|
|
||||||
memory_desc_t *ip_prop_invariant_dst_d(inner_product_desc_t *desc) {
|
|
||||||
return utils::one_of(desc->prop_kind, forward_inference, forward_training)
|
|
||||||
? &desc->dst_desc : &desc->diff_dst_desc;
|
|
||||||
}
|
|
||||||
|
|
||||||
const memory_desc_t *ip_prop_invariant_src_d(const inner_product_desc_t *desc)
|
|
||||||
{ return ip_prop_invariant_src_d(const_cast<inner_product_desc_t *>(desc)); }
|
|
||||||
const memory_desc_t *ip_prop_invariant_wei_d(const inner_product_desc_t *desc)
|
|
||||||
{ return ip_prop_invariant_wei_d(const_cast<inner_product_desc_t *>(desc)); }
|
|
||||||
const memory_desc_t *ip_prop_invariant_bia_d(const inner_product_desc_t *desc)
|
|
||||||
{ return ip_prop_invariant_bia_d(const_cast<inner_product_desc_t *>(desc)); }
|
|
||||||
const memory_desc_t *ip_prop_invariant_dst_d(const inner_product_desc_t *desc)
|
|
||||||
{ return ip_prop_invariant_dst_d(const_cast<inner_product_desc_t *>(desc)); }
|
|
||||||
|
|
||||||
}
|
|
||||||
}
|
|
@ -1,321 +0,0 @@
|
|||||||
/*******************************************************************************
|
|
||||||
* Copyright 2016-2018 Intel Corporation
|
|
||||||
*
|
|
||||||
* Licensed under the Apache License, Version 2.0 (the "License");
|
|
||||||
* you may not use this file except in compliance with the License.
|
|
||||||
* You may obtain a copy of the License at
|
|
||||||
*
|
|
||||||
* http://www.apache.org/licenses/LICENSE-2.0
|
|
||||||
*
|
|
||||||
* Unless required by applicable law or agreed to in writing, software
|
|
||||||
* distributed under the License is distributed on an "AS IS" BASIS,
|
|
||||||
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
|
||||||
* See the License for the specific language governing permissions and
|
|
||||||
* limitations under the License.
|
|
||||||
*******************************************************************************/
|
|
||||||
|
|
||||||
#ifndef INNER_PRODUCT_PD_HPP
|
|
||||||
#define INNER_PRODUCT_PD_HPP
|
|
||||||
|
|
||||||
#include "mkldnn.h"
|
|
||||||
|
|
||||||
#include "c_types_map.hpp"
|
|
||||||
#include "primitive_desc.hpp"
|
|
||||||
#include "utils.hpp"
|
|
||||||
|
|
||||||
namespace mkldnn {
|
|
||||||
namespace impl {
|
|
||||||
|
|
||||||
memory_desc_t *ip_prop_invariant_src_d(inner_product_desc_t *desc);
|
|
||||||
memory_desc_t *ip_prop_invariant_wei_d(inner_product_desc_t *desc);
|
|
||||||
memory_desc_t *ip_prop_invariant_bia_d(inner_product_desc_t *desc);
|
|
||||||
memory_desc_t *ip_prop_invariant_dst_d(inner_product_desc_t *desc);
|
|
||||||
const memory_desc_t *ip_prop_invariant_src_d(const inner_product_desc_t *desc);
|
|
||||||
const memory_desc_t *ip_prop_invariant_wei_d(const inner_product_desc_t *desc);
|
|
||||||
const memory_desc_t *ip_prop_invariant_bia_d(const inner_product_desc_t *desc);
|
|
||||||
const memory_desc_t *ip_prop_invariant_dst_d(const inner_product_desc_t *desc);
|
|
||||||
|
|
||||||
struct inner_product_fwd_pd_t;
|
|
||||||
|
|
||||||
struct inner_product_pd_t: public primitive_desc_t {
|
|
||||||
static constexpr auto base_pkind = primitive_kind::inner_product;
|
|
||||||
|
|
||||||
inner_product_pd_t(engine_t *engine,
|
|
||||||
const inner_product_desc_t *adesc,
|
|
||||||
const primitive_attr_t *attr,
|
|
||||||
const inner_product_fwd_pd_t *hint_fwd_pd)
|
|
||||||
: primitive_desc_t(engine, attr, base_pkind)
|
|
||||||
, desc_(*adesc)
|
|
||||||
, hint_fwd_pd_(hint_fwd_pd)
|
|
||||||
{}
|
|
||||||
|
|
||||||
const inner_product_desc_t *desc() const { return &desc_; }
|
|
||||||
virtual const op_desc_t *op_desc() const override
|
|
||||||
{ return reinterpret_cast<const op_desc_t *>(this->desc()); }
|
|
||||||
virtual void init_info() override { impl::init_info(this, this->info_); }
|
|
||||||
|
|
||||||
virtual status_t query(query_t what, int idx, void *result) const override {
|
|
||||||
switch (what) {
|
|
||||||
case query::inner_product_d:
|
|
||||||
*(const inner_product_desc_t**)result = desc(); break;
|
|
||||||
default: return primitive_desc_t::query(what, idx, result);
|
|
||||||
}
|
|
||||||
return status::success;
|
|
||||||
}
|
|
||||||
|
|
||||||
/* common inner_product aux functions */
|
|
||||||
|
|
||||||
dim_t MB() const { return ip_prop_invariant_src_d(&desc_)->dims[0]; }
|
|
||||||
dim_t IC() const { return ip_prop_invariant_src_d(&desc_)->dims[1]; }
|
|
||||||
dim_t OC() const { return ip_prop_invariant_dst_d(&desc_)->dims[1]; }
|
|
||||||
|
|
||||||
dim_t ID() const {
|
|
||||||
return ndims() >= 5
|
|
||||||
? ip_prop_invariant_src_d(&desc_)->dims[ndims() - 3] : 1;
|
|
||||||
}
|
|
||||||
dim_t IH() const {
|
|
||||||
return ndims() >= 4
|
|
||||||
? ip_prop_invariant_src_d(&desc_)->dims[ndims() - 2] : 1;
|
|
||||||
}
|
|
||||||
dim_t IW() const {
|
|
||||||
return ndims() >= 3
|
|
||||||
? ip_prop_invariant_src_d(&desc_)->dims[ndims() - 1] : 1;
|
|
||||||
}
|
|
||||||
|
|
||||||
dim_t OD() const {
|
|
||||||
return ndims() >= 5
|
|
||||||
? ip_prop_invariant_dst_d(&desc_)->dims[ndims() - 3] : 1;
|
|
||||||
}
|
|
||||||
dim_t OH() const {
|
|
||||||
return ndims() >= 4
|
|
||||||
? ip_prop_invariant_dst_d(&desc_)->dims[ndims() - 2] : 1;
|
|
||||||
}
|
|
||||||
dim_t OW() const {
|
|
||||||
return ndims() >= 3
|
|
||||||
? ip_prop_invariant_dst_d(&desc_)->dims[ndims() - 1] : 1;
|
|
||||||
}
|
|
||||||
|
|
||||||
dim_t KD() const {
|
|
||||||
return ndims() >= 5
|
|
||||||
? ip_prop_invariant_wei_d(&desc_)->dims[ndims() - 3] : 1;
|
|
||||||
}
|
|
||||||
dim_t KH() const {
|
|
||||||
return ndims() >= 4
|
|
||||||
? ip_prop_invariant_wei_d(&desc_)->dims[ndims() - 2] : 1;
|
|
||||||
}
|
|
||||||
dim_t KW() const {
|
|
||||||
return ndims() >= 3
|
|
||||||
? ip_prop_invariant_wei_d(&desc_)->dims[ndims() - 1] : 1;
|
|
||||||
}
|
|
||||||
|
|
||||||
dim_t IC_total() const {
|
|
||||||
return utils::array_product(&ip_prop_invariant_src_d(&desc_)->dims[1],
|
|
||||||
ndims() - 1);
|
|
||||||
}
|
|
||||||
|
|
||||||
dim_t IC_total_padded() const {
|
|
||||||
auto src_d = desc()->prop_kind == prop_kind::backward_data
|
|
||||||
? memory_desc_wrapper(diff_src_md())
|
|
||||||
: memory_desc_wrapper(src_md());
|
|
||||||
assert(src_d.is_blocking_desc());
|
|
||||||
if (!src_d.is_blocking_desc()) return -1;
|
|
||||||
return utils::array_product(src_d.padded_dims() + 1, ndims() - 1);
|
|
||||||
}
|
|
||||||
|
|
||||||
int ndims() const { return ip_prop_invariant_src_d(&desc_)->ndims; }
|
|
||||||
|
|
||||||
bool with_bias() const
|
|
||||||
{ return !memory_desc_wrapper(*ip_prop_invariant_bia_d(&desc_)).is_zero(); }
|
|
||||||
|
|
||||||
bool has_zero_dim_memory() const {
|
|
||||||
const auto s_d = memory_desc_wrapper(*ip_prop_invariant_src_d(&desc_));
|
|
||||||
const auto d_d = memory_desc_wrapper(*ip_prop_invariant_dst_d(&desc_));
|
|
||||||
return s_d.has_zero_dim() || d_d.has_zero_dim();
|
|
||||||
}
|
|
||||||
|
|
||||||
bool is_fwd() const {
|
|
||||||
return utils::one_of(desc_.prop_kind, prop_kind::forward_training,
|
|
||||||
prop_kind::forward_inference);
|
|
||||||
}
|
|
||||||
|
|
||||||
protected:
|
|
||||||
inner_product_desc_t desc_;
|
|
||||||
const inner_product_fwd_pd_t *hint_fwd_pd_;
|
|
||||||
|
|
||||||
status_t template_set_default_params(memory_desc_t &src_md,
|
|
||||||
memory_desc_t &weights_md, memory_desc_t &dst_md,
|
|
||||||
memory_desc_t *bias_md) {
|
|
||||||
using namespace format_tag;
|
|
||||||
if (src_md.format_kind == format_kind::any) {
|
|
||||||
CHECK(memory_desc_init_by_tag(src_md,
|
|
||||||
utils::pick(ndims() - 2, nc, ncw, nchw, ncdhw)));
|
|
||||||
}
|
|
||||||
if (dst_md.format_kind == format_kind::any)
|
|
||||||
CHECK(memory_desc_init_by_tag(dst_md, nc));
|
|
||||||
if (weights_md.format_kind == format_kind::any) {
|
|
||||||
CHECK(memory_desc_init_by_tag(weights_md,
|
|
||||||
utils::pick(ndims() - 2, oi, oiw, oihw, oidhw)));
|
|
||||||
}
|
|
||||||
if (bias_md && bias_md->format_kind == format_kind::any)
|
|
||||||
CHECK(memory_desc_init_by_tag(*bias_md, x));
|
|
||||||
return status::success;
|
|
||||||
}
|
|
||||||
};
|
|
||||||
|
|
||||||
struct inner_product_fwd_pd_t: public inner_product_pd_t {
|
|
||||||
typedef inner_product_fwd_pd_t base_class;
|
|
||||||
typedef inner_product_fwd_pd_t hint_class;
|
|
||||||
|
|
||||||
inner_product_fwd_pd_t(engine_t *engine,
|
|
||||||
const inner_product_desc_t *adesc,
|
|
||||||
const primitive_attr_t *attr,
|
|
||||||
const inner_product_fwd_pd_t *hint_fwd_pd)
|
|
||||||
: inner_product_pd_t(engine, adesc, attr, hint_fwd_pd)
|
|
||||||
, src_md_(desc_.src_desc)
|
|
||||||
, weights_md_(desc_.weights_desc)
|
|
||||||
, bias_md_(desc_.bias_desc)
|
|
||||||
, dst_md_(desc_.dst_desc)
|
|
||||||
{}
|
|
||||||
|
|
||||||
virtual arg_usage_t arg_usage(primitive_arg_index_t arg) const override {
|
|
||||||
if (utils::one_of(arg, MKLDNN_ARG_SRC, MKLDNN_ARG_WEIGHTS))
|
|
||||||
return arg_usage_t::input;
|
|
||||||
|
|
||||||
if (arg == MKLDNN_ARG_BIAS && with_bias())
|
|
||||||
return arg_usage_t::input;
|
|
||||||
|
|
||||||
if (arg == MKLDNN_ARG_DST)
|
|
||||||
return arg_usage_t::output;
|
|
||||||
|
|
||||||
return primitive_desc_t::arg_usage(arg);
|
|
||||||
}
|
|
||||||
|
|
||||||
virtual const memory_desc_t *src_md(int index = 0) const override
|
|
||||||
{ return index == 0 ? &src_md_ : nullptr; }
|
|
||||||
virtual const memory_desc_t *dst_md(int index = 0) const override
|
|
||||||
{ return index == 0 ? &dst_md_ : nullptr; }
|
|
||||||
virtual const memory_desc_t *weights_md(int index = 0) const override {
|
|
||||||
if (index == 0) return &weights_md_;
|
|
||||||
if (index == 1 && with_bias()) return &bias_md_;
|
|
||||||
return nullptr;
|
|
||||||
}
|
|
||||||
|
|
||||||
virtual int n_inputs() const override { return 2 + with_bias(); }
|
|
||||||
virtual int n_outputs() const override { return 1; }
|
|
||||||
|
|
||||||
protected:
|
|
||||||
memory_desc_t src_md_;
|
|
||||||
memory_desc_t weights_md_;
|
|
||||||
memory_desc_t bias_md_;
|
|
||||||
memory_desc_t dst_md_;
|
|
||||||
|
|
||||||
status_t set_default_params() {
|
|
||||||
return template_set_default_params(src_md_, weights_md_, dst_md_,
|
|
||||||
&bias_md_);
|
|
||||||
}
|
|
||||||
};
|
|
||||||
|
|
||||||
struct inner_product_bwd_data_pd_t: public inner_product_pd_t {
|
|
||||||
typedef inner_product_bwd_data_pd_t base_class;
|
|
||||||
typedef inner_product_fwd_pd_t hint_class;
|
|
||||||
|
|
||||||
inner_product_bwd_data_pd_t(engine_t *engine,
|
|
||||||
const inner_product_desc_t *adesc,
|
|
||||||
const primitive_attr_t *attr,
|
|
||||||
const inner_product_fwd_pd_t *hint_fwd_pd)
|
|
||||||
: inner_product_pd_t(engine, adesc, attr, hint_fwd_pd)
|
|
||||||
, diff_src_md_(desc_.diff_src_desc)
|
|
||||||
, weights_md_(desc_.weights_desc)
|
|
||||||
, diff_dst_md_(desc_.diff_dst_desc)
|
|
||||||
{}
|
|
||||||
|
|
||||||
virtual arg_usage_t arg_usage(primitive_arg_index_t arg) const override {
|
|
||||||
if (utils::one_of(arg, MKLDNN_ARG_WEIGHTS, MKLDNN_ARG_DIFF_DST))
|
|
||||||
return arg_usage_t::input;
|
|
||||||
|
|
||||||
if (arg == MKLDNN_ARG_DIFF_SRC)
|
|
||||||
return arg_usage_t::output;
|
|
||||||
|
|
||||||
return primitive_desc_t::arg_usage(arg);
|
|
||||||
}
|
|
||||||
|
|
||||||
virtual const memory_desc_t *diff_src_md(int index = 0) const override
|
|
||||||
{ return index == 0 ? &diff_src_md_ : nullptr; }
|
|
||||||
virtual const memory_desc_t *diff_dst_md(int index = 0) const override
|
|
||||||
{ return index == 0 ? &diff_dst_md_ : nullptr; }
|
|
||||||
virtual const memory_desc_t *weights_md(int index = 0) const override
|
|
||||||
{ return index == 0 ? &weights_md_ : nullptr; }
|
|
||||||
|
|
||||||
virtual int n_inputs() const override { return 2; }
|
|
||||||
virtual int n_outputs() const override { return 1; }
|
|
||||||
|
|
||||||
protected:
|
|
||||||
memory_desc_t diff_src_md_;
|
|
||||||
memory_desc_t weights_md_;
|
|
||||||
memory_desc_t diff_dst_md_;
|
|
||||||
|
|
||||||
status_t set_default_params() {
|
|
||||||
return template_set_default_params(diff_src_md_, weights_md_,
|
|
||||||
diff_dst_md_, nullptr);
|
|
||||||
}
|
|
||||||
};
|
|
||||||
|
|
||||||
struct inner_product_bwd_weights_pd_t: public inner_product_pd_t {
|
|
||||||
typedef inner_product_bwd_weights_pd_t base_class;
|
|
||||||
typedef inner_product_fwd_pd_t hint_class;
|
|
||||||
|
|
||||||
inner_product_bwd_weights_pd_t(engine_t *engine,
|
|
||||||
const inner_product_desc_t *adesc,
|
|
||||||
const primitive_attr_t *attr,
|
|
||||||
const inner_product_fwd_pd_t *hint_fwd_pd)
|
|
||||||
: inner_product_pd_t(engine, adesc, attr, hint_fwd_pd)
|
|
||||||
, src_md_(desc_.src_desc)
|
|
||||||
, diff_weights_md_(desc_.diff_weights_desc)
|
|
||||||
, diff_bias_md_(desc_.diff_bias_desc)
|
|
||||||
, diff_dst_md_(desc_.diff_dst_desc)
|
|
||||||
{}
|
|
||||||
|
|
||||||
virtual arg_usage_t arg_usage(primitive_arg_index_t arg) const override {
|
|
||||||
if (utils::one_of(arg, MKLDNN_ARG_SRC, MKLDNN_ARG_DIFF_DST))
|
|
||||||
return arg_usage_t::input;
|
|
||||||
|
|
||||||
if (arg == MKLDNN_ARG_DIFF_WEIGHTS)
|
|
||||||
return arg_usage_t::output;
|
|
||||||
|
|
||||||
if (arg == MKLDNN_ARG_DIFF_BIAS && with_bias())
|
|
||||||
return arg_usage_t::output;
|
|
||||||
|
|
||||||
return primitive_desc_t::arg_usage(arg);
|
|
||||||
}
|
|
||||||
|
|
||||||
virtual const memory_desc_t *src_md(int index = 0) const override
|
|
||||||
{ return index == 0 ? &src_md_ : nullptr; }
|
|
||||||
virtual const memory_desc_t *diff_dst_md(int index = 0) const override
|
|
||||||
{ return index == 0 ? &diff_dst_md_ : nullptr; }
|
|
||||||
virtual const memory_desc_t *diff_weights_md(int index = 0) const override {
|
|
||||||
if (index == 0) return &diff_weights_md_;
|
|
||||||
if (index == 1 && with_bias()) return &diff_bias_md_;
|
|
||||||
return nullptr;
|
|
||||||
}
|
|
||||||
|
|
||||||
virtual int n_inputs() const override { return 2; }
|
|
||||||
virtual int n_outputs() const override { return 1 + with_bias(); }
|
|
||||||
|
|
||||||
protected:
|
|
||||||
memory_desc_t src_md_;
|
|
||||||
memory_desc_t diff_weights_md_;
|
|
||||||
memory_desc_t diff_bias_md_;
|
|
||||||
memory_desc_t diff_dst_md_;
|
|
||||||
|
|
||||||
status_t set_default_params() {
|
|
||||||
return template_set_default_params(src_md_, diff_weights_md_,
|
|
||||||
diff_dst_md_, &diff_bias_md_);
|
|
||||||
}
|
|
||||||
};
|
|
||||||
|
|
||||||
}
|
|
||||||
}
|
|
||||||
|
|
||||||
#endif
|
|
||||||
|
|
||||||
// vim: et ts=4 sw=4 cindent cino^=l0,\:0,N-s
|
|
91
thirdparty/oidn/mkl-dnn/src/common/lrn.cpp
vendored
91
thirdparty/oidn/mkl-dnn/src/common/lrn.cpp
vendored
@ -1,91 +0,0 @@
|
|||||||
/*******************************************************************************
|
|
||||||
* Copyright 2016-2018 Intel Corporation
|
|
||||||
*
|
|
||||||
* Licensed under the Apache License, Version 2.0 (the "License");
|
|
||||||
* you may not use this file except in compliance with the License.
|
|
||||||
* You may obtain a copy of the License at
|
|
||||||
*
|
|
||||||
* http://www.apache.org/licenses/LICENSE-2.0
|
|
||||||
*
|
|
||||||
* Unless required by applicable law or agreed to in writing, software
|
|
||||||
* distributed under the License is distributed on an "AS IS" BASIS,
|
|
||||||
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
|
||||||
* See the License for the specific language governing permissions and
|
|
||||||
* limitations under the License.
|
|
||||||
*******************************************************************************/
|
|
||||||
|
|
||||||
#include <assert.h>
|
|
||||||
#include "mkldnn.h"
|
|
||||||
|
|
||||||
#include "c_types_map.hpp"
|
|
||||||
#include "type_helpers.hpp"
|
|
||||||
#include "utils.hpp"
|
|
||||||
|
|
||||||
using namespace mkldnn::impl;
|
|
||||||
using namespace mkldnn::impl::utils;
|
|
||||||
using namespace mkldnn::impl::status;
|
|
||||||
using namespace mkldnn::impl::prop_kind;
|
|
||||||
using namespace mkldnn::impl::alg_kind;
|
|
||||||
using namespace mkldnn::impl::types;
|
|
||||||
|
|
||||||
namespace {
|
|
||||||
status_t lrn_desc_init(lrn_desc_t *lrn_desc,
|
|
||||||
prop_kind_t prop_kind, alg_kind_t alg_kind,
|
|
||||||
const memory_desc_t *data_desc, const memory_desc_t *diff_data_desc,
|
|
||||||
dim_t local_size, float alpha, float beta, float k) {
|
|
||||||
bool args_ok = true
|
|
||||||
&& !any_null(lrn_desc, data_desc)
|
|
||||||
&& one_of(alg_kind, lrn_within_channel, lrn_across_channels)
|
|
||||||
&& one_of(prop_kind, forward_training, forward_inference, backward_data)
|
|
||||||
&& IMPLICATION(prop_kind == backward_data, diff_data_desc != nullptr);
|
|
||||||
if (!args_ok) return invalid_arguments;
|
|
||||||
|
|
||||||
auto ld = lrn_desc_t();
|
|
||||||
ld.primitive_kind = primitive_kind::lrn;
|
|
||||||
ld.prop_kind = prop_kind;
|
|
||||||
ld.alg_kind = alg_kind;
|
|
||||||
|
|
||||||
const bool is_fwd = one_of(prop_kind, forward_training, forward_inference);
|
|
||||||
|
|
||||||
ld.data_desc = *data_desc;
|
|
||||||
if (!is_fwd)
|
|
||||||
ld.diff_data_desc = *diff_data_desc;
|
|
||||||
else
|
|
||||||
ld.diff_data_desc = zero_md();
|
|
||||||
ld.local_size = local_size;
|
|
||||||
ld.lrn_alpha = alpha;
|
|
||||||
ld.lrn_beta = beta;
|
|
||||||
ld.lrn_k = k;
|
|
||||||
|
|
||||||
bool consistency = true
|
|
||||||
&& ld.data_desc.ndims == 4;
|
|
||||||
if (ld.prop_kind == backward_data)
|
|
||||||
consistency = consistency
|
|
||||||
&& ld.diff_data_desc.ndims == 4
|
|
||||||
&& array_cmp(ld.diff_data_desc.dims, ld.data_desc.dims, 4);
|
|
||||||
if (!consistency) return invalid_arguments;
|
|
||||||
|
|
||||||
*lrn_desc = ld;
|
|
||||||
return success;
|
|
||||||
}
|
|
||||||
}
|
|
||||||
|
|
||||||
status_t mkldnn_lrn_forward_desc_init(lrn_desc_t *lrn_desc,
|
|
||||||
prop_kind_t prop_kind, alg_kind_t alg_kind,
|
|
||||||
const memory_desc_t *data_desc, dim_t local_size, float alpha,
|
|
||||||
float beta, float k) {
|
|
||||||
if (!one_of(prop_kind, forward_training, forward_inference))
|
|
||||||
return invalid_arguments;
|
|
||||||
return lrn_desc_init(lrn_desc, prop_kind, alg_kind, data_desc, nullptr,
|
|
||||||
local_size, alpha, beta, k);
|
|
||||||
}
|
|
||||||
|
|
||||||
status_t mkldnn_lrn_backward_desc_init(lrn_desc_t *lrn_desc,
|
|
||||||
alg_kind_t alg_kind, const memory_desc_t *data_desc,
|
|
||||||
const memory_desc_t *diff_data_desc, dim_t local_size, float alpha,
|
|
||||||
float beta, float k) {
|
|
||||||
return lrn_desc_init(lrn_desc, backward_data, alg_kind, data_desc,
|
|
||||||
diff_data_desc, local_size, alpha, beta, k);
|
|
||||||
}
|
|
||||||
|
|
||||||
// vim: et ts=4 sw=4 cindent cino^=l0,\:0,N-s
|
|
170
thirdparty/oidn/mkl-dnn/src/common/lrn_pd.hpp
vendored
170
thirdparty/oidn/mkl-dnn/src/common/lrn_pd.hpp
vendored
@ -1,170 +0,0 @@
|
|||||||
/*******************************************************************************
|
|
||||||
* Copyright 2016-2018 Intel Corporation
|
|
||||||
*
|
|
||||||
* Licensed under the Apache License, Version 2.0 (the "License");
|
|
||||||
* you may not use this file except in compliance with the License.
|
|
||||||
* You may obtain a copy of the License at
|
|
||||||
*
|
|
||||||
* http://www.apache.org/licenses/LICENSE-2.0
|
|
||||||
*
|
|
||||||
* Unless required by applicable law or agreed to in writing, software
|
|
||||||
* distributed under the License is distributed on an "AS IS" BASIS,
|
|
||||||
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
|
||||||
* See the License for the specific language governing permissions and
|
|
||||||
* limitations under the License.
|
|
||||||
*******************************************************************************/
|
|
||||||
|
|
||||||
#ifndef LRN_PD_HPP
|
|
||||||
#define LRN_PD_HPP
|
|
||||||
|
|
||||||
#include "mkldnn.h"
|
|
||||||
|
|
||||||
#include "c_types_map.hpp"
|
|
||||||
#include "primitive_desc.hpp"
|
|
||||||
|
|
||||||
namespace mkldnn {
|
|
||||||
namespace impl {
|
|
||||||
|
|
||||||
struct lrn_fwd_pd_t;
|
|
||||||
|
|
||||||
struct lrn_pd_t: public primitive_desc_t {
|
|
||||||
static constexpr auto base_pkind = primitive_kind::lrn;
|
|
||||||
|
|
||||||
lrn_pd_t(engine_t *engine,
|
|
||||||
const lrn_desc_t *adesc,
|
|
||||||
const primitive_attr_t *attr,
|
|
||||||
const lrn_fwd_pd_t *hint_fwd_pd)
|
|
||||||
: primitive_desc_t(engine, attr, base_pkind)
|
|
||||||
, desc_(*adesc)
|
|
||||||
, hint_fwd_pd_(hint_fwd_pd)
|
|
||||||
, data_md_(desc_.data_desc)
|
|
||||||
, ws_md_()
|
|
||||||
{}
|
|
||||||
|
|
||||||
const lrn_desc_t *desc() const { return &desc_; }
|
|
||||||
virtual const op_desc_t *op_desc() const override
|
|
||||||
{ return reinterpret_cast<const op_desc_t *>(this->desc()); }
|
|
||||||
virtual void init_info() override { impl::init_info(this, this->info_); }
|
|
||||||
|
|
||||||
virtual status_t query(query_t what, int idx, void *result) const override {
|
|
||||||
switch (what) {
|
|
||||||
case query::lrn_d:
|
|
||||||
*(const lrn_desc_t**)result = desc(); break;
|
|
||||||
default: return primitive_desc_t::query(what, idx, result);
|
|
||||||
}
|
|
||||||
return status::success;
|
|
||||||
}
|
|
||||||
|
|
||||||
/* common lrn aux functions */
|
|
||||||
|
|
||||||
dim_t MB() const { return data_desc().dims[0]; }
|
|
||||||
dim_t C() const { return data_desc().dims[1]; }
|
|
||||||
dim_t D() const { return ndims() >= 5 ? data_desc().dims[ndims() - 3] : 1; }
|
|
||||||
dim_t H() const { return ndims() >= 4 ? data_desc().dims[ndims() - 2] : 1; }
|
|
||||||
dim_t W() const { return ndims() >= 3 ? data_desc().dims[ndims() - 1] : 1; }
|
|
||||||
|
|
||||||
int ndims() const { return data_desc().ndims; }
|
|
||||||
|
|
||||||
bool has_zero_dim_memory() const
|
|
||||||
{ return memory_desc_wrapper(desc_.data_desc).has_zero_dim(); }
|
|
||||||
|
|
||||||
bool is_fwd() const {
|
|
||||||
return utils::one_of(desc_.prop_kind, prop_kind::forward_training,
|
|
||||||
prop_kind::forward_inference);
|
|
||||||
}
|
|
||||||
|
|
||||||
protected:
|
|
||||||
lrn_desc_t desc_;
|
|
||||||
const lrn_fwd_pd_t *hint_fwd_pd_;
|
|
||||||
|
|
||||||
memory_desc_t data_md_;
|
|
||||||
memory_desc_t ws_md_;
|
|
||||||
|
|
||||||
private:
|
|
||||||
const memory_desc_t &data_desc() const { return desc_.data_desc; }
|
|
||||||
};
|
|
||||||
|
|
||||||
struct lrn_fwd_pd_t: public lrn_pd_t {
|
|
||||||
typedef lrn_fwd_pd_t base_class;
|
|
||||||
typedef lrn_fwd_pd_t hint_class;
|
|
||||||
|
|
||||||
lrn_fwd_pd_t(engine_t *engine,
|
|
||||||
const lrn_desc_t *adesc,
|
|
||||||
const primitive_attr_t *attr,
|
|
||||||
const lrn_fwd_pd_t *hint_fwd_pd)
|
|
||||||
: lrn_pd_t(engine, adesc, attr, hint_fwd_pd)
|
|
||||||
{}
|
|
||||||
|
|
||||||
virtual arg_usage_t arg_usage(primitive_arg_index_t arg) const override {
|
|
||||||
if (arg == MKLDNN_ARG_SRC)
|
|
||||||
return arg_usage_t::input;
|
|
||||||
|
|
||||||
if (arg == MKLDNN_ARG_DST)
|
|
||||||
return arg_usage_t::output;
|
|
||||||
|
|
||||||
if (arg == MKLDNN_ARG_WORKSPACE && (workspace_md() != nullptr))
|
|
||||||
return arg_usage_t::output;
|
|
||||||
|
|
||||||
return primitive_desc_t::arg_usage(arg);
|
|
||||||
}
|
|
||||||
|
|
||||||
virtual const memory_desc_t *src_md(int index = 0) const override
|
|
||||||
{ return index == 0 ? &data_md_ : nullptr; }
|
|
||||||
virtual const memory_desc_t *dst_md(int index = 0) const override
|
|
||||||
{ return index == 0 ? &data_md_ : nullptr; }
|
|
||||||
virtual const memory_desc_t *workspace_md(int index = 0) const override
|
|
||||||
{ return index == 0 && !types::is_zero_md(&ws_md_) ? &ws_md_ : nullptr; }
|
|
||||||
|
|
||||||
virtual int n_inputs() const override { return 1; }
|
|
||||||
virtual int n_outputs() const override
|
|
||||||
{ return 1 + (workspace_md() != nullptr); }
|
|
||||||
};
|
|
||||||
|
|
||||||
struct lrn_bwd_pd_t: public lrn_pd_t {
|
|
||||||
typedef lrn_bwd_pd_t base_class;
|
|
||||||
typedef lrn_fwd_pd_t hint_class;
|
|
||||||
|
|
||||||
lrn_bwd_pd_t(engine_t *engine,
|
|
||||||
const lrn_desc_t *adesc,
|
|
||||||
const primitive_attr_t *attr,
|
|
||||||
const lrn_fwd_pd_t *hint_fwd_pd)
|
|
||||||
: lrn_pd_t(engine, adesc, attr, hint_fwd_pd)
|
|
||||||
, diff_data_md_(desc_.diff_data_desc)
|
|
||||||
{}
|
|
||||||
|
|
||||||
virtual arg_usage_t arg_usage(primitive_arg_index_t arg) const override {
|
|
||||||
if (utils::one_of(arg, MKLDNN_ARG_SRC, MKLDNN_ARG_DIFF_DST))
|
|
||||||
return arg_usage_t::input;
|
|
||||||
|
|
||||||
if (arg == MKLDNN_ARG_DIFF_SRC)
|
|
||||||
return arg_usage_t::output;
|
|
||||||
|
|
||||||
if (arg == MKLDNN_ARG_WORKSPACE && (workspace_md() != nullptr))
|
|
||||||
return arg_usage_t::input;
|
|
||||||
|
|
||||||
return primitive_desc_t::arg_usage(arg);
|
|
||||||
}
|
|
||||||
|
|
||||||
virtual const memory_desc_t *src_md(int index = 0) const override
|
|
||||||
{ return index == 0 ? &data_md_ : nullptr; }
|
|
||||||
virtual const memory_desc_t *diff_dst_md(int index = 0) const override
|
|
||||||
{ return index == 0 ? &diff_data_md_ : nullptr; }
|
|
||||||
virtual const memory_desc_t *diff_src_md(int index = 0) const override
|
|
||||||
{ return index == 0 ? &diff_data_md_ : nullptr; }
|
|
||||||
virtual const memory_desc_t *workspace_md(int index = 0) const override
|
|
||||||
{ return index == 0 && !types::is_zero_md(&ws_md_) ? &ws_md_ : nullptr; }
|
|
||||||
|
|
||||||
virtual int n_inputs() const override
|
|
||||||
{ return 2 + (workspace_md() != nullptr); }
|
|
||||||
virtual int n_outputs() const override { return 1; }
|
|
||||||
|
|
||||||
protected:
|
|
||||||
memory_desc_t diff_data_md_;
|
|
||||||
};
|
|
||||||
|
|
||||||
}
|
|
||||||
}
|
|
||||||
|
|
||||||
#endif
|
|
||||||
|
|
||||||
// vim: et ts=4 sw=4 cindent cino^=l0,\:0,N-s
|
|
280
thirdparty/oidn/mkl-dnn/src/common/math_utils.hpp
vendored
280
thirdparty/oidn/mkl-dnn/src/common/math_utils.hpp
vendored
@ -1,280 +0,0 @@
|
|||||||
/*******************************************************************************
|
|
||||||
* Copyright 2017-2018 Intel Corporation
|
|
||||||
*
|
|
||||||
* Licensed under the Apache License, Version 2.0 (the "License");
|
|
||||||
* you may not use this file except in compliance with the License.
|
|
||||||
* You may obtain a copy of the License at
|
|
||||||
*
|
|
||||||
* http://www.apache.org/licenses/LICENSE-2.0
|
|
||||||
*
|
|
||||||
* Unless required by applicable law or agreed to in writing, software
|
|
||||||
* distributed under the License is distributed on an "AS IS" BASIS,
|
|
||||||
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
|
||||||
* See the License for the specific language governing permissions and
|
|
||||||
* limitations under the License.
|
|
||||||
*******************************************************************************/
|
|
||||||
|
|
||||||
#ifndef MATH_UTILS_HPP
|
|
||||||
#define MATH_UTILS_HPP
|
|
||||||
|
|
||||||
#include <stdint.h>
|
|
||||||
#include <math.h>
|
|
||||||
|
|
||||||
#include "utils.hpp"
|
|
||||||
#include "nstl.hpp"
|
|
||||||
#include "mkldnn_traits.hpp"
|
|
||||||
|
|
||||||
#if defined(MKLDNN_X86_64)
|
|
||||||
#include "immintrin.h"
|
|
||||||
#endif
|
|
||||||
|
|
||||||
namespace mkldnn {
|
|
||||||
namespace impl {
|
|
||||||
namespace math {
|
|
||||||
|
|
||||||
/** rounds @p f to an integer according to the mxcsr register */
|
|
||||||
inline int mxcsr_round(float f) {
|
|
||||||
#if defined(MKLDNN_X86_64)
|
|
||||||
return _mm_cvtss_si32(_mm_load_ss(&f));
|
|
||||||
#else
|
|
||||||
return (int)nearbyintf(f); // optimism
|
|
||||||
#endif
|
|
||||||
}
|
|
||||||
|
|
||||||
template <typename data_t, typename acc_t>
|
|
||||||
inline typename utils::enable_if<!nstl::is_integral<data_t>::value,
|
|
||||||
typename utils::remove_reference<data_t>::type>::type
|
|
||||||
saturate(const acc_t &x) {
|
|
||||||
return (typename utils::remove_reference<data_t>::type)x;
|
|
||||||
}
|
|
||||||
|
|
||||||
template <typename data_t, typename acc_t>
|
|
||||||
inline typename utils::enable_if<nstl::is_integral<data_t>::value,
|
|
||||||
typename utils::remove_reference<data_t>::type>::type
|
|
||||||
saturate(const acc_t &x) {
|
|
||||||
acc_t v = x;
|
|
||||||
if (v < (acc_t)nstl::numeric_limits<data_t>::lowest())
|
|
||||||
v = (acc_t)nstl::numeric_limits<data_t>::lowest();
|
|
||||||
if (v > (acc_t)nstl::numeric_limits<data_t>::max())
|
|
||||||
v = (acc_t)nstl::numeric_limits<data_t>::max();
|
|
||||||
return (typename utils::remove_reference<data_t>::type)v;
|
|
||||||
}
|
|
||||||
|
|
||||||
template <typename data_t>
|
|
||||||
double saturate(const double &x) {
|
|
||||||
double v = x;
|
|
||||||
if (v < (double)nstl::numeric_limits<data_t>::lowest())
|
|
||||||
v = (double)nstl::numeric_limits<data_t>::lowest();
|
|
||||||
if (v > (double)nstl::numeric_limits<data_t>::max())
|
|
||||||
v = (double)nstl::numeric_limits<data_t>::max();
|
|
||||||
return v;
|
|
||||||
}
|
|
||||||
|
|
||||||
template <> inline int8_t saturate<int8_t, uint8_t>(const uint8_t &x) {
|
|
||||||
return x <= 127u ? x : 127;
|
|
||||||
}
|
|
||||||
|
|
||||||
template <> inline uint8_t saturate<uint8_t, int8_t>(const int8_t &x) {
|
|
||||||
return x >= 0 ? x : 0;
|
|
||||||
}
|
|
||||||
|
|
||||||
template <typename out_t>
|
|
||||||
typename utils::enable_if<nstl::is_integral<out_t>::value, out_t>::type
|
|
||||||
out_round(float v) { return (out_t)mxcsr_round(v); }
|
|
||||||
|
|
||||||
template <typename out_t>
|
|
||||||
typename utils::enable_if<nstl::is_integral<out_t>::value, out_t>::type
|
|
||||||
out_round(double v) { return (out_t)mxcsr_round((float)v); }
|
|
||||||
|
|
||||||
template <typename out_t>
|
|
||||||
typename utils::enable_if<!nstl::is_integral<out_t>::value, out_t>::type
|
|
||||||
out_round(float v) { return v; }
|
|
||||||
|
|
||||||
inline int gcd(int a, int b) {
|
|
||||||
a = impl::nstl::abs(a);
|
|
||||||
b = impl::nstl::abs(b);
|
|
||||||
if (a < b) { int x = a; a = b; b = x; }
|
|
||||||
|
|
||||||
if (b == 0) return a;
|
|
||||||
|
|
||||||
int r;
|
|
||||||
while ((r = a % b) != 0) { a = b; b = r; }
|
|
||||||
|
|
||||||
return b;
|
|
||||||
}
|
|
||||||
|
|
||||||
template <typename T>
|
|
||||||
inline bool is_pow2(const T& v) { return (v & (v - 1)) == 0; }
|
|
||||||
|
|
||||||
/** returns floor(log2(v)), aka the position of the leftmost non-0 bit */
|
|
||||||
inline int ilog2q(size_t v) {
|
|
||||||
if (v == 0)
|
|
||||||
return -1;
|
|
||||||
|
|
||||||
int p = 0;
|
|
||||||
# define CP(pw) do { if (v >= (1ull << pw)) { v >>= pw; p += pw; } } while(0)
|
|
||||||
CP(32); CP(16); CP(8); CP(4); CP(2); CP(1);
|
|
||||||
# undef CP
|
|
||||||
return p;
|
|
||||||
}
|
|
||||||
|
|
||||||
template <typename T, typename U = typename utils::remove_reference<T>::type>
|
|
||||||
inline U one_m_square(T x) {
|
|
||||||
return (U)(1 - x) * (1 + x);
|
|
||||||
}
|
|
||||||
|
|
||||||
template <typename T, typename U = typename utils::remove_reference<T>::type>
|
|
||||||
inline U x_m_square(T x) {
|
|
||||||
return (U)(1 - x) * x;
|
|
||||||
}
|
|
||||||
|
|
||||||
/* activation */
|
|
||||||
template <typename T, typename A,
|
|
||||||
typename U = typename utils::remove_reference<T>::type>
|
|
||||||
inline U relu_fwd(T s, A alpha) {
|
|
||||||
return s > 0 ? s : (U)(s * alpha);
|
|
||||||
}
|
|
||||||
template <typename T, typename A,
|
|
||||||
typename U = typename utils::remove_reference<T>::type>
|
|
||||||
inline U relu_bwd(T dd, T s, A alpha) {
|
|
||||||
return s > 0 ? dd : (U)(dd * alpha);
|
|
||||||
}
|
|
||||||
|
|
||||||
template <typename T, typename U = typename utils::remove_reference<T>::type>
|
|
||||||
inline U tanh_fwd(T s) {
|
|
||||||
const float e = tanhf((float) s);
|
|
||||||
return (U)e;
|
|
||||||
}
|
|
||||||
|
|
||||||
template <typename T, typename U = typename utils::remove_reference<T>::type>
|
|
||||||
inline U tanh_bwd(T dd, T s) {
|
|
||||||
const float e = tanh_fwd<float>((float) s);
|
|
||||||
return (U)(dd * (1 - e) * (1 + e));
|
|
||||||
}
|
|
||||||
|
|
||||||
template <typename T, typename A,
|
|
||||||
typename U = typename utils::remove_reference<T>::type>
|
|
||||||
inline U elu_fwd(T s, A alpha) {
|
|
||||||
return s > 0 ? s : (U)(alpha * (::expm1f((float)s)));
|
|
||||||
}
|
|
||||||
template <typename T, typename A,
|
|
||||||
typename U = typename utils::remove_reference<T>::type>
|
|
||||||
inline U elu_bwd(T dd, T s, A alpha) {
|
|
||||||
return (U)(dd * (s > 0 ? 1 : alpha * ::expf((float)s)));
|
|
||||||
}
|
|
||||||
|
|
||||||
template <typename T, typename U = typename utils::remove_reference<T>::type>
|
|
||||||
inline U square_fwd(T s) {
|
|
||||||
return s * s;
|
|
||||||
}
|
|
||||||
|
|
||||||
template <typename T, typename U = typename utils::remove_reference<T>::type>
|
|
||||||
inline U square_bwd(T dd, T s) {
|
|
||||||
return dd * 2 * s;
|
|
||||||
}
|
|
||||||
|
|
||||||
template <typename T, typename U = typename utils::remove_reference<T>::type>
|
|
||||||
inline U abs_fwd(T s) {
|
|
||||||
return s > 0 ? s : -s;
|
|
||||||
}
|
|
||||||
|
|
||||||
template <typename T, typename U = typename utils::remove_reference<T>::type>
|
|
||||||
inline U abs_bwd(T dd, T s) {
|
|
||||||
return s > 0 ? dd : s < 0 ? -dd : 0;
|
|
||||||
}
|
|
||||||
|
|
||||||
template <typename T, typename U = typename utils::remove_reference<T>::type>
|
|
||||||
inline U sqrt_fwd(T s) {
|
|
||||||
return s > 0 ? (U)(::sqrtf((float)(s))) : 0;
|
|
||||||
}
|
|
||||||
|
|
||||||
template <typename T, typename U = typename utils::remove_reference<T>::type>
|
|
||||||
inline U sqrt_bwd(T dd, T s) {
|
|
||||||
return s > 0
|
|
||||||
? (U)(dd / (2 * ::sqrtf((float)(s))))
|
|
||||||
: 0;
|
|
||||||
}
|
|
||||||
|
|
||||||
template <typename T, typename A,
|
|
||||||
typename U = typename utils::remove_reference<T>::type>
|
|
||||||
inline U linear_fwd(T s, A alpha, A beta) {
|
|
||||||
return (U)(alpha * s + beta);
|
|
||||||
}
|
|
||||||
|
|
||||||
template <typename T, typename A,
|
|
||||||
typename U = typename utils::remove_reference<T>::type>
|
|
||||||
inline U linear_bwd(T dd, T s, A alpha, A beta) {
|
|
||||||
(void) s;
|
|
||||||
(void) beta;
|
|
||||||
return (U)(dd * alpha);
|
|
||||||
}
|
|
||||||
|
|
||||||
template <typename T, typename A,
|
|
||||||
typename U = typename utils::remove_reference<T>::type>
|
|
||||||
inline U bounded_relu_fwd(T s, A alpha) {
|
|
||||||
s = s > 0 ? s : 0;
|
|
||||||
return s > alpha ? (U)(alpha) : s;
|
|
||||||
}
|
|
||||||
|
|
||||||
template <typename T, typename A,
|
|
||||||
typename U = typename utils::remove_reference<T>::type>
|
|
||||||
inline U bounded_relu_bwd(T dd, T s, A alpha) {
|
|
||||||
return dd * (0 < s && s < alpha ? 1 : 0);
|
|
||||||
}
|
|
||||||
|
|
||||||
template <typename T, typename U = typename utils::remove_reference<T>::type>
|
|
||||||
inline U soft_relu_fwd(T s) {
|
|
||||||
float max_logf = 8.872284e+01; //::logf(FLT_MAX)
|
|
||||||
return s < max_logf ? (U)(::log1pf(::expf((float)s))) : s;
|
|
||||||
}
|
|
||||||
|
|
||||||
template <typename T, typename U = typename utils::remove_reference<T>::type>
|
|
||||||
inline U soft_relu_bwd(T dd, T s) {
|
|
||||||
return (U)(dd / (1 + ::expf((float)(-s))));
|
|
||||||
}
|
|
||||||
|
|
||||||
template <typename T, typename U = typename utils::remove_reference<T>::type>
|
|
||||||
inline U logistic_fwd(T s) {
|
|
||||||
U v = (U)(::expf((float) -s));
|
|
||||||
return 1 / (1 + v);
|
|
||||||
}
|
|
||||||
|
|
||||||
template <typename T, typename U = typename utils::remove_reference<T>::type>
|
|
||||||
inline U logistic_bwd(T dd, T s) {
|
|
||||||
U v = logistic_fwd<T, U>(s);
|
|
||||||
return dd * v * (1 - v);
|
|
||||||
}
|
|
||||||
|
|
||||||
inline bool eltwise_fwd_preserves_zero(alg_kind_t alg, bool jit_impl = false) {
|
|
||||||
using namespace alg_kind;
|
|
||||||
using namespace utils;
|
|
||||||
const bool preserves_zero = true
|
|
||||||
&& !one_of(alg, eltwise_linear, eltwise_soft_relu, eltwise_logistic)
|
|
||||||
&& IMPLICATION(jit_impl, !one_of(alg, eltwise_elu, eltwise_tanh));
|
|
||||||
return preserves_zero;
|
|
||||||
}
|
|
||||||
|
|
||||||
inline float get_bias(const char *bias, size_t offset, data_type_t data_type)
|
|
||||||
{
|
|
||||||
if (!bias)
|
|
||||||
return 0.0f;
|
|
||||||
|
|
||||||
#define CASE(dt) \
|
|
||||||
case dt: return (float)((const prec_traits<dt>::type *)bias)[offset]
|
|
||||||
|
|
||||||
switch (data_type) {
|
|
||||||
CASE(data_type::s8);
|
|
||||||
CASE(data_type::u8);
|
|
||||||
CASE(data_type::s32);
|
|
||||||
CASE(data_type::f32);
|
|
||||||
default: assert(!"unimplemented");
|
|
||||||
}
|
|
||||||
return 0; // never happens (should probably be a NaN)
|
|
||||||
#undef CASE
|
|
||||||
}
|
|
||||||
|
|
||||||
}
|
|
||||||
}
|
|
||||||
}
|
|
||||||
|
|
||||||
#endif
|
|
238
thirdparty/oidn/mkl-dnn/src/common/memory.cpp
vendored
238
thirdparty/oidn/mkl-dnn/src/common/memory.cpp
vendored
@ -1,238 +0,0 @@
|
|||||||
/*******************************************************************************
|
|
||||||
* Copyright 2016-2018 Intel Corporation
|
|
||||||
*
|
|
||||||
* Licensed under the Apache License, Version 2.0 (the "License");
|
|
||||||
* you may not use this file except in compliance with the License.
|
|
||||||
* You may obtain a copy of the License at
|
|
||||||
*
|
|
||||||
* http://www.apache.org/licenses/LICENSE-2.0
|
|
||||||
*
|
|
||||||
* Unless required by applicable law or agreed to in writing, software
|
|
||||||
* distributed under the License is distributed on an "AS IS" BASIS,
|
|
||||||
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
|
||||||
* See the License for the specific language governing permissions and
|
|
||||||
* limitations under the License.
|
|
||||||
*******************************************************************************/
|
|
||||||
|
|
||||||
#include <assert.h>
|
|
||||||
#include <stddef.h>
|
|
||||||
#include <stdint.h>
|
|
||||||
|
|
||||||
#include "mkldnn.h"
|
|
||||||
|
|
||||||
#include "c_types_map.hpp"
|
|
||||||
#include "engine.hpp"
|
|
||||||
#include "type_helpers.hpp"
|
|
||||||
#include "utils.hpp"
|
|
||||||
|
|
||||||
using namespace mkldnn::impl;
|
|
||||||
using namespace mkldnn::impl::utils;
|
|
||||||
using namespace mkldnn::impl::status;
|
|
||||||
using namespace mkldnn::impl::data_type;
|
|
||||||
|
|
||||||
namespace {
|
|
||||||
bool memory_desc_sanity_check(int ndims,const dims_t dims,
|
|
||||||
data_type_t data_type, format_kind_t format_kind) {
|
|
||||||
if (ndims == 0) return true;
|
|
||||||
|
|
||||||
bool ok = true
|
|
||||||
&& dims != nullptr
|
|
||||||
&& 0 < ndims && ndims <= MKLDNN_MAX_NDIMS
|
|
||||||
&& one_of(data_type, f32, s32, s8, u8)
|
|
||||||
&& format_kind != format_kind::undef;
|
|
||||||
if (!ok) return false;
|
|
||||||
for (int d = 0; d < ndims; ++d)
|
|
||||||
if (dims[d] < 0) return false;
|
|
||||||
|
|
||||||
return true;
|
|
||||||
}
|
|
||||||
|
|
||||||
bool memory_desc_sanity_check(const memory_desc_t *md) {
|
|
||||||
if (md == nullptr) return false;
|
|
||||||
return memory_desc_sanity_check(md->ndims, md->dims, md->data_type,
|
|
||||||
format_kind::any);
|
|
||||||
}
|
|
||||||
}
|
|
||||||
|
|
||||||
status_t mkldnn_memory_desc_init_by_tag(memory_desc_t *memory_desc, int ndims,
|
|
||||||
const dims_t dims, data_type_t data_type, format_tag_t tag) {
|
|
||||||
if (any_null(memory_desc)) return invalid_arguments;
|
|
||||||
if (ndims == 0 || tag == format_tag::undef) {
|
|
||||||
*memory_desc = types::zero_md();
|
|
||||||
return success;
|
|
||||||
}
|
|
||||||
|
|
||||||
format_kind_t format_kind = types::format_tag_to_kind(tag);
|
|
||||||
|
|
||||||
/* memory_desc != 0 */
|
|
||||||
bool args_ok = !any_null(memory_desc)
|
|
||||||
&& memory_desc_sanity_check(ndims, dims, data_type, format_kind);
|
|
||||||
if (!args_ok) return invalid_arguments;
|
|
||||||
|
|
||||||
auto md = memory_desc_t();
|
|
||||||
md.ndims = ndims;
|
|
||||||
array_copy(md.dims, dims, ndims);
|
|
||||||
md.data_type = data_type;
|
|
||||||
array_copy(md.padded_dims, dims, ndims);
|
|
||||||
md.format_kind = format_kind;
|
|
||||||
|
|
||||||
status_t status = success;
|
|
||||||
if (tag == format_tag::undef) {
|
|
||||||
status = invalid_arguments;
|
|
||||||
} else if (tag == format_tag::any) {
|
|
||||||
// nop
|
|
||||||
} else if (format_kind == format_kind::blocked) {
|
|
||||||
status = memory_desc_wrapper::compute_blocking(md, tag);
|
|
||||||
} else {
|
|
||||||
assert(!"unreachable");
|
|
||||||
status = invalid_arguments;
|
|
||||||
}
|
|
||||||
|
|
||||||
if (status == success)
|
|
||||||
*memory_desc = md;
|
|
||||||
|
|
||||||
return status;
|
|
||||||
}
|
|
||||||
|
|
||||||
status_t mkldnn_memory_desc_init_by_strides(memory_desc_t *memory_desc,
|
|
||||||
int ndims, const dims_t dims, data_type_t data_type,
|
|
||||||
const dims_t strides) {
|
|
||||||
if (any_null(memory_desc)) return invalid_arguments;
|
|
||||||
if (ndims == 0) {
|
|
||||||
*memory_desc = types::zero_md();
|
|
||||||
return success;
|
|
||||||
}
|
|
||||||
|
|
||||||
/* memory_desc != 0 */
|
|
||||||
bool args_ok = !any_null(memory_desc)
|
|
||||||
&& memory_desc_sanity_check(ndims, dims, data_type, format_kind::any);
|
|
||||||
if (!args_ok) return invalid_arguments;
|
|
||||||
|
|
||||||
auto md = memory_desc_t();
|
|
||||||
md.ndims = ndims;
|
|
||||||
array_copy(md.dims, dims, ndims);
|
|
||||||
md.data_type = data_type;
|
|
||||||
array_copy(md.padded_dims, dims, ndims);
|
|
||||||
md.format_kind = format_kind::blocked;
|
|
||||||
|
|
||||||
dims_t default_strides = {0};
|
|
||||||
if (strides == nullptr) {
|
|
||||||
default_strides[md.ndims - 1] = 1;
|
|
||||||
for (int d = md.ndims - 2; d >= 0; --d)
|
|
||||||
default_strides[d] = default_strides[d + 1] * md.padded_dims[d + 1];
|
|
||||||
strides = default_strides;
|
|
||||||
} else {
|
|
||||||
/* TODO: add sanity check for the provided strides */
|
|
||||||
}
|
|
||||||
|
|
||||||
array_copy(md.format_desc.blocking.strides, strides, md.ndims);
|
|
||||||
|
|
||||||
*memory_desc = md;
|
|
||||||
|
|
||||||
return status::success;
|
|
||||||
}
|
|
||||||
|
|
||||||
status_t mkldnn_memory_desc_init_submemory(memory_desc_t *md,
|
|
||||||
const memory_desc_t *parent_md, const dims_t dims,
|
|
||||||
const dims_t offsets) {
|
|
||||||
if (any_null(md, parent_md) || !memory_desc_sanity_check(parent_md))
|
|
||||||
return invalid_arguments;
|
|
||||||
|
|
||||||
const memory_desc_wrapper src_d(parent_md);
|
|
||||||
|
|
||||||
for (int d = 0; d < src_d.ndims(); ++d) {
|
|
||||||
if (dims[d] < 0 || offsets[d] < 0
|
|
||||||
|| (offsets[d] + dims[d] > src_d.dims()[d]))
|
|
||||||
return invalid_arguments;
|
|
||||||
}
|
|
||||||
|
|
||||||
if (src_d.format_kind() != format_kind::blocked)
|
|
||||||
return unimplemented;
|
|
||||||
|
|
||||||
dims_t blocks;
|
|
||||||
src_d.compute_blocks(blocks);
|
|
||||||
|
|
||||||
memory_desc_t dst_d = *parent_md;
|
|
||||||
auto &dst_d_blk = dst_d.format_desc.blocking;
|
|
||||||
|
|
||||||
/* TODO: put this into memory_desc_wrapper */
|
|
||||||
for (int d = 0; d < src_d.ndims(); ++d) {
|
|
||||||
/* very limited functionality for now */
|
|
||||||
const bool ok = true
|
|
||||||
&& offsets[d] % blocks[d] == 0 /* [r1] */
|
|
||||||
&& src_d.padded_offsets()[d] == 0
|
|
||||||
&& (false
|
|
||||||
|| dims[d] % blocks[d] == 0
|
|
||||||
|| dims[d] < blocks[d]);
|
|
||||||
if (!ok)
|
|
||||||
return unimplemented;
|
|
||||||
|
|
||||||
const bool is_right_border = offsets[d] + dims[d] == src_d.dims()[d];
|
|
||||||
|
|
||||||
dst_d.dims[d] = dims[d];
|
|
||||||
dst_d.padded_dims[d] = is_right_border
|
|
||||||
? src_d.padded_dims()[d] - offsets[d] : dst_d.dims[d];
|
|
||||||
dst_d.padded_offsets[d] = src_d.padded_offsets()[d];
|
|
||||||
dst_d.offset0 += /* [r1] */
|
|
||||||
offsets[d] / blocks[d] * dst_d_blk.strides[d];
|
|
||||||
}
|
|
||||||
|
|
||||||
*md = dst_d;
|
|
||||||
|
|
||||||
return success;
|
|
||||||
}
|
|
||||||
|
|
||||||
int mkldnn_memory_desc_equal(const memory_desc_t *lhs,
|
|
||||||
const memory_desc_t *rhs) {
|
|
||||||
if (lhs == rhs) return 1;
|
|
||||||
if (any_null(lhs, rhs)) return 0;
|
|
||||||
return memory_desc_wrapper(*lhs) == memory_desc_wrapper(*rhs);
|
|
||||||
}
|
|
||||||
|
|
||||||
size_t mkldnn_memory_desc_get_size(const memory_desc_t *md) {
|
|
||||||
if (md == nullptr) return 0;
|
|
||||||
return memory_desc_wrapper(*md).size();
|
|
||||||
}
|
|
||||||
|
|
||||||
status_t mkldnn_memory_create(memory_t **memory, const memory_desc_t *md,
|
|
||||||
engine_t *engine, void *handle) {
|
|
||||||
if (any_null(memory, engine)) return invalid_arguments;
|
|
||||||
memory_desc_t z_md = types::zero_md();
|
|
||||||
return engine->memory_create(memory, md ? md : &z_md, handle);
|
|
||||||
}
|
|
||||||
|
|
||||||
status_t mkldnn_memory_get_memory_desc(const memory_t *memory,
|
|
||||||
const memory_desc_t **md) {
|
|
||||||
if (any_null(memory, md)) return invalid_arguments;
|
|
||||||
*md = memory->md();
|
|
||||||
return success;
|
|
||||||
}
|
|
||||||
|
|
||||||
status_t mkldnn_memory_get_engine(const memory_t *memory, engine_t **engine) {
|
|
||||||
if (any_null(memory, engine)) return invalid_arguments;
|
|
||||||
*engine = memory->engine();
|
|
||||||
return success;
|
|
||||||
}
|
|
||||||
|
|
||||||
status_t mkldnn_memory_get_data_handle(const memory_t *memory,
|
|
||||||
void **handle) {
|
|
||||||
if (any_null(handle))
|
|
||||||
return invalid_arguments;
|
|
||||||
if (memory == nullptr) {
|
|
||||||
*handle = nullptr;
|
|
||||||
return success;
|
|
||||||
}
|
|
||||||
return memory->get_data_handle(handle);
|
|
||||||
}
|
|
||||||
|
|
||||||
status_t mkldnn_memory_set_data_handle(memory_t *memory, void *handle) {
|
|
||||||
if (any_null(memory)) return invalid_arguments;
|
|
||||||
return memory->set_data_handle(handle);
|
|
||||||
}
|
|
||||||
|
|
||||||
status_t mkldnn_memory_destroy(memory_t *memory) {
|
|
||||||
delete memory;
|
|
||||||
return success;
|
|
||||||
}
|
|
||||||
|
|
||||||
// vim: et ts=4 sw=4 cindent cino^=l0,\:0,N-s
|
|
63
thirdparty/oidn/mkl-dnn/src/common/memory.hpp
vendored
63
thirdparty/oidn/mkl-dnn/src/common/memory.hpp
vendored
@ -1,63 +0,0 @@
|
|||||||
/*******************************************************************************
|
|
||||||
* Copyright 2018 Intel Corporation
|
|
||||||
*
|
|
||||||
* Licensed under the Apache License, Version 2.0 (the "License");
|
|
||||||
* you may not use this file except in compliance with the License.
|
|
||||||
* You may obtain a copy of the License at
|
|
||||||
*
|
|
||||||
* http://www.apache.org/licenses/LICENSE-2.0
|
|
||||||
*
|
|
||||||
* Unless required by applicable law or agreed to in writing, software
|
|
||||||
* distributed under the License is distributed on an "AS IS" BASIS,
|
|
||||||
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
|
||||||
* See the License for the specific language governing permissions and
|
|
||||||
* limitations under the License.
|
|
||||||
*******************************************************************************/
|
|
||||||
|
|
||||||
#ifndef MEMORY_HPP
|
|
||||||
#define MEMORY_HPP
|
|
||||||
|
|
||||||
#include <assert.h>
|
|
||||||
|
|
||||||
#include "mkldnn.h"
|
|
||||||
|
|
||||||
#include "c_types_map.hpp"
|
|
||||||
#include "nstl.hpp"
|
|
||||||
|
|
||||||
struct mkldnn_memory: public mkldnn::impl::c_compatible {
|
|
||||||
mkldnn_memory(mkldnn::impl::engine_t *engine,
|
|
||||||
const mkldnn::impl::memory_desc_t *md)
|
|
||||||
: engine_(engine), md_(*md) {}
|
|
||||||
virtual ~mkldnn_memory() {}
|
|
||||||
|
|
||||||
/** allocates/initializes memory */
|
|
||||||
virtual mkldnn::impl::status_t init() = 0;
|
|
||||||
|
|
||||||
/** returns memory's engine */
|
|
||||||
mkldnn::impl::engine_t *engine() const { return engine_; }
|
|
||||||
/** returns memory's description */
|
|
||||||
const mkldnn::impl::memory_desc_t *md() const { return &md_; }
|
|
||||||
|
|
||||||
/** returns data handle */
|
|
||||||
virtual mkldnn::impl::status_t get_data_handle(void **handle) const = 0;
|
|
||||||
|
|
||||||
/** sets data handle */
|
|
||||||
virtual mkldnn::impl::status_t set_data_handle(void *handle) = 0;
|
|
||||||
|
|
||||||
/** zeros padding */
|
|
||||||
virtual mkldnn::impl::status_t zero_pad() const
|
|
||||||
{ return mkldnn::impl::status::success; }
|
|
||||||
|
|
||||||
protected:
|
|
||||||
mkldnn::impl::engine_t *engine_;
|
|
||||||
const mkldnn::impl::memory_desc_t md_;
|
|
||||||
|
|
||||||
private:
|
|
||||||
mkldnn_memory() = delete;
|
|
||||||
mkldnn_memory(const mkldnn_memory &) = delete;
|
|
||||||
mkldnn_memory(mkldnn_memory &&) = delete;
|
|
||||||
mkldnn_memory &operator=(const mkldnn_memory &) = delete;
|
|
||||||
mkldnn_memory &operator=(mkldnn_memory &&) = delete;
|
|
||||||
};
|
|
||||||
|
|
||||||
#endif
|
|
@ -1,212 +0,0 @@
|
|||||||
/*******************************************************************************
|
|
||||||
* Copyright 2016-2018 Intel Corporation
|
|
||||||
*
|
|
||||||
* Licensed under the Apache License, Version 2.0 (the "License");
|
|
||||||
* you may not use this file except in compliance with the License.
|
|
||||||
* You may obtain a copy of the License at
|
|
||||||
*
|
|
||||||
* http://www.apache.org/licenses/LICENSE-2.0
|
|
||||||
*
|
|
||||||
* Unless required by applicable law or agreed to in writing, software
|
|
||||||
* distributed under the License is distributed on an "AS IS" BASIS,
|
|
||||||
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
|
||||||
* See the License for the specific language governing permissions and
|
|
||||||
* limitations under the License.
|
|
||||||
*******************************************************************************/
|
|
||||||
|
|
||||||
#include <assert.h>
|
|
||||||
|
|
||||||
#include <initializer_list>
|
|
||||||
|
|
||||||
#include "c_types_map.hpp"
|
|
||||||
#include "memory_desc_wrapper.hpp"
|
|
||||||
#include "type_helpers.hpp"
|
|
||||||
#include "utils.hpp"
|
|
||||||
|
|
||||||
namespace mkldnn {
|
|
||||||
namespace impl {
|
|
||||||
|
|
||||||
status_t fill_blocked(memory_desc_t &md,
|
|
||||||
std::initializer_list<int> perm,
|
|
||||||
std::initializer_list<int> inner_blks,
|
|
||||||
std::initializer_list<int> inner_idxs) {
|
|
||||||
const bool ok = true
|
|
||||||
&& perm.size() == (size_t)md.ndims
|
|
||||||
&& inner_blks.size() == inner_idxs.size();
|
|
||||||
if (!ok) return status::invalid_arguments;
|
|
||||||
|
|
||||||
md.offset0 = 0;
|
|
||||||
|
|
||||||
blocking_desc_t &blk = md.format_desc.blocking;
|
|
||||||
|
|
||||||
dim_t block_size = 1;
|
|
||||||
dims_t blocks = {0};
|
|
||||||
utils::array_set(blocks, 1, md.ndims);
|
|
||||||
|
|
||||||
blk.inner_nblks = (int)inner_blks.size();
|
|
||||||
|
|
||||||
int iblk = 0;
|
|
||||||
for (const auto &b: inner_idxs)
|
|
||||||
blk.inner_idxs[iblk++] = b;
|
|
||||||
|
|
||||||
iblk = 0;
|
|
||||||
for (const auto &b: inner_blks) {
|
|
||||||
int dim = blk.inner_idxs[iblk];
|
|
||||||
block_size *= b;
|
|
||||||
blocks[dim] *= b;
|
|
||||||
blk.inner_blks[iblk++] = b;
|
|
||||||
}
|
|
||||||
|
|
||||||
utils::array_set(md.padded_offsets, 0, md.ndims);
|
|
||||||
for (int d = 0; d < md.ndims; ++d)
|
|
||||||
md.padded_dims[d] = utils::rnd_up(md.dims[d], blocks[d]);
|
|
||||||
|
|
||||||
dim_t stride = block_size;
|
|
||||||
// if only we use C++14, the initializer_list would have rbegin()/rend()...
|
|
||||||
for (int d = 0; d < md.ndims; ++d)
|
|
||||||
stride *= md.padded_dims[d] == 0 ? 1 : md.padded_dims[d] / blocks[d];
|
|
||||||
|
|
||||||
for (const auto &d: perm) {
|
|
||||||
if (md.padded_dims[d] == 0) {
|
|
||||||
blk.strides[d] = 1;
|
|
||||||
continue;
|
|
||||||
}
|
|
||||||
stride /= md.padded_dims[d] / blocks[d];
|
|
||||||
blk.strides[d] = stride;
|
|
||||||
}
|
|
||||||
|
|
||||||
assert(stride == block_size);
|
|
||||||
|
|
||||||
return status::success;
|
|
||||||
}
|
|
||||||
|
|
||||||
status_t memory_desc_wrapper::compute_blocking(memory_desc_t &memory_desc,
|
|
||||||
format_tag_t tag)
|
|
||||||
{
|
|
||||||
using namespace format_tag;
|
|
||||||
|
|
||||||
if (memory_desc.ndims == 0) return status::invalid_arguments;
|
|
||||||
|
|
||||||
# define C(tag, ... /* perm, inner_blks, inner_idxs */) \
|
|
||||||
case tag: return fill_blocked(memory_desc, __VA_ARGS__)
|
|
||||||
|
|
||||||
switch (tag) {
|
|
||||||
C(a, {0}, {}, {});
|
|
||||||
C(ab, {0, 1}, {}, {});
|
|
||||||
C(abc, {0, 1, 2}, {}, {});
|
|
||||||
C(abcd, {0, 1, 2, 3}, {}, {});
|
|
||||||
C(abcde, {0, 1, 2, 3, 4}, {}, {});
|
|
||||||
C(abcdef, {0, 1, 2, 3, 4, 5}, {}, {});
|
|
||||||
C(abdec, {0, 1, 3, 4, 2}, {}, {});
|
|
||||||
C(acb, {0, 2, 1}, {}, {});
|
|
||||||
C(acbde, {0, 2, 1, 3, 4}, {}, {});
|
|
||||||
C(acdb, {0, 2, 3, 1}, {}, {});
|
|
||||||
C(acdeb, {0, 2, 3, 4, 1}, {}, {});
|
|
||||||
C(ba, {1, 0}, {}, {});
|
|
||||||
C(bac, {1, 0, 2}, {}, {});
|
|
||||||
C(bacd, {1, 0, 2, 3}, {}, {});
|
|
||||||
C(bcda, {1, 2, 3, 0}, {}, {});
|
|
||||||
C(cba, {2, 1, 0}, {}, {});
|
|
||||||
C(cdba, {2, 3, 1, 0}, {}, {});
|
|
||||||
C(cdeba, {2, 3, 4, 1, 0}, {}, {});
|
|
||||||
C(decab, {3, 4, 2, 0, 1}, {}, {});
|
|
||||||
|
|
||||||
C(Abc4a, {0, 1, 2}, {4}, {0});
|
|
||||||
C(aBc4b, {0, 1, 2}, {4}, {1});
|
|
||||||
C(ABc4b16a4b, {0, 1, 2}, {4, 16, 4}, {1, 0, 1});
|
|
||||||
C(ABc4b4a, {0, 1, 2}, {4, 4}, {1, 0});
|
|
||||||
C(Abcd4a, {0, 1, 2, 3}, {4}, {0});
|
|
||||||
C(aBcd4b, {0, 1, 2, 3}, {4}, {1});
|
|
||||||
C(ABcd4b4a, {0, 1, 2, 3}, {4, 4}, {1, 0});
|
|
||||||
C(aBCd4c16b4c, {0, 1, 2, 3}, {4, 16, 4}, {2, 1, 2});
|
|
||||||
C(aBCd4c4b, {0, 1, 2, 3, 4}, {4, 4}, {2, 1});
|
|
||||||
C(Abcde4a, {0, 1, 2, 3, 4}, {4}, {0});
|
|
||||||
C(aBcde4b, {0, 1, 2, 3, 4}, {4}, {1});
|
|
||||||
C(ABcde4b4a, {0, 1, 2, 3, 4}, {4, 4}, {1, 0});
|
|
||||||
C(aBCde4c4b, {0, 1, 2, 3, 4}, {4, 4}, {2, 1});
|
|
||||||
C(aBcdef4b, {0, 1, 2, 3, 4, 5}, {4}, {1});
|
|
||||||
C(aBCdef4c4b, {0, 1, 2, 3, 4, 5}, {4, 4}, {2, 1});
|
|
||||||
C(aBdc4b, {0, 1, 3, 2}, {4}, {1});
|
|
||||||
C(aBdec4b, {0, 1, 3, 4, 2}, {4}, {1});
|
|
||||||
C(aBdefc4b, {0, 1, 3, 4, 5, 2}, {4}, {1});
|
|
||||||
C(Acb4a, {0, 2, 1}, {4}, {0});
|
|
||||||
C(Acdb4a, {0, 2, 3, 1}, {4}, {0});
|
|
||||||
C(Acdeb4a, {0, 2, 3, 4, 1}, {4}, {0});
|
|
||||||
|
|
||||||
C(Abc16a, {0, 1, 2}, {16}, {0});
|
|
||||||
C(ABc16a16b, {0, 1, 2}, {16, 16}, {0, 1});
|
|
||||||
C(aBc16b, {0, 1, 2}, {16}, {1});
|
|
||||||
C(ABc16b16a, {0, 1, 2}, {16, 16}, {1, 0});
|
|
||||||
C(ABc8a16b2a, {0, 1, 2}, {8, 16, 2}, {0, 1, 0});
|
|
||||||
C(ABc8a8b, {0, 1, 2}, {8, 8}, {0, 1});
|
|
||||||
C(aBc8b, {0, 1, 2}, {8}, {1});
|
|
||||||
C(ABc8b16a2b, {0, 1, 2}, {8, 16, 2}, {1, 0, 1});
|
|
||||||
C(ABc8b8a, {0, 1, 2}, {8, 8}, {1, 0});
|
|
||||||
C(Abcd16a, {0, 1, 2, 3}, {16}, {0});
|
|
||||||
C(ABcd16a16b, {0, 1, 2, 3}, {16, 16}, {0, 1});
|
|
||||||
C(aBcd16b, {0, 1, 2, 3}, {16}, {1});
|
|
||||||
C(ABcd16b16a, {0, 1, 2, 3}, {16, 16}, {1, 0});
|
|
||||||
C(aBCd16b16c, {0, 1, 2, 3}, {16, 16}, {1, 2});
|
|
||||||
C(aBCd16c16b, {0, 1, 2, 3}, {16, 16}, {2, 1});
|
|
||||||
C(ABcd4b16a4b, {0, 1, 2, 3}, {4, 16, 4}, {1, 0, 1});
|
|
||||||
C(ABcd8a16b2a, {0, 1, 2, 3}, {8, 16, 2}, {0, 1, 0});
|
|
||||||
C(ABcd8a8b, {0, 1, 2, 3}, {8, 8}, {0, 1});
|
|
||||||
C(aBcd8b, {0, 1, 2, 3}, {8}, {1});
|
|
||||||
C(ABcd8b16a2b, {0, 1, 2, 3}, {8, 16, 2}, {1, 0, 1});
|
|
||||||
C(aBCd8b16c2b, {0, 1, 2, 3}, {8, 16, 2}, {1, 2, 1});
|
|
||||||
C(ABcd8b8a, {0, 1, 2, 3}, {8, 8}, {1, 0});
|
|
||||||
C(aBCd8b8c, {0, 1, 2, 3}, {8, 8}, {1, 2});
|
|
||||||
C(aBCd8c16b2c, {0, 1, 2, 3}, {8, 16, 2}, {2, 1, 2});
|
|
||||||
C(aBCd8c8b, {0, 1, 2, 3}, {8, 8}, {2, 1});
|
|
||||||
C(Abcde16a, {0, 1, 2, 3, 4}, {16}, {0});
|
|
||||||
C(ABcde16a16b, {0, 1, 2, 3, 4}, {16, 16}, {0, 1});
|
|
||||||
C(aBcde16b, {0, 1, 2, 3, 4}, {16}, {1});
|
|
||||||
C(ABcde16b16a, {0, 1, 2, 3, 4}, {16, 16}, {1, 0});
|
|
||||||
C(aBCde16b16c, {0, 1, 2, 3, 4}, {16, 16}, {1, 2});
|
|
||||||
C(aBCde16c16b, {0, 1, 2, 3, 4}, {16, 16}, {2, 1});
|
|
||||||
C(aBCde2c8b4c, {0, 1, 2, 3, 4}, {2, 8, 4}, {2, 1, 2});
|
|
||||||
C(aBCde4b4c, {0, 1, 2, 3, 4}, {4, 4}, {1, 2});
|
|
||||||
C(aBCde4c16b4c, {0, 1, 2, 3, 4}, {4, 16, 4}, {2, 1, 2});
|
|
||||||
C(Abcde8a, {0, 1, 2, 3, 4}, {8}, {0});
|
|
||||||
C(ABcde8a8b, {0, 1, 2, 3, 4}, {8, 8}, {0, 1});
|
|
||||||
C(aBcde8b, {0, 1, 2, 3, 4}, {8}, {1});
|
|
||||||
C(ABcde8b16a2b, {0, 1, 2, 3, 4}, {8, 16, 2}, {1, 0, 1});
|
|
||||||
C(aBCde8b16c2b, {0, 1, 2, 3, 4}, {8, 16, 2}, {1, 2, 1});
|
|
||||||
C(ABcde8b8a, {0, 1, 2, 3, 4}, {8, 8}, {1, 0});
|
|
||||||
C(aBCde8b8c, {0, 1, 2, 3, 4}, {8, 8}, {1, 2});
|
|
||||||
C(aBCde8c16b2c, {0, 1, 2, 3, 4}, {8, 16, 2}, {2, 1, 2});
|
|
||||||
C(aBCde8c8b, {0, 1, 2, 3, 4}, {8, 8}, {2, 1});
|
|
||||||
C(aBcdef16b, {0, 1, 2, 3, 4, 5}, {16}, {1});
|
|
||||||
C(aBCdef16b16c, {0, 1, 2, 3, 4, 5}, {16, 16}, {1, 2});
|
|
||||||
C(aBCdef16c16b, {0, 1, 2, 3, 4, 5}, {16, 16}, {2, 1});
|
|
||||||
C(aBCdef8b8c, {0, 1, 2, 3, 4, 5}, {8, 8}, {1, 2});
|
|
||||||
C(aBCdef8c16b2c, {0, 1, 2, 3, 4, 5}, {8, 16, 2}, {2, 1, 2});
|
|
||||||
C(aBCdef8c8b, {0, 1, 2, 3, 4, 5}, {8, 8}, {2, 1});
|
|
||||||
C(aBdc16b, {0, 1, 3, 2}, {16}, {1});
|
|
||||||
C(aBdc8b, {0, 1, 3, 2}, {8}, {1});
|
|
||||||
C(aBdec16b, {0, 1, 3, 4, 2}, {16}, {1});
|
|
||||||
C(aBdec8b, {0, 1, 3, 4, 2}, {8}, {1});
|
|
||||||
C(aBdefc16b, {0, 1, 3, 4, 5, 2}, {16}, {1});
|
|
||||||
C(aBdefc8b, {0, 1, 3, 4, 5, 2}, {8}, {1});
|
|
||||||
C(Acb16a, {0, 2, 1}, {16}, {0});
|
|
||||||
C(Acb8a, {0, 2, 1}, {8}, {0});
|
|
||||||
C(aCBd16b16c, {0, 2, 1, 3}, {16, 16}, {1, 2});
|
|
||||||
C(aCBde16b16c, {0, 2, 1, 3, 4}, {16, 16}, {1, 2});
|
|
||||||
C(Acdb16a, {0, 2, 3, 1}, {16}, {0});
|
|
||||||
C(Acdb8a, {0, 2, 3, 1}, {8}, {0});
|
|
||||||
C(Acdeb16a, {0, 2, 3, 4, 1}, {16}, {0});
|
|
||||||
C(Acdeb8a, {0, 2, 3, 4, 1}, {8}, {0});
|
|
||||||
C(BAc16a16b, {1, 0, 2}, {16, 16}, {0, 1});
|
|
||||||
C(BAcd16a16b, {1, 0, 2, 3}, {16, 16}, {0, 1});
|
|
||||||
default: break;
|
|
||||||
}
|
|
||||||
|
|
||||||
#undef C
|
|
||||||
|
|
||||||
return status::invalid_arguments;
|
|
||||||
}
|
|
||||||
|
|
||||||
}
|
|
||||||
}
|
|
||||||
|
|
||||||
// vim: et ts=4 sw=4 cindent cino^=l0,\:0,N-s
|
|
@ -1,400 +0,0 @@
|
|||||||
/*******************************************************************************
|
|
||||||
* Copyright 2016-2018 Intel Corporation
|
|
||||||
*
|
|
||||||
* Licensed under the Apache License, Version 2.0 (the "License");
|
|
||||||
* you may not use this file except in compliance with the License.
|
|
||||||
* You may obtain a copy of the License at
|
|
||||||
*
|
|
||||||
* http://www.apache.org/licenses/LICENSE-2.0
|
|
||||||
*
|
|
||||||
* Unless required by applicable law or agreed to in writing, software
|
|
||||||
* distributed under the License is distributed on an "AS IS" BASIS,
|
|
||||||
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
|
||||||
* See the License for the specific language governing permissions and
|
|
||||||
* limitations under the License.
|
|
||||||
*******************************************************************************/
|
|
||||||
|
|
||||||
#ifndef MEMORY_DESC_WRAPPER_HPP
|
|
||||||
#define MEMORY_DESC_WRAPPER_HPP
|
|
||||||
|
|
||||||
#include <assert.h>
|
|
||||||
|
|
||||||
#include "c_types_map.hpp"
|
|
||||||
#include "nstl.hpp"
|
|
||||||
#include "utils.hpp"
|
|
||||||
|
|
||||||
#include "type_helpers.hpp"
|
|
||||||
|
|
||||||
namespace mkldnn {
|
|
||||||
namespace impl {
|
|
||||||
|
|
||||||
/** thin wrapper class over \struct memory_desc_t which allows easy
|
|
||||||
* manipulations with underlying C structure, which is taken by reference */
|
|
||||||
struct memory_desc_wrapper: public c_compatible {
|
|
||||||
const memory_desc_t *md_;
|
|
||||||
|
|
||||||
/** constructor which takes a reference to a constant underlying C memory
|
|
||||||
* descriptor \param md */
|
|
||||||
memory_desc_wrapper(const memory_desc_t *md): md_(md) {}
|
|
||||||
memory_desc_wrapper(const memory_desc_t &md): memory_desc_wrapper(&md) {}
|
|
||||||
|
|
||||||
/* implementing attributes */
|
|
||||||
int ndims() const { return md_->ndims; }
|
|
||||||
const dims_t &dims() const { return md_->dims; }
|
|
||||||
data_type_t data_type() const { return md_->data_type; }
|
|
||||||
|
|
||||||
const dims_t &padded_dims() const { return md_->padded_dims; }
|
|
||||||
const dims_t &padded_offsets() const { return md_->padded_offsets; }
|
|
||||||
dim_t offset0() const { return md_->offset0; }
|
|
||||||
|
|
||||||
format_kind_t format_kind() const { return md_->format_kind; }
|
|
||||||
|
|
||||||
bool is_blocking_desc() const
|
|
||||||
{ return format_kind() == format_kind::blocked; }
|
|
||||||
bool is_wino_desc() const
|
|
||||||
{ return format_kind() == format_kind::wino; }
|
|
||||||
bool is_rnn_packed_desc() const
|
|
||||||
{ return format_kind() == format_kind::rnn_packed; }
|
|
||||||
|
|
||||||
const blocking_desc_t &blocking_desc() const {
|
|
||||||
assert(is_blocking_desc());
|
|
||||||
return md_->format_desc.blocking;
|
|
||||||
}
|
|
||||||
const wino_desc_t &wino_desc() const {
|
|
||||||
assert(is_wino_desc());
|
|
||||||
return md_->format_desc.wino_desc;
|
|
||||||
}
|
|
||||||
const rnn_packed_desc_t &rnn_packed_desc() const {
|
|
||||||
assert(is_rnn_packed_desc());
|
|
||||||
return md_->format_desc.rnn_packed_desc;
|
|
||||||
}
|
|
||||||
|
|
||||||
const memory_extra_desc_t &extra() const { return md_->extra; }
|
|
||||||
|
|
||||||
/* some useful function */
|
|
||||||
|
|
||||||
/** returns the number of elements including padding if \param with_padding
|
|
||||||
* is true, and the number of data elements otherwise */
|
|
||||||
dim_t nelems(bool with_padding = false) const {
|
|
||||||
if (is_zero()) return 0;
|
|
||||||
return utils::array_product(
|
|
||||||
with_padding ? padded_dims() : dims(), ndims());
|
|
||||||
}
|
|
||||||
|
|
||||||
/** returns true if memory descriptor is zero */
|
|
||||||
bool is_zero() const { return ndims() == 0; }
|
|
||||||
|
|
||||||
/** returns true if memory descriptor contains zero as one of its dim */
|
|
||||||
bool has_zero_dim() const { return nelems() == 0; }
|
|
||||||
|
|
||||||
/** return the size of data type (a shortcut) */
|
|
||||||
size_t data_type_size() const
|
|
||||||
{ return types::data_type_size(data_type()); }
|
|
||||||
|
|
||||||
/** return the size of data type of additional buffer */
|
|
||||||
size_t additional_buffer_data_size() const {
|
|
||||||
if (extra().flags & memory_extra_flags::compensation_conv_s8s8)
|
|
||||||
return sizeof(int32_t);
|
|
||||||
return 0;
|
|
||||||
}
|
|
||||||
|
|
||||||
/** return true if memory format has additional buffer */
|
|
||||||
bool is_additional_buffer() const {
|
|
||||||
return (extra().flags & memory_extra_flags::compensation_conv_s8s8);
|
|
||||||
}
|
|
||||||
|
|
||||||
/** returns the size of additional buffer */
|
|
||||||
size_t additional_buffer_size() const {
|
|
||||||
if (extra().flags & memory_extra_flags::compensation_conv_s8s8) {
|
|
||||||
int cmask = extra().compensation_mask;
|
|
||||||
assert(cmask == 1 || cmask == 3);
|
|
||||||
dim_t prod = 1;
|
|
||||||
for (int d = 0; d < ndims(); ++d)
|
|
||||||
if (cmask & (1<<d)) prod *= padded_dims()[d];
|
|
||||||
return prod * additional_buffer_data_size();
|
|
||||||
}
|
|
||||||
|
|
||||||
return 0;
|
|
||||||
}
|
|
||||||
|
|
||||||
/** returns the size required to store described memory
|
|
||||||
* note: if offset0 != 0 returns 0 (need to specify the behavior) */
|
|
||||||
size_t size() const {
|
|
||||||
if (is_zero() || has_zero_dim() || format_kind() == format_kind::any)
|
|
||||||
return 0;
|
|
||||||
|
|
||||||
if (format_kind() == format_kind::wino) {
|
|
||||||
return wino_desc().size;
|
|
||||||
} else if (format_kind() == format_kind::rnn_packed) {
|
|
||||||
return rnn_packed_desc().size;
|
|
||||||
} else {
|
|
||||||
if (offset0() != 0) return 0;
|
|
||||||
|
|
||||||
dims_t blocks = {0};
|
|
||||||
compute_blocks(blocks);
|
|
||||||
|
|
||||||
const auto &bd = blocking_desc();
|
|
||||||
|
|
||||||
size_t max_size = 0;
|
|
||||||
for (int d = 0; d < ndims(); ++d)
|
|
||||||
max_size = nstl::max<size_t>(max_size,
|
|
||||||
padded_dims()[d] / blocks[d] * bd.strides[d]);
|
|
||||||
|
|
||||||
if (max_size == 1 && bd.inner_nblks != 0) {
|
|
||||||
max_size = utils::array_product(bd.inner_blks, bd.inner_nblks);
|
|
||||||
}
|
|
||||||
|
|
||||||
return max_size * data_type_size() + additional_buffer_size();
|
|
||||||
}
|
|
||||||
}
|
|
||||||
|
|
||||||
/** returns true if data is dense in memory */
|
|
||||||
bool is_dense(bool with_padding = false) const {
|
|
||||||
if (utils::one_of(format_kind(), format_kind::undef, format_kind::any))
|
|
||||||
return false;
|
|
||||||
return nelems(with_padding) * data_type_size() == size();
|
|
||||||
}
|
|
||||||
|
|
||||||
/** returns true if memory desc is fully defined */
|
|
||||||
bool is_defined() const { return format_kind() != format_kind::any; }
|
|
||||||
|
|
||||||
/** returns true if the only (potentially) padded dim is \param dim */
|
|
||||||
bool only_padded_dim(int dim) const {
|
|
||||||
for (int d = 0; d < ndims(); ++d)
|
|
||||||
if (d != dim && dims()[d] != padded_dims()[d])
|
|
||||||
return false;
|
|
||||||
return true;
|
|
||||||
}
|
|
||||||
|
|
||||||
/** returns true if memory desc has blocked layout and block dims are 1s */
|
|
||||||
bool is_plain() const {
|
|
||||||
if (!is_blocking_desc()) return false;
|
|
||||||
return blocking_desc().inner_nblks == 0;
|
|
||||||
}
|
|
||||||
|
|
||||||
/** returns overall block sizes */
|
|
||||||
void compute_blocks(dims_t blocks) const {
|
|
||||||
if (!is_blocking_desc()) {
|
|
||||||
utils::array_set(blocks, 0, ndims());
|
|
||||||
return;
|
|
||||||
}
|
|
||||||
|
|
||||||
utils::array_set(blocks, 1, ndims());
|
|
||||||
|
|
||||||
const auto &bd = blocking_desc();
|
|
||||||
for (int iblk = 0; iblk < bd.inner_nblks; ++iblk)
|
|
||||||
blocks[bd.inner_idxs[iblk]] *= bd.inner_blks[iblk];
|
|
||||||
}
|
|
||||||
|
|
||||||
/* comparison section */
|
|
||||||
|
|
||||||
bool operator==(const memory_desc_wrapper &rhs) const
|
|
||||||
{ return *this->md_ == *rhs.md_; }
|
|
||||||
bool operator!=(const memory_desc_wrapper &rhs) const
|
|
||||||
{ return !operator==(rhs); }
|
|
||||||
bool operator==(const memory_desc_t &rhs) const
|
|
||||||
{ return operator==(memory_desc_wrapper(rhs)); }
|
|
||||||
bool operator!=(const memory_desc_t &rhs) const
|
|
||||||
{ return !operator==(rhs); }
|
|
||||||
|
|
||||||
/** returns true if data (w/o padding if with_padding == false and w/
|
|
||||||
* padding otherwise) have the same physical structure, i.e. dimensions,
|
|
||||||
* strides, and blocked structure. Depending on with_data_type flag
|
|
||||||
* data_type is taken or not taken into account. dim_start allows to check
|
|
||||||
* similarity for the logical part of data [dim_start .. ndims()].
|
|
||||||
* CAUTION: format kind any and undef are not similar to whatever, hence the
|
|
||||||
* following statement might be true: lhs == rhs && !lhs.similar_to(rhs) */
|
|
||||||
/* TODO: revise */
|
|
||||||
bool similar_to(const memory_desc_wrapper &rhs,
|
|
||||||
bool with_padding = true, bool with_data_type = true,
|
|
||||||
int dim_start = 0) const;
|
|
||||||
|
|
||||||
/** returns true if one memory can be reordered to another */
|
|
||||||
bool consistent_with(const memory_desc_wrapper &rhs) const;
|
|
||||||
|
|
||||||
/** returns true if the memory desc corresponds to the given format tag and
|
|
||||||
* strides.
|
|
||||||
* @sa memory_desc_matches_tag */
|
|
||||||
bool matches_tag(format_tag_t tag, const dims_t strides = nullptr) const {
|
|
||||||
return memory_desc_matches_tag(*md_, tag, strides);
|
|
||||||
}
|
|
||||||
|
|
||||||
/** returns matching tag (or undef if match is not found)
|
|
||||||
* XXX: This is a workaround that eventually should go away! */
|
|
||||||
template <typename... Tags>
|
|
||||||
format_tag_t matches_one_of_tag(Tags ...tags) const {
|
|
||||||
for (const auto tag: {tags...}) {
|
|
||||||
if (memory_desc_matches_tag(*md_, tag))
|
|
||||||
return tag;
|
|
||||||
}
|
|
||||||
return format_tag::undef;
|
|
||||||
}
|
|
||||||
|
|
||||||
/* offset section */
|
|
||||||
|
|
||||||
/** returns physical offset by logical one. logical offset is represented by
|
|
||||||
* an array \param pos. if \param is_pos_padded is true \param pos
|
|
||||||
* represents the position in already padded area */
|
|
||||||
dim_t off_v(const dims_t pos, bool is_pos_padded = false) const {
|
|
||||||
assert(is_blocking_desc());
|
|
||||||
const blocking_desc_t &blk = blocking_desc();
|
|
||||||
|
|
||||||
dims_t pos_copy = {0};
|
|
||||||
for (int d = 0; d < ndims(); ++d)
|
|
||||||
pos_copy[d] = pos[d] + (is_pos_padded ? 0 : padded_offsets()[d]);
|
|
||||||
|
|
||||||
dim_t phys_offset = offset0();
|
|
||||||
|
|
||||||
if (blk.inner_nblks > 0) {
|
|
||||||
dim_t blk_stride = 1;
|
|
||||||
for (int iblk = blk.inner_nblks - 1; iblk >= 0; --iblk) {
|
|
||||||
const int d = blk.inner_idxs[iblk];
|
|
||||||
const dim_t p = pos_copy[d] % blk.inner_blks[iblk];
|
|
||||||
|
|
||||||
phys_offset += p * blk_stride;
|
|
||||||
|
|
||||||
pos_copy[d] /= blk.inner_blks[iblk];
|
|
||||||
|
|
||||||
blk_stride *= blk.inner_blks[iblk];
|
|
||||||
}
|
|
||||||
}
|
|
||||||
|
|
||||||
for (int d = 0; d < ndims(); ++d) {
|
|
||||||
const dim_t p = pos_copy[d];
|
|
||||||
phys_offset += p * blk.strides[d];
|
|
||||||
}
|
|
||||||
|
|
||||||
return phys_offset;
|
|
||||||
}
|
|
||||||
|
|
||||||
/** returns physical offset by logical one. logical offset is represented by
|
|
||||||
* a scalar \param l_offset. if \param is_pos_padded is true, \param
|
|
||||||
* l_offset represents logical offset in already padded area */
|
|
||||||
dim_t off_l(dim_t l_offset, bool is_pos_padded = false) const {
|
|
||||||
assert(is_blocking_desc());
|
|
||||||
dims_t pos;
|
|
||||||
for (int rd = 0; rd < ndims(); ++rd) {
|
|
||||||
const int d = ndims() - 1 - rd;
|
|
||||||
const dim_t cur_dim = is_pos_padded ? padded_dims()[d] : dims()[d];
|
|
||||||
pos[d] = l_offset % cur_dim;
|
|
||||||
l_offset /= cur_dim;
|
|
||||||
}
|
|
||||||
return off_v(pos, is_pos_padded);
|
|
||||||
}
|
|
||||||
|
|
||||||
/** returns physical offset by logical one. logical offset is represented by
|
|
||||||
* a tuple of indices (\param xn, ..., \param x1, \param x0) */
|
|
||||||
template<typename... Args>
|
|
||||||
dim_t off(Args... args) const {
|
|
||||||
assert(sizeof...(args) == ndims());
|
|
||||||
dims_t pos = { args... };
|
|
||||||
return off_v(pos, false);
|
|
||||||
}
|
|
||||||
|
|
||||||
/** returns physical offset by logical one. logical offset is represented by
|
|
||||||
* a tuple of indices (\param xn, ..., \param x1, \param x0) in already
|
|
||||||
* padded area */
|
|
||||||
template<typename... Args>
|
|
||||||
dim_t off_padding(Args... args) const {
|
|
||||||
assert(sizeof...(args) == ndims());
|
|
||||||
dims_t pos = { args... };
|
|
||||||
return off_v(pos, true);
|
|
||||||
}
|
|
||||||
|
|
||||||
/** returns physical offset by logical one. Logical offset is represented by
|
|
||||||
* a tuple of block indices (\param bn, ..., \param b1, \param b0). It is a
|
|
||||||
* user responsibility to adjust the result to get offset within blocks */
|
|
||||||
template<typename ...Args>
|
|
||||||
dim_t blk_off(Args... args) const {
|
|
||||||
return _blk_off<sizeof...(args), Args...>(args...);
|
|
||||||
}
|
|
||||||
|
|
||||||
template<bool skip_first, typename T, typename ...Args>
|
|
||||||
dim_t blk_off(T xn, Args... args) const {
|
|
||||||
return skip_first
|
|
||||||
? blk_off<Args...>(args...)
|
|
||||||
: blk_off<T, Args...>(xn, args...);
|
|
||||||
}
|
|
||||||
|
|
||||||
/* static functions section */
|
|
||||||
/* TODO: replace with non-static, once md_ becomes non-const ref */
|
|
||||||
|
|
||||||
static status_t compute_blocking(memory_desc_t &memory_desc,
|
|
||||||
format_tag_t tag);
|
|
||||||
|
|
||||||
private:
|
|
||||||
/* TODO: put logical_offset in utils */
|
|
||||||
template<typename T>
|
|
||||||
dim_t logical_offset(T x0) const { return x0; }
|
|
||||||
|
|
||||||
template<typename T, typename... Args>
|
|
||||||
dim_t logical_offset(T xn, Args... args) const {
|
|
||||||
const size_t n_args = sizeof...(args);
|
|
||||||
return xn * utils::array_product<n_args>(
|
|
||||||
&dims()[ndims() - n_args]) + logical_offset(args...);
|
|
||||||
}
|
|
||||||
|
|
||||||
template<int ORIG_LEN, typename ...Void>
|
|
||||||
dim_t _blk_off() const { return offset0(); }
|
|
||||||
|
|
||||||
template<int ORIG_LEN, typename T, typename ...Args>
|
|
||||||
dim_t _blk_off(T xc, Args ...args) const {
|
|
||||||
assert(is_blocking_desc());
|
|
||||||
constexpr int dc = ORIG_LEN - sizeof...(args) - 1;
|
|
||||||
return xc * blocking_desc().strides[dc]
|
|
||||||
+ _blk_off<ORIG_LEN, Args...>(args...);
|
|
||||||
}
|
|
||||||
};
|
|
||||||
|
|
||||||
inline bool memory_desc_wrapper::similar_to(const memory_desc_wrapper &rhs,
|
|
||||||
bool with_padding, bool with_data_type, int dim_start) const {
|
|
||||||
using namespace utils;
|
|
||||||
|
|
||||||
if (one_of(format_kind(), format_kind::undef, format_kind::any))
|
|
||||||
return false;
|
|
||||||
if (is_wino_desc() || is_rnn_packed_desc())
|
|
||||||
return false;
|
|
||||||
|
|
||||||
const int ds = dim_start;
|
|
||||||
const auto &blk = blocking_desc();
|
|
||||||
const auto &r_blk = rhs.blocking_desc();
|
|
||||||
|
|
||||||
return ndims() == rhs.ndims()
|
|
||||||
&& dim_start <= ndims() /* guard */
|
|
||||||
&& format_kind() == rhs.format_kind()
|
|
||||||
&& IMPLICATION(with_data_type, data_type() == rhs.data_type())
|
|
||||||
&& array_cmp(dims() + ds, rhs.dims() + ds, ndims() - ds)
|
|
||||||
&& array_cmp(blk.strides + ds, r_blk.strides + ds, ndims() - ds)
|
|
||||||
&& blk.inner_nblks == r_blk.inner_nblks
|
|
||||||
&& array_cmp(blk.inner_blks, r_blk.inner_blks, blk.inner_nblks)
|
|
||||||
&& array_cmp(blk.inner_idxs, r_blk.inner_idxs, blk.inner_nblks)
|
|
||||||
&& IMPLICATION(with_padding, true
|
|
||||||
&& array_cmp(padded_dims() + ds, rhs.padded_dims() + ds,
|
|
||||||
ndims() - ds)
|
|
||||||
&& array_cmp(padded_offsets() + ds, rhs.padded_offsets() + ds,
|
|
||||||
ndims() - ds));
|
|
||||||
}
|
|
||||||
|
|
||||||
inline bool memory_desc_wrapper::consistent_with(
|
|
||||||
const memory_desc_wrapper &rhs) const {
|
|
||||||
if (ndims() == rhs.ndims()) {
|
|
||||||
for (int d = 0; d < ndims(); ++d) {
|
|
||||||
if (dims()[d] != rhs.dims()[d]) return false;
|
|
||||||
}
|
|
||||||
return true;
|
|
||||||
} else {
|
|
||||||
/* TODO: revise.
|
|
||||||
* is the following possible?
|
|
||||||
* [1, a, b] <--reorder--> [a, b]
|
|
||||||
* [a, 1, b] <--reorder--> [a, b]
|
|
||||||
* not, at least for now */
|
|
||||||
return false;
|
|
||||||
}
|
|
||||||
}
|
|
||||||
|
|
||||||
}
|
|
||||||
}
|
|
||||||
|
|
||||||
#endif
|
|
||||||
|
|
||||||
// vim: et ts=4 sw=4 cindent cino^=l0,\:0,N-s
|
|
@ -1,295 +0,0 @@
|
|||||||
/*******************************************************************************
|
|
||||||
* Copyright 2018 Intel Corporation
|
|
||||||
*
|
|
||||||
* Licensed under the Apache License, Version 2.0 (the "License");
|
|
||||||
* you may not use this file except in compliance with the License.
|
|
||||||
* You may obtain a copy of the License at
|
|
||||||
*
|
|
||||||
* http://www.apache.org/licenses/LICENSE-2.0
|
|
||||||
*
|
|
||||||
* Unless required by applicable law or agreed to in writing, software
|
|
||||||
* distributed under the License is distributed on an "AS IS" BASIS,
|
|
||||||
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
|
||||||
* See the License for the specific language governing permissions and
|
|
||||||
* limitations under the License.
|
|
||||||
*******************************************************************************/
|
|
||||||
|
|
||||||
#ifndef MEMORY_TRACKING_HPP
|
|
||||||
#define MEMORY_TRACKING_HPP
|
|
||||||
|
|
||||||
#include <assert.h>
|
|
||||||
#include <unordered_map>
|
|
||||||
|
|
||||||
#include "nstl.hpp"
|
|
||||||
#include "utils.hpp"
|
|
||||||
|
|
||||||
namespace mkldnn {
|
|
||||||
namespace impl {
|
|
||||||
namespace memory_tracking {
|
|
||||||
|
|
||||||
/* Memory tracking capabilities
|
|
||||||
*
|
|
||||||
* The main purpose of this header file is to provide uniform way to register
|
|
||||||
* required memory for a scratchpad at a primitive descriptor creation time
|
|
||||||
* and then easily access it having only the base address of the scratchpad.
|
|
||||||
*
|
|
||||||
* Primitives might contain multiple disjoint parts that require temporary
|
|
||||||
* buffers (known as scratchpad) during their execution. A primitive descriptor
|
|
||||||
* should summarize all the needs into one single number -- the buffer size
|
|
||||||
* that would be requested from a user. At execution time, the corresponding
|
|
||||||
* primitive will receive a base pointer to a scratchpad. It then needs to
|
|
||||||
* provide each part of algorithm the corresponding piece of memory. Three main
|
|
||||||
* challenges here are:
|
|
||||||
* 1. Track correct offset (from the base scratchpad address) for each piece
|
|
||||||
* 2. Algorithm might require that different memory pieces to be aligned, so
|
|
||||||
* the scratchpad size is no more just a sum of size of the corresponding
|
|
||||||
* subparts.
|
|
||||||
* 3. While a primitive is responsible for its scratchpad, the implementation
|
|
||||||
* might use some other basic blocks (e.g. cpu_reducer) that also require
|
|
||||||
* scratchpad memory. So there should be a simple way of passing the
|
|
||||||
* information back and force between the main algorithm (a primitive) and
|
|
||||||
* auxiliary stuff that lives completely separately from it (e.g. reducer).
|
|
||||||
*
|
|
||||||
* To address these challenges this header file provides 3 structures:
|
|
||||||
* 1. registry_t -- the class the stores the information about requested
|
|
||||||
* memory. The information includes required size and desired
|
|
||||||
* alignment for each piece. This class is also responsible
|
|
||||||
* for computing the right offset to a given piece using the
|
|
||||||
* base pointer.
|
|
||||||
* This class is basically a ledger with all entries.
|
|
||||||
* Lives in primitive descriptors.
|
|
||||||
*
|
|
||||||
* 2. registrar_t -- the interface to a registry_t to book memory. Used at
|
|
||||||
* primitive descriptor creation time only. Contains a
|
|
||||||
* reference to the corresponding *mutable* registry.
|
|
||||||
* Always modifiable.
|
|
||||||
* Allows chaining (using prefixes).
|
|
||||||
*
|
|
||||||
* 3. grantor_t -- the interface to a registry_t to access memory. Used at
|
|
||||||
* primitive execution time only. Contains a reference to
|
|
||||||
* the corresponding *constant* registry and base pointer.
|
|
||||||
* Always constant.
|
|
||||||
* Allows chaining (using prefixes).
|
|
||||||
*
|
|
||||||
* Both registrar_t and grantor_t allow chaining with extra prefix provided.
|
|
||||||
* The feature is useful when a primitive offload a part of computations to
|
|
||||||
* some other primitives which require their own scratchpad space
|
|
||||||
* (e.g. reducer). Prefixes are used to avoid key collision in cases when
|
|
||||||
* multiple sub-primitive (e.g. multiple reducers) are used.
|
|
||||||
*
|
|
||||||
* A short example below demonstrates how to use aforementioned classes. In it
|
|
||||||
* the main primitive is convolution that uses scratchpad for keeping padded
|
|
||||||
* bias. It also needs a reducer, that needs its own space as well.
|
|
||||||
*
|
|
||||||
* ``` c++
|
|
||||||
* struct reducer_t {
|
|
||||||
* static void init(registrar_t &scratchpad) {
|
|
||||||
* // preserve space for the reduction (one page aligned)
|
|
||||||
* scratchpad.book(key_space, sizeof(float) * 980 * 1024, 4096);
|
|
||||||
* }
|
|
||||||
*
|
|
||||||
* void exec(const grantor_t &scratchpad) {
|
|
||||||
* // get the pointer to preserved space. scratchpad came from
|
|
||||||
* // upper primitive (convolution in this example)
|
|
||||||
* auto space = scratchpad.get<float>(key_reducer_space);
|
|
||||||
*
|
|
||||||
* space[:] += ...;
|
|
||||||
* }
|
|
||||||
* };
|
|
||||||
*
|
|
||||||
* struct conv_t {
|
|
||||||
* struct pd_t {
|
|
||||||
* void init() {
|
|
||||||
* registrar_t scratchpad(scratchpad_registry_);
|
|
||||||
*
|
|
||||||
* // preserve a space for padded bias (using default alignment)
|
|
||||||
* scratchpad.book(key_conv_padded_bias, 128);
|
|
||||||
*
|
|
||||||
* // create a proxy registrar for the reducer All entries made
|
|
||||||
* // by reducer would live in convolution's registry, but would
|
|
||||||
* // have their own `prefix`, so no interference with conv's
|
|
||||||
* // buffers.
|
|
||||||
* registrar_t reducer_scratchpad(scratchpad, prefix_reducer);
|
|
||||||
*
|
|
||||||
* reducer_t::init(reducer_scratchpad);
|
|
||||||
* }
|
|
||||||
*
|
|
||||||
* registry_t scratchpad_registry_;
|
|
||||||
* }
|
|
||||||
*
|
|
||||||
* void exec() {
|
|
||||||
* // get the base pointer to a scratchpad memory from a user
|
|
||||||
* void *scratchpad_ptr = this->input(MKLDNN_MEM_SCRATCHPAD);
|
|
||||||
*
|
|
||||||
* // create a grantor to the scratchpad (and provide the base
|
|
||||||
* // pointer).
|
|
||||||
* grantor_t scratchpad(pd()->scratchpad_registry_, scratchpad_ptr);
|
|
||||||
*
|
|
||||||
* // access the padded_bias (need only key name and the grantor)
|
|
||||||
* auto padded_bias = scratchpad.get<float>(key_conv_padded_bias);
|
|
||||||
*
|
|
||||||
* // to give the `right` grantor to reducer we need to add the
|
|
||||||
* // corresponding prefix, so that reducer would be able to access
|
|
||||||
* // its keys. The call is very similar to the one in pd_t::init
|
|
||||||
* // with only difference in types: grantor_t vs registrar_t.
|
|
||||||
* grantor_t reducer_scratchpad(scratchpad, prefix_reducer);
|
|
||||||
* reducer->exec(reducer_scratchpad);
|
|
||||||
* }
|
|
||||||
* };
|
|
||||||
* ```
|
|
||||||
*/
|
|
||||||
|
|
||||||
|
|
||||||
/* namespace with common keys and prefixes */
|
|
||||||
namespace names {
|
|
||||||
enum {
|
|
||||||
key_none = 0,
|
|
||||||
key_bnorm_tmp_mean,
|
|
||||||
key_bnorm_tmp_var,
|
|
||||||
key_bnorm_tmp_diff_ss,
|
|
||||||
key_bnorm_tmp_stats,
|
|
||||||
key_bnorm_reduction,
|
|
||||||
key_concat_iptrs,
|
|
||||||
key_concat_istrides,
|
|
||||||
key_concat_nelems,
|
|
||||||
key_concat_optrs,
|
|
||||||
key_conv_adjusted_scales,
|
|
||||||
key_conv_bia_reduction,
|
|
||||||
key_conv_gemm_col,
|
|
||||||
key_conv_gemm_imtr,
|
|
||||||
key_conv_int_dat_in_acc_dt,
|
|
||||||
key_conv_padded_bias,
|
|
||||||
key_conv_rtus_space,
|
|
||||||
key_conv_tr_diff_dst,
|
|
||||||
key_conv_tr_diff_dst_bctx,
|
|
||||||
key_conv_tr_src,
|
|
||||||
key_conv_tr_src_bctx,
|
|
||||||
key_conv_wei_reduction,
|
|
||||||
key_conv_wei_bia_reduction,
|
|
||||||
key_conv_wei_bia_reduction_bctx,
|
|
||||||
key_iprod_int_dat_in_acc_dt,
|
|
||||||
key_reducer_space,
|
|
||||||
key_reducer_space_bctx,
|
|
||||||
key_reorder_wino_plain,
|
|
||||||
key_reorder_wino_transform_space,
|
|
||||||
key_reorder_rnn_weights_quantization,
|
|
||||||
key_reorder_rnn_weights_reduction,
|
|
||||||
key_rnn_space,
|
|
||||||
key_rnn_ptrs_bia,
|
|
||||||
key_rnn_ptrs_wei_layer,
|
|
||||||
key_rnn_ptrs_wei_iter,
|
|
||||||
key_softmax_reduction,
|
|
||||||
key_wino_U,
|
|
||||||
key_wino_V,
|
|
||||||
key_wino_M,
|
|
||||||
key_barrier,
|
|
||||||
};
|
|
||||||
|
|
||||||
enum {
|
|
||||||
prefix_none = 0,
|
|
||||||
prefix_reducer_bia,
|
|
||||||
prefix_reducer_wei,
|
|
||||||
};
|
|
||||||
}
|
|
||||||
|
|
||||||
// level 0: 00 00 00 xxx
|
|
||||||
// level 1: 00 00 aa xxx
|
|
||||||
// level 2: 00 aa bb xxx
|
|
||||||
// level 3: aa bb cc xxx
|
|
||||||
// max # of levels: 3 + 1 (base_level)
|
|
||||||
// here:
|
|
||||||
// xxx : [1 .. MAX_KEY) : key
|
|
||||||
// aa, bb, cc : [1 .. MAX_PREFIX) : prefixes for levels 1, 2, and 3
|
|
||||||
|
|
||||||
using key_t = uint32_t;
|
|
||||||
enum { MAX_KEY = (1u << 10), MAX_PREFIX = (1u << 7), };
|
|
||||||
|
|
||||||
/// generates global key based on a prefix and a local key
|
|
||||||
inline key_t make_key(key_t prefix, key_t key) { return prefix + key; }
|
|
||||||
|
|
||||||
/// generates global prefix based on the global parent and the local ones
|
|
||||||
inline key_t make_prefix(key_t parent_prefix, key_t prefix)
|
|
||||||
{ return MAX_PREFIX * parent_prefix + MAX_KEY * prefix; }
|
|
||||||
|
|
||||||
struct registrar_t;
|
|
||||||
struct grantor_t;
|
|
||||||
|
|
||||||
struct registry_t {
|
|
||||||
void book(const key_t &key, size_t size, size_t alignment) {
|
|
||||||
if (size == 0) return;
|
|
||||||
assert(offset_map_.count(key) == 0);
|
|
||||||
|
|
||||||
size = utils::rnd_up(size, minimal_alignment);
|
|
||||||
alignment = nstl::max<size_t>(alignment, minimal_alignment);
|
|
||||||
offset_map_[key] = entry_t{size_, size, alignment};
|
|
||||||
|
|
||||||
size_ += size + alignment - minimal_alignment;
|
|
||||||
}
|
|
||||||
|
|
||||||
void *get(const key_t &key, void *base_ptr) const {
|
|
||||||
if (base_ptr == nullptr) { assert(size() == 0); return nullptr; }
|
|
||||||
if (offset_map_.count(key) != 1) return nullptr;
|
|
||||||
|
|
||||||
const auto &e = offset_map_.at(key);
|
|
||||||
base_ptr = utils::align_ptr<void>(base_ptr, minimal_alignment);
|
|
||||||
char *ptr = (char *)base_ptr + e.offset;
|
|
||||||
return utils::align_ptr<void>(ptr, e.alignment);
|
|
||||||
}
|
|
||||||
|
|
||||||
size_t size() const
|
|
||||||
{ return size_ > 0 ? size_ + minimal_alignment - 1 : 0; }
|
|
||||||
|
|
||||||
registrar_t registrar();
|
|
||||||
grantor_t grantor(void *base_ptr) const;
|
|
||||||
|
|
||||||
protected:
|
|
||||||
enum { minimal_alignment = 64 };
|
|
||||||
struct entry_t { size_t offset, size, alignment; };
|
|
||||||
|
|
||||||
std::unordered_map<key_t, entry_t> offset_map_;
|
|
||||||
size_t size_ = 0;
|
|
||||||
};
|
|
||||||
|
|
||||||
struct registrar_t {
|
|
||||||
enum { default_alignment = 64 };
|
|
||||||
|
|
||||||
registrar_t(registry_t ®istry): registry_(registry), prefix_(0) {}
|
|
||||||
registrar_t(registrar_t &parent, const key_t &prefix)
|
|
||||||
: registry_(parent.registry_)
|
|
||||||
, prefix_(make_prefix(parent.prefix_, prefix)) {}
|
|
||||||
|
|
||||||
void book(const key_t &key, size_t size,
|
|
||||||
size_t alignment = default_alignment)
|
|
||||||
{ registry_.book(make_key(prefix_, key), size, alignment); }
|
|
||||||
|
|
||||||
protected:
|
|
||||||
registry_t ®istry_;
|
|
||||||
const key_t prefix_;
|
|
||||||
};
|
|
||||||
|
|
||||||
struct grantor_t {
|
|
||||||
grantor_t(const registry_t ®istry, void *base_ptr)
|
|
||||||
: registry_(registry), prefix_(0), base_ptr_(base_ptr) {}
|
|
||||||
grantor_t(const grantor_t &parent, const key_t &prefix)
|
|
||||||
: registry_(parent.registry_)
|
|
||||||
, prefix_(make_prefix(parent.prefix_, prefix))
|
|
||||||
, base_ptr_(parent.base_ptr_) {}
|
|
||||||
|
|
||||||
template <typename T = void> T *get(const key_t &key) const
|
|
||||||
{ return (T *)registry_.get(make_key(prefix_, key), base_ptr_); }
|
|
||||||
|
|
||||||
protected:
|
|
||||||
const registry_t ®istry_;
|
|
||||||
const key_t prefix_;
|
|
||||||
void *base_ptr_;
|
|
||||||
};
|
|
||||||
|
|
||||||
inline registrar_t registry_t::registrar() { return registrar_t(*this); }
|
|
||||||
inline grantor_t registry_t::grantor(void *base_ptr) const
|
|
||||||
{ return grantor_t(*this, base_ptr); }
|
|
||||||
|
|
||||||
}
|
|
||||||
}
|
|
||||||
}
|
|
||||||
|
|
||||||
#endif
|
|
131
thirdparty/oidn/mkl-dnn/src/common/mkldnn_debug.cpp
vendored
131
thirdparty/oidn/mkl-dnn/src/common/mkldnn_debug.cpp
vendored
@ -1,131 +0,0 @@
|
|||||||
/*******************************************************************************
|
|
||||||
* Copyright 2019 Intel Corporation
|
|
||||||
*
|
|
||||||
* Licensed under the Apache License, Version 2.0 (the "License");
|
|
||||||
* you may not use this file except in compliance with the License.
|
|
||||||
* You may obtain a copy of the License at
|
|
||||||
*
|
|
||||||
* http://www.apache.org/licenses/LICENSE-2.0
|
|
||||||
*
|
|
||||||
* Unless required by applicable law or agreed to in writing, software
|
|
||||||
* distributed under the License is distributed on an "AS IS" BASIS,
|
|
||||||
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
|
||||||
* See the License for the specific language governing permissions and
|
|
||||||
* limitations under the License.
|
|
||||||
*******************************************************************************/
|
|
||||||
|
|
||||||
#include <assert.h>
|
|
||||||
#include <stdio.h>
|
|
||||||
#include <cinttypes>
|
|
||||||
|
|
||||||
#include "mkldnn_debug.h"
|
|
||||||
#include "mkldnn_types.h"
|
|
||||||
|
|
||||||
#include "c_types_map.hpp"
|
|
||||||
#include "type_helpers.hpp"
|
|
||||||
#include "utils.hpp"
|
|
||||||
|
|
||||||
#define DPRINT(...) do { \
|
|
||||||
int l = snprintf(str + written_len, str_len, __VA_ARGS__); \
|
|
||||||
if (l < 0) return l; \
|
|
||||||
if ((size_t)l >= str_len) return -1; \
|
|
||||||
written_len += l; str_len -= l; \
|
|
||||||
} while(0)
|
|
||||||
|
|
||||||
int mkldnn_md2fmt_str(char *str, size_t str_len,
|
|
||||||
const mkldnn_memory_desc_t *mdesc) {
|
|
||||||
using namespace mkldnn::impl;
|
|
||||||
|
|
||||||
if (str == nullptr || str_len <= 1u)
|
|
||||||
return -1;
|
|
||||||
|
|
||||||
int written_len = 0;
|
|
||||||
|
|
||||||
if (mdesc == nullptr) {
|
|
||||||
DPRINT("%s::%s::",
|
|
||||||
mkldnn_dt2str(data_type::undef),
|
|
||||||
mkldnn_fmt_kind2str(format_kind::undef));
|
|
||||||
return written_len;
|
|
||||||
}
|
|
||||||
|
|
||||||
memory_desc_wrapper md(mdesc);
|
|
||||||
|
|
||||||
DPRINT("%s:", mkldnn_dt2str(md.data_type()));
|
|
||||||
|
|
||||||
bool padded_dims = false, padded_offsets = false;
|
|
||||||
for (int d = 0; d < md.ndims(); ++d) {
|
|
||||||
if (md.dims()[d] != md.padded_dims()[d]) padded_dims = true;
|
|
||||||
if (md.padded_offsets()[d] != 0) padded_offsets = true;
|
|
||||||
}
|
|
||||||
bool offset0 = md.offset0();
|
|
||||||
DPRINT("%s%s%s:",
|
|
||||||
padded_dims ? "p" : "",
|
|
||||||
padded_offsets ? "o" : "",
|
|
||||||
offset0 ? "0" : "");
|
|
||||||
|
|
||||||
DPRINT("%s:", mkldnn_fmt_kind2str(md.format_kind()));
|
|
||||||
|
|
||||||
if (!md.is_blocking_desc()) {
|
|
||||||
/* TODO: extend */
|
|
||||||
DPRINT("%s:", "");
|
|
||||||
} else {
|
|
||||||
const auto &blk = md.blocking_desc();
|
|
||||||
|
|
||||||
dims_t blocks;
|
|
||||||
md.compute_blocks(blocks);
|
|
||||||
|
|
||||||
char dim_chars[MKLDNN_MAX_NDIMS + 1];
|
|
||||||
|
|
||||||
bool plain = true;
|
|
||||||
for (int d = 0; d < md.ndims(); ++d) {
|
|
||||||
dim_chars[d] = (blocks[d] == 1 ? 'a' : 'A') + (char)d;
|
|
||||||
if (blocks[d] != 1) plain = false;
|
|
||||||
}
|
|
||||||
|
|
||||||
dims_t strides;
|
|
||||||
utils::array_copy(strides, blk.strides, md.ndims());
|
|
||||||
utils::simultaneous_sort(strides, dim_chars, md.ndims(),
|
|
||||||
[](dim_t a, dim_t b) { return b - a; });
|
|
||||||
|
|
||||||
dim_chars[md.ndims()] = '\0';
|
|
||||||
DPRINT("%s", dim_chars);
|
|
||||||
|
|
||||||
if (!plain) {
|
|
||||||
for (int iblk = 0; iblk < blk.inner_nblks; ++iblk) {
|
|
||||||
DPRINT("%d%c", (int)blk.inner_blks[iblk],
|
|
||||||
'a' + (char)blk.inner_idxs[iblk]);
|
|
||||||
}
|
|
||||||
}
|
|
||||||
|
|
||||||
DPRINT("%s", ":");
|
|
||||||
}
|
|
||||||
|
|
||||||
DPRINT("f%lx", (long)md.extra().flags);
|
|
||||||
|
|
||||||
return written_len;
|
|
||||||
}
|
|
||||||
|
|
||||||
int mkldnn_md2dim_str(char *str, size_t str_len,
|
|
||||||
const mkldnn_memory_desc_t *mdesc) {
|
|
||||||
using namespace mkldnn::impl;
|
|
||||||
|
|
||||||
if (str == nullptr || str_len <= 1)
|
|
||||||
return -1;
|
|
||||||
|
|
||||||
int written_len = 0;
|
|
||||||
|
|
||||||
if (mdesc == nullptr || mdesc->ndims == 0) {
|
|
||||||
DPRINT("%s", "");
|
|
||||||
return written_len;
|
|
||||||
}
|
|
||||||
|
|
||||||
memory_desc_wrapper md(mdesc);
|
|
||||||
|
|
||||||
for (int d = 0; d < md.ndims() - 1; ++d)
|
|
||||||
DPRINT("%" PRId64 "x", md.dims()[d]);
|
|
||||||
DPRINT("%" PRId64, md.dims()[md.ndims() - 1]);
|
|
||||||
|
|
||||||
return written_len;
|
|
||||||
}
|
|
||||||
|
|
||||||
#undef DPRINT
|
|
@ -1,365 +0,0 @@
|
|||||||
/*******************************************************************************
|
|
||||||
* Copyright 2018-2019 Intel Corporation
|
|
||||||
*
|
|
||||||
* Licensed under the Apache License, Version 2.0 (the "License");
|
|
||||||
* you may not use this file except in compliance with the License.
|
|
||||||
* You may obtain a copy of the License at
|
|
||||||
*
|
|
||||||
* http://www.apache.org/licenses/LICENSE-2.0
|
|
||||||
*
|
|
||||||
* Unless required by applicable law or agreed to in writing, software
|
|
||||||
* distributed under the License is distributed on an "AS IS" BASIS,
|
|
||||||
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
|
||||||
* See the License for the specific language governing permissions and
|
|
||||||
* limitations under the License.
|
|
||||||
*******************************************************************************/
|
|
||||||
|
|
||||||
/* DO NOT EDIT, AUTO-GENERATED */
|
|
||||||
|
|
||||||
#include <assert.h>
|
|
||||||
|
|
||||||
#include "mkldnn_debug.h"
|
|
||||||
#include "mkldnn_types.h"
|
|
||||||
|
|
||||||
const char *mkldnn_status2str(mkldnn_status_t v) {
|
|
||||||
if (v == mkldnn_success) return "success";
|
|
||||||
if (v == mkldnn_out_of_memory) return "out_of_memory";
|
|
||||||
if (v == mkldnn_try_again) return "try_again";
|
|
||||||
if (v == mkldnn_invalid_arguments) return "invalid_arguments";
|
|
||||||
if (v == mkldnn_not_ready) return "not_ready";
|
|
||||||
if (v == mkldnn_unimplemented) return "unimplemented";
|
|
||||||
if (v == mkldnn_iterator_ends) return "iterator_ends";
|
|
||||||
if (v == mkldnn_runtime_error) return "runtime_error";
|
|
||||||
if (v == mkldnn_not_required) return "not_required";
|
|
||||||
assert(!"unknown status");
|
|
||||||
return "unknown status";
|
|
||||||
}
|
|
||||||
|
|
||||||
const char *mkldnn_dt2str(mkldnn_data_type_t v) {
|
|
||||||
if (v == mkldnn_data_type_undef) return "undef";
|
|
||||||
if (v == mkldnn_f32) return "f32";
|
|
||||||
if (v == mkldnn_s32) return "s32";
|
|
||||||
if (v == mkldnn_s8) return "s8";
|
|
||||||
if (v == mkldnn_u8) return "u8";
|
|
||||||
assert(!"unknown dt");
|
|
||||||
return "unknown dt";
|
|
||||||
}
|
|
||||||
|
|
||||||
const char *mkldnn_fmt_kind2str(mkldnn_format_kind_t v) {
|
|
||||||
if (v == mkldnn_format_kind_undef) return "undef";
|
|
||||||
if (v == mkldnn_format_kind_any) return "any";
|
|
||||||
if (v == mkldnn_blocked) return "blocked";
|
|
||||||
if (v == mkldnn_format_kind_wino) return "wino";
|
|
||||||
if (v == mkldnn_format_kind_rnn_packed) return "rnn_packed";
|
|
||||||
assert(!"unknown fmt_kind");
|
|
||||||
return "unknown fmt_kind";
|
|
||||||
}
|
|
||||||
|
|
||||||
const char *mkldnn_fmt_tag2str(mkldnn_format_tag_t v) {
|
|
||||||
if (v == mkldnn_format_tag_undef) return "undef";
|
|
||||||
if (v == mkldnn_format_tag_any) return "format_tag_any";
|
|
||||||
if (v == mkldnn_a) return "a";
|
|
||||||
if (v == mkldnn_ab) return "ab";
|
|
||||||
if (v == mkldnn_abc) return "abc";
|
|
||||||
if (v == mkldnn_abcd) return "abcd";
|
|
||||||
if (v == mkldnn_abcde) return "abcde";
|
|
||||||
if (v == mkldnn_abcdef) return "abcdef";
|
|
||||||
if (v == mkldnn_abdec) return "abdec";
|
|
||||||
if (v == mkldnn_acb) return "acb";
|
|
||||||
if (v == mkldnn_acbde) return "acbde";
|
|
||||||
if (v == mkldnn_acdb) return "acdb";
|
|
||||||
if (v == mkldnn_acdeb) return "acdeb";
|
|
||||||
if (v == mkldnn_ba) return "ba";
|
|
||||||
if (v == mkldnn_bac) return "bac";
|
|
||||||
if (v == mkldnn_bacd) return "bacd";
|
|
||||||
if (v == mkldnn_bcda) return "bcda";
|
|
||||||
if (v == mkldnn_cba) return "cba";
|
|
||||||
if (v == mkldnn_cdba) return "cdba";
|
|
||||||
if (v == mkldnn_cdeba) return "cdeba";
|
|
||||||
if (v == mkldnn_decab) return "decab";
|
|
||||||
if (v == mkldnn_Abc16a) return "Abc16a";
|
|
||||||
if (v == mkldnn_ABc16a16b) return "ABc16a16b";
|
|
||||||
if (v == mkldnn_aBc16b) return "aBc16b";
|
|
||||||
if (v == mkldnn_ABc16b16a) return "ABc16b16a";
|
|
||||||
if (v == mkldnn_Abc4a) return "Abc4a";
|
|
||||||
if (v == mkldnn_aBc4b) return "aBc4b";
|
|
||||||
if (v == mkldnn_ABc4b16a4b) return "ABc4b16a4b";
|
|
||||||
if (v == mkldnn_ABc4b4a) return "ABc4b4a";
|
|
||||||
if (v == mkldnn_ABc8a16b2a) return "ABc8a16b2a";
|
|
||||||
if (v == mkldnn_ABc8a8b) return "ABc8a8b";
|
|
||||||
if (v == mkldnn_aBc8b) return "aBc8b";
|
|
||||||
if (v == mkldnn_ABc8b16a2b) return "ABc8b16a2b";
|
|
||||||
if (v == mkldnn_ABc8b8a) return "ABc8b8a";
|
|
||||||
if (v == mkldnn_Abcd16a) return "Abcd16a";
|
|
||||||
if (v == mkldnn_ABcd16a16b) return "ABcd16a16b";
|
|
||||||
if (v == mkldnn_aBcd16b) return "aBcd16b";
|
|
||||||
if (v == mkldnn_ABcd16b16a) return "ABcd16b16a";
|
|
||||||
if (v == mkldnn_aBCd16b16c) return "aBCd16b16c";
|
|
||||||
if (v == mkldnn_aBCd16c16b) return "aBCd16c16b";
|
|
||||||
if (v == mkldnn_Abcd4a) return "Abcd4a";
|
|
||||||
if (v == mkldnn_aBcd4b) return "aBcd4b";
|
|
||||||
if (v == mkldnn_ABcd4b16a4b) return "ABcd4b16a4b";
|
|
||||||
if (v == mkldnn_ABcd4b4a) return "ABcd4b4a";
|
|
||||||
if (v == mkldnn_aBCd4c16b4c) return "aBCd4c16b4c";
|
|
||||||
if (v == mkldnn_aBCd4c4b) return "aBCd4c4b";
|
|
||||||
if (v == mkldnn_ABcd8a16b2a) return "ABcd8a16b2a";
|
|
||||||
if (v == mkldnn_ABcd8a8b) return "ABcd8a8b";
|
|
||||||
if (v == mkldnn_aBcd8b) return "aBcd8b";
|
|
||||||
if (v == mkldnn_ABcd8b16a2b) return "ABcd8b16a2b";
|
|
||||||
if (v == mkldnn_aBCd8b16c2b) return "aBCd8b16c2b";
|
|
||||||
if (v == mkldnn_ABcd8b8a) return "ABcd8b8a";
|
|
||||||
if (v == mkldnn_aBCd8b8c) return "aBCd8b8c";
|
|
||||||
if (v == mkldnn_aBCd8c16b2c) return "aBCd8c16b2c";
|
|
||||||
if (v == mkldnn_aBCd8c8b) return "aBCd8c8b";
|
|
||||||
if (v == mkldnn_Abcde16a) return "Abcde16a";
|
|
||||||
if (v == mkldnn_ABcde16a16b) return "ABcde16a16b";
|
|
||||||
if (v == mkldnn_aBcde16b) return "aBcde16b";
|
|
||||||
if (v == mkldnn_ABcde16b16a) return "ABcde16b16a";
|
|
||||||
if (v == mkldnn_aBCde16b16c) return "aBCde16b16c";
|
|
||||||
if (v == mkldnn_aBCde16c16b) return "aBCde16c16b";
|
|
||||||
if (v == mkldnn_aBCde2c8b4c) return "aBCde2c8b4c";
|
|
||||||
if (v == mkldnn_Abcde4a) return "Abcde4a";
|
|
||||||
if (v == mkldnn_aBcde4b) return "aBcde4b";
|
|
||||||
if (v == mkldnn_ABcde4b4a) return "ABcde4b4a";
|
|
||||||
if (v == mkldnn_aBCde4b4c) return "aBCde4b4c";
|
|
||||||
if (v == mkldnn_aBCde4c16b4c) return "aBCde4c16b4c";
|
|
||||||
if (v == mkldnn_aBCde4c4b) return "aBCde4c4b";
|
|
||||||
if (v == mkldnn_Abcde8a) return "Abcde8a";
|
|
||||||
if (v == mkldnn_ABcde8a8b) return "ABcde8a8b";
|
|
||||||
if (v == mkldnn_ABcde8b16a2b) return "ABcde8b16a2b";
|
|
||||||
if (v == mkldnn_aBCde8b16c2b) return "aBCde8b16c2b";
|
|
||||||
if (v == mkldnn_ABcde8b8a) return "ABcde8b8a";
|
|
||||||
if (v == mkldnn_aBCde8b8c) return "aBCde8b8c";
|
|
||||||
if (v == mkldnn_aBCde8c16b2c) return "aBCde8c16b2c";
|
|
||||||
if (v == mkldnn_aBCde8c8b) return "aBCde8c8b";
|
|
||||||
if (v == mkldnn_aBcdef16b) return "aBcdef16b";
|
|
||||||
if (v == mkldnn_aBCdef16b16c) return "aBCdef16b16c";
|
|
||||||
if (v == mkldnn_aBCdef16c16b) return "aBCdef16c16b";
|
|
||||||
if (v == mkldnn_aBcdef4b) return "aBcdef4b";
|
|
||||||
if (v == mkldnn_aBCdef4c4b) return "aBCdef4c4b";
|
|
||||||
if (v == mkldnn_aBCdef8b8c) return "aBCdef8b8c";
|
|
||||||
if (v == mkldnn_aBCdef8c16b2c) return "aBCdef8c16b2c";
|
|
||||||
if (v == mkldnn_aBCdef8c8b) return "aBCdef8c8b";
|
|
||||||
if (v == mkldnn_aBdc16b) return "aBdc16b";
|
|
||||||
if (v == mkldnn_aBdc4b) return "aBdc4b";
|
|
||||||
if (v == mkldnn_aBdc8b) return "aBdc8b";
|
|
||||||
if (v == mkldnn_aBdec16b) return "aBdec16b";
|
|
||||||
if (v == mkldnn_aBdec4b) return "aBdec4b";
|
|
||||||
if (v == mkldnn_aBdec8b) return "aBdec8b";
|
|
||||||
if (v == mkldnn_aBdefc16b) return "aBdefc16b";
|
|
||||||
if (v == mkldnn_aBdefc4b) return "aBdefc4b";
|
|
||||||
if (v == mkldnn_aBdefc8b) return "aBdefc8b";
|
|
||||||
if (v == mkldnn_Acb16a) return "Acb16a";
|
|
||||||
if (v == mkldnn_Acb4a) return "Acb4a";
|
|
||||||
if (v == mkldnn_Acb8a) return "Acb8a";
|
|
||||||
if (v == mkldnn_aCBd16b16c) return "aCBd16b16c";
|
|
||||||
if (v == mkldnn_aCBde16b16c) return "aCBde16b16c";
|
|
||||||
if (v == mkldnn_Acdb16a) return "Acdb16a";
|
|
||||||
if (v == mkldnn_Acdb4a) return "Acdb4a";
|
|
||||||
if (v == mkldnn_Acdb8a) return "Acdb8a";
|
|
||||||
if (v == mkldnn_Acdeb16a) return "Acdeb16a";
|
|
||||||
if (v == mkldnn_Acdeb4a) return "Acdeb4a";
|
|
||||||
if (v == mkldnn_Acdeb8a) return "Acdeb8a";
|
|
||||||
if (v == mkldnn_BAc16a16b) return "BAc16a16b";
|
|
||||||
if (v == mkldnn_BAcd16a16b) return "BAcd16a16b";
|
|
||||||
if (v == mkldnn_format_tag_last) return "format_tag_last";
|
|
||||||
if (v == mkldnn_x) return "x";
|
|
||||||
if (v == mkldnn_nc) return "nc";
|
|
||||||
if (v == mkldnn_cn) return "cn";
|
|
||||||
if (v == mkldnn_ncw) return "ncw";
|
|
||||||
if (v == mkldnn_nwc) return "nwc";
|
|
||||||
if (v == mkldnn_nchw) return "nchw";
|
|
||||||
if (v == mkldnn_nhwc) return "nhwc";
|
|
||||||
if (v == mkldnn_chwn) return "chwn";
|
|
||||||
if (v == mkldnn_ncdhw) return "ncdhw";
|
|
||||||
if (v == mkldnn_ndhwc) return "ndhwc";
|
|
||||||
if (v == mkldnn_oi) return "oi";
|
|
||||||
if (v == mkldnn_io) return "io";
|
|
||||||
if (v == mkldnn_oiw) return "oiw";
|
|
||||||
if (v == mkldnn_wio) return "wio";
|
|
||||||
if (v == mkldnn_oihw) return "oihw";
|
|
||||||
if (v == mkldnn_hwio) return "hwio";
|
|
||||||
if (v == mkldnn_ihwo) return "ihwo";
|
|
||||||
if (v == mkldnn_iohw) return "iohw";
|
|
||||||
if (v == mkldnn_oidhw) return "oidhw";
|
|
||||||
if (v == mkldnn_dhwio) return "dhwio";
|
|
||||||
if (v == mkldnn_goiw) return "goiw";
|
|
||||||
if (v == mkldnn_goihw) return "goihw";
|
|
||||||
if (v == mkldnn_hwigo) return "hwigo";
|
|
||||||
if (v == mkldnn_giohw) return "giohw";
|
|
||||||
if (v == mkldnn_goidhw) return "goidhw";
|
|
||||||
if (v == mkldnn_tnc) return "tnc";
|
|
||||||
if (v == mkldnn_ntc) return "ntc";
|
|
||||||
if (v == mkldnn_ldsnc) return "ldsnc";
|
|
||||||
if (v == mkldnn_ldigo) return "ldigo";
|
|
||||||
if (v == mkldnn_ldgoi) return "ldgoi";
|
|
||||||
if (v == mkldnn_ldgo) return "ldgo";
|
|
||||||
if (v == mkldnn_nCdhw16c) return "nCdhw16c";
|
|
||||||
if (v == mkldnn_nCdhw4c) return "nCdhw4c";
|
|
||||||
if (v == mkldnn_nCdhw8c) return "nCdhw8c";
|
|
||||||
if (v == mkldnn_nChw16c) return "nChw16c";
|
|
||||||
if (v == mkldnn_nChw4c) return "nChw4c";
|
|
||||||
if (v == mkldnn_nChw8c) return "nChw8c";
|
|
||||||
if (v == mkldnn_nCw16c) return "nCw16c";
|
|
||||||
if (v == mkldnn_nCw4c) return "nCw4c";
|
|
||||||
if (v == mkldnn_nCw8c) return "nCw8c";
|
|
||||||
if (v == mkldnn_IOw16o16i) return "IOw16o16i";
|
|
||||||
if (v == mkldnn_OIw16i16o) return "OIw16i16o";
|
|
||||||
if (v == mkldnn_OIw16o16i) return "OIw16o16i";
|
|
||||||
if (v == mkldnn_Oiw16o) return "Oiw16o";
|
|
||||||
if (v == mkldnn_OIw4i16o4i) return "OIw4i16o4i";
|
|
||||||
if (v == mkldnn_OIw4i4o) return "OIw4i4o";
|
|
||||||
if (v == mkldnn_Oiw4o) return "Oiw4o";
|
|
||||||
if (v == mkldnn_OIw8i16o2i) return "OIw8i16o2i";
|
|
||||||
if (v == mkldnn_OIw8i8o) return "OIw8i8o";
|
|
||||||
if (v == mkldnn_OIw8o16i2o) return "OIw8o16i2o";
|
|
||||||
if (v == mkldnn_OIw8o8i) return "OIw8o8i";
|
|
||||||
if (v == mkldnn_Owi16o) return "Owi16o";
|
|
||||||
if (v == mkldnn_Owi4o) return "Owi4o";
|
|
||||||
if (v == mkldnn_Owi8o) return "Owi8o";
|
|
||||||
if (v == mkldnn_IOhw16o16i) return "IOhw16o16i";
|
|
||||||
if (v == mkldnn_Ohwi16o) return "Ohwi16o";
|
|
||||||
if (v == mkldnn_Ohwi4o) return "Ohwi4o";
|
|
||||||
if (v == mkldnn_Ohwi8o) return "Ohwi8o";
|
|
||||||
if (v == mkldnn_OIhw16i16o) return "OIhw16i16o";
|
|
||||||
if (v == mkldnn_OIhw16o16i) return "OIhw16o16i";
|
|
||||||
if (v == mkldnn_Oihw16o) return "Oihw16o";
|
|
||||||
if (v == mkldnn_OIhw4i16o4i) return "OIhw4i16o4i";
|
|
||||||
if (v == mkldnn_OIhw4i4o) return "OIhw4i4o";
|
|
||||||
if (v == mkldnn_Oihw4o) return "Oihw4o";
|
|
||||||
if (v == mkldnn_OIhw8i16o2i) return "OIhw8i16o2i";
|
|
||||||
if (v == mkldnn_OIhw8i8o) return "OIhw8i8o";
|
|
||||||
if (v == mkldnn_OIhw8o16i2o) return "OIhw8o16i2o";
|
|
||||||
if (v == mkldnn_OIhw8o8i) return "OIhw8o8i";
|
|
||||||
if (v == mkldnn_Odhwi16o) return "Odhwi16o";
|
|
||||||
if (v == mkldnn_Odhwi4o) return "Odhwi4o";
|
|
||||||
if (v == mkldnn_Odhwi8o) return "Odhwi8o";
|
|
||||||
if (v == mkldnn_OIdhw16i16o) return "OIdhw16i16o";
|
|
||||||
if (v == mkldnn_OIdhw16o16i) return "OIdhw16o16i";
|
|
||||||
if (v == mkldnn_Oidhw16o) return "Oidhw16o";
|
|
||||||
if (v == mkldnn_OIdhw4i4o) return "OIdhw4i4o";
|
|
||||||
if (v == mkldnn_Oidhw4o) return "Oidhw4o";
|
|
||||||
if (v == mkldnn_OIdhw8i16o2i) return "OIdhw8i16o2i";
|
|
||||||
if (v == mkldnn_OIdhw8i8o) return "OIdhw8i8o";
|
|
||||||
if (v == mkldnn_OIdhw8o8i) return "OIdhw8o8i";
|
|
||||||
if (v == mkldnn_Goiw16g) return "Goiw16g";
|
|
||||||
if (v == mkldnn_gIOw16o16i) return "gIOw16o16i";
|
|
||||||
if (v == mkldnn_gOIw16i16o) return "gOIw16i16o";
|
|
||||||
if (v == mkldnn_gOIw16o16i) return "gOIw16o16i";
|
|
||||||
if (v == mkldnn_gOiw16o) return "gOiw16o";
|
|
||||||
if (v == mkldnn_gOIw4i16o4i) return "gOIw4i16o4i";
|
|
||||||
if (v == mkldnn_gOIw4i4o) return "gOIw4i4o";
|
|
||||||
if (v == mkldnn_gOiw4o) return "gOiw4o";
|
|
||||||
if (v == mkldnn_gOIw8i16o2i) return "gOIw8i16o2i";
|
|
||||||
if (v == mkldnn_gOIw8i8o) return "gOIw8i8o";
|
|
||||||
if (v == mkldnn_gOIw8o16i2o) return "gOIw8o16i2o";
|
|
||||||
if (v == mkldnn_gOIw8o8i) return "gOIw8o8i";
|
|
||||||
if (v == mkldnn_gOwi16o) return "gOwi16o";
|
|
||||||
if (v == mkldnn_gOwi4o) return "gOwi4o";
|
|
||||||
if (v == mkldnn_gOwi8o) return "gOwi8o";
|
|
||||||
if (v == mkldnn_gIOhw16o16i) return "gIOhw16o16i";
|
|
||||||
if (v == mkldnn_gOhwi16o) return "gOhwi16o";
|
|
||||||
if (v == mkldnn_gOhwi4o) return "gOhwi4o";
|
|
||||||
if (v == mkldnn_gOhwi8o) return "gOhwi8o";
|
|
||||||
if (v == mkldnn_Goihw16g) return "Goihw16g";
|
|
||||||
if (v == mkldnn_gOIhw16i16o) return "gOIhw16i16o";
|
|
||||||
if (v == mkldnn_gOIhw16o16i) return "gOIhw16o16i";
|
|
||||||
if (v == mkldnn_gOihw16o) return "gOihw16o";
|
|
||||||
if (v == mkldnn_gOIhw2i8o4i) return "gOIhw2i8o4i";
|
|
||||||
if (v == mkldnn_gOIhw4i16o4i) return "gOIhw4i16o4i";
|
|
||||||
if (v == mkldnn_gOIhw4i4o) return "gOIhw4i4o";
|
|
||||||
if (v == mkldnn_gOIhw4o4i) return "gOIhw4o4i";
|
|
||||||
if (v == mkldnn_gOihw4o) return "gOihw4o";
|
|
||||||
if (v == mkldnn_Goihw8g) return "Goihw8g";
|
|
||||||
if (v == mkldnn_gOIhw8i16o2i) return "gOIhw8i16o2i";
|
|
||||||
if (v == mkldnn_gOIhw8i8o) return "gOIhw8i8o";
|
|
||||||
if (v == mkldnn_gOIhw8o16i2o) return "gOIhw8o16i2o";
|
|
||||||
if (v == mkldnn_gOIhw8o8i) return "gOIhw8o8i";
|
|
||||||
if (v == mkldnn_gOdhwi16o) return "gOdhwi16o";
|
|
||||||
if (v == mkldnn_gOdhwi4o) return "gOdhwi4o";
|
|
||||||
if (v == mkldnn_gOdhwi8o) return "gOdhwi8o";
|
|
||||||
if (v == mkldnn_gOIdhw16i16o) return "gOIdhw16i16o";
|
|
||||||
if (v == mkldnn_gOIdhw16o16i) return "gOIdhw16o16i";
|
|
||||||
if (v == mkldnn_gOidhw16o) return "gOidhw16o";
|
|
||||||
if (v == mkldnn_gOIdhw4i4o) return "gOIdhw4i4o";
|
|
||||||
if (v == mkldnn_gOidhw4o) return "gOidhw4o";
|
|
||||||
if (v == mkldnn_gOIdhw8i16o2i) return "gOIdhw8i16o2i";
|
|
||||||
if (v == mkldnn_gOIdhw8i8o) return "gOIdhw8i8o";
|
|
||||||
if (v == mkldnn_gOIdhw8o8i) return "gOIdhw8o8i";
|
|
||||||
assert(!"unknown fmt_tag");
|
|
||||||
return "unknown fmt_tag";
|
|
||||||
}
|
|
||||||
|
|
||||||
const char *mkldnn_prop_kind2str(mkldnn_prop_kind_t v) {
|
|
||||||
if (v == mkldnn_prop_kind_undef) return "undef";
|
|
||||||
if (v == mkldnn_forward_training) return "forward_training";
|
|
||||||
if (v == mkldnn_forward_inference) return "forward_inference";
|
|
||||||
if (v == mkldnn_forward_scoring) return "forward_scoring";
|
|
||||||
if (v == mkldnn_forward) return "forward";
|
|
||||||
if (v == mkldnn_backward) return "backward";
|
|
||||||
if (v == mkldnn_backward_data) return "backward_data";
|
|
||||||
if (v == mkldnn_backward_weights) return "backward_weights";
|
|
||||||
if (v == mkldnn_backward_bias) return "backward_bias";
|
|
||||||
assert(!"unknown prop_kind");
|
|
||||||
return "unknown prop_kind";
|
|
||||||
}
|
|
||||||
|
|
||||||
const char *mkldnn_prim_kind2str(mkldnn_primitive_kind_t v) {
|
|
||||||
if (v == mkldnn_undefined_primitive) return "undef";
|
|
||||||
if (v == mkldnn_reorder) return "reorder";
|
|
||||||
if (v == mkldnn_shuffle) return "shuffle";
|
|
||||||
if (v == mkldnn_concat) return "concat";
|
|
||||||
if (v == mkldnn_sum) return "sum";
|
|
||||||
if (v == mkldnn_convolution) return "convolution";
|
|
||||||
if (v == mkldnn_deconvolution) return "deconvolution";
|
|
||||||
if (v == mkldnn_eltwise) return "eltwise";
|
|
||||||
if (v == mkldnn_softmax) return "softmax";
|
|
||||||
if (v == mkldnn_pooling) return "pooling";
|
|
||||||
if (v == mkldnn_lrn) return "lrn";
|
|
||||||
if (v == mkldnn_batch_normalization) return "batch_normalization";
|
|
||||||
if (v == mkldnn_inner_product) return "inner_product";
|
|
||||||
if (v == mkldnn_rnn) return "rnn";
|
|
||||||
assert(!"unknown prim_kind");
|
|
||||||
return "unknown prim_kind";
|
|
||||||
}
|
|
||||||
|
|
||||||
const char *mkldnn_alg_kind2str(mkldnn_alg_kind_t v) {
|
|
||||||
if (v == mkldnn_alg_kind_undef) return "undef";
|
|
||||||
if (v == mkldnn_convolution_direct) return "convolution_direct";
|
|
||||||
if (v == mkldnn_convolution_winograd) return "convolution_winograd";
|
|
||||||
if (v == mkldnn_convolution_auto) return "convolution_auto";
|
|
||||||
if (v == mkldnn_deconvolution_direct) return "deconvolution_direct";
|
|
||||||
if (v == mkldnn_deconvolution_winograd) return "deconvolution_winograd";
|
|
||||||
if (v == mkldnn_eltwise_relu) return "eltwise_relu";
|
|
||||||
if (v == mkldnn_eltwise_tanh) return "eltwise_tanh";
|
|
||||||
if (v == mkldnn_eltwise_elu) return "eltwise_elu";
|
|
||||||
if (v == mkldnn_eltwise_square) return "eltwise_square";
|
|
||||||
if (v == mkldnn_eltwise_abs) return "eltwise_abs";
|
|
||||||
if (v == mkldnn_eltwise_sqrt) return "eltwise_sqrt";
|
|
||||||
if (v == mkldnn_eltwise_linear) return "eltwise_linear";
|
|
||||||
if (v == mkldnn_eltwise_bounded_relu) return "eltwise_bounded_relu";
|
|
||||||
if (v == mkldnn_eltwise_soft_relu) return "eltwise_soft_relu";
|
|
||||||
if (v == mkldnn_eltwise_logistic) return "eltwise_logistic";
|
|
||||||
if (v == mkldnn_pooling_max) return "pooling_max";
|
|
||||||
if (v == mkldnn_pooling_avg_include_padding) return "pooling_avg_include_padding";
|
|
||||||
if (v == mkldnn_pooling_avg_exclude_padding) return "pooling_avg_exclude_padding";
|
|
||||||
if (v == mkldnn_pooling_avg) return "pooling_avg";
|
|
||||||
if (v == mkldnn_lrn_across_channels) return "lrn_across_channels";
|
|
||||||
if (v == mkldnn_lrn_within_channel) return "lrn_within_channel";
|
|
||||||
if (v == mkldnn_vanilla_rnn) return "vanilla_rnn";
|
|
||||||
if (v == mkldnn_vanilla_lstm) return "vanilla_lstm";
|
|
||||||
if (v == mkldnn_vanilla_gru) return "vanilla_gru";
|
|
||||||
if (v == mkldnn_gru_linear_before_reset) return "gru_linear_before_reset";
|
|
||||||
assert(!"unknown alg_kind");
|
|
||||||
return "unknown alg_kind";
|
|
||||||
}
|
|
||||||
|
|
||||||
const char *mkldnn_rnn_direction2str(mkldnn_rnn_direction_t v) {
|
|
||||||
if (v == mkldnn_unidirectional_left2right) return "unidirectional_left2right";
|
|
||||||
if (v == mkldnn_unidirectional_right2left) return "unidirectional_right2left";
|
|
||||||
if (v == mkldnn_bidirectional_concat) return "bidirectional_concat";
|
|
||||||
if (v == mkldnn_bidirectional_sum) return "bidirectional_sum";
|
|
||||||
if (v == mkldnn_unidirectional) return "unidirectional";
|
|
||||||
assert(!"unknown rnn_direction");
|
|
||||||
return "unknown rnn_direction";
|
|
||||||
}
|
|
115
thirdparty/oidn/mkl-dnn/src/common/mkldnn_thread.hpp
vendored
115
thirdparty/oidn/mkl-dnn/src/common/mkldnn_thread.hpp
vendored
@ -1,115 +0,0 @@
|
|||||||
/*******************************************************************************
|
|
||||||
* Copyright 2017-2018 Intel Corporation
|
|
||||||
*
|
|
||||||
* Licensed under the Apache License, Version 2.0 (the "License");
|
|
||||||
* you may not use this file except in compliance with the License.
|
|
||||||
* You may obtain a copy of the License at
|
|
||||||
*
|
|
||||||
* http://www.apache.org/licenses/LICENSE-2.0
|
|
||||||
*
|
|
||||||
* Unless required by applicable law or agreed to in writing, software
|
|
||||||
* distributed under the License is distributed on an "AS IS" BASIS,
|
|
||||||
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
|
||||||
* See the License for the specific language governing permissions and
|
|
||||||
* limitations under the License.
|
|
||||||
*******************************************************************************/
|
|
||||||
|
|
||||||
#ifndef MKLDNN_THREAD_HPP
|
|
||||||
#define MKLDNN_THREAD_HPP
|
|
||||||
|
|
||||||
#include "utils.hpp"
|
|
||||||
#include "z_magic.hpp"
|
|
||||||
|
|
||||||
#define MKLDNN_THR_SEQ 0
|
|
||||||
#define MKLDNN_THR_OMP 1
|
|
||||||
#define MKLDNN_THR_TBB 2
|
|
||||||
|
|
||||||
/* Ideally this condition below should never happen (if the library is built
|
|
||||||
* using regular cmake). For the 3rd-party projects that build the library
|
|
||||||
* from the sources on their own try to guess the right threading... */
|
|
||||||
#if !defined(MKLDNN_THR)
|
|
||||||
# define MKLDNN_THR MKLDNN_THR_TBB
|
|
||||||
#endif
|
|
||||||
|
|
||||||
#if MKLDNN_THR == MKLDNN_THR_SEQ
|
|
||||||
#define MKLDNN_THR_SYNC 1
|
|
||||||
inline int mkldnn_get_max_threads() { return 1; }
|
|
||||||
inline int mkldnn_get_num_threads() { return 1; }
|
|
||||||
inline int mkldnn_get_thread_num() { return 0; }
|
|
||||||
inline int mkldnn_in_parallel() { return 0; }
|
|
||||||
inline void mkldnn_thr_barrier() {}
|
|
||||||
|
|
||||||
#define PRAGMA_OMP(...)
|
|
||||||
|
|
||||||
#elif MKLDNN_THR == MKLDNN_THR_OMP
|
|
||||||
#include <omp.h>
|
|
||||||
#define MKLDNN_THR_SYNC 1
|
|
||||||
|
|
||||||
inline int mkldnn_get_max_threads() { return omp_get_max_threads(); }
|
|
||||||
inline int mkldnn_get_num_threads() { return omp_get_num_threads(); }
|
|
||||||
inline int mkldnn_get_thread_num() { return omp_get_thread_num(); }
|
|
||||||
inline int mkldnn_in_parallel() { return omp_in_parallel(); }
|
|
||||||
inline void mkldnn_thr_barrier() {
|
|
||||||
# pragma omp barrier
|
|
||||||
}
|
|
||||||
|
|
||||||
#define PRAGMA_OMP(...) PRAGMA_MACRO(CHAIN2(omp, __VA_ARGS__))
|
|
||||||
|
|
||||||
#elif MKLDNN_THR == MKLDNN_THR_TBB
|
|
||||||
#include "tbb/task_arena.h"
|
|
||||||
#include "tbb/parallel_for.h"
|
|
||||||
#define MKLDNN_THR_SYNC 0
|
|
||||||
|
|
||||||
inline int mkldnn_get_max_threads()
|
|
||||||
{ return tbb::this_task_arena::max_concurrency(); }
|
|
||||||
inline int mkldnn_get_num_threads() { return mkldnn_get_max_threads(); }
|
|
||||||
inline int mkldnn_get_thread_num()
|
|
||||||
{ return tbb::this_task_arena::current_thread_index(); }
|
|
||||||
inline int mkldnn_in_parallel() { return 0; }
|
|
||||||
inline void mkldnn_thr_barrier() { assert(!"no barrier in TBB"); }
|
|
||||||
|
|
||||||
#define PRAGMA_OMP(...)
|
|
||||||
|
|
||||||
#endif
|
|
||||||
|
|
||||||
/* MSVC still supports omp 2.0 only */
|
|
||||||
#if defined(_MSC_VER) && !defined(__clang__) && !defined(__INTEL_COMPILER)
|
|
||||||
# define collapse(x)
|
|
||||||
# define PRAGMA_OMP_SIMD(...)
|
|
||||||
#else
|
|
||||||
# define PRAGMA_OMP_SIMD(...) PRAGMA_MACRO(CHAIN2(omp, simd __VA_ARGS__))
|
|
||||||
#endif // defined(_MSC_VER) && !defined(__INTEL_COMPILER)
|
|
||||||
|
|
||||||
namespace mkldnn {
|
|
||||||
namespace impl {
|
|
||||||
|
|
||||||
inline bool mkldnn_thr_syncable() { return MKLDNN_THR_SYNC == 1; }
|
|
||||||
|
|
||||||
template <typename T, typename U>
|
|
||||||
inline void balance211(T n, U team, U tid, T &n_start, T &n_end) {
|
|
||||||
T n_min = 1;
|
|
||||||
T &n_my = n_end;
|
|
||||||
if (team <= 1 || n == 0) {
|
|
||||||
n_start = 0;
|
|
||||||
n_my = n;
|
|
||||||
} else if (n_min == 1) {
|
|
||||||
// team = T1 + T2
|
|
||||||
// n = T1*n1 + T2*n2 (n1 - n2 = 1)
|
|
||||||
T n1 = utils::div_up(n, (T)team);
|
|
||||||
T n2 = n1 - 1;
|
|
||||||
T T1 = n - n2 * (T)team;
|
|
||||||
n_my = (T)tid < T1 ? n1 : n2;
|
|
||||||
n_start = (T)tid <= T1 ? tid * n1 : T1 * n1 + ((T)tid - T1) * n2;
|
|
||||||
}
|
|
||||||
|
|
||||||
n_end += n_start;
|
|
||||||
}
|
|
||||||
|
|
||||||
} // namespace impl
|
|
||||||
} // namespace mkldnn
|
|
||||||
|
|
||||||
#include "mkldnn_thread_parallel_nd.hpp"
|
|
||||||
|
|
||||||
#endif
|
|
||||||
|
|
||||||
// vim: et ts=4 sw=4 cindent cino^=l0,\:0,N-s
|
|
@ -1,277 +0,0 @@
|
|||||||
/*******************************************************************************
|
|
||||||
* Copyright 2018 Intel Corporation
|
|
||||||
*
|
|
||||||
* Licensed under the Apache License, Version 2.0 (the "License");
|
|
||||||
* you may not use this file except in compliance with the License.
|
|
||||||
* You may obtain a copy of the License at
|
|
||||||
*
|
|
||||||
* http://www.apache.org/licenses/LICENSE-2.0
|
|
||||||
*
|
|
||||||
* Unless required by applicable law or agreed to in writing, software
|
|
||||||
* distributed under the License is distributed on an "AS IS" BASIS,
|
|
||||||
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
|
||||||
* See the License for the specific language governing permissions and
|
|
||||||
* limitations under the License.
|
|
||||||
*******************************************************************************/
|
|
||||||
|
|
||||||
#ifndef MKLDNN_THREAD_PARALLEL_ND_HPP
|
|
||||||
#define MKLDNN_THREAD_PARALLEL_ND_HPP
|
|
||||||
|
|
||||||
/* This header must be included by mkldnn_thread.hpp only */
|
|
||||||
|
|
||||||
/* Functions:
|
|
||||||
* - parallel(nthr, f) - executes f in parallel using at most
|
|
||||||
* nthr threads. If nthr equals 0
|
|
||||||
* mkldnn_get_max_threads() threads is
|
|
||||||
* used
|
|
||||||
* - for_nd(ithr, nthr, dims..., f) - multidimensional for loop for already
|
|
||||||
* created threads
|
|
||||||
* - parallel_nd(dims..., f) - creates a parallel section and then
|
|
||||||
* calls for_nd
|
|
||||||
* - parallel_nd_in_omp(dims..., f) - queries current nthr and ithr and then
|
|
||||||
* calls for_nd (mostly for convenience)
|
|
||||||
*/
|
|
||||||
|
|
||||||
namespace mkldnn {
|
|
||||||
namespace impl {
|
|
||||||
|
|
||||||
/* general parallelization */
|
|
||||||
template <typename F>
|
|
||||||
void parallel(int nthr, F f) {
|
|
||||||
if (nthr == 0) nthr = mkldnn_get_max_threads();
|
|
||||||
#if MKLDNN_THR == MKLDNN_THR_SEQ
|
|
||||||
assert(nthr == 1);
|
|
||||||
f(0, 1);
|
|
||||||
#elif MKLDNN_THR == MKLDNN_THR_OMP
|
|
||||||
if (nthr == 1) { f(0, 1); return; }
|
|
||||||
# pragma omp parallel num_threads(nthr)
|
|
||||||
f(mkldnn_get_thread_num(), mkldnn_get_num_threads());
|
|
||||||
#elif MKLDNN_THR == MKLDNN_THR_TBB
|
|
||||||
if (nthr == 1) { f(0, 1); return; }
|
|
||||||
tbb::parallel_for(0, nthr, [&](int ithr) { f(ithr, nthr); }, tbb::static_partitioner());
|
|
||||||
#endif
|
|
||||||
}
|
|
||||||
|
|
||||||
/* for_nd section */
|
|
||||||
|
|
||||||
template <typename T0, typename F>
|
|
||||||
void for_nd(const int ithr, const int nthr, const T0 &D0, F f) {
|
|
||||||
T0 start{0}, end{0};
|
|
||||||
balance211(D0, nthr, ithr, start, end);
|
|
||||||
for (T0 d0 = start; d0 < end; ++d0) f(d0);
|
|
||||||
}
|
|
||||||
|
|
||||||
template <typename T0, typename T1, typename F>
|
|
||||||
void for_nd(const int ithr, const int nthr, const T0 &D0, const T1 &D1, F f) {
|
|
||||||
const size_t work_amount = (size_t)D0 * D1;
|
|
||||||
if (work_amount == 0) return;
|
|
||||||
size_t start{0}, end{0};
|
|
||||||
balance211(work_amount, nthr, ithr, start, end);
|
|
||||||
|
|
||||||
T0 d0{0}; T1 d1{0};
|
|
||||||
utils::nd_iterator_init(start, d0, D0, d1, D1);
|
|
||||||
for (size_t iwork = start; iwork < end; ++iwork) {
|
|
||||||
f(d0, d1);
|
|
||||||
utils::nd_iterator_step(d0, D0, d1, D1);
|
|
||||||
}
|
|
||||||
}
|
|
||||||
|
|
||||||
template <typename T0, typename T1, typename T2, typename F>
|
|
||||||
void for_nd(const int ithr, const int nthr, const T0 &D0, const T1 &D1,
|
|
||||||
const T2 &D2, F f) {
|
|
||||||
const size_t work_amount = (size_t)D0 * D1 * D2;
|
|
||||||
if (work_amount == 0) return;
|
|
||||||
size_t start{0}, end{0};
|
|
||||||
balance211(work_amount, nthr, ithr, start, end);
|
|
||||||
|
|
||||||
T0 d0{0}; T1 d1{0}; T2 d2{0};
|
|
||||||
utils::nd_iterator_init(start, d0, D0, d1, D1, d2, D2);
|
|
||||||
for (size_t iwork = start; iwork < end; ++iwork) {
|
|
||||||
f(d0, d1, d2);
|
|
||||||
utils::nd_iterator_step(d0, D0, d1, D1, d2, D2);
|
|
||||||
}
|
|
||||||
}
|
|
||||||
|
|
||||||
template <typename T0, typename T1, typename T2, typename T3, typename F>
|
|
||||||
void for_nd(const int ithr, const int nthr, const T0 &D0, const T1 &D1,
|
|
||||||
const T2 &D2, const T3 &D3, F f) {
|
|
||||||
const size_t work_amount = (size_t)D0 * D1 * D2 * D3;
|
|
||||||
if (work_amount == 0) return;
|
|
||||||
size_t start{0}, end{0};
|
|
||||||
balance211(work_amount, nthr, ithr, start, end);
|
|
||||||
|
|
||||||
T0 d0{0}; T1 d1{0}; T2 d2{0}; T3 d3{0};
|
|
||||||
utils::nd_iterator_init(start, d0, D0, d1, D1, d2, D2, d3, D3);
|
|
||||||
for (size_t iwork = start; iwork < end; ++iwork) {
|
|
||||||
f(d0, d1, d2, d3);
|
|
||||||
utils::nd_iterator_step(d0, D0, d1, D1, d2, D2, d3, D3);
|
|
||||||
}
|
|
||||||
}
|
|
||||||
|
|
||||||
template <typename T0, typename T1, typename T2, typename T3, typename T4,
|
|
||||||
typename F>
|
|
||||||
void for_nd(const int ithr, const int nthr, const T0 &D0, const T1 &D1,
|
|
||||||
const T2 &D2, const T3 &D3, const T4 &D4, F f) {
|
|
||||||
const size_t work_amount = (size_t)D0 * D1 * D2 * D3 * D4;
|
|
||||||
if (work_amount == 0) return;
|
|
||||||
size_t start{0}, end{0};
|
|
||||||
balance211(work_amount, nthr, ithr, start, end);
|
|
||||||
|
|
||||||
T0 d0{0}; T1 d1{0}; T2 d2{0}; T3 d3{0}; T4 d4{0};
|
|
||||||
utils::nd_iterator_init(start, d0, D0, d1, D1, d2, D2, d3, D3, d4, D4);
|
|
||||||
for (size_t iwork = start; iwork < end; ++iwork) {
|
|
||||||
f(d0, d1, d2, d3, d4);
|
|
||||||
utils::nd_iterator_step(d0, D0, d1, D1, d2, D2, d3, D3, d4, D4);
|
|
||||||
}
|
|
||||||
}
|
|
||||||
|
|
||||||
template <typename T0, typename T1, typename T2, typename T3, typename T4,
|
|
||||||
typename T5, typename F>
|
|
||||||
void for_nd(const int ithr, const int nthr, const T0 &D0, const T1 &D1,
|
|
||||||
const T2 &D2, const T3 &D3, const T4 &D4, const T5 &D5, F f) {
|
|
||||||
const size_t work_amount = (size_t)D0 * D1 * D2 * D3 * D4 * D5;
|
|
||||||
if (work_amount == 0) return;
|
|
||||||
size_t start{0}, end{0};
|
|
||||||
balance211(work_amount, nthr, ithr, start, end);
|
|
||||||
|
|
||||||
T0 d0{0}; T1 d1{0}; T2 d2{0}; T3 d3{0}; T4 d4{0}; T5 d5{0};
|
|
||||||
utils::nd_iterator_init(start, d0, D0, d1, D1, d2, D2, d3, D3, d4, D4,
|
|
||||||
d5, D5);
|
|
||||||
for (size_t iwork = start; iwork < end; ++iwork) {
|
|
||||||
f(d0, d1, d2, d3, d4, d5);
|
|
||||||
utils::nd_iterator_step(d0, D0, d1, D1, d2, D2, d3, D3, d4, D4, d5, D5);
|
|
||||||
}
|
|
||||||
}
|
|
||||||
|
|
||||||
// Skip a lambda function in the parameter pack.
|
|
||||||
template <typename T>
|
|
||||||
constexpr size_t get_work_amount(const T &v) { return 1; }
|
|
||||||
template <typename T, typename ...Args>
|
|
||||||
constexpr size_t get_work_amount(const T &v, Args &&...args)
|
|
||||||
{ return (size_t)v * get_work_amount(utils::forward<Args>(args)...); }
|
|
||||||
|
|
||||||
/* parallel_nd and parallel_nd_in_omp section */
|
|
||||||
|
|
||||||
#if MKLDNN_THR != MKLDNN_THR_TBB
|
|
||||||
template <typename ...Args>
|
|
||||||
void parallel_nd(Args &&...args) {
|
|
||||||
#if MKLDNN_THR == MKLDNN_THR_SEQ
|
|
||||||
for_nd(0, 1, utils::forward<Args>(args)...);
|
|
||||||
#elif MKLDNN_THR == MKLDNN_THR_OMP
|
|
||||||
const bool do_parallel = get_work_amount(utils::forward<Args>(args)...) > 1;
|
|
||||||
# pragma omp parallel if (do_parallel)
|
|
||||||
{
|
|
||||||
const int nthr = !do_parallel ? 1 : mkldnn_get_num_threads();
|
|
||||||
const int ithr = !do_parallel ? 0 : mkldnn_get_thread_num();
|
|
||||||
for_nd(ithr, nthr, utils::forward<Args>(args)...);
|
|
||||||
}
|
|
||||||
#endif
|
|
||||||
}
|
|
||||||
#else // MKLDNN_THR != MKLDNN_THR_TBB
|
|
||||||
|
|
||||||
// gcc 4.8 has a bug with passing parameter pack to lambdas.
|
|
||||||
// So have to explicitly instantiate all the cases.
|
|
||||||
|
|
||||||
template <typename T0, typename F>
|
|
||||||
void parallel_nd(const T0 &D0, F f) {
|
|
||||||
const size_t work_amount = (size_t)D0;
|
|
||||||
if (work_amount == 0) return;
|
|
||||||
tbb::parallel_for(tbb::blocked_range<size_t>(0, work_amount), [&](const tbb::blocked_range<size_t>& r) {
|
|
||||||
for (size_t iwork = r.begin(); iwork != r.end(); ++iwork) {
|
|
||||||
f(T0(iwork));
|
|
||||||
}
|
|
||||||
}, tbb::static_partitioner());
|
|
||||||
}
|
|
||||||
|
|
||||||
template <typename T0, typename T1, typename F>
|
|
||||||
void parallel_nd(const T0 &D0, const T1 &D1, F f) {
|
|
||||||
const size_t work_amount = (size_t)D0 * D1;
|
|
||||||
if (work_amount == 0) return;
|
|
||||||
tbb::parallel_for(tbb::blocked_range<size_t>(0, work_amount), [&](const tbb::blocked_range<size_t>& r) {
|
|
||||||
T0 d0{0}; T1 d1{0};
|
|
||||||
utils::nd_iterator_init(r.begin(), d0, D0, d1, D1);
|
|
||||||
for (size_t iwork = r.begin(); iwork != r.end(); ++iwork) {
|
|
||||||
f(d0, d1);
|
|
||||||
utils::nd_iterator_step(d0, D0, d1, D1);
|
|
||||||
}
|
|
||||||
}, tbb::static_partitioner());
|
|
||||||
}
|
|
||||||
|
|
||||||
template <typename T0, typename T1, typename T2, typename F>
|
|
||||||
void parallel_nd(const T0 &D0, const T1 &D1, const T2 &D2, F f) {
|
|
||||||
const size_t work_amount = (size_t)D0 * D1 * D2;
|
|
||||||
if (work_amount == 0) return;
|
|
||||||
tbb::parallel_for(tbb::blocked_range<size_t>(0, work_amount), [&](const tbb::blocked_range<size_t>& r) {
|
|
||||||
T0 d0{0}; T1 d1{0}; T2 d2{0};
|
|
||||||
utils::nd_iterator_init(r.begin(), d0, D0, d1, D1, d2, D2);
|
|
||||||
for (size_t iwork = r.begin(); iwork != r.end(); ++iwork) {
|
|
||||||
f(d0, d1, d2);
|
|
||||||
utils::nd_iterator_step(d0, D0, d1, D1, d2, D2);
|
|
||||||
}
|
|
||||||
}, tbb::static_partitioner());
|
|
||||||
}
|
|
||||||
|
|
||||||
template <typename T0, typename T1, typename T2, typename T3, typename F>
|
|
||||||
void parallel_nd(const T0 &D0, const T1 &D1, const T2 &D2, const T3 &D3, F f) {
|
|
||||||
const size_t work_amount = (size_t)D0 * D1 * D2 * D3;
|
|
||||||
if (work_amount == 0) return;
|
|
||||||
tbb::parallel_for(tbb::blocked_range<size_t>(0, work_amount), [&](const tbb::blocked_range<size_t>& r) {
|
|
||||||
T0 d0{0}; T1 d1{0}; T2 d2{0}; T3 d3{0};
|
|
||||||
utils::nd_iterator_init(r.begin(), d0, D0, d1, D1, d2, D2, d3, D3);
|
|
||||||
for (size_t iwork = r.begin(); iwork != r.end(); ++iwork) {
|
|
||||||
f(d0, d1, d2, d3);
|
|
||||||
utils::nd_iterator_step(d0, D0, d1, D1, d2, D2, d3, D3);
|
|
||||||
}
|
|
||||||
}, tbb::static_partitioner());
|
|
||||||
}
|
|
||||||
|
|
||||||
template <typename T0, typename T1, typename T2, typename T3, typename T4,
|
|
||||||
typename F>
|
|
||||||
void parallel_nd(const T0 &D0, const T1 &D1, const T2 &D2, const T3 &D3,
|
|
||||||
const T4 &D4, F f) {
|
|
||||||
const size_t work_amount = (size_t)D0 * D1 * D2 * D3 * D4;
|
|
||||||
if (work_amount == 0) return;
|
|
||||||
tbb::parallel_for(tbb::blocked_range<size_t>(0, work_amount), [&](const tbb::blocked_range<size_t>& r) {
|
|
||||||
T0 d0{0}; T1 d1{0}; T2 d2{0}; T3 d3{0}; T4 d4{0};
|
|
||||||
utils::nd_iterator_init(r.begin(), d0, D0, d1, D1, d2, D2, d3, D3, d4, D4);
|
|
||||||
for (size_t iwork = r.begin(); iwork != r.end(); ++iwork) {
|
|
||||||
f(d0, d1, d2, d3, d4);
|
|
||||||
utils::nd_iterator_step(d0, D0, d1, D1, d2, D2, d3, D3, d4, D4);
|
|
||||||
}
|
|
||||||
}, tbb::static_partitioner());
|
|
||||||
}
|
|
||||||
|
|
||||||
template <typename T0, typename T1, typename T2, typename T3, typename T4,
|
|
||||||
typename T5, typename F>
|
|
||||||
void parallel_nd(const T0 &D0, const T1 &D1, const T2 &D2, const T3 &D3,
|
|
||||||
const T4 &D4, const T5 &D5, F f) {
|
|
||||||
const size_t work_amount = (size_t)D0 * D1 * D2 * D3 * D4 * D5;
|
|
||||||
if (work_amount == 0) return;
|
|
||||||
tbb::parallel_for(tbb::blocked_range<size_t>(0, work_amount), [&](const tbb::blocked_range<size_t>& r) {
|
|
||||||
T0 d0{0}; T1 d1{0}; T2 d2{0}; T3 d3{0}; T4 d4{0}; T5 d5{0};
|
|
||||||
utils::nd_iterator_init(r.begin(), d0, D0, d1, D1, d2, D2, d3, D3, d4, D4,
|
|
||||||
d5, D5);
|
|
||||||
for (size_t iwork = r.begin(); iwork != r.end(); ++iwork) {
|
|
||||||
f(d0, d1, d2, d3, d4, d5);
|
|
||||||
utils::nd_iterator_step(d0, D0, d1, D1, d2, D2, d3, D3, d4, D4, d5, D5);
|
|
||||||
}
|
|
||||||
}, tbb::static_partitioner());
|
|
||||||
}
|
|
||||||
#endif
|
|
||||||
|
|
||||||
template <typename ...Args>
|
|
||||||
void parallel_nd_in_omp(Args &&...args) {
|
|
||||||
#if MKLDNN_THR == MKLDNN_THR_SEQ
|
|
||||||
for_nd(0, 1, utils::forward<Args>(args)...);
|
|
||||||
#elif MKLDNN_THR == MKLDNN_THR_OMP
|
|
||||||
for_nd(mkldnn_get_thread_num(), mkldnn_get_num_threads(),
|
|
||||||
utils::forward<Args>(args)...);
|
|
||||||
#elif MKLDNN_THR == MKLDNN_THR_TBB
|
|
||||||
assert(!"unsupported parallel_nd_in_omp()");
|
|
||||||
#endif
|
|
||||||
}
|
|
||||||
|
|
||||||
} // namespace impl
|
|
||||||
} // namespace mkldnn
|
|
||||||
|
|
||||||
#endif
|
|
@ -1,77 +0,0 @@
|
|||||||
/*******************************************************************************
|
|
||||||
* Copyright 2016-2018 Intel Corporation
|
|
||||||
*
|
|
||||||
* Licensed under the Apache License, Version 2.0 (the "License");
|
|
||||||
* you may not use this file except in compliance with the License.
|
|
||||||
* You may obtain a copy of the License at
|
|
||||||
*
|
|
||||||
* http://www.apache.org/licenses/LICENSE-2.0
|
|
||||||
*
|
|
||||||
* Unless required by applicable law or agreed to in writing, software
|
|
||||||
* distributed under the License is distributed on an "AS IS" BASIS,
|
|
||||||
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
|
||||||
* See the License for the specific language governing permissions and
|
|
||||||
* limitations under the License.
|
|
||||||
*******************************************************************************/
|
|
||||||
|
|
||||||
#ifndef MKLDNN_TRAITS_HPP
|
|
||||||
#define MKLDNN_TRAITS_HPP
|
|
||||||
|
|
||||||
#include <assert.h>
|
|
||||||
#include <stdint.h>
|
|
||||||
|
|
||||||
#include "mkldnn.h"
|
|
||||||
#include "c_types_map.hpp"
|
|
||||||
#include "nstl.hpp"
|
|
||||||
#include "utils.hpp"
|
|
||||||
#include "z_magic.hpp"
|
|
||||||
|
|
||||||
namespace mkldnn {
|
|
||||||
namespace impl {
|
|
||||||
|
|
||||||
template <data_type_t> struct prec_traits {}; /* ::type -> float */
|
|
||||||
template <typename> struct data_traits {}; /* ::data_type -> f32 */
|
|
||||||
template <int> struct typesize_traits {}; /* ::data_type_size -> f32 */
|
|
||||||
template <primitive_kind_t> struct pkind_traits {}; /* ::desc_type, ::query_d */
|
|
||||||
|
|
||||||
template <> struct prec_traits<data_type::f32> { typedef float type; };
|
|
||||||
template <> struct prec_traits<data_type::s32> { typedef int32_t type; };
|
|
||||||
template <> struct prec_traits<data_type::s8> { typedef int8_t type; };
|
|
||||||
template <> struct prec_traits<data_type::u8> { typedef uint8_t type; };
|
|
||||||
|
|
||||||
template <> struct data_traits<float>
|
|
||||||
{ static constexpr data_type_t data_type = data_type::f32; };
|
|
||||||
template <> struct data_traits<int32_t>
|
|
||||||
{ static constexpr data_type_t data_type = data_type::s32; };
|
|
||||||
template <> struct data_traits<int8_t>
|
|
||||||
{ static constexpr data_type_t data_type = data_type::s8; };
|
|
||||||
template <> struct data_traits<uint8_t>
|
|
||||||
{ static constexpr data_type_t data_type = data_type::u8; };
|
|
||||||
|
|
||||||
template <> struct typesize_traits<4> { typedef float type; };
|
|
||||||
template <> struct typesize_traits<2> { typedef int16_t type; };
|
|
||||||
template <> struct typesize_traits<1> { typedef uint8_t type; };
|
|
||||||
|
|
||||||
#define PKIND_TRAITS_INST(op) \
|
|
||||||
template <> struct pkind_traits<primitive_kind::op> { \
|
|
||||||
typedef CONCAT2(op, _desc_t) desc_type; \
|
|
||||||
static constexpr query_t query_d = query::CONCAT2(op, _d); \
|
|
||||||
}
|
|
||||||
PKIND_TRAITS_INST(convolution);
|
|
||||||
PKIND_TRAITS_INST(deconvolution);
|
|
||||||
PKIND_TRAITS_INST(shuffle);
|
|
||||||
PKIND_TRAITS_INST(eltwise);
|
|
||||||
PKIND_TRAITS_INST(softmax);
|
|
||||||
PKIND_TRAITS_INST(pooling);
|
|
||||||
PKIND_TRAITS_INST(lrn);
|
|
||||||
PKIND_TRAITS_INST(batch_normalization);
|
|
||||||
PKIND_TRAITS_INST(inner_product);
|
|
||||||
PKIND_TRAITS_INST(rnn);
|
|
||||||
#undef PKIND_TRAITS_INST
|
|
||||||
|
|
||||||
}
|
|
||||||
}
|
|
||||||
|
|
||||||
#endif
|
|
||||||
|
|
||||||
// vim: et ts=4 sw=4 cindent cino^=l0,\:0,N-s
|
|
193
thirdparty/oidn/mkl-dnn/src/common/nstl.hpp
vendored
193
thirdparty/oidn/mkl-dnn/src/common/nstl.hpp
vendored
@ -1,193 +0,0 @@
|
|||||||
/*******************************************************************************
|
|
||||||
* Copyright 2016-2018 Intel Corporation
|
|
||||||
*
|
|
||||||
* Licensed under the Apache License, Version 2.0 (the "License");
|
|
||||||
* you may not use this file except in compliance with the License.
|
|
||||||
* You may obtain a copy of the License at
|
|
||||||
*
|
|
||||||
* http://www.apache.org/licenses/LICENSE-2.0
|
|
||||||
*
|
|
||||||
* Unless required by applicable law or agreed to in writing, software
|
|
||||||
* distributed under the License is distributed on an "AS IS" BASIS,
|
|
||||||
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
|
||||||
* See the License for the specific language governing permissions and
|
|
||||||
* limitations under the License.
|
|
||||||
*******************************************************************************/
|
|
||||||
|
|
||||||
#ifndef NSTL_HPP
|
|
||||||
#define NSTL_HPP
|
|
||||||
|
|
||||||
#include <stdint.h>
|
|
||||||
#include <limits.h>
|
|
||||||
#include <float.h>
|
|
||||||
|
|
||||||
#include <vector>
|
|
||||||
#include <map>
|
|
||||||
|
|
||||||
#include "z_magic.hpp"
|
|
||||||
|
|
||||||
namespace mkldnn {
|
|
||||||
namespace impl {
|
|
||||||
|
|
||||||
void *malloc(size_t size, int alignment);
|
|
||||||
void free(void *p);
|
|
||||||
|
|
||||||
struct c_compatible {
|
|
||||||
enum { default_alignment = 64 };
|
|
||||||
static void *operator new(size_t sz) {
|
|
||||||
return malloc(sz, default_alignment);
|
|
||||||
}
|
|
||||||
static void *operator new(size_t sz, void *p) { UNUSED(sz); return p; }
|
|
||||||
static void *operator new[](size_t sz) {
|
|
||||||
return malloc(sz, default_alignment);
|
|
||||||
}
|
|
||||||
static void operator delete(void *p) { free(p); }
|
|
||||||
static void operator delete[](void *p) { free(p); }
|
|
||||||
};
|
|
||||||
|
|
||||||
namespace nstl {
|
|
||||||
|
|
||||||
template<typename T>
|
|
||||||
inline const T abs(const T& a) {
|
|
||||||
return a >= 0 ? a : -a;
|
|
||||||
}
|
|
||||||
|
|
||||||
template<typename T>
|
|
||||||
inline const T& max(const T& a, const T& b) {
|
|
||||||
return a > b ? a : b;
|
|
||||||
}
|
|
||||||
|
|
||||||
template<typename T>
|
|
||||||
inline const T& min(const T& a, const T& b) {
|
|
||||||
return a < b ? a : b;
|
|
||||||
}
|
|
||||||
|
|
||||||
template<typename T> void swap(T& t1, T& t2) {
|
|
||||||
T tmp(t1);
|
|
||||||
t1 = t2;
|
|
||||||
t2 = tmp;
|
|
||||||
}
|
|
||||||
|
|
||||||
// Rationale: MKL-DNN needs numeric limits implementation that does not
|
|
||||||
// generate dependencies on C++ run-time libraries.
|
|
||||||
|
|
||||||
template<typename T> struct numeric_limits;
|
|
||||||
|
|
||||||
template<> struct numeric_limits<float> {
|
|
||||||
static constexpr float lowest() { return -FLT_MAX; }
|
|
||||||
static constexpr float max() { return FLT_MAX; }
|
|
||||||
};
|
|
||||||
|
|
||||||
template<> struct numeric_limits<int32_t> {
|
|
||||||
static constexpr int lowest() { return INT32_MIN; }
|
|
||||||
static constexpr int max() { return INT32_MAX; }
|
|
||||||
};
|
|
||||||
|
|
||||||
template<> struct numeric_limits<int16_t> {
|
|
||||||
static constexpr int16_t lowest() { return INT16_MIN; }
|
|
||||||
static constexpr int16_t max() { return INT16_MAX; }
|
|
||||||
};
|
|
||||||
|
|
||||||
template<> struct numeric_limits<int8_t> {
|
|
||||||
static constexpr int8_t lowest() { return INT8_MIN; }
|
|
||||||
static constexpr int8_t max() { return INT8_MAX; }
|
|
||||||
};
|
|
||||||
|
|
||||||
template<> struct numeric_limits<uint8_t> {
|
|
||||||
static constexpr uint8_t lowest() { return 0; }
|
|
||||||
static constexpr uint8_t max() { return UINT8_MAX; }
|
|
||||||
};
|
|
||||||
|
|
||||||
template<typename T> struct is_integral
|
|
||||||
{ static constexpr bool value = false; };
|
|
||||||
template<> struct is_integral<int32_t> { static constexpr bool value = true; };
|
|
||||||
template<> struct is_integral<int16_t> { static constexpr bool value = true; };
|
|
||||||
template<> struct is_integral<int8_t> { static constexpr bool value = true; };
|
|
||||||
template<> struct is_integral<uint8_t> { static constexpr bool value = true; };
|
|
||||||
|
|
||||||
template <typename T, typename U> struct is_same
|
|
||||||
{ static constexpr bool value = false; };
|
|
||||||
template <typename T> struct is_same<T, T>
|
|
||||||
{ static constexpr bool value = true; };
|
|
||||||
|
|
||||||
// Rationale: MKL-DNN needs container implementations that do not generate
|
|
||||||
// dependencies on C++ run-time libraries.
|
|
||||||
//
|
|
||||||
// Implementation philosophy: caller is responsible to check if the operation
|
|
||||||
// is valid. The only functions that have to return status are those that
|
|
||||||
// depend on memory allocation or similar operations.
|
|
||||||
//
|
|
||||||
// This means that e.g. an operator [] does not have to check for boundaries.
|
|
||||||
// The caller should have checked the boundaries. If it did not we crash and
|
|
||||||
// burn: this is a bug in MKL-DNN and throwing an exception would not have been
|
|
||||||
// recoverable.
|
|
||||||
//
|
|
||||||
// On the other hand, insert() or resize() or a similar operation needs to
|
|
||||||
// return a status because the outcome depends on factors external to the
|
|
||||||
// caller. The situation is probably also not recoverable also, but MKL-DNN
|
|
||||||
// needs to be nice and report "out of memory" to the users.
|
|
||||||
|
|
||||||
enum nstl_status_t {
|
|
||||||
success = 0,
|
|
||||||
out_of_memory
|
|
||||||
};
|
|
||||||
|
|
||||||
template <typename T> class vector: public c_compatible {
|
|
||||||
private:
|
|
||||||
std::vector<T> _impl;
|
|
||||||
public:
|
|
||||||
typedef typename std::vector<T>::iterator iterator;
|
|
||||||
typedef typename std::vector<T>::const_iterator const_iterator;
|
|
||||||
typedef typename std::vector<T>::size_type size_type;
|
|
||||||
vector() {}
|
|
||||||
vector(size_type n): _impl(n) {}
|
|
||||||
vector(size_type n, const T &value): _impl(n, value) {}
|
|
||||||
template <typename input_iterator>
|
|
||||||
vector(input_iterator first, input_iterator last): _impl(first, last) {}
|
|
||||||
~vector() {}
|
|
||||||
size_type size() const { return _impl.size(); }
|
|
||||||
T& operator[] (size_type i) { return _impl[i]; }
|
|
||||||
const T& operator[] (size_type i) const { return _impl[i]; }
|
|
||||||
iterator begin() { return _impl.begin(); }
|
|
||||||
const_iterator begin() const { return _impl.begin(); }
|
|
||||||
iterator end() { return _impl.end(); }
|
|
||||||
const_iterator end() const { return _impl.end(); }
|
|
||||||
template <typename input_iterator>
|
|
||||||
nstl_status_t insert(iterator pos, input_iterator begin, input_iterator end)
|
|
||||||
{
|
|
||||||
_impl.insert(pos, begin, end);
|
|
||||||
return success;
|
|
||||||
}
|
|
||||||
void clear() { _impl.clear(); }
|
|
||||||
void push_back(const T& t) { _impl.push_back(t); }
|
|
||||||
void resize(size_type count) { _impl.resize(count); }
|
|
||||||
void reserve(size_type count) { _impl.reserve(count); }
|
|
||||||
};
|
|
||||||
|
|
||||||
template <typename Key, typename T> class map: public c_compatible {
|
|
||||||
private:
|
|
||||||
std::map<Key, T> _impl;
|
|
||||||
public:
|
|
||||||
typedef typename std::map<Key, T>::iterator iterator;
|
|
||||||
typedef typename std::map<Key, T>::const_iterator const_iterator;
|
|
||||||
typedef typename std::map<Key, T>::size_type size_type;
|
|
||||||
map() {}
|
|
||||||
~map() {}
|
|
||||||
size_type size() const { return _impl.size(); }
|
|
||||||
T& operator[](const Key &k) { return _impl[k]; }
|
|
||||||
const T& operator[](const Key &k) const { return _impl[k]; }
|
|
||||||
iterator begin() { return _impl.begin(); }
|
|
||||||
const_iterator begin() const { return _impl.begin(); }
|
|
||||||
iterator end() { return _impl.end(); }
|
|
||||||
const_iterator end() const { return _impl.end(); }
|
|
||||||
template <typename input_iterator>
|
|
||||||
void clear() { _impl.clear(); }
|
|
||||||
};
|
|
||||||
|
|
||||||
}
|
|
||||||
}
|
|
||||||
}
|
|
||||||
|
|
||||||
#endif
|
|
||||||
|
|
||||||
// vim: et ts=4 sw=4 cindent cino^=l0,\:0,N-s
|
|
114
thirdparty/oidn/mkl-dnn/src/common/pooling.cpp
vendored
114
thirdparty/oidn/mkl-dnn/src/common/pooling.cpp
vendored
@ -1,114 +0,0 @@
|
|||||||
/*******************************************************************************
|
|
||||||
* Copyright 2016-2018 Intel Corporation
|
|
||||||
*
|
|
||||||
* Licensed under the Apache License, Version 2.0 (the "License");
|
|
||||||
* you may not use this file except in compliance with the License.
|
|
||||||
* You may obtain a copy of the License at
|
|
||||||
*
|
|
||||||
* http://www.apache.org/licenses/LICENSE-2.0
|
|
||||||
*
|
|
||||||
* Unless required by applicable law or agreed to in writing, software
|
|
||||||
* distributed under the License is distributed on an "AS IS" BASIS,
|
|
||||||
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
|
||||||
* See the License for the specific language governing permissions and
|
|
||||||
* limitations under the License.
|
|
||||||
*******************************************************************************/
|
|
||||||
|
|
||||||
#include <assert.h>
|
|
||||||
#include "mkldnn.h"
|
|
||||||
|
|
||||||
#include "c_types_map.hpp"
|
|
||||||
#include "type_helpers.hpp"
|
|
||||||
#include "utils.hpp"
|
|
||||||
|
|
||||||
using namespace mkldnn::impl;
|
|
||||||
using namespace mkldnn::impl::utils;
|
|
||||||
using namespace mkldnn::impl::status;
|
|
||||||
using namespace mkldnn::impl::prop_kind;
|
|
||||||
using namespace mkldnn::impl::alg_kind;
|
|
||||||
using namespace mkldnn::impl::types;
|
|
||||||
|
|
||||||
namespace {
|
|
||||||
status_t pooling_desc_init(pooling_desc_t *pool_desc,
|
|
||||||
prop_kind_t prop_kind, alg_kind_t alg_kind,
|
|
||||||
const memory_desc_t *src_desc, const memory_desc_t *dst_desc,
|
|
||||||
const dims_t strides, const dims_t kernel, const dims_t padding_l,
|
|
||||||
const dims_t padding_r, padding_kind_t padding_kind) {
|
|
||||||
bool args_ok = true
|
|
||||||
&& !any_null(pool_desc, src_desc, dst_desc, strides, kernel, padding_l)
|
|
||||||
&& one_of(alg_kind, pooling_max,
|
|
||||||
pooling_avg_include_padding,
|
|
||||||
pooling_avg_exclude_padding)
|
|
||||||
&& one_of(padding_kind, padding_kind::padding_zero);
|
|
||||||
if (!args_ok) return invalid_arguments;
|
|
||||||
|
|
||||||
if (padding_r == nullptr) padding_r = padding_l;
|
|
||||||
|
|
||||||
auto pd = pooling_desc_t();
|
|
||||||
pd.primitive_kind = primitive_kind::pooling;
|
|
||||||
pd.prop_kind = prop_kind;
|
|
||||||
pd.alg_kind = alg_kind;
|
|
||||||
pd.src_desc.ndims = src_desc->ndims;
|
|
||||||
|
|
||||||
const bool is_fwd = one_of(prop_kind, forward_training, forward_inference);
|
|
||||||
|
|
||||||
pd.diff_src_desc = pd.src_desc = zero_md();
|
|
||||||
pd.diff_dst_desc = pd.dst_desc = zero_md();
|
|
||||||
|
|
||||||
(is_fwd ? pd.src_desc : pd.diff_src_desc) = *src_desc;
|
|
||||||
(is_fwd ? pd.dst_desc : pd.diff_dst_desc) = *dst_desc;
|
|
||||||
|
|
||||||
int sp_dims = src_desc->ndims - 2;
|
|
||||||
utils::array_copy(pd.strides, strides, sp_dims);
|
|
||||||
utils::array_copy(pd.kernel, kernel, sp_dims);
|
|
||||||
utils::array_copy(pd.padding[0], padding_l, sp_dims);
|
|
||||||
utils::array_copy(pd.padding[1], padding_r, sp_dims);
|
|
||||||
|
|
||||||
pd.padding_kind = padding_kind;
|
|
||||||
if (one_of(alg_kind, pooling_max, pooling_avg_include_padding,
|
|
||||||
pooling_avg_exclude_padding)) {
|
|
||||||
pd.accum_data_type = types::default_accum_data_type(
|
|
||||||
src_desc->data_type, dst_desc->data_type);
|
|
||||||
} else {
|
|
||||||
pd.accum_data_type = dst_desc->data_type;
|
|
||||||
}
|
|
||||||
|
|
||||||
bool consistency = true
|
|
||||||
&& utils::one_of(src_desc->ndims, 4, 5)
|
|
||||||
&& utils::one_of(dst_desc->ndims, 4, 5)
|
|
||||||
&& src_desc->dims[0] == dst_desc->dims[0]
|
|
||||||
&& src_desc->dims[1] == dst_desc->dims[1];
|
|
||||||
for (int i = 2; i < src_desc->ndims; ++i)
|
|
||||||
consistency = consistency && (
|
|
||||||
(src_desc->dims[i] - kernel[i - 2] + padding_l[i - 2]
|
|
||||||
+ padding_r[i - 2]) / strides[i - 2] + 1
|
|
||||||
== dst_desc->dims[i]);
|
|
||||||
if (!consistency) return invalid_arguments;
|
|
||||||
|
|
||||||
*pool_desc = pd;
|
|
||||||
return success;
|
|
||||||
}
|
|
||||||
}
|
|
||||||
|
|
||||||
status_t mkldnn_pooling_forward_desc_init(pooling_desc_t *pool_desc,
|
|
||||||
prop_kind_t prop_kind, alg_kind_t alg_kind,
|
|
||||||
const memory_desc_t *src_desc, const memory_desc_t *dst_desc,
|
|
||||||
const dims_t strides, const dims_t kernel, const dims_t padding_l,
|
|
||||||
const dims_t padding_r, padding_kind_t padding_kind) {
|
|
||||||
if (!one_of(prop_kind, forward_training, forward_inference))
|
|
||||||
return invalid_arguments;
|
|
||||||
return pooling_desc_init(pool_desc, prop_kind, alg_kind, src_desc,
|
|
||||||
dst_desc, strides, kernel, padding_l, padding_r, padding_kind);
|
|
||||||
}
|
|
||||||
|
|
||||||
status_t mkldnn_pooling_backward_desc_init(pooling_desc_t *pool_desc,
|
|
||||||
alg_kind_t alg_kind, const memory_desc_t *diff_src_desc,
|
|
||||||
const memory_desc_t *diff_dst_desc, const dims_t strides,
|
|
||||||
const dims_t kernel, const dims_t padding_l, const dims_t padding_r,
|
|
||||||
padding_kind_t padding_kind) {
|
|
||||||
return pooling_desc_init(pool_desc, prop_kind::backward_data, alg_kind,
|
|
||||||
diff_src_desc, diff_dst_desc, strides, kernel, padding_l,
|
|
||||||
padding_r, padding_kind);
|
|
||||||
}
|
|
||||||
|
|
||||||
// vim: et ts=4 sw=4 cindent cino^=l0,\:0,N-s
|
|
238
thirdparty/oidn/mkl-dnn/src/common/pooling_pd.hpp
vendored
238
thirdparty/oidn/mkl-dnn/src/common/pooling_pd.hpp
vendored
@ -1,238 +0,0 @@
|
|||||||
/*******************************************************************************
|
|
||||||
* Copyright 2016-2018 Intel Corporation
|
|
||||||
*
|
|
||||||
* Licensed under the Apache License, Version 2.0 (the "License");
|
|
||||||
* you may not use this file except in compliance with the License.
|
|
||||||
* You may obtain a copy of the License at
|
|
||||||
*
|
|
||||||
* http://www.apache.org/licenses/LICENSE-2.0
|
|
||||||
*
|
|
||||||
* Unless required by applicable law or agreed to in writing, software
|
|
||||||
* distributed under the License is distributed on an "AS IS" BASIS,
|
|
||||||
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
|
||||||
* See the License for the specific language governing permissions and
|
|
||||||
* limitations under the License.
|
|
||||||
*******************************************************************************/
|
|
||||||
|
|
||||||
#ifndef POOLING_PD_HPP
|
|
||||||
#define POOLING_PD_HPP
|
|
||||||
|
|
||||||
#include "mkldnn.h"
|
|
||||||
|
|
||||||
#include "c_types_map.hpp"
|
|
||||||
#include "primitive_desc.hpp"
|
|
||||||
#include "type_helpers.hpp"
|
|
||||||
|
|
||||||
namespace mkldnn {
|
|
||||||
namespace impl {
|
|
||||||
|
|
||||||
struct pooling_fwd_pd_t;
|
|
||||||
|
|
||||||
struct pooling_pd_t: public primitive_desc_t {
|
|
||||||
static constexpr auto base_pkind = primitive_kind::pooling;
|
|
||||||
|
|
||||||
pooling_pd_t(engine_t *engine,
|
|
||||||
const pooling_desc_t *adesc,
|
|
||||||
const primitive_attr_t *attr,
|
|
||||||
const pooling_fwd_pd_t *hint_fwd_pd)
|
|
||||||
: primitive_desc_t(engine, attr, base_pkind)
|
|
||||||
, desc_(*adesc)
|
|
||||||
, hint_fwd_pd_(hint_fwd_pd)
|
|
||||||
, ws_md_()
|
|
||||||
{}
|
|
||||||
|
|
||||||
const pooling_desc_t *desc() const { return &desc_; }
|
|
||||||
virtual const op_desc_t *op_desc() const override
|
|
||||||
{ return reinterpret_cast<const op_desc_t *>(this->desc()); }
|
|
||||||
virtual void init_info() override { impl::init_info(this, this->info_); }
|
|
||||||
|
|
||||||
virtual status_t query(query_t what, int idx, void *result) const override {
|
|
||||||
switch (what) {
|
|
||||||
case query::pooling_d:
|
|
||||||
*(const pooling_desc_t**)result = desc(); break;
|
|
||||||
default: return primitive_desc_t::query(what, idx, result);
|
|
||||||
}
|
|
||||||
return status::success;
|
|
||||||
}
|
|
||||||
|
|
||||||
/* common pooling aux functions */
|
|
||||||
|
|
||||||
dim_t MB() const { return src_desc().dims[0]; }
|
|
||||||
dim_t C() const { return src_desc().dims[1]; }
|
|
||||||
|
|
||||||
dim_t ID() const { return ndims() >= 5 ? src_desc().dims[ndims() - 3] : 1; }
|
|
||||||
dim_t IH() const { return ndims() >= 4 ? src_desc().dims[ndims() - 2] : 1; }
|
|
||||||
dim_t IW() const { return src_desc().dims[ndims() - 1]; }
|
|
||||||
|
|
||||||
dim_t OD() const { return ndims() >= 5 ? dst_desc().dims[ndims() - 3] : 1; }
|
|
||||||
dim_t OH() const { return ndims() >= 4 ? dst_desc().dims[ndims() - 2] : 1; }
|
|
||||||
dim_t OW() const { return dst_desc().dims[ndims() - 1]; }
|
|
||||||
|
|
||||||
dim_t KD() const { return ndims() >= 5 ? desc_.kernel[ndims() - 5] : 1; }
|
|
||||||
dim_t KH() const { return ndims() >= 4 ? desc_.kernel[ndims() - 4] : 1; }
|
|
||||||
dim_t KW() const { return desc_.kernel[ndims() - 3]; }
|
|
||||||
|
|
||||||
dim_t KSD() const { return ndims() >= 5 ? desc_.strides[ndims() - 5] : 1; }
|
|
||||||
dim_t KSH() const { return ndims() >= 4 ? desc_.strides[ndims() - 4] : 1; }
|
|
||||||
dim_t KSW() const { return desc_.strides[ndims() - 3]; }
|
|
||||||
|
|
||||||
dim_t padFront() const
|
|
||||||
{ return ndims() >= 5 ? desc_.padding[0][ndims() - 5] : 0; }
|
|
||||||
dim_t padBack() const
|
|
||||||
{ return ndims() >= 5 ? desc_.padding[1][ndims() - 5] : 0; }
|
|
||||||
dim_t padT() const
|
|
||||||
{ return ndims() >= 4 ? desc_.padding[0][ndims() - 4] : 0; }
|
|
||||||
dim_t padB() const
|
|
||||||
{ return ndims() >= 4 ? desc_.padding[1][ndims() - 4] : 0; }
|
|
||||||
dim_t padL() const { return desc_.padding[0][ndims() - 3]; }
|
|
||||||
dim_t padR() const { return desc_.padding[1][ndims() - 3]; }
|
|
||||||
|
|
||||||
int ndims() const { return src_desc().ndims; }
|
|
||||||
bool is_3d() const { return ndims() == 5; }
|
|
||||||
|
|
||||||
bool has_zero_dim_memory() const
|
|
||||||
{ return memory_desc_wrapper(src_desc()).has_zero_dim(); }
|
|
||||||
|
|
||||||
bool is_fwd() const {
|
|
||||||
return utils::one_of(desc_.prop_kind, prop_kind::forward_training,
|
|
||||||
prop_kind::forward_inference);
|
|
||||||
}
|
|
||||||
|
|
||||||
protected:
|
|
||||||
pooling_desc_t desc_;
|
|
||||||
const pooling_fwd_pd_t *hint_fwd_pd_;
|
|
||||||
|
|
||||||
memory_desc_t ws_md_;
|
|
||||||
|
|
||||||
void init_default_ws() {
|
|
||||||
ws_md_ = is_fwd() ? *dst_md() : *diff_dst_md();
|
|
||||||
ws_md_.data_type = indices_data_type();
|
|
||||||
}
|
|
||||||
|
|
||||||
data_type_t indices_data_type() const {
|
|
||||||
/* the simplest way to express 256... */
|
|
||||||
const int u8_max = nstl::numeric_limits<
|
|
||||||
typename prec_traits<data_type::u8>::type>::max();
|
|
||||||
return utils::array_product(desc()->kernel, ndims()) <= u8_max
|
|
||||||
? data_type::u8 : data_type::s32;
|
|
||||||
}
|
|
||||||
|
|
||||||
private:
|
|
||||||
const memory_desc_t &src_desc() const
|
|
||||||
{ return is_fwd() ? desc_.src_desc : desc_.diff_src_desc; }
|
|
||||||
const memory_desc_t &dst_desc() const
|
|
||||||
{ return is_fwd() ? desc_.dst_desc : desc_.diff_dst_desc; }
|
|
||||||
};
|
|
||||||
|
|
||||||
struct pooling_fwd_pd_t: public pooling_pd_t {
|
|
||||||
typedef pooling_fwd_pd_t base_class;
|
|
||||||
typedef pooling_fwd_pd_t hint_class;
|
|
||||||
|
|
||||||
pooling_fwd_pd_t(engine_t *engine,
|
|
||||||
const pooling_desc_t *adesc,
|
|
||||||
const primitive_attr_t *attr,
|
|
||||||
const pooling_fwd_pd_t *hint_fwd_pd)
|
|
||||||
: pooling_pd_t(engine, adesc, attr, hint_fwd_pd)
|
|
||||||
, src_md_(desc_.src_desc)
|
|
||||||
, dst_md_(desc_.dst_desc)
|
|
||||||
{}
|
|
||||||
|
|
||||||
virtual arg_usage_t arg_usage(primitive_arg_index_t arg) const override {
|
|
||||||
if (arg == MKLDNN_ARG_SRC)
|
|
||||||
return arg_usage_t::input;
|
|
||||||
|
|
||||||
if (arg == MKLDNN_ARG_DST)
|
|
||||||
return arg_usage_t::output;
|
|
||||||
|
|
||||||
if (arg == MKLDNN_ARG_WORKSPACE && (workspace_md() != nullptr))
|
|
||||||
return arg_usage_t::output;
|
|
||||||
|
|
||||||
return primitive_desc_t::arg_usage(arg);
|
|
||||||
}
|
|
||||||
|
|
||||||
virtual const memory_desc_t *src_md(int index = 0) const override
|
|
||||||
{ return index == 0 ? &src_md_ : nullptr; }
|
|
||||||
virtual const memory_desc_t *dst_md(int index = 0) const override
|
|
||||||
{ return index == 0 ? &dst_md_ : nullptr; }
|
|
||||||
virtual const memory_desc_t *workspace_md(int index = 0) const override
|
|
||||||
{ return index == 0 && !types::is_zero_md(&ws_md_) ? &ws_md_ : nullptr; }
|
|
||||||
|
|
||||||
virtual int n_inputs() const override { return 1; }
|
|
||||||
virtual int n_outputs() const override
|
|
||||||
{ return 1 + (workspace_md() != nullptr); }
|
|
||||||
|
|
||||||
protected:
|
|
||||||
memory_desc_t src_md_;
|
|
||||||
memory_desc_t dst_md_;
|
|
||||||
|
|
||||||
virtual status_t set_default_params() {
|
|
||||||
if (dst_md()->format_kind != format_kind::any)
|
|
||||||
return status::success;
|
|
||||||
|
|
||||||
if (src_md()->format_kind != format_kind::blocked)
|
|
||||||
return status::unimplemented;
|
|
||||||
|
|
||||||
return memory_desc_init_by_blocking_desc(dst_md_,
|
|
||||||
src_md_.format_desc.blocking);
|
|
||||||
}
|
|
||||||
};
|
|
||||||
|
|
||||||
struct pooling_bwd_pd_t: public pooling_pd_t {
|
|
||||||
typedef pooling_bwd_pd_t base_class;
|
|
||||||
typedef pooling_fwd_pd_t hint_class;
|
|
||||||
|
|
||||||
pooling_bwd_pd_t(engine_t *engine,
|
|
||||||
const pooling_desc_t *adesc,
|
|
||||||
const primitive_attr_t *attr,
|
|
||||||
const pooling_fwd_pd_t *hint_fwd_pd)
|
|
||||||
: pooling_pd_t(engine, adesc, attr, hint_fwd_pd)
|
|
||||||
, diff_src_md_(desc_.diff_src_desc)
|
|
||||||
, diff_dst_md_(desc_.diff_dst_desc)
|
|
||||||
{}
|
|
||||||
|
|
||||||
virtual arg_usage_t arg_usage(primitive_arg_index_t arg) const override {
|
|
||||||
if (arg == MKLDNN_ARG_DIFF_DST)
|
|
||||||
return arg_usage_t::input;
|
|
||||||
|
|
||||||
if (arg == MKLDNN_ARG_DIFF_SRC)
|
|
||||||
return arg_usage_t::output;
|
|
||||||
|
|
||||||
if (arg == MKLDNN_ARG_WORKSPACE && (workspace_md() != nullptr))
|
|
||||||
return arg_usage_t::input;
|
|
||||||
|
|
||||||
return primitive_desc_t::arg_usage(arg);
|
|
||||||
}
|
|
||||||
|
|
||||||
virtual const memory_desc_t *diff_src_md(int index = 0) const override
|
|
||||||
{ return index == 0 ? &diff_src_md_ : nullptr; }
|
|
||||||
virtual const memory_desc_t *diff_dst_md(int index = 0) const override
|
|
||||||
{ return index == 0 ? &diff_dst_md_ : nullptr; }
|
|
||||||
virtual const memory_desc_t *workspace_md(int index = 0) const override
|
|
||||||
{ return index == 0 && !types::is_zero_md(&ws_md_) ? &ws_md_ : nullptr; }
|
|
||||||
|
|
||||||
virtual int n_inputs() const override
|
|
||||||
{ return 1 + (workspace_md() != nullptr); }
|
|
||||||
virtual int n_outputs() const override { return 1; }
|
|
||||||
|
|
||||||
protected:
|
|
||||||
memory_desc_t diff_src_md_;
|
|
||||||
memory_desc_t diff_dst_md_;
|
|
||||||
|
|
||||||
virtual status_t set_default_params() {
|
|
||||||
if (diff_src_md()->format_kind != format_kind::any)
|
|
||||||
return status::success;
|
|
||||||
|
|
||||||
if (diff_dst_md()->format_kind != format_kind::blocked)
|
|
||||||
return status::unimplemented;
|
|
||||||
|
|
||||||
return memory_desc_init_by_blocking_desc(diff_src_md_,
|
|
||||||
diff_dst_md_.format_desc.blocking);
|
|
||||||
}
|
|
||||||
};
|
|
||||||
|
|
||||||
}
|
|
||||||
}
|
|
||||||
|
|
||||||
#endif
|
|
||||||
|
|
||||||
// vim: et ts=4 sw=4 cindent cino^=l0,\:0,N-s
|
|
103
thirdparty/oidn/mkl-dnn/src/common/primitive.cpp
vendored
103
thirdparty/oidn/mkl-dnn/src/common/primitive.cpp
vendored
@ -1,103 +0,0 @@
|
|||||||
/*******************************************************************************
|
|
||||||
* Copyright 2016-2018 Intel Corporation
|
|
||||||
*
|
|
||||||
* Licensed under the Apache License, Version 2.0 (the "License");
|
|
||||||
* you may not use this file except in compliance with the License.
|
|
||||||
* You may obtain a copy of the License at
|
|
||||||
*
|
|
||||||
* http://www.apache.org/licenses/LICENSE-2.0
|
|
||||||
*
|
|
||||||
* Unless required by applicable law or agreed to in writing, software
|
|
||||||
* distributed under the License is distributed on an "AS IS" BASIS,
|
|
||||||
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
|
||||||
* See the License for the specific language governing permissions and
|
|
||||||
* limitations under the License.
|
|
||||||
*******************************************************************************/
|
|
||||||
|
|
||||||
#include <assert.h>
|
|
||||||
|
|
||||||
#include "c_types_map.hpp"
|
|
||||||
#include "engine.hpp"
|
|
||||||
#include "primitive_desc.hpp"
|
|
||||||
#include "primitive.hpp"
|
|
||||||
#include "type_helpers.hpp"
|
|
||||||
#include "stream.hpp"
|
|
||||||
#include "utils.hpp"
|
|
||||||
|
|
||||||
using namespace mkldnn::impl;
|
|
||||||
using namespace mkldnn::impl::status;
|
|
||||||
using namespace mkldnn::impl::primitive_kind;
|
|
||||||
|
|
||||||
namespace {
|
|
||||||
// XXX: this is a huge hammer. This disables all and any msan checks on
|
|
||||||
// primitives outputs.
|
|
||||||
//
|
|
||||||
// A proper approach would be an implementation-specific unpoisoning.
|
|
||||||
void unpoison_outputs(const exec_args_t &args) {
|
|
||||||
for(const auto &arg: args) {
|
|
||||||
if (arg.second.is_const) continue;
|
|
||||||
auto *mem = arg.second.mem;
|
|
||||||
void *p;
|
|
||||||
mem->get_data_handle(&p);
|
|
||||||
size_t s = memory_desc_wrapper(*mem->md()).size();
|
|
||||||
msan_unpoison(p, s);
|
|
||||||
}
|
|
||||||
}
|
|
||||||
}
|
|
||||||
|
|
||||||
status_t mkldnn_primitive_desc_destroy(primitive_desc_t *primitive_desc) {
|
|
||||||
if (primitive_desc) delete primitive_desc;
|
|
||||||
return success;
|
|
||||||
}
|
|
||||||
|
|
||||||
status_t mkldnn_primitive_create(primitive_t **primitive,
|
|
||||||
const primitive_desc_t *primitive_desc) {
|
|
||||||
if (utils::any_null(primitive, primitive_desc))
|
|
||||||
return invalid_arguments;
|
|
||||||
return primitive_desc->create_primitive(primitive);
|
|
||||||
}
|
|
||||||
|
|
||||||
status_t mkldnn_primitive_execute(const primitive_t *primitive,
|
|
||||||
stream_t *stream, int nargs, const mkldnn_exec_arg_t *c_args) {
|
|
||||||
bool ok = true
|
|
||||||
&& !utils::any_null(primitive, stream)
|
|
||||||
&& primitive->engine() == stream->engine()
|
|
||||||
&& IMPLICATION(nargs > 0, c_args != nullptr);
|
|
||||||
if (!ok) return invalid_arguments;
|
|
||||||
|
|
||||||
exec_args_t args;
|
|
||||||
status_t status = cvt_primtive_args(primitive->pd(), nargs, c_args, args);
|
|
||||||
if (status != status::success) return status;
|
|
||||||
|
|
||||||
exec_ctx_t ctx(stream, std::move(args));
|
|
||||||
|
|
||||||
if (mkldnn_verbose()->level) {
|
|
||||||
double ms = get_msec();
|
|
||||||
status = primitive->execute(ctx);
|
|
||||||
ms = get_msec() - ms;
|
|
||||||
printf("mkldnn_verbose,exec,%s,%g\n", primitive->pd()->info(), ms);
|
|
||||||
fflush(0);
|
|
||||||
} else {
|
|
||||||
status = primitive->execute(ctx);
|
|
||||||
}
|
|
||||||
|
|
||||||
if (msan_enabled) unpoison_outputs(ctx.args());
|
|
||||||
|
|
||||||
return status;
|
|
||||||
}
|
|
||||||
|
|
||||||
status_t mkldnn_primitive_get_primitive_desc(const primitive_t *primitive,
|
|
||||||
const primitive_desc_t **primitive_desc) {
|
|
||||||
if (utils::any_null(primitive, primitive_desc))
|
|
||||||
return invalid_arguments;
|
|
||||||
return safe_ptr_assign<const primitive_desc_t>(*primitive_desc,
|
|
||||||
primitive->pd());
|
|
||||||
}
|
|
||||||
|
|
||||||
status_t mkldnn_primitive_destroy(primitive_t *primitive) {
|
|
||||||
if (primitive != nullptr)
|
|
||||||
delete primitive;
|
|
||||||
return success;
|
|
||||||
}
|
|
||||||
|
|
||||||
// vim: et ts=4 sw=4 cindent cino^=l0,\:0,N-s
|
|
76
thirdparty/oidn/mkl-dnn/src/common/primitive.hpp
vendored
76
thirdparty/oidn/mkl-dnn/src/common/primitive.hpp
vendored
@ -1,76 +0,0 @@
|
|||||||
/*******************************************************************************
|
|
||||||
* Copyright 2016-2018 Intel Corporation
|
|
||||||
*
|
|
||||||
* Licensed under the Apache License, Version 2.0 (the "License");
|
|
||||||
* you may not use this file except in compliance with the License.
|
|
||||||
* You may obtain a copy of the License at
|
|
||||||
*
|
|
||||||
* http://www.apache.org/licenses/LICENSE-2.0
|
|
||||||
*
|
|
||||||
* Unless required by applicable law or agreed to in writing, software
|
|
||||||
* distributed under the License is distributed on an "AS IS" BASIS,
|
|
||||||
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
|
||||||
* See the License for the specific language governing permissions and
|
|
||||||
* limitations under the License.
|
|
||||||
*******************************************************************************/
|
|
||||||
|
|
||||||
#ifndef PRIMITIVE_HPP
|
|
||||||
#define PRIMITIVE_HPP
|
|
||||||
|
|
||||||
#include <assert.h>
|
|
||||||
|
|
||||||
#include "mkldnn.h"
|
|
||||||
|
|
||||||
#include "c_types_map.hpp"
|
|
||||||
#include "nstl.hpp"
|
|
||||||
#include "primitive_desc.hpp"
|
|
||||||
#include "primitive_exec_types.hpp"
|
|
||||||
|
|
||||||
/** \brief A pure virtual primitive class
|
|
||||||
*
|
|
||||||
* Primitive contains links to its inputs & outputs, though it does not track
|
|
||||||
* their readiness on execution step.
|
|
||||||
*
|
|
||||||
* @remark @b Rational.
|
|
||||||
* Dependencies are essential through-out the whole MKL-DNN library, so it
|
|
||||||
* makes sense to include them on the very low level. On the other hand,
|
|
||||||
* tracking them should be a task for corresponding essence, like scheduler,
|
|
||||||
* stream or whatever. Primitive itself should know nothing about the
|
|
||||||
* environment it is running in.
|
|
||||||
*
|
|
||||||
* @note
|
|
||||||
* To make user experience better we should provide API which allows
|
|
||||||
* achieving the best (or good enough) performance when creating primitives
|
|
||||||
* in natural order: i.e. from bottom to top for forward pass and from top to
|
|
||||||
* bottom for backward pass. Please consider restriction [1] in Level 0.
|
|
||||||
*/
|
|
||||||
struct mkldnn_primitive: public mkldnn::impl::c_compatible {
|
|
||||||
mkldnn_primitive(const mkldnn::impl::primitive_desc_t *pd)
|
|
||||||
: pd_(pd->clone()) {}
|
|
||||||
virtual ~mkldnn_primitive() { delete pd_; }
|
|
||||||
|
|
||||||
/** returns primitive's engine */
|
|
||||||
mkldnn::impl::engine_t *engine() const { return pd_->engine(); }
|
|
||||||
/** returns primitive's inputs */
|
|
||||||
const mkldnn::impl::primitive_desc_t *pd() const { return pd_; }
|
|
||||||
/** returns primitive's kind */
|
|
||||||
mkldnn::impl::primitive_kind_t kind() const { return pd_->kind(); }
|
|
||||||
|
|
||||||
/** executes primitive with execution context @p ctx */
|
|
||||||
virtual mkldnn::impl::status_t execute(const mkldnn::impl::exec_ctx_t &ctx)
|
|
||||||
const = 0;
|
|
||||||
|
|
||||||
protected:
|
|
||||||
const mkldnn::impl::primitive_desc_t *pd_;
|
|
||||||
|
|
||||||
private:
|
|
||||||
mkldnn_primitive() = delete;
|
|
||||||
mkldnn_primitive(const mkldnn_primitive &) = delete;
|
|
||||||
mkldnn_primitive(mkldnn_primitive &&) = delete;
|
|
||||||
mkldnn_primitive &operator=(const mkldnn_primitive &) = delete;
|
|
||||||
mkldnn_primitive &operator=(mkldnn_primitive &&) = delete;
|
|
||||||
};
|
|
||||||
|
|
||||||
#endif
|
|
||||||
|
|
||||||
// vim: et ts=4 sw=4 cindent cino^=l0,\:0,N-s
|
|
@ -1,290 +0,0 @@
|
|||||||
/*******************************************************************************
|
|
||||||
* Copyright 2017-2018 Intel Corporation
|
|
||||||
*
|
|
||||||
* Licensed under the Apache License, Version 2.0 (the "License");
|
|
||||||
* you may not use this file except in compliance with the License.
|
|
||||||
* You may obtain a copy of the License at
|
|
||||||
*
|
|
||||||
* http://www.apache.org/licenses/LICENSE-2.0
|
|
||||||
*
|
|
||||||
* Unless required by applicable law or agreed to in writing, software
|
|
||||||
* distributed under the License is distributed on an "AS IS" BASIS,
|
|
||||||
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
|
||||||
* See the License for the specific language governing permissions and
|
|
||||||
* limitations under the License.
|
|
||||||
*******************************************************************************/
|
|
||||||
|
|
||||||
#include "mkldnn.h"
|
|
||||||
|
|
||||||
#include "c_types_map.hpp"
|
|
||||||
#include "primitive_attr.hpp"
|
|
||||||
#include "type_helpers.hpp"
|
|
||||||
#include "utils.hpp"
|
|
||||||
|
|
||||||
using namespace mkldnn::impl;
|
|
||||||
using namespace mkldnn::impl::status;
|
|
||||||
using namespace mkldnn::impl::utils;
|
|
||||||
|
|
||||||
namespace mkldnn {
|
|
||||||
namespace impl {
|
|
||||||
|
|
||||||
status_t scales_t::set(dim_t count, int mask, const float *scales) {
|
|
||||||
cleanup();
|
|
||||||
|
|
||||||
count_ = count;
|
|
||||||
mask_ = mask;
|
|
||||||
|
|
||||||
if (count_ == 1) {
|
|
||||||
scales_ = scales_buf_;
|
|
||||||
utils::array_set(scales_, scales[0], scales_buf_size);
|
|
||||||
} else {
|
|
||||||
scales_ = (float *)impl::malloc(count_ * sizeof(*scales_), 64);
|
|
||||||
if (scales_ == nullptr)
|
|
||||||
return status::out_of_memory;
|
|
||||||
|
|
||||||
for (dim_t c = 0; c < count_; ++c)
|
|
||||||
scales_[c] = scales[c];
|
|
||||||
}
|
|
||||||
|
|
||||||
return status::success;
|
|
||||||
}
|
|
||||||
|
|
||||||
}
|
|
||||||
}
|
|
||||||
|
|
||||||
status_t post_ops_t::append_sum(float scale) {
|
|
||||||
if (len_ == capacity)
|
|
||||||
return out_of_memory;
|
|
||||||
|
|
||||||
entry_[len_].kind = primitive_kind::sum;
|
|
||||||
entry_[len_].sum.scale = scale;
|
|
||||||
|
|
||||||
len_++;
|
|
||||||
|
|
||||||
return success;
|
|
||||||
}
|
|
||||||
|
|
||||||
status_t post_ops_t::append_eltwise(float scale, alg_kind_t alg, float alpha,
|
|
||||||
float beta) {
|
|
||||||
using namespace mkldnn::impl::alg_kind;
|
|
||||||
bool known_alg = one_of(alg, eltwise_relu, eltwise_tanh, eltwise_elu,
|
|
||||||
eltwise_square, eltwise_abs, eltwise_sqrt, eltwise_linear,
|
|
||||||
eltwise_bounded_relu, eltwise_soft_relu, eltwise_logistic);
|
|
||||||
if (!known_alg)
|
|
||||||
return invalid_arguments;
|
|
||||||
|
|
||||||
if (len_ == capacity)
|
|
||||||
return out_of_memory;
|
|
||||||
|
|
||||||
entry_[len_].kind = primitive_kind::eltwise;
|
|
||||||
entry_[len_].eltwise.scale = scale;
|
|
||||||
entry_[len_].eltwise.alg = alg;
|
|
||||||
entry_[len_].eltwise.alpha = alpha;
|
|
||||||
entry_[len_].eltwise.beta = beta;
|
|
||||||
|
|
||||||
len_++;
|
|
||||||
|
|
||||||
return success;
|
|
||||||
}
|
|
||||||
|
|
||||||
status_t primitive_attr_t::set_scratchpad_mode(
|
|
||||||
scratchpad_mode_t scratchpad_mode) {
|
|
||||||
using namespace mkldnn::impl::scratchpad_mode;
|
|
||||||
|
|
||||||
const bool ok = one_of(scratchpad_mode, library, user);
|
|
||||||
if (!ok)
|
|
||||||
return invalid_arguments;
|
|
||||||
|
|
||||||
scratchpad_mode_ = scratchpad_mode;
|
|
||||||
return success;
|
|
||||||
}
|
|
||||||
|
|
||||||
status_t primitive_attr_t::set_post_ops(const post_ops_t &post_ops) {
|
|
||||||
this->post_ops_ = post_ops;
|
|
||||||
return success;
|
|
||||||
}
|
|
||||||
|
|
||||||
/* Public C API */
|
|
||||||
|
|
||||||
status_t mkldnn_primitive_attr_create(primitive_attr_t **attr) {
|
|
||||||
if (attr == nullptr)
|
|
||||||
return invalid_arguments;
|
|
||||||
|
|
||||||
return safe_ptr_assign<mkldnn_primitive_attr>(*attr,
|
|
||||||
new mkldnn_primitive_attr);
|
|
||||||
}
|
|
||||||
|
|
||||||
status_t mkldnn_primitive_attr_clone(primitive_attr_t **attr,
|
|
||||||
const primitive_attr_t *existing_attr) {
|
|
||||||
if (any_null(attr, existing_attr))
|
|
||||||
return invalid_arguments;
|
|
||||||
|
|
||||||
return safe_ptr_assign<mkldnn_primitive_attr>(*attr,
|
|
||||||
existing_attr->clone());
|
|
||||||
}
|
|
||||||
|
|
||||||
status_t mkldnn_primitive_attr_destroy(primitive_attr_t *attr) {
|
|
||||||
if (attr)
|
|
||||||
delete attr;
|
|
||||||
|
|
||||||
return success;
|
|
||||||
}
|
|
||||||
|
|
||||||
status_t mkldnn_primitive_attr_get_scratchpad_mode(
|
|
||||||
const primitive_attr_t *attr, scratchpad_mode_t *scratchpad_mode) {
|
|
||||||
if (any_null(attr, scratchpad_mode))
|
|
||||||
return invalid_arguments;
|
|
||||||
|
|
||||||
*scratchpad_mode = attr->scratchpad_mode_;
|
|
||||||
|
|
||||||
return success;
|
|
||||||
}
|
|
||||||
|
|
||||||
status_t mkldnn_primitive_attr_set_scratchpad_mode(
|
|
||||||
primitive_attr_t *attr, scratchpad_mode_t scratchpad_mode) {
|
|
||||||
if (any_null(attr))
|
|
||||||
return invalid_arguments;
|
|
||||||
|
|
||||||
return attr->set_scratchpad_mode(scratchpad_mode);
|
|
||||||
}
|
|
||||||
|
|
||||||
status_t mkldnn_primitive_attr_get_output_scales(const primitive_attr_t *attr,
|
|
||||||
dim_t *count, int *mask, const float **scales) {
|
|
||||||
if (any_null(attr, count, mask, scales))
|
|
||||||
return invalid_arguments;
|
|
||||||
|
|
||||||
*count = attr->output_scales_.count_;
|
|
||||||
*mask = attr->output_scales_.mask_;
|
|
||||||
*scales = attr->output_scales_.scales_;
|
|
||||||
|
|
||||||
return success;
|
|
||||||
}
|
|
||||||
|
|
||||||
status_t mkldnn_primitive_attr_set_output_scales(primitive_attr_t *attr,
|
|
||||||
dim_t count, int mask, const float *scales) {
|
|
||||||
bool ok = !any_null(attr, scales) && count > 0 && mask >= 0;
|
|
||||||
if (!ok)
|
|
||||||
return invalid_arguments;
|
|
||||||
|
|
||||||
return attr->output_scales_.set(count, mask, scales);
|
|
||||||
}
|
|
||||||
|
|
||||||
status_t mkldnn_primitive_attr_get_post_ops(const primitive_attr_t *attr,
|
|
||||||
const post_ops_t **post_ops) {
|
|
||||||
if (any_null(attr, post_ops))
|
|
||||||
return invalid_arguments;
|
|
||||||
|
|
||||||
*post_ops = &attr->post_ops_;
|
|
||||||
return success;
|
|
||||||
}
|
|
||||||
|
|
||||||
status_t mkldnn_primitive_attr_set_post_ops(primitive_attr_t *attr,
|
|
||||||
const post_ops_t *post_ops) {
|
|
||||||
if (any_null(attr, post_ops))
|
|
||||||
return invalid_arguments;
|
|
||||||
|
|
||||||
return attr->set_post_ops(*post_ops);
|
|
||||||
}
|
|
||||||
|
|
||||||
status_t mkldnn_post_ops_create(post_ops_t **post_ops) {
|
|
||||||
if (post_ops == nullptr)
|
|
||||||
return invalid_arguments;
|
|
||||||
|
|
||||||
return safe_ptr_assign<mkldnn_post_ops>(*post_ops, new mkldnn_post_ops);
|
|
||||||
}
|
|
||||||
|
|
||||||
status_t mkldnn_post_ops_destroy(post_ops_t *post_ops) {
|
|
||||||
if (post_ops)
|
|
||||||
delete post_ops;
|
|
||||||
|
|
||||||
return success;
|
|
||||||
}
|
|
||||||
|
|
||||||
int mkldnn_post_ops_len(const post_ops_t *post_ops) {
|
|
||||||
if (post_ops)
|
|
||||||
return post_ops->len_;
|
|
||||||
|
|
||||||
return 0;
|
|
||||||
}
|
|
||||||
|
|
||||||
primitive_kind_t mkldnn_post_ops_get_kind(const post_ops_t *post_ops,
|
|
||||||
int index) {
|
|
||||||
bool ok = post_ops && 0 <= index && index < post_ops->len_;
|
|
||||||
if (!ok)
|
|
||||||
return primitive_kind::undefined;
|
|
||||||
|
|
||||||
return post_ops->entry_[index].kind;
|
|
||||||
}
|
|
||||||
|
|
||||||
status_t mkldnn_post_ops_append_sum(post_ops_t *post_ops, float scale) {
|
|
||||||
if (post_ops == nullptr)
|
|
||||||
return invalid_arguments;
|
|
||||||
|
|
||||||
return post_ops->append_sum(scale);
|
|
||||||
}
|
|
||||||
|
|
||||||
namespace {
|
|
||||||
bool simple_get_params_check(const post_ops_t *post_ops, int index,
|
|
||||||
primitive_kind_t kind) {
|
|
||||||
bool ok = true
|
|
||||||
&& post_ops != nullptr
|
|
||||||
&& 0 <= index
|
|
||||||
&& index < post_ops->len_
|
|
||||||
&& post_ops->entry_[index].kind == kind;
|
|
||||||
return ok;
|
|
||||||
}
|
|
||||||
}
|
|
||||||
|
|
||||||
status_t mkldnn_post_ops_get_params_sum(const post_ops_t *post_ops, int index,
|
|
||||||
float *scale) {
|
|
||||||
bool ok = true
|
|
||||||
&& simple_get_params_check(post_ops, index, primitive_kind::sum)
|
|
||||||
&& !any_null(scale);
|
|
||||||
if (!ok)
|
|
||||||
return invalid_arguments;
|
|
||||||
|
|
||||||
*scale = post_ops->entry_[index].sum.scale;
|
|
||||||
return success;
|
|
||||||
}
|
|
||||||
|
|
||||||
status_t mkldnn_post_ops_append_eltwise(post_ops_t *post_ops, float scale,
|
|
||||||
alg_kind_t kind, float alpha, float beta) {
|
|
||||||
if (post_ops == nullptr)
|
|
||||||
return invalid_arguments;
|
|
||||||
|
|
||||||
return post_ops->append_eltwise(scale, kind, alpha, beta);
|
|
||||||
}
|
|
||||||
|
|
||||||
status_t mkldnn_post_ops_get_params_eltwise(const post_ops_t *post_ops,
|
|
||||||
int index, float *scale, alg_kind_t *alg, float *alpha, float *beta) {
|
|
||||||
bool ok = true
|
|
||||||
&& simple_get_params_check(post_ops, index, primitive_kind::eltwise)
|
|
||||||
&& !any_null(scale, alpha, beta);
|
|
||||||
if (!ok)
|
|
||||||
return invalid_arguments;
|
|
||||||
|
|
||||||
const auto &e = post_ops->entry_[index].eltwise;
|
|
||||||
*scale = e.scale;
|
|
||||||
*alg = e.alg;
|
|
||||||
*alpha = e.alpha;
|
|
||||||
*beta = e.beta;
|
|
||||||
|
|
||||||
return success;
|
|
||||||
}
|
|
||||||
|
|
||||||
status_t mkldnn_primitive_attr_set_rnn_data_qparams(
|
|
||||||
primitive_attr_t *attr, const float scale, const float shift) {
|
|
||||||
if (attr == nullptr)
|
|
||||||
return invalid_arguments;
|
|
||||||
|
|
||||||
return attr->rnn_data_qparams_.set(scale, shift);
|
|
||||||
}
|
|
||||||
|
|
||||||
status_t mkldnn_primitive_attr_set_rnn_weights_qparams(
|
|
||||||
primitive_attr_t *attr, dim_t count, int mask, const float *scales) {
|
|
||||||
bool ok = !any_null(attr, scales) && count > 0 && mask >= 0;
|
|
||||||
if (!ok)
|
|
||||||
return invalid_arguments;
|
|
||||||
|
|
||||||
return attr->rnn_weights_qparams_.set(count, mask, scales);
|
|
||||||
}
|
|
@ -1,183 +0,0 @@
|
|||||||
/*******************************************************************************
|
|
||||||
* Copyright 2017-2018 Intel Corporation
|
|
||||||
*
|
|
||||||
* Licensed under the Apache License, Version 2.0 (the "License");
|
|
||||||
* you may not use this file except in compliance with the License.
|
|
||||||
* You may obtain a copy of the License at
|
|
||||||
*
|
|
||||||
* http://www.apache.org/licenses/LICENSE-2.0
|
|
||||||
*
|
|
||||||
* Unless required by applicable law or agreed to in writing, software
|
|
||||||
* distributed under the License is distributed on an "AS IS" BASIS,
|
|
||||||
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
|
||||||
* See the License for the specific language governing permissions and
|
|
||||||
* limitations under the License.
|
|
||||||
*******************************************************************************/
|
|
||||||
|
|
||||||
#ifndef PRIMITIVE_ATTR_HPP
|
|
||||||
#define PRIMITIVE_ATTR_HPP
|
|
||||||
|
|
||||||
#include "mkldnn.h"
|
|
||||||
|
|
||||||
#include "c_types_map.hpp"
|
|
||||||
#include "nstl.hpp"
|
|
||||||
#include "utils.hpp"
|
|
||||||
|
|
||||||
namespace mkldnn {
|
|
||||||
namespace impl {
|
|
||||||
|
|
||||||
struct rnn_data_qparams_t : public c_compatible {
|
|
||||||
rnn_data_qparams_t() : scale_(1.), shift_(0.) {}
|
|
||||||
bool has_default_values() const { return (scale_ == 1. && shift_ == 0.); }
|
|
||||||
|
|
||||||
status_t set(float scale, float shift) {
|
|
||||||
scale_ = scale;
|
|
||||||
shift_ = shift;
|
|
||||||
return status::success;
|
|
||||||
}
|
|
||||||
|
|
||||||
float scale_;
|
|
||||||
float shift_;
|
|
||||||
};
|
|
||||||
|
|
||||||
struct scales_t: public c_compatible {
|
|
||||||
scales_t(): count_(1), mask_(0), scales_(scales_buf_)
|
|
||||||
{ set(1.); }
|
|
||||||
|
|
||||||
scales_t(const scales_t &rhs): scales_t()
|
|
||||||
{ set(rhs.count_, rhs.mask_, rhs.scales_); }
|
|
||||||
|
|
||||||
~scales_t() { cleanup(); }
|
|
||||||
|
|
||||||
scales_t &operator=(const scales_t &rhs) {
|
|
||||||
if (&rhs == this)
|
|
||||||
return *this;
|
|
||||||
status_t status = set(rhs.count_, rhs.mask_, rhs.scales_);
|
|
||||||
assert(status == status::success);
|
|
||||||
(void)status;
|
|
||||||
return *this;
|
|
||||||
}
|
|
||||||
|
|
||||||
bool has_default_values() const {
|
|
||||||
for (dim_t c = 0; c < count_; ++c) {
|
|
||||||
if(scales_[c] != 1.) return false;
|
|
||||||
}
|
|
||||||
return true;
|
|
||||||
}
|
|
||||||
|
|
||||||
status_t set(dim_t count, int mask, const float *scales);
|
|
||||||
status_t set(float single_scale) { return this->set(1, 0, &single_scale); }
|
|
||||||
|
|
||||||
dim_t count_;
|
|
||||||
int mask_;
|
|
||||||
float *scales_;
|
|
||||||
|
|
||||||
private:
|
|
||||||
enum { scales_buf_size = 16 };
|
|
||||||
float scales_buf_[scales_buf_size];
|
|
||||||
|
|
||||||
void cleanup() {
|
|
||||||
if (scales_ != scales_buf_ && scales_ != nullptr)
|
|
||||||
impl::free(scales_);
|
|
||||||
|
|
||||||
count_ = 1;
|
|
||||||
mask_ = 0;
|
|
||||||
scales_ = scales_buf_;
|
|
||||||
}
|
|
||||||
};
|
|
||||||
|
|
||||||
}
|
|
||||||
}
|
|
||||||
|
|
||||||
struct mkldnn_post_ops: public mkldnn::impl::c_compatible {
|
|
||||||
struct entry_t {
|
|
||||||
struct eltwise_t {
|
|
||||||
mkldnn::impl::alg_kind_t alg;
|
|
||||||
float scale, alpha, beta;
|
|
||||||
};
|
|
||||||
|
|
||||||
mkldnn::impl::primitive_kind_t kind;
|
|
||||||
union {
|
|
||||||
struct { float scale; } sum;
|
|
||||||
eltwise_t eltwise;
|
|
||||||
};
|
|
||||||
|
|
||||||
bool is_eltwise(bool require_scale_one = true) const {
|
|
||||||
using namespace mkldnn::impl;
|
|
||||||
return kind == primitive_kind::eltwise
|
|
||||||
&& IMPLICATION(require_scale_one, eltwise.scale == 1.f);
|
|
||||||
}
|
|
||||||
|
|
||||||
bool is_relu(bool require_scale_one = true,
|
|
||||||
bool require_nslope_zero = true) const {
|
|
||||||
using namespace mkldnn::impl;
|
|
||||||
return is_eltwise(require_scale_one)
|
|
||||||
&& eltwise.alg == alg_kind::eltwise_relu
|
|
||||||
&& IMPLICATION(require_nslope_zero, eltwise.alpha == 0.f);
|
|
||||||
}
|
|
||||||
|
|
||||||
bool is_sum(bool require_scale_one = true) const {
|
|
||||||
using namespace mkldnn::impl;
|
|
||||||
return kind == primitive_kind::sum
|
|
||||||
&& IMPLICATION(require_scale_one, sum.scale == 1.f);
|
|
||||||
}
|
|
||||||
};
|
|
||||||
|
|
||||||
mkldnn_post_ops(): len_(0) {}
|
|
||||||
|
|
||||||
mkldnn::impl::status_t append_sum(float scale);
|
|
||||||
mkldnn::impl::status_t append_eltwise(float scale,
|
|
||||||
mkldnn::impl::alg_kind_t alg, float alpha, float beta);
|
|
||||||
|
|
||||||
int find(mkldnn::impl::primitive_kind_t kind, int start = 0,
|
|
||||||
int stop = -1) const {
|
|
||||||
if (stop == -1) stop = len_;
|
|
||||||
stop = mkldnn::impl::nstl::min(stop, len_);
|
|
||||||
for (int idx = start; idx < stop; ++idx)
|
|
||||||
if (entry_[idx].kind == kind) return idx;
|
|
||||||
return -1;
|
|
||||||
}
|
|
||||||
|
|
||||||
bool has_default_values() const { return len_ == 0; }
|
|
||||||
|
|
||||||
bool contain(mkldnn::impl::primitive_kind_t kind, int index) const
|
|
||||||
{ return find(kind, index, index + 1) == index; }
|
|
||||||
|
|
||||||
enum { capacity = 4 };
|
|
||||||
|
|
||||||
int len_;
|
|
||||||
entry_t entry_[capacity];
|
|
||||||
};
|
|
||||||
|
|
||||||
struct mkldnn_primitive_attr: public mkldnn::impl::c_compatible {
|
|
||||||
mkldnn_primitive_attr()
|
|
||||||
: scratchpad_mode_(mkldnn::impl::scratchpad_mode::library)
|
|
||||||
{}
|
|
||||||
|
|
||||||
mkldnn_primitive_attr *clone() const
|
|
||||||
{ return new mkldnn_primitive_attr(*this); }
|
|
||||||
|
|
||||||
/** Returns true if the attributes have default values.
|
|
||||||
*
|
|
||||||
* @note The scratchpad_mode_ is not take into account */
|
|
||||||
bool has_default_values() const {
|
|
||||||
return true
|
|
||||||
&& output_scales_.has_default_values()
|
|
||||||
&& post_ops_.has_default_values()
|
|
||||||
&& rnn_data_qparams_.has_default_values()
|
|
||||||
&& rnn_weights_qparams_.has_default_values();
|
|
||||||
}
|
|
||||||
|
|
||||||
mkldnn::impl::status_t set_scratchpad_mode(
|
|
||||||
mkldnn::impl::scratchpad_mode_t scratchpad_mode);
|
|
||||||
mkldnn::impl::status_t set_post_ops(
|
|
||||||
const mkldnn::impl::post_ops_t &post_ops);
|
|
||||||
|
|
||||||
mkldnn::impl::scratchpad_mode_t scratchpad_mode_;
|
|
||||||
mkldnn::impl::scales_t output_scales_;
|
|
||||||
mkldnn::impl::post_ops_t post_ops_;
|
|
||||||
mkldnn::impl::rnn_data_qparams_t rnn_data_qparams_;
|
|
||||||
mkldnn::impl::scales_t rnn_weights_qparams_;
|
|
||||||
};
|
|
||||||
|
|
||||||
#endif
|
|
@ -1,78 +0,0 @@
|
|||||||
/*******************************************************************************
|
|
||||||
* Copyright 2016-2018 Intel Corporation
|
|
||||||
*
|
|
||||||
* Licensed under the Apache License, Version 2.0 (the "License");
|
|
||||||
* you may not use this file except in compliance with the License.
|
|
||||||
* You may obtain a copy of the License at
|
|
||||||
*
|
|
||||||
* http://www.apache.org/licenses/LICENSE-2.0
|
|
||||||
*
|
|
||||||
* Unless required by applicable law or agreed to in writing, software
|
|
||||||
* distributed under the License is distributed on an "AS IS" BASIS,
|
|
||||||
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
|
||||||
* See the License for the specific language governing permissions and
|
|
||||||
* limitations under the License.
|
|
||||||
*******************************************************************************/
|
|
||||||
|
|
||||||
#include "mkldnn.h"
|
|
||||||
|
|
||||||
#include "c_types_map.hpp"
|
|
||||||
#include "nstl.hpp"
|
|
||||||
#include "primitive_desc.hpp"
|
|
||||||
|
|
||||||
using namespace mkldnn::impl;
|
|
||||||
using namespace mkldnn::impl::status;
|
|
||||||
|
|
||||||
status_t primitive_desc_t::query(query_t what, int idx, void *result) const {
|
|
||||||
auto safe_ret_md = [&](const memory_desc_t *_) {
|
|
||||||
if (_ == nullptr) return not_required;
|
|
||||||
*(const memory_desc_t **)result = _;
|
|
||||||
return success;
|
|
||||||
};
|
|
||||||
|
|
||||||
switch (what) {
|
|
||||||
case query::engine: *(engine_t**)result = engine(); break;
|
|
||||||
case query::primitive_kind: *(primitive_kind_t*)result = kind(); break;
|
|
||||||
|
|
||||||
case query::scratchpad_engine:
|
|
||||||
*(engine_t**)result = scratchpad_engine(); break;
|
|
||||||
|
|
||||||
case query::memory_consumption_s64:
|
|
||||||
*(dim_t *)result = scratchpad_size(scratchpad_mode::library); break;
|
|
||||||
|
|
||||||
case query::op_d:
|
|
||||||
if (idx != 0 || op_desc() == nullptr) return invalid_arguments;
|
|
||||||
*(const_c_op_desc_t *)result
|
|
||||||
= static_cast<const_c_op_desc_t>(op_desc()); break;
|
|
||||||
|
|
||||||
case query::src_md: return safe_ret_md(src_md(idx));
|
|
||||||
case query::diff_src_md: return safe_ret_md(diff_src_md(idx));
|
|
||||||
case query::dst_md: return safe_ret_md(dst_md(idx));
|
|
||||||
case query::diff_dst_md: return safe_ret_md(diff_dst_md(idx));
|
|
||||||
case query::weights_md: return safe_ret_md(weights_md(idx));
|
|
||||||
case query::diff_weights_md: return safe_ret_md(diff_weights_md(idx));
|
|
||||||
case query::workspace_md:
|
|
||||||
if (idx != 0) return status::invalid_arguments;
|
|
||||||
return safe_ret_md(workspace_md(idx));
|
|
||||||
case query::scratchpad_md:
|
|
||||||
if (idx != 0) return status::invalid_arguments;
|
|
||||||
return safe_ret_md(scratchpad_md(idx));
|
|
||||||
|
|
||||||
case query::num_of_inputs_s32: *(int*)result = n_inputs(); break;
|
|
||||||
case query::num_of_outputs_s32: *(int*)result = n_outputs(); break;
|
|
||||||
|
|
||||||
case query::impl_info_str: *(const char **)result = name(); break;
|
|
||||||
|
|
||||||
default: return unimplemented;
|
|
||||||
}
|
|
||||||
return success;
|
|
||||||
}
|
|
||||||
|
|
||||||
status_t mkldnn_primitive_desc_get_attr(const primitive_desc_t *primitive_desc,
|
|
||||||
const primitive_attr_t **attr) {
|
|
||||||
if (utils::any_null(primitive_desc, attr))
|
|
||||||
return invalid_arguments;
|
|
||||||
|
|
||||||
*attr = primitive_desc->attr();
|
|
||||||
return success;
|
|
||||||
}
|
|
@ -1,174 +0,0 @@
|
|||||||
/*******************************************************************************
|
|
||||||
* Copyright 2016-2018 Intel Corporation
|
|
||||||
*
|
|
||||||
* Licensed under the Apache License, Version 2.0 (the "License");
|
|
||||||
* you may not use this file except in compliance with the License.
|
|
||||||
* You may obtain a copy of the License at
|
|
||||||
*
|
|
||||||
* http://www.apache.org/licenses/LICENSE-2.0
|
|
||||||
*
|
|
||||||
* Unless required by applicable law or agreed to in writing, software
|
|
||||||
* distributed under the License is distributed on an "AS IS" BASIS,
|
|
||||||
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
|
||||||
* See the License for the specific language governing permissions and
|
|
||||||
* limitations under the License.
|
|
||||||
*******************************************************************************/
|
|
||||||
|
|
||||||
#ifndef PRIMITIVE_DESC_HPP
|
|
||||||
#define PRIMITIVE_DESC_HPP
|
|
||||||
|
|
||||||
#include "mkldnn.h"
|
|
||||||
|
|
||||||
#include "c_types_map.hpp"
|
|
||||||
#include "memory_tracking.hpp"
|
|
||||||
#include "nstl.hpp"
|
|
||||||
#include "type_helpers.hpp"
|
|
||||||
#include "primitive_attr.hpp"
|
|
||||||
#include "verbose.hpp"
|
|
||||||
|
|
||||||
struct mkldnn_primitive_desc: public mkldnn::impl::c_compatible {
|
|
||||||
using md_t = mkldnn::impl::memory_desc_t;
|
|
||||||
|
|
||||||
mkldnn_primitive_desc(mkldnn::impl::engine_t *engine,
|
|
||||||
const mkldnn::impl::primitive_attr_t *attr,
|
|
||||||
mkldnn::impl::primitive_kind_t kind)
|
|
||||||
: engine_(engine), attr_(*attr), kind_(kind) { info_[0] = '\0'; }
|
|
||||||
|
|
||||||
mkldnn_primitive_desc(mkldnn::impl::engine_t *engine,
|
|
||||||
mkldnn::impl::primitive_kind_t kind)
|
|
||||||
: engine_(engine), kind_(kind) { info_[0] = '\0'; }
|
|
||||||
|
|
||||||
virtual mkldnn_primitive_desc *clone() const = 0;
|
|
||||||
virtual ~mkldnn_primitive_desc() {}
|
|
||||||
|
|
||||||
const mkldnn::impl::primitive_attr_t *attr() const { return &attr_; }
|
|
||||||
mkldnn::impl::engine_t *engine() const { return engine_; }
|
|
||||||
mkldnn::impl::primitive_kind_t kind() const { return kind_; }
|
|
||||||
|
|
||||||
virtual void init_info() {}
|
|
||||||
const char *info() const { return info_; }
|
|
||||||
|
|
||||||
mkldnn::impl::memory_tracking::registry_t &scratchpad_registry()
|
|
||||||
{ return scratchpad_registry_; }
|
|
||||||
const mkldnn::impl::memory_tracking::registry_t &scratchpad_registry() const
|
|
||||||
{ return scratchpad_registry_; }
|
|
||||||
virtual mkldnn::impl::engine_t *scratchpad_engine() const
|
|
||||||
{ return engine_; }
|
|
||||||
|
|
||||||
virtual const mkldnn::impl::op_desc_t *op_desc() const { return nullptr; }
|
|
||||||
|
|
||||||
enum class arg_usage_t { unused, input, output };
|
|
||||||
virtual arg_usage_t arg_usage(
|
|
||||||
mkldnn::impl::primitive_arg_index_t arg) const {
|
|
||||||
using mkldnn::impl::types::is_zero_md;
|
|
||||||
if (arg == MKLDNN_ARG_SCRATCHPAD && !is_zero_md(scratchpad_md()))
|
|
||||||
return arg_usage_t::output;
|
|
||||||
return arg_usage_t::unused;
|
|
||||||
}
|
|
||||||
|
|
||||||
# define DECLARE_MD_STUB(stub) \
|
|
||||||
virtual const mkldnn::impl::memory_desc_t *stub(int idx = 0) const \
|
|
||||||
{ return nullptr; }
|
|
||||||
|
|
||||||
DECLARE_MD_STUB(input_md); DECLARE_MD_STUB(output_md);
|
|
||||||
DECLARE_MD_STUB(src_md); DECLARE_MD_STUB(diff_src_md);
|
|
||||||
DECLARE_MD_STUB(dst_md); DECLARE_MD_STUB(diff_dst_md);
|
|
||||||
DECLARE_MD_STUB(weights_md); DECLARE_MD_STUB(diff_weights_md);
|
|
||||||
DECLARE_MD_STUB(workspace_md);
|
|
||||||
# undef DECLARE_MD_STUB
|
|
||||||
|
|
||||||
const mkldnn::impl::memory_desc_t *scratchpad_md(int idx = 0) const {
|
|
||||||
return idx == 0 ? &scratchpad_md_ : nullptr;
|
|
||||||
}
|
|
||||||
|
|
||||||
virtual void init_scratchpad_md() {
|
|
||||||
auto size = scratchpad_size(mkldnn::impl::scratchpad_mode::user);
|
|
||||||
mkldnn::impl::dims_t dims = { size };
|
|
||||||
mkldnn_memory_desc_init_by_tag(&scratchpad_md_, size ? 1 : 0, dims,
|
|
||||||
mkldnn::impl::data_type::u8, mkldnn_x);
|
|
||||||
}
|
|
||||||
|
|
||||||
/** returns the scratchpad size for the given scratchpad mode. */
|
|
||||||
mkldnn::impl::dim_t scratchpad_size(
|
|
||||||
mkldnn::impl::scratchpad_mode_t mode) const {
|
|
||||||
if (mode != attr_.scratchpad_mode_) return 0;
|
|
||||||
return scratchpad_registry().size();
|
|
||||||
}
|
|
||||||
|
|
||||||
virtual int n_inputs() const { return 0; }
|
|
||||||
virtual int n_outputs() const { return 0; }
|
|
||||||
|
|
||||||
virtual mkldnn::impl::status_t query(mkldnn::impl::query_t what, int idx,
|
|
||||||
void *result) const;
|
|
||||||
|
|
||||||
virtual mkldnn::impl::status_t create_primitive(
|
|
||||||
mkldnn::impl::primitive_t **primitive) const = 0;
|
|
||||||
|
|
||||||
virtual const char *name() const { return "mkldnn_primitive_desc"; }
|
|
||||||
|
|
||||||
/* static magic */
|
|
||||||
|
|
||||||
template<typename pd_t>
|
|
||||||
static mkldnn::impl::status_t create(mkldnn::impl::primitive_desc_t **pd,
|
|
||||||
const mkldnn::impl::op_desc_t *adesc,
|
|
||||||
const mkldnn::impl::primitive_attr_t *attr,
|
|
||||||
mkldnn::impl::engine_t *engine,
|
|
||||||
const mkldnn::impl::primitive_desc_t *hint_fwd) {
|
|
||||||
using namespace mkldnn::impl;
|
|
||||||
using namespace mkldnn::impl::status;
|
|
||||||
using pd_op_desc_t = typename pkind_traits<pd_t::base_pkind>::desc_type;
|
|
||||||
if (adesc->kind != pd_t::base_pkind) return invalid_arguments;
|
|
||||||
assert(hint_fwd ? hint_fwd->kind() == pd_t::base_pkind : true);
|
|
||||||
auto hint =
|
|
||||||
reinterpret_cast<const typename pd_t::hint_class *>(hint_fwd);
|
|
||||||
auto _pd = new pd_t(engine, (const pd_op_desc_t *)adesc, attr, hint);
|
|
||||||
if (_pd == nullptr) return out_of_memory;
|
|
||||||
if (_pd->init() != success) { delete _pd; return unimplemented; }
|
|
||||||
_pd->init_info();
|
|
||||||
_pd->init_scratchpad_md();
|
|
||||||
*pd = _pd;
|
|
||||||
return success;
|
|
||||||
}
|
|
||||||
|
|
||||||
protected:
|
|
||||||
mkldnn::impl::engine_t *engine_;
|
|
||||||
mkldnn::impl::primitive_attr_t attr_;
|
|
||||||
mkldnn::impl::primitive_kind_t kind_;
|
|
||||||
|
|
||||||
mkldnn::impl::memory_desc_t scratchpad_md_;
|
|
||||||
|
|
||||||
char info_[MKLDNN_VERBOSE_BUF_LEN];
|
|
||||||
|
|
||||||
mkldnn::impl::memory_tracking::registry_t scratchpad_registry_;
|
|
||||||
|
|
||||||
protected:
|
|
||||||
/** compares ws between fwd_pd and this (make sense to use for bwd_pd)
|
|
||||||
* Expectation: this already set workspace, and this workspace should
|
|
||||||
* exactly match the one from fwd_pd */
|
|
||||||
bool compare_ws(const mkldnn_primitive_desc *fwd_pd) const {
|
|
||||||
using namespace mkldnn::impl;
|
|
||||||
if (!workspace_md()) return true; // the impl lives fine w/o workspace
|
|
||||||
return fwd_pd && fwd_pd->workspace_md()
|
|
||||||
&& *fwd_pd->workspace_md() == *workspace_md();
|
|
||||||
}
|
|
||||||
};
|
|
||||||
|
|
||||||
#define DECLARE_COMMON_PD_t(impl_name, ...) \
|
|
||||||
virtual pd_t *clone() const override { return new pd_t(*this); } \
|
|
||||||
virtual status_t create_primitive(primitive_t **p) const override { \
|
|
||||||
double ms = get_msec(); \
|
|
||||||
auto ret = safe_ptr_assign<primitive_t>(*p, new (__VA_ARGS__)(this)); \
|
|
||||||
ms = get_msec() - ms; \
|
|
||||||
if (mkldnn_verbose()->level >= 2) { \
|
|
||||||
printf("mkldnn_verbose,create,%s,%g\n", this->info(), ms); \
|
|
||||||
fflush(0); \
|
|
||||||
} \
|
|
||||||
return ret; \
|
|
||||||
} \
|
|
||||||
virtual const char *name() const override { return impl_name; }
|
|
||||||
#define DECLARE_COMMON_PD_T(impl_name, ...) \
|
|
||||||
DECLARE_COMMON_PD_t(impl_name, __VA_ARGS__)
|
|
||||||
|
|
||||||
#endif
|
|
||||||
|
|
||||||
// vim: et ts=4 sw=4 cindent cino^=l0,\:0,N-s
|
|
@ -1,90 +0,0 @@
|
|||||||
/*******************************************************************************
|
|
||||||
* Copyright 2018 Intel Corporation
|
|
||||||
*
|
|
||||||
* Licensed under the Apache License, Version 2.0 (the "License");
|
|
||||||
* you may not use this file except in compliance with the License.
|
|
||||||
* You may obtain a copy of the License at
|
|
||||||
*
|
|
||||||
* http://www.apache.org/licenses/LICENSE-2.0
|
|
||||||
*
|
|
||||||
* Unless required by applicable law or agreed to in writing, software
|
|
||||||
* distributed under the License is distributed on an "AS IS" BASIS,
|
|
||||||
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
|
||||||
* See the License for the specific language governing permissions and
|
|
||||||
* limitations under the License.
|
|
||||||
*******************************************************************************/
|
|
||||||
|
|
||||||
#include "memory.hpp"
|
|
||||||
#include "primitive.hpp"
|
|
||||||
#include "primitive_exec_types.hpp"
|
|
||||||
|
|
||||||
namespace mkldnn {
|
|
||||||
namespace impl {
|
|
||||||
|
|
||||||
status_t cvt_primtive_args(const primitive_desc_t *pd, int nargs,
|
|
||||||
const mkldnn_exec_arg_t *c_args, exec_args_t &args) {
|
|
||||||
using namespace status;
|
|
||||||
|
|
||||||
if (!IMPLICATION(nargs > 0, c_args != nullptr)) return invalid_arguments;
|
|
||||||
|
|
||||||
int n_inputs = 0;
|
|
||||||
int n_outputs = 0;
|
|
||||||
|
|
||||||
for (int i = 0; i < nargs; ++i) {
|
|
||||||
primitive_arg_index_t arg = c_args[i].arg;
|
|
||||||
auto *mem = c_args[i].memory;
|
|
||||||
|
|
||||||
switch (pd->arg_usage(arg)) {
|
|
||||||
case primitive_desc_t::arg_usage_t::input:
|
|
||||||
if (args.count(arg) != 0) return invalid_arguments;
|
|
||||||
args[arg] = {mem, true};
|
|
||||||
n_inputs++;
|
|
||||||
break;
|
|
||||||
case primitive_desc_t::arg_usage_t::output:
|
|
||||||
if (args.count(arg) != 0) return invalid_arguments;
|
|
||||||
args[arg] = {mem, false};
|
|
||||||
n_outputs++;
|
|
||||||
break;
|
|
||||||
case primitive_desc_t::arg_usage_t::unused:
|
|
||||||
break;
|
|
||||||
}
|
|
||||||
}
|
|
||||||
|
|
||||||
bool scratchpad_required = !types::is_zero_md(pd->scratchpad_md());
|
|
||||||
|
|
||||||
if (n_inputs != pd->n_inputs()) return invalid_arguments;
|
|
||||||
if (n_outputs != pd->n_outputs() + (scratchpad_required ? 1 : 0))
|
|
||||||
return invalid_arguments;
|
|
||||||
|
|
||||||
return success;
|
|
||||||
}
|
|
||||||
|
|
||||||
const void *exec_ctx_t::input(primitive_arg_index_t arg) const {
|
|
||||||
if (args_.count(arg) != 1) return nullptr;
|
|
||||||
const auto ma = args_.at(arg);
|
|
||||||
assert(ma.is_const);
|
|
||||||
void *ptr;
|
|
||||||
status_t status = ma.mem->get_data_handle(&ptr);
|
|
||||||
assert(status == status::success); MAYBE_UNUSED(status);
|
|
||||||
return ptr;
|
|
||||||
}
|
|
||||||
|
|
||||||
void *exec_ctx_t::output(primitive_arg_index_t arg) const {
|
|
||||||
if (args_.count(arg) != 1) return nullptr;
|
|
||||||
const auto ma = args_.at(arg);
|
|
||||||
assert(!ma.is_const);
|
|
||||||
void *ptr;
|
|
||||||
status_t status = ma.mem->get_data_handle(&ptr);
|
|
||||||
assert(status == status::success); MAYBE_UNUSED(status);
|
|
||||||
return ptr;
|
|
||||||
}
|
|
||||||
|
|
||||||
const memory_t *exec_ctx_t::memory(primitive_arg_index_t arg) const {
|
|
||||||
assert(args_.count(arg) == 1);
|
|
||||||
const auto ma = args_.at(arg);
|
|
||||||
assert(!ma.is_const);
|
|
||||||
return ma.mem;
|
|
||||||
}
|
|
||||||
|
|
||||||
}
|
|
||||||
}
|
|
@ -1,68 +0,0 @@
|
|||||||
/*******************************************************************************
|
|
||||||
* Copyright 2018 Intel Corporation
|
|
||||||
*
|
|
||||||
* Licensed under the Apache License, Version 2.0 (the "License");
|
|
||||||
* you may not use this file except in compliance with the License.
|
|
||||||
* You may obtain a copy of the License at
|
|
||||||
*
|
|
||||||
* http://www.apache.org/licenses/LICENSE-2.0
|
|
||||||
*
|
|
||||||
* Unless required by applicable law or agreed to in writing, software
|
|
||||||
* distributed under the License is distributed on an "AS IS" BASIS,
|
|
||||||
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
|
||||||
* See the License for the specific language governing permissions and
|
|
||||||
* limitations under the License.
|
|
||||||
*******************************************************************************/
|
|
||||||
|
|
||||||
#ifndef PRIMITIVE_EXEC_TYPES_HPP
|
|
||||||
#define PRIMITIVE_EXEC_TYPES_HPP
|
|
||||||
|
|
||||||
#include <unordered_map>
|
|
||||||
|
|
||||||
#include "mkldnn_types.h"
|
|
||||||
|
|
||||||
#include "c_types_map.hpp"
|
|
||||||
#include "memory.hpp"
|
|
||||||
#include "primitive_desc.hpp"
|
|
||||||
|
|
||||||
namespace mkldnn {
|
|
||||||
namespace impl {
|
|
||||||
|
|
||||||
struct memory_arg_t {
|
|
||||||
memory_t *mem;
|
|
||||||
bool is_const;
|
|
||||||
};
|
|
||||||
|
|
||||||
using exec_args_t = std::unordered_map<primitive_arg_index_t, memory_arg_t>;
|
|
||||||
|
|
||||||
status_t cvt_primtive_args(const primitive_desc_t *pd, int nargs,
|
|
||||||
const mkldnn_exec_arg_t *c_args, exec_args_t &args);
|
|
||||||
|
|
||||||
/** Primitive execution context (helps passing stream, memories, and events. */
|
|
||||||
struct exec_ctx_t {
|
|
||||||
exec_ctx_t(const exec_ctx_t &) = default;
|
|
||||||
exec_ctx_t(exec_ctx_t &&) = default;
|
|
||||||
|
|
||||||
exec_ctx_t(stream_t *stream): stream_(stream) {}
|
|
||||||
exec_ctx_t(stream_t *stream, exec_args_t &&args)
|
|
||||||
: stream_(stream)
|
|
||||||
, args_(std::move(args)) {}
|
|
||||||
|
|
||||||
stream_t *stream() const { return stream_; }
|
|
||||||
const exec_args_t &args() const { return args_; }
|
|
||||||
|
|
||||||
/* tentative solution... TODO: replace with functions return memory_t */
|
|
||||||
const void *input(primitive_arg_index_t arg) const;
|
|
||||||
void *output(primitive_arg_index_t arg) const;
|
|
||||||
|
|
||||||
const memory_t *memory(primitive_arg_index_t arg) const;
|
|
||||||
|
|
||||||
private:
|
|
||||||
stream_t *stream_;
|
|
||||||
exec_args_t args_;
|
|
||||||
};
|
|
||||||
|
|
||||||
}
|
|
||||||
}
|
|
||||||
|
|
||||||
#endif
|
|
@ -1,89 +0,0 @@
|
|||||||
/*******************************************************************************
|
|
||||||
* Copyright 2016-2018 Intel Corporation
|
|
||||||
*
|
|
||||||
* Licensed under the Apache License, Version 2.0 (the "License");
|
|
||||||
* you may not use this file except in compliance with the License.
|
|
||||||
* You may obtain a copy of the License at
|
|
||||||
*
|
|
||||||
* http://www.apache.org/licenses/LICENSE-2.0
|
|
||||||
*
|
|
||||||
* Unless required by applicable law or agreed to in writing, software
|
|
||||||
* distributed under the License is distributed on an "AS IS" BASIS,
|
|
||||||
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
|
||||||
* See the License for the specific language governing permissions and
|
|
||||||
* limitations under the License.
|
|
||||||
*******************************************************************************/
|
|
||||||
|
|
||||||
#include <assert.h>
|
|
||||||
|
|
||||||
#include "mkldnn.h"
|
|
||||||
|
|
||||||
#include "c_types_map.hpp"
|
|
||||||
#include "engine.hpp"
|
|
||||||
#include "primitive_desc.hpp"
|
|
||||||
#include "type_helpers.hpp"
|
|
||||||
#include "primitive_iterator.hpp"
|
|
||||||
|
|
||||||
using namespace mkldnn::impl;
|
|
||||||
using namespace mkldnn::impl::status;
|
|
||||||
|
|
||||||
status_t mkldnn_primitive_desc_iterator_create(
|
|
||||||
primitive_desc_iterator_t **iterator, const_c_op_desc_t c_op_desc,
|
|
||||||
const primitive_attr_t *attr, engine_t *engine,
|
|
||||||
const primitive_desc_t *hint_fwd_pd) {
|
|
||||||
const op_desc_t *op_desc = (const op_desc_t *)c_op_desc;
|
|
||||||
|
|
||||||
auto it = new primitive_desc_iterator_t(engine, op_desc, attr, hint_fwd_pd);
|
|
||||||
if (it == nullptr) return out_of_memory;
|
|
||||||
|
|
||||||
++(*it);
|
|
||||||
if (*it == it->end()) {
|
|
||||||
delete it;
|
|
||||||
return unimplemented;
|
|
||||||
}
|
|
||||||
|
|
||||||
*iterator = it;
|
|
||||||
return success;
|
|
||||||
}
|
|
||||||
|
|
||||||
status_t mkldnn_primitive_desc_iterator_next(
|
|
||||||
primitive_desc_iterator_t *iterator) {
|
|
||||||
if (iterator == nullptr) return invalid_arguments;
|
|
||||||
++(*iterator);
|
|
||||||
return *iterator == iterator->end() ? iterator_ends : success;
|
|
||||||
}
|
|
||||||
|
|
||||||
primitive_desc_t *mkldnn_primitive_desc_iterator_fetch(
|
|
||||||
const primitive_desc_iterator_t *iterator) {
|
|
||||||
if (iterator == nullptr) return nullptr;
|
|
||||||
return *(*iterator);
|
|
||||||
}
|
|
||||||
|
|
||||||
status_t mkldnn_primitive_desc_clone(primitive_desc_t **primitive_desc,
|
|
||||||
const primitive_desc_t *existing_primitive_desc) {
|
|
||||||
if (utils::any_null(primitive_desc, existing_primitive_desc))
|
|
||||||
return invalid_arguments;
|
|
||||||
return safe_ptr_assign<primitive_desc_t>(*primitive_desc,
|
|
||||||
existing_primitive_desc->clone());
|
|
||||||
}
|
|
||||||
|
|
||||||
status_t mkldnn_primitive_desc_iterator_destroy(
|
|
||||||
primitive_desc_iterator_t *iterator) {
|
|
||||||
if (iterator != nullptr)
|
|
||||||
delete iterator;
|
|
||||||
return success;
|
|
||||||
}
|
|
||||||
|
|
||||||
status_t mkldnn_primitive_desc_create(primitive_desc_t **primitive_desc,
|
|
||||||
const_c_op_desc_t c_op_desc, const primitive_attr_t *attr,
|
|
||||||
engine_t *engine, const primitive_desc_t *hint_fwd_pd) {
|
|
||||||
const op_desc_t *op_desc = (const op_desc_t *)c_op_desc;
|
|
||||||
|
|
||||||
mkldnn_primitive_desc_iterator it(engine, op_desc, attr, hint_fwd_pd);
|
|
||||||
++it;
|
|
||||||
if (it == it.end()) return unimplemented;
|
|
||||||
|
|
||||||
return safe_ptr_assign<primitive_desc_t>(*primitive_desc, *it);
|
|
||||||
}
|
|
||||||
|
|
||||||
// vim: et ts=4 sw=4 cindent cino^=l0,\:0,N-s
|
|
@ -1,79 +0,0 @@
|
|||||||
/*******************************************************************************
|
|
||||||
* Copyright 2018 Intel Corporation
|
|
||||||
*
|
|
||||||
* Licensed under the Apache License, Version 2.0 (the "License");
|
|
||||||
* you may not use this file except in compliance with the License.
|
|
||||||
* You may obtain a copy of the License at
|
|
||||||
*
|
|
||||||
* http://www.apache.org/licenses/LICENSE-2.0
|
|
||||||
*
|
|
||||||
* Unless required by applicable law or agreed to in writing, software
|
|
||||||
* distributed under the License is distributed on an "AS IS" BASIS,
|
|
||||||
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
|
||||||
* See the License for the specific language governing permissions and
|
|
||||||
* limitations under the License.
|
|
||||||
*******************************************************************************/
|
|
||||||
#ifndef PRIMITIVE_ITERATOR_HPP
|
|
||||||
#define PRIMITIVE_ITERATOR_HPP
|
|
||||||
|
|
||||||
#include "mkldnn.h"
|
|
||||||
|
|
||||||
#include "c_types_map.hpp"
|
|
||||||
#include "engine.hpp"
|
|
||||||
#include "primitive_desc.hpp"
|
|
||||||
#include "type_helpers.hpp"
|
|
||||||
|
|
||||||
struct mkldnn_primitive_desc_iterator: public mkldnn::impl::c_compatible {
|
|
||||||
using pd_create_f = mkldnn::impl::engine_t::primitive_desc_create_f;
|
|
||||||
|
|
||||||
mkldnn_primitive_desc_iterator(mkldnn::impl::engine_t *engine, const mkldnn::impl::op_desc_t *op_desc,
|
|
||||||
const mkldnn::impl::primitive_attr_t *attr, const mkldnn::impl::primitive_desc_t *hint_fwd_pd)
|
|
||||||
: idx_(-1), engine_(engine), pd_(nullptr), op_desc_(op_desc)
|
|
||||||
, attr_(attr ? *attr : mkldnn::impl::primitive_attr_t()), hint_fwd_pd_(hint_fwd_pd)
|
|
||||||
, impl_list_(engine_->get_implementation_list()), last_idx_(0)
|
|
||||||
{
|
|
||||||
while (impl_list_[last_idx_] != nullptr) ++last_idx_;
|
|
||||||
}
|
|
||||||
~mkldnn_primitive_desc_iterator() { if (pd_) delete pd_; }
|
|
||||||
|
|
||||||
bool operator==(const mkldnn::impl::primitive_desc_iterator_t& rhs) const
|
|
||||||
{ return idx_ == rhs.idx_ && engine_ == rhs.engine_; }
|
|
||||||
bool operator!=(const mkldnn::impl::primitive_desc_iterator_t& rhs) const
|
|
||||||
{ return !operator==(rhs); }
|
|
||||||
|
|
||||||
mkldnn::impl::primitive_desc_iterator_t end() const
|
|
||||||
{ return mkldnn_primitive_desc_iterator(engine_, last_idx_); }
|
|
||||||
|
|
||||||
mkldnn::impl::primitive_desc_iterator_t &operator++() {
|
|
||||||
if (pd_) { delete pd_; pd_ = nullptr; }
|
|
||||||
while (++idx_ != last_idx_) {
|
|
||||||
auto s = impl_list_[idx_](&pd_, op_desc_, &attr_, engine_,
|
|
||||||
hint_fwd_pd_);
|
|
||||||
if (s == mkldnn::impl::status::success) break;
|
|
||||||
}
|
|
||||||
return *this;
|
|
||||||
}
|
|
||||||
|
|
||||||
mkldnn::impl::primitive_desc_t *operator*() const {
|
|
||||||
if (*this == end() || pd_ == nullptr) return nullptr;
|
|
||||||
return pd_->clone();
|
|
||||||
}
|
|
||||||
|
|
||||||
protected:
|
|
||||||
int idx_;
|
|
||||||
mkldnn::impl::engine_t *engine_;
|
|
||||||
mkldnn::impl::primitive_desc_t *pd_;
|
|
||||||
const mkldnn::impl::op_desc_t *op_desc_;
|
|
||||||
const mkldnn::impl::primitive_attr_t attr_;
|
|
||||||
const mkldnn::impl::primitive_desc_t *hint_fwd_pd_;
|
|
||||||
const pd_create_f *impl_list_;
|
|
||||||
int last_idx_;
|
|
||||||
|
|
||||||
private:
|
|
||||||
mkldnn_primitive_desc_iterator(mkldnn::impl::engine_t *engine, int last_idx)
|
|
||||||
: idx_(last_idx), engine_(engine), pd_(nullptr)
|
|
||||||
, op_desc_(nullptr), hint_fwd_pd_(nullptr)
|
|
||||||
, impl_list_(nullptr), last_idx_(last_idx) {}
|
|
||||||
};
|
|
||||||
|
|
||||||
#endif
|
|
59
thirdparty/oidn/mkl-dnn/src/common/query.cpp
vendored
59
thirdparty/oidn/mkl-dnn/src/common/query.cpp
vendored
@ -1,59 +0,0 @@
|
|||||||
/*******************************************************************************
|
|
||||||
* Copyright 2016-2018 Intel Corporation
|
|
||||||
*
|
|
||||||
* Licensed under the Apache License, Version 2.0 (the "License");
|
|
||||||
* you may not use this file except in compliance with the License.
|
|
||||||
* You may obtain a copy of the License at
|
|
||||||
*
|
|
||||||
* http://www.apache.org/licenses/LICENSE-2.0
|
|
||||||
*
|
|
||||||
* Unless required by applicable law or agreed to in writing, software
|
|
||||||
* distributed under the License is distributed on an "AS IS" BASIS,
|
|
||||||
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
|
||||||
* See the License for the specific language governing permissions and
|
|
||||||
* limitations under the License.
|
|
||||||
*******************************************************************************/
|
|
||||||
|
|
||||||
#include <assert.h>
|
|
||||||
#include "mkldnn.h"
|
|
||||||
|
|
||||||
#include "c_types_map.hpp"
|
|
||||||
#include "engine.hpp"
|
|
||||||
#include "primitive_desc.hpp"
|
|
||||||
#include "utils.hpp"
|
|
||||||
|
|
||||||
using namespace mkldnn::impl;
|
|
||||||
using namespace mkldnn::impl::utils;
|
|
||||||
using namespace mkldnn::impl::status;
|
|
||||||
|
|
||||||
status_t mkldnn_primitive_desc_query(const primitive_desc_t *primitive_desc,
|
|
||||||
query_t what, int index, void *result) {
|
|
||||||
if (any_null(primitive_desc, result))
|
|
||||||
return invalid_arguments;
|
|
||||||
|
|
||||||
return primitive_desc->query(what, index, result);
|
|
||||||
}
|
|
||||||
|
|
||||||
const memory_desc_t *mkldnn_primitive_desc_query_md(
|
|
||||||
const primitive_desc_t *primitive_desc, query_t what, int index) {
|
|
||||||
const memory_desc_t *res_md = nullptr;
|
|
||||||
bool args_ok = true
|
|
||||||
&& primitive_desc != nullptr
|
|
||||||
&& (what & query::some_md) == query::some_md
|
|
||||||
&& what != query::some_md
|
|
||||||
&& mkldnn_primitive_desc_query(primitive_desc,
|
|
||||||
what, index, &res_md) == success;
|
|
||||||
return args_ok ? res_md : nullptr;
|
|
||||||
}
|
|
||||||
|
|
||||||
int mkldnn_primitive_desc_query_s32(const primitive_desc_t *primitive_desc,
|
|
||||||
query_t what, int index) {
|
|
||||||
int res_s32;
|
|
||||||
bool args_ok = primitive_desc != nullptr
|
|
||||||
&& one_of(what, query::num_of_inputs_s32, query::num_of_outputs_s32)
|
|
||||||
&& mkldnn_primitive_desc_query(primitive_desc, what, index, &res_s32)
|
|
||||||
== success;
|
|
||||||
return args_ok ? res_s32 : 0;
|
|
||||||
}
|
|
||||||
|
|
||||||
// vim: et ts=4 sw=4 cindent cino^=l0,\:0,N-s
|
|
68
thirdparty/oidn/mkl-dnn/src/common/reorder.cpp
vendored
68
thirdparty/oidn/mkl-dnn/src/common/reorder.cpp
vendored
@ -1,68 +0,0 @@
|
|||||||
/*******************************************************************************
|
|
||||||
* Copyright 2016-2018 Intel Corporation
|
|
||||||
*
|
|
||||||
* Licensed under the Apache License, Version 2.0 (the "License");
|
|
||||||
* you may not use this file except in compliance with the License.
|
|
||||||
* You may obtain a copy of the License at
|
|
||||||
*
|
|
||||||
* http://www.apache.org/licenses/LICENSE-2.0
|
|
||||||
*
|
|
||||||
* Unless required by applicable law or agreed to in writing, software
|
|
||||||
* distributed under the License is distributed on an "AS IS" BASIS,
|
|
||||||
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
|
||||||
* See the License for the specific language governing permissions and
|
|
||||||
* limitations under the License.
|
|
||||||
*******************************************************************************/
|
|
||||||
|
|
||||||
#include <assert.h>
|
|
||||||
#include "mkldnn.h"
|
|
||||||
|
|
||||||
#include "c_types_map.hpp"
|
|
||||||
#include "engine.hpp"
|
|
||||||
#include "type_helpers.hpp"
|
|
||||||
#include "utils.hpp"
|
|
||||||
|
|
||||||
#include "reorder_pd.hpp"
|
|
||||||
|
|
||||||
using namespace mkldnn::impl;
|
|
||||||
using namespace mkldnn::impl::utils;
|
|
||||||
using namespace mkldnn::impl::status;
|
|
||||||
|
|
||||||
status_t mkldnn_reorder_primitive_desc_create(
|
|
||||||
primitive_desc_t **reorder_pd,
|
|
||||||
engine_t *src_engine, const memory_desc_t *src_md,
|
|
||||||
engine_t *dst_engine, const memory_desc_t *dst_md,
|
|
||||||
const primitive_attr_t *attr) {
|
|
||||||
if (any_null(reorder_pd, src_engine, src_md, dst_engine, dst_md))
|
|
||||||
return invalid_arguments;
|
|
||||||
|
|
||||||
auto s_ek = src_engine->kind();
|
|
||||||
auto d_ek = dst_engine->kind();
|
|
||||||
if (!IMPLICATION(s_ek != d_ek, one_of(engine_kind::cpu, s_ek, d_ek)))
|
|
||||||
return invalid_arguments;
|
|
||||||
|
|
||||||
auto r_pd = reinterpret_cast<reorder_pd_t **>(reorder_pd);
|
|
||||||
auto s_mdw = memory_desc_wrapper(*src_md);
|
|
||||||
auto d_mdw = memory_desc_wrapper(*dst_md);
|
|
||||||
|
|
||||||
if (!s_mdw.consistent_with(d_mdw))
|
|
||||||
return invalid_arguments;
|
|
||||||
|
|
||||||
auto e = (s_ek != engine_kind::cpu) ? src_engine : dst_engine;
|
|
||||||
|
|
||||||
const primitive_attr_t dummy_attr;
|
|
||||||
if (attr == NULL)
|
|
||||||
attr = &dummy_attr;
|
|
||||||
|
|
||||||
for (auto r = e->get_reorder_implementation_list(); *r; ++r) {
|
|
||||||
if ((*r)(r_pd, e, attr, src_engine, src_md, dst_engine, dst_md)
|
|
||||||
== success) {
|
|
||||||
(*r_pd)->init_info();
|
|
||||||
(*r_pd)->init_scratchpad_md();
|
|
||||||
return success;
|
|
||||||
}
|
|
||||||
}
|
|
||||||
return unimplemented;
|
|
||||||
}
|
|
||||||
|
|
||||||
// vim: et ts=4 sw=4 cindent cino^=l0,\:0,N-s
|
|
@ -1,85 +0,0 @@
|
|||||||
/*******************************************************************************
|
|
||||||
* Copyright 2016-2018 Intel Corporation
|
|
||||||
*
|
|
||||||
* Licensed under the Apache License, Version 2.0 (the "License");
|
|
||||||
* you may not use this file except in compliance with the License.
|
|
||||||
* You may obtain a copy of the License at
|
|
||||||
*
|
|
||||||
* http://www.apache.org/licenses/LICENSE-2.0
|
|
||||||
*
|
|
||||||
* Unless required by applicable law or agreed to in writing, software
|
|
||||||
* distributed under the License is distributed on an "AS IS" BASIS,
|
|
||||||
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
|
||||||
* See the License for the specific language governing permissions and
|
|
||||||
* limitations under the License.
|
|
||||||
*******************************************************************************/
|
|
||||||
|
|
||||||
#ifndef REORDER_PD_HPP
|
|
||||||
#define REORDER_PD_HPP
|
|
||||||
|
|
||||||
#include <assert.h>
|
|
||||||
|
|
||||||
#include "c_types_map.hpp"
|
|
||||||
#include "primitive_attr.hpp"
|
|
||||||
#include "type_helpers.hpp"
|
|
||||||
#include "utils.hpp"
|
|
||||||
|
|
||||||
namespace mkldnn {
|
|
||||||
namespace impl {
|
|
||||||
|
|
||||||
struct reorder_pd_t: public primitive_desc_t {
|
|
||||||
reorder_pd_t(engine_t *engine, const primitive_attr_t *attr,
|
|
||||||
engine_t *src_engine, const memory_desc_t *src_md,
|
|
||||||
engine_t *dst_engine, const memory_desc_t *dst_md)
|
|
||||||
: primitive_desc_t(engine, attr, primitive_kind::reorder)
|
|
||||||
, src_engine_(src_engine)
|
|
||||||
, dst_engine_(dst_engine)
|
|
||||||
, scratchpad_engine_(nullptr)
|
|
||||||
, src_md_(*src_md)
|
|
||||||
, dst_md_(*dst_md)
|
|
||||||
{}
|
|
||||||
|
|
||||||
virtual const op_desc_t *op_desc() const override { return nullptr; }
|
|
||||||
virtual void init_info() override { impl::init_info(this, this->info_); }
|
|
||||||
|
|
||||||
virtual arg_usage_t arg_usage(primitive_arg_index_t arg) const override {
|
|
||||||
if (arg == MKLDNN_ARG_FROM)
|
|
||||||
return arg_usage_t::input;
|
|
||||||
|
|
||||||
if (arg == MKLDNN_ARG_TO)
|
|
||||||
return arg_usage_t::output;
|
|
||||||
|
|
||||||
return primitive_desc_t::arg_usage(arg);
|
|
||||||
}
|
|
||||||
|
|
||||||
virtual const memory_desc_t *src_md(int index = 0) const override
|
|
||||||
{ return index == 0 ? &src_md_ : nullptr; }
|
|
||||||
virtual const memory_desc_t *dst_md(int index = 0) const override
|
|
||||||
{ return index == 0 ? &dst_md_ : nullptr; }
|
|
||||||
|
|
||||||
virtual int n_inputs() const override { return 1; }
|
|
||||||
virtual int n_outputs() const override { return 1; }
|
|
||||||
|
|
||||||
float alpha() const { return attr()->output_scales_.scales_[0]; }
|
|
||||||
float beta() const {
|
|
||||||
const int sum_idx = attr()->post_ops_.find(primitive_kind::sum);
|
|
||||||
return sum_idx == -1 ? 0 : attr()->post_ops_.entry_[sum_idx].sum.scale;
|
|
||||||
}
|
|
||||||
virtual mkldnn::impl::engine_t *scratchpad_engine() const override
|
|
||||||
{ return scratchpad_engine_; }
|
|
||||||
|
|
||||||
protected:
|
|
||||||
engine_t *src_engine_;
|
|
||||||
engine_t *dst_engine_;
|
|
||||||
engine_t *scratchpad_engine_;
|
|
||||||
|
|
||||||
memory_desc_t src_md_;
|
|
||||||
memory_desc_t dst_md_;
|
|
||||||
};
|
|
||||||
|
|
||||||
}
|
|
||||||
}
|
|
||||||
|
|
||||||
#endif
|
|
||||||
|
|
||||||
// vim: et ts=4 sw=4 cindent cino^=l0,\:0,N-s
|
|
400
thirdparty/oidn/mkl-dnn/src/common/rnn.cpp
vendored
400
thirdparty/oidn/mkl-dnn/src/common/rnn.cpp
vendored
@ -1,400 +0,0 @@
|
|||||||
/*******************************************************************************
|
|
||||||
* Copyright 2018 Intel Corporation
|
|
||||||
*
|
|
||||||
* Licensed under the Apache License, Version 2.0 (the "License");
|
|
||||||
* you may not use this file except in compliance with the License.
|
|
||||||
* You may obtain a copy of the License at
|
|
||||||
*
|
|
||||||
* http://www.apache.org/licenses/LICENSE-2.0
|
|
||||||
*
|
|
||||||
* Unless required by applicable law or agreed to in writing, software
|
|
||||||
* distributed under the License is distributed on an "AS IS" BASIS,
|
|
||||||
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
|
||||||
* See the License for the specific language governing permissions and
|
|
||||||
* limitations under the License.
|
|
||||||
*******************************************************************************/
|
|
||||||
|
|
||||||
#include "mkldnn.h"
|
|
||||||
|
|
||||||
#include "c_types_map.hpp"
|
|
||||||
#include "type_helpers.hpp"
|
|
||||||
#include "utils.hpp"
|
|
||||||
#include "cpu/gemm/os_blas.hpp"
|
|
||||||
|
|
||||||
using namespace mkldnn::impl;
|
|
||||||
using namespace mkldnn::impl::status;
|
|
||||||
using namespace mkldnn::impl::types;
|
|
||||||
using namespace mkldnn::impl::utils;
|
|
||||||
|
|
||||||
namespace {
|
|
||||||
memory_desc_t copy_maybe_null(const memory_desc_t *md) {
|
|
||||||
return md ? *md : zero_md();
|
|
||||||
}
|
|
||||||
|
|
||||||
rnn_desc_t zero_rnn_desc() {
|
|
||||||
auto rd = rnn_desc_t();
|
|
||||||
rd.src_layer_desc = zero_md();
|
|
||||||
rd.src_iter_desc = zero_md();
|
|
||||||
rd.weights_layer_desc = zero_md();
|
|
||||||
rd.weights_iter_desc = zero_md();
|
|
||||||
rd.bias_desc = zero_md();
|
|
||||||
rd.dst_layer_desc = zero_md();
|
|
||||||
rd.dst_iter_desc = zero_md();
|
|
||||||
rd.diff_src_layer_desc = zero_md();
|
|
||||||
rd.diff_src_iter_desc = zero_md();
|
|
||||||
rd.diff_weights_layer_desc = zero_md();
|
|
||||||
rd.diff_weights_iter_desc = zero_md();
|
|
||||||
rd.diff_bias_desc = zero_md();
|
|
||||||
rd.diff_dst_layer_desc = zero_md();
|
|
||||||
rd.diff_dst_iter_desc = zero_md();
|
|
||||||
return rd;
|
|
||||||
}
|
|
||||||
}
|
|
||||||
|
|
||||||
/* Public C Api */
|
|
||||||
|
|
||||||
status_t mkldnn_rnn_cell_desc_init(rnn_cell_desc_t *rnn_cell_desc,
|
|
||||||
mkldnn_alg_kind_t cell_kind, mkldnn_alg_kind_t act_f,
|
|
||||||
unsigned int flags, float alpha, float clipping) {
|
|
||||||
using namespace mkldnn::impl::alg_kind;
|
|
||||||
|
|
||||||
bool args_ok = true
|
|
||||||
&& one_of(cell_kind, vanilla_rnn, vanilla_lstm, vanilla_gru,
|
|
||||||
gru_linear_before_reset)
|
|
||||||
&& IMPLICATION(cell_kind == vanilla_rnn,
|
|
||||||
one_of(act_f, eltwise_relu, eltwise_tanh, eltwise_logistic));
|
|
||||||
if (!args_ok)
|
|
||||||
return invalid_arguments;
|
|
||||||
|
|
||||||
auto rcd = mkldnn_rnn_cell_desc_t();
|
|
||||||
|
|
||||||
rcd.cell_kind = cell_kind;
|
|
||||||
rcd.activation_kind = act_f;
|
|
||||||
rcd.flags = flags;
|
|
||||||
rcd.alpha = rcd.flags & mkldnn_rnn_cell_with_relu ? alpha : 0;
|
|
||||||
rcd.clipping = rcd.flags & mkldnn_rnn_cell_with_clipping ? clipping : 0;
|
|
||||||
|
|
||||||
*rnn_cell_desc = rcd;
|
|
||||||
|
|
||||||
return success;
|
|
||||||
}
|
|
||||||
|
|
||||||
int mkldnn_rnn_cell_get_gates_count(const rnn_cell_desc_t *rnn_cell_desc) {
|
|
||||||
switch (rnn_cell_desc->cell_kind) {
|
|
||||||
case mkldnn::impl::alg_kind::vanilla_rnn: return 1;
|
|
||||||
case mkldnn::impl::alg_kind::vanilla_gru: return 3;
|
|
||||||
case mkldnn::impl::alg_kind::gru_linear_before_reset: return 3;
|
|
||||||
case mkldnn::impl::alg_kind::vanilla_lstm: return 4;
|
|
||||||
default: assert(!"unknown cell kind"); return 0;
|
|
||||||
}
|
|
||||||
return 0;
|
|
||||||
}
|
|
||||||
|
|
||||||
int mkldnn_rnn_cell_get_states_count(const rnn_cell_desc_t *rnn_cell_desc) {
|
|
||||||
switch (rnn_cell_desc->cell_kind) {
|
|
||||||
case mkldnn::impl::alg_kind::vanilla_rnn: return 1;
|
|
||||||
case mkldnn::impl::alg_kind::vanilla_gru: return 1;
|
|
||||||
case mkldnn::impl::alg_kind::gru_linear_before_reset: return 1;
|
|
||||||
case mkldnn::impl::alg_kind::vanilla_lstm: return 2;
|
|
||||||
default: assert(!"unknown cell kind"); return 0;
|
|
||||||
}
|
|
||||||
return 0;
|
|
||||||
}
|
|
||||||
|
|
||||||
status_t check_data_type_consistency_fwd(const rnn_cell_desc_t *rnn_cell_desc,
|
|
||||||
prop_kind_t prop_kind, const memory_desc_t *src_layer_desc,
|
|
||||||
const memory_desc_t *src_iter_desc,
|
|
||||||
const memory_desc_t *weights_layer_desc,
|
|
||||||
const memory_desc_t *weights_iter_desc, const memory_desc_t *bias_desc,
|
|
||||||
const memory_desc_t *dst_layer_desc,
|
|
||||||
const memory_desc_t *dst_iter_desc) {
|
|
||||||
using namespace data_type;
|
|
||||||
data_type_t src_layer_dt = src_layer_desc->data_type;
|
|
||||||
data_type_t dst_layer_dt = dst_layer_desc->data_type;
|
|
||||||
data_type_t weights_iter_dt = weights_iter_desc->data_type;
|
|
||||||
data_type_t weights_layer_dt = weights_layer_desc->data_type;
|
|
||||||
|
|
||||||
bool is_f32 = everyone_is(f32, src_layer_dt, dst_layer_dt, weights_iter_dt,
|
|
||||||
weights_layer_dt)
|
|
||||||
&& IMPLICATION(!is_zero_md(src_iter_desc),
|
|
||||||
src_iter_desc->data_type == f32)
|
|
||||||
&& IMPLICATION(!is_zero_md(dst_iter_desc),
|
|
||||||
dst_iter_desc->data_type == f32)
|
|
||||||
&& IMPLICATION(!is_zero_md(bias_desc), bias_desc->data_type == f32);
|
|
||||||
|
|
||||||
#if USE_MKL_PACKED_GEMM
|
|
||||||
bool is_u8u8u8 = src_layer_dt == u8
|
|
||||||
&& IMPLICATION(!is_zero_md(src_iter_desc),
|
|
||||||
src_iter_desc->data_type == u8)
|
|
||||||
&& IMPLICATION(!is_zero_md(dst_iter_desc),
|
|
||||||
dst_iter_desc->data_type == u8)
|
|
||||||
&& one_of(dst_layer_dt, u8, f32)
|
|
||||||
&& everyone_is(s8, weights_iter_dt, weights_layer_dt)
|
|
||||||
&& IMPLICATION(!is_zero_md(bias_desc), bias_desc->data_type == f32);
|
|
||||||
|
|
||||||
bool is_f32u8f32 = src_layer_dt == u8
|
|
||||||
&& IMPLICATION(!is_zero_md(src_iter_desc),
|
|
||||||
src_iter_desc->data_type == f32)
|
|
||||||
&& IMPLICATION(!is_zero_md(dst_iter_desc),
|
|
||||||
dst_iter_desc->data_type == f32)
|
|
||||||
&& one_of(dst_layer_dt, u8, f32)
|
|
||||||
&& everyone_is(s8, weights_iter_dt, weights_layer_dt)
|
|
||||||
&& IMPLICATION(!is_zero_md(bias_desc), bias_desc->data_type == f32);
|
|
||||||
|
|
||||||
bool is_inference = prop_kind == prop_kind::forward_inference;
|
|
||||||
bool is_lstm = rnn_cell_desc->cell_kind == mkldnn_vanilla_lstm;
|
|
||||||
|
|
||||||
return (is_f32 || ((is_u8u8u8 || is_f32u8f32) && is_lstm && is_inference))
|
|
||||||
? success
|
|
||||||
: unimplemented;
|
|
||||||
#else
|
|
||||||
return is_f32 ? success : unimplemented;
|
|
||||||
#endif
|
|
||||||
}
|
|
||||||
|
|
||||||
status_t check_dim_consistency(const rnn_cell_desc_t *rnn_cell_desc,
|
|
||||||
rnn_direction_t direction, int L, int D, int T, int N, int S, int G,
|
|
||||||
int SLC, int SIC, int DLC, int DIC, const memory_desc_t *src_layer_desc,
|
|
||||||
const memory_desc_t *src_iter_desc,
|
|
||||||
const memory_desc_t *weights_layer_desc,
|
|
||||||
const memory_desc_t *weights_iter_desc, const memory_desc_t *bias_desc,
|
|
||||||
const memory_desc_t *dst_layer_desc,
|
|
||||||
const memory_desc_t *dst_iter_desc) {
|
|
||||||
bool args_ok;
|
|
||||||
|
|
||||||
// * algorithm specific
|
|
||||||
args_ok = true
|
|
||||||
&& IMPLICATION(rnn_cell_desc->cell_kind == alg_kind::vanilla_gru,
|
|
||||||
DIC == SIC);
|
|
||||||
if (!args_ok) return invalid_arguments;
|
|
||||||
int extra_bias =
|
|
||||||
rnn_cell_desc->cell_kind == alg_kind::gru_linear_before_reset;
|
|
||||||
|
|
||||||
// * on num layers
|
|
||||||
args_ok = true
|
|
||||||
&& L == weights_layer_desc->dims[0]
|
|
||||||
&& L == weights_iter_desc->dims[0]
|
|
||||||
&& IMPLICATION(!is_zero_md(bias_desc), L == bias_desc->dims[0])
|
|
||||||
&& IMPLICATION(!is_zero_md(src_iter_desc), L == src_iter_desc->dims[0])
|
|
||||||
&& IMPLICATION(!is_zero_md(dst_iter_desc), L == dst_iter_desc->dims[0]);
|
|
||||||
if (!args_ok) return invalid_arguments;
|
|
||||||
|
|
||||||
// * on num directions
|
|
||||||
args_ok = true
|
|
||||||
&& D == weights_layer_desc->dims[1]
|
|
||||||
&& D == weights_iter_desc->dims[1]
|
|
||||||
&& IMPLICATION(!is_zero_md(bias_desc), D == bias_desc->dims[1])
|
|
||||||
&& IMPLICATION(!is_zero_md(src_iter_desc), D == src_iter_desc->dims[1])
|
|
||||||
&& IMPLICATION(!is_zero_md(dst_iter_desc), D == dst_iter_desc->dims[1]);
|
|
||||||
if (!args_ok) return invalid_arguments;
|
|
||||||
|
|
||||||
// * on num iterations
|
|
||||||
args_ok = true
|
|
||||||
&& T == src_layer_desc->dims[0]
|
|
||||||
&& T == dst_layer_desc->dims[0];
|
|
||||||
if (!args_ok) return invalid_arguments;
|
|
||||||
|
|
||||||
// * on mb
|
|
||||||
args_ok = true
|
|
||||||
&& N == src_layer_desc->dims[1]
|
|
||||||
&& N == dst_layer_desc->dims[1]
|
|
||||||
&& IMPLICATION(!is_zero_md(src_iter_desc), N == src_iter_desc->dims[3])
|
|
||||||
&& IMPLICATION(!is_zero_md(dst_iter_desc), N == dst_iter_desc->dims[3]);
|
|
||||||
if (!args_ok) return invalid_arguments;
|
|
||||||
|
|
||||||
// * on num gates
|
|
||||||
args_ok = true
|
|
||||||
&& G == mkldnn_rnn_cell_get_gates_count(rnn_cell_desc)
|
|
||||||
&& G == weights_layer_desc->dims[3]
|
|
||||||
&& G == weights_iter_desc->dims[3]
|
|
||||||
&& IMPLICATION(!is_zero_md(bias_desc),
|
|
||||||
G + extra_bias == bias_desc->dims[2]);
|
|
||||||
if (!args_ok) return invalid_arguments;
|
|
||||||
|
|
||||||
// * on num states
|
|
||||||
args_ok = true
|
|
||||||
&& S == mkldnn_rnn_cell_get_states_count(rnn_cell_desc)
|
|
||||||
&& IMPLICATION(!is_zero_md(src_iter_desc), S == src_iter_desc->dims[2])
|
|
||||||
&& IMPLICATION(!is_zero_md(dst_iter_desc), S == dst_iter_desc->dims[2]);
|
|
||||||
if (!args_ok) return invalid_arguments;
|
|
||||||
|
|
||||||
// * on slc
|
|
||||||
args_ok = true
|
|
||||||
&& SLC == weights_layer_desc->dims[2]
|
|
||||||
&& SLC == src_layer_desc->dims[2];
|
|
||||||
if (!args_ok) return invalid_arguments;
|
|
||||||
|
|
||||||
// * on sic
|
|
||||||
args_ok = true
|
|
||||||
&& SIC == weights_iter_desc->dims[2]
|
|
||||||
&& IMPLICATION(!is_zero_md(src_iter_desc),
|
|
||||||
SIC == src_iter_desc->dims[4]);
|
|
||||||
if (!args_ok) return invalid_arguments;
|
|
||||||
|
|
||||||
// * on dlc
|
|
||||||
int dlc_multiplier = (direction == mkldnn_bidirectional_concat) ? 2 : 1;
|
|
||||||
args_ok = true
|
|
||||||
&& DLC == dlc_multiplier * DIC
|
|
||||||
&& DLC == dst_layer_desc->dims[2];
|
|
||||||
if (!args_ok) return invalid_arguments;
|
|
||||||
|
|
||||||
// * on dic
|
|
||||||
args_ok = true
|
|
||||||
&& DIC == weights_layer_desc->dims[4]
|
|
||||||
&& DIC == weights_iter_desc->dims[4]
|
|
||||||
&& IMPLICATION(!is_zero_md(bias_desc), DIC == bias_desc->dims[3])
|
|
||||||
&& IMPLICATION(!is_zero_md(dst_iter_desc),
|
|
||||||
DIC == dst_iter_desc->dims[4]);
|
|
||||||
if (!args_ok) return invalid_arguments;
|
|
||||||
|
|
||||||
// * unrolling/fusion conditions
|
|
||||||
args_ok = true
|
|
||||||
&& IMPLICATION(L > 1, (dlc_multiplier * SLC) == DLC)
|
|
||||||
&& IMPLICATION(T > 1, SIC == DIC);
|
|
||||||
if (!args_ok) return invalid_arguments;
|
|
||||||
|
|
||||||
return success;
|
|
||||||
}
|
|
||||||
|
|
||||||
status_t MKLDNN_API mkldnn_rnn_forward_desc_init(mkldnn_rnn_desc_t *rnn_desc,
|
|
||||||
prop_kind_t prop_kind, const rnn_cell_desc_t *rnn_cell_desc,
|
|
||||||
const rnn_direction_t direction, const memory_desc_t *src_layer_desc,
|
|
||||||
const memory_desc_t *src_iter_desc,
|
|
||||||
const memory_desc_t *weights_layer_desc,
|
|
||||||
const memory_desc_t *weights_iter_desc, const memory_desc_t *bias_desc,
|
|
||||||
const memory_desc_t *dst_layer_desc,
|
|
||||||
const memory_desc_t *dst_iter_desc) {
|
|
||||||
bool args_ok = true && rnn_cell_desc != nullptr
|
|
||||||
&& !any_null(src_layer_desc, weights_layer_desc, weights_iter_desc,
|
|
||||||
dst_layer_desc);
|
|
||||||
if (!args_ok) return invalid_arguments;
|
|
||||||
|
|
||||||
//check dimensions consistency
|
|
||||||
int L = weights_layer_desc->dims[0];
|
|
||||||
int T = src_layer_desc->dims[0];
|
|
||||||
int N = src_layer_desc->dims[1];
|
|
||||||
const int D = one_of(direction, mkldnn_unidirectional_left2right,
|
|
||||||
mkldnn_unidirectional_right2left) ?
|
|
||||||
1 :
|
|
||||||
2;
|
|
||||||
int G = mkldnn_rnn_cell_get_gates_count(rnn_cell_desc);
|
|
||||||
int S = mkldnn_rnn_cell_get_states_count(rnn_cell_desc);
|
|
||||||
int SLC = src_layer_desc->dims[2];
|
|
||||||
int SIC = weights_iter_desc->dims[2];
|
|
||||||
int DLC = dst_layer_desc->dims[2];
|
|
||||||
int DIC = weights_layer_desc->dims[4];
|
|
||||||
|
|
||||||
CHECK(check_dim_consistency(rnn_cell_desc, direction, L, D, T, N, S,
|
|
||||||
G, SLC, SIC, DLC, DIC, src_layer_desc, src_iter_desc,
|
|
||||||
weights_layer_desc, weights_iter_desc, bias_desc, dst_layer_desc,
|
|
||||||
dst_iter_desc));
|
|
||||||
|
|
||||||
CHECK(check_data_type_consistency_fwd(rnn_cell_desc, prop_kind,
|
|
||||||
src_layer_desc, src_iter_desc, weights_layer_desc,
|
|
||||||
weights_iter_desc, bias_desc, dst_layer_desc, dst_iter_desc));
|
|
||||||
|
|
||||||
// Create the descriptor
|
|
||||||
mkldnn_rnn_desc_t rd = zero_rnn_desc();
|
|
||||||
|
|
||||||
rd.primitive_kind = primitive_kind::rnn;
|
|
||||||
rd.prop_kind = prop_kind;
|
|
||||||
rd.cell_desc = *rnn_cell_desc;
|
|
||||||
rd.direction = direction;
|
|
||||||
rd.src_layer_desc = copy_maybe_null(src_layer_desc);
|
|
||||||
rd.src_iter_desc = copy_maybe_null(src_iter_desc);
|
|
||||||
rd.weights_layer_desc = copy_maybe_null(weights_layer_desc);
|
|
||||||
rd.weights_iter_desc = copy_maybe_null(weights_iter_desc);
|
|
||||||
rd.bias_desc = copy_maybe_null(bias_desc);
|
|
||||||
rd.dst_layer_desc = copy_maybe_null(dst_layer_desc);
|
|
||||||
rd.dst_iter_desc = copy_maybe_null(dst_iter_desc);
|
|
||||||
|
|
||||||
*rnn_desc = rd;
|
|
||||||
|
|
||||||
return success;
|
|
||||||
}
|
|
||||||
|
|
||||||
status_t MKLDNN_API mkldnn_rnn_backward_desc_init(mkldnn_rnn_desc_t *rnn_desc,
|
|
||||||
prop_kind_t prop_kind, const rnn_cell_desc_t *rnn_cell_desc,
|
|
||||||
const rnn_direction_t direction, const memory_desc_t *src_layer_desc,
|
|
||||||
const memory_desc_t *src_iter_desc,
|
|
||||||
const memory_desc_t *weights_layer_desc,
|
|
||||||
const memory_desc_t *weights_iter_desc, const memory_desc_t *bias_desc,
|
|
||||||
const memory_desc_t *dst_layer_desc, const memory_desc_t *dst_iter_desc,
|
|
||||||
const memory_desc_t *diff_src_layer_desc,
|
|
||||||
const memory_desc_t *diff_src_iter_desc,
|
|
||||||
const memory_desc_t *diff_weights_layer_desc,
|
|
||||||
const memory_desc_t *diff_weights_iter_desc,
|
|
||||||
const memory_desc_t *diff_bias_desc,
|
|
||||||
const memory_desc_t *diff_dst_layer_desc,
|
|
||||||
const memory_desc_t *diff_dst_iter_desc) {
|
|
||||||
bool args_ok = true
|
|
||||||
&& !any_null(src_layer_desc, weights_layer_desc, weights_iter_desc,
|
|
||||||
dst_layer_desc, diff_src_layer_desc,
|
|
||||||
diff_weights_layer_desc, diff_weights_iter_desc,
|
|
||||||
diff_dst_layer_desc);
|
|
||||||
if (!args_ok)
|
|
||||||
return invalid_arguments;
|
|
||||||
|
|
||||||
auto xnor_md = [=](const memory_desc_t *a_md, const memory_desc_t *b_md) {
|
|
||||||
return is_zero_md(a_md) == is_zero_md(b_md);
|
|
||||||
};
|
|
||||||
|
|
||||||
args_ok = args_ok && xnor_md(bias_desc, diff_bias_desc)
|
|
||||||
&& xnor_md(dst_iter_desc, diff_dst_iter_desc)
|
|
||||||
&& xnor_md(src_iter_desc, diff_src_iter_desc);
|
|
||||||
if (!args_ok)
|
|
||||||
return invalid_arguments;
|
|
||||||
|
|
||||||
//check dimensions consistency
|
|
||||||
int L = weights_layer_desc->dims[0];
|
|
||||||
int T = src_layer_desc->dims[0];
|
|
||||||
int N = src_layer_desc->dims[1];
|
|
||||||
const int D = one_of(direction, mkldnn_unidirectional_left2right,
|
|
||||||
mkldnn_unidirectional_right2left) ?
|
|
||||||
1 :
|
|
||||||
2;
|
|
||||||
int G = mkldnn_rnn_cell_get_gates_count(rnn_cell_desc);
|
|
||||||
int S = mkldnn_rnn_cell_get_states_count(rnn_cell_desc);
|
|
||||||
int SLC = src_layer_desc->dims[2];
|
|
||||||
int SIC = weights_iter_desc->dims[2];
|
|
||||||
int DLC = dst_layer_desc->dims[2];
|
|
||||||
int DIC = weights_layer_desc->dims[4];
|
|
||||||
|
|
||||||
status_t st = check_dim_consistency(rnn_cell_desc, direction, L, D, T, N, S,
|
|
||||||
G, SLC, SIC, DLC, DIC, src_layer_desc, src_iter_desc,
|
|
||||||
weights_layer_desc, weights_iter_desc, bias_desc, dst_layer_desc,
|
|
||||||
dst_iter_desc);
|
|
||||||
if (st != success) return st;
|
|
||||||
|
|
||||||
st = check_dim_consistency(rnn_cell_desc, direction, L, D, T, N, S,
|
|
||||||
G, SLC, SIC, DLC, DIC, diff_src_layer_desc, diff_src_iter_desc,
|
|
||||||
diff_weights_layer_desc, diff_weights_iter_desc, diff_bias_desc,
|
|
||||||
diff_dst_layer_desc, diff_dst_iter_desc);
|
|
||||||
if (st != success) return st;
|
|
||||||
|
|
||||||
mkldnn_rnn_desc_t rd = zero_rnn_desc();
|
|
||||||
|
|
||||||
rd.primitive_kind = primitive_kind::rnn;
|
|
||||||
rd.prop_kind = prop_kind;
|
|
||||||
rd.cell_desc = *rnn_cell_desc;
|
|
||||||
rd.direction = direction;
|
|
||||||
|
|
||||||
rd.src_layer_desc = copy_maybe_null(src_layer_desc);
|
|
||||||
rd.src_iter_desc = copy_maybe_null(src_iter_desc);
|
|
||||||
rd.weights_layer_desc = copy_maybe_null(weights_layer_desc);
|
|
||||||
rd.weights_iter_desc = copy_maybe_null(weights_iter_desc);
|
|
||||||
rd.bias_desc = copy_maybe_null(bias_desc);
|
|
||||||
rd.dst_layer_desc = copy_maybe_null(dst_layer_desc);
|
|
||||||
rd.dst_iter_desc = copy_maybe_null(dst_iter_desc);
|
|
||||||
rd.diff_src_layer_desc = copy_maybe_null(diff_src_layer_desc);
|
|
||||||
rd.diff_src_iter_desc = copy_maybe_null(diff_src_iter_desc);
|
|
||||||
rd.diff_weights_layer_desc = copy_maybe_null(diff_weights_layer_desc);
|
|
||||||
rd.diff_weights_iter_desc = copy_maybe_null(diff_weights_iter_desc);
|
|
||||||
rd.diff_bias_desc = copy_maybe_null(diff_bias_desc);
|
|
||||||
rd.diff_dst_layer_desc = copy_maybe_null(diff_dst_layer_desc);
|
|
||||||
rd.diff_dst_iter_desc = copy_maybe_null(diff_dst_iter_desc);
|
|
||||||
|
|
||||||
*rnn_desc = rd;
|
|
||||||
|
|
||||||
return success;
|
|
||||||
}
|
|
280
thirdparty/oidn/mkl-dnn/src/common/rnn_pd.hpp
vendored
280
thirdparty/oidn/mkl-dnn/src/common/rnn_pd.hpp
vendored
@ -1,280 +0,0 @@
|
|||||||
/*******************************************************************************
|
|
||||||
* Copyright 2018 Intel Corporation
|
|
||||||
*
|
|
||||||
* Licensed under the Apache License, Version 2.0 (the "License");
|
|
||||||
* you may not use this file except in compliance with the License.
|
|
||||||
* You may obtain a copy of the License at
|
|
||||||
*
|
|
||||||
* http://www.apache.org/licenses/LICENSE-2.0
|
|
||||||
*
|
|
||||||
* Unless required by applicable law or agreed to in writing, software
|
|
||||||
* distributed under the License is distributed on an "AS IS" BASIS,
|
|
||||||
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
|
||||||
* See the License for the specific language governing permissions and
|
|
||||||
* limitations under the License.
|
|
||||||
*******************************************************************************/
|
|
||||||
|
|
||||||
#ifndef RNN_PD_HPP
|
|
||||||
#define RNN_PD_HPP
|
|
||||||
|
|
||||||
#include "mkldnn.h"
|
|
||||||
|
|
||||||
#include "c_types_map.hpp"
|
|
||||||
#include "primitive_desc.hpp"
|
|
||||||
#include "type_helpers.hpp"
|
|
||||||
|
|
||||||
namespace mkldnn {
|
|
||||||
namespace impl {
|
|
||||||
|
|
||||||
struct rnn_fwd_pd_t;
|
|
||||||
|
|
||||||
struct rnn_pd_t : public primitive_desc_t {
|
|
||||||
static constexpr auto base_pkind = primitive_kind::rnn;
|
|
||||||
|
|
||||||
rnn_pd_t(engine_t *engine,
|
|
||||||
const rnn_desc_t *adesc,
|
|
||||||
const primitive_attr_t *attr,
|
|
||||||
const rnn_fwd_pd_t *hint_fwd_pd)
|
|
||||||
: primitive_desc_t(engine, attr, base_pkind)
|
|
||||||
, desc_(*adesc)
|
|
||||||
, hint_fwd_pd_(hint_fwd_pd)
|
|
||||||
, src_layer_md_(desc_.src_layer_desc)
|
|
||||||
, src_iter_md_(desc_.src_iter_desc)
|
|
||||||
, weights_layer_md_(desc_.weights_layer_desc)
|
|
||||||
, weights_iter_md_(desc_.weights_iter_desc)
|
|
||||||
, bias_md_(desc_.bias_desc)
|
|
||||||
, dst_layer_md_(desc_.dst_layer_desc)
|
|
||||||
, dst_iter_md_(desc_.dst_iter_desc)
|
|
||||||
, ws_md_()
|
|
||||||
{}
|
|
||||||
|
|
||||||
const rnn_desc_t *desc() const { return &desc_; }
|
|
||||||
virtual const op_desc_t *op_desc() const override
|
|
||||||
{ return reinterpret_cast<const op_desc_t *>(this->desc()); }
|
|
||||||
virtual void init_info() override { impl::init_info(this, this->info_); }
|
|
||||||
|
|
||||||
virtual status_t query(query_t what, int idx, void *result) const override {
|
|
||||||
switch (what) {
|
|
||||||
case query::rnn_d: *(const rnn_desc_t **)result = desc(); break;
|
|
||||||
default: return primitive_desc_t::query(what, idx, result);
|
|
||||||
}
|
|
||||||
return status::success;
|
|
||||||
}
|
|
||||||
|
|
||||||
virtual const memory_desc_t *src_md(int index = 0) const override {
|
|
||||||
if (index == 0) return &src_layer_md_;
|
|
||||||
if (index == 1 && with_src_iter()) return &src_iter_md_;
|
|
||||||
return nullptr;
|
|
||||||
}
|
|
||||||
virtual const memory_desc_t *weights_md(int index = 0) const override {
|
|
||||||
if (index == 0) return &weights_layer_md_;
|
|
||||||
if (index == 1) return &weights_iter_md_;
|
|
||||||
if (index == 2 && with_bias()) return &bias_md_;
|
|
||||||
return nullptr;
|
|
||||||
}
|
|
||||||
virtual const memory_desc_t *dst_md(int index = 0) const override {
|
|
||||||
if (index == 0) return &dst_layer_md_;
|
|
||||||
if (index == 1 && with_dst_iter()) return &dst_iter_md_;
|
|
||||||
return nullptr;
|
|
||||||
}
|
|
||||||
virtual const memory_desc_t *workspace_md(int index = 0) const override
|
|
||||||
{ return index == 0 && !types::is_zero_md(&ws_md_) ? &ws_md_ : nullptr; }
|
|
||||||
|
|
||||||
/* common pooling aux functions */
|
|
||||||
|
|
||||||
bool is_training() const {
|
|
||||||
return utils::one_of(desc_.prop_kind, prop_kind::forward_training,
|
|
||||||
prop_kind::backward);
|
|
||||||
}
|
|
||||||
|
|
||||||
bool is_fwd() const {
|
|
||||||
return utils::one_of(desc_.prop_kind, prop_kind::forward_training,
|
|
||||||
prop_kind::forward_inference);
|
|
||||||
}
|
|
||||||
|
|
||||||
dim_t T() const { return desc_.src_layer_desc.dims[0]; }
|
|
||||||
dim_t MB() const { return desc_.src_layer_desc.dims[1]; }
|
|
||||||
|
|
||||||
dim_t L() const { return desc_.weights_layer_desc.dims[0]; }
|
|
||||||
dim_t D() const { return desc_.weights_layer_desc.dims[1]; }
|
|
||||||
|
|
||||||
dim_t SIC() const { return desc_.weights_iter_desc.dims[2]; }
|
|
||||||
|
|
||||||
dim_t SLC() const { return desc_.weights_layer_desc.dims[2]; }
|
|
||||||
dim_t G() const { return desc_.weights_layer_desc.dims[3]; }
|
|
||||||
dim_t DIC() const { return desc_.weights_layer_desc.dims[4]; }
|
|
||||||
|
|
||||||
dim_t DLC() const { return desc_.dst_layer_desc.dims[2]; }
|
|
||||||
|
|
||||||
bool with_bias() const
|
|
||||||
{ return !memory_desc_wrapper(desc_.bias_desc).is_zero(); }
|
|
||||||
|
|
||||||
bool with_src_iter() const
|
|
||||||
{ return !(memory_desc_wrapper(desc_.src_iter_desc).is_zero()); }
|
|
||||||
|
|
||||||
bool with_dst_iter() const
|
|
||||||
{ return !memory_desc_wrapper(desc_.dst_iter_desc).is_zero(); }
|
|
||||||
|
|
||||||
mkldnn::impl::alg_kind_t cell_kind() const
|
|
||||||
{ return desc_.cell_desc.cell_kind; }
|
|
||||||
mkldnn::impl::alg_kind_t activation_kind() const
|
|
||||||
{ return desc_.cell_desc.activation_kind; }
|
|
||||||
|
|
||||||
bool is_lbr() const
|
|
||||||
{ return cell_kind() == mkldnn_gru_linear_before_reset; }
|
|
||||||
|
|
||||||
mkldnn_rnn_direction_t direction() const { return desc_.direction; }
|
|
||||||
|
|
||||||
protected:
|
|
||||||
rnn_desc_t desc_;
|
|
||||||
const rnn_fwd_pd_t *hint_fwd_pd_;
|
|
||||||
|
|
||||||
memory_desc_t src_layer_md_;
|
|
||||||
memory_desc_t src_iter_md_;
|
|
||||||
memory_desc_t weights_layer_md_;
|
|
||||||
memory_desc_t weights_iter_md_;
|
|
||||||
memory_desc_t bias_md_;
|
|
||||||
memory_desc_t dst_layer_md_;
|
|
||||||
memory_desc_t dst_iter_md_;
|
|
||||||
|
|
||||||
memory_desc_t ws_md_;
|
|
||||||
};
|
|
||||||
|
|
||||||
struct rnn_fwd_pd_t: public rnn_pd_t {
|
|
||||||
typedef rnn_fwd_pd_t base_class;
|
|
||||||
typedef rnn_fwd_pd_t hint_class;
|
|
||||||
|
|
||||||
rnn_fwd_pd_t(engine_t *engine,
|
|
||||||
const rnn_desc_t *adesc,
|
|
||||||
const primitive_attr_t *attr,
|
|
||||||
const rnn_fwd_pd_t *hint_fwd_pd)
|
|
||||||
: rnn_pd_t(engine, adesc, attr, hint_fwd_pd)
|
|
||||||
{}
|
|
||||||
|
|
||||||
virtual arg_usage_t arg_usage(primitive_arg_index_t arg) const override {
|
|
||||||
if (arg == MKLDNN_ARG_SRC_LAYER)
|
|
||||||
return arg_usage_t::input;
|
|
||||||
|
|
||||||
if (arg == MKLDNN_ARG_SRC_ITER && with_src_iter())
|
|
||||||
return arg_usage_t::input;
|
|
||||||
|
|
||||||
if (utils::one_of(arg, MKLDNN_ARG_WEIGHTS_LAYER,
|
|
||||||
MKLDNN_ARG_WEIGHTS_ITER))
|
|
||||||
return arg_usage_t::input;
|
|
||||||
|
|
||||||
if (arg == MKLDNN_ARG_BIAS && with_bias())
|
|
||||||
return arg_usage_t::input;
|
|
||||||
|
|
||||||
if (arg == MKLDNN_ARG_DST_LAYER)
|
|
||||||
return arg_usage_t::output;
|
|
||||||
|
|
||||||
if (arg == MKLDNN_ARG_DST_ITER && with_dst_iter())
|
|
||||||
return arg_usage_t::output;
|
|
||||||
|
|
||||||
if (arg == MKLDNN_ARG_WORKSPACE && is_training())
|
|
||||||
return arg_usage_t::output;
|
|
||||||
|
|
||||||
return primitive_desc_t::arg_usage(arg);
|
|
||||||
}
|
|
||||||
|
|
||||||
virtual int n_inputs() const override
|
|
||||||
{ return 3 + with_bias() + with_src_iter(); }
|
|
||||||
virtual int n_outputs() const override
|
|
||||||
{ return 1 + with_dst_iter() + is_training(); }
|
|
||||||
};
|
|
||||||
|
|
||||||
struct rnn_bwd_pd_t : public rnn_pd_t {
|
|
||||||
typedef rnn_bwd_pd_t base_class;
|
|
||||||
typedef rnn_fwd_pd_t hint_class;
|
|
||||||
|
|
||||||
rnn_bwd_pd_t(engine_t *engine,
|
|
||||||
const rnn_desc_t *adesc,
|
|
||||||
const primitive_attr_t *attr,
|
|
||||||
const rnn_fwd_pd_t *hint_fwd_pd)
|
|
||||||
: rnn_pd_t(engine, adesc, attr, hint_fwd_pd)
|
|
||||||
, diff_src_layer_md_(desc_.diff_src_layer_desc)
|
|
||||||
, diff_src_iter_md_(desc_.diff_src_iter_desc)
|
|
||||||
, diff_weights_layer_md_(desc_.diff_weights_layer_desc)
|
|
||||||
, diff_weights_iter_md_(desc_.diff_weights_iter_desc)
|
|
||||||
, diff_bias_md_(desc_.diff_bias_desc)
|
|
||||||
, diff_dst_layer_md_(desc_.diff_dst_layer_desc)
|
|
||||||
, diff_dst_iter_md_(desc_.diff_dst_iter_desc)
|
|
||||||
{}
|
|
||||||
|
|
||||||
virtual arg_usage_t arg_usage(primitive_arg_index_t arg) const override {
|
|
||||||
if (utils::one_of(arg, MKLDNN_ARG_SRC_LAYER, MKLDNN_ARG_DST_LAYER,
|
|
||||||
MKLDNN_ARG_DIFF_DST_LAYER))
|
|
||||||
return arg_usage_t::input;
|
|
||||||
|
|
||||||
if (with_src_iter()) {
|
|
||||||
if (arg == MKLDNN_ARG_SRC_ITER)
|
|
||||||
return arg_usage_t::input;
|
|
||||||
|
|
||||||
if (arg == MKLDNN_ARG_DIFF_SRC_ITER)
|
|
||||||
return arg_usage_t::output;
|
|
||||||
}
|
|
||||||
|
|
||||||
if (utils::one_of(arg, MKLDNN_ARG_WEIGHTS_LAYER,
|
|
||||||
MKLDNN_ARG_WEIGHTS_ITER))
|
|
||||||
return arg_usage_t::input;
|
|
||||||
|
|
||||||
if (with_bias()) {
|
|
||||||
if (arg == MKLDNN_ARG_BIAS)
|
|
||||||
return arg_usage_t::input;
|
|
||||||
|
|
||||||
if (arg == MKLDNN_ARG_DIFF_BIAS)
|
|
||||||
return arg_usage_t::output;
|
|
||||||
}
|
|
||||||
|
|
||||||
if (utils::one_of(arg, MKLDNN_ARG_DST_ITER, MKLDNN_ARG_DIFF_DST_ITER)
|
|
||||||
&& with_dst_iter())
|
|
||||||
return arg_usage_t::input;
|
|
||||||
|
|
||||||
if (arg == MKLDNN_ARG_WORKSPACE)
|
|
||||||
return arg_usage_t::input;
|
|
||||||
|
|
||||||
if (utils::one_of(arg, MKLDNN_ARG_DIFF_SRC_LAYER,
|
|
||||||
MKLDNN_ARG_DIFF_WEIGHTS_LAYER,
|
|
||||||
MKLDNN_ARG_DIFF_WEIGHTS_ITER))
|
|
||||||
return arg_usage_t::output;
|
|
||||||
|
|
||||||
return primitive_desc_t::arg_usage(arg);
|
|
||||||
}
|
|
||||||
|
|
||||||
virtual const memory_desc_t *diff_src_md(int index = 0) const override {
|
|
||||||
if (index == 0) return &diff_src_layer_md_;
|
|
||||||
if (index == 1 && with_src_iter()) return &diff_src_iter_md_;
|
|
||||||
return nullptr;
|
|
||||||
}
|
|
||||||
virtual const memory_desc_t *diff_weights_md(
|
|
||||||
int index = 0) const override {
|
|
||||||
if (index == 0) return &diff_weights_layer_md_;
|
|
||||||
if (index == 1) return &diff_weights_iter_md_;
|
|
||||||
if (index == 2 && with_bias()) return &diff_bias_md_;
|
|
||||||
return nullptr;
|
|
||||||
}
|
|
||||||
virtual const memory_desc_t *diff_dst_md(int index = 0) const override {
|
|
||||||
if (index == 0) return &diff_dst_layer_md_;
|
|
||||||
if (index == 1 && with_dst_iter()) return &diff_dst_iter_md_;
|
|
||||||
return nullptr;
|
|
||||||
}
|
|
||||||
|
|
||||||
virtual int n_inputs() const override
|
|
||||||
{ return 6 + with_src_iter() + with_bias() + 2 * with_dst_iter(); }
|
|
||||||
virtual int n_outputs() const override
|
|
||||||
{ return 3 + with_src_iter() + with_bias(); }
|
|
||||||
|
|
||||||
protected:
|
|
||||||
memory_desc_t diff_src_layer_md_;
|
|
||||||
memory_desc_t diff_src_iter_md_;
|
|
||||||
memory_desc_t diff_weights_layer_md_;
|
|
||||||
memory_desc_t diff_weights_iter_md_;
|
|
||||||
memory_desc_t diff_bias_md_;
|
|
||||||
memory_desc_t diff_dst_layer_md_;
|
|
||||||
memory_desc_t diff_dst_iter_md_;
|
|
||||||
};
|
|
||||||
|
|
||||||
}
|
|
||||||
}
|
|
||||||
|
|
||||||
#endif
|
|
Some files were not shown because too many files have changed in this diff Show More
Loading…
Reference in New Issue
Block a user